Skip to content

Commit

Permalink
Add improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
henrmota committed Oct 13, 2018
1 parent 7b0afb9 commit dd0a062
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 40 deletions.
2 changes: 1 addition & 1 deletion models_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (bc *BadComment) JSONAPILinks() *Links {
type Company struct {
ID string `jsonapi:"primary,companies"`
Name string `jsonapi:"attr,name"`
Boss Employee `jsonapi:"attr,boss"`
Boss *Employee `jsonapi:"attr,boss"`
Teams []Team `jsonapi:"attr,teams"`
FoundedAt time.Time `jsonapi:"attr,founded-at,iso8601"`
}
Expand Down
58 changes: 24 additions & 34 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,26 +248,14 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node)

structField := fieldType

if structField.Type.Kind() != reflect.Struct ||
fieldValue.Type() == reflect.TypeOf(new(time.Time)) ||
fieldValue.Type() == reflect.TypeOf(time.Time{}) {
value, err := unmarshalAttribute(attribute, args, structField, fieldValue)
if err != nil {
er = err
break
}
assign(fieldValue, value)
continue

} else {
structModel, err := unmarshalFromAttribute(attribute, fieldValue)
if err != nil {
er = err
break
}
fieldValue.Set((*structModel).Elem())
continue
value, err := unmarshalAttribute(attribute, args, structField, fieldValue)
if err != nil {
er = err
break
}

assign(fieldValue, value)
continue
} else if annotation == annotationRelation {
isSlice := fieldValue.Type().Kind() == reflect.Slice

Expand Down Expand Up @@ -346,21 +334,22 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node)
return er
}

