diff --git a/onnxslim/core/pattern/fusion/padconv.py b/onnxslim/core/pattern/fusion/padconv.py index a8993ba..42bf176 100644 --- a/onnxslim/core/pattern/fusion/padconv.py +++ b/onnxslim/core/pattern/fusion/padconv.py @@ -36,13 +36,19 @@ def rewrite(self, opset=11): pad_node_users = get_node_users(pad_node) pad_inputs = len(pad_node.inputs) - if pad_inputs < 3 or (pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Constant) and pad_node.inputs[2].values == 0)): - if isinstance(pad_node.inputs[1], gs.Constant) and pad_node.attrs["mode"] == "constant" and conv_node.inputs[1].shape: + if pad_inputs < 3 or ( + pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Constant) and pad_node.inputs[2].values == 0) + ): + if ( + isinstance(pad_node.inputs[1], gs.Constant) + and pad_node.attrs["mode"] == "constant" + and conv_node.inputs[1].shape + ): conv_weight_dim = len(conv_node.inputs[1].shape) pad_value = pad_node.inputs[1].values.tolist() - if all(pad == 0 for pad in (pad_value[:2] + pad_value[conv_weight_dim: conv_weight_dim+2])): - spatial_dim = conv_weight_dim - 2 + if all(pad == 0 for pad in (pad_value[:2] + pad_value[conv_weight_dim : conv_weight_dim + 2])): + conv_weight_dim - 2 input_variable = self.pad_0.inputs[0] pad_variable = pad_node.outputs[0] # pad output variable index = conv_node.inputs.index(pad_variable) @@ -61,19 +67,20 @@ def rewrite(self, opset=11): pad_node.inputs.clear() pad_node.outputs.clear() conv_pads = attrs["pads"] - pads = pad_value[2 : conv_weight_dim] + pad_value[conv_weight_dim + 2:] + pads = pad_value[2:conv_weight_dim] + pad_value[conv_weight_dim + 2 :] pads = [pad + conv_pad for pad, conv_pad in zip(pads, conv_pads)] attrs["pads"] = pads match_case[conv_node.name] = { - "op": "Conv", - "inputs": inputs, - "outputs": outputs, - "name": conv_node.name, - "attrs": conv_node.attrs, - "domain": None, - } + "op": "Conv", + "inputs": inputs, + "outputs": outputs, + "name": conv_node.name, + "attrs": conv_node.attrs, + "domain": None, + } return match_case + register_fusion_pattern(PadConvMatcher(1))