Skip to content

Commit

Permalink
add PatternGenerator and test
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Jun 19, 2024
1 parent 87e491b commit 83708b8
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 2 deletions.
49 changes: 47 additions & 2 deletions onnxslim/core/graph_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
from abc import ABCMeta, abstractmethod

from onnxslim.utils import logger
import onnxslim.onnx_graphsurgeon as gs
from onnxslim.onnx_graphsurgeon import Constant


def get_node_users(node):
"""Retrieve the list of nodes that use the outputs of the given node."""
users = []
for output in node.outputs: # output is a Variable
if len(output.outputs) == 0:
users.append(output)
users.extend(iter(output.outputs))
return users

Expand All @@ -30,6 +33,15 @@ def get_node_feeds(node):
return feeds


def get_name(name):
_illegal_char_regex = re.compile("[^0-9a-zA-Z_]+")
sanitized_name = _illegal_char_regex.sub("_", name)
if sanitized_name.isdigit():
sanitized_name = "_" + sanitized_name

return sanitized_name


class NodeDescriptor:
def __init__(self, node_spec):
if not isinstance(node_spec, list):
Expand Down Expand Up @@ -88,9 +100,10 @@ def __init__(self, pattern, priority):
self.pattern = pattern
self.priority = priority
self.pattern_dict = {node.name: node for node in pattern.nodes}
self.output_names = [node.name for node in pattern.nodes if node.op == 'output']

def get_match_point(self):
return self.pattern_dict[self.pattern_dict['output'].input_names[0]]
return self.pattern_dict[self.pattern_dict[self.output_names[0]].input_names[0]]

def match(self, node):
match_point = self.get_match_point()
Expand Down Expand Up @@ -141,4 +154,36 @@ def rewrite(self):
raise NotImplementedError('rewrite method must be implemented')

def parameter_check(self):
return True
return True


class PatternGenerator:
def __init__(self, onnx_model):
self.graph = gs.import_onnx(onnx_model)
self.graph.fold_constants().cleanup().toposort()

def generate(self):
inputs = self.graph.inputs
outputs = self.graph.outputs
nodes = self.graph.nodes

template = []
for input in inputs:
name = get_name(input.name)
template.append(' '.join(['input', name, '0', str(len(input.outputs))] +
[get_name(output.name) for output in input.outputs]))

for node in nodes:
if node.op != "Constant":
name = get_name(node.name)
feeds = get_node_feeds(node)
users = get_node_users(node)
template.append(' '.join([node.op, name, str(len(feeds)), str(len(users))] +
[get_name(feed.name) if not isinstance(feed, Constant) else '?' for feed in feeds] +
[get_name(user.name) if not isinstance(user, Constant) else '?' for user in users]))

for output in outputs:
name = get_name(output.name)
template.append(' '.join(['output', name, str(len(output.inputs)), '0'] + [get_name(input.name) for input in output.inputs]))

return '\n'.join(template)
77 changes: 77 additions & 0 deletions tests/test_pattern_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os
import pytest

import torch
import torch.nn as nn

import onnx
from onnxslim import slim
from onnxslim import register_fusion_pattern
from onnxslim.core.graph_rewriter import PatternMatcher, Pattern, PatternGenerator


class TestPatternGenerator:
def test_gelu(self, request):
class PatternModel(nn.Module):
def __init__(self):
super(PatternModel, self).__init__()
self.gelu = nn.GELU()

def forward(self, x):
x = self.gelu(x)
return x

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.relu0 = nn.ReLU()
self.pattern = PatternModel()
self.relu1 = nn.ReLU()

def forward(self, x):
x = self.relu0(x)
x = self.pattern(x)
x = self.relu1(x)
return x

input = torch.randn(2)
p = PatternModel()
m = Model()
directory = "tmp/" + request.node.name
os.makedirs(directory, exist_ok=True)

pattern_filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(p, input, pattern_filename)

model_filename = f"{directory}/{request.node.name}.onnx"
torch.onnx.export(m, input, model_filename)

model = onnx.load(pattern_filename)
pgen = PatternGenerator(model)
template = pgen.generate()
pattern = Pattern(template)

class GeluMatcher(PatternMatcher):
def __init__(self, pattern, priority):
super().__init__(pattern, priority)

@property
def name(self):
return "FusionGelu"

def rewrite(self):
raise Exception("Pattern Matched")

register_fusion_pattern(GeluMatcher(pattern, 1))
slim(model_filename, f"{directory}/{request.node.name}_slim.onnx")


if __name__ == "__main__":
pytest.main(
[
"-p",
"no:warnings",
"-sv",
"tests/test_pattern_generator.py",
]
)

0 comments on commit 83708b8

Please sign in to comment.