From cef5e277d0a159ee9d4d92c15e5cab8bbbe10e41 Mon Sep 17 00:00:00 2001 From: Sam Lown Date: Wed, 16 Oct 2024 09:19:17 +0000 Subject: [PATCH] UUID: compatibiltiy for SQL serialization --- CHANGELOG.md | 4 ++++ uuid/sql.go | 37 +++++++++++++++++++++++++++++++++++ uuid/sql_test.go | 50 +++++++++++++++++++++++++++++++++++++++++++++++ uuid/uuid.go | 25 ++++++++++++++++++++++++ uuid/uuid_test.go | 37 +++++++++++++++++++++++++++++++++++ 5 files changed, 153 insertions(+) create mode 100644 uuid/sql.go create mode 100644 uuid/sql_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 662d435e..55b5b25c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/) and this p ## [Unreleased] +### Added + +- `uuid` - SQL library compatibility for type conversion. + ### Fixed - `bill.Invoice` - remove empty taxes instances. diff --git a/uuid/sql.go b/uuid/sql.go new file mode 100644 index 00000000..3dd16328 --- /dev/null +++ b/uuid/sql.go @@ -0,0 +1,37 @@ +package uuid + +import ( + "database/sql" + "database/sql/driver" + "fmt" +) + +var _ driver.Valuer = UUID("") +var _ sql.Scanner = (*UUID)(nil) + +// Value implements the driver.Valuer interface. +func (u UUID) Value() (driver.Value, error) { + return u.String(), nil +} + +// Scan implements the sql.Scanner interface. +// A 16-byte slice will be handled by UnmarshalBinary, while +// a longer byte slice or a string will be handled by UnmarshalText. +func (u *UUID) Scan(src interface{}) error { + switch src := src.(type) { + case UUID: + *u = src + return nil + case []byte: + if len(src) == Size { + return u.UnmarshalBinary(src) + } + return u.UnmarshalText(src) + case string: + uu, err := Parse(src) + *u = uu + return err + } + + return fmt.Errorf("cannot convert %T to UUID", src) +} diff --git a/uuid/sql_test.go b/uuid/sql_test.go new file mode 100644 index 00000000..4fe7beb8 --- /dev/null +++ b/uuid/sql_test.go @@ -0,0 +1,50 @@ +package uuid_test + +import ( + "testing" + + "github.com/invopop/gobl/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValue(t *testing.T) { + u := uuid.MustParse("6ba7b810-9dad-11d1-80b4-00c04fd430c8") + v, err := u.Value() + require.NoError(t, err) + assert.Equal(t, u.String(), v) +} + +func TestScan(t *testing.T) { + u := uuid.MustParse("6ba7b810-9dad-11d1-80b4-00c04fd430c8") + t.Run("with UUID", func(t *testing.T) { + var uu uuid.UUID + err := uu.Scan(u) + require.NoError(t, err) + assert.Equal(t, u, uu) + }) + t.Run("with string", func(t *testing.T) { + var uu uuid.UUID + err := uu.Scan(u.String()) + require.NoError(t, err) + assert.Equal(t, u, uu) + }) + t.Run("with []byte text", func(t *testing.T) { + var uu uuid.UUID + err := uu.Scan([]byte(u.String())) + require.NoError(t, err) + assert.Equal(t, u, uu) + }) + t.Run("with bytes", func(t *testing.T) { + var uu uuid.UUID + err := uu.Scan(u.Bytes()) + require.NoError(t, err) + assert.Equal(t, u, uu) + }) + + t.Run("with int", func(t *testing.T) { + var uu uuid.UUID + err := uu.Scan(42) + require.ErrorContains(t, err, "cannot convert int to UUI") + }) +} diff --git a/uuid/uuid.go b/uuid/uuid.go index 47717b09..3d1e7140 100644 --- a/uuid/uuid.go +++ b/uuid/uuid.go @@ -32,6 +32,9 @@ const ( Zero UUID = "00000000-0000-0000-0000-000000000000" ) +// Size is the number of bytes in a UUID. +const Size = 16 + var ( regexpSimpleUUID = regexp.MustCompile("^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$") ) @@ -195,6 +198,12 @@ func (u UUID) Base64() string { return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(b) } +// Bytes returns a byte slice of the UUID. +func (u UUID) Bytes() []byte { + id := parse(u) + return id[:] +} + // ParseBase64 will attempt to decode a Base64 string into a UUID. If the string // is already a regular UUID, it will be parsed and returned using the regular // Parse method. @@ -273,6 +282,22 @@ func Normalize(u *UUID) { } } +// UnmarshalBinary will convert a 16 byte slice into a UUID +func (u *UUID) UnmarshalBinary(data []byte) error { + id, err := uuid.FromBytes(data) + if err != nil { + return err + } + *u = UUID(id.String()) + return nil +} + +// MarshalBinary implements the encoding.BinaryMarshaler interface. +func (u *UUID) MarshalBinary() ([]byte, error) { + id := parse(*u) + return id.MarshalBinary() +} + // UnmarshalText will ensure the UUID is always a valid UUID when unmarshalling // and just return an empty value if incorrect. // TODO: Remove this and instead depend on validation to provide more readable errors. diff --git a/uuid/uuid_test.go b/uuid/uuid_test.go index 1878262a..027ad12c 100644 --- a/uuid/uuid_test.go +++ b/uuid/uuid_test.go @@ -36,6 +36,12 @@ func TestUUIDParsing(t *testing.T) { assert.True(t, u1.IsZero()) u1 = uuid.ShouldParse(v1s) assert.Equal(t, v1s, u1.String()) + + t.Run("must parse", func(t *testing.T) { + assert.Panics(t, func() { + uuid.MustParse("fooo") + }) + }) } func TestUUIDIsZero(t *testing.T) { @@ -73,6 +79,15 @@ func TestUUIDTimestasmp(t *testing.T) { assert.True(t, ts.IsZero()) } +func TestNodeID(t *testing.T) { + a := uuid.NodeID() + assert.Len(t, a, 12) + + uuid.SetRandomNodeID() + assert.Len(t, uuid.NodeID(), 12) + assert.NotEqual(t, a, uuid.NodeID()) +} + func TestUUIDJSON(t *testing.T) { v1s := "03907310-8daa-11eb-8dcd-0242ac130003" type testJSON struct { @@ -225,6 +240,9 @@ func TestNormalize(t *testing.T) { uuid.Normalize(&u3) assert.Equal(t, "03907310-8daa-11eb-8dcd-0242ac130003", u3.String()) + assert.NotPanics(t, func() { + uuid.Normalize(nil) + }) } func TestUUIDv3(t *testing.T) { @@ -257,3 +275,22 @@ func TestUUIDBase64(t *testing.T) { require.NoError(t, err) assert.Equal(t, u.String(), u2.String()) } + +func TestUUIDBunary(t *testing.T) { + u := uuid.MustParse("f47ac10b-58cc-0372-8567-0e02b2c3d479") + + b := u.Bytes() + assert.Equal(t, 16, len(b)) + + out, err := u.MarshalBinary() + require.NoError(t, err) + assert.Equal(t, b, out) + + u2 := new(uuid.UUID) + err = u2.UnmarshalBinary(b) + require.NoError(t, err) + + u2 = new(uuid.UUID) + err = u2.UnmarshalBinary([]byte("invalid")) + assert.ErrorContains(t, err, "invalid UUID (got 7 bytes)") +}