diff --git a/cache.go b/cache.go index bf21697..065b8d6 100644 --- a/cache.go +++ b/cache.go @@ -197,6 +197,7 @@ func (c *cache) createField(field reflect.StructField, parentAlias string) *fiel isSliceOfStructs: isSlice && isStruct, isAnonymous: field.Anonymous, isRequired: options.Contains("required"), + defaultValue: options.getDefaultOptionValue(), } } @@ -246,8 +247,9 @@ type fieldInfo struct { // isSliceOfStructs indicates if the field type is a slice of structs. isSliceOfStructs bool // isAnonymous indicates whether the field is embedded in the struct. - isAnonymous bool - isRequired bool + isAnonymous bool + isRequired bool + defaultValue string } func (f *fieldInfo) paths(prefix string) []string { @@ -303,3 +305,13 @@ func (o tagOptions) Contains(option string) bool { } return false } + +func (o tagOptions) getDefaultOptionValue() string { + for _, s := range o { + if strings.HasPrefix(s, "default:") { + return strings.Split(s, ":")[1] + } + } + + return "" +} diff --git a/converter.go b/converter.go index 4f2116a..4bae6df 100644 --- a/converter.go +++ b/converter.go @@ -143,3 +143,80 @@ func convertUint64(value string) reflect.Value { } return invalidValue } + +func convertPointer(k reflect.Kind, value string) reflect.Value { + switch k { + case boolType: + if v := convertBool(value); v.IsValid() { + converted := v.Bool() + return reflect.ValueOf(&converted) + } + case float32Type: + if v := convertFloat32(value); v.IsValid() { + converted := float32(v.Float()) + return reflect.ValueOf(&converted) + } + case float64Type: + if v := convertFloat64(value); v.IsValid() { + converted := float64(v.Float()) + return reflect.ValueOf(&converted) + } + case intType: + if v := convertInt(value); v.IsValid() { + converted := int(v.Int()) + return reflect.ValueOf(&converted) + } + case int8Type: + if v := convertInt8(value); v.IsValid() { + converted := int8(v.Int()) + return reflect.ValueOf(&converted) + } + case int16Type: + if v := convertInt16(value); v.IsValid() { + converted := int16(v.Int()) + return reflect.ValueOf(&converted) + } + case int32Type: + if v := convertInt32(value); v.IsValid() { + converted := int32(v.Int()) + return reflect.ValueOf(&converted) + } + case int64Type: + if v := convertInt64(value); v.IsValid() { + converted := int64(v.Int()) + return reflect.ValueOf(&converted) + } + case stringType: + if v := convertString(value); v.IsValid() { + converted := v.String() + return reflect.ValueOf(&converted) + } + case uintType: + if v := convertUint(value); v.IsValid() { + converted := uint(v.Uint()) + return reflect.ValueOf(&converted) + } + case uint8Type: + if v := convertUint8(value); v.IsValid() { + converted := uint8(v.Uint()) + return reflect.ValueOf(&converted) + } + case uint16Type: + if v := convertUint16(value); v.IsValid() { + converted := uint16(v.Uint()) + return reflect.ValueOf(&converted) + } + case uint32Type: + if v := convertUint32(value); v.IsValid() { + converted := uint32(v.Uint()) + return reflect.ValueOf(&converted) + } + case uint64Type: + if v := convertUint64(value); v.IsValid() { + converted := uint64(v.Uint()) + return reflect.ValueOf(&converted) + } + } + + return invalidValue +} diff --git a/decoder.go b/decoder.go index 28b560b..98f072e 100644 --- a/decoder.go +++ b/decoder.go @@ -84,6 +84,7 @@ func (d *Decoder) Decode(dst interface{}, src map[string][]string) error { errors[path] = UnknownKeyError{Key: path} } } + errors.merge(d.setDefaults(t, v)) errors.merge(d.checkRequired(t, src)) if len(errors) > 0 { return errors @@ -91,6 +92,76 @@ func (d *Decoder) Decode(dst interface{}, src map[string][]string) error { return nil } +//setDefaults sets the default values when the `default` tag is specified, +//default is supported on basic/primitive types and their pointers, +//nested structs can also have default tags +func (d *Decoder) setDefaults(t reflect.Type, v reflect.Value) MultiError { + struc := d.cache.get(t) + if struc == nil { + // unexpect, cache.get never return nil + return MultiError{"default-" + t.Name(): errors.New("cache fail")} + } + + errs := MultiError{} + + for _, f := range struc.fields { + vCurrent := v.FieldByName(f.name) + + if vCurrent.Type().Kind() == reflect.Struct && f.defaultValue == "" { + errs.merge(d.setDefaults(vCurrent.Type(), vCurrent)) + } else if isPointerToStruct(vCurrent) && f.defaultValue == "" { + errs.merge(d.setDefaults(vCurrent.Elem().Type(), vCurrent.Elem())) + } + + if f.defaultValue != "" && f.isRequired { + errs.merge(MultiError{"default-" + f.name: errors.New("required fields cannot have a default value")}) + } else if f.defaultValue != "" && vCurrent.IsZero() && !f.isRequired { + if f.typ.Kind() == reflect.Struct { + errs.merge(MultiError{"default-" + f.name: errors.New("default option is supported only on: bool, float variants, string, unit variants types or their corresponding pointers or slices")}) + } else if f.typ.Kind() == reflect.Slice { + vals := strings.Split(f.defaultValue, "|") + + //check if slice has one of the supported types for defaults + if _, ok := builtinConverters[f.typ.Elem().Kind()]; !ok { + errs.merge(MultiError{"default-" + f.name: errors.New("default option is supported only on: bool, float variants, string, unit variants types or their corresponding pointers or slices")}) + continue + } + + defaultSlice := reflect.MakeSlice(f.typ, 0, cap(vals)) + for _, val := range vals { + //this check is to handle if the wrong value is provided + if convertedVal := builtinConverters[f.typ.Elem().Kind()](val); convertedVal.IsValid() { + defaultSlice = reflect.Append(defaultSlice, convertedVal) + } + } + vCurrent.Set(defaultSlice) + } else if f.typ.Kind() == reflect.Ptr { + t1 := f.typ.Elem() + + if t1.Kind() == reflect.Struct || t1.Kind() == reflect.Slice { + errs.merge(MultiError{"default-" + f.name: errors.New("default option is supported only on: bool, float variants, string, unit variants types or their corresponding pointers or slices")}) + } + + //this check is to handle if the wrong value is provided + if convertedVal := convertPointer(t1.Kind(), f.defaultValue); convertedVal.IsValid() { + vCurrent.Set(convertedVal) + } + } else { + //this check is to handle if the wrong value is provided + if convertedVal := builtinConverters[f.typ.Kind()](f.defaultValue); convertedVal.IsValid() { + vCurrent.Set(builtinConverters[f.typ.Kind()](f.defaultValue)) + } + } + } + } + + return errs +} + +func isPointerToStruct(v reflect.Value) bool { + return !v.IsZero() && v.Type().Kind() == reflect.Ptr && v.Elem().Type().Kind() == reflect.Struct +} + // checkRequired checks whether required fields are empty // // check type t recursively if t has struct fields. diff --git a/decoder_test.go b/decoder_test.go index f89a4c3..3c12218 100644 --- a/decoder_test.go +++ b/decoder_test.go @@ -2055,3 +2055,242 @@ func TestUnmashalPointerToEmbedded(t *testing.T) { t.Errorf("Expected %v errors, got %v", expected, s.Value) } } + +func TestDefaultValuesAreSet(t *testing.T) { + type N struct { + S1 string `schema:"s1,default:test1"` + I2 int `schema:"i2,default:22"` + R2 []float64 `schema:"r2,default:2|3.5|11.01"` + } + + type D struct { + N + S string `schema:"s,default:test1"` + I int `schema:"i,default:21"` + J int8 `schema:"j,default:2"` + K int16 `schema:"k,default:-455"` + L int32 `schema:"l,default:899"` + M int64 `schema:"m,default:12455"` + B bool `schema:"b,default:false"` + F float64 `schema:"f,default:3.14"` + G float32 `schema:"g,default:19.12"` + U uint `schema:"u,default:1"` + V uint8 `schema:"v,default:190"` + W uint16 `schema:"w,default:20000"` + Y uint32 `schema:"y,default:156666666"` + Z uint64 `schema:"z,default:1545465465465546"` + X []string `schema:"x,default:x1|x2"` + } + + data := map[string][]string{} + + d := D{} + + decoder := NewDecoder() + + if err := decoder.Decode(&d, data); err != nil { + t.Fatal("Error while decoding:", err) + } + + expected := D{ + N: N{ + S1: "test1", + I2: 22, + R2: []float64{2, 3.5, 11.01}, + }, + S: "test1", + I: 21, + J: 2, + K: -455, + L: 899, + M: 12455, + B: false, + F: 3.14, + G: 19.12, + U: 1, + V: 190, + W: 20000, + Y: 156666666, + Z: 1545465465465546, + X: []string{"x1", "x2"}, + } + + if !reflect.DeepEqual(expected, d) { + t.Errorf("Expected %v, got %v", expected, d) + } + + type P struct { + *N + S *string `schema:"s,default:test1"` + I *int `schema:"i,default:21"` + J *int8 `schema:"j,default:2"` + K *int16 `schema:"k,default:-455"` + L *int32 `schema:"l,default:899"` + M *int64 `schema:"m,default:12455"` + B *bool `schema:"b,default:false"` + F *float64 `schema:"f,default:3.14"` + G *float32 `schema:"g,default:19.12"` + U *uint `schema:"u,default:1"` + V *uint8 `schema:"v,default:190"` + W *uint16 `schema:"w,default:20000"` + Y *uint32 `schema:"y,default:156666666"` + Z *uint64 `schema:"z,default:1545465465465546"` + X []string `schema:"x,default:x1|x2"` + } + + p := P{N: &N{}} + + if err := decoder.Decode(&p, data); err != nil { + t.Fatal("Error while decoding:", err) + } + + vExpected := reflect.ValueOf(expected) + vActual := reflect.ValueOf(p) + + i := 0 + + for i < vExpected.NumField() { + if !reflect.DeepEqual(vExpected.Field(i).Interface(), reflect.Indirect(vActual.Field(i)).Interface()) { + t.Errorf("Expected %v, got %v", vExpected.Field(i).Interface(), reflect.Indirect(vActual.Field(i)).Interface()) + } + i++ + } +} + +func TestDefaultValuesAreIgnoredIfValuesAreProvided(t *testing.T) { + type D struct { + S string `schema:"s,default:test1"` + I int `schema:"i,default:21"` + B bool `schema:"b,default:false"` + F float64 `schema:"f,default:3.14"` + U uint `schema:"u,default:1"` + } + + data := map[string][]string{"s": {"s"}, "i": {"1"}, "b": {"true"}, "f": {"0.22"}, "u": {"14"}} + + d := D{} + + decoder := NewDecoder() + + if err := decoder.Decode(&d, data); err != nil { + t.Fatal("Error while decoding:", err) + } + + expected := D{ + S: "s", + I: 1, + B: true, + F: 0.22, + U: 14, + } + + if !reflect.DeepEqual(expected, d) { + t.Errorf("Expected %v, got %v", expected, d) + } +} + +func TestRequiredFieldsCannotHaveDefaults(t *testing.T) { + type D struct { + S string `schema:"s,required,default:test1"` + I int `schema:"i,required,default:21"` + B bool `schema:"b,required,default:false"` + F float64 `schema:"f,required,default:3.14"` + U uint `schema:"u,required,default:1"` + } + + data := map[string][]string{"s": {"s"}, "i": {"1"}, "b": {"true"}, "f": {"0.22"}, "u": {"14"}} + + d := D{} + + decoder := NewDecoder() + + err := decoder.Decode(&d, data) + + expected := "required fields cannot have a default value" + + if err == nil || !strings.Contains(err.Error(), expected) { + t.Errorf("decoding should fail with error msg %s got %q", expected, err) + } + +} + +func TestInvalidDefaultsValuesHaveNoEffect(t *testing.T) { + type D struct { + A []int `schema:"a,default:wrong1|wrong2"` + B bool `schema:"b,default:invalid"` + C *float32 `schema:"c,default:notAFloat"` + //uint types + D uint `schema:"d,default:notUint"` + E uint8 `schema:"e,default:notUint"` + F uint16 `schema:"f,default:notUint"` + G uint32 `schema:"g,default:notUint"` + H uint64 `schema:"h,default:notUint"` + // uint types pointers + I *uint `schema:"i,default:notUint"` + J *uint8 `schema:"j,default:notUint"` + K *uint16 `schema:"k,default:notUint"` + L *uint32 `schema:"l,default:notUint"` + M *uint64 `schema:"m,default:notUint"` + // int types + N int `schema:"n,default:notInt"` + O int8 `schema:"o,default:notInt"` + P int16 `schema:"p,default:notInt"` + Q int32 `schema:"q,default:notInt"` + R int64 `schema:"r,default:notInt"` + // int types pointers + S *int `schema:"s,default:notInt"` + T *int8 `schema:"t,default:notInt"` + U *int16 `schema:"u,default:notInt"` + V *int32 `schema:"v,default:notInt"` + W *int64 `schema:"w,default:notInt"` + // float + X float32 `schema:"c,default:notAFloat"` + Y float64 `schema:"c,default:notAFloat"` + Z *float64 `schema:"c,default:notAFloat"` + } + + d := D{} + + expected := D{A: []int{}} + + data := map[string][]string{} + + decoder := NewDecoder() + + err := decoder.Decode(&d, data) + + if err != nil { + t.Errorf("decoding should succeed but got error: %q", err) + } + + if !reflect.DeepEqual(expected, d) { + t.Errorf("expected %v but got %v", expected, d) + } +} + +func TestDefaultsAreNotSupportedForStructsAndStructSlices(t *testing.T) { + type C struct { + C string `schema:"c"` + } + + type D struct { + S S1 `schema:"s,default:{f1:0}"` + A []C `schema:"a,default:{c:test1}|{c:test2}"` + B []*int `schema:"b,default:12"` + E *C `schema:"e,default:{c:test3}"` + } + + d := D{} + + data := map[string][]string{} + + decoder := NewDecoder() + + err := decoder.Decode(&d, data) + + expected := "default option is supported only on: bool, float variants, string, unit variants types or their corresponding pointers or slices" + + if err == nil || !strings.Contains(err.Error(), expected) { + t.Errorf("decoding should fail with error msg %s got %q", expected, err) + } +}