Skip to content

Commit

Permalink
Merge pull request halide#3117 from halide/srj-output-tuple
Browse files Browse the repository at this point in the history
Allow declaring Output<Buffer<>> with tuples (Issue halide#2980)
  • Loading branch information
Zalman Stern authored Jul 17, 2018
2 parents 6acc274 + a0c7a39 commit 7bc00ba
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 31 deletions.
9 changes: 6 additions & 3 deletions src/Generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1277,9 +1277,12 @@ void GeneratorBase::track_parameter_values(bool include_outputs) {
internal_assert(!output->funcs().empty());
for (auto &f : output->funcs()) {
user_assert(f.defined()) << "Output " << output->name() << " is not fully defined.";
Parameter p = f.output_buffer().parameter();
// This must use p.name(), *not* output->name()
get_value_tracker()->track_values(p.name(), parameter_constraints(p));
auto output_buffers = f.output_buffers();
for (auto &o : output_buffers) {
Parameter p = o.parameter();
// This must use p.name(), *not* output->name()
get_value_tracker()->track_values(p.name(), parameter_constraints(p));
}
}
}
}
Expand Down
37 changes: 21 additions & 16 deletions src/Generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2074,18 +2074,21 @@ class GeneratorOutput_Buffer : public GeneratorOutputImpl<T> {

internal_assert(f.defined());

const auto &output_types = f.output_types();
user_assert(output_types.size() == 1)
<< "Output " << this->name() << " should have size=1 but saw size=" << output_types.size() << "\n";

Buffer<> other(output_types.at(0), nullptr, std::vector<int>(f.dimensions(), 1));
user_assert(T::can_convert_from(other))
<< "Cannot assign to the Output \"" << this->name()
<< "\": the expression is not convertible to the same Buffer type and/or dimensions.\n";
if (TBase::has_static_halide_type) {
Buffer<> other(f.output_types().at(0), nullptr, std::vector<int>(f.dimensions(), 1));
user_assert(T::can_convert_from(other))
<< "Cannot assign to the Output \"" << this->name()
<< "\": the expression is not convertible to the same Buffer type and/or dimensions.\n";
}

if (this->types_defined()) {
user_assert(output_types.at(0) == this->type())
<< "Output " << this->name() << " should have type=" << this->type() << " but saw type=" << output_types.at(0) << "\n";
const auto &my_types = this->types();
user_assert(my_types.size() == f.output_types().size())
<< "Output " << this->name() << " requires a Func with " << my_types.size() << " type(s) but tried to assign one with " << f.output_types().size() << " type(s)\n";
for (size_t i = 0; i < my_types.size(); i++) {
user_assert(my_types[i] == f.output_types().at(i))
<< "Output " << this->name() << " should have type[" << i << "]=" << my_types[i] << " but saw type[" << i << "]=" << f.output_types().at(i) << "\n";
}
}
if (this->dims_defined()) {
user_assert(f.dimensions() == this->dims())
Expand All @@ -2100,19 +2103,21 @@ class GeneratorOutput_Buffer : public GeneratorOutputImpl<T> {
protected:
using TBase = typename Super::TBase;

static std::vector<Type> my_types() {
return TBase::has_static_halide_type ? std::vector<Type>{ TBase::static_halide_type() } : std::vector<Type>{};
static std::vector<Type> my_types(const std::vector<Type> &t) {
if (TBase::has_static_halide_type) {
user_assert(t.empty()) << "Cannot pass a Type argument for an Output<Buffer> with a non-void static type\n";
return std::vector<Type>{ TBase::static_halide_type() };
}
return t;
}

protected:
GeneratorOutput_Buffer(const std::string &name, const std::vector<Type> &t = {}, int d = -1)
: Super(name, IOKind::Buffer, my_types(), d) {
user_assert(t.empty()) << "You cannot specify a Type argument for Output<Buffer<>>\n";
: Super(name, IOKind::Buffer, my_types(t), d) {
}

GeneratorOutput_Buffer(size_t array_size, const std::string &name, const std::vector<Type> &t = {}, int d = -1)
: Super(array_size, name, IOKind::Buffer, my_types(), d) {
user_assert(t.empty()) << "You cannot specify a Type argument for Output<Buffer<>>\n";
: Super(array_size, name, IOKind::Buffer, my_types(t), d) {
}

HALIDE_NO_USER_CODE_INLINE std::string get_c_type() const override {
Expand Down
51 changes: 40 additions & 11 deletions test/generator/metadata_tester_aottest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,15 @@ Buffer<Type> make_image() {
return im;
}

template <typename InputType, typename OutputType>
void verify(const Buffer<InputType> &input,
const Buffer<OutputType> &output0,
const Buffer<OutputType> &output1,
const Buffer<OutputType> &output_scalar,
const Buffer<OutputType> &output_array0,
const Buffer<OutputType> &output_array1,
const Buffer<OutputType> &untyped_output_buffer) {
void verify(const Buffer<uint8_t> &input,
const Buffer<float> &output0,
const Buffer<float> &output1,
const Buffer<float> &output_scalar,
const Buffer<float> &output_array0,
const Buffer<float> &output_array1,
const Buffer<float> &untyped_output_buffer,
const Buffer<float> &tupled_output_buffer0,
const Buffer<int32_t> &tupled_output_buffer1) {
if (output_scalar.dimensions() != 0) {
fprintf(stderr, "output_scalar should be zero-dimensional\n");
exit(-1);
Expand All @@ -144,9 +145,9 @@ void verify(const Buffer<InputType> &input,
for (int x = 0; x < kSize; x++) {
for (int y = 0; y < kSize; y++) {
for (int c = 0; c < 3; c++) {
const OutputType expected0 = static_cast<OutputType>(input(x, y, c));
const float expected0 = static_cast<float>(input(x, y, c));
const float expected1 = expected0 + 1;
const OutputType actual0 = output0(x, y, c);
const float actual0 = output0(x, y, c);
const float actual1 = output1(x, y, c);
if (expected0 != actual0) {
fprintf(stderr, "img0[%d, %d, %d] = %f, expected %f\n", x, y, c, (double)actual0, (double)expected0);
Expand All @@ -168,6 +169,10 @@ void verify(const Buffer<InputType> &input,
fprintf(stderr, "untyped_output_buffer[%d, %d, %d] = %f, expected %f\n", x, y, c, untyped_output_buffer(x, y, c), expected1);
exit(-1);
}
if (tupled_output_buffer0(x, y, c) != expected1) {
fprintf(stderr, "tupled_output_buffer0[%d, %d, %d] = %f, expected %f\n", x, y, c, tupled_output_buffer0(x, y, c), expected1);
exit(-1);
}
}
}
}
Expand Down Expand Up @@ -812,6 +817,24 @@ void check_metadata(const halide_filter_metadata_t &md, bool expect_ucon_at_0) {
nullptr,
nullptr,
},
{
"tupled_output_buffer.0",
halide_argument_kind_output_buffer,
3,
halide_type_t(halide_type_float, 32),
nullptr,
nullptr,
nullptr,
},
{
"tupled_output_buffer.1",
halide_argument_kind_output_buffer,
3,
halide_type_t(halide_type_int, 32),
nullptr,
nullptr,
nullptr,
},
{
"output_scalar",
halide_argument_kind_output_buffer,
Expand Down Expand Up @@ -1033,6 +1056,8 @@ int main(int argc, char **argv) {
Buffer<float> type_only_output_buffer(kSize, kSize, 3);
Buffer<float> dim_only_output_buffer(kSize, kSize, 3);
Buffer<float> untyped_output_buffer(kSize, kSize, 3);
Buffer<float> tupled_output_buffer0(kSize, kSize, 3);
Buffer<int32_t> tupled_output_buffer1(kSize, kSize, 3);
Buffer<float> output_scalar = Buffer<float>::make_scalar();
Buffer<float> output_array[2] = {{kSize, kSize, 3}, {kSize, kSize, 3}};
Buffer<float> output_array2[4] = {{kSize, kSize, 3}, {kSize, kSize, 3}, {kSize, kSize, 3}, {kSize, kSize, 3}};
Expand Down Expand Up @@ -1086,6 +1111,8 @@ int main(int argc, char **argv) {
type_only_output_buffer, // Output<Buffer<float>>
dim_only_output_buffer, // Output<Buffer<>>(3)
untyped_output_buffer, // Output<Buffer<>>
tupled_output_buffer0, // Output<Buffer<>> with tuple type
tupled_output_buffer1, // Output<Buffer<>> with tuple type
output_scalar, // Output<float>
output_array[0], output_array[1], // Output<Func[]>
output_array2[0], output_array2[1], output_array2[2], output_array2[3], // Output<Func[2]>(Tuple)
Expand Down Expand Up @@ -1142,6 +1169,8 @@ int main(int argc, char **argv) {
type_only_output_buffer, // Output<Buffer<float>>
dim_only_output_buffer, // Output<Buffer<>>(3)
untyped_output_buffer, // Output<Buffer<>>
tupled_output_buffer0, // Output<Buffer<>> with tuple type
tupled_output_buffer1, // Output<Buffer<>> with tuple type
output_scalar, // Output<float>
output_array[0], output_array[1], // Output<Func[]>
output_array2[0], output_array2[1], output_array2[2], output_array2[3], // Output<Func[2]>(Tuple)
Expand All @@ -1155,7 +1184,7 @@ int main(int argc, char **argv) {
);
EXPECT_EQ(0, result);

verify(input, output0, output1, output_scalar, output_array[0], output_array[1], untyped_output_buffer);
verify(input, output0, output1, output_scalar, output_array[0], output_array[1], untyped_output_buffer, tupled_output_buffer0, tupled_output_buffer1);

check_metadata(*metadata_tester_metadata(), false);
if (!strcmp(metadata_tester_metadata()->name, "metadata_tester_metadata")) {
Expand Down
2 changes: 2 additions & 0 deletions test/generator/metadata_tester_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class MetadataTester : public Halide::Generator<MetadataTester> {
Output<Buffer<float>> type_only_output_buffer{ "type_only_output_buffer" }; // untyped outputs can have type and/or dimensions inferred
Output<Buffer<>> dim_only_output_buffer{ "dim_only_output_buffer", 3 }; // untyped outputs can have type and/or dimensions inferred
Output<Buffer<>> untyped_output_buffer{ "untyped_output_buffer" }; // untyped outputs can have type and/or dimensions inferred
Output<Buffer<>> tupled_output_buffer{ "tupled_output_buffer", { Float(32), Int(32) }, 3 };
Output<float> output_scalar{ "output_scalar" };
Output<Func[]> array_outputs{ "array_outputs", Float(32), 3 }; // must be overridden to size=2
Output<Func[2]> array_outputs2{ "array_outputs2", { Float(32), Float(32) }, 3 };
Expand Down Expand Up @@ -100,6 +101,7 @@ class MetadataTester : public Halide::Generator<MetadataTester> {
typed_output_buffer(x, y, c) = f1(x, y, c);
type_only_output_buffer(x, y, c) = f1(x, y, c);
dim_only_output_buffer(x, y, c) = f1(x, y, c);
tupled_output_buffer(x, y, c) = Tuple(f2(x, y, c), cast<int32_t>(f2(x, y, c) + 1.5f));
// verify that we can assign a Func to an Output<Buffer<>>
untyped_output_buffer = f2;
output_scalar() = 1234.25f;
Expand Down
5 changes: 5 additions & 0 deletions test/generator/stubtest_aottest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ int main(int argc, char **argv) {
Buffer<float> array_input1 = make_image<float>(1);
Buffer<float> typed_buffer_output(kSize, kSize, 3);
Buffer<float> untyped_buffer_output(kSize, kSize, 3);
Buffer<float> tupled_output0(kSize, kSize, 3);
Buffer<int32_t> tupled_output1(kSize, kSize, 3);
Buffer<uint8_t> array_buffer_input0 = make_image<uint8_t>(0);
Buffer<uint8_t> array_buffer_input1 = make_image<uint8_t>(1);
Buffer<float> simple_output(kSize, kSize, 3);
Expand All @@ -72,13 +74,16 @@ int main(int argc, char **argv) {
array_output0, array_output1,
typed_buffer_output,
untyped_buffer_output,
tupled_output0, tupled_output1,
static_compiled_buffer_output,
array_buffer_output0, array_buffer_output1
);

verify(buffer_input, 1.f, 0, typed_buffer_output);
verify(buffer_input, 1.f, 0, untyped_buffer_output);
verify(simple_input, 1.f, 0, simple_output);
verify(simple_input, 1.f, 0, tupled_output0);
verify(simple_input, 1.f, 1, tupled_output1);
verify(array_input0, 1.f, 0, simple_output);
verify(array_input0, 1.25f, 0, tuple_output0);
verify(array_input0, 1.25f, 33, tuple_output1);
Expand Down
3 changes: 3 additions & 0 deletions test/generator/stubtest_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class StubTest : public Halide::Generator<StubTest> {
Output<Func[]> array_output{ "array_output", Int(16), 2}; // leave ArraySize unspecified
Output<Buffer<float>> typed_buffer_output{ "typed_buffer_output" };
Output<Buffer<>> untyped_buffer_output{ "untyped_buffer_output" };
Output<Buffer<>> tupled_output{ "tupled_output", { Float(32), Int(32) }, 3 };
Output<Buffer<uint8_t>> static_compiled_buffer_output{ "static_compiled_buffer_output", 3 };
Output<Buffer<uint8_t>[2]> array_buffer_output{ "array_buffer_output", 3 };

Expand All @@ -55,6 +56,8 @@ class StubTest : public Halide::Generator<StubTest> {
// explicit GeneratorParam to allow us to set it.
untyped_buffer_output(x, y, c) = cast(untyped_buffer_output_type, untyped_buffer_input(x, y, c));

tupled_output(x, y, c) = Tuple(simple_output(x, y, c), cast<int32_t>(simple_output(x, y, c)) + 1);

for (int i = 0; i < 2; ++i) {
array_buffer_output[i](x, y, c) = array_buffer_input[i](x, y,c) + 1 + i;
}
Expand Down
6 changes: 5 additions & 1 deletion test/generator/stubuser_aottest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,16 @@ int main(int argc, char **argv) {
Buffer<float> float32_buffer_output(kSize, kSize, 3);
Buffer<> int32_buffer_output(halide_type_t(halide_type_int, 32), kSize, kSize, 3);
Buffer<uint8_t> array_test_output(kSize, kSize, 3);
Buffer<float> tupled_output0(kSize, kSize, 3);
Buffer<int32_t> tupled_output1(kSize, kSize, 3);

stubuser(input, calculated_output, float32_buffer_output, int32_buffer_output, array_test_output);
stubuser(input, calculated_output, float32_buffer_output, int32_buffer_output, array_test_output, tupled_output0, tupled_output1);
verify(input, kFloatArg, kIntArg, kOffset, calculated_output);
verify(input, 1.f, 0, 0.f, float32_buffer_output);
verify<uint8_t, int32_t>(input, 1.f, 0, 0.f, int32_buffer_output);
verify(input, 1.f, 0, 2, array_test_output);
verify(input, 1.f, 0, 0, tupled_output0);
verify(input, 1.f, 1, 0, tupled_output1);

printf("Success!\n");
return 0;
Expand Down
3 changes: 3 additions & 0 deletions test/generator/stubuser_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class StubUser : public Halide::Generator<StubUser> {
Output<Buffer<float>> float32_buffer_output{"float32_buffer_output" };
Output<Buffer<int32_t>> int32_buffer_output{"int32_buffer_output" };
Output<Buffer<uint8_t>> array_test_output{"array_test_output" };
// We can infer the tupled-output-type from the Stub
Output<Buffer<>> tupled_output{ "tupled_output", 3 };

void generate() {
Var x{"x"}, y{"y"}, c{"c"};
Expand Down Expand Up @@ -60,6 +62,7 @@ class StubUser : public Halide::Generator<StubUser> {
float32_buffer_output = out.typed_buffer_output;
int32_buffer_output = out.untyped_buffer_output;
array_test_output = out.array_buffer_output[1];
tupled_output = out.tupled_output;

const float kOffset = 2.f;
calculated_output(x, y, c) = cast<uint8_t>(out.tuple_output(x, y, c)[1] + kOffset);
Expand Down

0 comments on commit 7bc00ba

Please sign in to comment.