Skip to content

Commit

Permalink
feat: support concrete types and custom types
Browse files Browse the repository at this point in the history
  • Loading branch information
wchargin committed Jul 1, 2024
1 parent 8d0f294 commit 267fac6
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 30 deletions.
2 changes: 0 additions & 2 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ linters:
- inamedparam
- ineffassign
- interfacebloat
- ireturn
- loggercheck
- makezero
- mirror
Expand Down Expand Up @@ -81,7 +80,6 @@ linters:
- tparallel
- typecheck
- unconvert
- unparam
- unused
- usestdlibvars
- wastedassign
Expand Down
167 changes: 140 additions & 27 deletions pgxgeom.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/binary"
"encoding/hex"
"errors"
"fmt"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
Expand Down Expand Up @@ -69,9 +70,122 @@ func (c codec) PreferredFormat() int16 {
return pgtype.BinaryFormatCode
}

// GeomScanner enables PostGIS geometry/geography values to be scanned into
// arbitrary Go types. For more context, see section "Extending Existing
// PostgreSQL Type Support" of the README for jackc/pgx/v5/pgtype.
type GeomScanner interface {
ScanGeom(v geom.T) error
}

// GeomValuer enables PostGIS geometry/geography values to be marshaled from
// arbitrary Go types. For more context, see section "Extending Existing
// PostgreSQL Type Support" of the README for jackc/pgx/v5/pgtype.
type GeomValuer interface {
GeomValue() (geom.T, error)
}

// unexpectedTypeError indicates that a PostGIS value did not meet the type
// constraints to be scanned into a particular Go value. For example, this
// occurs when attempting to scan a `geometry(point)` into a `*geom.Polygon`.
type unexpectedTypeError struct {
Got any
Want any
}

func (e unexpectedTypeError) Error() string {
return fmt.Sprintf("pgxgeom: got %T, want %T", e.Got, e.Want)
}

// unsupportedTypeError indicates that a given Go value could not be converted to
// a GeomScanner/GeomValuer. For example, this occurs if you attempt to scan
// into a `*bool`.
type unsupportedTypeError struct {
Got any
}

func (e unsupportedTypeError) Error() string {
return fmt.Sprintf("pgxgeom: unsupported type %T", e.Got)
}

// genericGeomValuer can be used to marshal generic geom.T values as well as
// any concrete value like a *geom.Point.
type genericGeomValuer struct {
value geom.T
}

func (gv genericGeomValuer) GeomValue() (geom.T, error) {
return gv.value, nil
}

// genericGeomScanner can only be used to scan into generic geom.T values. To
// scan into concrete values like a *geom.Point, a more specific scanner type
// is needed to perform the appropriate error checking.
type genericGeomScanner struct {
target *geom.T
}

func (sc genericGeomScanner) ScanGeom(v geom.T) error {
*sc.target = v
return nil
}

// concreteScanner is used to scan into a specific, concrete geom.T type.
// The type parameter T should be in *non-pointer* form, like `geom.Point`,
// such that `*T` implements `geom.T`.
type concreteScanner[T any] struct {
target *T
}

func (sc concreteScanner[T]) ScanGeom(v geom.T) error {
var vv any = v // work around "impossible type assertion" compiler error
concrete, ok := vv.(*T)
if !ok {
return unexpectedTypeError{Got: v, Want: sc.target}
}
*sc.target = *concrete
return nil
}

func getGeomScanner(v any) (GeomScanner, error) {
switch v := v.(type) {
case GeomScanner:
return v, nil
case *geom.T:
return genericGeomScanner{v}, nil
case *geom.Point:
return concreteScanner[geom.Point]{v}, nil
case *geom.LineString:
return concreteScanner[geom.LineString]{v}, nil
case *geom.Polygon:
return concreteScanner[geom.Polygon]{v}, nil
case *geom.MultiPoint:
return concreteScanner[geom.MultiPoint]{v}, nil
case *geom.MultiLineString:
return concreteScanner[geom.MultiLineString]{v}, nil
case *geom.MultiPolygon:
return concreteScanner[geom.MultiPolygon]{v}, nil
case *geom.GeometryCollection:
return concreteScanner[geom.GeometryCollection]{v}, nil
default:
return nil, unsupportedTypeError{v}
}
}

