Skip to content

Commit

Permalink
format pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Jun 21, 2024
1 parent f01c295 commit d64e1ae
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions onnxslim/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def graph_constant_fold_inplace(graph):
delete_node(node, idx)
logger.debug(f"removing Add op: {node.name}")
elif node.op == "Expand":
# tests/test_onnx_nets.py::TestTimmClass::test_timm[lambda_resnet26rpt_256]
if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant) and np.all(node.inputs[1].values == 1):
delete_node(node)
logger.debug(f"removing Expand op: {node.name}")
Expand All @@ -179,10 +180,10 @@ class PadConvMatcher(PatternMatcher):
def __init__(self, priority):
pattern = Pattern(
"""
input input 0 1 pad_0
input input 0 1 pad_0
Pad pad_0 1+ 1 input conv_0
Conv conv_0 1+ 1 pad_0 output
output output 1 0 conv_0
output output 1 0 conv_0
"""
)
super().__init__(pattern, priority)
Expand Down Expand Up @@ -252,10 +253,10 @@ class ConvBatchNormMatcher(PatternMatcher):
def __init__(self, priority):
pattern = Pattern(
"""
input input 0 1 conv_0
Conv conv_0 3 1 input ? ? bn_0
BatchNormalization bn_0 5 1 conv_0 ? ? ? ? output
output output 1 0 bn_0
input input 0 1 conv_0
Conv conv_0 1+ 1 input bn_0
BatchNormalization bn_0 5 1 conv_0 ? ? ? ? output
output output 1 0 bn_0
"""
)
super().__init__(pattern, priority)
Expand Down Expand Up @@ -335,7 +336,7 @@ def __init__(self, priority):
input input 0 1 slice_0
Slice slice_0 5 1 input ? ? ? ? slice_1
Slice slice_1 5 1 slice_0 ? ? ? ? output
output output 1 0 slice_1
output output 1 0 slice_1
"""
) # to check here slice_0
super().__init__(pattern, priority)
Expand Down Expand Up @@ -434,10 +435,10 @@ class ReshapePatternMatcher(PatternMatcher):
def __init__(self, priority):
pattern = Pattern(
"""
input input 0 1 reshape_0
input input 0 1 reshape_0
Reshape reshape_0 2 1 input ? reshape_1
Reshape reshape_1 2 1 reshape_0 ? output
output output 1 0 reshape_1
output output 1 0 reshape_1
"""
)
super().__init__(pattern, priority)
Expand Down

0 comments on commit d64e1ae

Please sign in to comment.