Skip to content

Commit

Permalink
cherry-picking commit 6b4bd2f: 【Hackathon 7th No.46】 添加对返回常量的 IfElse …
Browse files Browse the repository at this point in the history
…算子的支持 (PaddlePaddle#1383)

* wip

* fix

* update due to comment

* Add missing implementation

* Restore code format

* Restore code format
  • Loading branch information
Asthestarsfalll authored and 0x3878f committed Dec 6, 2024
1 parent 7f1e54d commit 3f2dd54
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 20 deletions.
72 changes: 60 additions & 12 deletions paddle2onnx/mapper/exporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,27 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportConditionalBlock(
parser, sub_block_idx, temp_parameters, temp_inputs, temp_outputs));
}

ONNX_NAMESPACE::GraphProto ModelExporter::ExportFillConstant(
const PaddleParser& parser,
OnnxHelper* temp_helper,
int32_t block_id,
int32_t op_id,
const std::string& output_names) {
ONNX_NAMESPACE::GraphProto graph;
graph.set_name("PaddlePaddle fill_constant Graph " + std::to_string(op_id));
auto op = parser.GetOpDesc(block_id, op_id); // fill_constant
auto out_info = parser.GetOpOutput(block_id, op_id, "Out");

*(graph.add_output()) = (*MakeValueInfo(out_info[0]));
for (auto& item : temp_helper->nodes) {
if (item->output(0) == output_names) {
*(graph.add_node()) = (*item.get());
break;
}
}

return std::move(graph);
}
ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
PaddlePirParser& pir_parser,
pir::Block* block,
Expand Down Expand Up @@ -618,27 +639,51 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
Assert(input_info.size() == 2,
"Only support when number of select_input's input_node is 2.");

// 构建 else 分支图
// Build else sub graph
auto else_node_name = input_info[0].name;
auto conditional_block_cood_it = sub_block_map_.find(else_node_name);
Assert(conditional_block_cood_it != sub_block_map_.end(),
"Don't find select_input else_input node.");
"Can't find select_input else_input node.");
auto conditional_block_cood = conditional_block_cood_it->second;
auto else_graph = ExportConditionalBlock(parser,
conditional_block_cood.first,
conditional_block_cood.second,
else_node_name);
ONNX_NAMESPACE::GraphProto else_graph, then_graph;
auto else_node = parser.GetOpDesc(conditional_block_cood.first,
conditional_block_cood.second);

if (else_node.type().find("conditional_block") != std::string::npos) {
else_graph = ExportConditionalBlock(parser,
conditional_block_cood.first,
conditional_block_cood.second,
else_node_name);
} else {
else_graph = ExportFillConstant(parser,
&temp_helper,
conditional_block_cood.first,
conditional_block_cood.second,
else_node_name);
}

// 构建 then 分支图
// Build then sub graph
auto then_node_name = input_info[1].name;
conditional_block_cood_it = sub_block_map_.find(then_node_name);
Assert(conditional_block_cood_it != sub_block_map_.end(),
"Don't find select_input then_input node.");
"Can't find select_input then_input node.");
conditional_block_cood = conditional_block_cood_it->second;
auto then_graph = ExportConditionalBlock(parser,
conditional_block_cood.first,
conditional_block_cood.second,
then_node_name);
auto then_node = parser.GetOpDesc(conditional_block_cood.first,
conditional_block_cood.second);

// use node.type() to make sure correctness
if (then_node.type().find("conditional_block") != std::string::npos) {
then_graph = ExportConditionalBlock(parser,
conditional_block_cood.first,
conditional_block_cood.second,
then_node_name);
} else {
then_graph = ExportFillConstant(parser,
&temp_helper,
conditional_block_cood.first,
conditional_block_cood.second,
then_node_name);
}