//nolint:ireturn
func getGeomValuer(v any) (GeomValuer, error) {
switch v := v.(type) {
case GeomValuer:
return v, nil
case geom.T:
return genericGeomValuer{v}, nil
default:
return nil, unsupportedTypeError{v}
}
}

// PlanEncode implements [github.com/jackc/pgx/v5/pgtype.Codec.PlanEncode].
func (c codec) PlanEncode(m *pgtype.Map, old uint32, format int16, value any) pgtype.EncodePlan {
if _, ok := value.(geom.T); !ok {
if _, err := getGeomValuer(value); err != nil {
return nil
}
switch format {
Expand All @@ -86,7 +200,7 @@ func (c codec) PlanEncode(m *pgtype.Map, old uint32, format int16, value any) pg

// PlanScan implements [github.com/jackc/pgx/v5/pgtype.Codec.PlanScan].
func (c codec) PlanScan(m *pgtype.Map, old uint32, format int16, target any) pgtype.ScanPlan {
if _, ok := target.(*geom.T); !ok {
if _, err := getGeomScanner(target); err != nil {
return nil
}
switch format {
Expand Down Expand Up @@ -122,13 +236,21 @@ func (c codec) DecodeValue(m *pgtype.Map, oid uint32, format int16, src []byte)
}
}

func encodeGeomValue(value any) (ewkbBuf []byte, err error) {
valuer, err := getGeomValuer(value)
if err != nil {
return nil, err
}
g, err := valuer.GeomValue()
if err != nil {
return nil, err
}
return ewkb.Marshal(g, nativeEndian)
}

// Encode implements [github.com/jackc/pgx/v5/pgtype.EncodePlan.Encode].
func (p binaryEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) {
g, ok := value.(geom.T)
if !ok {
return buf, errors.ErrUnsupported
}
data, err := ewkb.Marshal(g, nativeEndian)
data, err := encodeGeomValue(value)
if err != nil {
return buf, err
}
Expand All @@ -137,11 +259,7 @@ func (p binaryEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err erro

// Encode implements [github.com/jackc/pgx/v5/pgtype.EncodePlan.Encode].
func (p textEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) {
g, ok := value.(geom.T)
if !ok {
return buf, errors.ErrUnsupported
}
data, err := ewkb.Marshal(g, nativeEndian)
data, err := encodeGeomValue(value)
if err != nil {
return buf, err
}
Expand All @@ -150,33 +268,29 @@ func (p textEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error)

// Scan implements [github.com/jackc/pgx/v5/pgtype.ScanPlan.Scan].
func (p binaryScanPlan) Scan(src []byte, target any) error {
pg, ok := target.(*geom.T)
if !ok {
return errors.ErrUnsupported
scanner, err := getGeomScanner(target)
if err != nil {
return err
}
if len(src) == 0 {
*pg = nil
return nil
return scanner.ScanGeom(nil)
}
g, err := ewkb.Unmarshal(src)
if err != nil {
return err
}
*pg = g
return nil
return scanner.ScanGeom(g)
}

// Scan implements [github.com/jackc/pgx/v5/pgtype.ScanPlan.Scan].
func (p textScanPlan) Scan(src []byte, target any) error {
pg, ok := target.(*geom.T)
if !ok {
return errors.ErrUnsupported
scanner, err := getGeomScanner(target)
if err != nil {
return err
}
if len(src) == 0 {
*pg = nil
return nil
return scanner.ScanGeom(nil)
}
var err error
src, err = hex.DecodeString(string(src))
if err != nil {
return err
Expand All @@ -185,8 +299,7 @@ func (p textScanPlan) Scan(src []byte, target any) error {
if err != nil {
return err
}
*pg = g
return nil
return scanner.ScanGeom(g)
}

// Register registers a codec for [github.com/twpayne/go-geom.T] types on conn.
Expand Down
111 changes: 110 additions & 1 deletion pgxgeom_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pgxgeom_test
import (
"context"
"encoding/binary"
"errors"
"strconv"
"testing"

Expand Down Expand Up @@ -72,7 +73,33 @@ func TestCodecDecodeNullValue(t *testing.T) {
tb.Helper()

rows, err := conn.Query(ctx, "select NULL::geometry AS geom", pgx.QueryResultFormats{format})
assert.NoError(tb, err)
assert.NoError(t, err)

value, err := pgx.CollectExactlyOneRow(rows, pgx.RowToStructByName[s])
assert.NoError(t, err)
assert.Zero(t, value)
})
}
})
}

func TestCodecDecodeNullValuePolymorphic(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, tb testing.TB, conn *pgx.Conn) {
tb.Helper()

type s struct {
Geom *geom.Point `db:"geom"`
}

for _, format := range []int16{
pgx.BinaryFormatCode,
pgx.TextFormatCode,
} {
tb.(*testing.T).Run(strconv.Itoa(int(format)), func(t *testing.T) {
tb.Helper()

rows, err := conn.Query(ctx, "select NULL::geometry AS geom", pgx.QueryResultFormats{format})
assert.NoError(t, err)

value, err := pgx.CollectExactlyOneRow(rows, pgx.RowToStructByName[s])
assert.NoError(t, err)
Expand Down Expand Up @@ -115,6 +142,88 @@ func TestCodecScanValue(t *testing.T) {
})
}

func TestCodecScanValuePolymorphic(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, tb testing.TB, conn *pgx.Conn) {
tb.Helper()
for _, format := range []int16{
pgx.BinaryFormatCode,
pgx.TextFormatCode,
} {
tb.(*testing.T).Run(strconv.Itoa(int(format)), func(t *testing.T) {
var point geom.Point
var polygon geom.Polygon
var err error
query := "select ST_SetSRID('POLYGON((0 0,1 0,1 1,0 1,0 0))'::geometry, 4326)"

err = conn.QueryRow(ctx, query, pgx.QueryResultFormats{format}).Scan(&polygon)
assert.NoError(t, err)
assert.Equal(t, mustNewGeomFromWKT(t, "POLYGON((0 0,1 0,1 1,0 1,0 0))", 4326), geom.T(&polygon))

err = conn.QueryRow(ctx, query, pgx.QueryResultFormats{format}).Scan(&point)
assert.EqualError(t, err, "can't scan into dest[0]: pgxgeom: got *geom.Polygon, want *geom.Point")
})
}
})
}

type CustomPoint struct {
*geom.Point
}

var errCustomPointScan = errors.New("invalid target for CustomPoint")

func (c *CustomPoint) ScanGeom(v geom.T) error {
concrete, ok := v.(*geom.Point)
if !ok {
return errCustomPointScan
}
c.Point = concrete
return nil
}

func (c *CustomPoint) GeomValue() (geom.T, error) {
return c.Point, nil
}

func TestCodecEncodeValueCustom(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, tb testing.TB, conn *pgx.Conn) {
tb.Helper()
point := CustomPoint{geom.NewPointFlat(geom.XY, []float64{1, 2}).SetSRID(4326)}

var bytes []byte
err := conn.QueryRow(ctx, "select $1::geometry::bytea", &point).Scan(&bytes)
assert.NoError(t, err)

g, err := ewkb.Unmarshal(bytes)
assert.NoError(t, err)
assert.Equal(t, mustNewGeomFromWKT(t, "POINT(1 2)", 4326), g)
})
}

func TestCodecScanValueCustom(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, tb testing.TB, conn *pgx.Conn) {
tb.Helper()
for _, format := range []int16{
pgx.BinaryFormatCode,
pgx.TextFormatCode,
} {
tb.(*testing.T).Run(strconv.Itoa(int(format)), func(t *testing.T) {
var point CustomPoint
var err error
pointQuery := "select ST_SetSRID('POINT(1 2)'::geometry, 4326)"
polygonQuery := "select ST_SetSRID('POLYGON((0 0,1 0,1 1,0 1,0 0))'::geometry, 4326)"

err = conn.QueryRow(ctx, pointQuery, pgx.QueryResultFormats{format}).Scan(&point)
assert.NoError(t, err)
assert.Equal(t, mustNewGeomFromWKT(t, "POINT(1 2)", 4326), geom.T(point.Point))

err = conn.QueryRow(ctx, polygonQuery, pgx.QueryResultFormats{format}).Scan(&point)
assert.EqualError(t, err, "can't scan into dest[0]: invalid target for CustomPoint")
})
}
})
}

func mustEWKB(tb testing.TB, g geom.T) []byte {
tb.Helper()
data, err := ewkb.Marshal(g, binary.LittleEndian)
Expand Down

0 comments on commit 267fac6

Please sign in to comment.