Skip to content

Commit

Permalink
Merge pull request #1 from lzambarda/feat/support-pointer
Browse files Browse the repository at this point in the history
feat: support pointer values
  • Loading branch information
lzambarda authored Dec 6, 2024
2 parents 82a4f90 + bf2bf34 commit c7830d3
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 12 deletions.
46 changes: 46 additions & 0 deletions .github/workflows/lint-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: Lint and Test

on:
push:
branches: [master]
pull_request:
branches:
- master
- feature/*
- bugfix/*
- patch/*
- refactor/*
- chore/*

jobs:
lint:
runs-on: ubuntu-latest
steps:
- name: Checkout repo
uses: actions/checkout@v3
with:
fetch-depth: 0 # needed because of the new-from-rev in golangci

- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: "1.23"

- name: golangci-lint
uses: golangci/golangci-lint-action@v3
with:
args: --timeout 5m0s

test:
runs-on: ubuntu-latest
steps:
- name: Checkout repo
uses: actions/checkout@v3

- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: "1.23"

- name: Running Unit Tests
run: go test ./...
49 changes: 49 additions & 0 deletions marshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ import (
)

func TestMarshal(t *testing.T) {
t.Run("success", testMarshalSuccess)
t.Run("success pointer", testMarshalSuccessPointer)
}

func testMarshalSuccess(t *testing.T) {
expected, err := testdata.ReadFile("testdata/marshal/success.csv")
if err != nil {
t.Fatalf("read test file: %v", err)
Expand Down Expand Up @@ -54,3 +59,47 @@ func TestMarshal(t *testing.T) {
t.Errorf("(-expected, +got):\n%s", diff)
}
}

func testMarshalSuccessPointer(t *testing.T) {
expected, err := testdata.ReadFile("testdata/marshal/success.csv")
if err != nil {
t.Fatalf("read test file: %v", err)
}

type record struct {
FirstName string `flat:"first_name"`
LastName string `flat:"last_name"`
Ignore uint8 `flat:"-"`
Age int `flat:"age"`
Height float32 `flat:"height"`
}

input := []*record{
{
FirstName: "John",
LastName: "Doe",
Ignore: 123,
Age: 30,
Height: 1.75,
},
{
FirstName: "Jane",
LastName: "Doe",
Ignore: 123,
Age: 25,
Height: 1.65,
},
}
var got bytes.Buffer

writer := csv.NewWriter(&got)

err = goflat.MarshalSliceToWriter(context.Background(), input, writer, goflat.Options{})
if err != nil {
t.Fatalf("marshal: %v", err)
}

if diff := cmp.Diff(string(expected), got.String()); diff != "" {
t.Errorf("(-expected, +got):\n%s", diff)
}
}
28 changes: 21 additions & 7 deletions reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

type structFactory[T any] struct {
structType reflect.Type
pointer bool
columnMap map[int]int
columnValues []any
columnNames []string
Expand All @@ -18,31 +19,36 @@ type structFactory[T any] struct {
// work with them.
const FieldTag = "flat"

//nolint:varnamelen // Fine-ish here.
//nolint:varnamelen,cyclop // Fine-ish here.
func newFactory[T any](headers []string, options Options) (*structFactory[T], error) {
var v T

t := reflect.TypeOf(v)
rv := reflect.ValueOf(v)

if t.Kind() == reflect.Pointer {
t = t.Elem()
}
pointer := false

if t.Kind() != reflect.Struct {
//nolint:exhaustive // Fine here, there's a default.
switch t.Kind() {
case reflect.Struct:
case reflect.Pointer:
pointer = true
t = t.Elem()
rv = reflect.New(t).Elem()
default:
return nil, fmt.Errorf("type %T: %w", v, ErrNotAStruct)
}

factory := &structFactory[T]{
structType: t,
pointer: pointer,
columnMap: make(map[int]int, len(headers)),
columnValues: make([]any, t.NumField()),
columnNames: make([]string, t.NumField()),
}

covered := make([]bool, len(headers))

rv := reflect.ValueOf(v)

for i := range t.NumField() {
fieldT := t.Field(i)
fieldV := rv.Field(i)
Expand Down Expand Up @@ -163,6 +169,10 @@ func (s *structFactory[T]) unmarshal(record []string) (T, error) {
newStruct.Field(mappedIndex).Set(reflect.ValueOf(value))
}

if s.pointer {
newStruct = newStruct.Addr()
}

return newStruct.Interface().(T), nil
}

Expand All @@ -183,6 +193,10 @@ func (s *structFactory[T]) marshalHeaders() []string {
func (s *structFactory[T]) marshal(t T, separator string) ([]string, error) {
reflectValue := reflect.ValueOf(t)

if s.pointer {
reflectValue = reflectValue.Elem()
}

record := make([]string, 0, len(s.columnNames))

var strValue string
Expand Down
41 changes: 41 additions & 0 deletions reflect_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func testReflectSuccess(t *testing.T) {
t.Run("duplicate", testReflectSuccessDuplicate)
t.Run("simple", testReflectSuccessSimple)
t.Run("subset struct", testReflectSuccessSubsetStruct)
t.Run("pointer", testReflectSuccessPointer)
}

func testReflectSuccessDuplicate(t *testing.T) {
Expand Down Expand Up @@ -209,3 +210,43 @@ func testReflectSuccessSubsetStruct(t *testing.T) {
t.Errorf("(-want +got):\\n%s", diff)
}
}

func testReflectSuccessPointer(t *testing.T) {
type foo struct {
Name string `flat:"name"`
}

headers := []string{"name"}

got, err := newFactory[*foo](headers, Options{
Strict: true,
ErrorIfDuplicateHeaders: true,
ErrorIfMissingHeaders: true,
})
if err != nil {
t.Errorf("expected no error, got %v", err)
}

expected := &structFactory[*foo]{
structType: reflect.TypeOf(foo{}),
columnMap: map[int]int{0: 0},
}
comparers := []cmp.Option{
cmp.AllowUnexported(structFactory[*foo]{}),
cmp.Comparer(func(a, b structFactory[*foo]) bool {
if a.structType.String() != b.structType.String() {
return false
}

if !maps.Equal(a.columnMap, b.columnMap) {
return false
}

return true
}),
}

if diff := cmp.Diff(expected, got, comparers...); diff != "" {
t.Errorf("(-want +got):\\n%s", diff)
}
}
65 changes: 60 additions & 5 deletions unmarshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

func TestUnmarshal(t *testing.T) {
t.Run("success", testUnmarshalSuccess)
t.Run("success pointer", testUnmarshalSuccessPointer)
}

//go:embed testdata
Expand Down Expand Up @@ -53,7 +54,7 @@ func testUnmarshalSuccess(t *testing.T) {
}

channel := make(chan record)
assertChannel(t, channel, expected)
assertChannel(t, channel, expected, cmp.AllowUnexported(record{}))

ctx := context.Background()

Expand All @@ -74,7 +75,63 @@ func testUnmarshalSuccess(t *testing.T) {
}
}

func assertChannel[T any](t *testing.T, ch <-chan T, expected []T) {
func testUnmarshalSuccessPointer(t *testing.T) {
file, err := testdata.Open("testdata/unmarshal/success.csv")
if err != nil {
t.Fatalf("open test file: %v", err)
}

type record struct {
FirstName string `flat:"first_name"`
LastName string `flat:"last_name"`
Age int `flat:"age"`
Height float32 `flat:"height"`
}

expected := []*record{
{
FirstName: "Guybrush",
LastName: "Threepwood",
Age: 28,
Height: 1.78,
},
{
FirstName: "Elaine",
LastName: "Marley",
Age: 20,
Height: 1.6,
},
{
FirstName: "LeChuck",
LastName: "",
Age: 100,
Height: 2.01,
},
}

channel := make(chan *record)
assertChannel(t, channel, expected, cmp.AllowUnexported(record{}))

ctx := context.Background()

csvReader, err := goflat.DetectReader(file)
if err != nil {
t.Fatalf("detect reader: %v", err)
}

options := goflat.Options{
Strict: true,
ErrorIfDuplicateHeaders: true,
ErrorIfMissingHeaders: true,
}

err = goflat.UnmarshalToChannel(ctx, csvReader, options, channel)
if err != nil {
t.Fatalf("unmarshal: %v", err)
}
}

func assertChannel[T any](t *testing.T, ch <-chan T, expected []T, cmpOpts ...cmp.Option) {
t.Helper()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
Expand All @@ -98,11 +155,9 @@ func assertChannel[T any](t *testing.T, ch <-chan T, expected []T) {
}()

t.Cleanup(func() {
var zero T

<-ctx.Done()

if diff := cmp.Diff(expected, got, cmp.AllowUnexported(zero)); diff != "" {
if diff := cmp.Diff(expected, got, cmpOpts...); diff != "" {
t.Errorf("(-expected,+got):\n%s", diff)
}
})
Expand Down

0 comments on commit c7830d3

Please sign in to comment.