diff --git a/onnxslim/core/pattern/fusion/convbn.py b/onnxslim/core/pattern/fusion/convbn.py index f4aa0e0..56f1e13 100644 --- a/onnxslim/core/pattern/fusion/convbn.py +++ b/onnxslim/core/pattern/fusion/convbn.py @@ -29,7 +29,7 @@ def rewrite(self, opset=11): conv_transpose_node = self.conv_0 conv_transpose_node_users = get_node_users(conv_transpose_node) node = self.bn_0 - if len(conv_transpose_node_users) == 1: + if len(conv_transpose_node_users) == 1 and all([isinstance(value, gs.Constant) for value in node.inputs[1:]]): conv_transpose_weight = conv_transpose_node.inputs[1].values bn_node = node bn_scale = bn_node.inputs[1].values