auto cond_info = parser.GetOpInput(block_id, op_id, "Mask");
auto output_info = parser.GetOpOutput(block_id, op_id, "Out");
Expand All @@ -649,6 +694,9 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(
AddAttribute(node, "then_branch", then_graph);
AddAttribute(node, "else_branch", else_graph);
continue;
} else if (op.type() == "fill_constant") {
auto out_info = parser.GetOpOutput(block_id, op_id, "Out");
sub_block_map_[out_info[0].name] = {block_id, op_id};
}
ExportOp(parser, &temp_helper, opset_version_, block_id, op_id, verbose_);
}
Expand Down
8 changes: 8 additions & 0 deletions paddle2onnx/mapper/exporter.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,20 @@ class ModelExporter {
ONNX_NAMESPACE::GraphProto ExportIfBlock(PaddlePirParser& pir_parser,
pir::Block& block);

ONNX_NAMESPACE::GraphProto ExportFillConstant(
const PaddleParser& parser,
OnnxHelper* temp_helper,
int32_t block_id,
int32_t op_id,
const std::string& output_names);

ONNX_NAMESPACE::GraphProto ExportBlock(
const PaddleParser& parser,
int32_t block_id,
std::vector<std::shared_ptr<ONNX_NAMESPACE::NodeProto>>& parameters,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& inputs,
std::vector<std::shared_ptr<ONNX_NAMESPACE::ValueInfoProto>>& outputs);

ONNX_NAMESPACE::GraphProto ExportBlock(
PaddlePirParser& pir_parser,
pir::Block* block,
Expand Down
6 changes: 5 additions & 1 deletion tests/onnxbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def compare(result, expect, delta=1e-10, rtol=1e-10):
# Convert Paddle Tensor to Numpy array
if isinstance(expect, list):
expect = expect[0]
expect = expect.numpy()

if isinstance(expect, paddle.Tensor):
expect = expect.numpy()
else:
expect = np.array(expect)

# For result_shape is (1) and expect_shape shape is ()
expect = expect.squeeze()
Expand Down
107 changes: 100 additions & 7 deletions tests/test_ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import paddle
from onnxbase import APIOnnx
from onnxbase import randtool


class BaseNet1(paddle.nn.Layer):
def __init__(self):
Expand All @@ -26,46 +26,139 @@ def forward(self, inputs):
else:
return inputs * 3


def test_ifelse_1_true():
op = BaseNet1()
op.eval()
obj = APIOnnx(op, 'ifelse', [11])
obj = APIOnnx(op, "ifelse", [11])
obj.set_input_data("input_data", paddle.to_tensor(1))
obj.run()


def test_ifelse_1_false():
op = BaseNet1()
op.eval()
obj = APIOnnx(op, 'ifelse', [11])
obj = APIOnnx(op, "ifelse", [11])
obj.set_input_data("input_data", paddle.to_tensor(2))
obj.run()


class BaseNet2(paddle.nn.Layer):
def __init__(self):
super(BaseNet2, self).__init__()

def forward(self, cond, inputs):
if cond == 1:
return inputs * 1, inputs * 2
return inputs * 1, inputs * 2
else:
return inputs * 3, inputs * 4


def test_ifelse_2_true():
op = BaseNet2()
op.eval()
obj = APIOnnx(op, 'ifelse', [11])
obj = APIOnnx(op, "ifelse", [11])
obj.set_input_data("input_data", paddle.to_tensor(1), paddle.to_tensor(1))
obj.run()


def test_ifelse_2_false():
op = BaseNet2()
op.eval()
obj = APIOnnx(op, 'ifelse', [11])
obj = APIOnnx(op, "ifelse", [11])
obj.set_input_data("input_data", paddle.to_tensor(2), paddle.to_tensor(1))
obj.run()


class BaseNet3(paddle.nn.Layer):
def __init__(self):
super(BaseNet3, self).__init__()

def forward(self, inputs):
if inputs == 1:
return 1
else:
return 2


def test_ifelse_3_true():
op = BaseNet3()
op.eval()
obj = APIOnnx(op, "ifelse", [11])
obj.set_input_data("input_data", paddle.to_tensor(1))
obj.run()


def test_ifelse_3_false():
op = BaseNet3()
op.eval()
obj = APIOnnx(op, "ifelse", [11])
obj.set_input_data("input_data", paddle.to_tensor(2))
obj.run()


class BaseNet4(paddle.nn.Layer):
def __init__(self):
super(BaseNet4, self).__init__()

def forward(self, inputs):
if inputs == 1:
return inputs + 1
else:
return 2


def test_ifelse_4_true():
op = BaseNet4()
op.eval()
obj = APIOnnx(op, "ifelse", [11])
obj.set_input_data("input_data", paddle.to_tensor(1))
obj.run()


def test_ifelse_4_false():
op = BaseNet4()
op.eval()
obj = APIOnnx(op, "ifelse", [11])
obj.set_input_data("input_data", paddle.to_tensor(2))
obj.run()


class BaseNet5(paddle.nn.Layer):
def __init__(self):
super(BaseNet5, self).__init__()

def forward(self, inputs):
if inputs == 1:
return 1, 2
else:
return 2, 3


def test_ifelse_5_true():
op = BaseNet5()
op.eval()
obj = APIOnnx(op, "ifelse", [11])
obj.set_input_data("input_data", paddle.to_tensor(1))
obj.run()


def test_ifelse_5_false():
op = BaseNet5()
op.eval()
obj = APIOnnx(op, "ifelse", [11])
obj.set_input_data("input_data", paddle.to_tensor(2))
obj.run()


if __name__ == "__main__":
test_ifelse_1_true()
test_ifelse_1_false()
test_ifelse_2_true()
test_ifelse_2_false()
test_ifelse_2_false()
test_ifelse_3_true()
test_ifelse_3_false()
test_ifelse_4_true()
test_ifelse_4_false()
test_ifelse_5_true()
test_ifelse_5_false()

0 comments on commit 3f2dd54

Please sign in to comment.