Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
UltralyticsAssistant committed Dec 3, 2024
1 parent 538fcf1 commit 3f3c754
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions onnxslim/core/pattern/fusion/padconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@ def rewrite(self, opset=11):
pad_node_users = get_node_users(pad_node)

pad_inputs = len(pad_node.inputs)
if pad_inputs < 3 or (pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Constant) and pad_node.inputs[2].values == 0)):
if isinstance(pad_node.inputs[1], gs.Constant) and pad_node.attrs["mode"] == "constant" and conv_node.inputs[1].shape:
if pad_inputs < 3 or (
pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Constant) and pad_node.inputs[2].values == 0)
):
if (
isinstance(pad_node.inputs[1], gs.Constant)
and pad_node.attrs["mode"] == "constant"
and conv_node.inputs[1].shape
):
conv_weight_dim = len(conv_node.inputs[1].shape)
pad_value = pad_node.inputs[1].values.tolist()

if all(pad == 0 for pad in (pad_value[:2] + pad_value[conv_weight_dim: conv_weight_dim+2])):
spatial_dim = conv_weight_dim - 2
if all(pad == 0 for pad in (pad_value[:2] + pad_value[conv_weight_dim : conv_weight_dim + 2])):
conv_weight_dim - 2
input_variable = self.pad_0.inputs[0]
pad_variable = pad_node.outputs[0] # pad output variable
index = conv_node.inputs.index(pad_variable)
Expand All @@ -61,19 +67,20 @@ def rewrite(self, opset=11):
pad_node.inputs.clear()
pad_node.outputs.clear()
conv_pads = attrs["pads"]
pads = pad_value[2 : conv_weight_dim] + pad_value[conv_weight_dim + 2:]
pads = pad_value[2:conv_weight_dim] + pad_value[conv_weight_dim + 2 :]
pads = [pad + conv_pad for pad, conv_pad in zip(pads, conv_pads)]

attrs["pads"] = pads
match_case[conv_node.name] = {
"op": "Conv",
"inputs": inputs,
"outputs": outputs,
"name": conv_node.name,
"attrs": conv_node.attrs,
"domain": None,
}
"op": "Conv",
"inputs": inputs,
"outputs": outputs,
"name": conv_node.name,
"attrs": conv_node.attrs,
"domain": None,
}

return match_case


register_fusion_pattern(PadConvMatcher(1))

0 comments on commit 3f3c754

Please sign in to comment.