diff --git a/cpp/evolution/v1/generated/ndjson/protocols.cc b/cpp/evolution/v1/generated/ndjson/protocols.cc index d0a81a61..1460a237 100644 --- a/cpp/evolution/v1/generated/ndjson/protocols.cc +++ b/cpp/evolution/v1/generated/ndjson/protocols.cc @@ -87,21 +87,21 @@ struct adl_serializer> { }; template <> -struct adl_serializer> { - static void to_json(ordered_json& j, std::variant const& value) { +struct adl_serializer> { + static void to_json(ordered_json& j, std::variant const& value) { switch (value.index()) { case 0: j = ordered_json{ {"RecordWithChanges", std::get(value)} }; break; case 1: - j = ordered_json{ {"RenamedRecord", std::get(value)} }; + j = ordered_json{ {"RenamedRecord", std::get(value)} }; break; default: throw std::runtime_error("Invalid union value"); } } - static void from_json(ordered_json const& j, std::variant& value) { + static void from_json(ordered_json const& j, std::variant& value) { auto it = j.begin(); std::string tag = it.key(); if (tag == "RecordWithChanges") { @@ -109,21 +109,21 @@ struct adl_serializer(); + value = it.value().get(); return; } } }; template <> -struct adl_serializer> { - static void to_json(ordered_json& j, std::variant const& value) { +struct adl_serializer> { + static void to_json(ordered_json& j, std::variant const& value) { std::visit([&j](auto const& v) {j = v;}, value); } - static void from_json(ordered_json const& j, std::variant& value) { + static void from_json(ordered_json const& j, std::variant& value) { if ((j.is_object())) { - value = j.get(); + value = j.get(); return; } if ((j.is_string())) { diff --git a/cpp/test/generated/binary/protocols.cc b/cpp/test/generated/binary/protocols.cc index 75a664d2..5ab2aa65 100644 --- a/cpp/test/generated/binary/protocols.cc +++ b/cpp/test/generated/binary/protocols.cc @@ -1840,6 +1840,24 @@ template ReadT1, typename T2, yardl::bina test_model::binary::ReadFruits(stream, value); } +[[maybe_unused]] void WriteAliasedSimpleRecord(yardl::binary::CodedOutputStream& stream, test_model::AliasedSimpleRecord const& value) { + if constexpr (yardl::binary::IsTriviallySerializable::value) { + yardl::binary::WriteTriviallySerializable(stream, value); + return; + } + + test_model::binary::WriteSimpleRecord(stream, value); +} + +[[maybe_unused]] void ReadAliasedSimpleRecord(yardl::binary::CodedInputStream& stream, test_model::AliasedSimpleRecord& value) { + if constexpr (yardl::binary::IsTriviallySerializable::value) { + yardl::binary::ReadTriviallySerializable(stream, value); + return; + } + + test_model::binary::ReadSimpleRecord(stream, value); +} + template WriteT1, typename T2, yardl::binary::Writer WriteT2> [[maybe_unused]] void WriteAliasedOpenGeneric(yardl::binary::CodedOutputStream& stream, test_model::AliasedOpenGeneric const& value) { if constexpr (yardl::binary::IsTriviallySerializable>::value) { @@ -2394,6 +2412,24 @@ template ReadA, typename B, yardl::binary:: ReadUnion(stream, value); } +[[maybe_unused]] void WriteAliasedIntOrAliasedSimpleRecord(yardl::binary::CodedOutputStream& stream, test_model::AliasedIntOrAliasedSimpleRecord const& value) { + if constexpr (yardl::binary::IsTriviallySerializable::value) { + yardl::binary::WriteTriviallySerializable(stream, value); + return; + } + + WriteUnion(stream, value); +} + +[[maybe_unused]] void ReadAliasedIntOrAliasedSimpleRecord(yardl::binary::CodedInputStream& stream, test_model::AliasedIntOrAliasedSimpleRecord& value) { + if constexpr (yardl::binary::IsTriviallySerializable::value) { + yardl::binary::ReadTriviallySerializable(stream, value); + return; + } + + ReadUnion(stream, value); +} + [[maybe_unused]] void WriteAliasedNullableIntSimpleRecord(yardl::binary::CodedOutputStream& stream, test_model::AliasedNullableIntSimpleRecord const& value) { if constexpr (yardl::binary::IsTriviallySerializable::value) { yardl::binary::WriteTriviallySerializable(stream, value); diff --git a/cpp/test/generated/model.json b/cpp/test/generated/model.json index 047e650a..b8f3f5af 100644 --- a/cpp/test/generated/model.json +++ b/cpp/test/generated/model.json @@ -1615,6 +1615,12 @@ "type": "TestModel.Fruits" } }, + { + "alias": { + "name": "AliasedSimpleRecord", + "type": "TestModel.SimpleRecord" + } + }, { "alias": { "name": "AliasedOpenGeneric", @@ -2271,6 +2277,21 @@ ] } }, + { + "alias": { + "name": "AliasedIntOrAliasedSimpleRecord", + "type": [ + { + "tag": "int32", + "type": "int32" + }, + { + "tag": "AliasedSimpleRecord", + "type": "TestModel.AliasedSimpleRecord" + } + ] + } + }, { "alias": { "name": "AliasedNullableIntSimpleRecord", diff --git a/cpp/test/generated/ndjson/protocols.cc b/cpp/test/generated/ndjson/protocols.cc index 2d3ad6de..309b5d0f 100644 --- a/cpp/test/generated/ndjson/protocols.cc +++ b/cpp/test/generated/ndjson/protocols.cc @@ -678,21 +678,21 @@ struct adl_serializer> { }; template <> -struct adl_serializer>> { - static void to_json(ordered_json& j, std::variant> const& value) { +struct adl_serializer>> { + static void to_json(ordered_json& j, std::variant> const& value) { switch (value.index()) { case 0: j = ordered_json{ {"acquisition", std::get(value)} }; break; case 1: - j = ordered_json{ {"image", std::get>(value)} }; + j = ordered_json{ {"image", std::get>(value)} }; break; default: throw std::runtime_error("Invalid union value"); } } - static void from_json(ordered_json const& j, std::variant>& value) { + static void from_json(ordered_json const& j, std::variant>& value) { auto it = j.begin(); std::string tag = it.key(); if (tag == "acquisition") { @@ -700,7 +700,7 @@ struct adl_serializer>(); + value = it.value().get>(); return; } } @@ -745,58 +745,58 @@ struct adl_serializer> { }; template <> -struct adl_serializer>> { - static void to_json(ordered_json& j, std::variant> const& value) { +struct adl_serializer, yardl::NDArray>> { + static void to_json(ordered_json& j, std::variant, yardl::NDArray> const& value) { switch (value.index()) { case 0: - j = ordered_json{ {"imageFloat", std::get(value)} }; + j = ordered_json{ {"imageFloat", std::get>(value)} }; break; case 1: - j = ordered_json{ {"imageDouble", std::get>(value)} }; + j = ordered_json{ {"imageDouble", std::get>(value)} }; break; default: throw std::runtime_error("Invalid union value"); } } - static void from_json(ordered_json const& j, std::variant>& value) { + static void from_json(ordered_json const& j, std::variant, yardl::NDArray>& value) { auto it = j.begin(); std::string tag = it.key(); if (tag == "imageFloat") { - value = it.value().get(); + value = it.value().get>(); return; } if (tag == "imageDouble") { - value = it.value().get>(); + value = it.value().get>(); return; } } }; template <> -struct adl_serializer> { - static void to_json(ordered_json& j, std::variant const& value) { +struct adl_serializer> { + static void to_json(ordered_json& j, std::variant const& value) { switch (value.index()) { case 0: - j = ordered_json{ {"T1", std::get(value)} }; + j = ordered_json{ {"T1", std::get(value)} }; break; case 1: - j = ordered_json{ {"T2", std::get(value)} }; + j = ordered_json{ {"T2", std::get(value)} }; break; default: throw std::runtime_error("Invalid union value"); } } - static void from_json(ordered_json const& j, std::variant& value) { + static void from_json(ordered_json const& j, std::variant& value) { auto it = j.begin(); std::string tag = it.key(); if (tag == "T1") { - value = it.value().get(); + value = it.value().get(); return; } if (tag == "T2") { - value = it.value().get(); + value = it.value().get(); return; } } diff --git a/cpp/test/generated/types.h b/cpp/test/generated/types.h index e51f96d1..fae96bf3 100644 --- a/cpp/test/generated/types.h +++ b/cpp/test/generated/types.h @@ -657,6 +657,8 @@ using AliasedString = std::string; using AliasedEnum = test_model::Fruits; +using AliasedSimpleRecord = test_model::SimpleRecord; + template using AliasedOpenGeneric = test_model::AliasedTuple; @@ -883,6 +885,8 @@ struct RecordContainingNestedGenericRecords { using AliasedIntOrSimpleRecord = std::variant; +using AliasedIntOrAliasedSimpleRecord = std::variant; + using AliasedNullableIntSimpleRecord = std::variant; struct RecordWithComputedFields { diff --git a/matlab/generated/+test_model/AliasedIntOrAliasedSimpleRecord.m b/matlab/generated/+test_model/AliasedIntOrAliasedSimpleRecord.m new file mode 100644 index 00000000..fe4f0655 --- /dev/null +++ b/matlab/generated/+test_model/AliasedIntOrAliasedSimpleRecord.m @@ -0,0 +1,49 @@ +% This file was generated by the "yardl" tool. DO NOT EDIT. + +classdef AliasedIntOrAliasedSimpleRecord < yardl.Union + methods (Static) + function res = Int32(value) + res = test_model.AliasedIntOrAliasedSimpleRecord(1, value); + end + + function res = AliasedSimpleRecord(value) + res = test_model.AliasedIntOrAliasedSimpleRecord(2, value); + end + + function z = zeros(varargin) + elem = test_model.AliasedIntOrAliasedSimpleRecord(0, yardl.None); + if nargin == 0 + z = elem; + return; + end + sz = [varargin{:}]; + if isscalar(sz) + sz = [sz, sz]; + end + z = reshape(repelem(elem, prod(sz)), sz); + end + end + + methods + function res = isInt32(self) + res = self.index == 1; + end + + function res = isAliasedSimpleRecord(self) + res = self.index == 2; + end + + function eq = eq(self, other) + eq = isa(other, "test_model.AliasedIntOrAliasedSimpleRecord") && isequal(self.index, other.index) && isequal(self.value, other.value); + end + + function ne = ne(self, other) + ne = ~self.eq(other); + end + + function t = tag(self) + tags_ = ["Int32", "AliasedSimpleRecord"]; + t = tags_(self.index_); + end + end +end diff --git a/matlab/generated/+test_model/AliasedSimpleRecord.m b/matlab/generated/+test_model/AliasedSimpleRecord.m new file mode 100644 index 00000000..bffe4286 --- /dev/null +++ b/matlab/generated/+test_model/AliasedSimpleRecord.m @@ -0,0 +1,5 @@ +% This file was generated by the "yardl" tool. DO NOT EDIT. + +function c = AliasedSimpleRecord(varargin) + c = test_model.SimpleRecord(varargin{:}); +end diff --git a/models/test/unittests.yml b/models/test/unittests.yml index 0ecf5826..936a0845 100644 --- a/models/test/unittests.yml +++ b/models/test/unittests.yml @@ -428,6 +428,7 @@ AdvancedGenerics: !protocol AliasedString: string AliasedEnum: Fruits +AliasedSimpleRecord: SimpleRecord AliasedOpenGeneric: AliasedTuple AliasedClosedGeneric: AliasedTuple AliasedOptional: int? @@ -528,6 +529,7 @@ Aliases: !protocol items: AliasedGenericUnion2 AliasedIntOrSimpleRecord: [int, SimpleRecord] +AliasedIntOrAliasedSimpleRecord: [int, AliasedSimpleRecord] AliasedNullableIntSimpleRecord: [null, int, SimpleRecord] StreamsOfAliasedUnions: !protocol diff --git a/python/test_model/__init__.py b/python/test_model/__init__.py index 9b49f57e..29ef2cc5 100644 --- a/python/test_model/__init__.py +++ b/python/test_model/__init__.py @@ -33,12 +33,14 @@ def _parse_version(version: str) -> _Tuple[int, ...]: AliasedGenericRank2Array, AliasedGenericUnion2, AliasedGenericVector, + AliasedIntOrAliasedSimpleRecord, AliasedIntOrSimpleRecord, AliasedMap, AliasedMultiGenericOptional, AliasedNullableIntSimpleRecord, AliasedOpenGeneric, AliasedOptional, + AliasedSimpleRecord, AliasedString, AliasedTuple, AliasedVectorOfGenericRecords, diff --git a/python/test_model/types.py b/python/test_model/types.py index 9cd9911d..a531a22f 100644 --- a/python/test_model/types.py +++ b/python/test_model/types.py @@ -1003,6 +1003,8 @@ def __repr__(self) -> str: AliasedEnum = Fruits +AliasedSimpleRecord = SimpleRecord + AliasedOpenGeneric = AliasedTuple AliasedClosedGeneric = AliasedTuple[AliasedString, AliasedEnum] @@ -1377,6 +1379,17 @@ class AliasedIntOrSimpleRecordUnionCase(AliasedIntOrSimpleRecord, yardl.UnionCas AliasedIntOrSimpleRecord.SimpleRecord = type("AliasedIntOrSimpleRecord.SimpleRecord", (AliasedIntOrSimpleRecordUnionCase,), {"index": 1, "tag": "SimpleRecord"}) del AliasedIntOrSimpleRecordUnionCase +class AliasedIntOrAliasedSimpleRecord: + Int32: typing.ClassVar[type["AliasedIntOrAliasedSimpleRecordUnionCase[yardl.Int32]"]] + AliasedSimpleRecord: typing.ClassVar[type["AliasedIntOrAliasedSimpleRecordUnionCase[AliasedSimpleRecord]"]] + +class AliasedIntOrAliasedSimpleRecordUnionCase(AliasedIntOrAliasedSimpleRecord, yardl.UnionCase[_T]): + pass + +AliasedIntOrAliasedSimpleRecord.Int32 = type("AliasedIntOrAliasedSimpleRecord.Int32", (AliasedIntOrAliasedSimpleRecordUnionCase,), {"index": 0, "tag": "int32"}) +AliasedIntOrAliasedSimpleRecord.AliasedSimpleRecord = type("AliasedIntOrAliasedSimpleRecord.AliasedSimpleRecord", (AliasedIntOrAliasedSimpleRecordUnionCase,), {"index": 1, "tag": "AliasedSimpleRecord"}) +del AliasedIntOrAliasedSimpleRecordUnionCase + class AliasedNullableIntSimpleRecord: Int32: typing.ClassVar[type["AliasedNullableIntSimpleRecordUnionCase[yardl.Int32]"]] SimpleRecord: typing.ClassVar[type["AliasedNullableIntSimpleRecordUnionCase[SimpleRecord]"]] @@ -2023,6 +2036,7 @@ def _mk_get_dtype(): dtype_map.setdefault(RecordWithAliasedGenerics, np.dtype([('my_strings', get_dtype(types.GenericAlias(tuples.Tuple, (str, str,)))), ('aliased_strings', get_dtype(types.GenericAlias(tuples.Tuple, (str, str,))))], align=True)) dtype_map.setdefault(AliasedString, np.dtype(np.object_)) dtype_map.setdefault(AliasedEnum, get_dtype(basic_types.Fruits)) + dtype_map.setdefault(AliasedSimpleRecord, get_dtype(SimpleRecord)) dtype_map.setdefault(AliasedOpenGeneric, lambda type_args: get_dtype(types.GenericAlias(tuples.Tuple, (type_args[0], type_args[1],)))) dtype_map.setdefault(AliasedClosedGeneric, get_dtype(types.GenericAlias(tuples.Tuple, (AliasedString, AliasedEnum,)))) dtype_map.setdefault(AliasedOptional, np.dtype([('has_value', np.dtype(np.bool_)), ('value', np.dtype(np.int32))], align=True)) @@ -2043,6 +2057,7 @@ def _mk_get_dtype(): dtype_map.setdefault(RecordContainingGenericRecords, lambda type_args: np.dtype([('g1', get_dtype(types.GenericAlias(RecordWithOptionalGenericField, (type_args[0],)))), ('g1a', get_dtype(types.GenericAlias(RecordWithAliasedOptionalGenericField, (type_args[0],)))), ('g2', get_dtype(types.GenericAlias(RecordWithOptionalGenericUnionField, (type_args[0], type_args[1],)))), ('g2a', get_dtype(types.GenericAlias(RecordWithAliasedOptionalGenericUnionField, (type_args[0], type_args[1],)))), ('g3', get_dtype(types.GenericAlias(tuples.Tuple, (type_args[0], type_args[1],)))), ('g3a', get_dtype(types.GenericAlias(tuples.Tuple, (type_args[0], type_args[1],)))), ('g4', get_dtype(types.GenericAlias(RecordWithGenericVectors, (type_args[1],)))), ('g5', get_dtype(types.GenericAlias(RecordWithGenericFixedVectors, (type_args[1],)))), ('g6', get_dtype(types.GenericAlias(RecordWithGenericArrays, (type_args[1],)))), ('g7', get_dtype(types.GenericAlias(RecordWithGenericMaps, (type_args[0], type_args[1],))))], align=True)) dtype_map.setdefault(RecordContainingNestedGenericRecords, np.dtype([('f1', get_dtype(types.GenericAlias(RecordWithOptionalGenericField, (str,)))), ('f1a', get_dtype(types.GenericAlias(RecordWithAliasedOptionalGenericField, (str,)))), ('f2', get_dtype(types.GenericAlias(RecordWithOptionalGenericUnionField, (str, yardl.Int32,)))), ('f2a', get_dtype(types.GenericAlias(RecordWithAliasedOptionalGenericUnionField, (str, yardl.Int32,)))), ('nested', get_dtype(types.GenericAlias(RecordContainingGenericRecords, (str, yardl.Int32,))))], align=True)) dtype_map.setdefault(AliasedIntOrSimpleRecord, np.dtype(np.object_)) + dtype_map.setdefault(AliasedIntOrAliasedSimpleRecord, np.dtype(np.object_)) dtype_map.setdefault(AliasedNullableIntSimpleRecord, np.dtype(np.object_)) dtype_map.setdefault(typing.Optional[AliasedNullableIntSimpleRecord], np.dtype(np.object_)) dtype_map.setdefault(Int32OrFloat32, np.dtype(np.object_)) diff --git a/tooling/internal/cpp/ndjson/ndjson.go b/tooling/internal/cpp/ndjson/ndjson.go index 60083bc3..8fb8aeae 100644 --- a/tooling/internal/cpp/ndjson/ndjson.go +++ b/tooling/internal/cpp/ndjson/ndjson.go @@ -94,6 +94,9 @@ func WriteNdJson(env *dsl.Environment, options packaging.CppCodegenOptions) erro self.Visit(t.ResolvedDefinition) case *dsl.GeneralizedType: if t.Cases.IsUnion() { + // Convert the union cases to their u types so we don't generate + // duplicate `adl_serializer` specializations for the same type. + t = dsl.ToUnionOfUnderlyingTypes(t) scalarType := dsl.NormalizeGenericTypeParameters(t.ToScalar()).(*dsl.GeneralizedType) typeSyntax := common.TypeSyntax(scalarType) if _, ok := unionsBySyntax[typeSyntax]; !ok { diff --git a/tooling/pkg/dsl/typefunctions.go b/tooling/pkg/dsl/typefunctions.go index 82feba6d..9f9f0434 100644 --- a/tooling/pkg/dsl/typefunctions.go +++ b/tooling/pkg/dsl/typefunctions.go @@ -522,6 +522,18 @@ func GetUnderlyingType(t Type) Type { return t } +// Convert a Union's TypeCases to their respective underlying Types +func ToUnionOfUnderlyingTypes(t *GeneralizedType) *GeneralizedType { + u := *t + u.Cases = make([]*TypeCase, len(t.Cases)) + for i, c := range t.Cases { + nc := *c + nc.Type = GetUnderlyingType(c.Type) + u.Cases[i] = &nc + } + return &u +} + func GetPrimitiveType(t Type) (primitive PrimitiveDefinition, ok bool) { switch t := GetUnderlyingType(t).(type) { case *SimpleType: