diff --git a/cpp/test/generated/binary/protocols.cc b/cpp/test/generated/binary/protocols.cc index 75a664d2..a96689fe 100644 --- a/cpp/test/generated/binary/protocols.cc +++ b/cpp/test/generated/binary/protocols.cc @@ -3534,6 +3534,34 @@ void MultiDArraysReader::CloseImpl() { stream_.VerifyFinished(); } +void ComplexArraysWriter::WriteFloatsImpl(yardl::DynamicNDArray> const& value) { + yardl::binary::WriteDynamicNDArray, yardl::binary::WriteFloatingPoint>(stream_, value); +} + +void ComplexArraysWriter::WriteDoublesImpl(yardl::NDArray, 2> const& value) { + yardl::binary::WriteNDArray, yardl::binary::WriteFloatingPoint, 2>(stream_, value); +} + +void ComplexArraysWriter::Flush() { + stream_.Flush(); +} + +void ComplexArraysWriter::CloseImpl() { + stream_.Flush(); +} + +void ComplexArraysReader::ReadFloatsImpl(yardl::DynamicNDArray>& value) { + yardl::binary::ReadDynamicNDArray, yardl::binary::ReadFloatingPoint>(stream_, value); +} + +void ComplexArraysReader::ReadDoublesImpl(yardl::NDArray, 2>& value) { + yardl::binary::ReadNDArray, yardl::binary::ReadFloatingPoint, 2>(stream_, value); +} + +void ComplexArraysReader::CloseImpl() { + stream_.VerifyFinished(); +} + void MapsWriter::WriteStringToIntImpl(std::unordered_map const& value) { yardl::binary::WriteMap(stream_, value); } diff --git a/cpp/test/generated/binary/protocols.h b/cpp/test/generated/binary/protocols.h index e088180b..f11a7b00 100644 --- a/cpp/test/generated/binary/protocols.h +++ b/cpp/test/generated/binary/protocols.h @@ -905,6 +905,44 @@ class MultiDArraysReader : public test_model::MultiDArraysReaderBase, yardl::bin size_t current_block_remaining_ = 0; }; +// Binary writer for the ComplexArrays protocol. +class ComplexArraysWriter : public test_model::ComplexArraysWriterBase, yardl::binary::BinaryWriter { + public: + ComplexArraysWriter(std::ostream& stream, Version version = Version::Current) + : yardl::binary::BinaryWriter(stream, test_model::ComplexArraysWriterBase::SchemaFromVersion(version)), version_(version) {} + + ComplexArraysWriter(std::string file_name, Version version = Version::Current) + : yardl::binary::BinaryWriter(file_name, test_model::ComplexArraysWriterBase::SchemaFromVersion(version)), version_(version) {} + + void Flush() override; + + protected: + void WriteFloatsImpl(yardl::DynamicNDArray> const& value) override; + void WriteDoublesImpl(yardl::NDArray, 2> const& value) override; + void CloseImpl() override; + + Version version_; +}; + +// Binary reader for the ComplexArrays protocol. +class ComplexArraysReader : public test_model::ComplexArraysReaderBase, yardl::binary::BinaryReader { + public: + ComplexArraysReader(std::istream& stream) + : yardl::binary::BinaryReader(stream), version_(test_model::ComplexArraysReaderBase::VersionFromSchema(schema_read_)) {} + + ComplexArraysReader(std::string file_name) + : yardl::binary::BinaryReader(file_name), version_(test_model::ComplexArraysReaderBase::VersionFromSchema(schema_read_)) {} + + Version GetVersion() { return version_; } + + protected: + void ReadFloatsImpl(yardl::DynamicNDArray>& value) override; + void ReadDoublesImpl(yardl::NDArray, 2>& value) override; + void CloseImpl() override; + + Version version_; +}; + // Binary writer for the Maps protocol. class MapsWriter : public test_model::MapsWriterBase, yardl::binary::BinaryWriter { public: diff --git a/cpp/test/generated/factories.cc b/cpp/test/generated/factories.cc index 2f94b6ee..3f038ae1 100644 --- a/cpp/test/generated/factories.cc +++ b/cpp/test/generated/factories.cc @@ -595,6 +595,34 @@ std::unique_ptr CreateReader +std::unique_ptr CreateWriter(Format format, std::string const& filename) { + switch (format) { + case Format::kHdf5: + return std::make_unique(filename); + case Format::kBinary: + return std::make_unique(filename); + case Format::kNDJson: + return std::make_unique(filename); + default: + throw std::runtime_error("Unknown format"); + } +} + +template<> +std::unique_ptr CreateReader(Format format, std::string const& filename) { + switch (format) { + case Format::kHdf5: + return std::make_unique(filename); + case Format::kBinary: + return std::make_unique(filename); + case Format::kNDJson: + return std::make_unique(filename); + default: + throw std::runtime_error("Unknown format"); + } +} + template<> std::unique_ptr CreateWriter(Format format, std::string const& filename) { switch (format) { diff --git a/cpp/test/generated/hdf5/protocols.cc b/cpp/test/generated/hdf5/protocols.cc index 6723854c..9fc67f52 100644 --- a/cpp/test/generated/hdf5/protocols.cc +++ b/cpp/test/generated/hdf5/protocols.cc @@ -2699,6 +2699,30 @@ bool MultiDArraysReader::ReadFramesImpl(std::vector> const& value) { + yardl::hdf5::WriteScalarDataset, std::complex>, yardl::DynamicNDArray>>(group_, "floats", yardl::hdf5::DynamicNDArrayDdl, std::complex>(yardl::hdf5::ComplexTypeDdl()), value); +} + +void ComplexArraysWriter::WriteDoublesImpl(yardl::NDArray, 2> const& value) { + yardl::hdf5::WriteScalarDataset, std::complex, 2>, yardl::NDArray, 2>>(group_, "doubles", yardl::hdf5::NDArrayDdl, std::complex, 2>(yardl::hdf5::ComplexTypeDdl()), value); +} + +ComplexArraysReader::ComplexArraysReader(std::string path) + : yardl::hdf5::Hdf5Reader::Hdf5Reader(path, "ComplexArrays", schema_) { +} + +void ComplexArraysReader::ReadFloatsImpl(yardl::DynamicNDArray>& value) { + yardl::hdf5::ReadScalarDataset, std::complex>, yardl::DynamicNDArray>>(group_, "floats", yardl::hdf5::DynamicNDArrayDdl, std::complex>(yardl::hdf5::ComplexTypeDdl()), value); +} + +void ComplexArraysReader::ReadDoublesImpl(yardl::NDArray, 2>& value) { + yardl::hdf5::ReadScalarDataset, std::complex, 2>, yardl::NDArray, 2>>(group_, "doubles", yardl::hdf5::NDArrayDdl, std::complex, 2>(yardl::hdf5::ComplexTypeDdl()), value); +} + MapsWriter::MapsWriter(std::string path) : yardl::hdf5::Hdf5Writer::Hdf5Writer(path, "Maps", schema_) { } diff --git a/cpp/test/generated/hdf5/protocols.h b/cpp/test/generated/hdf5/protocols.h index 29abd127..ab19bb93 100644 --- a/cpp/test/generated/hdf5/protocols.h +++ b/cpp/test/generated/hdf5/protocols.h @@ -700,6 +700,31 @@ class MultiDArraysReader : public test_model::MultiDArraysReaderBase, public yar std::unique_ptr frames_dataset_state_; }; +// HDF5 writer for the ComplexArrays protocol. +class ComplexArraysWriter : public test_model::ComplexArraysWriterBase, public yardl::hdf5::Hdf5Writer { + public: + ComplexArraysWriter(std::string path); + + protected: + void WriteFloatsImpl(yardl::DynamicNDArray> const& value) override; + + void WriteDoublesImpl(yardl::NDArray, 2> const& value) override; + + private: +}; + +// HDF5 reader for the ComplexArrays protocol. +class ComplexArraysReader : public test_model::ComplexArraysReaderBase, public yardl::hdf5::Hdf5Reader { + public: + ComplexArraysReader(std::string path); + + void ReadFloatsImpl(yardl::DynamicNDArray>& value) override; + + void ReadDoublesImpl(yardl::NDArray, 2>& value) override; + + private: +}; + // HDF5 writer for the Maps protocol. class MapsWriter : public test_model::MapsWriterBase, public yardl::hdf5::Hdf5Writer { public: diff --git a/cpp/test/generated/mocks.cc b/cpp/test/generated/mocks.cc index 7234b5ed..3556f281 100644 --- a/cpp/test/generated/mocks.cc +++ b/cpp/test/generated/mocks.cc @@ -2446,6 +2446,87 @@ class TestMultiDArraysWriterBase : public MultiDArraysWriterBase { bool close_called_ = false; }; +class MockComplexArraysWriter : public ComplexArraysWriterBase { + public: + void WriteFloatsImpl (yardl::DynamicNDArray> const& value) override { + if (WriteFloatsImpl_expected_values_.empty()) { + throw std::runtime_error("Unexpected call to WriteFloatsImpl"); + } + if (WriteFloatsImpl_expected_values_.front() != value) { + throw std::runtime_error("Unexpected argument value for call to WriteFloatsImpl"); + } + WriteFloatsImpl_expected_values_.pop(); + } + + std::queue>> WriteFloatsImpl_expected_values_; + + void ExpectWriteFloatsImpl (yardl::DynamicNDArray> const& value) { + WriteFloatsImpl_expected_values_.push(value); + } + + void WriteDoublesImpl (yardl::NDArray, 2> const& value) override { + if (WriteDoublesImpl_expected_values_.empty()) { + throw std::runtime_error("Unexpected call to WriteDoublesImpl"); + } + if (WriteDoublesImpl_expected_values_.front() != value) { + throw std::runtime_error("Unexpected argument value for call to WriteDoublesImpl"); + } + WriteDoublesImpl_expected_values_.pop(); + } + + std::queue, 2>> WriteDoublesImpl_expected_values_; + + void ExpectWriteDoublesImpl (yardl::NDArray, 2> const& value) { + WriteDoublesImpl_expected_values_.push(value); + } + + void Verify() { + if (!WriteFloatsImpl_expected_values_.empty()) { + throw std::runtime_error("Expected call to WriteFloatsImpl was not received"); + } + if (!WriteDoublesImpl_expected_values_.empty()) { + throw std::runtime_error("Expected call to WriteDoublesImpl was not received"); + } + } +}; + +class TestComplexArraysWriterBase : public ComplexArraysWriterBase { + public: + TestComplexArraysWriterBase(std::unique_ptr writer, std::function()> create_reader) : writer_(std::move(writer)), create_reader_(create_reader) { + } + + ~TestComplexArraysWriterBase() { + if (!close_called_ && !std::uncaught_exceptions()) { + ADD_FAILURE() << "Close() needs to be called on 'TestComplexArraysWriterBase' to verify mocks"; + } + } + + protected: + void WriteFloatsImpl(yardl::DynamicNDArray> const& value) override { + writer_->WriteFloats(value); + mock_writer_.ExpectWriteFloatsImpl(value); + } + + void WriteDoublesImpl(yardl::NDArray, 2> const& value) override { + writer_->WriteDoubles(value); + mock_writer_.ExpectWriteDoublesImpl(value); + } + + void CloseImpl() override { + close_called_ = true; + writer_->Close(); + std::unique_ptr reader = create_reader_(); + reader->CopyTo(mock_writer_); + mock_writer_.Verify(); + } + + private: + std::unique_ptr writer_; + std::function()> create_reader_; + MockComplexArraysWriter mock_writer_; + bool close_called_ = false; +}; + class MockMapsWriter : public MapsWriterBase { public: void WriteStringToIntImpl (std::unordered_map const& value) override { @@ -4411,6 +4492,14 @@ std::unique_ptr CreateValidatingWriter +std::unique_ptr CreateValidatingWriter(Format format, std::string const& filename) { + return std::make_unique( + CreateWriter(format, filename), + [format, filename](){ return CreateReader(format, filename);} + ); +} + template<> std::unique_ptr CreateValidatingWriter(Format format, std::string const& filename) { return std::make_unique( diff --git a/cpp/test/generated/model.json b/cpp/test/generated/model.json index 047e650a..02aa6a64 100644 --- a/cpp/test/generated/model.json +++ b/cpp/test/generated/model.json @@ -4780,6 +4780,28 @@ } ] }, + { + "name": "ComplexArrays", + "sequence": [ + { + "name": "floats", + "type": { + "array": { + "items": "complexfloat32" + } + } + }, + { + "name": "doubles", + "type": { + "array": { + "items": "complexfloat64", + "dimensions": 2 + } + } + } + ] + }, { "name": "Maps", "sequence": [ diff --git a/cpp/test/generated/ndjson/protocols.cc b/cpp/test/generated/ndjson/protocols.cc index 2d3ad6de..a0581bcf 100644 --- a/cpp/test/generated/ndjson/protocols.cc +++ b/cpp/test/generated/ndjson/protocols.cc @@ -3291,6 +3291,34 @@ void MultiDArraysReader::CloseImpl() { VerifyFinished(); } +void ComplexArraysWriter::WriteFloatsImpl(yardl::DynamicNDArray> const& value) { + ordered_json json_value = value; + yardl::ndjson::WriteProtocolValue(stream_, "floats", json_value);} + +void ComplexArraysWriter::WriteDoublesImpl(yardl::NDArray, 2> const& value) { + ordered_json json_value = value; + yardl::ndjson::WriteProtocolValue(stream_, "doubles", json_value);} + +void ComplexArraysWriter::Flush() { + stream_.flush(); +} + +void ComplexArraysWriter::CloseImpl() { + stream_.flush(); +} + +void ComplexArraysReader::ReadFloatsImpl(yardl::DynamicNDArray>& value) { + yardl::ndjson::ReadProtocolValue(stream_, line_, "floats", true, unused_step_, value); +} + +void ComplexArraysReader::ReadDoublesImpl(yardl::NDArray, 2>& value) { + yardl::ndjson::ReadProtocolValue(stream_, line_, "doubles", true, unused_step_, value); +} + +void ComplexArraysReader::CloseImpl() { + VerifyFinished(); +} + void MapsWriter::WriteStringToIntImpl(std::unordered_map const& value) { ordered_json json_value = value; yardl::ndjson::WriteProtocolValue(stream_, "stringToInt", json_value);} diff --git a/cpp/test/generated/ndjson/protocols.h b/cpp/test/generated/ndjson/protocols.h index d1b50b62..5a336a26 100644 --- a/cpp/test/generated/ndjson/protocols.h +++ b/cpp/test/generated/ndjson/protocols.h @@ -815,6 +815,42 @@ class MultiDArraysReader : public test_model::MultiDArraysReaderBase, yardl::ndj void CloseImpl() override; }; +// NDJSON writer for the ComplexArrays protocol. +class ComplexArraysWriter : public test_model::ComplexArraysWriterBase, yardl::ndjson::NDJsonWriter { + public: + ComplexArraysWriter(std::ostream& stream) + : yardl::ndjson::NDJsonWriter(stream, schema_) { + } + + ComplexArraysWriter(std::string file_name) + : yardl::ndjson::NDJsonWriter(file_name, schema_) { + } + + void Flush() override; + + protected: + void WriteFloatsImpl(yardl::DynamicNDArray> const& value) override; + void WriteDoublesImpl(yardl::NDArray, 2> const& value) override; + void CloseImpl() override; +}; + +// NDJSON reader for the ComplexArrays protocol. +class ComplexArraysReader : public test_model::ComplexArraysReaderBase, yardl::ndjson::NDJsonReader { + public: + ComplexArraysReader(std::istream& stream) + : yardl::ndjson::NDJsonReader(stream, schema_) { + } + + ComplexArraysReader(std::string file_name) + : yardl::ndjson::NDJsonReader(file_name, schema_) { + } + + protected: + void ReadFloatsImpl(yardl::DynamicNDArray>& value) override; + void ReadDoublesImpl(yardl::NDArray, 2>& value) override; + void CloseImpl() override; +}; + // NDJSON writer for the Maps protocol. class MapsWriter : public test_model::MapsWriterBase, yardl::ndjson::NDJsonWriter { public: diff --git a/cpp/test/generated/protocols.cc b/cpp/test/generated/protocols.cc index cff90358..2375176b 100644 --- a/cpp/test/generated/protocols.cc +++ b/cpp/test/generated/protocols.cc @@ -3875,6 +3875,122 @@ void MultiDArraysReaderBase::CopyTo(MultiDArraysWriterBase& writer, size_t image } } +namespace { +void ComplexArraysWriterBaseInvalidState(uint8_t attempted, [[maybe_unused]] bool end, uint8_t current) { + std::string expected_method; + switch (current) { + case 0: expected_method = "WriteFloats()"; break; + case 1: expected_method = "WriteDoubles()"; break; + } + std::string attempted_method; + switch (attempted) { + case 0: attempted_method = "WriteFloats()"; break; + case 1: attempted_method = "WriteDoubles()"; break; + case 2: attempted_method = "Close()"; break; + } + throw std::runtime_error("Expected call to " + expected_method + " but received call to " + attempted_method + " instead."); +} + +void ComplexArraysReaderBaseInvalidState(uint8_t attempted, uint8_t current) { + auto f = [](uint8_t i) -> std::string { + switch (i/2) { + case 0: return "ReadFloats()"; + case 1: return "ReadDoubles()"; + case 2: return "Close()"; + default: return ""; + } + }; + throw std::runtime_error("Expected call to " + f(current) + " but received call to " + f(attempted) + " instead."); +} + +} // namespace + +std::string ComplexArraysWriterBase::schema_ = R"({"protocol":{"name":"ComplexArrays","sequence":[{"name":"floats","type":{"array":{"items":"complexfloat32"}}},{"name":"doubles","type":{"array":{"items":"complexfloat64","dimensions":2}}}]},"types":null})"; + +std::vector ComplexArraysWriterBase::previous_schemas_ = { +}; + +std::string ComplexArraysWriterBase::SchemaFromVersion(Version version) { + switch (version) { + case Version::Current: return ComplexArraysWriterBase::schema_; break; + default: throw std::runtime_error("The version does not correspond to any schema supported by protocol ComplexArrays."); + } + +} +void ComplexArraysWriterBase::WriteFloats(yardl::DynamicNDArray> const& value) { + if (unlikely(state_ != 0)) { + ComplexArraysWriterBaseInvalidState(0, false, state_); + } + + WriteFloatsImpl(value); + state_ = 1; +} + +void ComplexArraysWriterBase::WriteDoubles(yardl::NDArray, 2> const& value) { + if (unlikely(state_ != 1)) { + ComplexArraysWriterBaseInvalidState(1, false, state_); + } + + WriteDoublesImpl(value); + state_ = 2; +} + +void ComplexArraysWriterBase::Close() { + if (unlikely(state_ != 2)) { + ComplexArraysWriterBaseInvalidState(2, false, state_); + } + + CloseImpl(); +} + +std::string ComplexArraysReaderBase::schema_ = ComplexArraysWriterBase::schema_; + +std::vector ComplexArraysReaderBase::previous_schemas_ = ComplexArraysWriterBase::previous_schemas_; + +Version ComplexArraysReaderBase::VersionFromSchema(std::string const& schema) { + if (schema == ComplexArraysWriterBase::schema_) { + return Version::Current; + } + throw std::runtime_error("The schema does not match any version supported by protocol ComplexArrays."); +} +void ComplexArraysReaderBase::ReadFloats(yardl::DynamicNDArray>& value) { + if (unlikely(state_ != 0)) { + ComplexArraysReaderBaseInvalidState(0, state_); + } + + ReadFloatsImpl(value); + state_ = 2; +} + +void ComplexArraysReaderBase::ReadDoubles(yardl::NDArray, 2>& value) { + if (unlikely(state_ != 2)) { + ComplexArraysReaderBaseInvalidState(2, state_); + } + + ReadDoublesImpl(value); + state_ = 4; +} + +void ComplexArraysReaderBase::Close() { + if (unlikely(state_ != 4)) { + ComplexArraysReaderBaseInvalidState(4, state_); + } + + CloseImpl(); +} +void ComplexArraysReaderBase::CopyTo(ComplexArraysWriterBase& writer) { + { + yardl::DynamicNDArray> value; + ReadFloats(value); + writer.WriteFloats(value); + } + { + yardl::NDArray, 2> value; + ReadDoubles(value); + writer.WriteDoubles(value); + } +} + namespace { void MapsWriterBaseInvalidState(uint8_t attempted, [[maybe_unused]] bool end, uint8_t current) { std::string expected_method; diff --git a/cpp/test/generated/protocols.h b/cpp/test/generated/protocols.h index f3decd7c..c8c3edb8 100644 --- a/cpp/test/generated/protocols.h +++ b/cpp/test/generated/protocols.h @@ -1655,6 +1655,70 @@ class MultiDArraysReaderBase { uint8_t state_ = 0; }; +// Abstract writer for the ComplexArrays protocol. +class ComplexArraysWriterBase { + public: + // Ordinal 0. + void WriteFloats(yardl::DynamicNDArray> const& value); + + // Ordinal 1. + void WriteDoubles(yardl::NDArray, 2> const& value); + + // Optionaly close this writer before destructing. Validates that all steps were completed. + void Close(); + + virtual ~ComplexArraysWriterBase() = default; + + // Flushes all buffered data. + virtual void Flush() {} + + protected: + virtual void WriteFloatsImpl(yardl::DynamicNDArray> const& value) = 0; + virtual void WriteDoublesImpl(yardl::NDArray, 2> const& value) = 0; + virtual void CloseImpl() {} + + static std::string schema_; + + static std::vector previous_schemas_; + + static std::string SchemaFromVersion(Version version); + + private: + uint8_t state_ = 0; + + friend class ComplexArraysReaderBase; +}; + +// Abstract reader for the ComplexArrays protocol. +class ComplexArraysReaderBase { + public: + // Ordinal 0. + void ReadFloats(yardl::DynamicNDArray>& value); + + // Ordinal 1. + void ReadDoubles(yardl::NDArray, 2>& value); + + // Optionaly close this writer before destructing. Validates that all steps were completely read. + void Close(); + + void CopyTo(ComplexArraysWriterBase& writer); + + virtual ~ComplexArraysReaderBase() = default; + + protected: + virtual void ReadFloatsImpl(yardl::DynamicNDArray>& value) = 0; + virtual void ReadDoublesImpl(yardl::NDArray, 2>& value) = 0; + virtual void CloseImpl() {} + static std::string schema_; + + static std::vector previous_schemas_; + + static Version VersionFromSchema(const std::string& schema); + + private: + uint8_t state_ = 0; +}; + // Abstract writer for the Maps protocol. class MapsWriterBase { public: diff --git a/cpp/test/generated/translator_impl.cc b/cpp/test/generated/translator_impl.cc index 8fab9a21..0fbdfd1f 100644 --- a/cpp/test/generated/translator_impl.cc +++ b/cpp/test/generated/translator_impl.cc @@ -249,6 +249,17 @@ void TranslateStream(std::string const& protocol_name, yardl::testing::Format in reader->CopyTo(*writer); return; } + if (protocol_name == "ComplexArrays") { + auto reader = input_format == yardl::testing::Format::kBinary + ? std::unique_ptr(new test_model::binary::ComplexArraysReader(input)) + : std::unique_ptr(new test_model::ndjson::ComplexArraysReader(input)); + + auto writer = output_format == yardl::testing::Format::kBinary + ? std::unique_ptr(new test_model::binary::ComplexArraysWriter(output)) + : std::unique_ptr(new test_model::ndjson::ComplexArraysWriter(output)); + reader->CopyTo(*writer); + return; + } if (protocol_name == "Maps") { auto reader = input_format == yardl::testing::Format::kBinary ? std::unique_ptr(new test_model::binary::MapsReader(input)) diff --git a/matlab/generated/+test_model/+binary/ComplexArraysReader.m b/matlab/generated/+test_model/+binary/ComplexArraysReader.m new file mode 100644 index 00000000..e61fb8a3 --- /dev/null +++ b/matlab/generated/+test_model/+binary/ComplexArraysReader.m @@ -0,0 +1,28 @@ +% This file was generated by the "yardl" tool. DO NOT EDIT. + +classdef ComplexArraysReader < yardl.binary.BinaryProtocolReader & test_model.ComplexArraysReaderBase + % Binary reader for the ComplexArrays protocol + properties (Access=protected) + floats_serializer + doubles_serializer + end + + methods + function self = ComplexArraysReader(filename) + self@test_model.ComplexArraysReaderBase(); + self@yardl.binary.BinaryProtocolReader(filename, test_model.ComplexArraysReaderBase.schema); + self.floats_serializer = yardl.binary.DynamicNDArraySerializer(yardl.binary.Complexfloat32Serializer); + self.doubles_serializer = yardl.binary.NDArraySerializer(yardl.binary.Complexfloat64Serializer, 2); + end + end + + methods (Access=protected) + function value = read_floats_(self) + value = self.floats_serializer.read(self.stream_); + end + + function value = read_doubles_(self) + value = self.doubles_serializer.read(self.stream_); + end + end +end diff --git a/matlab/generated/+test_model/+binary/ComplexArraysWriter.m b/matlab/generated/+test_model/+binary/ComplexArraysWriter.m new file mode 100644 index 00000000..1f3ed4bd --- /dev/null +++ b/matlab/generated/+test_model/+binary/ComplexArraysWriter.m @@ -0,0 +1,28 @@ +% This file was generated by the "yardl" tool. DO NOT EDIT. + +classdef ComplexArraysWriter < yardl.binary.BinaryProtocolWriter & test_model.ComplexArraysWriterBase + % Binary writer for the ComplexArrays protocol + properties (Access=protected) + floats_serializer + doubles_serializer + end + + methods + function self = ComplexArraysWriter(filename) + self@test_model.ComplexArraysWriterBase(); + self@yardl.binary.BinaryProtocolWriter(filename, test_model.ComplexArraysWriterBase.schema); + self.floats_serializer = yardl.binary.DynamicNDArraySerializer(yardl.binary.Complexfloat32Serializer); + self.doubles_serializer = yardl.binary.NDArraySerializer(yardl.binary.Complexfloat64Serializer, 2); + end + end + + methods (Access=protected) + function write_floats_(self, value) + self.floats_serializer.write(self.stream_, value); + end + + function write_doubles_(self, value) + self.doubles_serializer.write(self.stream_, value); + end + end +end diff --git a/matlab/generated/+test_model/+testing/MockComplexArraysWriter.m b/matlab/generated/+test_model/+testing/MockComplexArraysWriter.m new file mode 100644 index 00000000..6f8a67f7 --- /dev/null +++ b/matlab/generated/+test_model/+testing/MockComplexArraysWriter.m @@ -0,0 +1,49 @@ +% This file was generated by the "yardl" tool. DO NOT EDIT. + +classdef MockComplexArraysWriter < matlab.mixin.Copyable & test_model.ComplexArraysWriterBase + properties + testCase_ + expected_floats + expected_doubles + end + + methods + function self = MockComplexArraysWriter(testCase) + self.testCase_ = testCase; + self.expected_floats = yardl.None; + self.expected_doubles = yardl.None; + end + + function expect_write_floats_(self, value) + self.expected_floats = yardl.Optional(value); + end + + function expect_write_doubles_(self, value) + self.expected_doubles = yardl.Optional(value); + end + + function verify(self) + self.testCase_.verifyEqual(self.expected_floats, yardl.None, "Expected call to write_floats_ was not received"); + self.testCase_.verifyEqual(self.expected_doubles, yardl.None, "Expected call to write_doubles_ was not received"); + end + end + + methods (Access=protected) + function write_floats_(self, value) + self.testCase_.verifyTrue(self.expected_floats.has_value(), "Unexpected call to write_floats_"); + self.testCase_.verifyEqual(value, self.expected_floats.value, "Unexpected argument value for call to write_floats_"); + self.expected_floats = yardl.None; + end + + function write_doubles_(self, value) + self.testCase_.verifyTrue(self.expected_doubles.has_value(), "Unexpected call to write_doubles_"); + self.testCase_.verifyEqual(value, self.expected_doubles.value, "Unexpected argument value for call to write_doubles_"); + self.expected_doubles = yardl.None; + end + + function close_(self) + end + function end_stream_(self) + end + end +end diff --git a/matlab/generated/+test_model/+testing/TestComplexArraysWriter.m b/matlab/generated/+test_model/+testing/TestComplexArraysWriter.m new file mode 100644 index 00000000..70e70fce --- /dev/null +++ b/matlab/generated/+test_model/+testing/TestComplexArraysWriter.m @@ -0,0 +1,66 @@ +% This file was generated by the "yardl" tool. DO NOT EDIT. + +classdef TestComplexArraysWriter < test_model.ComplexArraysWriterBase + properties (Access = private) + writer_ + create_reader_ + mock_writer_ + close_called_ + filename_ + format_ + end + + methods + function self = TestComplexArraysWriter(testCase, format, create_writer, create_reader) + self.filename_ = tempname(); + self.format_ = format; + self.writer_ = create_writer(self.filename_); + self.create_reader_ = create_reader; + self.mock_writer_ = test_model.testing.MockComplexArraysWriter(testCase); + self.close_called_ = false; + end + + function delete(self) + delete(self.filename_); + if ~self.close_called_ + % ADD_FAILURE() << ...; + throw(yardl.RuntimeError("Close() must be called on 'TestComplexArraysWriter' to verify mocks")); + end + end + end + + methods (Access=protected) + function write_floats_(self, value) + self.writer_.write_floats(value); + self.mock_writer_.expect_write_floats_(value); + end + + function write_doubles_(self, value) + self.writer_.write_doubles(value); + self.mock_writer_.expect_write_doubles_(value); + end + + function close_(self) + self.close_called_ = true; + self.writer_.close(); + mock_copy = copy(self.mock_writer_); + + reader = self.create_reader_(self.filename_); + reader.copy_to(self.mock_writer_); + reader.close(); + self.mock_writer_.verify(); + self.mock_writer_.close(); + + translated = invoke_translator(self.filename_, self.format_, self.format_); + reader = self.create_reader_(translated); + reader.copy_to(mock_copy); + reader.close(); + mock_copy.verify(); + mock_copy.close(); + delete(translated); + end + + function end_stream_(self) + end + end +end diff --git a/matlab/generated/+test_model/ComplexArraysReaderBase.m b/matlab/generated/+test_model/ComplexArraysReaderBase.m new file mode 100644 index 00000000..578c20b9 --- /dev/null +++ b/matlab/generated/+test_model/ComplexArraysReaderBase.m @@ -0,0 +1,77 @@ +% This file was generated by the "yardl" tool. DO NOT EDIT. + +classdef ComplexArraysReaderBase < handle + properties (Access=protected) + state_ + end + + methods + function self = ComplexArraysReaderBase() + self.state_ = 0; + end + + function close(self) + self.close_(); + if self.state_ ~= 2 + expected_method = self.state_to_method_name_(self.state_); + throw(yardl.ProtocolError("Protocol reader closed before all data was consumed. Expected call to '%s'.", expected_method)); + end + end + + % Ordinal 0 + function value = read_floats(self) + if self.state_ ~= 0 + self.raise_unexpected_state_(0); + end + + value = self.read_floats_(); + self.state_ = 1; + end + + % Ordinal 1 + function value = read_doubles(self) + if self.state_ ~= 1 + self.raise_unexpected_state_(1); + end + + value = self.read_doubles_(); + self.state_ = 2; + end + + function copy_to(self, writer) + writer.write_floats(self.read_floats()); + writer.write_doubles(self.read_doubles()); + end + end + + methods (Static) + function res = schema() + res = test_model.ComplexArraysWriterBase.schema; + end + end + + methods (Abstract, Access=protected) + read_floats_(self) + read_doubles_(self) + + close_(self) + end + + methods (Access=private) + function raise_unexpected_state_(self, actual) + actual_method = self.state_to_method_name_(actual); + expected_method = self.state_to_method_name_(self.state_); + throw(yardl.ProtocolError("Expected call to '%s' but received call to '%s'.", expected_method, actual_method)); + end + + function name = state_to_method_name_(self, state) + if state == 0 + name = "read_floats"; + elseif state == 1 + name = "read_doubles"; + else + name = ""; + end + end + end +end diff --git a/matlab/generated/+test_model/ComplexArraysWriterBase.m b/matlab/generated/+test_model/ComplexArraysWriterBase.m new file mode 100644 index 00000000..456bb689 --- /dev/null +++ b/matlab/generated/+test_model/ComplexArraysWriterBase.m @@ -0,0 +1,74 @@ +% This file was generated by the "yardl" tool. DO NOT EDIT. + +% Abstract writer for protocol ComplexArrays +classdef (Abstract) ComplexArraysWriterBase < handle + properties (Access=protected) + state_ + end + + methods + function self = ComplexArraysWriterBase() + self.state_ = 0; + end + + function close(self) + self.close_(); + if self.state_ ~= 2 + expected_method = self.state_to_method_name_(self.state_); + throw(yardl.ProtocolError("Protocol writer closed before all steps were called. Expected call to '%s'.", expected_method)); + end + end + + % Ordinal 0 + function write_floats(self, value) + if self.state_ ~= 0 + self.raise_unexpected_state_(0); + end + + self.write_floats_(value); + self.state_ = 1; + end + + % Ordinal 1 + function write_doubles(self, value) + if self.state_ ~= 1 + self.raise_unexpected_state_(1); + end + + self.write_doubles_(value); + self.state_ = 2; + end + end + + methods (Static) + function res = schema() + res = string('{"protocol":{"name":"ComplexArrays","sequence":[{"name":"floats","type":{"array":{"items":"complexfloat32"}}},{"name":"doubles","type":{"array":{"items":"complexfloat64","dimensions":2}}}]},"types":null}'); + end + end + + methods (Abstract, Access=protected) + write_floats_(self, value) + write_doubles_(self, value) + + end_stream_(self) + close_(self) + end + + methods (Access=private) + function raise_unexpected_state_(self, actual) + expected_method = self.state_to_method_name_(self.state_); + actual_method = self.state_to_method_name_(actual); + throw(yardl.ProtocolError("Expected call to '%s' but received call to '%s'", expected_method, actual_method)); + end + + function name = state_to_method_name_(self, state) + if state == 0 + name = "write_floats"; + elseif state == 1 + name = "write_doubles"; + else + name = ''; + end + end + end +end diff --git a/models/test/unittests.yml b/models/test/unittests.yml index 0ecf5826..7bf86702 100644 --- a/models/test/unittests.yml +++ b/models/test/unittests.yml @@ -304,6 +304,11 @@ MultiDArrays: !protocol frames: !stream items: float[ch:1, z:1, y:64, x:32] +ComplexArrays: !protocol + sequence: + floats: complexfloat32[] + doubles: complexfloat64[,] + Maps: !protocol sequence: stringToInt: string->int diff --git a/python/test_model/__init__.py b/python/test_model/__init__.py index 9b49f57e..b138924d 100644 --- a/python/test_model/__init__.py +++ b/python/test_model/__init__.py @@ -133,6 +133,8 @@ def _parse_version(version: str) -> _Tuple[int, ...]: BenchmarkSmallRecordWithOptionalsReaderBase, BenchmarkSmallRecordWithOptionalsWriterBase, BenchmarkSmallRecordWriterBase, + ComplexArraysReaderBase, + ComplexArraysWriterBase, DynamicNDArraysReaderBase, DynamicNDArraysWriterBase, EnumsReaderBase, @@ -201,6 +203,8 @@ def _parse_version(version: str) -> _Tuple[int, ...]: BinaryBenchmarkSmallRecordWithOptionalsReader, BinaryBenchmarkSmallRecordWithOptionalsWriter, BinaryBenchmarkSmallRecordWriter, + BinaryComplexArraysReader, + BinaryComplexArraysWriter, BinaryDynamicNDArraysReader, BinaryDynamicNDArraysWriter, BinaryEnumsReader, @@ -269,6 +273,8 @@ def _parse_version(version: str) -> _Tuple[int, ...]: NDJsonBenchmarkSmallRecordWithOptionalsReader, NDJsonBenchmarkSmallRecordWithOptionalsWriter, NDJsonBenchmarkSmallRecordWriter, + NDJsonComplexArraysReader, + NDJsonComplexArraysWriter, NDJsonDynamicNDArraysReader, NDJsonDynamicNDArraysWriter, NDJsonEnumsReader, diff --git a/python/test_model/binary.py b/python/test_model/binary.py index 3cf168c6..1dd4085d 100644 --- a/python/test_model/binary.py +++ b/python/test_model/binary.py @@ -730,6 +730,35 @@ def _read_images(self) -> collections.abc.Iterable[npt.NDArray[np.float32]]: def _read_frames(self) -> collections.abc.Iterable[npt.NDArray[np.float32]]: return _binary.StreamSerializer(_binary.FixedNDArraySerializer(_binary.float32_serializer, (1, 1, 64, 32,))).read(self._stream) +class BinaryComplexArraysWriter(_binary.BinaryProtocolWriter, ComplexArraysWriterBase): + """Binary writer for the ComplexArrays protocol.""" + + + def __init__(self, stream: typing.Union[typing.BinaryIO, str]) -> None: + ComplexArraysWriterBase.__init__(self) + _binary.BinaryProtocolWriter.__init__(self, stream, ComplexArraysWriterBase.schema) + + def _write_floats(self, value: npt.NDArray[np.complex64]) -> None: + _binary.DynamicNDArraySerializer(_binary.complexfloat32_serializer).write(self._stream, value) + + def _write_doubles(self, value: npt.NDArray[np.complex128]) -> None: + _binary.NDArraySerializer(_binary.complexfloat64_serializer, 2).write(self._stream, value) + + +class BinaryComplexArraysReader(_binary.BinaryProtocolReader, ComplexArraysReaderBase): + """Binary writer for the ComplexArrays protocol.""" + + + def __init__(self, stream: typing.Union[io.BufferedReader, io.BytesIO, typing.BinaryIO, str]) -> None: + ComplexArraysReaderBase.__init__(self) + _binary.BinaryProtocolReader.__init__(self, stream, ComplexArraysReaderBase.schema) + + def _read_floats(self) -> npt.NDArray[np.complex64]: + return _binary.DynamicNDArraySerializer(_binary.complexfloat32_serializer).read(self._stream) + + def _read_doubles(self) -> npt.NDArray[np.complex128]: + return _binary.NDArraySerializer(_binary.complexfloat64_serializer, 2).read(self._stream) + class BinaryMapsWriter(_binary.BinaryProtocolWriter, MapsWriterBase): """Binary writer for the Maps protocol.""" diff --git a/python/test_model/ndjson.py b/python/test_model/ndjson.py index 1025db3a..91a8168f 100644 --- a/python/test_model/ndjson.py +++ b/python/test_model/ndjson.py @@ -3379,6 +3379,43 @@ def _read_frames(self) -> collections.abc.Iterable[npt.NDArray[np.float32]]: while (json_object := self._read_json_line("frames", False)) is not _ndjson.MISSING_SENTINEL: yield converter.from_json(json_object) +class NDJsonComplexArraysWriter(_ndjson.NDJsonProtocolWriter, ComplexArraysWriterBase): + """NDJson writer for the ComplexArrays protocol.""" + + + def __init__(self, stream: typing.Union[typing.TextIO, str]) -> None: + ComplexArraysWriterBase.__init__(self) + _ndjson.NDJsonProtocolWriter.__init__(self, stream, ComplexArraysWriterBase.schema) + + def _write_floats(self, value: npt.NDArray[np.complex64]) -> None: + converter = _ndjson.DynamicNDArrayConverter(_ndjson.complexfloat32_converter) + json_value = converter.to_json(value) + self._write_json_line({"floats": json_value}) + + def _write_doubles(self, value: npt.NDArray[np.complex128]) -> None: + converter = _ndjson.NDArrayConverter(_ndjson.complexfloat64_converter, 2) + json_value = converter.to_json(value) + self._write_json_line({"doubles": json_value}) + + +class NDJsonComplexArraysReader(_ndjson.NDJsonProtocolReader, ComplexArraysReaderBase): + """NDJson writer for the ComplexArrays protocol.""" + + + def __init__(self, stream: typing.Union[io.BufferedReader, typing.TextIO, str]) -> None: + ComplexArraysReaderBase.__init__(self) + _ndjson.NDJsonProtocolReader.__init__(self, stream, ComplexArraysReaderBase.schema) + + def _read_floats(self) -> npt.NDArray[np.complex64]: + json_object = self._read_json_line("floats", True) + converter = _ndjson.DynamicNDArrayConverter(_ndjson.complexfloat32_converter) + return converter.from_json(json_object) + + def _read_doubles(self) -> npt.NDArray[np.complex128]: + json_object = self._read_json_line("doubles", True) + converter = _ndjson.NDArrayConverter(_ndjson.complexfloat64_converter, 2) + return converter.from_json(json_object) + class NDJsonMapsWriter(_ndjson.NDJsonProtocolWriter, MapsWriterBase): """NDJson writer for the Maps protocol.""" diff --git a/python/test_model/protocols.py b/python/test_model/protocols.py index 9276f020..778b8c71 100644 --- a/python/test_model/protocols.py +++ b/python/test_model/protocols.py @@ -3936,6 +3936,164 @@ def _state_to_method_name(self, state: int) -> str: return 'read_frames' return "" +class ComplexArraysWriterBase(abc.ABC): + """Abstract writer for the ComplexArrays protocol.""" + + + def __init__(self) -> None: + self._state = 0 + + schema = r"""{"protocol":{"name":"ComplexArrays","sequence":[{"name":"floats","type":{"array":{"items":"complexfloat32"}}},{"name":"doubles","type":{"array":{"items":"complexfloat64","dimensions":2}}}]},"types":null}""" + + def close(self) -> None: + self._close() + if self._state != 4: + expected_method = self._state_to_method_name((self._state + 1) & ~1) + raise ProtocolError(f"Protocol writer closed before all steps were called. Expected to call to '{expected_method}'.") + + def __enter__(self): + return self + + def __exit__(self, exc_type: typing.Optional[type[BaseException]], exc: typing.Optional[BaseException], traceback: object) -> None: + try: + self.close() + except Exception as e: + if exc is None: + raise e + + def write_floats(self, value: npt.NDArray[np.complex64]) -> None: + """Ordinal 0""" + + if self._state != 0: + self._raise_unexpected_state(0) + + self._write_floats(value) + self._state = 2 + + def write_doubles(self, value: npt.NDArray[np.complex128]) -> None: + """Ordinal 1""" + + if self._state != 2: + self._raise_unexpected_state(2) + + self._write_doubles(value) + self._state = 4 + + @abc.abstractmethod + def _write_floats(self, value: npt.NDArray[np.complex64]) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def _write_doubles(self, value: npt.NDArray[np.complex128]) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def _close(self) -> None: + pass + + @abc.abstractmethod + def _end_stream(self) -> None: + pass + + def _raise_unexpected_state(self, actual: int) -> None: + expected_method = self._state_to_method_name(self._state) + actual_method = self._state_to_method_name(actual) + raise ProtocolError(f"Expected to call to '{expected_method}' but received call to '{actual_method}'.") + + def _state_to_method_name(self, state: int) -> str: + if state == 0: + return 'write_floats' + if state == 2: + return 'write_doubles' + return "" + +class ComplexArraysReaderBase(abc.ABC): + """Abstract reader for the ComplexArrays protocol.""" + + + def __init__(self) -> None: + self._state = 0 + + def close(self) -> None: + self._close() + if self._state != 4: + if self._state % 2 == 1: + previous_method = self._state_to_method_name(self._state - 1) + raise ProtocolError(f"Protocol reader closed before all data was consumed. The iterable returned by '{previous_method}' was not fully consumed.") + else: + expected_method = self._state_to_method_name(self._state) + raise ProtocolError(f"Protocol reader closed before all data was consumed. Expected call to '{expected_method}'.") + + + schema = ComplexArraysWriterBase.schema + + def __enter__(self): + return self + + def __exit__(self, exc_type: typing.Optional[type[BaseException]], exc: typing.Optional[BaseException], traceback: object) -> None: + try: + self.close() + except Exception as e: + if exc is None: + raise e + + @abc.abstractmethod + def _close(self) -> None: + raise NotImplementedError() + + def read_floats(self) -> npt.NDArray[np.complex64]: + """Ordinal 0""" + + if self._state != 0: + self._raise_unexpected_state(0) + + value = self._read_floats() + self._state = 2 + return value + + def read_doubles(self) -> npt.NDArray[np.complex128]: + """Ordinal 1""" + + if self._state != 2: + self._raise_unexpected_state(2) + + value = self._read_doubles() + self._state = 4 + return value + + def copy_to(self, writer: ComplexArraysWriterBase) -> None: + writer.write_floats(self.read_floats()) + writer.write_doubles(self.read_doubles()) + + @abc.abstractmethod + def _read_floats(self) -> npt.NDArray[np.complex64]: + raise NotImplementedError() + + @abc.abstractmethod + def _read_doubles(self) -> npt.NDArray[np.complex128]: + raise NotImplementedError() + + T = typing.TypeVar('T') + def _wrap_iterable(self, iterable: collections.abc.Iterable[T], final_state: int) -> collections.abc.Iterable[T]: + yield from iterable + self._state = final_state + + def _raise_unexpected_state(self, actual: int) -> None: + actual_method = self._state_to_method_name(actual) + if self._state % 2 == 1: + previous_method = self._state_to_method_name(self._state - 1) + raise ProtocolError(f"Received call to '{actual_method}' but the iterable returned by '{previous_method}' was not fully consumed.") + else: + expected_method = self._state_to_method_name(self._state) + raise ProtocolError(f"Expected to call to '{expected_method}' but received call to '{actual_method}'.") + + def _state_to_method_name(self, state: int) -> str: + if state == 0: + return 'read_floats' + if state == 2: + return 'read_doubles' + return "" + class MapsWriterBase(abc.ABC): """Abstract writer for the Maps protocol.""" diff --git a/python/tests/test_protocol_roundtrip.py b/python/tests/test_protocol_roundtrip.py index 1d27adc6..36a59948 100644 --- a/python/tests/test_protocol_roundtrip.py +++ b/python/tests/test_protocol_roundtrip.py @@ -213,6 +213,21 @@ def test_fixed_arrays(format: Format): w.write_named_array(named_fixed_array) +def test_complex_arrays(format: Format): + with create_validating_writer_class(format, tm.ComplexArraysWriterBase)() as w: + fs = np.zeros((2, 16), dtype=np.complex64) + w.write_floats(fs) + ds = np.zeros((2, 16), dtype=np.complex128) + w.write_doubles(ds) + + # Again but with arrays in Fortran order + with create_validating_writer_class(format, tm.ComplexArraysWriterBase)() as w: + fs = np.zeros((2, 16), dtype=np.complex64, order="F") + w.write_floats(fs) + ds = np.zeros((2, 16), dtype=np.complex128, order="F") + w.write_doubles(ds) + + def test_subarrays(format: Format): with create_validating_writer_class(format, tm.SubarraysWriterBase)() as w: with pytest.raises( diff --git a/tooling/internal/python/static_files/_binary.py b/tooling/internal/python/static_files/_binary.py index 8e193074..8c3c7d86 100644 --- a/tooling/internal/python/static_files/_binary.py +++ b/tooling/internal/python/static_files/_binary.py @@ -628,6 +628,9 @@ def __init__(self) -> None: def write(self, stream: CodedOutputStream, value: ComplexFloat) -> None: stream.write(self._struct, value.real, value.imag) + def write_numpy(self, stream: CodedOutputStream, value: np.complex64) -> None: + stream.write(self._struct, value.real, value.imag) + def read(self, stream: CodedInputStream) -> ComplexFloat: return ComplexFloat(*stream.read(self._struct)) @@ -649,6 +652,9 @@ def __init__(self) -> None: def write(self, stream: CodedOutputStream, value: ComplexDouble) -> None: stream.write(self._struct, value.real, value.imag) + def write_numpy(self, stream: CodedOutputStream, value: np.complex128) -> None: + stream.write(self._struct, value.real, value.imag) + def read(self, stream: CodedInputStream) -> ComplexDouble: return ComplexDouble(*stream.read(self._struct))