Commit 07daaf4e563562e18240031d27d8936d45d0af2d
1 parent
6a5c6931eb
Exists in
master
separate nullable time, date and timedate structs
Showing
2 changed files
with
224 additions
and
13 deletions
Show diff stats
nullables.go
... | ... | @@ -54,8 +54,8 @@ func (nb *NullBool) UnmarshalJSON(b []byte) error { |
54 | 54 | return nil |
55 | 55 | } |
56 | 56 | |
57 | -// SQLCast ... | |
58 | -func (nb *NullBool) SQLCast() sql.NullBool { | |
57 | +// CastToSQL ... | |
58 | +func (nb *NullBool) CastToSQL() sql.NullBool { | |
59 | 59 | return sql.NullBool(*nb) |
60 | 60 | } |
61 | 61 | |
... | ... | @@ -104,8 +104,8 @@ func (ns *NullString) UnmarshalJSON(b []byte) error { |
104 | 104 | return nil |
105 | 105 | } |
106 | 106 | |
107 | -// SQLCast ... | |
108 | -func (ns *NullString) SQLCast() sql.NullString { | |
107 | +// CastToSQL ... | |
108 | +func (ns *NullString) CastToSQL() sql.NullString { | |
109 | 109 | return sql.NullString(*ns) |
110 | 110 | } |
111 | 111 | |
... | ... | @@ -177,8 +177,8 @@ func (ni *NullInt64) UnmarshalJSON(b []byte) error { |
177 | 177 | return nil |
178 | 178 | } |
179 | 179 | |
180 | -// SQLCast ... | |
181 | -func (ni *NullInt64) SQLCast() sql.NullInt64 { | |
180 | +// CastToSQL ... | |
181 | +func (ni *NullInt64) CastToSQL() sql.NullInt64 { | |
182 | 182 | return sql.NullInt64(*ni) |
183 | 183 | } |
184 | 184 | |
... | ... | @@ -251,11 +251,169 @@ func (nf *NullFloat64) UnmarshalJSON(b []byte) error { |
251 | 251 | return nil |
252 | 252 | } |
253 | 253 | |
254 | -// SQLCast ... | |
255 | -func (nf *NullFloat64) SQLCast() sql.NullFloat64 { | |
254 | +// CastToSQL ... | |
255 | +func (nf *NullFloat64) CastToSQL() sql.NullFloat64 { | |
256 | 256 | return sql.NullFloat64(*nf) |
257 | 257 | } |
258 | 258 | |
259 | +// NullDateTime ... | |
260 | +type NullDateTime struct { | |
261 | + Time time.Time | |
262 | + Valid bool // Valid is true if Time is not NULL | |
263 | +} | |
264 | + | |
265 | +// Scan ... | |
266 | +func (nt *NullDateTime) Scan(value interface{}) (err error) { | |
267 | + if value == nil { | |
268 | + nt.Time, nt.Valid = time.Time{}, false | |
269 | + return | |
270 | + } | |
271 | + | |
272 | + switch v := value.(type) { | |
273 | + case time.Time: | |
274 | + nt.Time, nt.Valid = v, true | |
275 | + return | |
276 | + case []byte: | |
277 | + nt.Time, err = parseSQLDateTime(string(v), time.UTC) | |
278 | + nt.Valid = (err == nil) | |
279 | + return | |
280 | + case string: | |
281 | + nt.Time, err = parseSQLDateTime(v, time.UTC) | |
282 | + nt.Valid = (err == nil) | |
283 | + return | |
284 | + } | |
285 | + | |
286 | + nt.Valid = false | |
287 | + return fmt.Errorf("Can't convert %T to time.Time", value) | |
288 | +} | |
289 | + | |
290 | +// Value implements the driver Valuer interface. | |
291 | +func (nt NullDateTime) Value() (driver.Value, error) { | |
292 | + if !nt.Valid { | |
293 | + return nil, nil | |
294 | + } | |
295 | + return nt.Time, nil | |
296 | +} | |
297 | + | |
298 | +// MarshalJSON ... | |
299 | +func (nt NullDateTime) MarshalJSON() ([]byte, error) { | |
300 | + if nt.Valid { | |
301 | + format := nt.Time.Format("2006-01-02 15:04:05") | |
302 | + return json.Marshal(format) | |
303 | + } | |
304 | + return json.Marshal(nil) | |
305 | +} | |
306 | + | |
307 | +// UnmarshalJSON ... | |
308 | +func (nt *NullDateTime) UnmarshalJSON(b []byte) error { | |
309 | + var temp *time.Time | |
310 | + var t1 time.Time | |
311 | + var err error | |
312 | + | |
313 | + s1 := string(b) | |
314 | + s2 := s1[1 : len(s1)-1] | |
315 | + if s1 == "null" { | |
316 | + temp = nil | |
317 | + } else { | |
318 | + t1, err = time.Parse("2006-01-02 15:04:05", s2) | |
319 | + if err != nil { | |
320 | + return err | |
321 | + } | |
322 | + temp = &t1 | |
323 | + } | |
324 | + | |
325 | + if temp != nil { | |
326 | + nt.Valid = true | |
327 | + nt.Time = *temp | |
328 | + } else { | |
329 | + nt.Valid = false | |
330 | + } | |
331 | + return nil | |
332 | +} | |
333 | + | |
334 | +func (nt *NullDateTime) CastToSQL() NullDateTime { | |
335 | + return *nt | |
336 | +} | |
337 | + | |
338 | +// NullDate ... | |
339 | +type NullDate struct { | |
340 | + Time time.Time | |
341 | + Valid bool // Valid is true if Time is not NULL | |
342 | +} | |
343 | + | |
344 | +// Scan ... | |
345 | +func (nt *NullDate) Scan(value interface{}) (err error) { | |
346 | + if value == nil { | |
347 | + nt.Time, nt.Valid = time.Time{}, false | |
348 | + return | |
349 | + } | |
350 | + | |
351 | + switch v := value.(type) { | |
352 | + case time.Time: | |
353 | + nt.Time, nt.Valid = v, true | |
354 | + return | |
355 | + case []byte: | |
356 | + nt.Time, err = parseSQLDate(string(v), time.UTC) | |
357 | + nt.Valid = (err == nil) | |
358 | + return | |
359 | + case string: | |
360 | + nt.Time, err = parseSQLDate(v, time.UTC) | |
361 | + nt.Valid = (err == nil) | |
362 | + return | |
363 | + } | |
364 | + | |
365 | + nt.Valid = false | |
366 | + return fmt.Errorf("Can't convert %T to time.Time", value) | |
367 | +} | |
368 | + | |
369 | +// Value implements the driver Valuer interface. | |
370 | +func (nt NullDate) Value() (driver.Value, error) { | |
371 | + if !nt.Valid { | |
372 | + return nil, nil | |
373 | + } | |
374 | + return nt.Time, nil | |
375 | +} | |
376 | + | |
377 | +// MarshalJSON ... | |
378 | +func (nt NullDate) MarshalJSON() ([]byte, error) { | |
379 | + if nt.Valid { | |
380 | + format := nt.Time.Format("2006-01-02") | |
381 | + return json.Marshal(format) | |
382 | + } | |
383 | + return json.Marshal(nil) | |
384 | +} | |
385 | + | |
386 | +// UnmarshalJSON ... | |
387 | +func (nt *NullDate) UnmarshalJSON(b []byte) error { | |
388 | + var temp *time.Time | |
389 | + var t1 time.Time | |
390 | + var err error | |
391 | + | |
392 | + s1 := string(b) | |
393 | + s2 := s1[1 : len(s1)-1] | |
394 | + if s1 == "null" { | |
395 | + temp = nil | |
396 | + } else { | |
397 | + t1, err = time.Parse("2006-01-02", s2) | |
398 | + if err != nil { | |
399 | + return err | |
400 | + } | |
401 | + temp = &t1 | |
402 | + } | |
403 | + | |
404 | + if temp != nil { | |
405 | + nt.Valid = true | |
406 | + nt.Time = *temp | |
407 | + } else { | |
408 | + nt.Valid = false | |
409 | + } | |
410 | + return nil | |
411 | +} | |
412 | + | |
413 | +func (nt *NullDate) CastToSQL() NullDate { | |
414 | + return *nt | |
415 | +} | |
416 | + | |
259 | 417 | // NullTime ... |
260 | 418 | type NullTime struct { |
261 | 419 | Time time.Time |
... | ... | @@ -274,11 +432,11 @@ func (nt *NullTime) Scan(value interface{}) (err error) { |
274 | 432 | nt.Time, nt.Valid = v, true |
275 | 433 | return |
276 | 434 | case []byte: |
277 | - nt.Time, err = parseDateTime(string(v), time.UTC) | |
435 | + nt.Time, err = parseSQLTime(string(v), time.UTC) | |
278 | 436 | nt.Valid = (err == nil) |
279 | 437 | return |
280 | 438 | case string: |
281 | - nt.Time, err = parseDateTime(v, time.UTC) | |
439 | + nt.Time, err = parseSQLTime(v, time.UTC) | |
282 | 440 | nt.Valid = (err == nil) |
283 | 441 | return |
284 | 442 | } |
... | ... | @@ -298,7 +456,7 @@ func (nt NullTime) Value() (driver.Value, error) { |
298 | 456 | // MarshalJSON ... |
299 | 457 | func (nt NullTime) MarshalJSON() ([]byte, error) { |
300 | 458 | if nt.Valid { |
301 | - format := nt.Time.Format("2006-01-02 15:04:05") | |
459 | + format := nt.Time.Format("15:04:05") | |
302 | 460 | return json.Marshal(format) |
303 | 461 | } |
304 | 462 | return json.Marshal(nil) |
... | ... | @@ -315,7 +473,7 @@ func (nt *NullTime) UnmarshalJSON(b []byte) error { |
315 | 473 | if s1 == "null" { |
316 | 474 | temp = nil |
317 | 475 | } else { |
318 | - t1, err = time.Parse("2006-01-02 15:04:05", s2) | |
476 | + t1, err = time.Parse("15:04:05", s2) | |
319 | 477 | if err != nil { |
320 | 478 | return err |
321 | 479 | } |
... | ... | @@ -331,7 +489,11 @@ func (nt *NullTime) UnmarshalJSON(b []byte) error { |
331 | 489 | return nil |
332 | 490 | } |
333 | 491 | |
334 | -func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { | |
492 | +func (nt *NullTime) CastToSQL() NullTime { | |
493 | + return *nt | |
494 | +} | |
495 | + | |
496 | +func parseSQLDateTime(str string, loc *time.Location) (t time.Time, err error) { | |
335 | 497 | base := "0000-00-00 00:00:00.0000000" |
336 | 498 | timeFormat := "2006-01-02 15:04:05.999999" |
337 | 499 | switch len(str) { |
... | ... | @@ -354,3 +516,51 @@ func parseDateTime(str string, loc *time.Location) (t time.Time, err error) { |
354 | 516 | |
355 | 517 | return |
356 | 518 | } |
519 | + | |
520 | +func parseSQLDate(str string, loc *time.Location) (t time.Time, err error) { | |
521 | + base := "0000-00-00" | |
522 | + timeFormat := "2006-01-02" | |
523 | + switch len(str) { | |
524 | + case 10: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM" | |
525 | + if str == base[:len(str)] { | |
526 | + return | |
527 | + } | |
528 | + t, err = time.Parse(timeFormat[:len(str)], str) | |
529 | + default: | |
530 | + err = fmt.Errorf("invalid time string: %s", str) | |
531 | + return | |
532 | + } | |
533 | + | |
534 | + // Adjust location | |
535 | + if err == nil && loc != time.UTC { | |
536 | + y, mo, d := t.Date() | |
537 | + h, mi, s := t.Clock() | |
538 | + t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil | |
539 | + } | |
540 | + | |
541 | + return | |
542 | +} | |
543 | + | |
544 | +func parseSQLTime(str string, loc *time.Location) (t time.Time, err error) { | |
545 | + base := "00:00:00.0000000" | |
546 | + timeFormat := "15:04:05.999999" | |
547 | + switch len(str) { | |
548 | + case 12, 15: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM" | |
549 | + if str == base[:len(str)] { | |
550 | + return | |
551 | + } | |
552 | + t, err = time.Parse(timeFormat[:len(str)], str) | |
553 | + default: | |
554 | + err = fmt.Errorf("invalid time string: %s", str) | |
555 | + return | |
556 | + } | |
557 | + | |
558 | + // Adjust location | |
559 | + if err == nil && loc != time.UTC { | |
560 | + y, mo, d := t.Date() | |
561 | + h, mi, s := t.Clock() | |
562 | + t, err = time.Date(y, mo, d, h, mi, s, t.Nanosecond(), loc), nil | |
563 | + } | |
564 | + | |
565 | + return | |
566 | +} | ... | ... |