Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
UltralyticsAssistant committed Jun 20, 2024
1 parent 860ca7f commit 51c53e7
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 58 deletions.
68 changes: 41 additions & 27 deletions onnxslim/core/graph_rewriter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import re
from abc import ABCMeta, abstractmethod
from abc import abstractmethod

from onnxslim.utils import logger
import onnxslim.onnx_graphsurgeon as gs
from onnxslim.onnx_graphsurgeon import Constant
from onnxslim.utils import logger


def get_node_users(node):
Expand Down Expand Up @@ -42,35 +42,35 @@ def get_name(name):
class NodeDescriptor:
def __init__(self, node_spec):
if not isinstance(node_spec, list):
raise ValueError('node_spec must be a list')
raise ValueError("node_spec must be a list")
if len(node_spec) < 4:
raise ValueError(f'node_spec must have at least 4 elements {node_spec}')
raise ValueError(f"node_spec must have at least 4 elements {node_spec}")

def get_input_info(io_spec):
if not io_spec.isdigit():
pattern_with_plus = re.search(r'(\d+)(\+)', io_spec)
pattern_with_plus = re.search(r"(\d+)(\+)", io_spec)
if pattern_with_plus:
return int(pattern_with_plus.group(1)), True
else:
raise ValueError(f'input_num and output_num must be integers {io_spec}')
raise ValueError(f"input_num and output_num must be integers {io_spec}")

return int(io_spec), False

self.op = node_spec[0]
self.name = node_spec[1]
self.input_num, self.coarse_input_num = get_input_info(node_spec[2])
self.output_num, self.coarse_output_num = get_input_info(node_spec[3])
self.input_names = node_spec[4:4 + self.input_num]
self.output_names = node_spec[4 + self.input_num:]
self.input_names = node_spec[4 : 4 + self.input_num]
self.output_names = node_spec[4 + self.input_num :]
assert len(self.input_names) == self.input_num
assert len(self.output_names) == self.output_num, f'{self.name} {len(self.output_names)} != {self.output_num}'
assert len(self.output_names) == self.output_num, f"{self.name} {len(self.output_names)} != {self.output_num}"

def __repr__(self):
return f'name: {self.name}, type: {self.op}, input_num: {self.input_num}, output_num: {self.output_num}, input_names: {self.input_names}, output_names: {self.output_names}'
return f"name: {self.name}, type: {self.op}, input_num: {self.input_num}, output_num: {self.output_num}, input_names: {self.input_names}, output_names: {self.output_names}"

def __dict__(self):
return {
'name': self,
"name": self,
}


Expand All @@ -80,7 +80,7 @@ def __init__(self, pattern):
self.nodes = self.parse_nodes()

def parse_nodes(self):
nodes = self.pattern.split('\n')
nodes = self.pattern.split("\n")
nodes = [line.strip().split() for line in nodes if line]
nodes = [NodeDescriptor(node) for node in nodes if node]
return nodes
Expand All @@ -97,7 +97,7 @@ def __init__(self, pattern, priority):
self.pattern = pattern
self.priority = priority
self.pattern_dict = {node.name: node for node in pattern.nodes}
self.output_names = [node.name for node in pattern.nodes if node.op == 'output']
self.output_names = [node.name for node in pattern.nodes if node.op == "output"]

def get_match_point(self):
return self.pattern_dict[self.pattern_dict[self.output_names[0]].input_names[0]]
Expand All @@ -106,11 +106,11 @@ def match(self, node):
match_point = self.get_match_point()

def match_(node, pattern_node):
if pattern_node.op == 'input':
if pattern_node.op == "input":
return True

# node is an input variable
if not hasattr(node, 'op'):
if not hasattr(node, "op"):
return False

if node.op == pattern_node.op:
Expand All @@ -122,15 +122,18 @@ def match_(node, pattern_node):
return False
else:
if len(node_feeds) != len(pattern_node.input_names):
logger.debug('len(node_feeds) != len(pattern_node.input_names)',
len(node_feeds), len(pattern_node.input_names))
logger.debug(
"len(node_feeds) != len(pattern_node.input_names)",
len(node_feeds),
len(pattern_node.input_names),
)
return False

