Commit 5d0856f6f343bc58afc148d86672a64e0f871690
1 parent
07daaf4e56
Exists in
master
when parsing time make sure to include the date as well... because reasons
Showing
1 changed file
with
56 additions
and
57 deletions
Show diff stats
nullables.go
1 | package webutility | 1 | package webutility |
2 | 2 | ||
3 | import ( | 3 | import ( |
4 | "database/sql" | 4 | "database/sql" |
5 | "database/sql/driver" | 5 | "database/sql/driver" |
6 | "encoding/json" | 6 | "encoding/json" |
7 | "fmt" | 7 | "fmt" |
8 | "time" | 8 | "time" |
9 | ) | 9 | ) |
10 | 10 | ||
11 | // NullBool is a wrapper for sql.NullBool with added JSON (un)marshalling | 11 | // NullBool is a wrapper for sql.NullBool with added JSON (un)marshalling |
12 | type NullBool sql.NullBool | 12 | type NullBool sql.NullBool |
13 | 13 | ||
14 | // Scan ... | 14 | // Scan ... |
15 | func (nb *NullBool) Scan(value interface{}) error { | 15 | func (nb *NullBool) Scan(value interface{}) error { |
16 | var b sql.NullBool | 16 | var b sql.NullBool |
17 | if err := b.Scan(value); err != nil { | 17 | if err := b.Scan(value); err != nil { |
18 | nb.Bool, nb.Valid = false, false | 18 | nb.Bool, nb.Valid = false, false |
19 | return err | 19 | return err |
20 | } | 20 | } |
21 | nb.Bool, nb.Valid = b.Bool, b.Valid | 21 | nb.Bool, nb.Valid = b.Bool, b.Valid |
22 | return nil | 22 | return nil |
23 | } | 23 | } |
24 | 24 | ||
25 | // Value ... | 25 | // Value ... |
26 | func (nb *NullBool) Value() (driver.Value, error) { | 26 | func (nb *NullBool) Value() (driver.Value, error) { |
27 | if !nb.Valid { | 27 | if !nb.Valid { |
28 | return nil, nil | 28 | return nil, nil |
29 | } | 29 | } |
30 | return nb.Bool, nil | 30 | return nb.Bool, nil |
31 | } | 31 | } |
32 | 32 | ||
33 | // MarshalJSON ... | 33 | // MarshalJSON ... |
34 | func (nb NullBool) MarshalJSON() ([]byte, error) { | 34 | func (nb NullBool) MarshalJSON() ([]byte, error) { |
35 | if nb.Valid { | 35 | if nb.Valid { |
36 | return json.Marshal(nb.Bool) | 36 | return json.Marshal(nb.Bool) |
37 | } | 37 | } |
38 | 38 | ||
39 | return json.Marshal(nil) | 39 | return json.Marshal(nil) |
40 | } | 40 | } |
41 | 41 | ||
42 | // UnmarshalJSON ... | 42 | // UnmarshalJSON ... |
43 | func (nb *NullBool) UnmarshalJSON(b []byte) error { | 43 | func (nb *NullBool) UnmarshalJSON(b []byte) error { |
44 | var temp *bool | 44 | var temp *bool |
45 | if err := json.Unmarshal(b, &temp); err != nil { | 45 | if err := json.Unmarshal(b, &temp); err != nil { |
46 | return err | 46 | return err |
47 | } | 47 | } |
48 | if temp != nil { | 48 | if temp != nil { |
49 | nb.Valid = true | 49 | nb.Valid = true |
50 | nb.Bool = *temp | 50 | nb.Bool = *temp |
51 | } else { | 51 | } else { |
52 | nb.Valid = false | 52 | nb.Valid = false |
53 | } | 53 | } |
54 | return nil | 54 | return nil |
55 | } | 55 | } |
56 | 56 | ||
57 | // CastToSQL ... | 57 | // CastToSQL ... |
58 | func (nb *NullBool) CastToSQL() sql.NullBool { | 58 | func (nb *NullBool) CastToSQL() sql.NullBool { |
59 | return sql.NullBool(*nb) | 59 | return sql.NullBool(*nb) |
60 | } | 60 | } |
61 | 61 | ||
62 | // NullString is a wrapper for sql.NullString with added JSON (un)marshalling | 62 | // NullString is a wrapper for sql.NullString with added JSON (un)marshalling |
63 | type NullString sql.NullString | 63 | type NullString sql.NullString |
64 | 64 | ||
65 | // Scan ... | 65 | // Scan ... |
66 | func (ns *NullString) Scan(value interface{}) error { | 66 | func (ns *NullString) Scan(value interface{}) error { |
67 | var s sql.NullString | 67 | var s sql.NullString |
68 | if err := s.Scan(value); err != nil { | 68 | if err := s.Scan(value); err != nil { |
69 | ns.String, ns.Valid = "", false | 69 | ns.String, ns.Valid = "", false |
70 | return err | 70 | return err |
71 | } | 71 | } |
72 | ns.String, ns.Valid = s.String, s.Valid | 72 | ns.String, ns.Valid = s.String, s.Valid |
73 | return nil | 73 | return nil |
74 | } | 74 | } |
75 | 75 | ||
76 | // Value ... | 76 | // Value ... |
77 | func (ns *NullString) Value() (driver.Value, error) { | 77 | func (ns *NullString) Value() (driver.Value, error) { |
78 | if !ns.Valid { | 78 | if !ns.Valid { |
79 | return nil, nil | 79 | return nil, nil |
80 | } | 80 | } |
81 | return ns.String, nil | 81 | return ns.String, nil |
82 | } | 82 | } |
83 | 83 | ||
84 | // MarshalJSON ... | 84 | // MarshalJSON ... |
85 | func (ns NullString) MarshalJSON() ([]byte, error) { | 85 | func (ns NullString) MarshalJSON() ([]byte, error) { |
86 | if ns.Valid { | 86 | if ns.Valid { |
87 | return json.Marshal(ns.String) | 87 | return json.Marshal(ns.String) |
88 | } | 88 | } |
89 | return json.Marshal(nil) | 89 | return json.Marshal(nil) |
90 | } | 90 | } |
91 | 91 | ||
92 | // UnmarshalJSON ... | 92 | // UnmarshalJSON ... |
93 | func (ns *NullString) UnmarshalJSON(b []byte) error { | 93 | func (ns *NullString) UnmarshalJSON(b []byte) error { |
94 | var temp *string | 94 | var temp *string |
95 | if err := json.Unmarshal(b, &temp); err != nil { | 95 | if err := json.Unmarshal(b, &temp); err != nil { |
96 | return err | 96 | return err |
97 | } | 97 | } |
98 | if temp != nil { | 98 | if temp != nil { |
99 | ns.Valid = true | 99 | ns.Valid = true |
100 | ns.String = *temp | 100 | ns.String = *temp |
101 | } else { | 101 | } else { |
102 | ns.Valid = false | 102 | ns.Valid = false |
103 | } | 103 | } |
104 | return nil | 104 | return nil |
105 | } | 105 | } |
106 | 106 | ||
107 | // CastToSQL ... | 107 | // CastToSQL ... |
108 | func (ns *NullString) CastToSQL() sql.NullString { | 108 | func (ns *NullString) CastToSQL() sql.NullString { |
109 | return sql.NullString(*ns) | 109 | return sql.NullString(*ns) |
110 | } | 110 | } |
111 | 111 | ||
112 | // NullInt64 is a wrapper for sql.NullInt64 with added JSON (un)marshalling | 112 | // NullInt64 is a wrapper for sql.NullInt64 with added JSON (un)marshalling |
113 | type NullInt64 sql.NullInt64 | 113 | type NullInt64 sql.NullInt64 |
114 | 114 | ||
115 | // Scan ... | 115 | // Scan ... |
116 | func (ni *NullInt64) Scan(value interface{}) error { | 116 | func (ni *NullInt64) Scan(value interface{}) error { |
117 | var i sql.NullInt64 | 117 | var i sql.NullInt64 |
118 | if err := i.Scan(value); err != nil { | 118 | if err := i.Scan(value); err != nil { |
119 | ni.Int64, ni.Valid = 0, false | 119 | ni.Int64, ni.Valid = 0, false |
120 | return err | 120 | return err |
121 | } | 121 | } |
122 | ni.Int64, ni.Valid = i.Int64, i.Valid | 122 | ni.Int64, ni.Valid = i.Int64, i.Valid |
123 | return nil | 123 | return nil |
124 | } | 124 | } |
125 | 125 | ||
126 | // ScanPtr ... | 126 | // ScanPtr ... |
127 | func (ni *NullInt64) ScanPtr(v interface{}) error { | 127 | func (ni *NullInt64) ScanPtr(v interface{}) error { |
128 | if ip, ok := v.(*int64); ok && ip != nil { | 128 | if ip, ok := v.(*int64); ok && ip != nil { |
129 | return ni.Scan(*ip) | 129 | return ni.Scan(*ip) |
130 | } | 130 | } |
131 | return nil | 131 | return nil |
132 | } | 132 | } |
133 | 133 | ||
134 | // Value ... | 134 | // Value ... |
135 | func (ni *NullInt64) Value() (driver.Value, error) { | 135 | func (ni *NullInt64) Value() (driver.Value, error) { |
136 | if !ni.Valid { | 136 | if !ni.Valid { |
137 | return nil, nil | 137 | return nil, nil |
138 | } | 138 | } |
139 | return ni.Int64, nil | 139 | return ni.Int64, nil |
140 | } | 140 | } |
141 | 141 | ||
142 | func (ni *NullInt64) Val() int64 { | 142 | func (ni *NullInt64) Val() int64 { |
143 | return ni.Int64 | 143 | return ni.Int64 |
144 | } | 144 | } |
145 | 145 | ||
146 | // Add | 146 | // Add |
147 | func (ni *NullInt64) Add(i NullInt64) { | 147 | func (ni *NullInt64) Add(i NullInt64) { |
148 | ni.Valid = true | 148 | ni.Valid = true |
149 | ni.Int64 += i.Int64 | 149 | ni.Int64 += i.Int64 |
150 | } | 150 | } |
151 | 151 | ||
152 | func (ni *NullInt64) Set(i int64) { | 152 | func (ni *NullInt64) Set(i int64) { |
153 | ni.Valid = true | 153 | ni.Valid = true |
154 | ni.Int64 = i | 154 | ni.Int64 = i |
155 | } | 155 | } |
156 | 156 | ||
157 | // MarshalJSON ... | 157 | // MarshalJSON ... |
158 | func (ni NullInt64) MarshalJSON() ([]byte, error) { | 158 | func (ni NullInt64) MarshalJSON() ([]byte, error) { |
159 | if ni.Valid { | 159 | if ni.Valid { |
160 | return json.Marshal(ni.Int64) | 160 | return json.Marshal(ni.Int64) |
161 | } | 161 | } |
162 | return json.Marshal(nil) | 162 | return json.Marshal(nil) |
163 | } | 163 | } |
164 | 164 | ||
165 | // UnmarshalJSON ... | 165 | // UnmarshalJSON ... |
166 | func (ni *NullInt64) UnmarshalJSON(b []byte) error { | 166 | func (ni *NullInt64) UnmarshalJSON(b []byte) error { |
167 | var temp *int64 | 167 | var temp *int64 |
168 | if err := json.Unmarshal(b, &temp); err != nil { | 168 | if err := json.Unmarshal(b, &temp); err != nil { |
169 | return err | 169 | return err |
170 | } | 170 | } |
171 | if temp != nil { | 171 | if temp != nil { |
172 | ni.Valid = true | 172 | ni.Valid = true |
173 | ni.Int64 = *temp | 173 | ni.Int64 = *temp |
174 | } else { | 174 | } else { |
175 | ni.Valid = false | 175 | ni.Valid = false |
176 | } | 176 | } |
177 | return nil | 177 | return nil |
178 | } | 178 | } |
179 | 179 | ||
180 | // CastToSQL ... | 180 | // CastToSQL ... |
181 | func (ni *NullInt64) CastToSQL() sql.NullInt64 { | 181 | func (ni *NullInt64) CastToSQL() sql.NullInt64 { |
182 | return sql.NullInt64(*ni) | 182 | return sql.NullInt64(*ni) |
183 | } | 183 | } |
184 | 184 | ||
185 | // NullFloat64 is a wrapper for sql.NullFloat64 with added JSON (un)marshalling | 185 | // NullFloat64 is a wrapper for sql.NullFloat64 with added JSON (un)marshalling |
186 | type NullFloat64 sql.NullFloat64 | 186 | type NullFloat64 sql.NullFloat64 |
187 | 187 | ||
188 | // Scan ... | 188 | // Scan ... |
189 | func (nf *NullFloat64) Scan(value interface{}) error { | 189 | func (nf *NullFloat64) Scan(value interface{}) error { |
190 | var f sql.NullFloat64 | 190 | var f sql.NullFloat64 |
191 | if err := f.Scan(value); err != nil { | 191 | if err := f.Scan(value); err != nil { |
192 | nf.Float64, nf.Valid = 0.0, false | 192 | nf.Float64, nf.Valid = 0.0, false |
193 | return err | 193 | return err |
194 | } | 194 | } |
195 | nf.Float64, nf.Valid = f.Float64, f.Valid | 195 | nf.Float64, nf.Valid = f.Float64, f.Valid |
196 | return nil | 196 | return nil |
197 | } | 197 | } |
198 | 198 | ||
199 | // ScanPtr ... | 199 | // ScanPtr ... |
200 | func (nf *NullFloat64) ScanPtr(v interface{}) error { | 200 | func (nf *NullFloat64) ScanPtr(v interface{}) error { |
201 | if fp, ok := v.(*float64); ok && fp != nil { | 201 | if fp, ok := v.(*float64); ok && fp != nil { |
202 | return nf.Scan(*fp) | 202 | return nf.Scan(*fp) |
203 | } | 203 | } |
204 | return nil | 204 | return nil |
205 | } | 205 | } |
206 | 206 | ||
207 | // Value ... | 207 | // Value ... |
208 | func (nf *NullFloat64) Value() (driver.Value, error) { | 208 | func (nf *NullFloat64) Value() (driver.Value, error) { |
209 | if !nf.Valid { | 209 | if !nf.Valid { |
210 | return nil, nil | 210 | return nil, nil |
211 | } | 211 | } |
212 | return nf.Float64, nil | 212 | return nf.Float64, nil |
213 | } | 213 | } |
214 | 214 | ||
215 | // Val ... | 215 | // Val ... |
216 | func (nf *NullFloat64) Val() float64 { | 216 | func (nf *NullFloat64) Val() float64 { |
217 | return nf.Float64 | 217 | return nf.Float64 |
218 | } | 218 | } |
219 | 219 | ||
220 | // Add ... | 220 | // Add ... |
221 | func (nf *NullFloat64) Add(f NullFloat64) { | 221 | func (nf *NullFloat64) Add(f NullFloat64) { |
222 | nf.Valid = true | 222 | nf.Valid = true |
223 | nf.Float64 += f.Float64 | 223 | nf.Float64 += f.Float64 |
224 | } | 224 | } |
225 | 225 | ||
226 | func (nf *NullFloat64) Set(f float64) { | 226 | func (nf *NullFloat64) Set(f float64) { |
227 | nf.Valid = true | 227 | nf.Valid = true |
228 | nf.Float64 = f | 228 | nf.Float64 = f |
229 | } | 229 | } |
230 | 230 | ||
231 | // MarshalJSON ... | 231 | // MarshalJSON ... |
232 | func (nf NullFloat64) MarshalJSON() ([]byte, error) { | 232 | func (nf NullFloat64) MarshalJSON() ([]byte, error) { |
233 | if nf.Valid { | 233 | if nf.Valid { |
234 | return json.Marshal(nf.Float64) | 234 | return json.Marshal(nf.Float64) |
235 | } | 235 | } |
236 | return json.Marshal(nil) | 236 | return json.Marshal(nil) |
237 | } | 237 | } |
238 | 238 | ||
239 | // UnmarshalJSON ... | 239 | // UnmarshalJSON ... |
240 | func (nf *NullFloat64) UnmarshalJSON(b []byte) error { | 240 | func (nf *NullFloat64) UnmarshalJSON(b []byte) error { |
241 | var temp *float64 | 241 | var temp *float64 |
242 | if err := json.Unmarshal(b, &temp); err != nil { | 242 | if err := json.Unmarshal(b, &temp); err != nil { |
243 | return err | 243 | return err |
244 | } | 244 | } |
245 | if temp != nil { | 245 | if temp != nil { |
246 | nf.Valid = true | 246 | nf.Valid = true |
247 | nf.Float64 = *temp | 247 | nf.Float64 = *temp |
248 | } else { | 248 | } else { |
249 | nf.Valid = false | 249 | nf.Valid = false |
250 | } | 250 | } |
251 | return nil | 251 | return nil |
252 | } | 252 | } |
253 | 253 | ||
254 | // CastToSQL ... | 254 | // CastToSQL ... |
255 | func (nf *NullFloat64) CastToSQL() sql.NullFloat64 { | 255 | func (nf *NullFloat64) CastToSQL() sql.NullFloat64 { |
256 | return sql.NullFloat64(*nf) | 256 | return sql.NullFloat64(*nf) |
257 | } | 257 | } |
258 | 258 | ||
259 | // NullDateTime ... | 259 | // NullDateTime ... |
260 | type NullDateTime struct { | 260 | type NullDateTime struct { |
261 | Time time.Time | 261 | Time time.Time |
262 | Valid bool // Valid is true if Time is not NULL | 262 | Valid bool // Valid is true if Time is not NULL |
263 | } | 263 | } |
264 | 264 | ||
265 | // Scan ... | 265 | // Scan ... |
266 | func (nt *NullDateTime) Scan(value interface{}) (err error) { | 266 | func (nt *NullDateTime) Scan(value interface{}) (err error) { |
267 | if value == nil { | 267 | if value == nil { |
268 | nt.Time, nt.Valid = time.Time{}, false | 268 | nt.Time, nt.Valid = time.Time{}, false |
269 | return | 269 | return |
270 | } | 270 | } |
271 | 271 | ||
272 | switch v := value.(type) { | 272 | switch v := value.(type) { |
273 | case time.Time: | 273 | case time.Time: |
274 | nt.Time, nt.Valid = v, true | 274 | nt.Time, nt.Valid = v, true |
275 | return | 275 | return |
276 | case []byte: | 276 | case []byte: |
277 | nt.Time, err = parseSQLDateTime(string(v), time.UTC) | 277 | nt.Time, err = parseSQLDateTime(string(v), time.UTC) |
278 | nt.Valid = (err == nil) | 278 | nt.Valid = (err == nil) |
279 | return | 279 | return |
280 | case string: | 280 | case string: |
281 | nt.Time, err = parseSQLDateTime(v, time.UTC) | 281 | nt.Time, err = parseSQLDateTime(v, time.UTC) |
282 | nt.Valid = (err == nil) | 282 | nt.Valid = (err == nil) |
283 | return | 283 | return |
284 | } | 284 | } |
285 | 285 | ||
286 | nt.Valid = false | 286 | nt.Valid = false |
287 | return fmt.Errorf("Can't convert %T to time.Time", value) | 287 | return fmt.Errorf("Can't convert %T to time.Time", value) |
288 | } | 288 | } |
289 | 289 | ||
290 | // Value implements the driver Valuer interface. | 290 | // Value implements the driver Valuer interface. |
291 | func (nt NullDateTime) Value() (driver.Value, error) { | 291 | func (nt NullDateTime) Value() (driver.Value, error) { |
292 | if !nt.Valid { | 292 | if !nt.Valid { |
293 | return nil, nil | 293 | return nil, nil |
294 | } | 294 | } |
295 | return nt.Time, nil | 295 | return nt.Time, nil |
296 | } | 296 | } |
297 | 297 | ||
298 | // MarshalJSON ... | 298 | // MarshalJSON ... |
299 | func (nt NullDateTime) MarshalJSON() ([]byte, error) { | 299 | func (nt NullDateTime) MarshalJSON() ([]byte, error) { |
300 | if nt.Valid { | 300 | if nt.Valid { |
301 | format := nt.Time.Format("2006-01-02 15:04:05") | 301 | format := nt.Time.Format("2006-01-02 15:04:05") |
302 | return json.Marshal(format) | 302 | return json.Marshal(format) |
303 | } | 303 | } |
304 | return json.Marshal(nil) | 304 | return json.Marshal(nil) |
305 | } | 305 | } |
306 | 306 | ||
307 | // UnmarshalJSON ... | 307 | // UnmarshalJSON ... |
308 | func (nt *NullDateTime) UnmarshalJSON(b []byte) error { | 308 | func (nt *NullDateTime) UnmarshalJSON(b []byte) error { |
309 | var temp *time.Time | 309 | var temp *time.Time |
310 | var t1 time.Time | 310 | var t1 time.Time |
311 | var err error | 311 | var err error |
312 | 312 | ||
313 | s1 := string(b) | 313 | s1 := string(b) |
314 | s2 := s1[1 : len(s1)-1] | 314 | s2 := s1[1 : len(s1)-1] |
315 | if s1 == "null" { | 315 | if s1 == "null" { |
316 | temp = nil | 316 | temp = nil |
317 | } else { | 317 | } else { |
318 | t1, err = time.Parse("2006-01-02 15:04:05", s2) | 318 | t1, err = time.Parse("2006-01-02 15:04:05", s2) |
319 | if err != nil { | 319 | if err != nil { |
320 | return err | 320 | return err |
321 | } | 321 | } |
322 | temp = &t1 | 322 | temp = &t1 |
323 | } | 323 | } |
324 | 324 | ||
325 | if temp != nil { | 325 | if temp != nil { |
326 | nt.Valid = true | 326 | nt.Valid = true |
327 | nt.Time = *temp | 327 | nt.Time = *temp |
328 | } else { | 328 | } else { |
329 | nt.Valid = false | 329 | nt.Valid = false |
330 | } | 330 | } |
331 | return nil | 331 | return nil |
332 | } | 332 | } |
333 | 333 | ||
334 | func (nt *NullDateTime) CastToSQL() NullDateTime { | 334 | func (nt *NullDateTime) CastToSQL() NullDateTime { |
335 | return *nt | 335 | return *nt |
336 | } | 336 | } |
337 | 337 | ||
338 | func parseSQLDateTime(str string, loc *time.Location) (t time.Time, err error) { | ||
339 | base := "0000-00-00 00:00:00.0000000" | ||
340 | timeFormat := "2006-01-02 15:04:05.999999" | ||
341 | switch len(str) { | ||
342 | case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM" | ||
343 | if str == base[:len(str)] { | ||
344 | return | ||
345 | } | ||
346 | t, err = time.Parse(timeFormat[:len(str)], str) | ||
347 | default: | ||
348 | err = fmt.Errorf("invalid time string: %s", str) | ||
349 | return | ||
350 | } | ||
351 | |||
352 | // Adjust location | ||
353 | if err == nil && loc != time.UTC { | ||
354 | y, mo, d := t.Date() | ||
355 | h, mi, s := t.Clock() | ||
356 | t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil | ||
357 | } | ||
358 | |||
359 | return | ||
360 | } | ||
361 | |||
338 | // NullDate ... | 362 | // NullDate ... |
339 | type NullDate struct { | 363 | type NullDate struct { |
340 | Time time.Time | 364 | Time time.Time |
341 | Valid bool // Valid is true if Time is not NULL | 365 | Valid bool // Valid is true if Time is not NULL |
342 | } | 366 | } |
343 | 367 | ||
344 | // Scan ... | 368 | // Scan ... |
345 | func (nt *NullDate) Scan(value interface{}) (err error) { | 369 | func (nt *NullDate) Scan(value interface{}) (err error) { |
346 | if value == nil { | 370 | if value == nil { |
347 | nt.Time, nt.Valid = time.Time{}, false | 371 | nt.Time, nt.Valid = time.Time{}, false |
348 | return | 372 | return |
349 | } | 373 | } |
350 | 374 | ||
351 | switch v := value.(type) { | 375 | switch v := value.(type) { |
352 | case time.Time: | 376 | case time.Time: |
353 | nt.Time, nt.Valid = v, true | 377 | nt.Time, nt.Valid = v, true |
354 | return | 378 | return |
355 | case []byte: | 379 | case []byte: |
356 | nt.Time, err = parseSQLDate(string(v), time.UTC) | 380 | nt.Time, err = parseSQLDate(string(v), time.UTC) |
357 | nt.Valid = (err == nil) | 381 | nt.Valid = (err == nil) |
358 | return | 382 | return |
359 | case string: | 383 | case string: |
360 | nt.Time, err = parseSQLDate(v, time.UTC) | 384 | nt.Time, err = parseSQLDate(v, time.UTC) |
361 | nt.Valid = (err == nil) | 385 | nt.Valid = (err == nil) |
362 | return | 386 | return |
363 | } | 387 | } |
364 | 388 | ||
365 | nt.Valid = false | 389 | nt.Valid = false |
366 | return fmt.Errorf("Can't convert %T to time.Time", value) | 390 | return fmt.Errorf("Can't convert %T to time.Time", value) |
367 | } | 391 | } |
368 | 392 | ||
369 | // Value implements the driver Valuer interface. | 393 | // Value implements the driver Valuer interface. |
370 | func (nt NullDate) Value() (driver.Value, error) { | 394 | func (nt NullDate) Value() (driver.Value, error) { |
371 | if !nt.Valid { | 395 | if !nt.Valid { |
372 | return nil, nil | 396 | return nil, nil |
373 | } | 397 | } |
374 | return nt.Time, nil | 398 | return nt.Time, nil |
375 | } | 399 | } |
376 | 400 | ||
377 | // MarshalJSON ... | 401 | // MarshalJSON ... |
378 | func (nt NullDate) MarshalJSON() ([]byte, error) { | 402 | func (nt NullDate) MarshalJSON() ([]byte, error) { |
379 | if nt.Valid { | 403 | if nt.Valid { |
380 | format := nt.Time.Format("2006-01-02") | 404 | format := nt.Time.Format("2006-01-02") |
381 | return json.Marshal(format) | 405 | return json.Marshal(format) |
382 | } | 406 | } |
383 | return json.Marshal(nil) | 407 | return json.Marshal(nil) |
384 | } | 408 | } |
385 | 409 | ||
386 | // UnmarshalJSON ... | 410 | // UnmarshalJSON ... |
387 | func (nt *NullDate) UnmarshalJSON(b []byte) error { | 411 | func (nt *NullDate) UnmarshalJSON(b []byte) error { |
388 | var temp *time.Time | 412 | var temp *time.Time |
389 | var t1 time.Time | 413 | var t1 time.Time |
390 | var err error | 414 | var err error |
391 | 415 | ||
392 | s1 := string(b) | 416 | s1 := string(b) |
393 | s2 := s1[1 : len(s1)-1] | 417 | s2 := s1[1 : len(s1)-1] |
394 | if s1 == "null" { | 418 | if s1 == "null" { |
395 | temp = nil | 419 | temp = nil |
396 | } else { | 420 | } else { |
397 | t1, err = time.Parse("2006-01-02", s2) | 421 | t1, err = time.Parse("2006-01-02", s2) |
398 | if err != nil { | 422 | if err != nil { |
399 | return err | 423 | return err |
400 | } | 424 | } |
401 | temp = &t1 | 425 | temp = &t1 |
402 | } | 426 | } |
403 | 427 | ||
404 | if temp != nil { | 428 | if temp != nil { |
405 | nt.Valid = true | 429 | nt.Scan(t1) |
406 | nt.Time = *temp | ||
407 | } else { | 430 | } else { |
408 | nt.Valid = false | 431 | nt.Valid = false |
409 | } | 432 | } |
410 | return nil | 433 | return nil |
411 | } | 434 | } |
412 | 435 | ||
413 | func (nt *NullDate) CastToSQL() NullDate { | 436 | func (nt *NullDate) CastToSQL() NullDate { |
414 | return *nt | 437 | return *nt |
415 | } | 438 | } |
416 | 439 | ||
440 | func parseSQLDate(str string, loc *time.Location) (t time.Time, err error) { | ||
441 | base := "0000-00-00" | ||
442 | timeFormat := "2006-01-02" | ||
443 | switch len(str) { | ||
444 | case 10: | ||
445 | if str == base[:len(str)] { | ||
446 | return | ||
447 | } | ||
448 | t, err = time.Parse(timeFormat[:len(str)], str) | ||
449 | default: | ||
450 | err = fmt.Errorf("invalid time string: %s", str) | ||
451 | return | ||
452 | } | ||
453 | |||
454 | // Adjust location | ||
455 | if err == nil && loc != time.UTC { | ||
456 | y, mo, d := t.Date() | ||
457 | h, mi, s := t.Clock() | ||
458 | t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil | ||
459 | } | ||
460 | |||
461 | return | ||
462 | } | ||
463 | |||
417 | // NullTime ... | 464 | // NullTime ... |
418 | type NullTime struct { | 465 | type NullTime struct { |
419 | Time time.Time | 466 | Time time.Time |
420 | Valid bool // Valid is true if Time is not NULL | 467 | Valid bool // Valid is true if Time is not NULL |
421 | } | 468 | } |
422 | 469 | ||
423 | // Scan ... | 470 | // Scan ... |
424 | func (nt *NullTime) Scan(value interface{}) (err error) { | 471 | func (nt *NullTime) Scan(value interface{}) (err error) { |
425 | if value == nil { | 472 | if value == nil { |
426 | nt.Time, nt.Valid = time.Time{}, false | 473 | nt.Time, nt.Valid = time.Time{}, false |
427 | return | 474 | return |
428 | } | 475 | } |
429 | 476 | ||
430 | switch v := value.(type) { | 477 | switch v := value.(type) { |
431 | case time.Time: | 478 | case time.Time: |
432 | nt.Time, nt.Valid = v, true | 479 | nt.Time, nt.Valid = v, true |
433 | return | 480 | return |
434 | case []byte: | 481 | case []byte: |
435 | nt.Time, err = parseSQLTime(string(v), time.UTC) | 482 | nt.Time, err = parseSQLTime(string(v), time.UTC) |
436 | nt.Valid = (err == nil) | 483 | nt.Valid = (err == nil) |
437 | return | 484 | return |
438 | case string: | 485 | case string: |
439 | nt.Time, err = parseSQLTime(v, time.UTC) | 486 | nt.Time, err = parseSQLTime(v, time.UTC) |
440 | nt.Valid = (err == nil) | 487 | nt.Valid = (err == nil) |
441 | return | 488 | return |
442 | } | 489 | } |
443 | 490 | ||
444 | nt.Valid = false | 491 | nt.Valid = false |
445 | return fmt.Errorf("Can't convert %T to time.Time", value) | 492 | return fmt.Errorf("Can't convert %T to time.Time", value) |
446 | } | 493 | } |
447 | 494 | ||
448 | // Value implements the driver Valuer interface. | 495 | // Value implements the driver Valuer interface. |
449 | func (nt NullTime) Value() (driver.Value, error) { | 496 | func (nt NullTime) Value() (driver.Value, error) { |
450 | if !nt.Valid { | 497 | if !nt.Valid { |
451 | return nil, nil | 498 | return nil, nil |
452 | } | 499 | } |
453 | return nt.Time, nil | 500 | return nt.Time, nil |
454 | } | 501 | } |
455 | 502 | ||
456 | // MarshalJSON ... | 503 | // MarshalJSON ... |
457 | func (nt NullTime) MarshalJSON() ([]byte, error) { | 504 | func (nt NullTime) MarshalJSON() ([]byte, error) { |
458 | if nt.Valid { | 505 | if nt.Valid { |
459 | format := nt.Time.Format("15:04:05") | 506 | format := nt.Time.Format("15:04:05") |
460 | return json.Marshal(format) | 507 | return json.Marshal(format) |
461 | } | 508 | } |
462 | return json.Marshal(nil) | 509 | return json.Marshal(nil) |
463 | } | 510 | } |
464 | 511 | ||
465 | // UnmarshalJSON ... | 512 | // UnmarshalJSON ... |
466 | func (nt *NullTime) UnmarshalJSON(b []byte) error { | 513 | func (nt *NullTime) UnmarshalJSON(b []byte) error { |
467 | var temp *time.Time | 514 | var temp *time.Time |
468 | var t1 time.Time | 515 | var t1 time.Time |
469 | var err error | 516 | var err error |
470 | 517 | ||
471 | s1 := string(b) | 518 | s1 := string(b) |
472 | s2 := s1[1 : len(s1)-1] | 519 | s2 := s1[1 : len(s1)-1] |
473 | if s1 == "null" { | 520 | if s1 == "null" { |
474 | temp = nil | 521 | temp = nil |
475 | } else { | 522 | } else { |
476 | t1, err = time.Parse("15:04:05", s2) | 523 | t1, err = time.Parse("2006-05-04 15:04:05", "1970-01-01 "+s2) |
477 | if err != nil { | 524 | if err != nil { |
478 | return err | 525 | return err |
479 | } | 526 | } |
480 | temp = &t1 | 527 | temp = &t1 |
481 | } | 528 | } |
482 | 529 | ||
483 | if temp != nil { | 530 | if temp != nil { |
484 | nt.Valid = true | 531 | nt.Scan(t1) |
485 | nt.Time = *temp | ||
486 | } else { | 532 | } else { |
487 | nt.Valid = false | 533 | nt.Valid = false |
488 | } | 534 | } |
489 | return nil | 535 | return nil |
490 | } | 536 | } |
491 | 537 | ||
492 | func (nt *NullTime) CastToSQL() NullTime { | 538 | func (nt *NullTime) CastToSQL() NullTime { |
493 | return *nt | 539 | return *nt |
494 | } | 540 | } |
495 | 541 | ||
496 | func parseSQLDateTime(str string, loc *time.Location) (t time.Time, err error) { | 542 | // NOTE(marko): Date must be included because database can't convert it to TIME otherwise. |
497 | base := "0000-00-00 00:00:00.0000000" | ||
498 | timeFormat := "2006-01-02 15:04:05.999999" | ||
499 | switch len(str) { | ||
500 | case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM" | ||
501 | if str == base[:len(str)] { | ||
502 | return | ||
503 | } | ||
504 | t, err = time.Parse(timeFormat[:len(str)], str) | ||
505 | default: | ||
506 | err = fmt.Errorf("invalid time string: %s", str) | ||
507 | return | ||
508 | } | ||
509 | |||
510 | // Adjust location | ||
511 | if err == nil && loc != time.UTC { | ||
512 | y, mo, d := t.Date() | ||
513 | h, mi, s := t.Clock() | ||
514 | t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil | ||
515 | } | ||
516 | |||
517 | return | ||
518 | } | ||
519 |