func unmarshalFromAttribute(attribute interface{}, fieldValue reflect.Value) (*reflect.Value, error) {
func unmarshalFromAttribute(attribute interface{}, fieldValue reflect.Value) (reflect.Value, error) {
structData, err := json.Marshal(attribute)
if err != nil {
return nil, err
return reflect.Value{}, err
}
structNode := new(Node)
if err := json.Unmarshal(structData, &structNode.Attributes); err != nil {
return nil, err
return reflect.Value{}, err
}

structModel := reflect.New(fieldValue.Type())
if err := unmarshalNode(structNode, structModel, nil); err != nil {
return nil, err
return reflect.Value{}, err
}

return &structModel, nil
return structModel, nil
}

func fullNode(n *Node, included *map[string]*Node) *Node {
Expand All @@ -376,7 +365,7 @@ func fullNode(n *Node, included *map[string]*Node) *Node {
// assign will take the value specified and assign it to the field; if
// field is expecting a ptr assign will assign a ptr.
func assign(field, value reflect.Value) {
if field.Kind() == reflect.Ptr || field.Kind() == reflect.Struct{
if field.Kind() == reflect.Ptr {
field.Set(value)
} else {
field.Set(reflect.Indirect(value))
Expand All @@ -402,12 +391,13 @@ func unmarshalAttribute(
if fieldValue.Type() == reflect.TypeOf(time.Time{}) ||
fieldValue.Type() == reflect.TypeOf(new(time.Time)) {
value, err = handleTime(attribute, args, fieldValue)

return
}

// Handle field of type struct
if fieldValue.Type().Kind() == reflect.Struct {
value, err = handleStruct(attribute, fieldValue)
if fieldValue.Kind() == reflect.Struct {
value, err = unmarshalFromAttribute(attribute, fieldValue)
return
}

Expand All @@ -426,7 +416,7 @@ func unmarshalAttribute(

// Field was a Pointer type
if fieldValue.Kind() == reflect.Ptr {
value, err = handlePointer(attribute, args, fieldType, fieldValue, structField)
value, err = handlePointer(attribute, fieldType, fieldValue, structField)
return
}

Expand Down Expand Up @@ -482,7 +472,6 @@ func handleTime(attribute interface{}, args []string, fieldValue reflect.Value)
}

var at int64

if v.Kind() == reflect.Float64 {
at = int64(v.Interface().(float64))
} else if v.Kind() == reflect.Int {
Expand All @@ -492,7 +481,6 @@ func handleTime(attribute interface{}, args []string, fieldValue reflect.Value)
}

t := time.Unix(at, 0)

return reflect.ValueOf(t), nil
}

Expand Down Expand Up @@ -558,7 +546,6 @@ func handleNumeric(

func handlePointer(
attribute interface{},
args []string,
fieldType reflect.Type,
fieldValue reflect.Value,
structField reflect.StructField) (reflect.Value, error) {
Expand All @@ -574,12 +561,15 @@ func handlePointer(
concreteVal = reflect.ValueOf(&cVal)
case map[string]interface{}:
var err error
concreteVal, err = handleStruct(attribute, fieldValue)
fieldValueType := reflect.New(fieldValue.Type().Elem()).Elem()
concreteVal, err = unmarshalFromAttribute(attribute, fieldValueType)
if err != nil {
return reflect.Value{}, newErrUnsupportedPtrType(
reflect.ValueOf(attribute), fieldType, structField)
}

return concreteVal, err

default:
return reflect.Value{}, newErrUnsupportedPtrType(
reflect.ValueOf(attribute), fieldType, structField)
Expand Down Expand Up @@ -624,13 +614,13 @@ func handleStruct(
func handleStructSlice(
attribute interface{},
fieldValue reflect.Value) (reflect.Value, error) {

models := reflect.New(fieldValue.Type()).Elem()
dataMap := reflect.ValueOf(attribute).Interface().([]interface{})
for _, data := range dataMap {
model := reflect.New(fieldValue.Type().Elem()).Elem()

value, err := handleStruct(data, model)

value, err := unmarshalFromAttribute(data, model)
if err != nil {
continue
}
Expand Down
6 changes: 5 additions & 1 deletion response.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,12 @@ func visitModelNode(model interface{}, included *map[string]*Node,
newSlice[i] = nested.Attributes
}
node.Attributes[args[1]] = newSlice
} else if fieldValue.Kind() == reflect.Struct {
} else if fieldValue.Kind() == reflect.Struct ||
(fieldValue.Kind() == reflect.Ptr && fieldValue.Elem().Kind() == reflect.Struct) {
included := make(map[string]*Node)
if fieldValue.Kind() == reflect.Ptr {
fieldValue = fieldValue.Elem()
}
nested, err := visitModelNode(fieldValue, &included, true)
if err != nil {
er = err
Expand Down
32 changes: 28 additions & 4 deletions response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -918,13 +918,37 @@ func TestMarshalNestedStruct(t *testing.T) {
},
}

now := time.Now()
company := Company {
ID: "an_id",
Name: "Awesome Company",
Boss: &Employee{
Firstname: "Company",
Surname: "boss",
Age: 60,
},
Teams: []Team {
team,
},
FoundedAt: now,
}

buffer := bytes.NewBuffer(nil)
MarshalOnePayloadEmbedded(buffer, &team)
MarshalOnePayloadEmbedded(buffer, &company)
reader := bytes.NewReader(buffer.Bytes())
var finalTeam Team
UnmarshalPayload(reader, &finalTeam)
var finalCompany Company
UnmarshalPayload(reader, &finalCompany)

diff := company.FoundedAt.Sub(finalCompany.FoundedAt)

if diff.Seconds() > 1 {
t.Error("final unmarshal payload founded at must be approximately equal to the original.")
}

company.FoundedAt = time.Time{}
finalCompany.FoundedAt = time.Time{}

if !reflect.DeepEqual(team, finalTeam) {
if !reflect.DeepEqual(company, finalCompany) {
t.Error("final unmarshal payload should be equal to the original one.")
}
}
Expand Down

0 comments on commit dd0a062

Please sign in to comment.