pattern_nodes = [self.pattern_dict[name] if name != "?" else None for name in pattern_node.input_names]
all_match = True
for node_feed, pattern_node in zip(node_feeds, pattern_nodes):
if pattern_node is not None:
node_match = match_(node_feed, pattern_node)
node_match = match_(node_feed, pattern_node)
if not node_match:
return False
setattr(self, pattern_node.name, node_feed)
Expand All @@ -140,15 +143,15 @@ def match_(node, pattern_node):
return False

if match_(node, match_point):
setattr(self, 'output', node.outputs)
setattr(self, "output", node.outputs)
if self.parameter_check():
return True

return False

@abstractmethod
def rewrite(self):
raise NotImplementedError('rewrite method must be implemented')
raise NotImplementedError("rewrite method must be implemented")

def parameter_check(self):
return True
Expand All @@ -167,20 +170,31 @@ def generate(self):
template = []
for input in inputs:
name = get_name(input.name)
template.append(' '.join(['input', name, '0', str(len(input.outputs))] +
[get_name(output.name) for output in input.outputs]))
template.append(
" ".join(
["input", name, "0", str(len(input.outputs))] + [get_name(output.name) for output in input.outputs]
)
)

for node in nodes:
if node.op != "Constant":
name = get_name(node.name)
feeds = get_node_feeds(node)
users = get_node_users(node)
template.append(' '.join([node.op, name, str(len(feeds)), str(len(users))] +
[get_name(feed.name) if not isinstance(feed, Constant) else '?' for feed in feeds] +
[get_name(user.name) if not isinstance(user, Constant) else '?' for user in users]))
template.append(
" ".join(
[node.op, name, str(len(feeds)), str(len(users))]
+ [get_name(feed.name) if not isinstance(feed, Constant) else "?" for feed in feeds]
+ [get_name(user.name) if not isinstance(user, Constant) else "?" for user in users]
)
)

for output in outputs:
name = get_name(output.name)
template.append(' '.join(['output', name, str(len(output.inputs)), '0'] + [get_name(input.name) for input in output.inputs]))
template.append(
" ".join(
["output", name, str(len(output.inputs)), "0"] + [get_name(input.name) for input in output.inputs]
)
)

return '\n'.join(template)
return "\n".join(template)
57 changes: 42 additions & 15 deletions onnxslim/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
import onnx

import onnxslim.onnx_graphsurgeon as gs
from onnxslim.core.graph_rewriter import (
Pattern,
PatternMatcher,
get_node_feeds,
get_node_users,
)
from onnxslim.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx
from onnxslim.onnx_graphsurgeon.ir.graph import Graph
from onnxslim.onnx_graphsurgeon.ir.tensor import Constant, Variable
from onnxslim.utils import logger
from onnxslim.core.graph_rewriter import PatternMatcher, Pattern, get_node_feeds, get_node_users

DEFAULT_FUSION_PATTERNS = OrderedDict()

Expand Down Expand Up @@ -173,12 +178,13 @@ def graph_constant_fold_inplace(graph):
class PadConvMatcher(PatternMatcher):
def __init__(self, priority):
pattern = Pattern(
'''
"""
input input 0 1 pad_0
Pad pad_0 1+ 1 input conv_0
Conv conv_0 1+ 1 pad_0 output
output output 1 0 conv_0
''')
"""
)
super().__init__(pattern, priority)

@property
Expand Down Expand Up @@ -234,17 +240,20 @@ def rewrite(self):

return match_case


register_fusion_pattern(PadConvMatcher(1))


class ConvBatchNormMatcher(PatternMatcher):
def __init__(self, priority):
pattern = Pattern(
'''
"""
input input 0 1 conv_0
Conv conv_0 3 1 input ? ? bn_0
BatchNormalization bn_0 5 1 conv_0 ? ? ? ? output
output output 1 0 bn_0
''')
"""
)
super().__init__(pattern, priority)

@property
Expand Down Expand Up @@ -309,17 +318,20 @@ def rewrite(self):

return match_case


register_fusion_pattern(ConvBatchNormMatcher(1))


class SlicePatternMatcher(PatternMatcher):
def __init__(self, priority):
pattern = Pattern(
'''
"""
input input 0 1 slice_0
Slice slice_0 5 1 input ? ? ? ? slice_1
Slice slice_1 5 1 slice_0 ? ? ? ? output
output output 1 0 slice_1
''') # to check here slice_0
"""
) # to check here slice_0
super().__init__(pattern, priority)

