Skip to content

Commit

Permalink
fix(reflect): correctly handle missing fields
Browse files Browse the repository at this point in the history
  • Loading branch information
lzambarda committed Nov 26, 2024
1 parent 5365be7 commit 82a4f90
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 42 deletions.
3 changes: 0 additions & 3 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ var (
// does not appear in the input file. Only returned if
// [Option.ErrorIfMissingHeaders] is set to true.
ErrMissingHeader = errors.New("missing header")
// ErrMismatchedFields is returned when the input structs have inconsistent
// fields. In theory this will never be returned.
ErrMismatchedFields = errors.New("mismatched fields")
// ErrUnsupportedType is returned when the unmarshaller encounters an
// unsupported type.
ErrUnsupportedType = errors.New("unsupported type")
Expand Down
22 changes: 6 additions & 16 deletions reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

type structFactory[T any] struct {
structType reflect.Type
columnMap []int
columnMap map[int]int
columnValues []any
columnNames []string
}
Expand All @@ -18,10 +18,6 @@ type structFactory[T any] struct {
// work with them.
const FieldTag = "flat"

// columnMapIgnore is used to mark a column as ignored. This is needed if there
// are duplicate headers that must be skipped.
const columnMapIgnore = -1

//nolint:varnamelen // Fine-ish here.
func newFactory[T any](headers []string, options Options) (*structFactory[T], error) {
var v T
Expand All @@ -38,7 +34,7 @@ func newFactory[T any](headers []string, options Options) (*structFactory[T], er

factory := &structFactory[T]{
structType: t,
columnMap: make([]int, len(headers)),
columnMap: make(map[int]int, len(headers)),
columnValues: make([]any, t.NumField()),
columnNames: make([]string, t.NumField()),
}
Expand Down Expand Up @@ -80,10 +76,6 @@ func newFactory[T any](headers []string, options Options) (*structFactory[T], er
return nil, fmt.Errorf("header %q, index %d and %d: %w", header, j, handledAt, ErrDuplicatedHeader)
}

// If the duplicate headers error flag is diabled, then mark the
// column as ignored and continue.
factory.columnMap[j] = columnMapIgnore

continue
}

Expand All @@ -103,9 +95,6 @@ func newFactory[T any](headers []string, options Options) (*structFactory[T], er
//nolint:forcetypeassert,gocyclo,cyclop,ireturn // Fine for now.
func (s *structFactory[T]) unmarshal(record []string) (T, error) {
var zero T
if len(record) != len(s.columnMap) {
return zero, fmt.Errorf("expected %d fields, got %d: %w", len(s.columnMap), len(record), ErrMismatchedFields)
}

newStruct := reflect.New(s.structType).Elem()

Expand All @@ -114,11 +103,12 @@ func (s *structFactory[T]) unmarshal(record []string) (T, error) {

//nolint:varnamelen // Fine here.
for i, column := range record {
if s.columnMap[i] == columnMapIgnore {
mappedIndex, found := s.columnMap[i]
if !found {
continue
}

columnBaseValue := s.columnValues[s.columnMap[i]]
columnBaseValue := s.columnValues[mappedIndex]

// special case
if u, ok := columnBaseValue.(Unmarshaller); ok {
Expand Down Expand Up @@ -170,7 +160,7 @@ func (s *structFactory[T]) unmarshal(record []string) (T, error) {
return zero, fmt.Errorf("parse column %d: %w", i, err)
}

newStruct.Field(s.columnMap[i]).Set(reflect.ValueOf(value))
newStruct.Field(mappedIndex).Set(reflect.ValueOf(value))
}

return newStruct.Interface().(T), nil
Expand Down
52 changes: 47 additions & 5 deletions reflect_internal_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package goflat

import (
"maps"
"reflect"
"slices"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -81,6 +81,7 @@ func testReflectErrorDuplicate(t *testing.T) {
func testReflectSuccess(t *testing.T) {
t.Run("duplicate", testReflectSuccessDuplicate)
t.Run("simple", testReflectSuccessSimple)
t.Run("subset struct", testReflectSuccessSubsetStruct)
}

func testReflectSuccessDuplicate(t *testing.T) {
Expand All @@ -102,7 +103,7 @@ func testReflectSuccessDuplicate(t *testing.T) {

expected := &structFactory[foo]{
structType: reflect.TypeOf(foo{}),
columnMap: []int{0, 1, -1},
columnMap: map[int]int{0: 0, 1: 1},
columnValues: []any{"", int(0)},
columnNames: []string{"name", "age"},
}
Expand All @@ -113,7 +114,7 @@ func testReflectSuccessDuplicate(t *testing.T) {
return false
}

if !slices.Equal(a.columnMap, b.columnMap) {
if !maps.Equal(a.columnMap, b.columnMap) {
return false
}

Expand Down Expand Up @@ -146,7 +147,7 @@ func testReflectSuccessSimple(t *testing.T) {

expected := &structFactory[foo]{
structType: reflect.TypeOf(foo{}),
columnMap: []int{0, 1},
columnMap: map[int]int{0: 0, 1: 1},
}
comparers := []cmp.Option{
cmp.AllowUnexported(structFactory[foo]{}),
Expand All @@ -155,7 +156,48 @@ func testReflectSuccessSimple(t *testing.T) {
return false
}

if !slices.Equal(a.columnMap, b.columnMap) {
if !maps.Equal(a.columnMap, b.columnMap) {
return false
}

return true
}),
}

if diff := cmp.Diff(expected, got, comparers...); diff != "" {
t.Errorf("(-want +got):\\n%s", diff)
}
}

func testReflectSuccessSubsetStruct(t *testing.T) {
type foo struct {
Col2 float32 `flat:"col2"`
}

headers := []string{"col1", "col2", "col3"}

got, err := newFactory[foo](headers, Options{
Strict: false,
ErrorIfDuplicateHeaders: false,
ErrorIfMissingHeaders: false,
})
if err != nil {
t.Errorf("expected no error, got %v", err)
}

expected := &structFactory[foo]{
structType: reflect.TypeOf(foo{}),
columnMap: map[int]int{1: 0},
columnValues: []any{float32(0)},
}
comparers := []cmp.Option{
cmp.AllowUnexported(structFactory[foo]{}),
cmp.Comparer(func(a, b structFactory[foo]) bool {
if a.structType.String() != b.structType.String() {
return false
}

if !maps.Equal(a.columnMap, b.columnMap) {
return false
}

Expand Down
62 changes: 44 additions & 18 deletions unmarshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,29 @@ func testUnmarshalSuccess(t *testing.T) {
Height float32 `flat:"height"`
}

expected := []record{}
expected := []record{
{
FirstName: "Guybrush",
LastName: "Threepwood",
Age: 28,
Height: 1.78,
},
{
FirstName: "Elaine",
LastName: "Marley",
Age: 20,
Height: 1.6,
},
{
FirstName: "LeChuck",
LastName: "",
Age: 100,
Height: 2.01,
},
}

channel := make(chan record)
go assertChannel(t, channel, expected)
assertChannel(t, channel, expected)

ctx := context.Background()

Expand All @@ -59,25 +78,32 @@ func assertChannel[T any](t *testing.T, ch <-chan T, expected []T) {
t.Helper()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
var got []T

loop:
for {
select {
case <-ctx.Done():
break loop
case v, ok := <-ch:
if !ok {
break loop
}
go func() {
defer cancel()

for {
select {
case <-ctx.Done():
return
case v, ok := <-ch:
if !ok {
return
}

got = append(got, v)
got = append(got, v)
}
}
}
}()

var zero T
if diff := cmp.Diff(expected, got, cmp.AllowUnexported(zero)); diff != "" {
t.Errorf("(-expected,+got):\n%s", diff)
}
t.Cleanup(func() {
var zero T

<-ctx.Done()

if diff := cmp.Diff(expected, got, cmp.AllowUnexported(zero)); diff != "" {
t.Errorf("(-expected,+got):\n%s", diff)
}
})
}

0 comments on commit 82a4f90

Please sign in to comment.