Skip to content

Commit 9989dd0

Browse files
olavloiterahul2393
andauthored
feat(spanner): allow string values for Scan functions (#11898)
Support scanning string values for the various spanner.NullTypes. Use the Scan implementation of the underlying data type for spanner.NullDate, as civil.Date now implements the sql.Scanner interface. The Scan implementations are used by the database/sql driver and applications that depend on the database/sql driver. Co-authored-by: rahul2393 <irahul@google.com>
1 parent 12465b5 commit 9989dd0

File tree

2 files changed

+196
-1
lines changed

2 files changed

+196
-1
lines changed

spanner/value.go

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,20 @@ func (n *NullInt64) Scan(value interface{}) error {
268268
case NullInt64:
269269
n.Int64 = p.Int64
270270
n.Valid = p.Valid
271+
case string:
272+
i64, err := strconv.ParseInt(p, 10, 64)
273+
if err != nil {
274+
return err
275+
}
276+
n.Int64 = i64
277+
n.Valid = true
278+
case *string:
279+
i64, err := strconv.ParseInt(*p, 10, 64)
280+
if err != nil {
281+
return err
282+
}
283+
n.Int64 = i64
284+
n.Valid = true
271285
}
272286
return nil
273287
}
@@ -433,6 +447,20 @@ func (n *NullFloat64) Scan(value interface{}) error {
433447
case NullFloat64:
434448
n.Float64 = p.Float64
435449
n.Valid = p.Valid
450+
case string:
451+
f, err := strconv.ParseFloat(p, 64)
452+
if err != nil {
453+
return err
454+
}
455+
n.Float64 = f
456+
n.Valid = true
457+
case *string:
458+
f, err := strconv.ParseFloat(*p, 64)
459+
if err != nil {
460+
return err
461+
}
462+
n.Float64 = f
463+
n.Valid = true
436464
}
437465
return nil
438466
}
@@ -513,6 +541,20 @@ func (n *NullFloat32) Scan(value interface{}) error {
513541
case NullFloat32:
514542
n.Float32 = p.Float32
515543
n.Valid = p.Valid
544+
case string:
545+
f, err := strconv.ParseFloat(p, 32)
546+
if err != nil {
547+
return err
548+
}
549+
n.Float32 = float32(f)
550+
n.Valid = true
551+
case *string:
552+
f, err := strconv.ParseFloat(*p, 32)
553+
if err != nil {
554+
return err
555+
}
556+
n.Float32 = float32(f)
557+
n.Valid = true
516558
}
517559
return nil
518560
}
@@ -593,6 +635,20 @@ func (n *NullBool) Scan(value interface{}) error {
593635
case NullBool:
594636
n.Bool = p.Bool
595637
n.Valid = p.Valid
638+
case string:
639+
f, err := strconv.ParseBool(p)
640+
if err != nil {
641+
return err
642+
}
643+
n.Bool = f
644+
n.Valid = true
645+
case *string:
646+
f, err := strconv.ParseBool(*p)
647+
if err != nil {
648+
return err
649+
}
650+
n.Bool = f
651+
n.Valid = true
596652
}
597653
return nil
598654
}
@@ -678,6 +734,20 @@ func (n *NullTime) Scan(value interface{}) error {
678734
case NullTime:
679735
n.Time = p.Time
680736
n.Valid = p.Valid
737+
case string:
738+
f, err := time.Parse(time.RFC3339Nano, p)
739+
if err != nil {
740+
return err
741+
}
742+
n.Time = f
743+
n.Valid = true
744+
case *string:
745+
f, err := time.Parse(time.RFC3339Nano, *p)
746+
if err != nil {
747+
return err
748+
}
749+
n.Time = f
750+
n.Valid = true
681751
}
682752
return nil
683753
}
@@ -752,7 +822,12 @@ func (n *NullDate) Scan(value interface{}) error {
752822
n.Valid = true
753823
switch p := value.(type) {
754824
default:
755-
return spannerErrorf(codes.InvalidArgument, "invalid type for NullDate: %v", p)
825+
d := civil.Date{}
826+
if err := d.Scan(value); err != nil {
827+
return err
828+
}
829+
n.Date = d
830+
n.Valid = true
756831
case *civil.Date:
757832
n.Date = *p
758833
case civil.Date:
@@ -848,6 +923,20 @@ func (n *NullNumeric) Scan(value interface{}) error {
848923
case NullNumeric:
849924
n.Numeric = p.Numeric
850925
n.Valid = p.Valid
926+
case string:
927+
y, ok := (&big.Rat{}).SetString(p)
928+
if !ok {
929+
return errUnexpectedNumericStr(p)
930+
}
931+
n.Numeric = *y
932+
n.Valid = true
933+
case *string:
934+
y, ok := (&big.Rat{}).SetString(*p)
935+
if !ok {
936+
return errUnexpectedNumericStr(*p)
937+
}
938+
n.Numeric = *y
939+
n.Valid = true
851940
}
852941
return nil
853942
}

spanner/value_test.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3361,3 +3361,109 @@ func TestDecodeProtoArrayUsingBaseVariant(t *testing.T) {
33613361
t.Errorf("%s: got %+v, want %+v", "Test PROTO decode to [][]byte custom type", nb, b)
33623362
}
33633363
}
3364+
3365+
func TestScanNullInt64(t *testing.T) {
3366+
for _, val := range []any{"99", stringPointer("99")} {
3367+
n := NullInt64{}
3368+
if err := n.Scan(val); err != nil {
3369+
t.Fatal(err)
3370+
}
3371+
want := NullInt64{Int64: 99, Valid: true}
3372+
if g, w := n, want; !cmp.Equal(g, w) {
3373+
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
3374+
}
3375+
}
3376+
}
3377+
3378+
func TestScanNullString(t *testing.T) {
3379+
for _, val := range []any{"foo", stringPointer("foo")} {
3380+
n := NullString{}
3381+
if err := n.Scan(val); err != nil {
3382+
t.Fatal(err)
3383+
}
3384+
want := NullString{StringVal: "foo", Valid: true}
3385+
if g, w := n, want; !cmp.Equal(g, w) {
3386+
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
3387+
}
3388+
}
3389+
}
3390+
3391+
func TestScanNullFloat64(t *testing.T) {
3392+
for _, val := range []any{"3.14", stringPointer("3.14")} {
3393+
n := NullFloat64{}
3394+
if err := n.Scan(val); err != nil {
3395+
t.Fatal(err)
3396+
}
3397+
want := NullFloat64{Float64: 3.14, Valid: true}
3398+
if g, w := n, want; !cmp.Equal(g, w) {
3399+
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
3400+
}
3401+
}
3402+
}
3403+
3404+
func TestScanNullFloat32(t *testing.T) {
3405+
for _, val := range []any{"3.14", stringPointer("3.14")} {
3406+
n := NullFloat32{}
3407+
if err := n.Scan(val); err != nil {
3408+
t.Fatal(err)
3409+
}
3410+
want := NullFloat32{Float32: float32(3.14), Valid: true}
3411+
if g, w := n, want; !cmp.Equal(g, w) {
3412+
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
3413+
}
3414+
}
3415+
}
3416+
3417+
func TestScanNullBool(t *testing.T) {
3418+
for _, val := range []any{"true", stringPointer("true")} {
3419+
n := NullBool{}
3420+
if err := n.Scan(val); err != nil {
3421+
t.Fatal(err)
3422+
}
3423+
want := NullBool{Bool: true, Valid: true}
3424+
if g, w := n, want; !cmp.Equal(g, w) {
3425+
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
3426+
}
3427+
}
3428+
}
3429+
3430+
func TestScanNullTime(t *testing.T) {
3431+
for _, val := range []any{"2025-03-25T17:54:00+01:00", stringPointer("2025-03-25T17:54:00+01:00")} {
3432+
n := NullTime{}
3433+
if err := n.Scan(val); err != nil {
3434+
t.Fatal(err)
3435+
}
3436+
tm, _ := time.Parse(time.RFC3339Nano, "2025-03-25T16:54:00Z")
3437+
want := NullTime{Time: tm, Valid: true}
3438+
if g, w := n, want; !cmp.Equal(g, w) {
3439+
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
3440+
}
3441+
}
3442+
}
3443+
3444+
func TestScanNullDate(t *testing.T) {
3445+
for _, val := range []any{"2025-03-25", stringPointer("2025-03-25")} {
3446+
n := NullDate{}
3447+
if err := n.Scan(val); err != nil {
3448+
t.Fatal(err)
3449+
}
3450+
want := NullDate{Date: civil.Date{Year: 2025, Month: 3, Day: 25}, Valid: true}
3451+
if g, w := n, want; !cmp.Equal(g, w) {
3452+
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
3453+
}
3454+
}
3455+
}
3456+
3457+
func TestScanNullNumeric(t *testing.T) {
3458+
for _, val := range []any{"3.14", stringPointer("3.14")} {
3459+
n := NullNumeric{}
3460+
if err := n.Scan(val); err != nil {
3461+
t.Fatal(err)
3462+
}
3463+
r, _ := (&big.Rat{}).SetString("3.14")
3464+
want := NullNumeric{Numeric: *r, Valid: true}
3465+
if g, w := n, want; !reflect.DeepEqual(g, w) {
3466+
t.Fatalf("value mismatch\n Got: %v\nWant: %v", g, w)
3467+
}
3468+
}
3469+
}

0 commit comments

Comments
 (0)