Skip to content

Commit

Permalink
feat: add implementation for all sql.Null types
Browse files Browse the repository at this point in the history
  • Loading branch information
gKits committed Oct 23, 2024
1 parent 142e6f7 commit c05ea42
Show file tree
Hide file tree
Showing 20 changed files with 961 additions and 0 deletions.
41 changes: 41 additions & 0 deletions bool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package sqlnull

import (
"database/sql"
"encoding/json"
"reflect"
)

type NullBool sql.NullBool

func (n NullBool) MarshalJSON() ([]byte, error) {
if !n.Valid {
return json.Marshal(nil)
}
return json.Marshal(n.Bool)
}

func (n *NullBool) UnmarshalJSON(data []byte) error {
var target *bool
if err := json.Unmarshal(data, &target); err != nil {
return err
}

n.Valid = target != nil
if n.Valid {
n.Bool = *target
} else {
n.Bool = false
}
return nil
}

func (n *NullBool) Scan(src any) error {
var sqln sql.NullBool
if err := sqln.Scan(src); err != nil {
return err
}
n.Bool = sqln.Bool
n.Valid = reflect.TypeOf(src) == nil
return nil
}
59 changes: 59 additions & 0 deletions bool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package sqlnull_test

import (
"encoding/json"
"testing"

"github.com/gkits/sqlnull"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_Bool_MarshalJSON(t *testing.T) {
cases := []struct {
name string
in sqlnull.NullBool
want string
}{
{
name: "marshal false bool to null",
in: sqlnull.NullBool{
Bool: false,
Valid: false,
},
want: "null",
},
{
name: "marshal true bool to null",
in: sqlnull.NullBool{
Bool: true,
Valid: false,
},
want: "null",
},
{
name: "marshal false bool",
in: sqlnull.NullBool{
Bool: false,
Valid: true,
},
want: "false",
},
{
name: "marshal true bool",
in: sqlnull.NullBool{
Bool: true,
Valid: true,
},
want: "true",
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got, err := json.Marshal(c.in)
require.NoError(t, err)
assert.JSONEq(t, c.want, string(got))
})
}
}
41 changes: 41 additions & 0 deletions byte.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package sqlnull

import (
"database/sql"
"encoding/json"
"reflect"
)

type NullByte sql.NullByte

func (n NullByte) MarshalJSON() ([]byte, error) {
if !n.Valid {
return json.Marshal(nil)
}
return json.Marshal(n.Byte)
}

func (n *NullByte) UnmarshalJSON(data []byte) error {
var target *byte
if err := json.Unmarshal(data, &target); err != nil {
return err
}

n.Valid = target != nil
if n.Valid {
n.Byte = *target
} else {
n.Byte = 0x00
}
return nil
}

func (n *NullByte) Scan(src any) error {
var sqln sql.NullByte
if err := sqln.Scan(src); err != nil {
return err
}
n.Byte = sqln.Byte
n.Valid = reflect.TypeOf(src) == nil
return nil
}
59 changes: 59 additions & 0 deletions byte_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package sqlnull_test

import (
"encoding/json"
"testing"

"github.com/gkits/sqlnull"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_Byte_MarshalJSON(t *testing.T) {
cases := []struct {
name string
in sqlnull.NullByte
want string
}{
{
name: "marshal zero byte to null",
in: sqlnull.NullByte{
Byte: 0x00,
Valid: false,
},
want: "null",
},
{
name: "marshal non zero byte to null",
in: sqlnull.NullByte{
Byte: 0x99,
Valid: false,
},
want: "null",
},
{
name: "marshal zero byte",
in: sqlnull.NullByte{
Byte: 0x00,
Valid: true,
},
want: "0",
},
{
name: "marshal non zero byte",
in: sqlnull.NullByte{
Byte: 0x69,
Valid: true,
},
want: "105",
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got, err := json.Marshal(c.in)
require.NoError(t, err)
assert.JSONEq(t, c.want, string(got))
})
}
}
41 changes: 41 additions & 0 deletions float64.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package sqlnull

import (
"database/sql"
"encoding/json"
"reflect"
)

type NullFloat64 sql.NullFloat64

func (n NullFloat64) MarshalJSON() ([]byte, error) {
if !n.Valid {
return json.Marshal(nil)
}
return json.Marshal(n.Float64)
}

func (n *NullFloat64) UnmarshalJSON(data []byte) error {
var target *float64
if err := json.Unmarshal(data, &target); err != nil {
return err
}

n.Valid = target != nil
if n.Valid {
n.Float64 = *target
} else {
n.Float64 = 0
}
return nil
}

func (n *NullFloat64) Scan(src any) error {
var sqln sql.NullFloat64
if err := sqln.Scan(src); err != nil {
return err
}
n.Float64 = sqln.Float64
n.Valid = reflect.TypeOf(src) == nil
return nil
}
59 changes: 59 additions & 0 deletions float64_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package sqlnull_test

import (
"encoding/json"
"testing"

"github.com/gkits/sqlnull"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_Float64_MarshalJSON(t *testing.T) {
cases := []struct {
name string
in sqlnull.NullFloat64
want string
}{
{
name: "marshal 0 float to null",
in: sqlnull.NullFloat64{
Float64: 0,
Valid: false,
},
want: "null",
},
{
name: "marshal non 0 float to null",
in: sqlnull.NullFloat64{
Float64: 1234.56789,
Valid: false,
},
want: "null",
},
{
name: "marshal 0 float",
in: sqlnull.NullFloat64{
Float64: 0.000,
Valid: true,
},
want: "0",
},
{
name: "marshal non 0 float",
in: sqlnull.NullFloat64{
Float64: 10101.0101,
Valid: true,
},
want: "10101.0101",
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got, err := json.Marshal(c.in)
require.NoError(t, err)
assert.JSONEq(t, c.want, string(got))
})
}
}
42 changes: 42 additions & 0 deletions generic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package sqlnull

import (
"database/sql"
"encoding/json"
"reflect"
)

type Null[T any] sql.Null[T]

func (n Null[T]) MarshalJSON() ([]byte, error) {
if !n.Valid {
return json.Marshal(nil)
}
return json.Marshal(n.V)
}

func (n *Null[T]) UnmarshalJSON(data []byte) error {
var target *T
if err := json.Unmarshal(data, &target); err != nil {
return err
}

n.Valid = target != nil
if n.Valid {
n.V = *target
} else {
var zero T
n.V = zero
}
return nil
}

func (n *Null[T]) Scan(src any) error {
var sqln sql.Null[T]
if err := sqln.Scan(src); err != nil {
return err
}
n.V = sqln.V
n.Valid = reflect.TypeOf(src) == nil
return nil
}
Loading

0 comments on commit c05ea42

Please sign in to comment.