Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generics to chan and callback-using functions #273

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- name: Set up Go 1.x
uses: actions/setup-go@v2
with:
go-version: ^1.11
go-version: ^1.18

- name: Check out code into the Go module directory
uses: actions/checkout@v2
Expand Down
143 changes: 57 additions & 86 deletions csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,12 @@ func MarshalWithoutHeaders(in interface{}, out io.Writer) (err error) {
}

// MarshalChan returns the CSV read from the channel.
func MarshalChan(c <-chan interface{}, out CSVWriter) error {
func MarshalChan[T any](c <-chan T, out CSVWriter) error {
return writeFromChan(out, c, false)
}

// MarshalChanWithoutHeaders returns the CSV read from the channel.
func MarshalChanWithoutHeaders(c <-chan interface{}, out CSVWriter) error {
func MarshalChanWithoutHeaders[T any](c <-chan T, out CSVWriter) error {
return writeFromChan(out, c, true)
}

Expand Down Expand Up @@ -278,24 +278,23 @@ func UnmarshalCSVToMap(in CSVReader, out interface{}) error {

// UnmarshalToChan parses the CSV from the reader and send each value in the chan c.
// The channel must have a concrete type.
func UnmarshalToChan(in io.Reader, c interface{}) error {
func UnmarshalToChan[T any](in io.Reader, c chan<- T) error {
if c == nil {
return fmt.Errorf("goscv: channel is %v", c)
}
return readEach(newSimpleDecoderFromReader(in), nil, c)
}

// UnmarshalToChanWithErrorHandler parses the CSV from the reader in the interface.
func UnmarshalToChanWithErrorHandler(in io.Reader, errorHandler ErrorHandler, c interface{}) error {
// UnmarshalToChanWithErrorHandler parses the CSV from the reader in the channel.
func UnmarshalToChanWithErrorHandler[T any](in io.Reader, errorHandler ErrorHandler, c chan<- T) error {
if c == nil {
return fmt.Errorf("goscv: channel is %v", c)
}
return readEach(newSimpleDecoderFromReader(in), errorHandler, c)
}

// UnmarshalToChanWithoutHeaders parses the CSV from the reader and send each value in the chan c.
// The channel must have a concrete type.
func UnmarshalToChanWithoutHeaders(in io.Reader, c interface{}) error {
func UnmarshalToChanWithoutHeaders[T any](in io.Reader, c chan<- T) error {
if c == nil {
return fmt.Errorf("goscv: channel is %v", c)
}
Expand All @@ -304,125 +303,99 @@ func UnmarshalToChanWithoutHeaders(in io.Reader, c interface{}) error {

// UnmarshalDecoderToChan parses the CSV from the decoder and send each value in the chan c.
// The channel must have a concrete type.
func UnmarshalDecoderToChan(in SimpleDecoder, c interface{}) error {
func UnmarshalDecoderToChan[T any](in SimpleDecoder, c chan<- T) error {
if c == nil {
return fmt.Errorf("goscv: channel is %v", c)
}
return readEach(in, nil, c)
}

// UnmarshalStringToChan parses the CSV from the string and send each value in the chan c.
// The channel must have a concrete type.
func UnmarshalStringToChan(in string, c interface{}) error {
// UnmarshalStringToChan parses the CSV from the string and send each value in
// the chan c.
func UnmarshalStringToChan[T any](in string, c chan<- T) error {
return UnmarshalToChan(strings.NewReader(in), c)
}

// UnmarshalBytesToChan parses the CSV from the bytes and send each value in the chan c.
// The channel must have a concrete type.
func UnmarshalBytesToChan(in []byte, c interface{}) error {
// UnmarshalBytesToChan parses the CSV from the bytes and send each value in the
// chan c.
func UnmarshalBytesToChan[T any](in []byte, c chan<- T) error {
return UnmarshalToChan(bytes.NewReader(in), c)
}

// UnmarshalToCallback parses the CSV from the reader and send each value to the given func f.
// The func must look like func(Struct).
func UnmarshalToCallback(in io.Reader, f interface{}) error {
valueFunc := reflect.ValueOf(f)
t := reflect.TypeOf(f)
if t.NumIn() != 1 {
return fmt.Errorf("the given function must have exactly one parameter")
}
// UnmarshalToCallback parses the CSV from the reader and send each value to the
// given func callback.
func UnmarshalToCallback[T any](in io.Reader, callback func(T) error) error {
cerr := make(chan error)
c := reflect.MakeChan(reflect.ChanOf(reflect.BothDir, t.In(0)), 0)
c := make(chan T)
go func() {
cerr <- UnmarshalToChan(in, c.Interface())
cerr <- UnmarshalToChan(in, c)
}()
for {
select {
case err := <-cerr:
return err
default:
}
v, notClosed := c.Recv()
if !notClosed || v.Interface() == nil {
v, notClosed := <-c
if !notClosed {
break
}
callResults := valueFunc.Call([]reflect.Value{v})
// if last returned value from Call() is an error, return it
if len(callResults) > 0 {
if err, ok := callResults[len(callResults)-1].Interface().(error); ok {
return err
}
err := callback(v)
if err != nil {
return err
}
}
return <-cerr
}

// UnmarshalDecoderToCallback parses the CSV from the decoder and send each value to the given func f.
// The func must look like func(Struct).
func UnmarshalDecoderToCallback(in SimpleDecoder, f interface{}) error {
valueFunc := reflect.ValueOf(f)
t := reflect.TypeOf(f)
if t.NumIn() != 1 {
return fmt.Errorf("the given function must have exactly one parameter")
}
// UnmarshalDecoderToCallback parses the CSV from the decoder and send each value to the given func callback.
func UnmarshalDecoderToCallback[T any](in SimpleDecoder, callback func(T) error) error {
cerr := make(chan error)
c := reflect.MakeChan(reflect.ChanOf(reflect.BothDir, t.In(0)), 0)
c := make(chan T)
go func() {
cerr <- UnmarshalDecoderToChan(in, c.Interface())
cerr <- UnmarshalDecoderToChan(in, c)
}()
for {
select {
case err := <-cerr:
return err
default:
}
v, notClosed := c.Recv()
if !notClosed || v.Interface() == nil {
v, notClosed := <-c
if !notClosed {
break
}
valueFunc.Call([]reflect.Value{v})
err := callback(v)
if err != nil {
return err
}
}
return <-cerr
}

// UnmarshalBytesToCallback parses the CSV from the bytes and send each value to the given func f.
// The func must look like func(Struct).
func UnmarshalBytesToCallback(in []byte, f interface{}) error {
return UnmarshalToCallback(bytes.NewReader(in), f)
// UnmarshalBytesToCallback parses the CSV from the bytes and send each value to
// the given func callback.
func UnmarshalBytesToCallback[T any](in []byte, callback func(T) error) error {
return UnmarshalToCallback(bytes.NewReader(in), callback)
}

// UnmarshalStringToCallback parses the CSV from the string and send each value to the given func f.
// The func must look like func(Struct).
func UnmarshalStringToCallback(in string, c interface{}) (err error) {
return UnmarshalToCallback(strings.NewReader(in), c)
// UnmarshalStringToCallback parses the CSV from the string and send each value
// to the given func callback.
func UnmarshalStringToCallback[T any](in string, callback func(T) error) error {
return UnmarshalToCallback(strings.NewReader(in), callback)
}

// UnmarshalToCallbackWithError parses the CSV from the reader and
// send each value to the given func f.
// send each value to the given func callback.
//
// If func returns error, it will stop processing, drain the
// parser and propagate the error to caller.
//
// The func must look like func(Struct) error.
func UnmarshalToCallbackWithError(in io.Reader, f interface{}) error {
valueFunc := reflect.ValueOf(f)
t := reflect.TypeOf(f)
if t.NumIn() != 1 {
return fmt.Errorf("the given function must have exactly one parameter")
}
if t.NumOut() != 1 {
return fmt.Errorf("the given function must have exactly one return value")
}
if !isErrorType(t.Out(0)) {
return fmt.Errorf("the given function must only return error")
}

func UnmarshalToCallbackWithError[T any](in io.Reader, callback func(T) error) error {
cerr := make(chan error)
c := reflect.MakeChan(reflect.ChanOf(reflect.BothDir, t.In(0)), 0)
c := make(chan T)
go func() {
cerr <- UnmarshalToChan(in, c.Interface())
cerr <- UnmarshalToChan(in, c)
}()

var fErr error
for {
select {
Expand All @@ -433,39 +406,37 @@ func UnmarshalToCallbackWithError(in io.Reader, f interface{}) error {
return fErr
default:
}
v, notClosed := c.Recv()
if !notClosed || v.Interface() == nil {
v, notClosed := <-c
if !notClosed {
if err := <-cerr; err != nil {
fErr = err
}
break
}

// callback f has already returned an error, stop processing but keep draining the chan c
// callback has already returned an error, stop processing but keep draining the chan c
if fErr != nil {
continue
}

results := valueFunc.Call([]reflect.Value{v})

// If the callback f returns an error, stores it and returns it in future.
errValue := results[0].Interface()
if errValue != nil {
fErr = errValue.(error)
// If the callback returns an error, stores it and returns it in future.
err := callback(v)
if err != nil {
fErr = err
}
}
return fErr
}

// UnmarshalBytesToCallbackWithError parses the CSV from the bytes and
// send each value to the given func f.
// send each value to the given func callback.
//
// If func returns error, it will stop processing, drain the
// parser and propagate the error to caller.
//
// The func must look like func(Struct) error.
func UnmarshalBytesToCallbackWithError(in []byte, f interface{}) error {
return UnmarshalToCallbackWithError(bytes.NewReader(in), f)
func UnmarshalBytesToCallbackWithError[T any](in []byte, callback func(T) error) error {
return UnmarshalToCallbackWithError(bytes.NewReader(in), callback)
}

// UnmarshalStringToCallbackWithError parses the CSV from the string and
Expand All @@ -475,8 +446,8 @@ func UnmarshalBytesToCallbackWithError(in []byte, f interface{}) error {
// parser and propagate the error to caller.
//
// The func must look like func(Struct) error.
func UnmarshalStringToCallbackWithError(in string, c interface{}) (err error) {
return UnmarshalToCallbackWithError(strings.NewReader(in), c)
func UnmarshalStringToCallbackWithError[T any](in string, callback func(T) error) error {
return UnmarshalToCallbackWithError(strings.NewReader(in), callback)
}

// CSVToMap creates a simple map from a CSV of 2 columns.
Expand Down
6 changes: 3 additions & 3 deletions csv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ import (

func TestUnmarshalToCallback_ReaderError(t *testing.T) {
type Dummy struct{}
var reader = &errorReader{}
reader := &errorReader{}

err := UnmarshalToCallback(reader, func(Dummy) {})
err := UnmarshalToCallback(reader, func(Dummy) error { return nil })
if !errors.Is(err, readerErr) {
t.Error("UnmarshalToCallback should return first reader error")
}

err = UnmarshalDecoderToCallback(newSimpleDecoderFromReader(reader), func(Dummy) {})
err = UnmarshalDecoderToCallback(newSimpleDecoderFromReader(reader), func(Dummy) error { return nil })
if !errors.Is(err, readerErr) {
t.Error("UnmarshalDecoderToCallback should return first reader error")
}
Expand Down
6 changes: 4 additions & 2 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,8 +616,9 @@ func TestUnmarshalToCallback(t *testing.T) {
aa,bb,11,cc,dd,ee
ff,gg,22,hh,ii,jj`)
var samples []SkipFieldSample
if err := UnmarshalBytesToCallback(b.Bytes(), func(s SkipFieldSample) {
if err := UnmarshalBytesToCallback(b.Bytes(), func(s SkipFieldSample) error {
samples = append(samples, s)
return nil
}); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1020,8 +1021,9 @@ f,1,baz,, *string
e,3,b,, `)

var samples []Sample
if err := UnmarshalDecoderToCallback(&trimDecoder{LazyCSVReader(b)}, func(s Sample) {
if err := UnmarshalDecoderToCallback(&trimDecoder{LazyCSVReader(b)}, func(s Sample) error {
samples = append(samples, s)
return nil
}); err != nil {
t.Fatal(err)
}
Expand Down
6 changes: 2 additions & 4 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ import (
"reflect"
)

var (
ErrChannelIsClosed = errors.New("channel is closed")
)
var ErrChannelIsClosed = errors.New("channel is closed")

type encoder struct {
out io.Writer
Expand All @@ -19,7 +17,7 @@ func newEncoder(out io.Writer) *encoder {
return &encoder{out}
}

func writeFromChan(writer CSVWriter, c <-chan interface{}, omitHeaders bool) error {
func writeFromChan[T any](writer CSVWriter, c <-chan T, omitHeaders bool) error {
// Get the first value. It wil determine the header structure.
firstValue, ok := <-c
if !ok {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/gocarina/gocsv

go 1.13
go 1.18
Loading