From 58e1a083c9fa40e31b73fc5f7bc5db44f48c1de2 Mon Sep 17 00:00:00 2001 From: Luca Zambarda Date: Fri, 6 Dec 2024 16:32:17 +0000 Subject: [PATCH 1/2] feat: support pointer values --- .github/workflows/lint-test.yml | 46 +++++++++++++++++++++++ marshal_test.go | 49 +++++++++++++++++++++++++ reflect.go | 25 ++++++++++--- reflect_internal_test.go | 41 +++++++++++++++++++++ unmarshal_test.go | 65 ++++++++++++++++++++++++++++++--- 5 files changed, 215 insertions(+), 11 deletions(-) create mode 100644 .github/workflows/lint-test.yml diff --git a/.github/workflows/lint-test.yml b/.github/workflows/lint-test.yml new file mode 100644 index 0000000..d55f484 --- /dev/null +++ b/.github/workflows/lint-test.yml @@ -0,0 +1,46 @@ +name: Lint and Test + +on: + push: + branches: [master] + pull_request: + branches: + - master + - feature/* + - bugfix/* + - patch/* + - refactor/* + - chore/* + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout repo + uses: actions/checkout@v3 + with: + fetch-depth: 0 # needed because of the new-from-rev in golangci + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "1.23" + + - name: golangci-lint + uses: golangci/golangci-lint-action@v3 + with: + args: --timeout 5m0s + + test: + runs-on: ubuntu-latest + steps: + - name: Checkout repo + uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "1.23" + + - name: Running Unit Tests + run: go test ./... diff --git a/marshal_test.go b/marshal_test.go index 24a8bed..bd20728 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -12,6 +12,11 @@ import ( ) func TestMarshal(t *testing.T) { + t.Run("success", testMarshalSuccess) + t.Run("success pointer", testMarshalSuccessPointer) +} + +func testMarshalSuccess(t *testing.T) { expected, err := testdata.ReadFile("testdata/marshal/success.csv") if err != nil { t.Fatalf("read test file: %v", err) @@ -54,3 +59,47 @@ func TestMarshal(t *testing.T) { t.Errorf("(-expected, +got):\n%s", diff) } } + +func testMarshalSuccessPointer(t *testing.T) { + expected, err := testdata.ReadFile("testdata/marshal/success.csv") + if err != nil { + t.Fatalf("read test file: %v", err) + } + + type record struct { + FirstName string `flat:"first_name"` + LastName string `flat:"last_name"` + Ignore uint8 `flat:"-"` + Age int `flat:"age"` + Height float32 `flat:"height"` + } + + input := []*record{ + { + FirstName: "John", + LastName: "Doe", + Ignore: 123, + Age: 30, + Height: 1.75, + }, + { + FirstName: "Jane", + LastName: "Doe", + Ignore: 123, + Age: 25, + Height: 1.65, + }, + } + var got bytes.Buffer + + writer := csv.NewWriter(&got) + + err = goflat.MarshalSliceToWriter(context.Background(), input, writer, goflat.Options{}) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + if diff := cmp.Diff(string(expected), got.String()); diff != "" { + t.Errorf("(-expected, +got):\n%s", diff) + } +} diff --git a/reflect.go b/reflect.go index bf16b94..b17b953 100644 --- a/reflect.go +++ b/reflect.go @@ -9,6 +9,7 @@ import ( type structFactory[T any] struct { structType reflect.Type + pointer bool columnMap map[int]int columnValues []any columnNames []string @@ -23,17 +24,23 @@ func newFactory[T any](headers []string, options Options) (*structFactory[T], er var v T t := reflect.TypeOf(v) + rv := reflect.ValueOf(v) - if t.Kind() == reflect.Pointer { - t = t.Elem() - } + pointer := false - if t.Kind() != reflect.Struct { + switch t.Kind() { + case reflect.Struct: + case reflect.Pointer: + pointer = true + t = t.Elem() + rv = reflect.New(t).Elem() + default: return nil, fmt.Errorf("type %T: %w", v, ErrNotAStruct) } factory := &structFactory[T]{ structType: t, + pointer: pointer, columnMap: make(map[int]int, len(headers)), columnValues: make([]any, t.NumField()), columnNames: make([]string, t.NumField()), @@ -41,8 +48,6 @@ func newFactory[T any](headers []string, options Options) (*structFactory[T], er covered := make([]bool, len(headers)) - rv := reflect.ValueOf(v) - for i := range t.NumField() { fieldT := t.Field(i) fieldV := rv.Field(i) @@ -163,6 +168,10 @@ func (s *structFactory[T]) unmarshal(record []string) (T, error) { newStruct.Field(mappedIndex).Set(reflect.ValueOf(value)) } + if s.pointer { + newStruct = newStruct.Addr() + } + return newStruct.Interface().(T), nil } @@ -183,6 +192,10 @@ func (s *structFactory[T]) marshalHeaders() []string { func (s *structFactory[T]) marshal(t T, separator string) ([]string, error) { reflectValue := reflect.ValueOf(t) + if s.pointer { + reflectValue = reflectValue.Elem() + } + record := make([]string, 0, len(s.columnNames)) var strValue string diff --git a/reflect_internal_test.go b/reflect_internal_test.go index 914b2b7..8924e88 100644 --- a/reflect_internal_test.go +++ b/reflect_internal_test.go @@ -82,6 +82,7 @@ func testReflectSuccess(t *testing.T) { t.Run("duplicate", testReflectSuccessDuplicate) t.Run("simple", testReflectSuccessSimple) t.Run("subset struct", testReflectSuccessSubsetStruct) + t.Run("pointer", testReflectSuccessPointer) } func testReflectSuccessDuplicate(t *testing.T) { @@ -209,3 +210,43 @@ func testReflectSuccessSubsetStruct(t *testing.T) { t.Errorf("(-want +got):\\n%s", diff) } } + +func testReflectSuccessPointer(t *testing.T) { + type foo struct { + Name string `flat:"name"` + } + + headers := []string{"name"} + + got, err := newFactory[*foo](headers, Options{ + Strict: true, + ErrorIfDuplicateHeaders: true, + ErrorIfMissingHeaders: true, + }) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + expected := &structFactory[*foo]{ + structType: reflect.TypeOf(foo{}), + columnMap: map[int]int{0: 0}, + } + comparers := []cmp.Option{ + cmp.AllowUnexported(structFactory[*foo]{}), + cmp.Comparer(func(a, b structFactory[*foo]) bool { + if a.structType.String() != b.structType.String() { + return false + } + + if !maps.Equal(a.columnMap, b.columnMap) { + return false + } + + return true + }), + } + + if diff := cmp.Diff(expected, got, comparers...); diff != "" { + t.Errorf("(-want +got):\\n%s", diff) + } +} diff --git a/unmarshal_test.go b/unmarshal_test.go index 2cb3ce3..d4ffe0d 100644 --- a/unmarshal_test.go +++ b/unmarshal_test.go @@ -13,6 +13,7 @@ import ( func TestUnmarshal(t *testing.T) { t.Run("success", testUnmarshalSuccess) + t.Run("success pointer", testUnmarshalSuccessPointer) } //go:embed testdata @@ -53,7 +54,7 @@ func testUnmarshalSuccess(t *testing.T) { } channel := make(chan record) - assertChannel(t, channel, expected) + assertChannel(t, channel, expected, cmp.AllowUnexported(record{})) ctx := context.Background() @@ -74,7 +75,63 @@ func testUnmarshalSuccess(t *testing.T) { } } -func assertChannel[T any](t *testing.T, ch <-chan T, expected []T) { +func testUnmarshalSuccessPointer(t *testing.T) { + file, err := testdata.Open("testdata/unmarshal/success.csv") + if err != nil { + t.Fatalf("open test file: %v", err) + } + + type record struct { + FirstName string `flat:"first_name"` + LastName string `flat:"last_name"` + Age int `flat:"age"` + Height float32 `flat:"height"` + } + + expected := []*record{ + { + FirstName: "Guybrush", + LastName: "Threepwood", + Age: 28, + Height: 1.78, + }, + { + FirstName: "Elaine", + LastName: "Marley", + Age: 20, + Height: 1.6, + }, + { + FirstName: "LeChuck", + LastName: "", + Age: 100, + Height: 2.01, + }, + } + + channel := make(chan *record) + assertChannel(t, channel, expected, cmp.AllowUnexported(record{})) + + ctx := context.Background() + + csvReader, err := goflat.DetectReader(file) + if err != nil { + t.Fatalf("detect reader: %v", err) + } + + options := goflat.Options{ + Strict: true, + ErrorIfDuplicateHeaders: true, + ErrorIfMissingHeaders: true, + } + + err = goflat.UnmarshalToChannel(ctx, csvReader, options, channel) + if err != nil { + t.Fatalf("unmarshal: %v", err) + } +} + +func assertChannel[T any](t *testing.T, ch <-chan T, expected []T, cmpOpts ...cmp.Option) { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) @@ -98,11 +155,9 @@ func assertChannel[T any](t *testing.T, ch <-chan T, expected []T) { }() t.Cleanup(func() { - var zero T - <-ctx.Done() - if diff := cmp.Diff(expected, got, cmp.AllowUnexported(zero)); diff != "" { + if diff := cmp.Diff(expected, got, cmpOpts...); diff != "" { t.Errorf("(-expected,+got):\n%s", diff) } }) From bf2bf34e61ad8d447095c2ce7ffc949b79575e06 Mon Sep 17 00:00:00 2001 From: Luca Zambarda Date: Fri, 6 Dec 2024 16:34:04 +0000 Subject: [PATCH 2/2] lint --- reflect.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/reflect.go b/reflect.go index b17b953..53aff71 100644 --- a/reflect.go +++ b/reflect.go @@ -19,7 +19,7 @@ type structFactory[T any] struct { // work with them. const FieldTag = "flat" -//nolint:varnamelen // Fine-ish here. +//nolint:varnamelen,cyclop // Fine-ish here. func newFactory[T any](headers []string, options Options) (*structFactory[T], error) { var v T @@ -28,6 +28,7 @@ func newFactory[T any](headers []string, options Options) (*structFactory[T], er pointer := false + //nolint:exhaustive // Fine here, there's a default. switch t.Kind() { case reflect.Struct: case reflect.Pointer: