diff --git a/onnxslim/core/optimizer.py b/onnxslim/core/optimizer.py index 7b0a18d..23c6565 100644 --- a/onnxslim/core/optimizer.py +++ b/onnxslim/core/optimizer.py @@ -162,6 +162,7 @@ def graph_constant_fold_inplace(graph): delete_node(node, idx) logger.debug(f"removing Add op: {node.name}") elif node.op == "Expand": + # tests/test_onnx_nets.py::TestTimmClass::test_timm[lambda_resnet26rpt_256] if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant) and np.all(node.inputs[1].values == 1): delete_node(node) logger.debug(f"removing Expand op: {node.name}") @@ -179,10 +180,10 @@ class PadConvMatcher(PatternMatcher): def __init__(self, priority): pattern = Pattern( """ - input input 0 1 pad_0 + input input 0 1 pad_0 Pad pad_0 1+ 1 input conv_0 Conv conv_0 1+ 1 pad_0 output - output output 1 0 conv_0 + output output 1 0 conv_0 """ ) super().__init__(pattern, priority) @@ -252,10 +253,10 @@ class ConvBatchNormMatcher(PatternMatcher): def __init__(self, priority): pattern = Pattern( """ - input input 0 1 conv_0 - Conv conv_0 3 1 input ? ? bn_0 - BatchNormalization bn_0 5 1 conv_0 ? ? ? ? output - output output 1 0 bn_0 + input input 0 1 conv_0 + Conv conv_0 1+ 1 input bn_0 + BatchNormalization bn_0 5 1 conv_0 ? ? ? ? output + output output 1 0 bn_0 """ ) super().__init__(pattern, priority) @@ -335,7 +336,7 @@ def __init__(self, priority): input input 0 1 slice_0 Slice slice_0 5 1 input ? ? ? ? slice_1 Slice slice_1 5 1 slice_0 ? ? ? ? output - output output 1 0 slice_1 + output output 1 0 slice_1 """ ) # to check here slice_0 super().__init__(pattern, priority) @@ -434,10 +435,10 @@ class ReshapePatternMatcher(PatternMatcher): def __init__(self, priority): pattern = Pattern( """ - input input 0 1 reshape_0 + input input 0 1 reshape_0 Reshape reshape_0 2 1 input ? reshape_1 Reshape reshape_1 2 1 reshape_0 ? output - output output 1 0 reshape_1 + output output 1 0 reshape_1 """ ) super().__init__(pattern, priority)