diff --git a/errors.go b/errors.go index 115072c..4c267e6 100644 --- a/errors.go +++ b/errors.go @@ -16,9 +16,6 @@ var ( // does not appear in the input file. Only returned if // [Option.ErrorIfMissingHeaders] is set to true. ErrMissingHeader = errors.New("missing header") - // ErrMismatchedFields is returned when the input structs have inconsistent - // fields. In theory this will never be returned. - ErrMismatchedFields = errors.New("mismatched fields") // ErrUnsupportedType is returned when the unmarshaller encounters an // unsupported type. ErrUnsupportedType = errors.New("unsupported type") diff --git a/reflect.go b/reflect.go index 19adce8..bf16b94 100644 --- a/reflect.go +++ b/reflect.go @@ -9,7 +9,7 @@ import ( type structFactory[T any] struct { structType reflect.Type - columnMap []int + columnMap map[int]int columnValues []any columnNames []string } @@ -18,10 +18,6 @@ type structFactory[T any] struct { // work with them. const FieldTag = "flat" -// columnMapIgnore is used to mark a column as ignored. This is needed if there -// are duplicate headers that must be skipped. -const columnMapIgnore = -1 - //nolint:varnamelen // Fine-ish here. func newFactory[T any](headers []string, options Options) (*structFactory[T], error) { var v T @@ -38,7 +34,7 @@ func newFactory[T any](headers []string, options Options) (*structFactory[T], er factory := &structFactory[T]{ structType: t, - columnMap: make([]int, len(headers)), + columnMap: make(map[int]int, len(headers)), columnValues: make([]any, t.NumField()), columnNames: make([]string, t.NumField()), } @@ -80,10 +76,6 @@ func newFactory[T any](headers []string, options Options) (*structFactory[T], er return nil, fmt.Errorf("header %q, index %d and %d: %w", header, j, handledAt, ErrDuplicatedHeader) } - // If the duplicate headers error flag is diabled, then mark the - // column as ignored and continue. - factory.columnMap[j] = columnMapIgnore - continue } @@ -103,9 +95,6 @@ func newFactory[T any](headers []string, options Options) (*structFactory[T], er //nolint:forcetypeassert,gocyclo,cyclop,ireturn // Fine for now. func (s *structFactory[T]) unmarshal(record []string) (T, error) { var zero T - if len(record) != len(s.columnMap) { - return zero, fmt.Errorf("expected %d fields, got %d: %w", len(s.columnMap), len(record), ErrMismatchedFields) - } newStruct := reflect.New(s.structType).Elem() @@ -114,11 +103,12 @@ func (s *structFactory[T]) unmarshal(record []string) (T, error) { //nolint:varnamelen // Fine here. for i, column := range record { - if s.columnMap[i] == columnMapIgnore { + mappedIndex, found := s.columnMap[i] + if !found { continue } - columnBaseValue := s.columnValues[s.columnMap[i]] + columnBaseValue := s.columnValues[mappedIndex] // special case if u, ok := columnBaseValue.(Unmarshaller); ok { @@ -170,7 +160,7 @@ func (s *structFactory[T]) unmarshal(record []string) (T, error) { return zero, fmt.Errorf("parse column %d: %w", i, err) } - newStruct.Field(s.columnMap[i]).Set(reflect.ValueOf(value)) + newStruct.Field(mappedIndex).Set(reflect.ValueOf(value)) } return newStruct.Interface().(T), nil diff --git a/reflect_internal_test.go b/reflect_internal_test.go index 8d628f3..914b2b7 100644 --- a/reflect_internal_test.go +++ b/reflect_internal_test.go @@ -1,8 +1,8 @@ package goflat import ( + "maps" "reflect" - "slices" "testing" "github.com/google/go-cmp/cmp" @@ -81,6 +81,7 @@ func testReflectErrorDuplicate(t *testing.T) { func testReflectSuccess(t *testing.T) { t.Run("duplicate", testReflectSuccessDuplicate) t.Run("simple", testReflectSuccessSimple) + t.Run("subset struct", testReflectSuccessSubsetStruct) } func testReflectSuccessDuplicate(t *testing.T) { @@ -102,7 +103,7 @@ func testReflectSuccessDuplicate(t *testing.T) { expected := &structFactory[foo]{ structType: reflect.TypeOf(foo{}), - columnMap: []int{0, 1, -1}, + columnMap: map[int]int{0: 0, 1: 1}, columnValues: []any{"", int(0)}, columnNames: []string{"name", "age"}, } @@ -113,7 +114,7 @@ func testReflectSuccessDuplicate(t *testing.T) { return false } - if !slices.Equal(a.columnMap, b.columnMap) { + if !maps.Equal(a.columnMap, b.columnMap) { return false } @@ -146,7 +147,7 @@ func testReflectSuccessSimple(t *testing.T) { expected := &structFactory[foo]{ structType: reflect.TypeOf(foo{}), - columnMap: []int{0, 1}, + columnMap: map[int]int{0: 0, 1: 1}, } comparers := []cmp.Option{ cmp.AllowUnexported(structFactory[foo]{}), @@ -155,7 +156,48 @@ func testReflectSuccessSimple(t *testing.T) { return false } - if !slices.Equal(a.columnMap, b.columnMap) { + if !maps.Equal(a.columnMap, b.columnMap) { + return false + } + + return true + }), + } + + if diff := cmp.Diff(expected, got, comparers...); diff != "" { + t.Errorf("(-want +got):\\n%s", diff) + } +} + +func testReflectSuccessSubsetStruct(t *testing.T) { + type foo struct { + Col2 float32 `flat:"col2"` + } + + headers := []string{"col1", "col2", "col3"} + + got, err := newFactory[foo](headers, Options{ + Strict: false, + ErrorIfDuplicateHeaders: false, + ErrorIfMissingHeaders: false, + }) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + expected := &structFactory[foo]{ + structType: reflect.TypeOf(foo{}), + columnMap: map[int]int{1: 0}, + columnValues: []any{float32(0)}, + } + comparers := []cmp.Option{ + cmp.AllowUnexported(structFactory[foo]{}), + cmp.Comparer(func(a, b structFactory[foo]) bool { + if a.structType.String() != b.structType.String() { + return false + } + + if !maps.Equal(a.columnMap, b.columnMap) { return false } diff --git a/unmarshal_test.go b/unmarshal_test.go index 7c53715..2cb3ce3 100644 --- a/unmarshal_test.go +++ b/unmarshal_test.go @@ -31,10 +31,29 @@ func testUnmarshalSuccess(t *testing.T) { Height float32 `flat:"height"` } - expected := []record{} + expected := []record{ + { + FirstName: "Guybrush", + LastName: "Threepwood", + Age: 28, + Height: 1.78, + }, + { + FirstName: "Elaine", + LastName: "Marley", + Age: 20, + Height: 1.6, + }, + { + FirstName: "LeChuck", + LastName: "", + Age: 100, + Height: 2.01, + }, + } channel := make(chan record) - go assertChannel(t, channel, expected) + assertChannel(t, channel, expected) ctx := context.Background() @@ -59,25 +78,32 @@ func assertChannel[T any](t *testing.T, ch <-chan T, expected []T) { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() var got []T -loop: - for { - select { - case <-ctx.Done(): - break loop - case v, ok := <-ch: - if !ok { - break loop - } + go func() { + defer cancel() + + for { + select { + case <-ctx.Done(): + return + case v, ok := <-ch: + if !ok { + return + } - got = append(got, v) + got = append(got, v) + } } - } + }() - var zero T - if diff := cmp.Diff(expected, got, cmp.AllowUnexported(zero)); diff != "" { - t.Errorf("(-expected,+got):\n%s", diff) - } + t.Cleanup(func() { + var zero T + + <-ctx.Done() + + if diff := cmp.Diff(expected, got, cmp.AllowUnexported(zero)); diff != "" { + t.Errorf("(-expected,+got):\n%s", diff) + } + }) }