diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 10d3d6c..7543470 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,32 +2,32 @@ name: CI on: push: - branches: [ "main" ] + branches: ["main"] pull_request: - branches: [ "main" ] + branches: ["main"] jobs: build: runs-on: self-hosted steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' + - uses: actions/setup-python@v4 + with: + python-version: "3.10" - - name: model zoo test - run: | - python -m pip install --upgrade pip wheel setuptools - pip install . - pip install pytest onnxruntime - python tests/test_folder.py --model-dir /root/data/modelzoo + - name: model zoo test + run: | + python -m pip install --upgrade pip wheel setuptools + pip install . + pip install pytest onnxruntime + python tests/test_folder.py --model-dir /root/data/modelzoo - - name: pattern matcher test - run: | - python tests/test_pattern_matcher.py + - name: pattern matcher test + run: | + python tests/test_pattern_matcher.py - - name: pattern generator test - run: | - python tests/test_pattern_generator.py + - name: pattern generator test + run: | + python tests/test_pattern_generator.py diff --git a/.github/workflows/nightly-test.yml b/.github/workflows/nightly-test.yml index fbe1463..f17e31d 100644 --- a/.github/workflows/nightly-test.yml +++ b/.github/workflows/nightly-test.yml @@ -2,23 +2,22 @@ name: nightly-test on: schedule: - - cron: '0 18 * * *' # Runs at 6:00 PM UTC every day, which is 2:00 AM Beijing Time the next day - + - cron: "0 18 * * *" # Runs at 6:00 PM UTC every day, which is 2:00 AM Beijing Time the next day jobs: build: runs-on: self-hosted steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 - with: - python-version: '3.10' + - uses: actions/setup-python@v4 + with: + python-version: "3.10" - - name: model test - run: | - python -m pip install --upgrade pip wheel setuptools - pip install . - pip install pytest pytest-xdist onnxruntime timm torchvision --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu - python tests/test_onnx_nets.py + - name: model test + run: | + python -m pip install --upgrade pip wheel setuptools + pip install . + pip install pytest pytest-xdist onnxruntime timm torchvision --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu + python tests/test_onnx_nets.py diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index f53f19c..58d0741 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -17,23 +17,22 @@ permissions: jobs: deploy: - runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v3 - with: - python-version: '3.x' - - name: Install dependencies - run: | - python -m pip install --upgrade pip wheel setuptools - pip install build - - name: Build package - run: python -m build - - name: Publish package - uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 - with: - user: __token__ - password: ${{ secrets.PYPI_API_TOKEN }} + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: "3.x" + - name: Install dependencies + run: | + python -m pip install --upgrade pip wheel setuptools + pip install build + - name: Build package + run: python -m build + - name: Publish package + uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/onnxslim/core/graph_rewriter.py b/onnxslim/core/graph_rewriter.py index 1fc3fd5..bef1413 100644 --- a/onnxslim/core/graph_rewriter.py +++ b/onnxslim/core/graph_rewriter.py @@ -25,12 +25,14 @@ def get_node_feeds(node): elif isinstance(input, Constant): feeds.append(input) else: - for feed in input.inputs: - feeds.append(input if feed.op == "Split" else feed) + feeds.extend(input if feed.op == "Split" else feed for feed in input.inputs) return feeds def get_name(name): + """Sanitizes the input string by replacing illegal characters with underscores and prefixing with an underscore if + numeric. + """ _illegal_char_regex = re.compile("[^0-9a-zA-Z_]+") sanitized_name = _illegal_char_regex.sub("_", name) if sanitized_name.isdigit(): @@ -41,16 +43,20 @@ def get_name(name): class NodeDescriptor: def __init__(self, node_spec): + """Initialize NodeDescriptor with node_spec list requiring at least 4 elements.""" if not isinstance(node_spec, list): raise ValueError("node_spec must be a list") if len(node_spec) < 4: raise ValueError(f"node_spec must have at least 4 elements {node_spec}") def get_input_info(io_spec): + """Parses io_spec to return a tuple of (integer, boolean) indicating the presence of a plus sign in the + input. + """ if not io_spec.isdigit(): pattern_with_plus = re.search(r"(\d+)(\+)", io_spec) if pattern_with_plus: - return int(pattern_with_plus.group(1)), True + return int(pattern_with_plus[1]), True else: raise ValueError(f"input_num and output_num must be integers {io_spec}") @@ -66,9 +72,13 @@ def get_input_info(io_spec): assert len(self.output_names) == self.output_num, f"{self.name} {len(self.output_names)} != {self.output_num}" def __repr__(self): + """Return a string representation of the object, including its name, operation type, input/output counts, and + input/output names. + """ return f"name: {self.name}, type: {self.op}, input_num: {self.input_num}, output_num: {self.output_num}, input_names: {self.input_names}, output_names: {self.output_names}" def __dict__(self): + """Returns a dictionary representation of the object, with 'name' as the key.""" return { "name": self, } @@ -76,36 +86,50 @@ def __dict__(self): class Pattern: def __init__(self, pattern): + """Initialize the Pattern class with a given pattern and parse its nodes.""" self.pattern = pattern self.nodes = self.parse_nodes() def parse_nodes(self): + """Parse pattern into a list of NodeDescriptor objects from non-empty, stripped, and split lines.""" nodes = self.pattern.split("\n") nodes = [line.strip().split() for line in nodes if line] nodes = [NodeDescriptor(node) for node in nodes if node] return nodes def match(self, node): + """Match a node against a precompiled pattern.""" return self.pattern.match(node) def __repr__(self): + """Return a string representation of the pattern attribute.""" return self.pattern class PatternMatcher: def __init__(self, pattern, priority): + """Initialize the PatternMatcher with a given pattern and priority, and prepare node references and output + names. + """ self.pattern = pattern self.priority = priority self.pattern_dict = {node.name: node for node in pattern.nodes} self.output_names = [node.name for node in pattern.nodes if node.op == "output"] def get_match_point(self): + """Retrieve the match point node from the pattern dictionary based on output node input names.""" return self.pattern_dict[self.pattern_dict[self.output_names[0]].input_names[0]] def match(self, node): + """Match a given node to a pattern by comparing input names with the match point node from the pattern + dictionary. + """ match_point = self.get_match_point() def match_(node, pattern_node): + """Match a given node to a pattern by comparing input names with the match point node from the pattern + dictionary. + """ if pattern_node.op == "input": return True @@ -151,18 +175,22 @@ def match_(node, pattern_node): @abstractmethod def rewrite(self): + """Abstract method to rewrite the graph based on matched patterns, to be implemented by subclasses.""" raise NotImplementedError("rewrite method must be implemented") def parameter_check(self): + """Check and validate parameters, returning True if valid.""" return True class PatternGenerator: def __init__(self, onnx_model): + """Initialize the PatternGenerator class with an ONNX model and process its graph.""" self.graph = gs.import_onnx(onnx_model) self.graph.fold_constants().cleanup().toposort() def generate(self): + """Generate the inputs, outputs, and nodes from the graph of the initialized ONNX model.""" inputs = self.graph.inputs outputs = self.graph.outputs nodes = self.graph.nodes @@ -184,8 +212,8 @@ def generate(self): template.append( " ".join( [node.op, name, str(len(feeds)), str(len(users))] - + [get_name(feed.name) if not isinstance(feed, Constant) else "?" for feed in feeds] - + [get_name(user.name) if not isinstance(user, Constant) else "?" for user in users] + + ["?" if isinstance(feed, Constant) else get_name(feed.name) for feed in feeds] + + ["?" if isinstance(user, Constant) else get_name(user.name) for user in users] ) ) diff --git a/onnxslim/core/optimizer.py b/onnxslim/core/optimizer.py index b850397..7b0a18d 100644 --- a/onnxslim/core/optimizer.py +++ b/onnxslim/core/optimizer.py @@ -189,17 +189,21 @@ def __init__(self, priority): @property def name(self): + """Returns the name of the fusion pattern used.""" return "FusionPadConv" def parameter_check(self): + """Validates if the padding parameter for a convolutional node is a constant.""" pad_node = self.pad_0 - if not isinstance(pad_node.inputs[1], Constant): - return False - return True + +def parameter_check(self) -> bool: + return isinstance(pad_node.inputs[1], Constant) def rewrite(self): - match_case = {} + """Rewrites the padding parameter for a convolutional node to use a constant if the current parameter is not a + constant. + """ node = self.conv_0 pad_node = self.pad_0 input_variable = self.pad_0.inputs[0] @@ -229,17 +233,17 @@ def rewrite(self): pads = [pad + conv_pad for pad, conv_pad in zip(pads, conv_pads)] attrs["pads"] = pads - match_case[node.name] = { - "op": "Conv", - "inputs": inputs, - "outputs": outputs, - "name": node.name, - "attrs": node.attrs, - "domain": None, + return { + node.name: { + "op": "Conv", + "inputs": inputs, + "outputs": outputs, + "name": node.name, + "attrs": node.attrs, + "domain": None, + } } - return match_case - register_fusion_pattern(PadConvMatcher(1)) @@ -258,9 +262,11 @@ def __init__(self, priority): @property def name(self): + """Returns the name of the FusionConvBN pattern.""" return "FusionConvBN" def rewrite(self): + """Rewrites the weights and biases of a BatchNormalization layer fused with a convolution layer.""" match_case = {} conv_transpose_node = self.conv_0 conv_transpose_node_users = get_node_users(conv_transpose_node) @@ -281,7 +287,7 @@ def rewrite(self): bn_var_rsqrt = 1.0 / np.sqrt(bn_running_var + bn_eps) shape = [1] * len(conv_transpose_weight.shape) - if node.i(0).op == "Conv": + if bn_node.i(0).op == "Conv": shape[0] = -1 else: shape[1] = -1 @@ -294,7 +300,7 @@ def rewrite(self): if weight_name.endswith("weight"): bias_name = f"{weight_name[:-6]}bias" else: - bias_name = weight_name + "_bias" + bias_name = f"{weight_name}_bias" inputs.extend( ( gs.Constant(weight_name, values=conv_w), @@ -336,9 +342,11 @@ def __init__(self, priority): @property def name(self): + """Returns the name of the elimination pattern, 'EliminationSlice'.""" return "EliminationSlice" def rewrite(self): + """Rewrites an elimination pattern for slice nodes by optimizing nested slice operations.""" match_case = {} first_slice_node = self.slice_0 first_slice_node_inputs = list(first_slice_node.inputs) @@ -436,9 +444,13 @@ def __init__(self, priority): @property def name(self): + """Returns the name 'EliminationReshape'.""" return "EliminationReshape" def rewrite(self): + """Rewrite the computational graph by eliminating redundant reshape operations when certain conditions are + met. + """ match_case = {} node = self.reshape_1 first_reshape_node = node.i(0) @@ -448,10 +460,13 @@ def rewrite(self): second_reshape_node = node def check_constant_mergeable(reshape_node): + """Check if a reshape node's shape input, containing zero dimensions, can be merged with its input + node's shape. + """ if isinstance(reshape_node.inputs[1], Constant): input_shape = reshape_node.inputs[0].shape reshape_shape = reshape_node.inputs[1].values - if input_shape != None and np.any(reshape_shape == 0): + if input_shape is not None and np.any(reshape_shape == 0): shape = [ input_shape[i] if dim_size == 0 else dim_size for i, dim_size in enumerate(reshape_shape) ] @@ -497,9 +512,13 @@ def __init__(self, priority): @property def name(self): + """Returns the name of the fusion pattern as a string 'FusionGemm'.""" return "FusionGemm" def rewrite(self): + """Rewrites the graph for the fusion pattern 'FusionGemm' based on matching criteria and constant variables in + matmul nodes. + """ match_case = {} node = self.add_0 matmul_node = self.matmul_0 @@ -663,10 +682,11 @@ def __init__(self, priority): @property def name(self): + """Returns the name of the fusion pattern, 'FusionGelu'.""" return "FusionGelu" def rewrite(self): - match_case = {} + """Rewrite the computation graph pattern to fuse GELU operations.""" input_variable = self.div_0.inputs[0] mul_node = self.mul_0 div_node = self.div_0 @@ -677,15 +697,15 @@ def rewrite(self): output_variable = self.mul_1.outputs[0] output_variable.inputs.clear() - match_case[self.mul_1.name] = { - "op": "Gelu", - "inputs": [input_variable], - "outputs": [output_variable], - "domain": None, + return { + self.mul_1.name: { + "op": "Gelu", + "inputs": [input_variable], + "outputs": [output_variable], + "domain": None, + } } - return match_case - # register_fusion_pattern(GeluPatternMatcher(1)) @@ -704,9 +724,11 @@ def __init__(self, priority): @property def name(self): + """Returns the name of the fusion pattern 'FusionReduce'.""" return "FusionReduce" def rewrite(self, opset=11): + """Rewrites the graph pattern based on opset version; reuses Reduce and Unsqueeze nodes if possible.""" match_case = {} node = self.unsqueeze_0 reduce_node = self.reduce_0 @@ -763,7 +785,6 @@ def replace_custom_layer( def find_matches(graph: Graph, fusion_patterns: dict): """Find matching patterns in the graph based on provided fusion patterns.""" - opset = graph.opset match_map = {} counter = Counter() for node in reversed(graph.nodes): diff --git a/onnxslim/core/symbolic_shape_infer.py b/onnxslim/core/symbolic_shape_infer.py index 3b0d4d9..3473460 100644 --- a/onnxslim/core/symbolic_shape_infer.py +++ b/onnxslim/core/symbolic_shape_infer.py @@ -398,7 +398,7 @@ def _broadcast_shapes(self, shape1, shape2): if self.auto_merge_: self._add_suggested_merge([dim1, dim2], apply=True) else: - logger.warning(f"unsupported broadcast between {str(dim1)} " + str(dim2)) + logger.warning(f"unsupported broadcast between {str(dim1)} {str(dim2)}") new_shape = [new_dim, *new_shape] return new_shape @@ -635,7 +635,7 @@ def _get_int_or_float_values(self, node, broadcast=False, allow_float_values=Fal """Extracts integer or float values from a node, with options for broadcasting and allowing float values.""" def int_or_float(value, allow_float_values): - # If casting into int has precision loss: keep float output + """Converts a value to an integer unless precision loss occurs and allow_float_values is True.""" return value if allow_float_values and value % 1 != 0 else int(value) values = [self._try_get_value(node, i) for i in range(len(node.input))] @@ -1900,6 +1900,9 @@ def _infer_Slice(self, node): # noqa: N802 # # If the number of `min(...)` subexpressions is not exactly one, this function just returns `[expr]`. def flatten_min(expr): + """Returns a list with expressions split by min() for inequality proof or original expr if no single min() + found. + """ assert isinstance(expr, sympy.Add), f"Expected a sum of two arguments, got {expr}" min_positions = [idx for idx in range(len(expr.args)) if isinstance(expr.args[idx], sympy.Min)] if len(min_positions) == 1: @@ -2810,11 +2813,11 @@ def get_prereq(node): # topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate sorted_nodes = [] sorted_known_vi = {i.name for i in list(self.out_mp_.graph.input) + list(self.out_mp_.graph.initializer)} - if any([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): + if any(o.name in sorted_known_vi for o in self.out_mp_.graph.output): # Loop/Scan will have some graph output in graph inputs, so don't do topological sort sorted_nodes = self.out_mp_.graph.node else: - while not all([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): + while not all(o.name in sorted_known_vi for o in self.out_mp_.graph.output): old_sorted_nodes_len = len(sorted_nodes) for node in self.out_mp_.graph.node: if (node.output[0] not in sorted_known_vi) and all( @@ -2986,7 +2989,7 @@ def get_prereq(node): # note that the broadcasting rule aligns from right to left # if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge dim_idx = [len(s) - len(out_shape) + idx for s in shapes] - if len(dim_idx) > 0: + if dim_idx: self._add_suggested_merge( [ s[i] if is_literal(s[i]) else str(s[i]) diff --git a/onnxslim/misc/tabulate.py b/onnxslim/misc/tabulate.py index f2f7778..68bccd2 100644 --- a/onnxslim/misc/tabulate.py +++ b/onnxslim/misc/tabulate.py @@ -2167,7 +2167,7 @@ def tabulate( ) for idx, align in enumerate(headersalign): hidx = headers_pad + idx - if not hidx < len(aligns_headers): + if hidx >= len(aligns_headers): break elif align == "same" and hidx < len(aligns): # same as column align aligns_headers[hidx] = aligns[hidx] diff --git a/onnxslim/onnx_graphsurgeon/ir/function.py b/onnxslim/onnx_graphsurgeon/ir/function.py index d86b91a..664f303 100644 --- a/onnxslim/onnx_graphsurgeon/ir/function.py +++ b/onnxslim/onnx_graphsurgeon/ir/function.py @@ -250,6 +250,7 @@ def __eq__(self, other: "Function"): """Checks equality of self with another Function object based on their attributes.""" def sequences_equal(seq1, seq2): + """Checks if two sequences are equal in length and elements.""" return len(seq1) == len(seq2) and all(elem1 == elem2 for elem1, elem2 in zip(seq1, seq2)) return ( diff --git a/onnxslim/onnx_graphsurgeon/ir/graph.py b/onnxslim/onnx_graphsurgeon/ir/graph.py index 3b86f7c..0399383 100644 --- a/onnxslim/onnx_graphsurgeon/ir/graph.py +++ b/onnxslim/onnx_graphsurgeon/ir/graph.py @@ -556,6 +556,7 @@ def get_inputs(node_or_func): """Find all nodes used by a given node or function.""" def get_used_nodes(node): + """Find all nodes that are used as inputs by a given node.""" inputs = {} def add_local_producers(tensor): @@ -921,6 +922,7 @@ def update_foldable_outputs(graph_constants): """Updates the graph's outputs to ensure certain operations remain foldable.""" def is_foldable(node): + """Determines if a given node operation is foldable based on its type.""" NO_FOLD_OPS = [ "QuantizeLinear", "DequantizeLinear", @@ -1093,7 +1095,9 @@ def partition_and_infer(subgraph): """Evaluates and partitions the subgraph to infer constant values using ONNX-Runtime.""" def get_out_node_ids(): - # Gets the final output nodes - producer nodes of graph output tensors without other outputs. + """Gets the final output nodes, identifying producer nodes of graph output tensors with no other + outputs. + """ with subgraph.node_ids(): out_node_ids = set() for out in subgraph.outputs: @@ -1187,7 +1191,7 @@ def should_eval_foldable(tensor): if onnx_model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF: tmp_dir = tempfile.TemporaryDirectory() tmp_path = os.path.join(tmp_dir.name, "tmp.onnx") - location = os.path.basename(tmp_path) + ".data" + location = f"{os.path.basename(tmp_path)}.data" if os.path.exists(location): os.remove(location) onnx.save( diff --git a/onnxslim/onnx_graphsurgeon/logger/logger.py b/onnxslim/onnx_graphsurgeon/logger/logger.py index 567288f..4c55010 100644 --- a/onnxslim/onnx_graphsurgeon/logger/logger.py +++ b/onnxslim/onnx_graphsurgeon/logger/logger.py @@ -159,6 +159,8 @@ def log(self, message, severity, mode=LogMode.EACH, stack_depth=2): """ def process_message(message, stack_depth): + """Generates a log message prefix with file name and line number based on the specified stack depth.""" + def get_prefix(): def get_line_info(): module = inspect.getmodule(sys._getframe(stack_depth + 3)) or inspect.getmodule( diff --git a/onnxslim/utils.py b/onnxslim/utils.py index 892a04a..f159d78 100644 --- a/onnxslim/utils.py +++ b/onnxslim/utils.py @@ -145,7 +145,7 @@ def onnxruntime_inference(model: onnx.ModelProto, input_data: dict) -> Dict[str, if model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF: tmp_dir = tempfile.TemporaryDirectory() tmp_path = os.path.join(tmp_dir.name, "tmp.onnx") - location = os.path.basename(tmp_path) + ".data" + location = f"{os.path.basename(tmp_path)}.data" if os.path.exists(location): os.remove(location) onnx.save( @@ -484,6 +484,7 @@ def check_result(raw_onnx_output, slimmed_onnx_output): def calculate_tensor_size(tensor): + """Calculates the size of an ONNX tensor in bytes based on its shape and data type size.""" shape = tensor.dims num_elements = np.prod(shape) if shape else 0 element_size = data_type_sizes.get(tensor.data_type, 0) @@ -491,6 +492,7 @@ def calculate_tensor_size(tensor): def get_model_size_and_initializer_size(model): + """Calculates and prints the model size and initializer size for an ONNX model in bytes.""" initializer_size = 0 for tensor in model.graph.initializer: tensor_size = calculate_tensor_size(tensor) @@ -501,6 +503,7 @@ def get_model_size_and_initializer_size(model): def get_model_subgraph_size(model): + """Calculate and print the size of subgraphs in an ONNX model in bytes.""" graph = model.graph for node in graph.node: for attr in node.attribute: diff --git a/tests/test_onnx_nets.py b/tests/test_onnx_nets.py index 106ac81..f45d4d5 100644 --- a/tests/test_onnx_nets.py +++ b/tests/test_onnx_nets.py @@ -28,7 +28,7 @@ def test_torchvision(self, request, model, shape=(1, 3, 224, 224)): """Test various TorchVision models with random input tensors of a specified shape.""" model = model(pretrained=PRETRAINED) x = torch.rand(shape) - directory = "tmp/" + request.node.name + directory = f"tmp/{request.node.name}" os.makedirs(directory, exist_ok=True) filename = f"{directory}/{request.node.name}.onnx" @@ -57,7 +57,7 @@ def test_timm(self, request, model_name): model = timm.create_model(model_name, pretrained=PRETRAINED) input_size = model.default_cfg.get("input_size") x = torch.randn((1,) + input_size) - directory = "tmp/" + request.node.name + directory = f"tmp/{request.node.name}" try: os.makedirs(directory, exist_ok=True) diff --git a/tests/test_pattern_generator.py b/tests/test_pattern_generator.py index 355390e..9fd61c6 100644 --- a/tests/test_pattern_generator.py +++ b/tests/test_pattern_generator.py @@ -11,23 +11,30 @@ class TestPatternGenerator: def test_gelu(self, request): + """Test the GELU activation function within the PatternModel class.""" + class PatternModel(nn.Module): def __init__(self): super(PatternModel, self).__init__() self.gelu = nn.GELU() def forward(self, x): + """Applies the GELU activation function to the input tensor.""" x = self.gelu(x) return x class Model(nn.Module): def __init__(self): + """Initializes the Model class with ReLU and PatternModel components.""" super(Model, self).__init__() self.relu0 = nn.ReLU() self.pattern = PatternModel() self.relu1 = nn.ReLU() def forward(self, x): + """Applies the ReLU activation function, the PatternModel, and another ReLU activation sequentially to + the input tensor. + """ x = self.relu0(x) x = self.pattern(x) x = self.relu1(x) @@ -36,7 +43,7 @@ def forward(self, x): input = torch.randn(2) p = PatternModel() m = Model() - directory = "tmp/" + request.node.name + directory = f"tmp/{request.node.name}" os.makedirs(directory, exist_ok=True) pattern_filename = f"{directory}/{request.node.name}.onnx" @@ -52,13 +59,16 @@ def forward(self, x): class GeluMatcher(PatternMatcher): def __init__(self, pattern, priority): + """Initialize a GeluMatcher with a given pattern and priority.""" super().__init__(pattern, priority) @property def name(self): + """Return the name of the matcher as 'FusionGelu'.""" return "FusionGelu" def rewrite(self): + """Raise an exception indicating a pattern match in GeluMatcher.""" raise Exception("Pattern Matched") register_fusion_pattern(GeluMatcher(pattern, 1)) diff --git a/tests/test_pattern_matcher.py b/tests/test_pattern_matcher.py index c184e49..9e58618 100644 --- a/tests/test_pattern_matcher.py +++ b/tests/test_pattern_matcher.py @@ -9,6 +9,8 @@ class TestPatternMatcher: def test_gelu(self, request): + """Test the GELU activation function in a neural network model using an instance of nn.Module.""" + class Model(nn.Module): def __init__(self): super(Model, self).__init__() @@ -17,6 +19,9 @@ def __init__(self): self.relu1 = nn.ReLU() def forward(self, x): + """Performs a forward pass through the model applying ReLU, GELU, and ReLU activations sequentially to + the input tensor x. + """ x = self.relu0(x) x = self.gelu(x) x = self.relu1(x) @@ -24,7 +29,7 @@ def forward(self, x): input = torch.randn(2) m = Model() - directory = "tmp/" + request.node.name + directory = f"tmp/{request.node.name}" os.makedirs(directory, exist_ok=True) filename = f"{directory}/{request.node.name}.onnx" @@ -33,6 +38,8 @@ def forward(self, x): slim(filename, filename) def test_pad_conv(self, request): + """Test padding followed by 2D convolution within a neural network module.""" + class Model(nn.Module): def __init__(self): super(Model, self).__init__() @@ -43,6 +50,7 @@ def __init__(self): self.conv_1 = nn.Conv2d(1, 1, 3, bias=False) def forward(self, x): + """Applies padding and convolutional layers to the input tensor x.""" x0 = self.pad_0(x) x0 = self.conv_0(x0) @@ -53,7 +61,7 @@ def forward(self, x): input = torch.randn(1, 1, 24, 24) m = Model() - directory = "tmp/" + request.node.name + directory = f"tmp/{request.node.name}" os.makedirs(directory, exist_ok=True) filename = f"{directory}/{request.node.name}.onnx" @@ -62,6 +70,8 @@ def forward(self, x): slim(filename, filename) def test_conv_bn(self, request): + """Test the convolutional layer followed by batch normalization export and re-import via ONNX.""" + class Model(nn.Module): def __init__(self): super(Model, self).__init__() @@ -69,13 +79,14 @@ def __init__(self): self.bn = nn.BatchNorm2d(1) def forward(self, x): + """Perform convolution followed by batch normalization on input tensor x.""" x = self.conv(x) x = self.bn(x) return x input = torch.randn(1, 1, 24, 24) m = Model() - directory = "tmp/" + request.node.name + directory = f"tmp/{request.node.name}" os.makedirs(directory, exist_ok=True) filename = f"{directory}/{request.node.name}.onnx" @@ -84,6 +95,10 @@ def forward(self, x): slim(filename, filename) def test_consecutive_slice(self, request): + """Tests consecutive slicing operations on a model by exporting it to ONNX format and then slimming the ONNX + file. + """ + class Model(nn.Module): def __init__(self): super(Model, self).__init__() @@ -91,11 +106,12 @@ def __init__(self): self.bn = nn.BatchNorm2d(1) def forward(self, x): + """Performs slicing operation on the input tensor x by returning the section x[1:2, :2].""" return x[1:2, :2] input = torch.randn(3, 4) m = Model() - directory = "tmp/" + request.node.name + directory = f"tmp/{request.node.name}" os.makedirs(directory, exist_ok=True) filename = f"{directory}/{request.node.name}.onnx" @@ -104,16 +120,19 @@ def forward(self, x): slim(filename, filename) def test_consecutive_reshape(self, request): + """Test the functionality of consecutive reshape operations in a model and export it to ONNX format.""" + class Model(nn.Module): def __init__(self): super(Model, self).__init__() def forward(self, x): + """Reshape tensor sequentially to (2, 6) and then to (12, 1).""" return x.view(2, 6).view(12, 1) input = torch.randn(3, 4) m = Model() - directory = "tmp/" + request.node.name + directory = f"tmp/{request.node.name}" os.makedirs(directory, exist_ok=True) filename = f"{directory}/{request.node.name}.onnx" @@ -122,19 +141,22 @@ def forward(self, x): slim(filename, filename) def test_matmul_add(self, request): + """Tests matrix multiplication followed by an addition operation within a neural network model.""" + class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.data = torch.randn(4, 3) def forward(self, x): + """Performs matrix multiplication of input 'x' with pre-defined data, adds 1, and returns the result.""" x = torch.matmul(x, self.data) x += 1 return x input = torch.randn(3, 4) m = Model() - directory = "tmp/" + request.node.name + directory = f"tmp/{request.node.name}" os.makedirs(directory, exist_ok=True) filename = f"{directory}/{request.node.name}.onnx" @@ -143,18 +165,25 @@ def forward(self, x): slim(filename, filename) def test_reduce(self, request): + """Tests model reduction by exporting a PyTorch model to ONNX format, slimming it, and saving to a specified + directory. + """ + class Model(nn.Module): def __init__(self): super(Model, self).__init__() def forward(self, x): + """Performs a reduction summing over the last dimension of the input tensor and then unsqueezes the + tensor along the same dimension. + """ x = torch.sum(x, dim=[-1], keepdim=False) x = x.unsqueeze(-1) return x input = torch.randn(3, 4) m = Model() - directory = "tmp/" + request.node.name + directory = f"tmp/{request.node.name}" os.makedirs(directory, exist_ok=True) filename = f"{directory}/{request.node.name}.onnx"