From 69acd6bc3dbd1a73426780018e96b2d7cc2572f3 Mon Sep 17 00:00:00 2001 From: Joe Naegele Date: Wed, 2 Oct 2024 21:23:18 +0000 Subject: [PATCH 1/2] Fix #179 - invalid defaulting of union in record --- cpp/test/generated/binary/protocols.cc | 34 ++++++++++++- cpp/test/generated/hdf5/protocols.cc | 26 +++++++++- cpp/test/generated/model.json | 24 +++++++++ cpp/test/generated/ndjson/protocols.cc | 41 ++++++++++++++++ cpp/test/generated/protocols.cc | 2 +- cpp/test/generated/types.h | 16 +++++- .../+binary/RecordWithStringSerializer.m | 24 +++++++++ .../+binary/RecordWithUnionsSerializer.m | 5 +- .../generated/+basic_types/RecordWithString.m | 41 ++++++++++++++++ .../+basic_types/RecordWithStringOrInt32.m | 49 +++++++++++++++++++ .../generated/+basic_types/RecordWithUnions.m | 6 ++- .../generated/+test_model/UnionsWriterBase.m | 2 +- models/basic-types/basic.yml | 5 ++ python/test_model/basic_types/__init__.py | 2 + python/test_model/basic_types/binary.py | 26 ++++++++-- python/test_model/basic_types/ndjson.py | 44 +++++++++++++++++ python/test_model/basic_types/types.py | 46 +++++++++++++++-- python/test_model/protocols.py | 2 +- python/test_model/types.py | 6 ++- tooling/internal/matlab/types/types.go | 5 -- tooling/internal/python/types/types.go | 5 -- 21 files changed, 383 insertions(+), 28 deletions(-) create mode 100644 matlab/generated/+basic_types/+binary/RecordWithStringSerializer.m create mode 100644 matlab/generated/+basic_types/RecordWithString.m create mode 100644 matlab/generated/+basic_types/RecordWithStringOrInt32.m diff --git a/cpp/test/generated/binary/protocols.cc b/cpp/test/generated/binary/protocols.cc index 28c687ee..656c5712 100644 --- a/cpp/test/generated/binary/protocols.cc +++ b/cpp/test/generated/binary/protocols.cc @@ -25,6 +25,15 @@ struct IsTriviallySerializable> { offsetof(__T__, v1) < offsetof(__T__, v2); }; +template <> +struct IsTriviallySerializable { + using __T__ = basic_types::RecordWithString; + static constexpr bool value = + std::is_standard_layout_v<__T__> && + IsTriviallySerializable::value && + (sizeof(__T__) == (sizeof(__T__::i))); +}; + template <> struct IsTriviallySerializable { using __T__ = basic_types::RecordWithUnions; @@ -33,8 +42,9 @@ struct IsTriviallySerializable { IsTriviallySerializable::value && IsTriviallySerializable::value && IsTriviallySerializable::value && - (sizeof(__T__) == (sizeof(__T__::null_or_int_or_string) + sizeof(__T__::date_or_datetime) + sizeof(__T__::null_or_fruits_or_days_of_week))) && - offsetof(__T__, null_or_int_or_string) < offsetof(__T__, date_or_datetime) && offsetof(__T__, date_or_datetime) < offsetof(__T__, null_or_fruits_or_days_of_week); + IsTriviallySerializable::value && + (sizeof(__T__) == (sizeof(__T__::null_or_int_or_string) + sizeof(__T__::date_or_datetime) + sizeof(__T__::null_or_fruits_or_days_of_week) + sizeof(__T__::record_or_int))) && + offsetof(__T__, null_or_int_or_string) < offsetof(__T__, date_or_datetime) && offsetof(__T__, date_or_datetime) < offsetof(__T__, null_or_fruits_or_days_of_week) && offsetof(__T__, null_or_fruits_or_days_of_week) < offsetof(__T__, record_or_int); }; template @@ -908,6 +918,24 @@ template ReadT> yardl::binary::ReadVector(stream, value); } +[[maybe_unused]] void WriteRecordWithString(yardl::binary::CodedOutputStream& stream, basic_types::RecordWithString const& value) { + if constexpr (yardl::binary::IsTriviallySerializable::value) { + yardl::binary::WriteTriviallySerializable(stream, value); + return; + } + + yardl::binary::WriteString(stream, value.i); +} + +[[maybe_unused]] void ReadRecordWithString(yardl::binary::CodedInputStream& stream, basic_types::RecordWithString& value) { + if constexpr (yardl::binary::IsTriviallySerializable::value) { + yardl::binary::ReadTriviallySerializable(stream, value); + return; + } + + yardl::binary::ReadString(stream, value.i); +} + [[maybe_unused]] void WriteRecordWithUnions(yardl::binary::CodedOutputStream& stream, basic_types::RecordWithUnions const& value) { if constexpr (yardl::binary::IsTriviallySerializable::value) { yardl::binary::WriteTriviallySerializable(stream, value); @@ -917,6 +945,7 @@ template ReadT> WriteUnion(stream, value.null_or_int_or_string); WriteUnion(stream, value.date_or_datetime); basic_types::binary::WriteGenericNullableUnion2, basic_types::DaysOfWeek, yardl::binary::WriteFlags>(stream, value.null_or_fruits_or_days_of_week); + WriteUnion(stream, value.record_or_int); } [[maybe_unused]] void ReadRecordWithUnions(yardl::binary::CodedInputStream& stream, basic_types::RecordWithUnions& value) { @@ -928,6 +957,7 @@ template ReadT> ReadUnion(stream, value.null_or_int_or_string); ReadUnion(stream, value.date_or_datetime); basic_types::binary::ReadGenericNullableUnion2, basic_types::DaysOfWeek, yardl::binary::ReadFlags>(stream, value.null_or_fruits_or_days_of_week); + ReadUnion(stream, value.record_or_int); } template WriteT0, typename T1, yardl::binary::Writer WriteT1> diff --git a/cpp/test/generated/hdf5/protocols.cc b/cpp/test/generated/hdf5/protocols.cc index 42bfb9cf..c2d786b4 100644 --- a/cpp/test/generated/hdf5/protocols.cc +++ b/cpp/test/generated/hdf5/protocols.cc @@ -413,23 +413,39 @@ namespace { return t; } +struct _Inner_RecordWithString { + _Inner_RecordWithString() {} + _Inner_RecordWithString(basic_types::RecordWithString const& o) + : i(o.i) { + } + + void ToOuter (basic_types::RecordWithString& o) const { + yardl::hdf5::ToOuter(i, o.i); + } + + yardl::hdf5::InnerVlenString i; +}; + struct _Inner_RecordWithUnions { _Inner_RecordWithUnions() {} _Inner_RecordWithUnions(basic_types::RecordWithUnions const& o) : null_or_int_or_string(o.null_or_int_or_string), date_or_datetime(o.date_or_datetime), - null_or_fruits_or_days_of_week(o.null_or_fruits_or_days_of_week) { + null_or_fruits_or_days_of_week(o.null_or_fruits_or_days_of_week), + record_or_int(o.record_or_int) { } void ToOuter (basic_types::RecordWithUnions& o) const { yardl::hdf5::ToOuter(null_or_int_or_string, o.null_or_int_or_string); yardl::hdf5::ToOuter(date_or_datetime, o.date_or_datetime); yardl::hdf5::ToOuter(null_or_fruits_or_days_of_week, o.null_or_fruits_or_days_of_week); + yardl::hdf5::ToOuter(record_or_int, o.record_or_int); } ::InnerUnion2 null_or_int_or_string; ::InnerUnion2 date_or_datetime; ::InnerUnion2 null_or_fruits_or_days_of_week; + ::InnerUnion2 record_or_int; }; template @@ -446,12 +462,20 @@ struct _Inner_GenericRecordWithComputedFields { ::InnerUnion2<_T0_Inner, T0, _T1_Inner, T1> f1; }; +[[maybe_unused]] H5::CompType GetRecordWithStringHdf5Ddl() { + using RecordType = basic_types::hdf5::_Inner_RecordWithString; + H5::CompType t(sizeof(RecordType)); + t.insertMember("i", HOFFSET(RecordType, i), yardl::hdf5::InnerVlenStringDdl()); + return t; +} + [[maybe_unused]] H5::CompType GetRecordWithUnionsHdf5Ddl() { using RecordType = basic_types::hdf5::_Inner_RecordWithUnions; H5::CompType t(sizeof(RecordType)); t.insertMember("nullOrIntOrString", HOFFSET(RecordType, null_or_int_or_string), ::InnerUnion2Ddl(true, H5::PredType::NATIVE_INT32, "int32", yardl::hdf5::InnerVlenStringDdl(), "string")); t.insertMember("dateOrDatetime", HOFFSET(RecordType, date_or_datetime), ::InnerUnion2Ddl(false, yardl::hdf5::TimeTypeDdl(), "time", yardl::hdf5::DateTimeTypeDdl(), "datetime")); t.insertMember("nullOrFruitsOrDaysOfWeek", HOFFSET(RecordType, null_or_fruits_or_days_of_week), ::InnerUnion2Ddl(true, basic_types::hdf5::GetFruitsHdf5Ddl(), "T1", H5::PredType::NATIVE_INT32, "T2")); + t.insertMember("recordOrInt", HOFFSET(RecordType, record_or_int), ::InnerUnion2Ddl(false, basic_types::hdf5::GetRecordWithStringHdf5Ddl(), "RecordWithString", H5::PredType::NATIVE_INT32, "int32")); return t; } diff --git a/cpp/test/generated/model.json b/cpp/test/generated/model.json index f9b2bd90..f95c8dff 100644 --- a/cpp/test/generated/model.json +++ b/cpp/test/generated/model.json @@ -192,6 +192,17 @@ } } }, + { + "record": { + "name": "RecordWithString", + "fields": [ + { + "name": "i", + "type": "string" + } + ] + } + }, { "record": { "name": "RecordWithUnions", @@ -232,6 +243,19 @@ "BasicTypes.DaysOfWeek" ] } + }, + { + "name": "recordOrInt", + "type": [ + { + "tag": "RecordWithString", + "type": "BasicTypes.RecordWithString" + }, + { + "tag": "int32", + "type": "int32" + } + ] } ] } diff --git a/cpp/test/generated/ndjson/protocols.cc b/cpp/test/generated/ndjson/protocols.cc index 4c7398fa..36d1c89d 100644 --- a/cpp/test/generated/ndjson/protocols.cc +++ b/cpp/test/generated/ndjson/protocols.cc @@ -25,6 +25,9 @@ void from_json(ordered_json const& j, basic_types::DaysOfWeek& value); void to_json(ordered_json& j, basic_types::TextFormat const& value); void from_json(ordered_json const& j, basic_types::TextFormat& value); +void to_json(ordered_json& j, basic_types::RecordWithString const& value); +void from_json(ordered_json const& j, basic_types::RecordWithString& value); + void to_json(ordered_json& j, basic_types::RecordWithUnions const& value); void from_json(ordered_json const& j, basic_types::RecordWithUnions& value); @@ -353,6 +356,25 @@ struct adl_serializer +struct adl_serializer> { + static void to_json(ordered_json& j, std::variant const& value) { + std::visit([&j](auto const& v) {j = v;}, value); + } + + static void from_json(ordered_json const& j, std::variant& value) { + if ((j.is_object())) { + value = j.get(); + return; + } + if ((j.is_number())) { + value = j.get(); + return; + } + throw std::runtime_error("Invalid union value"); + } +}; + template <> struct adl_serializer, int32_t>> { static void to_json(ordered_json& j, std::variant, int32_t> const& value) { @@ -1125,6 +1147,19 @@ void from_json(ordered_json const& j, basic_types::TextFormat& value) { } } +void to_json(ordered_json& j, basic_types::RecordWithString const& value) { + j = ordered_json::object(); + if (yardl::ndjson::ShouldSerializeFieldValue(value.i)) { + j.push_back({"i", value.i}); + } +} + +void from_json(ordered_json const& j, basic_types::RecordWithString& value) { + if (auto it = j.find("i"); it != j.end()) { + it->get_to(value.i); + } +} + void to_json(ordered_json& j, basic_types::RecordWithUnions const& value) { j = ordered_json::object(); if (yardl::ndjson::ShouldSerializeFieldValue(value.null_or_int_or_string)) { @@ -1136,6 +1171,9 @@ void to_json(ordered_json& j, basic_types::RecordWithUnions const& value) { if (yardl::ndjson::ShouldSerializeFieldValue(value.null_or_fruits_or_days_of_week)) { j.push_back({"nullOrFruitsOrDaysOfWeek", value.null_or_fruits_or_days_of_week}); } + if (yardl::ndjson::ShouldSerializeFieldValue(value.record_or_int)) { + j.push_back({"recordOrInt", value.record_or_int}); + } } void from_json(ordered_json const& j, basic_types::RecordWithUnions& value) { @@ -1148,6 +1186,9 @@ void from_json(ordered_json const& j, basic_types::RecordWithUnions& value) { if (auto it = j.find("nullOrFruitsOrDaysOfWeek"); it != j.end()) { it->get_to(value.null_or_fruits_or_days_of_week); } + if (auto it = j.find("recordOrInt"); it != j.end()) { + it->get_to(value.record_or_int); + } } template diff --git a/cpp/test/generated/protocols.cc b/cpp/test/generated/protocols.cc index 1d82f978..7c2d7300 100644 --- a/cpp/test/generated/protocols.cc +++ b/cpp/test/generated/protocols.cc @@ -4221,7 +4221,7 @@ void UnionsReaderBaseInvalidState(uint8_t attempted, uint8_t current) { } // namespace -std::string UnionsWriterBase::schema_ = R"({"protocol":{"name":"Unions","sequence":[{"name":"intOrSimpleRecord","type":[{"tag":"int32","type":"int32"},{"tag":"SimpleRecord","type":"TestModel.SimpleRecord"}]},{"name":"intOrRecordWithVlens","type":[{"tag":"int32","type":"int32"},{"tag":"RecordWithVlens","type":"TestModel.RecordWithVlens"}]},{"name":"monosotateOrIntOrSimpleRecord","type":[null,{"tag":"int32","type":"int32"},{"tag":"SimpleRecord","type":"TestModel.SimpleRecord"}]},{"name":"recordWithUnions","type":"BasicTypes.RecordWithUnions"}]},"types":[{"name":"DaysOfWeek","values":[{"symbol":"monday","value":1},{"symbol":"tuesday","value":2},{"symbol":"wednesday","value":4},{"symbol":"thursday","value":8},{"symbol":"friday","value":16},{"symbol":"saturday","value":32},{"symbol":"sunday","value":64}]},{"name":"Fruits","values":[{"symbol":"apple","value":1},{"symbol":"banana","value":2},{"symbol":"pear","value":3}]},{"name":"GenericNullableUnion2","typeParameters":["T1","T2"],"type":[null,{"tag":"T1","type":"T1"},{"tag":"T2","type":"T2"}]},{"name":"RecordWithUnions","fields":[{"name":"nullOrIntOrString","type":[null,{"tag":"int32","type":"int32"},{"tag":"string","type":"string"}]},{"name":"dateOrDatetime","type":[{"tag":"time","type":"time"},{"tag":"datetime","type":"datetime"}]},{"name":"nullOrFruitsOrDaysOfWeek","type":{"name":"BasicTypes.GenericNullableUnion2","typeArguments":["BasicTypes.Fruits","BasicTypes.DaysOfWeek"]}}]},{"name":"RecordWithVlens","fields":[{"name":"a","type":{"vector":{"items":"TestModel.SimpleRecord"}}},{"name":"b","type":"int32"},{"name":"c","type":"int32"}]},{"name":"SimpleRecord","fields":[{"name":"x","type":"int32"},{"name":"y","type":"int32"},{"name":"z","type":"int32"}]}]})"; +std::string UnionsWriterBase::schema_ = R"({"protocol":{"name":"Unions","sequence":[{"name":"intOrSimpleRecord","type":[{"tag":"int32","type":"int32"},{"tag":"SimpleRecord","type":"TestModel.SimpleRecord"}]},{"name":"intOrRecordWithVlens","type":[{"tag":"int32","type":"int32"},{"tag":"RecordWithVlens","type":"TestModel.RecordWithVlens"}]},{"name":"monosotateOrIntOrSimpleRecord","type":[null,{"tag":"int32","type":"int32"},{"tag":"SimpleRecord","type":"TestModel.SimpleRecord"}]},{"name":"recordWithUnions","type":"BasicTypes.RecordWithUnions"}]},"types":[{"name":"DaysOfWeek","values":[{"symbol":"monday","value":1},{"symbol":"tuesday","value":2},{"symbol":"wednesday","value":4},{"symbol":"thursday","value":8},{"symbol":"friday","value":16},{"symbol":"saturday","value":32},{"symbol":"sunday","value":64}]},{"name":"Fruits","values":[{"symbol":"apple","value":1},{"symbol":"banana","value":2},{"symbol":"pear","value":3}]},{"name":"GenericNullableUnion2","typeParameters":["T1","T2"],"type":[null,{"tag":"T1","type":"T1"},{"tag":"T2","type":"T2"}]},{"name":"RecordWithString","fields":[{"name":"i","type":"string"}]},{"name":"RecordWithUnions","fields":[{"name":"nullOrIntOrString","type":[null,{"tag":"int32","type":"int32"},{"tag":"string","type":"string"}]},{"name":"dateOrDatetime","type":[{"tag":"time","type":"time"},{"tag":"datetime","type":"datetime"}]},{"name":"nullOrFruitsOrDaysOfWeek","type":{"name":"BasicTypes.GenericNullableUnion2","typeArguments":["BasicTypes.Fruits","BasicTypes.DaysOfWeek"]}},{"name":"recordOrInt","type":[{"tag":"RecordWithString","type":"BasicTypes.RecordWithString"},{"tag":"int32","type":"int32"}]}]},{"name":"RecordWithVlens","fields":[{"name":"a","type":{"vector":{"items":"TestModel.SimpleRecord"}}},{"name":"b","type":"int32"},{"name":"c","type":"int32"}]},{"name":"SimpleRecord","fields":[{"name":"x","type":"int32"},{"name":"y","type":"int32"},{"name":"z","type":"int32"}]}]})"; std::vector UnionsWriterBase::previous_schemas_ = { }; diff --git a/cpp/test/generated/types.h b/cpp/test/generated/types.h index d31481cc..c1322c20 100644 --- a/cpp/test/generated/types.h +++ b/cpp/test/generated/types.h @@ -70,15 +70,29 @@ using GenericNullableUnion2 = std::variant; template using GenericVector = std::vector; +struct RecordWithString { + std::string i{}; + + bool operator==(const RecordWithString& other) const { + return i == other.i; + } + + bool operator!=(const RecordWithString& other) const { + return !(*this == other); + } +}; + struct RecordWithUnions { std::variant null_or_int_or_string{}; std::variant date_or_datetime{}; basic_types::GenericNullableUnion2 null_or_fruits_or_days_of_week{}; + std::variant record_or_int{}; bool operator==(const RecordWithUnions& other) const { return null_or_int_or_string == other.null_or_int_or_string && date_or_datetime == other.date_or_datetime && - null_or_fruits_or_days_of_week == other.null_or_fruits_or_days_of_week; + null_or_fruits_or_days_of_week == other.null_or_fruits_or_days_of_week && + record_or_int == other.record_or_int; } bool operator!=(const RecordWithUnions& other) const { diff --git a/matlab/generated/+basic_types/+binary/RecordWithStringSerializer.m b/matlab/generated/+basic_types/+binary/RecordWithStringSerializer.m new file mode 100644 index 00000000..dd5f8aea --- /dev/null +++ b/matlab/generated/+basic_types/+binary/RecordWithStringSerializer.m @@ -0,0 +1,24 @@ +% This file was generated by the "yardl" tool. DO NOT EDIT. + +classdef RecordWithStringSerializer < yardl.binary.RecordSerializer + methods + function self = RecordWithStringSerializer() + field_serializers{1} = yardl.binary.StringSerializer; + self@yardl.binary.RecordSerializer('basic_types.RecordWithString', field_serializers); + end + + function write(self, outstream, value) + arguments + self + outstream (1,1) yardl.binary.CodedOutputStream + value (1,1) basic_types.RecordWithString + end + self.write_(outstream, value.i); + end + + function value = read(self, instream) + fields = self.read_(instream); + value = basic_types.RecordWithString(i=fields{1}); + end + end +end diff --git a/matlab/generated/+basic_types/+binary/RecordWithUnionsSerializer.m b/matlab/generated/+basic_types/+binary/RecordWithUnionsSerializer.m index 9ca1bfa1..63587856 100644 --- a/matlab/generated/+basic_types/+binary/RecordWithUnionsSerializer.m +++ b/matlab/generated/+basic_types/+binary/RecordWithUnionsSerializer.m @@ -6,6 +6,7 @@ field_serializers{1} = yardl.binary.UnionSerializer('basic_types.Int32OrString', {yardl.binary.NoneSerializer, yardl.binary.Int32Serializer, yardl.binary.StringSerializer}, {yardl.None, @basic_types.Int32OrString.Int32, @basic_types.Int32OrString.String}); field_serializers{2} = yardl.binary.UnionSerializer('basic_types.TimeOrDatetime', {yardl.binary.TimeSerializer, yardl.binary.DatetimeSerializer}, {@basic_types.TimeOrDatetime.Time, @basic_types.TimeOrDatetime.Datetime}); field_serializers{3} = yardl.binary.UnionSerializer('basic_types.GenericNullableUnion2', {yardl.binary.NoneSerializer, yardl.binary.EnumSerializer('basic_types.Fruits', @basic_types.Fruits, yardl.binary.Int32Serializer), yardl.binary.EnumSerializer('basic_types.DaysOfWeek', @basic_types.DaysOfWeek, yardl.binary.Int32Serializer)}, {yardl.None, @basic_types.GenericNullableUnion2.T1, @basic_types.GenericNullableUnion2.T2}); + field_serializers{4} = yardl.binary.UnionSerializer('basic_types.RecordWithStringOrInt32', {basic_types.binary.RecordWithStringSerializer(), yardl.binary.Int32Serializer}, {@basic_types.RecordWithStringOrInt32.RecordWithString, @basic_types.RecordWithStringOrInt32.Int32}); self@yardl.binary.RecordSerializer('basic_types.RecordWithUnions', field_serializers); end @@ -15,12 +16,12 @@ function write(self, outstream, value) outstream (1,1) yardl.binary.CodedOutputStream value (1,1) basic_types.RecordWithUnions end - self.write_(outstream, value.null_or_int_or_string, value.date_or_datetime, value.null_or_fruits_or_days_of_week); + self.write_(outstream, value.null_or_int_or_string, value.date_or_datetime, value.null_or_fruits_or_days_of_week, value.record_or_int); end function value = read(self, instream) fields = self.read_(instream); - value = basic_types.RecordWithUnions(null_or_int_or_string=fields{1}, date_or_datetime=fields{2}, null_or_fruits_or_days_of_week=fields{3}); + value = basic_types.RecordWithUnions(null_or_int_or_string=fields{1}, date_or_datetime=fields{2}, null_or_fruits_or_days_of_week=fields{3}, record_or_int=fields{4}); end end end diff --git a/matlab/generated/+basic_types/RecordWithString.m b/matlab/generated/+basic_types/RecordWithString.m new file mode 100644 index 00000000..37982fd8 --- /dev/null +++ b/matlab/generated/+basic_types/RecordWithString.m @@ -0,0 +1,41 @@ +% This file was generated by the "yardl" tool. DO NOT EDIT. + +classdef RecordWithString < handle + properties + i + end + + methods + function self = RecordWithString(kwargs) + arguments + kwargs.i = ""; + end + self.i = kwargs.i; + end + + function res = eq(self, other) + res = ... + isa(other, "basic_types.RecordWithString") && ... + isequal(self.i, other.i); + end + + function res = ne(self, other) + res = ~self.eq(other); + end + end + + methods (Static) + function z = zeros(varargin) + elem = basic_types.RecordWithString(); + if nargin == 0 + z = elem; + return; + end + sz = [varargin{:}]; + if isscalar(sz) + sz = [sz, sz]; + end + z = reshape(repelem(elem, prod(sz)), sz); + end + end +end diff --git a/matlab/generated/+basic_types/RecordWithStringOrInt32.m b/matlab/generated/+basic_types/RecordWithStringOrInt32.m new file mode 100644 index 00000000..1a014df2 --- /dev/null +++ b/matlab/generated/+basic_types/RecordWithStringOrInt32.m @@ -0,0 +1,49 @@ +% This file was generated by the "yardl" tool. DO NOT EDIT. + +classdef RecordWithStringOrInt32 < yardl.Union + methods (Static) + function res = RecordWithString(value) + res = basic_types.RecordWithStringOrInt32(1, value); + end + + function res = Int32(value) + res = basic_types.RecordWithStringOrInt32(2, value); + end + + function z = zeros(varargin) + elem = basic_types.RecordWithStringOrInt32(0, yardl.None); + if nargin == 0 + z = elem; + return; + end + sz = [varargin{:}]; + if isscalar(sz) + sz = [sz, sz]; + end + z = reshape(repelem(elem, prod(sz)), sz); + end + end + + methods + function res = isRecordWithString(self) + res = self.index == 1; + end + + function res = isInt32(self) + res = self.index == 2; + end + + function eq = eq(self, other) + eq = isa(other, "basic_types.RecordWithStringOrInt32") && isequal(self.index, other.index) && isequal(self.value, other.value); + end + + function ne = ne(self, other) + ne = ~self.eq(other); + end + + function t = tag(self) + tags_ = ["RecordWithString", "Int32"]; + t = tags_(self.index_); + end + end +end diff --git a/matlab/generated/+basic_types/RecordWithUnions.m b/matlab/generated/+basic_types/RecordWithUnions.m index 4dc6416e..34b27752 100644 --- a/matlab/generated/+basic_types/RecordWithUnions.m +++ b/matlab/generated/+basic_types/RecordWithUnions.m @@ -5,6 +5,7 @@ null_or_int_or_string date_or_datetime null_or_fruits_or_days_of_week + record_or_int end methods @@ -13,10 +14,12 @@ kwargs.null_or_int_or_string = yardl.None; kwargs.date_or_datetime = basic_types.TimeOrDatetime.Time(yardl.Time()); kwargs.null_or_fruits_or_days_of_week = yardl.None; + kwargs.record_or_int = basic_types.RecordWithStringOrInt32.RecordWithString(basic_types.RecordWithString()); end self.null_or_int_or_string = kwargs.null_or_int_or_string; self.date_or_datetime = kwargs.date_or_datetime; self.null_or_fruits_or_days_of_week = kwargs.null_or_fruits_or_days_of_week; + self.record_or_int = kwargs.record_or_int; end function res = eq(self, other) @@ -24,7 +27,8 @@ isa(other, "basic_types.RecordWithUnions") && ... isequal(self.null_or_int_or_string, other.null_or_int_or_string) && ... isequal(self.date_or_datetime, other.date_or_datetime) && ... - isequal(self.null_or_fruits_or_days_of_week, other.null_or_fruits_or_days_of_week); + isequal(self.null_or_fruits_or_days_of_week, other.null_or_fruits_or_days_of_week) && ... + isequal(self.record_or_int, other.record_or_int); end function res = ne(self, other) diff --git a/matlab/generated/+test_model/UnionsWriterBase.m b/matlab/generated/+test_model/UnionsWriterBase.m index 7f8e3096..893da943 100644 --- a/matlab/generated/+test_model/UnionsWriterBase.m +++ b/matlab/generated/+test_model/UnionsWriterBase.m @@ -62,7 +62,7 @@ function write_record_with_unions(self, value) methods (Static) function res = schema() - res = string('{"protocol":{"name":"Unions","sequence":[{"name":"intOrSimpleRecord","type":[{"tag":"int32","type":"int32"},{"tag":"SimpleRecord","type":"TestModel.SimpleRecord"}]},{"name":"intOrRecordWithVlens","type":[{"tag":"int32","type":"int32"},{"tag":"RecordWithVlens","type":"TestModel.RecordWithVlens"}]},{"name":"monosotateOrIntOrSimpleRecord","type":[null,{"tag":"int32","type":"int32"},{"tag":"SimpleRecord","type":"TestModel.SimpleRecord"}]},{"name":"recordWithUnions","type":"BasicTypes.RecordWithUnions"}]},"types":[{"name":"DaysOfWeek","values":[{"symbol":"monday","value":1},{"symbol":"tuesday","value":2},{"symbol":"wednesday","value":4},{"symbol":"thursday","value":8},{"symbol":"friday","value":16},{"symbol":"saturday","value":32},{"symbol":"sunday","value":64}]},{"name":"Fruits","values":[{"symbol":"apple","value":1},{"symbol":"banana","value":2},{"symbol":"pear","value":3}]},{"name":"GenericNullableUnion2","typeParameters":["T1","T2"],"type":[null,{"tag":"T1","type":"T1"},{"tag":"T2","type":"T2"}]},{"name":"RecordWithUnions","fields":[{"name":"nullOrIntOrString","type":[null,{"tag":"int32","type":"int32"},{"tag":"string","type":"string"}]},{"name":"dateOrDatetime","type":[{"tag":"time","type":"time"},{"tag":"datetime","type":"datetime"}]},{"name":"nullOrFruitsOrDaysOfWeek","type":{"name":"BasicTypes.GenericNullableUnion2","typeArguments":["BasicTypes.Fruits","BasicTypes.DaysOfWeek"]}}]},{"name":"RecordWithVlens","fields":[{"name":"a","type":{"vector":{"items":"TestModel.SimpleRecord"}}},{"name":"b","type":"int32"},{"name":"c","type":"int32"}]},{"name":"SimpleRecord","fields":[{"name":"x","type":"int32"},{"name":"y","type":"int32"},{"name":"z","type":"int32"}]}]}'); + res = string('{"protocol":{"name":"Unions","sequence":[{"name":"intOrSimpleRecord","type":[{"tag":"int32","type":"int32"},{"tag":"SimpleRecord","type":"TestModel.SimpleRecord"}]},{"name":"intOrRecordWithVlens","type":[{"tag":"int32","type":"int32"},{"tag":"RecordWithVlens","type":"TestModel.RecordWithVlens"}]},{"name":"monosotateOrIntOrSimpleRecord","type":[null,{"tag":"int32","type":"int32"},{"tag":"SimpleRecord","type":"TestModel.SimpleRecord"}]},{"name":"recordWithUnions","type":"BasicTypes.RecordWithUnions"}]},"types":[{"name":"DaysOfWeek","values":[{"symbol":"monday","value":1},{"symbol":"tuesday","value":2},{"symbol":"wednesday","value":4},{"symbol":"thursday","value":8},{"symbol":"friday","value":16},{"symbol":"saturday","value":32},{"symbol":"sunday","value":64}]},{"name":"Fruits","values":[{"symbol":"apple","value":1},{"symbol":"banana","value":2},{"symbol":"pear","value":3}]},{"name":"GenericNullableUnion2","typeParameters":["T1","T2"],"type":[null,{"tag":"T1","type":"T1"},{"tag":"T2","type":"T2"}]},{"name":"RecordWithString","fields":[{"name":"i","type":"string"}]},{"name":"RecordWithUnions","fields":[{"name":"nullOrIntOrString","type":[null,{"tag":"int32","type":"int32"},{"tag":"string","type":"string"}]},{"name":"dateOrDatetime","type":[{"tag":"time","type":"time"},{"tag":"datetime","type":"datetime"}]},{"name":"nullOrFruitsOrDaysOfWeek","type":{"name":"BasicTypes.GenericNullableUnion2","typeArguments":["BasicTypes.Fruits","BasicTypes.DaysOfWeek"]}},{"name":"recordOrInt","type":[{"tag":"RecordWithString","type":"BasicTypes.RecordWithString"},{"tag":"int32","type":"int32"}]}]},{"name":"RecordWithVlens","fields":[{"name":"a","type":{"vector":{"items":"TestModel.SimpleRecord"}}},{"name":"b","type":"int32"},{"name":"c","type":"int32"}]},{"name":"SimpleRecord","fields":[{"name":"x","type":"int32"},{"name":"y","type":"int32"},{"name":"z","type":"int32"}]}]}'); end end diff --git a/models/basic-types/basic.yml b/models/basic-types/basic.yml index 73630668..7fcc1092 100644 --- a/models/basic-types/basic.yml +++ b/models/basic-types/basic.yml @@ -33,11 +33,16 @@ GenericNullableUnion2: [null, T1, T2] GenericVector: T* +RecordWithString: !record + fields: + i: string + RecordWithUnions: !record fields: nullOrIntOrString: [null, int, string] dateOrDatetime: [time, datetime] nullOrFruitsOrDaysOfWeek: GenericNullableUnion2 + recordOrInt: [RecordWithString, int] GenericRecordWithComputedFields: !record fields: diff --git a/python/test_model/basic_types/__init__.py b/python/test_model/basic_types/__init__.py index a26a377a..1de7c5df 100644 --- a/python/test_model/basic_types/__init__.py +++ b/python/test_model/basic_types/__init__.py @@ -14,6 +14,8 @@ GenericVector, Int32OrString, MyTuple, + RecordWithString, + RecordWithStringOrInt32, RecordWithUnions, T0OrT1, TextFormat, diff --git a/python/test_model/basic_types/binary.py b/python/test_model/basic_types/binary.py index 705f5892..15f6bb0d 100644 --- a/python/test_model/basic_types/binary.py +++ b/python/test_model/basic_types/binary.py @@ -18,22 +18,40 @@ from .. import _binary from .. import yardl_types as yardl +class RecordWithStringSerializer(_binary.RecordSerializer[RecordWithString]): + def __init__(self) -> None: + super().__init__([("i", _binary.string_serializer)]) + + def write(self, stream: _binary.CodedOutputStream, value: RecordWithString) -> None: + if isinstance(value, np.void): + self.write_numpy(stream, value) + return + self._write(stream, value.i) + + def write_numpy(self, stream: _binary.CodedOutputStream, value: np.void) -> None: + self._write(stream, value['i']) + + def read(self, stream: _binary.CodedInputStream) -> RecordWithString: + field_values = self._read(stream) + return RecordWithString(i=field_values[0]) + + class RecordWithUnionsSerializer(_binary.RecordSerializer[RecordWithUnions]): def __init__(self) -> None: - super().__init__([("null_or_int_or_string", _binary.UnionSerializer(Int32OrString, [None, (Int32OrString.Int32, _binary.int32_serializer), (Int32OrString.String, _binary.string_serializer)])), ("date_or_datetime", _binary.UnionSerializer(TimeOrDatetime, [(TimeOrDatetime.Time, _binary.time_serializer), (TimeOrDatetime.Datetime, _binary.datetime_serializer)])), ("null_or_fruits_or_days_of_week", _binary.UnionSerializer(GenericNullableUnion2, [None, (GenericNullableUnion2.T1, _binary.EnumSerializer(_binary.int32_serializer, Fruits)), (GenericNullableUnion2.T2, _binary.EnumSerializer(_binary.int32_serializer, DaysOfWeek))]))]) + super().__init__([("null_or_int_or_string", _binary.UnionSerializer(Int32OrString, [None, (Int32OrString.Int32, _binary.int32_serializer), (Int32OrString.String, _binary.string_serializer)])), ("date_or_datetime", _binary.UnionSerializer(TimeOrDatetime, [(TimeOrDatetime.Time, _binary.time_serializer), (TimeOrDatetime.Datetime, _binary.datetime_serializer)])), ("null_or_fruits_or_days_of_week", _binary.UnionSerializer(GenericNullableUnion2, [None, (GenericNullableUnion2.T1, _binary.EnumSerializer(_binary.int32_serializer, Fruits)), (GenericNullableUnion2.T2, _binary.EnumSerializer(_binary.int32_serializer, DaysOfWeek))])), ("record_or_int", _binary.UnionSerializer(RecordWithStringOrInt32, [(RecordWithStringOrInt32.RecordWithString, RecordWithStringSerializer()), (RecordWithStringOrInt32.Int32, _binary.int32_serializer)]))]) def write(self, stream: _binary.CodedOutputStream, value: RecordWithUnions) -> None: if isinstance(value, np.void): self.write_numpy(stream, value) return - self._write(stream, value.null_or_int_or_string, value.date_or_datetime, value.null_or_fruits_or_days_of_week) + self._write(stream, value.null_or_int_or_string, value.date_or_datetime, value.null_or_fruits_or_days_of_week, value.record_or_int) def write_numpy(self, stream: _binary.CodedOutputStream, value: np.void) -> None: - self._write(stream, value['null_or_int_or_string'], value['date_or_datetime'], value['null_or_fruits_or_days_of_week']) + self._write(stream, value['null_or_int_or_string'], value['date_or_datetime'], value['null_or_fruits_or_days_of_week'], value['record_or_int']) def read(self, stream: _binary.CodedInputStream) -> RecordWithUnions: field_values = self._read(stream) - return RecordWithUnions(null_or_int_or_string=field_values[0], date_or_datetime=field_values[1], null_or_fruits_or_days_of_week=field_values[2]) + return RecordWithUnions(null_or_int_or_string=field_values[0], date_or_datetime=field_values[1], null_or_fruits_or_days_of_week=field_values[2], record_or_int=field_values[3]) class GenericRecordWithComputedFieldsSerializer(typing.Generic[T0, T0_NP, T1, T1_NP], _binary.RecordSerializer[GenericRecordWithComputedFields[T0, T1]]): diff --git a/python/test_model/basic_types/ndjson.py b/python/test_model/basic_types/ndjson.py index 759923fe..08701d0d 100644 --- a/python/test_model/basic_types/ndjson.py +++ b/python/test_model/basic_types/ndjson.py @@ -45,15 +45,55 @@ } text_format_value_to_name_map = {v: n for n, v in text_format_name_to_value_map.items()} +class RecordWithStringConverter(_ndjson.JsonConverter[RecordWithString, np.void]): + def __init__(self) -> None: + self._i_converter = _ndjson.string_converter + super().__init__(np.dtype([ + ("i", self._i_converter.overall_dtype()), + ])) + + def to_json(self, value: RecordWithString) -> object: + if not isinstance(value, RecordWithString): # pyright: ignore [reportUnnecessaryIsInstance] + raise TypeError("Expected 'RecordWithString' instance") + json_object = {} + + json_object["i"] = self._i_converter.to_json(value.i) + return json_object + + def numpy_to_json(self, value: np.void) -> object: + if not isinstance(value, np.void): # pyright: ignore [reportUnnecessaryIsInstance] + raise TypeError("Expected 'np.void' instance") + json_object = {} + + json_object["i"] = self._i_converter.numpy_to_json(value["i"]) + return json_object + + def from_json(self, json_object: object) -> RecordWithString: + if not isinstance(json_object, dict): + raise TypeError("Expected 'dict' instance") + return RecordWithString( + i=self._i_converter.from_json(json_object["i"],), + ) + + def from_json_to_numpy(self, json_object: object) -> np.void: + if not isinstance(json_object, dict): + raise TypeError("Expected 'dict' instance") + return ( + self._i_converter.from_json_to_numpy(json_object["i"]), + ) # type:ignore + + class RecordWithUnionsConverter(_ndjson.JsonConverter[RecordWithUnions, np.void]): def __init__(self) -> None: self._null_or_int_or_string_converter = _ndjson.UnionConverter(Int32OrString, [None, (Int32OrString.Int32, _ndjson.int32_converter, [int, float]), (Int32OrString.String, _ndjson.string_converter, [str])], True) self._date_or_datetime_converter = _ndjson.UnionConverter(TimeOrDatetime, [(TimeOrDatetime.Time, _ndjson.time_converter, [int, float]), (TimeOrDatetime.Datetime, _ndjson.datetime_converter, [int, float])], False) self._null_or_fruits_or_days_of_week_converter = _ndjson.UnionConverter(GenericNullableUnion2, [None, (GenericNullableUnion2.T1, _ndjson.EnumConverter(Fruits, np.int32, fruits_name_to_value_map, fruits_value_to_name_map), [int, float, str]), (GenericNullableUnion2.T2, _ndjson.FlagsConverter(DaysOfWeek, np.int32, days_of_week_name_to_value_map, days_of_week_value_to_name_map), [list])], True) + self._record_or_int_converter = _ndjson.UnionConverter(RecordWithStringOrInt32, [(RecordWithStringOrInt32.RecordWithString, RecordWithStringConverter(), [dict]), (RecordWithStringOrInt32.Int32, _ndjson.int32_converter, [int, float])], True) super().__init__(np.dtype([ ("null_or_int_or_string", self._null_or_int_or_string_converter.overall_dtype()), ("date_or_datetime", self._date_or_datetime_converter.overall_dtype()), ("null_or_fruits_or_days_of_week", self._null_or_fruits_or_days_of_week_converter.overall_dtype()), + ("record_or_int", self._record_or_int_converter.overall_dtype()), ])) def to_json(self, value: RecordWithUnions) -> object: @@ -66,6 +106,7 @@ def to_json(self, value: RecordWithUnions) -> object: json_object["dateOrDatetime"] = self._date_or_datetime_converter.to_json(value.date_or_datetime) if value.null_or_fruits_or_days_of_week is not None: json_object["nullOrFruitsOrDaysOfWeek"] = self._null_or_fruits_or_days_of_week_converter.to_json(value.null_or_fruits_or_days_of_week) + json_object["recordOrInt"] = self._record_or_int_converter.to_json(value.record_or_int) return json_object def numpy_to_json(self, value: np.void) -> object: @@ -78,6 +119,7 @@ def numpy_to_json(self, value: np.void) -> object: json_object["dateOrDatetime"] = self._date_or_datetime_converter.numpy_to_json(value["date_or_datetime"]) if (field_val := value["null_or_fruits_or_days_of_week"]) is not None: json_object["nullOrFruitsOrDaysOfWeek"] = self._null_or_fruits_or_days_of_week_converter.numpy_to_json(field_val) + json_object["recordOrInt"] = self._record_or_int_converter.numpy_to_json(value["record_or_int"]) return json_object def from_json(self, json_object: object) -> RecordWithUnions: @@ -87,6 +129,7 @@ def from_json(self, json_object: object) -> RecordWithUnions: null_or_int_or_string=self._null_or_int_or_string_converter.from_json(json_object.get("nullOrIntOrString")), date_or_datetime=self._date_or_datetime_converter.from_json(json_object["dateOrDatetime"],), null_or_fruits_or_days_of_week=self._null_or_fruits_or_days_of_week_converter.from_json(json_object.get("nullOrFruitsOrDaysOfWeek")), + record_or_int=self._record_or_int_converter.from_json(json_object["recordOrInt"],), ) def from_json_to_numpy(self, json_object: object) -> np.void: @@ -96,6 +139,7 @@ def from_json_to_numpy(self, json_object: object) -> np.void: self._null_or_int_or_string_converter.from_json_to_numpy(json_object.get("nullOrIntOrString")), self._date_or_datetime_converter.from_json_to_numpy(json_object["dateOrDatetime"]), self._null_or_fruits_or_days_of_week_converter.from_json_to_numpy(json_object.get("nullOrFruitsOrDaysOfWeek")), + self._record_or_int_converter.from_json_to_numpy(json_object["recordOrInt"]), ) # type:ignore diff --git a/python/test_model/basic_types/types.py b/python/test_model/basic_types/types.py index 623e75fa..56edde0b 100644 --- a/python/test_model/basic_types/types.py +++ b/python/test_model/basic_types/types.py @@ -99,6 +99,27 @@ class GenericNullableUnion2UnionCase(GenericNullableUnion2[T1, T2], yardl.UnionC GenericVector = list[T] +class RecordWithString: + i: str + + def __init__(self, *, + i: str = "", + ): + self.i = i + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, RecordWithString) + and self.i == other.i + ) + + def __str__(self) -> str: + return f"RecordWithString(i={self.i})" + + def __repr__(self) -> str: + return f"RecordWithString(i={repr(self.i)})" + + class Int32OrString: Int32: typing.ClassVar[type["Int32OrStringUnionCase[yardl.Int32]"]] String: typing.ClassVar[type["Int32OrStringUnionCase[str]"]] @@ -121,19 +142,33 @@ class TimeOrDatetimeUnionCase(TimeOrDatetime, yardl.UnionCase[_T]): TimeOrDatetime.Datetime = type("TimeOrDatetime.Datetime", (TimeOrDatetimeUnionCase,), {"index": 1, "tag": "datetime"}) del TimeOrDatetimeUnionCase +class RecordWithStringOrInt32: + RecordWithString: typing.ClassVar[type["RecordWithStringOrInt32UnionCase[RecordWithString]"]] + Int32: typing.ClassVar[type["RecordWithStringOrInt32UnionCase[yardl.Int32]"]] + +class RecordWithStringOrInt32UnionCase(RecordWithStringOrInt32, yardl.UnionCase[_T]): + pass + +RecordWithStringOrInt32.RecordWithString = type("RecordWithStringOrInt32.RecordWithString", (RecordWithStringOrInt32UnionCase,), {"index": 0, "tag": "RecordWithString"}) +RecordWithStringOrInt32.Int32 = type("RecordWithStringOrInt32.Int32", (RecordWithStringOrInt32UnionCase,), {"index": 1, "tag": "int32"}) +del RecordWithStringOrInt32UnionCase + class RecordWithUnions: null_or_int_or_string: typing.Optional[Int32OrString] date_or_datetime: TimeOrDatetime null_or_fruits_or_days_of_week: typing.Optional[GenericNullableUnion2[Fruits, DaysOfWeek]] + record_or_int: RecordWithStringOrInt32 def __init__(self, *, null_or_int_or_string: typing.Optional[Int32OrString] = None, date_or_datetime: TimeOrDatetime = TimeOrDatetime.Time(yardl.Time()), null_or_fruits_or_days_of_week: typing.Optional[GenericNullableUnion2[Fruits, DaysOfWeek]] = None, + record_or_int: typing.Optional[RecordWithStringOrInt32] = None, ): self.null_or_int_or_string = null_or_int_or_string self.date_or_datetime = date_or_datetime self.null_or_fruits_or_days_of_week = null_or_fruits_or_days_of_week + self.record_or_int = record_or_int if record_or_int is not None else RecordWithStringOrInt32.RecordWithString(RecordWithString()) def __eq__(self, other: object) -> bool: return ( @@ -141,13 +176,14 @@ def __eq__(self, other: object) -> bool: and self.null_or_int_or_string == other.null_or_int_or_string and self.date_or_datetime == other.date_or_datetime and self.null_or_fruits_or_days_of_week == other.null_or_fruits_or_days_of_week + and self.record_or_int == other.record_or_int ) def __str__(self) -> str: - return f"RecordWithUnions(nullOrIntOrString={self.null_or_int_or_string}, dateOrDatetime={self.date_or_datetime}, nullOrFruitsOrDaysOfWeek={self.null_or_fruits_or_days_of_week})" + return f"RecordWithUnions(nullOrIntOrString={self.null_or_int_or_string}, dateOrDatetime={self.date_or_datetime}, nullOrFruitsOrDaysOfWeek={self.null_or_fruits_or_days_of_week}, recordOrInt={self.record_or_int})" def __repr__(self) -> str: - return f"RecordWithUnions(nullOrIntOrString={repr(self.null_or_int_or_string)}, dateOrDatetime={repr(self.date_or_datetime)}, nullOrFruitsOrDaysOfWeek={repr(self.null_or_fruits_or_days_of_week)})" + return f"RecordWithUnions(nullOrIntOrString={repr(self.null_or_int_or_string)}, dateOrDatetime={repr(self.date_or_datetime)}, nullOrFruitsOrDaysOfWeek={repr(self.null_or_fruits_or_days_of_week)}, recordOrInt={repr(self.record_or_int)})" class T0OrT1(typing.Generic[T0, T1]): @@ -202,13 +238,17 @@ def _mk_get_dtype(): dtype_map.setdefault(GenericUnion2, lambda type_args: np.dtype(np.object_)) dtype_map.setdefault(GenericNullableUnion2, lambda type_args: np.dtype(np.object_)) dtype_map.setdefault(GenericNullableUnion2, lambda type_args: np.dtype(np.object_)) - dtype_map.setdefault(RecordWithUnions, np.dtype([('null_or_int_or_string', np.dtype(np.object_)), ('date_or_datetime', np.dtype(np.object_)), ('null_or_fruits_or_days_of_week', np.dtype(np.object_))], align=True)) + dtype_map.setdefault(RecordWithString, np.dtype([('i', np.dtype(np.object_))], align=True)) + dtype_map.setdefault(RecordWithUnions, np.dtype([('null_or_int_or_string', np.dtype(np.object_)), ('date_or_datetime', np.dtype(np.object_)), ('null_or_fruits_or_days_of_week', np.dtype(np.object_)), ('record_or_int', np.dtype(np.object_))], align=True)) dtype_map.setdefault(Int32OrString, np.dtype(np.object_)) dtype_map.setdefault(Int32OrString.Int32, np.dtype(np.int32)) dtype_map.setdefault(Int32OrString.String, np.dtype(np.object_)) dtype_map.setdefault(TimeOrDatetime, np.dtype(np.object_)) dtype_map.setdefault(TimeOrDatetime.Time, np.dtype(np.timedelta64)) dtype_map.setdefault(TimeOrDatetime.Datetime, np.dtype(np.datetime64)) + dtype_map.setdefault(RecordWithStringOrInt32, np.dtype(np.object_)) + dtype_map.setdefault(RecordWithStringOrInt32.RecordWithString, get_dtype(RecordWithString)) + dtype_map.setdefault(RecordWithStringOrInt32.Int32, np.dtype(np.int32)) dtype_map.setdefault(GenericRecordWithComputedFields, lambda type_args: np.dtype([('f1', np.dtype(np.object_))], align=True)) dtype_map.setdefault(T0OrT1, np.dtype(np.object_)) diff --git a/python/test_model/protocols.py b/python/test_model/protocols.py index d58b11a8..92d0a418 100644 --- a/python/test_model/protocols.py +++ b/python/test_model/protocols.py @@ -4355,7 +4355,7 @@ class UnionsWriterBase(abc.ABC): def __init__(self) -> None: self._state = 0 - schema = r"""{"protocol":{"name":"Unions","sequence":[{"name":"intOrSimpleRecord","type":[{"tag":"int32","type":"int32"},{"tag":"SimpleRecord","type":"TestModel.SimpleRecord"}]},{"name":"intOrRecordWithVlens","type":[{"tag":"int32","type":"int32"},{"tag":"RecordWithVlens","type":"TestModel.RecordWithVlens"}]},{"name":"monosotateOrIntOrSimpleRecord","type":[null,{"tag":"int32","type":"int32"},{"tag":"SimpleRecord","type":"TestModel.SimpleRecord"}]},{"name":"recordWithUnions","type":"BasicTypes.RecordWithUnions"}]},"types":[{"name":"DaysOfWeek","values":[{"symbol":"monday","value":1},{"symbol":"tuesday","value":2},{"symbol":"wednesday","value":4},{"symbol":"thursday","value":8},{"symbol":"friday","value":16},{"symbol":"saturday","value":32},{"symbol":"sunday","value":64}]},{"name":"Fruits","values":[{"symbol":"apple","value":1},{"symbol":"banana","value":2},{"symbol":"pear","value":3}]},{"name":"GenericNullableUnion2","typeParameters":["T1","T2"],"type":[null,{"tag":"T1","type":"T1"},{"tag":"T2","type":"T2"}]},{"name":"RecordWithUnions","fields":[{"name":"nullOrIntOrString","type":[null,{"tag":"int32","type":"int32"},{"tag":"string","type":"string"}]},{"name":"dateOrDatetime","type":[{"tag":"time","type":"time"},{"tag":"datetime","type":"datetime"}]},{"name":"nullOrFruitsOrDaysOfWeek","type":{"name":"BasicTypes.GenericNullableUnion2","typeArguments":["BasicTypes.Fruits","BasicTypes.DaysOfWeek"]}}]},{"name":"RecordWithVlens","fields":[{"name":"a","type":{"vector":{"items":"TestModel.SimpleRecord"}}},{"name":"b","type":"int32"},{"name":"c","type":"int32"}]},{"name":"SimpleRecord","fields":[{"name":"x","type":"int32"},{"name":"y","type":"int32"},{"name":"z","type":"int32"}]}]}""" + schema = r"""{"protocol":{"name":"Unions","sequence":[{"name":"intOrSimpleRecord","type":[{"tag":"int32","type":"int32"},{"tag":"SimpleRecord","type":"TestModel.SimpleRecord"}]},{"name":"intOrRecordWithVlens","type":[{"tag":"int32","type":"int32"},{"tag":"RecordWithVlens","type":"TestModel.RecordWithVlens"}]},{"name":"monosotateOrIntOrSimpleRecord","type":[null,{"tag":"int32","type":"int32"},{"tag":"SimpleRecord","type":"TestModel.SimpleRecord"}]},{"name":"recordWithUnions","type":"BasicTypes.RecordWithUnions"}]},"types":[{"name":"DaysOfWeek","values":[{"symbol":"monday","value":1},{"symbol":"tuesday","value":2},{"symbol":"wednesday","value":4},{"symbol":"thursday","value":8},{"symbol":"friday","value":16},{"symbol":"saturday","value":32},{"symbol":"sunday","value":64}]},{"name":"Fruits","values":[{"symbol":"apple","value":1},{"symbol":"banana","value":2},{"symbol":"pear","value":3}]},{"name":"GenericNullableUnion2","typeParameters":["T1","T2"],"type":[null,{"tag":"T1","type":"T1"},{"tag":"T2","type":"T2"}]},{"name":"RecordWithString","fields":[{"name":"i","type":"string"}]},{"name":"RecordWithUnions","fields":[{"name":"nullOrIntOrString","type":[null,{"tag":"int32","type":"int32"},{"tag":"string","type":"string"}]},{"name":"dateOrDatetime","type":[{"tag":"time","type":"time"},{"tag":"datetime","type":"datetime"}]},{"name":"nullOrFruitsOrDaysOfWeek","type":{"name":"BasicTypes.GenericNullableUnion2","typeArguments":["BasicTypes.Fruits","BasicTypes.DaysOfWeek"]}},{"name":"recordOrInt","type":[{"tag":"RecordWithString","type":"BasicTypes.RecordWithString"},{"tag":"int32","type":"int32"}]}]},{"name":"RecordWithVlens","fields":[{"name":"a","type":{"vector":{"items":"TestModel.SimpleRecord"}}},{"name":"b","type":"int32"},{"name":"c","type":"int32"}]},{"name":"SimpleRecord","fields":[{"name":"x","type":"int32"},{"name":"y","type":"int32"},{"name":"z","type":"int32"}]}]}""" def close(self) -> None: self._close() diff --git a/python/test_model/types.py b/python/test_model/types.py index e67453a0..3439b641 100644 --- a/python/test_model/types.py +++ b/python/test_model/types.py @@ -2094,13 +2094,17 @@ def _mk_get_dtype(): dtype_map.setdefault(basic_types.GenericUnion2, lambda type_args: np.dtype(np.object_)) dtype_map.setdefault(basic_types.GenericNullableUnion2, lambda type_args: np.dtype(np.object_)) dtype_map.setdefault(basic_types.GenericNullableUnion2, lambda type_args: np.dtype(np.object_)) - dtype_map.setdefault(basic_types.RecordWithUnions, np.dtype([('null_or_int_or_string', np.dtype(np.object_)), ('date_or_datetime', np.dtype(np.object_)), ('null_or_fruits_or_days_of_week', np.dtype(np.object_))], align=True)) + dtype_map.setdefault(basic_types.RecordWithString, np.dtype([('i', np.dtype(np.object_))], align=True)) + dtype_map.setdefault(basic_types.RecordWithUnions, np.dtype([('null_or_int_or_string', np.dtype(np.object_)), ('date_or_datetime', np.dtype(np.object_)), ('null_or_fruits_or_days_of_week', np.dtype(np.object_)), ('record_or_int', np.dtype(np.object_))], align=True)) dtype_map.setdefault(basic_types.Int32OrString, np.dtype(np.object_)) dtype_map.setdefault(basic_types.Int32OrString.Int32, np.dtype(np.int32)) dtype_map.setdefault(basic_types.Int32OrString.String, np.dtype(np.object_)) dtype_map.setdefault(basic_types.TimeOrDatetime, np.dtype(np.object_)) dtype_map.setdefault(basic_types.TimeOrDatetime.Time, np.dtype(np.timedelta64)) dtype_map.setdefault(basic_types.TimeOrDatetime.Datetime, np.dtype(np.datetime64)) + dtype_map.setdefault(basic_types.RecordWithStringOrInt32, np.dtype(np.object_)) + dtype_map.setdefault(basic_types.RecordWithStringOrInt32.RecordWithString, get_dtype(basic_types.RecordWithString)) + dtype_map.setdefault(basic_types.RecordWithStringOrInt32.Int32, np.dtype(np.int32)) dtype_map.setdefault(basic_types.GenericRecordWithComputedFields, lambda type_args: np.dtype([('f1', np.dtype(np.object_))], align=True)) dtype_map.setdefault(basic_types.T0OrT1, np.dtype(np.object_)) dtype_map.setdefault(SmallBenchmarkRecord, np.dtype([('a', np.dtype(np.float64)), ('b', np.dtype(np.float32)), ('c', np.dtype(np.float32))], align=True)) diff --git a/tooling/internal/matlab/types/types.go b/tooling/internal/matlab/types/types.go index 70f0a911..17145bea 100644 --- a/tooling/internal/matlab/types/types.go +++ b/tooling/internal/matlab/types/types.go @@ -925,11 +925,6 @@ func typeDefault(t dsl.Type, contextNamespace string, namedType string, st dsl.S case defaultValueKindImmutable: return fmt.Sprintf(`%s(%s)`, unionCaseConstructor, defaultExpression), defaultKind case defaultValueKindMutable: - if t, ok := dsl.GetUnderlyingType(t.Cases[0].Type).(*dsl.SimpleType); ok { - if _, ok := t.ResolvedDefinition.(*dsl.RecordDefinition); ok { - return fmt.Sprintf(`%s(%s())`, unionCaseConstructor, defaultExpression), defaultValueKindMutable - } - } return fmt.Sprintf(`%s(%s)`, unionCaseConstructor, defaultExpression), defaultValueKindMutable } diff --git a/tooling/internal/python/types/types.go b/tooling/internal/python/types/types.go index 18b8096e..416d5573 100644 --- a/tooling/internal/python/types/types.go +++ b/tooling/internal/python/types/types.go @@ -852,11 +852,6 @@ func typeDefault(t dsl.Type, contextNamespace string, namedType string, st dsl.S case defaultValueKindImmutable: return fmt.Sprintf(`%s(%s)`, unionCaseConstructor, defaultExpression), defaultKind case defaultValueKindMutable: - if t, ok := dsl.GetUnderlyingType(t.Cases[0].Type).(*dsl.SimpleType); ok { - if _, ok := t.ResolvedDefinition.(*dsl.RecordDefinition); ok { - return fmt.Sprintf(`%s(%s())`, unionCaseConstructor, defaultExpression), defaultValueKindMutable - } - } return fmt.Sprintf(`%s(%s)`, unionCaseConstructor, defaultExpression), defaultValueKindMutable } From 8329173c88360efbab0c2d7fe723afa177800d4d Mon Sep 17 00:00:00 2001 From: Joe Naegele Date: Thu, 3 Oct 2024 14:40:01 +0000 Subject: [PATCH 2/2] Collapse duplicate switch cases --- tooling/internal/matlab/types/types.go | 6 +----- tooling/internal/python/types/types.go | 5 +---- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/tooling/internal/matlab/types/types.go b/tooling/internal/matlab/types/types.go index 17145bea..182105c9 100644 --- a/tooling/internal/matlab/types/types.go +++ b/tooling/internal/matlab/types/types.go @@ -922,14 +922,10 @@ func typeDefault(t dsl.Type, contextNamespace string, namedType string, st dsl.S switch defaultKind { case defaultValueKindNone: return "", defaultKind - case defaultValueKindImmutable: + default: return fmt.Sprintf(`%s(%s)`, unionCaseConstructor, defaultExpression), defaultKind - case defaultValueKindMutable: - return fmt.Sprintf(`%s(%s)`, unionCaseConstructor, defaultExpression), defaultValueKindMutable } - return fmt.Sprintf(`("%s", %s)`, t.Cases[0].Tag, defaultExpression), defaultValueKindImmutable - case *dsl.Vector: scalar := t.ToScalar() if dsl.TypeContainsGenericTypeParameter(scalar) { diff --git a/tooling/internal/python/types/types.go b/tooling/internal/python/types/types.go index 416d5573..42ff3a3b 100644 --- a/tooling/internal/python/types/types.go +++ b/tooling/internal/python/types/types.go @@ -849,13 +849,10 @@ func typeDefault(t dsl.Type, contextNamespace string, namedType string, st dsl.S switch defaultKind { case defaultValueKindNone: return "", defaultKind - case defaultValueKindImmutable: + default: return fmt.Sprintf(`%s(%s)`, unionCaseConstructor, defaultExpression), defaultKind - case defaultValueKindMutable: - return fmt.Sprintf(`%s(%s)`, unionCaseConstructor, defaultExpression), defaultValueKindMutable } - return fmt.Sprintf(`("%s", %s)`, t.Cases[0].Tag, defaultExpression), defaultValueKindImmutable case *dsl.Vector: if td.Length == nil { return "[]", defaultValueKindMutable