-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add implementation for all sql.Null types
- Loading branch information
Showing
20 changed files
with
961 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
package sqlnull | ||
|
||
import ( | ||
"database/sql" | ||
"encoding/json" | ||
"reflect" | ||
) | ||
|
||
type NullBool sql.NullBool | ||
|
||
func (n NullBool) MarshalJSON() ([]byte, error) { | ||
if !n.Valid { | ||
return json.Marshal(nil) | ||
} | ||
return json.Marshal(n.Bool) | ||
} | ||
|
||
func (n *NullBool) UnmarshalJSON(data []byte) error { | ||
var target *bool | ||
if err := json.Unmarshal(data, &target); err != nil { | ||
return err | ||
} | ||
|
||
n.Valid = target != nil | ||
if n.Valid { | ||
n.Bool = *target | ||
} else { | ||
n.Bool = false | ||
} | ||
return nil | ||
} | ||
|
||
func (n *NullBool) Scan(src any) error { | ||
var sqln sql.NullBool | ||
if err := sqln.Scan(src); err != nil { | ||
return err | ||
} | ||
n.Bool = sqln.Bool | ||
n.Valid = reflect.TypeOf(src) == nil | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
package sqlnull_test | ||
|
||
import ( | ||
"encoding/json" | ||
"testing" | ||
|
||
"github.com/gkits/sqlnull" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func Test_Bool_MarshalJSON(t *testing.T) { | ||
cases := []struct { | ||
name string | ||
in sqlnull.NullBool | ||
want string | ||
}{ | ||
{ | ||
name: "marshal false bool to null", | ||
in: sqlnull.NullBool{ | ||
Bool: false, | ||
Valid: false, | ||
}, | ||
want: "null", | ||
}, | ||
{ | ||
name: "marshal true bool to null", | ||
in: sqlnull.NullBool{ | ||
Bool: true, | ||
Valid: false, | ||
}, | ||
want: "null", | ||
}, | ||
{ | ||
name: "marshal false bool", | ||
in: sqlnull.NullBool{ | ||
Bool: false, | ||
Valid: true, | ||
}, | ||
want: "false", | ||
}, | ||
{ | ||
name: "marshal true bool", | ||
in: sqlnull.NullBool{ | ||
Bool: true, | ||
Valid: true, | ||
}, | ||
want: "true", | ||
}, | ||
} | ||
|
||
for _, c := range cases { | ||
t.Run(c.name, func(t *testing.T) { | ||
got, err := json.Marshal(c.in) | ||
require.NoError(t, err) | ||
assert.JSONEq(t, c.want, string(got)) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
package sqlnull | ||
|
||
import ( | ||
"database/sql" | ||
"encoding/json" | ||
"reflect" | ||
) | ||
|
||
type NullByte sql.NullByte | ||
|
||
func (n NullByte) MarshalJSON() ([]byte, error) { | ||
if !n.Valid { | ||
return json.Marshal(nil) | ||
} | ||
return json.Marshal(n.Byte) | ||
} | ||
|
||
func (n *NullByte) UnmarshalJSON(data []byte) error { | ||
var target *byte | ||
if err := json.Unmarshal(data, &target); err != nil { | ||
return err | ||
} | ||
|
||
n.Valid = target != nil | ||
if n.Valid { | ||
n.Byte = *target | ||
} else { | ||
n.Byte = 0x00 | ||
} | ||
return nil | ||
} | ||
|
||
func (n *NullByte) Scan(src any) error { | ||
var sqln sql.NullByte | ||
if err := sqln.Scan(src); err != nil { | ||
return err | ||
} | ||
n.Byte = sqln.Byte | ||
n.Valid = reflect.TypeOf(src) == nil | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
package sqlnull_test | ||
|
||
import ( | ||
"encoding/json" | ||
"testing" | ||
|
||
"github.com/gkits/sqlnull" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func Test_Byte_MarshalJSON(t *testing.T) { | ||
cases := []struct { | ||
name string | ||
in sqlnull.NullByte | ||
want string | ||
}{ | ||
{ | ||
name: "marshal zero byte to null", | ||
in: sqlnull.NullByte{ | ||
Byte: 0x00, | ||
Valid: false, | ||
}, | ||
want: "null", | ||
}, | ||
{ | ||
name: "marshal non zero byte to null", | ||
in: sqlnull.NullByte{ | ||
Byte: 0x99, | ||
Valid: false, | ||
}, | ||
want: "null", | ||
}, | ||
{ | ||
name: "marshal zero byte", | ||
in: sqlnull.NullByte{ | ||
Byte: 0x00, | ||
Valid: true, | ||
}, | ||
want: "0", | ||
}, | ||
{ | ||
name: "marshal non zero byte", | ||
in: sqlnull.NullByte{ | ||
Byte: 0x69, | ||
Valid: true, | ||
}, | ||
want: "105", | ||
}, | ||
} | ||
|
||
for _, c := range cases { | ||
t.Run(c.name, func(t *testing.T) { | ||
got, err := json.Marshal(c.in) | ||
require.NoError(t, err) | ||
assert.JSONEq(t, c.want, string(got)) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
package sqlnull | ||
|
||
import ( | ||
"database/sql" | ||
"encoding/json" | ||
"reflect" | ||
) | ||
|
||
type NullFloat64 sql.NullFloat64 | ||
|
||
func (n NullFloat64) MarshalJSON() ([]byte, error) { | ||
if !n.Valid { | ||
return json.Marshal(nil) | ||
} | ||
return json.Marshal(n.Float64) | ||
} | ||
|
||
func (n *NullFloat64) UnmarshalJSON(data []byte) error { | ||
var target *float64 | ||
if err := json.Unmarshal(data, &target); err != nil { | ||
return err | ||
} | ||
|
||
n.Valid = target != nil | ||
if n.Valid { | ||
n.Float64 = *target | ||
} else { | ||
n.Float64 = 0 | ||
} | ||
return nil | ||
} | ||
|
||
func (n *NullFloat64) Scan(src any) error { | ||
var sqln sql.NullFloat64 | ||
if err := sqln.Scan(src); err != nil { | ||
return err | ||
} | ||
n.Float64 = sqln.Float64 | ||
n.Valid = reflect.TypeOf(src) == nil | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
package sqlnull_test | ||
|
||
import ( | ||
"encoding/json" | ||
"testing" | ||
|
||
"github.com/gkits/sqlnull" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func Test_Float64_MarshalJSON(t *testing.T) { | ||
cases := []struct { | ||
name string | ||
in sqlnull.NullFloat64 | ||
want string | ||
}{ | ||
{ | ||
name: "marshal 0 float to null", | ||
in: sqlnull.NullFloat64{ | ||
Float64: 0, | ||
Valid: false, | ||
}, | ||
want: "null", | ||
}, | ||
{ | ||
name: "marshal non 0 float to null", | ||
in: sqlnull.NullFloat64{ | ||
Float64: 1234.56789, | ||
Valid: false, | ||
}, | ||
want: "null", | ||
}, | ||
{ | ||
name: "marshal 0 float", | ||
in: sqlnull.NullFloat64{ | ||
Float64: 0.000, | ||
Valid: true, | ||
}, | ||
want: "0", | ||
}, | ||
{ | ||
name: "marshal non 0 float", | ||
in: sqlnull.NullFloat64{ | ||
Float64: 10101.0101, | ||
Valid: true, | ||
}, | ||
want: "10101.0101", | ||
}, | ||
} | ||
|
||
for _, c := range cases { | ||
t.Run(c.name, func(t *testing.T) { | ||
got, err := json.Marshal(c.in) | ||
require.NoError(t, err) | ||
assert.JSONEq(t, c.want, string(got)) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
package sqlnull | ||
|
||
import ( | ||
"database/sql" | ||
"encoding/json" | ||
"reflect" | ||
) | ||
|
||
type Null[T any] sql.Null[T] | ||
|
||
func (n Null[T]) MarshalJSON() ([]byte, error) { | ||
if !n.Valid { | ||
return json.Marshal(nil) | ||
} | ||
return json.Marshal(n.V) | ||
} | ||
|
||
func (n *Null[T]) UnmarshalJSON(data []byte) error { | ||
var target *T | ||
if err := json.Unmarshal(data, &target); err != nil { | ||
return err | ||
} | ||
|
||
n.Valid = target != nil | ||
if n.Valid { | ||
n.V = *target | ||
} else { | ||
var zero T | ||
n.V = zero | ||
} | ||
return nil | ||
} | ||
|
||
func (n *Null[T]) Scan(src any) error { | ||
var sqln sql.Null[T] | ||
if err := sqln.Scan(src); err != nil { | ||
return err | ||
} | ||
n.V = sqln.V | ||
n.Valid = reflect.TypeOf(src) == nil | ||
return nil | ||
} |
Oops, something went wrong.