From 83708b8b8553e4263d221cb678126a8115fc9a80 Mon Sep 17 00:00:00 2001 From: inisis Date: Wed, 19 Jun 2024 16:22:21 +0000 Subject: [PATCH] add PatternGenerator and test --- onnxslim/core/graph_rewriter.py | 49 ++++++++++++++++++++- tests/test_pattern_generator.py | 77 +++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 2 deletions(-) create mode 100644 tests/test_pattern_generator.py diff --git a/onnxslim/core/graph_rewriter.py b/onnxslim/core/graph_rewriter.py index 7c6dbdf..8db8311 100644 --- a/onnxslim/core/graph_rewriter.py +++ b/onnxslim/core/graph_rewriter.py @@ -2,6 +2,7 @@ from abc import ABCMeta, abstractmethod from onnxslim.utils import logger +import onnxslim.onnx_graphsurgeon as gs from onnxslim.onnx_graphsurgeon import Constant @@ -9,6 +10,8 @@ def get_node_users(node): """Retrieve the list of nodes that use the outputs of the given node.""" users = [] for output in node.outputs: # output is a Variable + if len(output.outputs) == 0: + users.append(output) users.extend(iter(output.outputs)) return users @@ -30,6 +33,15 @@ def get_node_feeds(node): return feeds +def get_name(name): + _illegal_char_regex = re.compile("[^0-9a-zA-Z_]+") + sanitized_name = _illegal_char_regex.sub("_", name) + if sanitized_name.isdigit(): + sanitized_name = "_" + sanitized_name + + return sanitized_name + + class NodeDescriptor: def __init__(self, node_spec): if not isinstance(node_spec, list): @@ -88,9 +100,10 @@ def __init__(self, pattern, priority): self.pattern = pattern self.priority = priority self.pattern_dict = {node.name: node for node in pattern.nodes} + self.output_names = [node.name for node in pattern.nodes if node.op == 'output'] def get_match_point(self): - return self.pattern_dict[self.pattern_dict['output'].input_names[0]] + return self.pattern_dict[self.pattern_dict[self.output_names[0]].input_names[0]] def match(self, node): match_point = self.get_match_point() @@ -141,4 +154,36 @@ def rewrite(self): raise NotImplementedError('rewrite method must be implemented') def parameter_check(self): - return True \ No newline at end of file + return True + + +class PatternGenerator: + def __init__(self, onnx_model): + self.graph = gs.import_onnx(onnx_model) + self.graph.fold_constants().cleanup().toposort() + + def generate(self): + inputs = self.graph.inputs + outputs = self.graph.outputs + nodes = self.graph.nodes + + template = [] + for input in inputs: + name = get_name(input.name) + template.append(' '.join(['input', name, '0', str(len(input.outputs))] + + [get_name(output.name) for output in input.outputs])) + + for node in nodes: + if node.op != "Constant": + name = get_name(node.name) + feeds = get_node_feeds(node) + users = get_node_users(node) + template.append(' '.join([node.op, name, str(len(feeds)), str(len(users))] + + [get_name(feed.name) if not isinstance(feed, Constant) else '?' for feed in feeds] + + [get_name(user.name) if not isinstance(user, Constant) else '?' for user in users])) + + for output in outputs: + name = get_name(output.name) + template.append(' '.join(['output', name, str(len(output.inputs)), '0'] + [get_name(input.name) for input in output.inputs])) + + return '\n'.join(template) diff --git a/tests/test_pattern_generator.py b/tests/test_pattern_generator.py new file mode 100644 index 0000000..97389ac --- /dev/null +++ b/tests/test_pattern_generator.py @@ -0,0 +1,77 @@ +import os +import pytest + +import torch +import torch.nn as nn + +import onnx +from onnxslim import slim +from onnxslim import register_fusion_pattern +from onnxslim.core.graph_rewriter import PatternMatcher, Pattern, PatternGenerator + + +class TestPatternGenerator: + def test_gelu(self, request): + class PatternModel(nn.Module): + def __init__(self): + super(PatternModel, self).__init__() + self.gelu = nn.GELU() + + def forward(self, x): + x = self.gelu(x) + return x + + class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.relu0 = nn.ReLU() + self.pattern = PatternModel() + self.relu1 = nn.ReLU() + + def forward(self, x): + x = self.relu0(x) + x = self.pattern(x) + x = self.relu1(x) + return x + + input = torch.randn(2) + p = PatternModel() + m = Model() + directory = "tmp/" + request.node.name + os.makedirs(directory, exist_ok=True) + + pattern_filename = f"{directory}/{request.node.name}.onnx" + torch.onnx.export(p, input, pattern_filename) + + model_filename = f"{directory}/{request.node.name}.onnx" + torch.onnx.export(m, input, model_filename) + + model = onnx.load(pattern_filename) + pgen = PatternGenerator(model) + template = pgen.generate() + pattern = Pattern(template) + + class GeluMatcher(PatternMatcher): + def __init__(self, pattern, priority): + super().__init__(pattern, priority) + + @property + def name(self): + return "FusionGelu" + + def rewrite(self): + raise Exception("Pattern Matched") + + register_fusion_pattern(GeluMatcher(pattern, 1)) + slim(model_filename, f"{directory}/{request.node.name}_slim.onnx") + + +if __name__ == "__main__": + pytest.main( + [ + "-p", + "no:warnings", + "-sv", + "tests/test_pattern_generator.py", + ] + )