From 86a2948b99be591db84039ebaf61ca3dff9bb1b0 Mon Sep 17 00:00:00 2001 From: Gabriel <45315755+JustAGabriel@users.noreply.github.com> Date: Wed, 29 Mar 2023 17:27:12 +0200 Subject: [PATCH] DEV-19660: prevent marshalling of `nil` ExtensionData (#212) --- bo/bilanzierung.go | 3 --- bo/marktlokation.go | 3 --- bo/messlokation.go | 3 --- bo/messlokation_test.go | 12 ++++-------- bo/zaehler_test.go | 5 ++++- com/zaehlwerk.go | 3 --- com/zaehlwerk_test.go | 4 +++- internal/testing_utils.go | 19 +++++++++++++++++++ .../unmappeddatamarshaller.go | 4 ++++ .../unmappeddatamarshaller_test.go | 4 ---- 10 files changed, 34 insertions(+), 26 deletions(-) create mode 100644 internal/testing_utils.go diff --git a/bo/bilanzierung.go b/bo/bilanzierung.go index 4a2e7bbf..edb5e40e 100644 --- a/bo/bilanzierung.go +++ b/bo/bilanzierung.go @@ -50,9 +50,6 @@ func (bila Bilanzierung) GetDefaultJsonTags() []string { } func (bila *Bilanzierung) UnmarshalJSON(bytes []byte) (err error) { - if bila.ExtensionData == nil { - bila.ExtensionData = map[string]any{} - } return unmappeddatamarshaller.UnmarshallWithUnmappedData(bila, &bila.ExtensionData, bytes) } diff --git a/bo/marktlokation.go b/bo/marktlokation.go index ad0d9bc6..8e3060f9 100644 --- a/bo/marktlokation.go +++ b/bo/marktlokation.go @@ -61,9 +61,6 @@ type Marktlokation struct { type marktlokationForUnmarshal Marktlokation func (malo *Marktlokation) UnmarshalJSON(bytes []byte) (err error) { - if malo.ExtensionData == nil { - malo.ExtensionData = map[string]any{} - } return unmappeddatamarshaller.UnmarshallWithUnmappedData(malo, &malo.ExtensionData, bytes) } diff --git a/bo/messlokation.go b/bo/messlokation.go index 71fa8421..dda8e1a9 100644 --- a/bo/messlokation.go +++ b/bo/messlokation.go @@ -53,9 +53,6 @@ func (melo Messlokation) GetDefaultJsonTags() []string { } func (melo *Messlokation) UnmarshalJSON(bytes []byte) (err error) { - if melo.ExtensionData == nil { - melo.ExtensionData = map[string]any{} - } return unmappeddatamarshaller.UnmarshallWithUnmappedData(melo, &melo.ExtensionData, bytes) } diff --git a/bo/messlokation_test.go b/bo/messlokation_test.go index 488023f0..f1a9cf12 100644 --- a/bo/messlokation_test.go +++ b/bo/messlokation_test.go @@ -83,16 +83,12 @@ func (s *Suite) Test_Messlokation_Deserialization() { err = json.Unmarshal(serializedMelo, &deserializedMelo) then.AssertThat(s.T(), err, is.Nil()) - // compare maps by value not by reference - isUnmappedDataEqual := melo.ExtensionData.CompareTo(deserializedMelo.ExtensionData) - then.AssertThat(s.T(), isUnmappedDataEqual, is.True()) - - // ignore reference inequality for 'Melo.UnmappedData.ExtensionData' field - deserializedMelo.ExtensionData = melo.ExtensionData - then.AssertThat(s.T(), deserializedMelo, is.EqualTo(melo)) + areEqual, err := internal.CompareAsJson(melo, deserializedMelo) + then.AssertThat(s.T(), err, is.Nil()) + then.AssertThat(s.T(), areEqual, is.True()) } -// Test_Failed_MesslokationValidation verifies that the validators of Messlokation work +// Test_Failed_MesslokationValidation verify that the validators of Messlokation work func (s *Suite) Test_Failed_MesslokationValidation() { validate := validator.New() validate.RegisterStructValidation(bo.XorStructLevelMesslokationValidation, bo.Messlokation{}) diff --git a/bo/zaehler_test.go b/bo/zaehler_test.go index e2852144..0fdc539e 100644 --- a/bo/zaehler_test.go +++ b/bo/zaehler_test.go @@ -57,7 +57,10 @@ func (s *Suite) Test_Zaehler_Deserialization() { var deserializedMeter bo.Zaehler err = json.Unmarshal(serializedMeter, &deserializedMeter) then.AssertThat(s.T(), err, is.Nil()) - then.AssertThat(s.T(), deserializedMeter, is.EqualTo(meter)) + + areEqual, err := internal.CompareAsJson(meter, deserializedMeter) + then.AssertThat(s.T(), err, is.Nil()) + then.AssertThat(s.T(), areEqual, is.True()) } // TestFailedZaehlerValidation verifies that the validators of a Zaehler work diff --git a/com/zaehlwerk.go b/com/zaehlwerk.go index a723d83a..5e99146d 100644 --- a/com/zaehlwerk.go +++ b/com/zaehlwerk.go @@ -40,9 +40,6 @@ type Zaehlwerk struct { } func (zw *Zaehlwerk) UnmarshalJSON(bytes []byte) (err error) { - if zw.ExtensionData == nil { - zw.ExtensionData = map[string]any{} - } return unmappeddatamarshaller.UnmarshallWithUnmappedData(zw, &zw.ExtensionData, bytes) } diff --git a/com/zaehlwerk_test.go b/com/zaehlwerk_test.go index f507ffdc..a680bf5e 100644 --- a/com/zaehlwerk_test.go +++ b/com/zaehlwerk_test.go @@ -27,12 +27,14 @@ func (s *Suite) Test_Zaehlwerk_Deserialization() { jsonString := string(serializedZaehlwerk) then.AssertThat(s.T(), strings.Contains(jsonString, "KWH"), is.True()) // stringified enum then.AssertThat(s.T(), strings.Contains(jsonString, "AUSSP"), is.True()) // stringified enum - then.AssertThat(s.T(), strings.Contains(jsonString, "wandlerfaktor\":1.2"), is.True()) // no quotes around die decimal + then.AssertThat(s.T(), strings.Contains(jsonString, "wandlerfaktor\":1.2"), is.True()) // no quotes around the decimal then.AssertThat(s.T(), err, is.Nil()) then.AssertThat(s.T(), serializedZaehlwerk, is.Not(is.Nil())) var deserializedZaehlwerk com.Zaehlwerk err = json.Unmarshal(serializedZaehlwerk, &deserializedZaehlwerk) then.AssertThat(s.T(), err, is.Nil()) + then.AssertThat(s.T(), deserializedZaehlwerk.ExtensionData.CompareTo(zaehlwerk.ExtensionData), is.True()) + zaehlwerk.ExtensionData = deserializedZaehlwerk.ExtensionData then.AssertThat(s.T(), deserializedZaehlwerk, is.EqualTo(zaehlwerk)) } diff --git a/internal/testing_utils.go b/internal/testing_utils.go new file mode 100644 index 00000000..05be67dc --- /dev/null +++ b/internal/testing_utils.go @@ -0,0 +1,19 @@ +package internal + +import ( + "bytes" + "encoding/json" +) + +func CompareAsJson[T any, T2 any](a T, a2 T2) (areEqual bool, err error) { + bytesFromA, err := json.Marshal(a) + if err != nil { + return + } + bytesFromA2, err := json.Marshal(a2) + if err != nil { + return + } + + return bytes.Equal(bytesFromA, bytesFromA2), nil +} diff --git a/internal/unmappeddatamarshaller/unmappeddatamarshaller.go b/internal/unmappeddatamarshaller/unmappeddatamarshaller.go index 7a7d555f..c7879c40 100644 --- a/internal/unmappeddatamarshaller/unmappeddatamarshaller.go +++ b/internal/unmappeddatamarshaller/unmappeddatamarshaller.go @@ -50,6 +50,10 @@ func HandleUnmappedDataPropertyMarshalling(b []byte) (bytes []byte, err error) { // UnmarshallWithUnmappedData will unmarshal a given type by mapping all strong-typed fields to the 'targetStruct'. All // other fields will be preserved in the 'unmappedDataInTargetStruct' dictionary. func UnmarshallWithUnmappedData[T any](targetStruct *T, unmappedDataInTargetStruct *ExtensionData, bytes []byte) (err error) { + if *unmappedDataInTargetStruct == nil { + *unmappedDataInTargetStruct = ExtensionData{} + } + var unmarshalledFields map[string]any err = json.Unmarshal(bytes, &unmarshalledFields) if err != nil { diff --git a/internal/unmappeddatamarshaller/unmappeddatamarshaller_test.go b/internal/unmappeddatamarshaller/unmappeddatamarshaller_test.go index 1fcb7607..ab845107 100644 --- a/internal/unmappeddatamarshaller/unmappeddatamarshaller_test.go +++ b/internal/unmappeddatamarshaller/unmappeddatamarshaller_test.go @@ -25,10 +25,6 @@ func (s SomeStruct) MarshalJSON() (b []byte, e error) { } func (s *SomeStruct) UnmarshalJSON(bytes []byte) (err error) { - if s.ExtensionData == nil { - s.ExtensionData = map[string]any{} - } - return UnmarshallWithUnmappedData(s, &s.ExtensionData, bytes) }