@property
Expand Down Expand Up @@ -406,17 +418,20 @@ def rewrite(self):

return match_case


register_fusion_pattern(SlicePatternMatcher(1))


class ReshapePatternMatcher(PatternMatcher):
def __init__(self, priority):
pattern = Pattern(
'''
"""
input input 0 1 reshape_0
Reshape reshape_0 2 1 input ? reshape_1
Reshape reshape_1 2 1 reshape_0 ? output
output output 1 0 reshape_1
''')
"""
)
super().__init__(pattern, priority)

@property
Expand All @@ -431,6 +446,7 @@ def rewrite(self):
first_reshape_node_users = get_node_users(first_reshape_node)
if len(first_reshape_node_users) == 1:
second_reshape_node = node

def check_constant_mergeable(reshape_node):
if isinstance(reshape_node.inputs[1], Constant):
input_shape = reshape_node.inputs[0].shape
Expand Down Expand Up @@ -463,17 +479,20 @@ def check_constant_mergeable(reshape_node):

return match_case


register_fusion_pattern(ReshapePatternMatcher(1))


class MatMulAddPatternMatcher(PatternMatcher):
def __init__(self, priority):
pattern = Pattern(
'''
"""
input input 0 1 matmul_0
MatMul matmul_0 2 1 input ? add_0
Add add_0 2 1 matmul_0 ? output
output output 1 0 add_0
''')
"""
)
super().__init__(pattern, priority)

@property
Expand Down Expand Up @@ -623,20 +642,23 @@ def rewrite(self):
)
return match_case


register_fusion_pattern(MatMulAddPatternMatcher(1))


class GeluPatternMatcher(PatternMatcher):
def __init__(self, priority):
pattern = Pattern(
'''
"""
input input 0 2 mul_0 div_0
Div div_0 2 1 input ? erf_0
Erf erf_0 1 1 div_0 add_0
Add add_0 2 1 erf_0 ? mul_0
Mul mul_0 2 1 input add_0 mul_1
Mul mul_1 2 1 mul_0 ? output
output output 1 0 mul_1
''')
"""
)
super().__init__(pattern, priority)

@property
Expand Down Expand Up @@ -664,17 +686,20 @@ def rewrite(self):

return match_case


# register_fusion_pattern(GeluPatternMatcher(1))


class ReducePatternMatcher(PatternMatcher):
def __init__(self, priority):
pattern = Pattern(
'''
"""
input input 0 1 reduce_0
ReduceSum reduce_0 1 1 input unsqueeze_0
Unsqueeze unsqueeze_0 1 1 reduce_0 output
output output 1 0 unsqueeze_0
''')
"""
)
super().__init__(pattern, priority)

@property
Expand Down Expand Up @@ -712,8 +737,10 @@ def rewrite(self, opset=11):

return match_case


register_fusion_pattern(ReducePatternMatcher(1))


@gs.Graph.register()
def replace_custom_layer(
self,
Expand Down
10 changes: 7 additions & 3 deletions onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,13 @@ def export_onnx(graph: Graph, do_type_check=True, **kwargs) -> "onnx.ModelProto"
intersection = None
else:
intersection = (
{key: graph_constants_list[0][key] for key in graph_constants_list[0]
if all(key in d and graph_constants_list[0][key] == d[key] for d in graph_constants_list[1:])}
if graph_constants_list else None
{
key: graph_constants_list[0][key]
for key in graph_constants_list[0]
if all(key in d and graph_constants_list[0][key] == d[key] for d in graph_constants_list[1:])
}
if graph_constants_list
else None
)

onnx_graph = OnnxExporter.export_graph(
Expand Down
3 changes: 2 additions & 1 deletion onnxslim/onnx_graphsurgeon/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,8 +1178,9 @@ def should_eval_foldable(tensor):
names = [t.name for t in graph_clone.outputs]
try:
import os
import tempfile

import onnx
import tempfile
import onnxruntime as onnxrt

onnx_model = export_onnx(graph_clone, do_type_check=False)
Expand Down
Loading

0 comments on commit 51c53e7

Please sign in to comment.