From 52405453d5020f6e83a6e972ae1e0b8db7a0cc3c Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Tue, 10 Jul 2018 17:16:31 -0700 Subject: [PATCH] Allow declaring Output> with tuples (Issue #2980) For AOT, this produces an output buffer for each tuple element. --- src/Generator.cpp | 9 ++-- src/Generator.h | 37 ++++++++------ test/generator/metadata_tester_aottest.cpp | 51 +++++++++++++++----- test/generator/metadata_tester_generator.cpp | 2 + test/generator/stubtest_aottest.cpp | 5 ++ test/generator/stubtest_generator.cpp | 3 ++ test/generator/stubuser_aottest.cpp | 6 ++- test/generator/stubuser_generator.cpp | 3 ++ 8 files changed, 85 insertions(+), 31 deletions(-) diff --git a/src/Generator.cpp b/src/Generator.cpp index aaea6284e8ba..c24a1570e017 100644 --- a/src/Generator.cpp +++ b/src/Generator.cpp @@ -1277,9 +1277,12 @@ void GeneratorBase::track_parameter_values(bool include_outputs) { internal_assert(!output->funcs().empty()); for (auto &f : output->funcs()) { user_assert(f.defined()) << "Output " << output->name() << " is not fully defined."; - Parameter p = f.output_buffer().parameter(); - // This must use p.name(), *not* output->name() - get_value_tracker()->track_values(p.name(), parameter_constraints(p)); + auto output_buffers = f.output_buffers(); + for (auto &o : output_buffers) { + Parameter p = o.parameter(); + // This must use p.name(), *not* output->name() + get_value_tracker()->track_values(p.name(), parameter_constraints(p)); + } } } } diff --git a/src/Generator.h b/src/Generator.h index 73dd84189da8..aaed625933d1 100644 --- a/src/Generator.h +++ b/src/Generator.h @@ -2074,18 +2074,21 @@ class GeneratorOutput_Buffer : public GeneratorOutputImpl { internal_assert(f.defined()); - const auto &output_types = f.output_types(); - user_assert(output_types.size() == 1) - << "Output " << this->name() << " should have size=1 but saw size=" << output_types.size() << "\n"; - - Buffer<> other(output_types.at(0), nullptr, std::vector(f.dimensions(), 1)); - user_assert(T::can_convert_from(other)) - << "Cannot assign to the Output \"" << this->name() - << "\": the expression is not convertible to the same Buffer type and/or dimensions.\n"; + if (TBase::has_static_halide_type) { + Buffer<> other(f.output_types().at(0), nullptr, std::vector(f.dimensions(), 1)); + user_assert(T::can_convert_from(other)) + << "Cannot assign to the Output \"" << this->name() + << "\": the expression is not convertible to the same Buffer type and/or dimensions.\n"; + } if (this->types_defined()) { - user_assert(output_types.at(0) == this->type()) - << "Output " << this->name() << " should have type=" << this->type() << " but saw type=" << output_types.at(0) << "\n"; + const auto &my_types = this->types(); + user_assert(my_types.size() == f.output_types().size()) + << "Output " << this->name() << " requires a Func with " << my_types.size() << " type(s) but tried to assign one with " << f.output_types().size() << " type(s)\n"; + for (size_t i = 0; i < my_types.size(); i++) { + user_assert(my_types[i] == f.output_types().at(i)) + << "Output " << this->name() << " should have type[" << i << "]=" << my_types[i] << " but saw type[" << i << "]=" << f.output_types().at(i) << "\n"; + } } if (this->dims_defined()) { user_assert(f.dimensions() == this->dims()) @@ -2100,19 +2103,21 @@ class GeneratorOutput_Buffer : public GeneratorOutputImpl { protected: using TBase = typename Super::TBase; - static std::vector my_types() { - return TBase::has_static_halide_type ? std::vector{ TBase::static_halide_type() } : std::vector{}; + static std::vector my_types(const std::vector &t) { + if (TBase::has_static_halide_type) { + user_assert(t.empty()) << "Cannot pass a Type argument for an Output with a non-void static type\n"; + return std::vector{ TBase::static_halide_type() }; + } + return t; } protected: GeneratorOutput_Buffer(const std::string &name, const std::vector &t = {}, int d = -1) - : Super(name, IOKind::Buffer, my_types(), d) { - user_assert(t.empty()) << "You cannot specify a Type argument for Output>\n"; + : Super(name, IOKind::Buffer, my_types(t), d) { } GeneratorOutput_Buffer(size_t array_size, const std::string &name, const std::vector &t = {}, int d = -1) - : Super(array_size, name, IOKind::Buffer, my_types(), d) { - user_assert(t.empty()) << "You cannot specify a Type argument for Output>\n"; + : Super(array_size, name, IOKind::Buffer, my_types(t), d) { } HALIDE_NO_USER_CODE_INLINE std::string get_c_type() const override { diff --git a/test/generator/metadata_tester_aottest.cpp b/test/generator/metadata_tester_aottest.cpp index e2d8217f2161..8fb88192b663 100644 --- a/test/generator/metadata_tester_aottest.cpp +++ b/test/generator/metadata_tester_aottest.cpp @@ -125,14 +125,15 @@ Buffer make_image() { return im; } -template -void verify(const Buffer &input, - const Buffer &output0, - const Buffer &output1, - const Buffer &output_scalar, - const Buffer &output_array0, - const Buffer &output_array1, - const Buffer &untyped_output_buffer) { +void verify(const Buffer &input, + const Buffer &output0, + const Buffer &output1, + const Buffer &output_scalar, + const Buffer &output_array0, + const Buffer &output_array1, + const Buffer &untyped_output_buffer, + const Buffer &tupled_output_buffer0, + const Buffer &tupled_output_buffer1) { if (output_scalar.dimensions() != 0) { fprintf(stderr, "output_scalar should be zero-dimensional\n"); exit(-1); @@ -144,9 +145,9 @@ void verify(const Buffer &input, for (int x = 0; x < kSize; x++) { for (int y = 0; y < kSize; y++) { for (int c = 0; c < 3; c++) { - const OutputType expected0 = static_cast(input(x, y, c)); + const float expected0 = static_cast(input(x, y, c)); const float expected1 = expected0 + 1; - const OutputType actual0 = output0(x, y, c); + const float actual0 = output0(x, y, c); const float actual1 = output1(x, y, c); if (expected0 != actual0) { fprintf(stderr, "img0[%d, %d, %d] = %f, expected %f\n", x, y, c, (double)actual0, (double)expected0); @@ -168,6 +169,10 @@ void verify(const Buffer &input, fprintf(stderr, "untyped_output_buffer[%d, %d, %d] = %f, expected %f\n", x, y, c, untyped_output_buffer(x, y, c), expected1); exit(-1); } + if (tupled_output_buffer0(x, y, c) != expected1) { + fprintf(stderr, "tupled_output_buffer0[%d, %d, %d] = %f, expected %f\n", x, y, c, tupled_output_buffer0(x, y, c), expected1); + exit(-1); + } } } } @@ -812,6 +817,24 @@ void check_metadata(const halide_filter_metadata_t &md, bool expect_ucon_at_0) { nullptr, nullptr, }, + { + "tupled_output_buffer.0", + halide_argument_kind_output_buffer, + 3, + halide_type_t(halide_type_float, 32), + nullptr, + nullptr, + nullptr, + }, + { + "tupled_output_buffer.1", + halide_argument_kind_output_buffer, + 3, + halide_type_t(halide_type_int, 32), + nullptr, + nullptr, + nullptr, + }, { "output_scalar", halide_argument_kind_output_buffer, @@ -1033,6 +1056,8 @@ int main(int argc, char **argv) { Buffer type_only_output_buffer(kSize, kSize, 3); Buffer dim_only_output_buffer(kSize, kSize, 3); Buffer untyped_output_buffer(kSize, kSize, 3); + Buffer tupled_output_buffer0(kSize, kSize, 3); + Buffer tupled_output_buffer1(kSize, kSize, 3); Buffer output_scalar = Buffer::make_scalar(); Buffer output_array[2] = {{kSize, kSize, 3}, {kSize, kSize, 3}}; Buffer output_array2[4] = {{kSize, kSize, 3}, {kSize, kSize, 3}, {kSize, kSize, 3}, {kSize, kSize, 3}}; @@ -1086,6 +1111,8 @@ int main(int argc, char **argv) { type_only_output_buffer, // Output> dim_only_output_buffer, // Output>(3) untyped_output_buffer, // Output> + tupled_output_buffer0, // Output> with tuple type + tupled_output_buffer1, // Output> with tuple type output_scalar, // Output output_array[0], output_array[1], // Output output_array2[0], output_array2[1], output_array2[2], output_array2[3], // Output(Tuple) @@ -1142,6 +1169,8 @@ int main(int argc, char **argv) { type_only_output_buffer, // Output> dim_only_output_buffer, // Output>(3) untyped_output_buffer, // Output> + tupled_output_buffer0, // Output> with tuple type + tupled_output_buffer1, // Output> with tuple type output_scalar, // Output output_array[0], output_array[1], // Output output_array2[0], output_array2[1], output_array2[2], output_array2[3], // Output(Tuple) @@ -1155,7 +1184,7 @@ int main(int argc, char **argv) { ); EXPECT_EQ(0, result); - verify(input, output0, output1, output_scalar, output_array[0], output_array[1], untyped_output_buffer); + verify(input, output0, output1, output_scalar, output_array[0], output_array[1], untyped_output_buffer, tupled_output_buffer0, tupled_output_buffer1); check_metadata(*metadata_tester_metadata(), false); if (!strcmp(metadata_tester_metadata()->name, "metadata_tester_metadata")) { diff --git a/test/generator/metadata_tester_generator.cpp b/test/generator/metadata_tester_generator.cpp index 873d72a04622..9652f4428466 100644 --- a/test/generator/metadata_tester_generator.cpp +++ b/test/generator/metadata_tester_generator.cpp @@ -53,6 +53,7 @@ class MetadataTester : public Halide::Generator { Output> type_only_output_buffer{ "type_only_output_buffer" }; // untyped outputs can have type and/or dimensions inferred Output> dim_only_output_buffer{ "dim_only_output_buffer", 3 }; // untyped outputs can have type and/or dimensions inferred Output> untyped_output_buffer{ "untyped_output_buffer" }; // untyped outputs can have type and/or dimensions inferred + Output> tupled_output_buffer{ "tupled_output_buffer", { Float(32), Int(32) }, 3 }; Output output_scalar{ "output_scalar" }; Output array_outputs{ "array_outputs", Float(32), 3 }; // must be overridden to size=2 Output array_outputs2{ "array_outputs2", { Float(32), Float(32) }, 3 }; @@ -100,6 +101,7 @@ class MetadataTester : public Halide::Generator { typed_output_buffer(x, y, c) = f1(x, y, c); type_only_output_buffer(x, y, c) = f1(x, y, c); dim_only_output_buffer(x, y, c) = f1(x, y, c); + tupled_output_buffer(x, y, c) = Tuple(f2(x, y, c), cast(f2(x, y, c) + 1.5f)); // verify that we can assign a Func to an Output> untyped_output_buffer = f2; output_scalar() = 1234.25f; diff --git a/test/generator/stubtest_aottest.cpp b/test/generator/stubtest_aottest.cpp index 4483265fa13b..b333a1615449 100644 --- a/test/generator/stubtest_aottest.cpp +++ b/test/generator/stubtest_aottest.cpp @@ -50,6 +50,8 @@ int main(int argc, char **argv) { Buffer array_input1 = make_image(1); Buffer typed_buffer_output(kSize, kSize, 3); Buffer untyped_buffer_output(kSize, kSize, 3); + Buffer tupled_output0(kSize, kSize, 3); + Buffer tupled_output1(kSize, kSize, 3); Buffer array_buffer_input0 = make_image(0); Buffer array_buffer_input1 = make_image(1); Buffer simple_output(kSize, kSize, 3); @@ -72,6 +74,7 @@ int main(int argc, char **argv) { array_output0, array_output1, typed_buffer_output, untyped_buffer_output, + tupled_output0, tupled_output1, static_compiled_buffer_output, array_buffer_output0, array_buffer_output1 ); @@ -79,6 +82,8 @@ int main(int argc, char **argv) { verify(buffer_input, 1.f, 0, typed_buffer_output); verify(buffer_input, 1.f, 0, untyped_buffer_output); verify(simple_input, 1.f, 0, simple_output); + verify(simple_input, 1.f, 0, tupled_output0); + verify(simple_input, 1.f, 1, tupled_output1); verify(array_input0, 1.f, 0, simple_output); verify(array_input0, 1.25f, 0, tuple_output0); verify(array_input0, 1.25f, 33, tuple_output1); diff --git a/test/generator/stubtest_generator.cpp b/test/generator/stubtest_generator.cpp index e35d262f6ed0..4edca7879b7d 100644 --- a/test/generator/stubtest_generator.cpp +++ b/test/generator/stubtest_generator.cpp @@ -42,6 +42,7 @@ class StubTest : public Halide::Generator { Output array_output{ "array_output", Int(16), 2}; // leave ArraySize unspecified Output> typed_buffer_output{ "typed_buffer_output" }; Output> untyped_buffer_output{ "untyped_buffer_output" }; + Output> tupled_output{ "tupled_output", { Float(32), Int(32) }, 3 }; Output> static_compiled_buffer_output{ "static_compiled_buffer_output", 3 }; Output[2]> array_buffer_output{ "array_buffer_output", 3 }; @@ -55,6 +56,8 @@ class StubTest : public Halide::Generator { // explicit GeneratorParam to allow us to set it. untyped_buffer_output(x, y, c) = cast(untyped_buffer_output_type, untyped_buffer_input(x, y, c)); + tupled_output(x, y, c) = Tuple(simple_output(x, y, c), cast(simple_output(x, y, c)) + 1); + for (int i = 0; i < 2; ++i) { array_buffer_output[i](x, y, c) = array_buffer_input[i](x, y,c) + 1 + i; } diff --git a/test/generator/stubuser_aottest.cpp b/test/generator/stubuser_aottest.cpp index 65bd7a94acd6..aae6da13884f 100644 --- a/test/generator/stubuser_aottest.cpp +++ b/test/generator/stubuser_aottest.cpp @@ -53,12 +53,16 @@ int main(int argc, char **argv) { Buffer float32_buffer_output(kSize, kSize, 3); Buffer<> int32_buffer_output(halide_type_t(halide_type_int, 32), kSize, kSize, 3); Buffer array_test_output(kSize, kSize, 3); + Buffer tupled_output0(kSize, kSize, 3); + Buffer tupled_output1(kSize, kSize, 3); - stubuser(input, calculated_output, float32_buffer_output, int32_buffer_output, array_test_output); + stubuser(input, calculated_output, float32_buffer_output, int32_buffer_output, array_test_output, tupled_output0, tupled_output1); verify(input, kFloatArg, kIntArg, kOffset, calculated_output); verify(input, 1.f, 0, 0.f, float32_buffer_output); verify(input, 1.f, 0, 0.f, int32_buffer_output); verify(input, 1.f, 0, 2, array_test_output); + verify(input, 1.f, 0, 0, tupled_output0); + verify(input, 1.f, 1, 0, tupled_output1); printf("Success!\n"); return 0; diff --git a/test/generator/stubuser_generator.cpp b/test/generator/stubuser_generator.cpp index 72bb7a8e52ea..8a7d5ce7d6bf 100644 --- a/test/generator/stubuser_generator.cpp +++ b/test/generator/stubuser_generator.cpp @@ -28,6 +28,8 @@ class StubUser : public Halide::Generator { Output> float32_buffer_output{"float32_buffer_output" }; Output> int32_buffer_output{"int32_buffer_output" }; Output> array_test_output{"array_test_output" }; + // We can infer the tupled-output-type from the Stub + Output> tupled_output{ "tupled_output", 3 }; void generate() { Var x{"x"}, y{"y"}, c{"c"}; @@ -60,6 +62,7 @@ class StubUser : public Halide::Generator { float32_buffer_output = out.typed_buffer_output; int32_buffer_output = out.untyped_buffer_output; array_test_output = out.array_buffer_output[1]; + tupled_output = out.tupled_output; const float kOffset = 2.f; calculated_output(x, y, c) = cast(out.tuple_output(x, y, c)[1] + kOffset);