diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 0000000..9459732 --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,29 @@ +# Ultralytics 🚀 - AGPL-3.0 license +# Ultralytics Actions https://github.com/ultralytics/actions +# This workflow automatically formats code and documentation in PRs to official Ultralytics standards + +name: Ultralytics Actions + +on: + push: + branches: [main] + pull_request_target: + branches: [main] + types: [opened, closed, synchronize] + +jobs: + format: + runs-on: ubuntu-latest + steps: + - name: Run Ultralytics Formatting + uses: ultralytics/actions@main + with: + token: ${{ secrets.GITHUB_TOKEN }} # automatically generated, do not modify + python: true # format Python code and docstrings + markdown: true # format Markdown + prettier: true # format YAML + spelling: true # check spelling + links: false # check broken links + # summary: true # print PR summary with GPT4 (requires 'openai_api_key' or 'openai_azure_api_key' and 'openai_azure_endpoint') + # openai_azure_api_key: ${{ secrets.OPENAI_AZURE_API_KEY }} + # openai_azure_endpoint: ${{ secrets.OPENAI_AZURE_ENDPOINT }} diff --git a/README.md b/README.md index c60bdb4..4e3053e 100644 --- a/README.md +++ b/README.md @@ -12,21 +12,24 @@ OnnxSlim can help you slim your onnx model, with less operators, but same accuracy, better inference speed. - 🚀 OnnxSlim is merged to [mnn-llm](https://github.com/wangzhaode/mnn-llm), performance increased by 5% -- 🚀 Rank 1st in the [AICAS 2024 LLM inference optimiztion challenge](https://tianchi.aliyun.com/competition/entrance/532170/customize440) held by Arm and T-head - +- 🚀 Rank 1st in the [AICAS 2024 LLM inference optimization challenge](https://tianchi.aliyun.com/competition/entrance/532170/customize440) held by Arm and T-head # Installation + ## Using Prebuilt + ```bash pip install onnxslim ``` + ## Build From Source + ``` pip install . ``` - # How to use + ``` onnxslim your_onnx_model slimmed_onnx_model ``` @@ -36,12 +39,14 @@ onnxslim your_onnx_model slimmed_onnx_model For more usage, see onnxslim -h or refer to our [examples](./examples) # References -> * [onnx-graphsurgeon](https://github.com/NVIDIA/TensorRT/tree/main/tools/onnx-graphsurgeon) -> * [Polygraphy](https://github.com/NVIDIA/TensorRT/tree/main/tools/Polygraphy/polygraphy) -> * [onnx-simplifier](https://github.com/daquexian/onnx-simplifier) -> * [tabulate](https://github.com/astanin/python-tabulate) -> * [onnxruntime](https://github.com/microsoft/onnxruntime) + +> - [onnx-graphsurgeon](https://github.com/NVIDIA/TensorRT/tree/main/tools/onnx-graphsurgeon) +> - [Polygraphy](https://github.com/NVIDIA/TensorRT/tree/main/tools/Polygraphy/polygraphy) +> - [onnx-simplifier](https://github.com/daquexian/onnx-simplifier) +> - [tabulate](https://github.com/astanin/python-tabulate) +> - [onnxruntime](https://github.com/microsoft/onnxruntime) # Contact -Discord: https://discord.gg/nRw2Fd3VUS + +Discord: https://discord.gg/nRw2Fd3VUS\ QQ Group: 873569894 diff --git a/docs/conf.py b/docs/conf.py index 55b4d89..53dfcf6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -38,4 +38,5 @@ def setup(app): + """Configure Sphinx application to include custom CSS from 'style.css'.""" app.add_css_file("style.css") diff --git a/examples/common_subexpression_elimination/README.md b/examples/common_subexpression_elimination/README.md index 0ac6641..ee63ebf 100644 --- a/examples/common_subexpression_elimination/README.md +++ b/examples/common_subexpression_elimination/README.md @@ -1,16 +1,20 @@ # Common SubExpression Elimination ## Introduction + Common Subexpression Elimination (CSE) is a powerful optimization technique commonly employed in compilers to improve the efficiency of code execution. It targets redundant computations within a program by identifying and removing duplicate expressions, thus reducing both computational overhead and memory usage. By eliminating redundant computations, CSE enhances the overall performance of slimmed onnx model. ## How CSE Works + In many programs, certain expressions are computed multiple times within a given scope, even though their results remain constant across these computations. Common subexpressions refer to these redundant expressions. CSE identifies such common subexpressions and replaces subsequent occurrences with references to the original computation result. This process effectively reduces the number of computations required during program execution. For example, consider the following code snippet: + ``` int a = b + c; int x = b + c; ``` + In this code, b + c is a common subexpression computed twice. With CSE, the redundant computation of b + c would be eliminated, and both occurrences of x would directly reference the computation result of a. ## Running the example @@ -31,7 +35,6 @@ After onnxslim, the output will look like this: ![../../image/after_cse.png](../../images/after_cse.png) - and the summary is as follow: -![../../image/cse.png](../../images/cse.png) \ No newline at end of file +![../../image/cse.png](../../images/cse.png) diff --git a/examples/common_subexpression_elimination/cse_demo.py b/examples/common_subexpression_elimination/cse_demo.py index ce8794d..a98d038 100644 --- a/examples/common_subexpression_elimination/cse_demo.py +++ b/examples/common_subexpression_elimination/cse_demo.py @@ -5,11 +5,13 @@ class Model(torch.nn.Module): def __init__(self): + """Initializes the Model class with a single LayerNorm layer of embedding dimension 10.""" super(Model, self).__init__() embedding_dim = 10 self.layer_norm = nn.LayerNorm(embedding_dim) def forward(self, x): + """Applies LayerNorm to the input tensor and adds it to an independently computed LayerNorm of the same tensor.""" return self.layer_norm(x) + F.layer_norm(x, [10]) diff --git a/examples/input_shape_modification/README.md b/examples/input_shape_modification/README.md index eb63d31..c2968c4 100644 --- a/examples/input_shape_modification/README.md +++ b/examples/input_shape_modification/README.md @@ -1,11 +1,13 @@ # Input Shape Modification ## Introduction + OnnxSlim includes an exploration of essential input shape modification techniques for ONNX models. This concise guide unveils techniques for seamlessly adjusting input tensor dimensions, ensuring optimal compatibility and performance within the dynamic landscape of neural network architectures. ## Running the example + Change the input model by running: ```bash @@ -14,4 +16,4 @@ onnxslim UNetModel-fp16.onnx slim.onnx --input_shapes cc:1,1,768 The slimmed model will look like this: -![../../image/input_shape_modification.jpg](../../images/input_shape_modification.jpg) \ No newline at end of file +![../../image/input_shape_modification.jpg](../../images/input_shape_modification.jpg) diff --git a/examples/model_inspect/README.md b/examples/model_inspect/README.md index 456d19f..3b07a9e 100644 --- a/examples/model_inspect/README.md +++ b/examples/model_inspect/README.md @@ -1,9 +1,11 @@ # Model Inspect ## Introduction + Dive deep into the intricacies of your ONNX model using the powerful --inspect argument with OnnxSlim. This feature provides detailed insights into various aspects of your model, including input and output details, operator information, opset version, and more. ## Running the example + Unveil the secrets of your ONNX model by executing the following command: ```bash @@ -12,4 +14,4 @@ onnxslim --inspect UNetModel-fp16.onnx The output will look like this: -![../../image/model_inspect.jpg](../../images/model_inspect.jpg) \ No newline at end of file +![../../image/model_inspect.jpg](../../images/model_inspect.jpg) diff --git a/examples/output_modification/README.md b/examples/output_modification/README.md index 792d5c6..cac7f2c 100644 --- a/examples/output_modification/README.md +++ b/examples/output_modification/README.md @@ -1,11 +1,13 @@ # Output Modification ## Introduction + OnnxSlim provides capabilities for modifying the output specifications of ONNX models. This section explores techniques to customize the outputs, allowing for flexibility in handling diverse model requirements. ## Running the example + Change the output of one model by running: ```bash @@ -14,4 +16,4 @@ onnxslim yolov5m.onnx slim.onnx --outputs 591 739 443 The slimmed model will look like this: -![../../image/output_modification.jpg](../../images/output_modification.jpg) \ No newline at end of file +![../../image/output_modification.jpg](../../images/output_modification.jpg) diff --git a/onnxslim/__init__.py b/onnxslim/__init__.py index 56dc54c..0520295 100644 --- a/onnxslim/__init__.py +++ b/onnxslim/__init__.py @@ -5,10 +5,7 @@ from .core.optimizer import DEFAULT_FUSION_PATTERNS from .version import __version__ - -if os.path.dirname(os.path.realpath(__file__)) == os.path.join( - os.path.realpath(os.getcwd()), "onnxslim" -): +if os.path.dirname(os.path.realpath(__file__)) == os.path.join(os.path.realpath(os.getcwd()), "onnxslim"): message = ( "You are importing onnxslim within its own root folder ({}). " "This is not expected to work and may give errors. Please exit the " diff --git a/onnxslim/cli/_main.py b/onnxslim/cli/_main.py index cc0ba3d..36b9ecb 100644 --- a/onnxslim/cli/_main.py +++ b/onnxslim/cli/_main.py @@ -1,6 +1,7 @@ from typing import Union import onnx + from onnxslim.utils.utils import logger @@ -83,11 +84,7 @@ def slim( init_logging(verbose) - MAX_ITER = ( - 10 - if not os.getenv("ONNXSLIM_MAX_ITER") - else int(os.getenv("ONNXSLIM_MAX_ITER")) - ) + MAX_ITER = 10 if not os.getenv("ONNXSLIM_MAX_ITER") else int(os.getenv("ONNXSLIM_MAX_ITER")) if isinstance(model, str): model_name = Path(model).name @@ -166,6 +163,7 @@ def slim( def main(): + """Entry point for the OnnxSlim toolkit, processes command-line arguments and passes them to the slim function.""" import argparse import onnxslim @@ -175,14 +173,10 @@ def main(): formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("input_model", help="input onnx model") - parser.add_argument( - "output_model", nargs="?", default=None, help="output onnx model" - ) + parser.add_argument("output_model", nargs="?", default=None, help="output onnx model") parser.add_argument("--model_check", action="store_true", help="enable model check") - parser.add_argument( - "-v", "--version", action="version", version=onnxslim.__version__ - ) + parser.add_argument("-v", "--version", action="version", version=onnxslim.__version__) # Input Shape Modification parser.add_argument( @@ -259,9 +253,7 @@ def main(): ) # Verbose - parser.add_argument( - "--verbose", action="store_true", help="verbose mode, default False." - ) + parser.add_argument("--verbose", action="store_true", help="verbose mode, default False.") args, unknown = parser.parse_known_args() diff --git a/onnxslim/core/optimizer.py b/onnxslim/core/optimizer.py index f57f1ba..0cd6806 100644 --- a/onnxslim/core/optimizer.py +++ b/onnxslim/core/optimizer.py @@ -1,23 +1,21 @@ import contextlib from collections import Counter, OrderedDict - from typing import List, Union import numpy as np - import onnx -from onnxslim.utils.utils import logger import onnxslim.onnx_graphsurgeon as gs from onnxslim.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx from onnxslim.onnx_graphsurgeon.ir.graph import Graph from onnxslim.onnx_graphsurgeon.ir.tensor import Constant, Variable - +from onnxslim.utils.utils import logger DEFAULT_FUSION_PATTERNS = OrderedDict() def register_fusion_pattern(layer_type): + """Registers a fusion pattern function for a specific layer type in the DEFAULT_FUSION_PATTERNS dictionary.""" def insert(fn): if layer_type in DEFAULT_FUSION_PATTERNS.keys(): raise @@ -28,6 +26,7 @@ def insert(fn): def get_fusion_patterns(skip_fusion_patterns: str = None): + """Returns a dictionary of default fusion patterns, optionally excluding specific patterns.""" default_fusion_patterns = DEFAULT_FUSION_PATTERNS.copy() if skip_fusion_patterns: for pattern in skip_fusion_patterns: @@ -37,6 +36,7 @@ def get_fusion_patterns(skip_fusion_patterns: str = None): def get_node_users(node): + """Retrieve the list of users for a given node based on its outputs.""" users = [] for output in node.outputs: # output is a Variable for user in output.outputs: # user is a Node @@ -45,6 +45,7 @@ def get_node_users(node): def get_node_feeds(node): + """Retrieve the list of feed nodes for a given node based on its inputs.""" feeds = [] for input in node.inputs: # input is a Variable for feed in input.inputs: # feed is a Node @@ -53,6 +54,7 @@ def get_node_feeds(node): def get_previous_node_by_type(node, op_type, trajectory=[]): + """Retrieve the previous node of a specific type in the computational graph starting from the given node.""" node_feeds = get_node_feeds(node) for node_feed in node_feeds: if node_feed.op == op_type: @@ -64,12 +66,14 @@ def get_previous_node_by_type(node, op_type, trajectory=[]): def get_constant_variable(node, return_idx=False): + """Return the first constant variable found in a node's inputs, optionally returning the index.""" for idx, input in enumerate(list(node.inputs)): if isinstance(input, Constant): return input if not return_idx else (idx, input) def delete_node(node, input_var_idx=0, output_var_idx=0): + """Delete a node from the computation graph while redirecting its inputs to its outputs to maintain graph integrity.""" input_variable = node.inputs[input_var_idx] node_variable = node.outputs[output_var_idx] next_nodes = get_node_users(node) @@ -86,6 +90,7 @@ def delete_node(node, input_var_idx=0, output_var_idx=0): def check_shape(shapes): + """Verify that 'shapes' contains exactly one string and all other elements are positive integers.""" string_count = 0 non_negative_int_count = 0 @@ -99,6 +104,7 @@ def check_shape(shapes): def graph_constant_fold_inplace(graph): + """Perform in-place constant folding optimizations on the provided computational graph by eliminating redundant nodes.""" for node in graph.nodes: if node.op == "Identity" or node.op == "Dropout": delete_node(node) @@ -106,9 +112,7 @@ def graph_constant_fold_inplace(graph): elif node.op == "Pad": if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant): pad_value = node.inputs[1].values.tolist() - pad_value = ( - [pad_value] if not isinstance(pad_value, list) else pad_value - ) + pad_value = [pad_value] if not isinstance(pad_value, list) else pad_value if all([value == 0 for value in pad_value]): delete_node(node) elif node.op == "Cast": @@ -123,10 +127,7 @@ def graph_constant_fold_inplace(graph): else: node_output_shape = node.outputs[0].shape if node_output_shape and check_shape(node_output_shape): - shapes = [ - shape if isinstance(shape, int) else -1 - for shape in node_output_shape - ] + shapes = [shape if isinstance(shape, int) else -1 for shape in node_output_shape] reshape_const = gs.Constant( node.inputs[1].name + "_", values=np.array(shapes, dtype=np.int64), @@ -134,24 +135,16 @@ def graph_constant_fold_inplace(graph): node.inputs.pop(1) node.inputs.insert(1, reshape_const) elif node.op == "Mul": - if ( - isinstance(node.inputs[1], Constant) - and isinstance(node.inputs[0], Variable) - ) or ( - isinstance(node.inputs[0], Constant) - and isinstance(node.inputs[1], Variable) + if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or ( + isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable) ): idx, constant_variable = get_constant_variable(node, return_idx=True) if np.all(constant_variable.values == 1): var_idx = 0 if idx == 1 else 1 delete_node(node, var_idx) elif node.op == "Add": - if ( - isinstance(node.inputs[1], Constant) - and isinstance(node.inputs[0], Variable) - ) or ( - isinstance(node.inputs[0], Constant) - and isinstance(node.inputs[1], Variable) + if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or ( + isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable) ): idx, constant_variable = get_constant_variable(node, return_idx=True) if np.all(constant_variable.values == 0): @@ -166,7 +159,7 @@ def graph_constant_fold_inplace(graph): @register_fusion_pattern("FusionPadConv") def find_conv_nodes(node, opset): - # fmt: off + """Identify and match convolution nodes following a padding operation to update padding attributes for fusion purposes.""" ''' x | @@ -201,10 +194,7 @@ def find_conv_nodes(node, opset): len_conv_pads = int(len(conv_pads) / 2) len_pads = int(len(pad_value) / 2) - pads = ( - pad_value[len_pads - len_conv_pads : len_pads] - + pad_value[len_pads + len_conv_pads :] - ) + pads = pad_value[len_pads - len_conv_pads : len_pads] + pad_value[len_pads + len_conv_pads :] pads = [pad + conv_pad for pad, conv_pad in zip(pads, conv_pads)] attrs["pads"] = pads @@ -227,13 +217,7 @@ def find_conv_nodes(node, opset): @register_fusion_pattern("FusionConvBN") def find_conv_transpose_nodes(node, opset): # fmt: off - ''' - x - | - Conv/ConvTranspose - | - BatchNormalization - ''' + """X | Conv/ConvTranspose | BatchNormalization.""" # fmt: on match = {} if node.op == "BatchNormalization": @@ -260,12 +244,8 @@ def find_conv_transpose_nodes(node, opset): shape[0] = -1 else: shape[1] = -1 - conv_w = conv_transpose_weight * (bn_scale * bn_var_rsqrt).reshape( - shape - ) - conv_b = ( - conv_transpose_bias - bn_running_mean - ) * bn_var_rsqrt * bn_scale + bn_bias + conv_w = conv_transpose_weight * (bn_scale * bn_var_rsqrt).reshape(shape) + conv_b = (conv_transpose_bias - bn_running_mean) * bn_var_rsqrt * bn_scale + bn_bias inputs = [] inputs.append(list(conv_transpose_node.inputs)[0]) @@ -300,7 +280,7 @@ def find_conv_transpose_nodes(node, opset): @register_fusion_pattern("EliminationSlice") def find_slice_nodes(node, opset): - # fmt: off + """Identify and combine consecutive 'Slice' nodes in a computational graph for optimization.""" ''' x | @@ -314,19 +294,11 @@ def find_slice_nodes(node, opset): if node.i(0).op == "Slice": first_slice_node = node.i(0) first_slice_node_inputs = list(first_slice_node.inputs) - if all( - [isinstance(input, Constant) for input in first_slice_node_inputs[1:]] - ): + if all([isinstance(input, Constant) for input in first_slice_node_inputs[1:]]): first_slice_node_users = get_node_users(first_slice_node) if all( [ - user.op == "Slice" - and all( - [ - isinstance(input, Constant) - for input in list(user.inputs)[1:] - ] - ) + user.op == "Slice" and all([isinstance(input, Constant) for input in list(user.inputs)[1:]]) for user in first_slice_node_users ] ): @@ -338,18 +310,10 @@ def find_slice_nodes(node, opset): for user_node in first_slice_node_users: second_slice_node = user_node second_slice_node_inputs = list(second_slice_node.inputs) - second_slice_node_starts = second_slice_node_inputs[ - 1 - ].values.tolist() - second_slice_node_ends = second_slice_node_inputs[ - 2 - ].values.tolist() - second_slice_node_axes = second_slice_node_inputs[ - 3 - ].values.tolist() - second_slice_node_steps = second_slice_node_inputs[ - 4 - ].values.tolist() + second_slice_node_starts = second_slice_node_inputs[1].values.tolist() + second_slice_node_ends = second_slice_node_inputs[2].values.tolist() + second_slice_node_axes = second_slice_node_inputs[3].values.tolist() + second_slice_node_steps = second_slice_node_inputs[4].values.tolist() new_starts = first_slice_node_starts + second_slice_node_starts new_ends = first_slice_node_ends + second_slice_node_ends @@ -423,7 +387,7 @@ def find_slice_nodes(node, opset): @register_fusion_pattern("EliminationReshape") def find_reshape_nodes(node, opset): - # fmt: off + """Identify consecutive 'Reshape' nodes in the computational graph and validate their mergeability based on input and output shapes.""" ''' x | @@ -454,9 +418,7 @@ def check_constant_mergeable(reshape_node): return False return True - if check_constant_mergeable( - first_reshape_node - ) and check_constant_mergeable(second_reshape_node): + if check_constant_mergeable(first_reshape_node) and check_constant_mergeable(second_reshape_node): inputs = [] inputs.append(first_reshape_node_inputs[0]) inputs.append(second_reshape_node.inputs[1]) @@ -483,7 +445,7 @@ def check_constant_mergeable(reshape_node): # @register_fusion_pattern("EliminationTranspose") def find_slice_nodes(node, opset): - # fmt: off + """Identifies and processes consecutive Transpose nodes, removing redundant ones to optimize the graph.""" ''' x | @@ -509,9 +471,7 @@ def find_slice_nodes(node, opset): ) last_node.inputs.pop(3) last_node.inputs.insert(3, slice_axis) - previous_transpose_node_variable = previous_transpose_node.outputs[ - 0 - ] # pad output variable + previous_transpose_node_variable = previous_transpose_node.outputs[0] # pad output variable previous_transpose_node_variable.outputs.remove(last_node) last_node.inputs.insert(0, previous_transpose_node.inputs[0]) for node in previous_nodes: @@ -525,7 +485,7 @@ def find_slice_nodes(node, opset): @register_fusion_pattern("FusionGemm") def find_matmul_add_nodes(node, opset): - # fmt: off + """Identifies and returns a pattern match for MatMul followed by Add operations for optimization in a computational graph.""" ''' x | @@ -539,14 +499,10 @@ def find_matmul_add_nodes(node, opset): if (isinstance(node.inputs[1], Constant) and node.i(0).op == "MatMul") or ( isinstance(node.inputs[0], Constant) and node.i(1).op == "MatMul" ): - matmul_node = ( - node.i(0) if isinstance(node.inputs[1], Constant) else node.i(1) - ) + matmul_node = node.i(0) if isinstance(node.inputs[1], Constant) else node.i(1) matmul_bias_variable = get_constant_variable(matmul_node) input_variable = ( - matmul_node.inputs[0] - if isinstance(matmul_node.inputs[1], Constant) - else matmul_node.inputs[1] + matmul_node.inputs[0] if isinstance(matmul_node.inputs[1], Constant) else matmul_node.inputs[1] ) users = get_node_users(matmul_node) if len(users) == 1 and matmul_bias_variable: @@ -557,9 +513,7 @@ def find_matmul_add_nodes(node, opset): ): pre_reshape_const = gs.Constant( matmul_node.name + "_pre_reshape_in", - values=np.array( - [-1, matmul_bias_variable.values.shape[0]], dtype=np.int64 - ), + values=np.array([-1, matmul_bias_variable.values.shape[0]], dtype=np.int64), ) inputs = [] inputs.append(input_variable) @@ -573,8 +527,7 @@ def find_matmul_add_nodes(node, opset): match.update( { - matmul_node.name - + "_pre_reshape": { + matmul_node.name + "_pre_reshape": { "op": "Reshape", "inputs": inputs, "outputs": outputs, @@ -599,9 +552,7 @@ def find_matmul_add_nodes(node, opset): inputs.append(matmul_bias_transpose_constant) inputs.append(add_bias_variable) - gemm_out_variable = gs.Variable( - matmul_node.name + "_gemm_out", dtype=output_variable.dtype - ) + gemm_out_variable = gs.Variable(matmul_node.name + "_gemm_out", dtype=output_variable.dtype) outputs = [gemm_out_variable] match.update( @@ -622,9 +573,7 @@ def find_matmul_add_nodes(node, opset): } ) - values = input_variable.shape[:-1] + [ - matmul_bias_variable.values.shape[-1] - ] + values = input_variable.shape[:-1] + [matmul_bias_variable.values.shape[-1]] post_reshape_const = gs.Constant( matmul_node.name + "_post_reshape_in", values=np.array(values, dtype=np.int64), @@ -641,8 +590,7 @@ def find_matmul_add_nodes(node, opset): match.update( { - matmul_node.name - + "_post_reshape": { + matmul_node.name + "_post_reshape": { "op": "Reshape", "inputs": inputs, "outputs": outputs, @@ -696,7 +644,7 @@ def find_matmul_add_nodes(node, opset): # @register_fusion_pattern("FusionGelu") def find_gelu_nodes(node, opset): - # fmt: off + """Identifies GELU (Gaussian Error Linear Unit) activation pattern nodes in a computational graph.""" ''' x / \ @@ -744,7 +692,7 @@ def find_gelu_nodes(node, opset): @register_fusion_pattern("FusionReduce") def find_slice_nodes(node, opset): - # fmt: off + """Find and return a dictionary of matching 'ReduceSum' followed by 'Unsqueeze' nodes that match specific conditions in the graph.""" ''' x | @@ -815,6 +763,7 @@ def replace_custom_layer( def find_matches(graph: Graph, fusion_patterns: dict): + """Find matching patterns in the graph based on provided fusion patterns.""" opset = graph.opset match_map = {} counter = Counter() @@ -829,13 +778,7 @@ def find_matches(graph: Graph, fusion_patterns: dict): if "op" not in match: match.update({"op": layer_type}) if "name" not in match: - match.update( - { - "name": "{}_{}".format( - layer_type.lower(), counter[layer_type] - ) - } - ) + match.update({"name": "{}_{}".format(layer_type.lower(), counter[layer_type])}) counter.update([layer_type]) match_map.update(matches) @@ -843,6 +786,7 @@ def find_matches(graph: Graph, fusion_patterns: dict): def find_and_remove_replaceable_nodes(nodes): + """Find and remove duplicate or replaceable nodes in a given list of computational graph nodes.""" def get_node_key(node): input_names = [] for input_node in node.inputs: @@ -880,22 +824,17 @@ def replace_node_references(existing_node, to_be_removed_node): if keep_nodes[i]: for j in range(i + 1, len(bucketed_nodes)): if keep_nodes[j]: - logger.debug( - f"node.op {bucketed_nodes[0].op} idx i: {i}, idx j: {j}" - ) + logger.debug(f"node.op {bucketed_nodes[0].op} idx i: {i}, idx j: {j}") if can_be_replaced(node, bucketed_nodes[j]): keep_nodes[j] = False existing_node = node to_be_removed_node = bucketed_nodes[j] - replace_node_references( - existing_node, to_be_removed_node - ) - logger.debug( - f"Node {to_be_removed_node.name} can be replaced by {existing_node.name}" - ) + replace_node_references(existing_node, to_be_removed_node) + logger.debug(f"Node {to_be_removed_node.name} can be replaced by {existing_node.name}") def sequences_equal(seq1, seq2): + """Check if two sequences are equal by comparing their lengths and elements.""" length_match = len(seq1) == len(seq2) if not length_match: return False @@ -908,6 +847,7 @@ def sequences_equal(seq1, seq2): def can_be_replaced(node, other_node): + """Check if two nodes can be replaced based on their operations, attributes, and inputs.""" attrs_match = node.op == other_node.op and node.attrs == other_node.attrs inputs_match = sequences_equal(node.inputs, other_node.inputs) @@ -915,6 +855,7 @@ def can_be_replaced(node, other_node): def subexpression_elimination(graph): + """Perform subexpression elimination on a computational graph to optimize node operations.""" nodes_by_op = {} for node in graph.nodes: @@ -927,9 +868,7 @@ def subexpression_elimination(graph): find_and_remove_replaceable_nodes(nodes) -def optimize_model( - model: Union[onnx.ModelProto, gs.Graph], skip_fusion_patterns: str = None -) -> onnx.ModelProto: +def optimize_model(model: Union[onnx.ModelProto, gs.Graph], skip_fusion_patterns: str = None) -> onnx.ModelProto: if isinstance(model, gs.Graph): graph = model else: diff --git a/onnxslim/core/slim.py b/onnxslim/core/slim.py index 3b0ef96..a691612 100644 --- a/onnxslim/core/slim.py +++ b/onnxslim/core/slim.py @@ -1,5 +1,4 @@ import logging - import os import sys import tempfile @@ -16,36 +15,35 @@ from ..utils.utils import ( dump_model_info_to_disk, gen_onnxruntime_input_data, + logger, onnxruntime_inference, print_model_info_as_table, - logger ) - from .optimizer import delete_node, optimize_model from .symbolic_shape_infer import SymbolicShapeInference DEBUG = bool(os.getenv("ONNXSLIM_DEBUG")) -AUTO_MERGE = ( - True - if os.getenv("ONNXSLIM_AUTO_MERGE") is None - else bool(int(os.getenv("ONNXSLIM_AUTO_MERGE"))) -) +AUTO_MERGE = True if os.getenv("ONNXSLIM_AUTO_MERGE") is None else bool(int(os.getenv("ONNXSLIM_AUTO_MERGE"))) def init_logging(verbose=False): - # Remove all handlers associated with the root logger object. + """Configure the logging settings for the application, setting verbosity based on the 'verbose' parameter.""" for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) if verbose: # DEBUG - logging.basicConfig(level=logging.DEBUG, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler(sys.stderr)]) + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stderr)], + ) G_LOGGER.severity = logging.DEBUG else: # ERROR - logging.basicConfig(level=logging.ERROR, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler(sys.stderr)]) + logging.basicConfig( + level=logging.ERROR, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stderr)], + ) G_LOGGER.severity = logging.ERROR G_LOGGER.colors = False @@ -77,9 +75,8 @@ def summarize_model(model: onnx.ModelProto) -> Dict: op_type_counts = {} def get_tensor_dtype_shape(tensor): - type_str = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE.get( - tensor.type.tensor_type.elem_type, "Unknown" - ) + """Extract the data type and shape of an ONNX tensor.""" + type_str = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE.get(tensor.type.tensor_type.elem_type, "Unknown") shape = None if tensor.type.tensor_type.HasField("shape"): shape = [] @@ -104,9 +101,7 @@ def get_shape(inputs: onnx.ModelProto) -> Dict[str, List[int]]: return op_shape_info - value_info_dict = { - value_info.name: value_info for value_info in model.graph.value_info - } + value_info_dict = {value_info.name: value_info for value_info in model.graph.value_info} for node in model.graph.node: op_type = node.op_type @@ -136,6 +131,7 @@ def get_shape(inputs: onnx.ModelProto) -> Dict[str, List[int]]: def model_save_as_external_data(model: onnx.ModelProto, model_path: str): + """Save an ONNX model with tensor data as an external file.""" location = os.path.basename(model_path) + ".data" if os.path.exists(location): os.remove(location) @@ -148,9 +144,7 @@ def model_save_as_external_data(model: onnx.ModelProto, model_path: str): ) -def input_shape_modification( - model: onnx.ModelProto, input_shapes: str -) -> onnx.ModelProto: +def input_shape_modification(model: onnx.ModelProto, input_shapes: str) -> onnx.ModelProto: if not input_shapes: return @@ -162,9 +156,7 @@ def input_shape_modification( key, values = input_shape.rsplit(":", 1) values_list = [int(value) for value in values.split(",")] if key not in input_names: - raise Exception( - f"Input name {key} not found in model, available keys: {' '.join(input_names)}" - ) + raise Exception(f"Input name {key} not found in model, available keys: {' '.join(input_names)}") tensors[key].shape = values_list for _, tensor in tensors.items(): @@ -187,15 +179,11 @@ def output_modification(model: onnx.ModelProto, outputs: str) -> onnx.ModelProto if len(values) == 1: key = values[0] if key not in tensors.keys(): - raise Exception( - f"Output name {key} not found in model, available keys: {' '.join(tensors.keys())}" - ) + raise Exception(f"Output name {key} not found in model, available keys: {' '.join(tensors.keys())}") dtype = tensors[key].dtype if dtype == None: dtype = np.float32 - logger.warning( - f"Output layer {key} has no dtype, set to default {dtype}" - ) + logger.warning(f"Output layer {key} has no dtype, set to default {dtype}") else: key, dtype = values if dtype == "fp16": @@ -207,13 +195,9 @@ def output_modification(model: onnx.ModelProto, outputs: str) -> onnx.ModelProto elif dtype == "bool": dtype = bool else: - raise Exception( - f"Output layer {key} assigned unsupported dtype {dtype}" - ) + raise Exception(f"Output layer {key} assigned unsupported dtype {dtype}") - graph.outputs.append( - tensors[key].to_variable(dtype=dtype, shape=tensors[key].shape) - ) + graph.outputs.append(tensors[key].to_variable(dtype=dtype, shape=tensors[key].shape)) graph.cleanup(remove_unused_graph_inputs=True).toposort() model = gs.export_onnx(graph) @@ -222,6 +206,7 @@ def output_modification(model: onnx.ModelProto, outputs: str) -> onnx.ModelProto def check_onnx(model: onnx.ModelProto, model_check_inputs=None): + """Validates an ONNX model by generating input data and performing inference to check outputs.""" input_data_dict = gen_onnxruntime_input_data(model, model_check_inputs) raw_onnx_output = onnxruntime_inference(model, input_data_dict) @@ -229,6 +214,7 @@ def check_onnx(model: onnx.ModelProto, model_check_inputs=None): def shape_infer(model: onnx.ModelProto): + """Infer tensor shapes in an ONNX model using onnxruntime or ONNX shape inference methods.""" logger.debug("Start shape inference.") try: logger.debug("try onnxruntime shape infer.") @@ -251,6 +237,7 @@ def shape_infer(model: onnx.ModelProto): def optimize(model: onnx.ModelProto, skip_fusion_patterns: str = None): + """Optimize the given ONNX model by converting to GraphSurgeon format, performing constant folding, and model optimizations.""" logger.debug("Start converting model to gs.") graph = gs.import_onnx(model).toposort() logger.debug("Finish converting model to gs.") @@ -267,6 +254,7 @@ def optimize(model: onnx.ModelProto, skip_fusion_patterns: str = None): def check_point(model: onnx.ModelProto): + """Imports an ONNX model into a graph from the specified model checkpoint.""" graph_check_point = gs.import_onnx(model) return graph_check_point @@ -310,6 +298,7 @@ def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto: def save(model: onnx.ModelProto, model_path: str, model_check: bool = False): + """Save an ONNX model to a specified path, with optional model checking for validity.""" if model_check: try: checker.check_model(model) @@ -338,12 +327,11 @@ def save(model: onnx.ModelProto, model_path: str, model_check: bool = False): def check_result(raw_onnx_output, slimmed_onnx_output): + """Verify the consistency of outputs between the raw and slimmed ONNX models, logging warnings if discrepancies are detected.""" if set(raw_onnx_output.keys()) != set(slimmed_onnx_output.keys()): logger.warning("Model output mismatch after slimming.") logger.warning("Raw model output keys: {}".format(raw_onnx_output.keys())) - logger.warning( - "Slimmed model output keys: {}".format(slimmed_onnx_output.keys()) - ) + logger.warning("Slimmed model output keys: {}".format(slimmed_onnx_output.keys())) logger.warning("Please check the model carefully.") return else: @@ -361,6 +349,7 @@ def check_result(raw_onnx_output, slimmed_onnx_output): def freeze(model: onnx.ModelProto): + """Freezes the ONNX model by removing inputs that are also present in the initializers.""" inputs = model.graph.input name_to_input = {} for input in inputs: diff --git a/onnxslim/core/symbolic_shape_infer.py b/onnxslim/core/symbolic_shape_infer.py index 1cf79ff..3573471 100644 --- a/onnxslim/core/symbolic_shape_infer.py +++ b/onnxslim/core/symbolic_shape_infer.py @@ -17,6 +17,7 @@ def get_attribute(node, attr_name, default_value=None): + """Retrieve the value of an attribute from an ONNX node, returning a default if the attribute is not found.""" found = [attr for attr in node.attribute if attr.name == attr_name] if found: return helper.get_attribute_value(found[0]) @@ -24,20 +25,19 @@ def get_attribute(node, attr_name, default_value=None): def get_dim_from_proto(dim): - return ( - getattr(dim, dim.WhichOneof("value")) - if type(dim.WhichOneof("value")) is str - else None - ) # noqa: E721 + """Retrieve the dimension value from the ONNX protobuf object 'dim'.""" + return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) is str else None # noqa: E721 def is_sequence(type_proto): + """Determine if the given ONNX 'type_proto' represents a sequence type.""" cls_type = type_proto.WhichOneof("value") assert cls_type in ["tensor_type", "sequence_type"] return cls_type == "sequence_type" def get_shape_from_type_proto(type_proto): + """Extract the shape from an ONNX 'type_proto' if it represents a tensor type.""" assert not is_sequence(type_proto) if type_proto.tensor_type.HasField("shape"): return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim] @@ -46,6 +46,7 @@ def get_shape_from_type_proto(type_proto): def get_elem_type_from_type_proto(type_proto): + """Return the tensor element type from the provided type_proto, handling both sequence and non-sequence cases.""" if is_sequence(type_proto): return type_proto.sequence_type.elem_type.tensor_type.elem_type else: @@ -53,6 +54,7 @@ def get_elem_type_from_type_proto(type_proto): def get_shape_from_value_info(vi): + """Returns the shape of the tensor from the provided value information.""" cls_type = vi.type.WhichOneof("value") if cls_type is None: return None @@ -66,30 +68,30 @@ def get_shape_from_value_info(vi): def make_named_value_info(name): + """Create and return an ONNX ValueInfoProto object with the specified name.""" vi = onnx.ValueInfoProto() vi.name = name return vi def get_shape_from_sympy_shape(sympy_shape): - return [ - None if i is None else (int(i) if is_literal(i) else str(i)) - for i in sympy_shape - ] + """Convert a sympy shape to a list with int, str, or None elements.""" + return [None if i is None else (int(i) if is_literal(i) else str(i)) for i in sympy_shape] def is_literal(dim): - return type(dim) in [int, np.int64, np.int32, sympy.Integer] or ( - hasattr(dim, "is_number") and dim.is_number - ) + """Check if a dimension is a literal number (int, np.int64, np.int32, sympy.Integer) or has an 'is_number' attribute.""" + return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(dim, "is_number") and dim.is_number) def handle_negative_axis(axis, rank): + """Convert a potentially negative axis to a positive axis based on the given rank.""" assert axis < rank and axis >= -rank return axis if axis >= 0 else rank + axis def get_opset(mp, domain=None): + """Retrieve the opset version for a given model namespace, defaulting to common ONNX domains if no specific domain is provided.""" domain = domain or ["", "onnx", "ai.onnx"] if type(domain) != list: # noqa: E721 domain = [domain] @@ -101,6 +103,7 @@ def get_opset(mp, domain=None): def as_scalar(x): + """Convert input to scalar if input is a list with a single item or a NumPy ndarray.""" if type(x) == list: # noqa: E721 assert len(x) == 1 return x[0] @@ -111,6 +114,7 @@ def as_scalar(x): def as_list(x, keep_none): + """Convert input to list, optionally preserving None values.""" if type(x) == list: # noqa: E721 return x elif type(x) == np.ndarray: @@ -122,6 +126,7 @@ def as_list(x, keep_none): def sympy_reduce_product(x): + """Reduce a list of sympy expressions to a single product or return the input if not a list.""" if type(x) == list: # noqa: E721 value = sympy.Integer(1) for v in x: @@ -133,6 +138,7 @@ def sympy_reduce_product(x): class SymbolicShapeInference: def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): + """Initializes the SymbolicShapeInference class with configuration parameters for symbolic shape inference.""" self.dispatcher_ = { "Add": self._infer_symbolic_compute_ops, "ArrayFeatureExtractor": self._infer_ArrayFeatureExtractor, @@ -263,12 +269,8 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): self.prefix_ = prefix def _add_suggested_merge(self, symbols, apply=False): - assert all( - [ - (type(s) == str and s in self.symbolic_dims_) or is_literal(s) - for s in symbols - ] - ) # noqa: E721 + """Add suggested merges for input symbols, prioritizing literals, input symbolic dims, or existing symbolic dims.""" + assert all([(type(s) == str and s in self.symbolic_dims_) or is_literal(s) for s in symbols]) # noqa: E721 symbols = set(symbols) for k, v in self.suggested_merge_.items(): if k in symbols: @@ -294,11 +296,7 @@ def _add_suggested_merge(self, symbols, apply=False): # when nothing to map to, use the shorter one if map_to is None: if self.verbose_ > 0: - logger.warning( - "Potential unsafe merge between symbolic expressions: ({})".format( - ",".join(symbols) - ) - ) + logger.warning("Potential unsafe merge between symbolic expressions: ({})".format(",".join(symbols))) symbols_list = list(symbols) lens = [len(s) for s in symbols_list] map_to = symbols_list[lens.index(min(lens))] @@ -317,11 +315,10 @@ def _add_suggested_merge(self, symbols, apply=False): self._apply_suggested_merge() def _apply_suggested_merge(self, graph_input_only=False): + """Applies suggested merges to graph dimensions based on predefined rules, optionally affecting graph input only.""" if not self.suggested_merge_: return - for i in list(self.out_mp_.graph.input) + ( - [] if graph_input_only else list(self.out_mp_.graph.value_info) - ): + for i in list(self.out_mp_.graph.input) + ([] if graph_input_only else list(self.out_mp_.graph.value_info)): for d in i.type.tensor_type.shape.dim: if d.dim_param in self.suggested_merge_: v = self.suggested_merge_[d.dim_param] @@ -331,6 +328,7 @@ def _apply_suggested_merge(self, graph_input_only=False): d.dim_param = v def _preprocess(self, in_mp): + """Preprocess ONNX model by copying its structure and updating graph input and initializer dictionaries.""" self.out_mp_ = onnx.ModelProto() self.out_mp_.CopyFrom(in_mp) self.graph_inputs_ = {i.name: i for i in list(self.out_mp_.graph.input)} @@ -344,13 +342,12 @@ def _preprocess(self, in_mp): ) def _merge_symbols(self, dims): + """Merge dimension symbols, handling automatic merging and validation of symbolic dimensions.""" if not all([type(d) == str for d in dims]): # noqa: E721 if self.auto_merge_: unique_dims = list(set(dims)) is_int = [is_literal(d) for d in unique_dims] - assert ( - sum(is_int) <= 1 - ) # if there are more than 1 unique ints, something is wrong + assert sum(is_int) <= 1 # if there are more than 1 unique ints, something is wrong if sum(is_int) == 1: int_dim = is_int.index(1) if self.verbose_ > 0: @@ -364,17 +361,13 @@ def _merge_symbols(self, dims): return unique_dims[int_dim] else: if self.verbose_ > 0: - logger.debug( - f"dim {unique_dims[1:]} has been merged with dim {unique_dims[0]}" - ) + logger.debug(f"dim {unique_dims[1:]} has been merged with dim {unique_dims[0]}") return dims[0] else: return None if all([d == dims[0] for d in dims]): return dims[0] - merged = [ - self.suggested_merge_[d] if d in self.suggested_merge_ else d for d in dims - ] + merged = [self.suggested_merge_[d] if d in self.suggested_merge_ else d for d in dims] if all([d == merged[0] for d in merged]): assert merged[0] in self.symbolic_dims_ return merged[0] @@ -383,6 +376,7 @@ def _merge_symbols(self, dims): # broadcast from right to left, and merge symbolic dims if needed def _broadcast_shapes(self, shape1, shape2): + """Broadcasts two shapes from right to left, merging symbolic dimensions if necessary.""" new_shape = [] rank1 = len(shape1) rank2 = len(shape2) @@ -403,16 +397,12 @@ def _broadcast_shapes(self, shape1, shape2): if self.auto_merge_: self._add_suggested_merge([dim1, dim2], apply=True) else: - logger.warning( - "unsupported broadcast between " - + str(dim1) - + " " - + str(dim2) - ) + logger.warning("unsupported broadcast between " + str(dim1) + " " + str(dim2)) new_shape = [new_dim, *new_shape] return new_shape def _get_shape(self, node, idx): + """Retrieve the shape of a tensor from a node's inputs based on known value info or initializers.""" name = node.input[idx] if name in self.known_vi_: vi = self.known_vi_[name] @@ -422,6 +412,7 @@ def _get_shape(self, node, idx): return list(self.initializers_[name].dims) def _try_get_shape(self, node, idx): + """Attempts to retrieve the shape of the input node at the specified index using known value info or initializers.""" if idx > len(node.input) - 1: return None name = node.input[idx] @@ -433,9 +424,11 @@ def _try_get_shape(self, node, idx): return None def _get_shape_rank(self, node, idx): + """Return the rank (number of dimensions) of the shape of the input tensor at the specified index for a given node.""" return len(self._get_shape(node, idx)) def _get_sympy_shape(self, node, idx): + """Return the symbolic shape dimensions using SymPy for the given input tensor at the specified index for a node.""" sympy_shape = [] for d in self._get_shape(node, idx): if type(d) == str: # noqa: E721 @@ -450,15 +443,13 @@ def _get_sympy_shape(self, node, idx): return sympy_shape def _get_value(self, node, idx): + """Retrieve the value associated with a node's input at a given index from sympy_data_ or initializers_.""" name = node.input[idx] assert name in self.sympy_data_ or name in self.initializers_ - return ( - self.sympy_data_[name] - if name in self.sympy_data_ - else numpy_helper.to_array(self.initializers_[name]) - ) + return self.sympy_data_[name] if name in self.sympy_data_ else numpy_helper.to_array(self.initializers_[name]) def _try_get_value(self, node, idx): + """Attempt to retrieve a node's input value at a specified index or return None if the index is out of range.""" if idx >= len(node.input): return None name = node.input[idx] @@ -467,22 +458,21 @@ def _try_get_value(self, node, idx): return None def _update_computed_dims(self, new_sympy_shape): + """Update computed dimensions by replacing symbolic dimensions with suggested merged dimensions or adding new ones.""" for i, new_dim in enumerate(new_sympy_shape): if not is_literal(new_dim) and type(new_dim) != str: # noqa: E721 str_dim = str(new_dim) if str_dim in self.suggested_merge_: if is_literal(self.suggested_merge_[str_dim]): continue # no need to create dim for literals - new_sympy_shape[i] = self.symbolic_dims_[ - self.suggested_merge_[str_dim] - ] + new_sympy_shape[i] = self.symbolic_dims_[self.suggested_merge_[str_dim]] else: # add new_dim if it's a computational expression if str(new_dim) not in self.symbolic_dims_: self.symbolic_dims_[str(new_dim)] = new_dim def _onnx_infer_single_node(self, node): - # skip onnx shape inference for some ops, as they are handled in _infer_* + """Performs ONNX shape inference for a single node, skipping inference for specified operation types.""" skip_infer = node.op_type in [ "If", "Loop", @@ -545,23 +535,11 @@ def _onnx_infer_single_node(self, node): if node.output[0] in self.known_vi_: vi = self.known_vi_[node.output[0]] out_rank = len(get_shape_from_type_proto(vi.type)) - in_shapes = [ - self._get_shape(node, i) for i in range(len(node.input)) - ] + in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] for d in range( - out_rank - - ( - 2 - if node.op_type - in ["MatMul", "MatMulInteger", "MatMulInteger16"] - else 0 - ) + out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0) ): - in_dims = [ - s[len(s) - out_rank + d] - for s in in_shapes - if len(s) + d >= out_rank - ] + in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank] if len(in_dims) > 1: self._check_merged_dims(in_dims, allow_broadcast=True) @@ -587,36 +565,23 @@ def _onnx_infer_single_node(self, node): vi.name = o self.known_vi_[o] = vi - def _onnx_infer_subgraph( - self, node, subgraph, use_node_input=True, inc_subgraph_id=True - ): + def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph_id=True): + """Infer shapes and types within a subgraph for a given ONNX node using temporary graphs and known value information.""" if self.verbose_ > 2: - logger.debug( - f"Inferencing subgraph of node {node.name} with output({node.output[0]}...): {node.op_type}" - ) + logger.debug(f"Inferencing subgraph of node {node.name} with output({node.output[0]}...): {node.op_type}") # node inputs are not passed directly to the subgraph # it's up to the node dispatcher to prepare subgraph input # for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape # besides, inputs in subgraph could shadow implicit inputs - subgraph_inputs = { - i.name for i in list(subgraph.initializer) + list(subgraph.input) - } - subgraph_implicit_input = { - name for name in self.known_vi_ if name not in subgraph_inputs - } + subgraph_inputs = {i.name for i in list(subgraph.initializer) + list(subgraph.input)} + subgraph_implicit_input = {name for name in self.known_vi_ if name not in subgraph_inputs} tmp_graph = helper.make_graph( list(subgraph.node), "tmp", list(subgraph.input) + [self.known_vi_[i] for i in subgraph_implicit_input], [make_named_value_info(i.name) for i in subgraph.output], ) - tmp_graph.initializer.extend( - [ - i - for i in self.out_mp_.graph.initializer - if i.name in subgraph_implicit_input - ] - ) + tmp_graph.initializer.extend([i for i in self.out_mp_.graph.initializer if i.name in subgraph_implicit_input]) tmp_graph.initializer.extend(subgraph.initializer) self.tmp_mp_.graph.CopyFrom(tmp_graph) @@ -638,9 +603,7 @@ def _onnx_infer_subgraph( if use_node_input: # if subgraph uses node input, it needs to update to merged dims subgraph.ClearField("input") - subgraph.input.extend( - symbolic_shape_inference.out_mp_.graph.input[: len(node.input)] - ) + subgraph.input.extend(symbolic_shape_inference.out_mp_.graph.input[: len(node.input)]) subgraph.ClearField("output") subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output) subgraph.ClearField("value_info") @@ -648,10 +611,7 @@ def _onnx_infer_subgraph( subgraph.ClearField("node") subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node) # for new symbolic dims from subgraph output, add to main graph symbolic dims - subgraph_shapes = [ - get_shape_from_value_info(o) - for o in symbolic_shape_inference.out_mp_.graph.output - ] + subgraph_shapes = [get_shape_from_value_info(o) for o in symbolic_shape_inference.out_mp_.graph.output] subgraph_new_symbolic_dims = { d for s in subgraph_shapes @@ -667,6 +627,7 @@ def _onnx_infer_subgraph( return symbolic_shape_inference def _get_int_or_float_values(self, node, broadcast=False, allow_float_values=False): + """Extracts integer or float values from a node, with optional broadcasting and float value allowance.""" def int_or_float(value, allow_float_values): # If casting into int has precision loss: keep float output if allow_float_values and value % 1 != 0: @@ -704,15 +665,14 @@ def int_or_float(value, allow_float_values): return values def _compute_on_sympy_data(self, node, op_func): + """Calculate the result using Sympy data and a specified operation function.""" assert len(node.output) == 1 # Before mul & div operations - # cast inputs into interger might lose decimal part and reduce precision + # cast inputs into integer might lose decimal part and reduce precision # keep them as float, finish the operation, then cast the result into integer if node.op_type in ["Mul", "Div"]: - values = self._get_int_or_float_values( - node, broadcast=True, allow_float_values=True - ) + values = self._get_int_or_float_values(node, broadcast=True, allow_float_values=True) else: values = self._get_int_or_float_values(node, broadcast=True) @@ -725,6 +685,7 @@ def _compute_on_sympy_data(self, node, op_func): self.sympy_data_[node.output[0]] = op_func(values) def _pass_on_sympy_data(self, node): + """Pass Sympy data through a node, validating input length or node operation type 'Reshape', 'Unsqueeze', 'Squeeze'.""" assert len(node.input) == 1 or node.op_type in [ "Reshape", "Unsqueeze", @@ -733,6 +694,7 @@ def _pass_on_sympy_data(self, node): self._compute_on_sympy_data(node, lambda x: x[0]) def _pass_on_shape_and_type(self, node): + """Propagates the shape and type information from input to output for a given node.""" vi = self.known_vi_[node.output[0]] vi.CopyFrom( helper.make_tensor_value_info( @@ -743,6 +705,7 @@ def _pass_on_shape_and_type(self, node): ) def _new_symbolic_dim(self, prefix, dim): + """Create and return a new symbolic dimension, handling literal values and caching results.""" new_dim = f"{prefix}_d{dim}" if new_dim in self.suggested_merge_: v = self.suggested_merge_[new_dim] @@ -753,6 +716,7 @@ def _new_symbolic_dim(self, prefix, dim): return new_symbolic_dim def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0): + """Generate a new symbolic dimension for a given node output using node operation type and indexing information.""" return self._new_symbolic_dim( "{}{}_{}_o{}_".format( node.op_type, @@ -764,11 +728,11 @@ def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0): ) def _new_symbolic_shape(self, rank, node, out_idx=0): - return [ - self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank) - ] + """Return a list of new symbolic dimensions for a given node output based on the specified rank.""" + return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)] def _compute_conv_pool_shape(self, node, channels_last=False): + """Compute and return the output shape of a convolutional or pooling layer for a given node, considering the channel order.""" sympy_shape = self._get_sympy_shape(node, 0) if len(node.input) > 1: W_shape = self._get_sympy_shape(node, 1) # noqa: N806 @@ -783,9 +747,7 @@ def _compute_conv_pool_shape(self, node, channels_last=False): assert len(sympy_shape) == rank + 2 # only need to symbolic shape inference if input has symbolic dims in spatial axes - spatial_shape = ( - sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:] - ) + spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:] is_symbolic_dims = [not is_literal(i) for i in spatial_shape] if not any(is_symbolic_dims): @@ -793,34 +755,26 @@ def _compute_conv_pool_shape(self, node, channels_last=False): if len(shape) > 0: assert len(sympy_shape) == len(shape) if channels_last: - sympy_shape[-rank - 1 : -1] = [ - sympy.Integer(d) for d in shape[-rank - 1 : -1] - ] + sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]] else: sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]] return sympy_shape dilations = get_attribute(node, "dilations", [1] * rank) strides = get_attribute(node, "strides", [1] * rank) - effective_kernel_shape = [ - (k - 1) * d + 1 for k, d in zip(kernel_shape, dilations) - ] + effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)] pads = get_attribute(node, "pads") if pads is None: pads = [0] * (2 * rank) auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8") if auto_pad != "VALID" and auto_pad != "NOTSET": try: - residual = [ - sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides) - ] + residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)] total_pads = [ max(0, (k - s) if r == 0 else (k - r)) for k, s, r in zip(effective_kernel_shape, strides, residual) ] - except ( - TypeError - ): # sympy may throw TypeError: cannot determine truth value of Relational + except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational total_pads = [ max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides) ] # assuming no residual if sympy throws error @@ -842,21 +796,19 @@ def _compute_conv_pool_shape(self, node, channels_last=False): (effective_input_size - effective_kernel_shape[i]) / strides[i] ) else: - strided_kernel_positions = ( - effective_input_size - effective_kernel_shape[i] - ) // strides[i] - sympy_shape[-rank + i + (-1 if channels_last else 0)] = ( - strided_kernel_positions + 1 - ) + strided_kernel_positions = (effective_input_size - effective_kernel_shape[i]) // strides[i] + sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1 return sympy_shape def _check_merged_dims(self, dims, allow_broadcast=True): + """Checks merged dimensions for consistency, optionally allowing broadcasting.""" if allow_broadcast: dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)] if not all([d == dims[0] for d in dims]): self._add_suggested_merge(dims, apply=True) def _compute_matmul_shape(self, node, output_dtype=None): + """Compute the output shape for a matrix multiplication operation based on input shapes and optionally infer the output data type.""" lhs_shape = self._get_shape(node, 0) rhs_shape = self._get_shape(node, 1) lhs_rank = len(lhs_shape) @@ -889,23 +841,15 @@ def _compute_matmul_shape(self, node, output_dtype=None): # infer output_dtype from input type when not specified output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, new_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): - """ - update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches - """ + """Update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches.""" dst_tensor_type = ( - dst_type.sequence_type.elem_type.tensor_type - if is_sequence(dst_type) - else dst_type.tensor_type + dst_type.sequence_type.elem_type.tensor_type if is_sequence(dst_type) else dst_type.tensor_type ) src_tensor_type = ( - src_type.sequence_type.elem_type.tensor_type - if is_sequence(src_type) - else src_type.tensor_type + src_type.sequence_type.elem_type.tensor_type if is_sequence(src_type) else src_type.tensor_type ) if dst_tensor_type.elem_type != src_tensor_type.elem_type: node_id = node.name if node.name else node.op_type @@ -915,22 +859,19 @@ def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}" ) if dst_tensor_type.HasField("shape"): - for di, ds in enumerate( - zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim) - ): + for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)): if ds[0] != ds[1]: # create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type # for sequence_type, clear the dimension new_dim = onnx.TensorShapeProto.Dimension() if not is_sequence(dst_type): - new_dim.dim_param = str( - self._new_symbolic_dim_from_output(node, out_idx, di) - ) + new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, out_idx, di)) dst_tensor_type.shape.dim[di].CopyFrom(new_dim) else: dst_tensor_type.CopyFrom(src_tensor_type) def _infer_ArrayFeatureExtractor(self, node): # noqa: N802 + """Infer and update the shape and type information for the ArrayFeatureExtractor node using input data and indices shapes.""" data_shape = self._get_shape(node, 0) indices_shape = self._get_shape(node, 1) vi = self.known_vi_[node.output[0]] @@ -943,6 +884,7 @@ def _infer_ArrayFeatureExtractor(self, node): # noqa: N802 ) def _infer_symbolic_compute_ops(self, node): + """Handles symbolic computation for nodes using predefined operation mappings such as Add, Div, Equal, Floor, Max, Min, Mul, Sub, Where, and Neg.""" funcs = { "Add": lambda l: l[0] + l[1], # noqa: E741 "Div": lambda l: ( @@ -955,24 +897,14 @@ def _infer_symbolic_compute_ops(self, node): "Max": lambda l: ( l[1] # noqa: E741 if is_literal(l[0]) and int(l[0]) < -self.int_max_ - else ( - l[0] - if is_literal(l[1]) and int(l[1]) < -self.int_max_ - else sympy.Max(l[0], l[1]) - ) + else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])) ), "Min": lambda l: ( l[1] # noqa: E741 if is_literal(l[0]) and int(l[0]) > self.int_max_ - else ( - l[0] - if is_literal(l[1]) and int(l[1]) > self.int_max_ - else sympy.Min(l[0], l[1]) - ) + else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])) ), - "Mul": lambda l: ( - int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1] - ), # noqa: E741 + "Mul": lambda l: (int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1]), # noqa: E741 "Sub": lambda l: l[0] - l[1], # noqa: E741 "Where": lambda l: l[1] if l[0] else l[2], # noqa: E741 "Neg": lambda l: -l[0], # noqa: E741 @@ -981,22 +913,21 @@ def _infer_symbolic_compute_ops(self, node): self._compute_on_sympy_data(node, funcs[node.op_type]) def _infer_Cast(self, node): # noqa: N802 + """Pass node's data to SymPy representation for type casting operation.""" self._pass_on_sympy_data(node) def _infer_CategoryMapper(self, node): # noqa: N802 + """Infer the output type for a CategoryMapper node based on the input tensor type.""" input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type if input_type == onnx.TensorProto.STRING: output_type = onnx.TensorProto.INT64 else: output_type = onnx.TensorProto.STRING vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_type, self._get_shape(node, 0) - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_type, self._get_shape(node, 0))) def _infer_Compress(self, node): # noqa: N802 + """Infer the output shape and type for the Compress operation based on input shape and axis attribute.""" input_shape = self._get_shape(node, 0) # create a new symbolic dimension for Compress output compress_len = str(self._new_symbolic_dim_from_output(node)) @@ -1017,6 +948,7 @@ def _infer_Compress(self, node): # noqa: N802 ) def _infer_Concat(self, node): # noqa: N802 + """Infer the output shape and type for the Concat operation based on input values.""" if any([i in self.sympy_data_ or i in self.initializers_ for i in node.input]): values = self._get_int_or_float_values(node) if all([v is not None for v in values]): @@ -1040,11 +972,7 @@ def _infer_Concat(self, node): # noqa: N802 for d in range(len(sympy_shape)): if d == axis: continue - dims = [ - self._get_shape(node, i_idx)[d] - for i_idx in range(len(node.input)) - if self._get_shape(node, i_idx) - ] + dims = [self._get_shape(node, i_idx)[d] for i_idx in range(len(node.input)) if self._get_shape(node, i_idx)] if all([d == dims[0] for d in dims]): continue merged = self._merge_symbols(dims) @@ -1062,11 +990,10 @@ def _infer_Concat(self, node): # noqa: N802 ) def _infer_ConcatFromSequence(self, node): # noqa: N802 + """Infers the output shape for ConcatFromSequence node based on its input sequence shape and axis attributes.""" seq_shape = self._get_shape(node, 0) new_axis = 1 if get_attribute(node, "new_axis") else 0 - axis = handle_negative_axis( - get_attribute(node, "axis"), len(seq_shape) + new_axis - ) + axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis) concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis)) new_shape = seq_shape if new_axis: @@ -1077,18 +1004,18 @@ def _infer_ConcatFromSequence(self, node): # noqa: N802 vi.CopyFrom( helper.make_tensor_value_info( node.output[0], - self.known_vi_[ - node.input[0] - ].type.sequence_type.elem_type.tensor_type.elem_type, + self.known_vi_[node.input[0]].type.sequence_type.elem_type.tensor_type.elem_type, new_shape, ) ) def _infer_Constant(self, node): # noqa: N802 + """Infer the constant value for a given node and store it in sympy data.""" t = get_attribute(node, "value") self.sympy_data_[node.output[0]] = numpy_helper.to_array(t) def _infer_ConstantOfShape(self, node): # noqa: N802 + """Infer the constant tensor of a given shape from a node and update sympy data accordingly.""" sympy_shape = self._get_int_or_float_values(node)[0] vi = self.known_vi_[node.output[0]] if sympy_shape is not None: @@ -1096,9 +1023,7 @@ def _infer_ConstantOfShape(self, node): # noqa: N802 sympy_shape = [sympy_shape] self._update_computed_dims(sympy_shape) # update sympy data if output type is int, and shape is known - if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all( - [is_literal(x) for x in sympy_shape] - ): + if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all([is_literal(x) for x in sympy_shape]): self.sympy_data_[node.output[0]] = np.ones( [int(x) for x in sympy_shape], dtype=np.int64 ) * numpy_helper.to_array(get_attribute(node, "value", 0)) @@ -1116,6 +1041,7 @@ def _infer_ConstantOfShape(self, node): # noqa: N802 ) def _infer_Conv(self, node): # noqa: N802 + """Infers the shape of the output tensor for a convolutional layer node.""" sympy_shape = self._compute_conv_pool_shape(node) self._update_computed_dims(sympy_shape) vi = self.known_vi_[node.output[0]] @@ -1128,6 +1054,7 @@ def _infer_Conv(self, node): # noqa: N802 ) def _infer_NhwcConv(self, node): # noqa: N802 + """Infer the shape of the output tensor for a convolutional layer with NHWC format.""" sympy_shape = self._compute_conv_pool_shape(node, channels_last=True) self._update_computed_dims(sympy_shape) vi = self.known_vi_[node.output[0]] @@ -1140,19 +1067,17 @@ def _infer_NhwcConv(self, node): # noqa: N802 ) def _infer_DequantizeLinear(self, node): # noqa: N802 - # Get the output data type from the scale input (index 1, required). + """Infers the output value info for DequantizeLinear node using the datatype from the scale input.""" output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type # Get the output shape from the first input. output_shape = self._get_shape(node, 0) vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, output_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_QuantizeLinear(self, node): # noqa: N802 - # Get the output data type from the zero-point input (index 2, optional). + """Infer the output data type and shape for the QuantizeLinear ONNX node, defaulting to uint8 if not specified.""" # Otherwise, default to uint8 output_dtype = onnx.TensorProto.UINT8 if len(node.input) > 2 and node.input[2]: @@ -1162,12 +1087,10 @@ def _infer_QuantizeLinear(self, node): # noqa: N802 output_shape = self._get_shape(node, 0) vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, output_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_Einsum(self, node): # noqa: N802 - # ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275 + """Infer the output shape and type for the Einsum operation as per ONNX standards: https://github.com/onnx/onnx/blob/623dfaa/onnx/defs/math/defs.cc#L3275.""" equation = get_attribute(node, "equation") equation = equation.replace(b" ", b"") mid_index = equation.find(b"->") @@ -1226,19 +1149,16 @@ def _infer_Einsum(self, node): # noqa: N802 output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, new_sympy_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_sympy_shape)) def _infer_Expand(self, node): # noqa: N802 + """Infers and updates the output shape for the Expand node based on broadcasted input shapes.""" expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True) if expand_to_shape is not None: # new_shape's dim can come from shape value self._update_computed_dims(expand_to_shape) shape = self._get_shape(node, 0) - new_shape = self._broadcast_shapes( - shape, get_shape_from_sympy_shape(expand_to_shape) - ) + new_shape = self._broadcast_shapes(shape, get_shape_from_sympy_shape(expand_to_shape)) vi = self.known_vi_[node.output[0]] vi.CopyFrom( helper.make_tensor_value_info( @@ -1249,6 +1169,7 @@ def _infer_Expand(self, node): # noqa: N802 ) def _infer_Gather(self, node): # noqa: N802 + """Infer the output shape of the Gather operation based on input shapes, axis, and indices properties.""" data_shape = self._get_shape(node, 0) axis = handle_negative_axis(get_attribute(node, "axis", 0), len(data_shape)) indices_shape = self._get_shape(node, 1) @@ -1261,11 +1182,7 @@ def _infer_Gather(self, node): # noqa: N802 ) ) # for 1D input, do some sympy compute - if ( - node.input[0] in self.sympy_data_ - and len(data_shape) == 1 - and get_attribute(node, "axis", 0) == 0 - ): + if node.input[0] in self.sympy_data_ and len(data_shape) == 1 and get_attribute(node, "axis", 0) == 0: idx = self._try_get_value(node, 1) if idx is not None: data = self.sympy_data_[node.input[0]] @@ -1279,6 +1196,7 @@ def _infer_Gather(self, node): # noqa: N802 self.sympy_data_[node.output[0]] = data def _infer_GatherElements(self, node): # noqa: N802 + """Infer tensor value information for the GatherElements operation using the node's output and input shape.""" indices_shape = self._get_shape(node, 1) vi = self.known_vi_[node.output[0]] vi.CopyFrom( @@ -1290,6 +1208,7 @@ def _infer_GatherElements(self, node): # noqa: N802 ) def _infer_GatherND(self, node): # noqa: N802 + """Infers the output shape and type for the GatherND operation based on input data and indices shapes.""" data_shape = self._get_shape(node, 0) data_rank = len(data_shape) indices_shape = self._get_shape(node, 1) @@ -1307,7 +1226,7 @@ def _infer_GatherND(self, node): # noqa: N802 ) def _infer_If(self, node): # noqa: N802 - # special case for constant condition, in case there are mismatching shape from the non-executed branch + """Infer the output shape for an If node, handling constant conditions to ensure shape consistency between branches.""" subgraphs = [ get_attribute(node, "then_branch"), get_attribute(node, "else_branch"), @@ -1320,32 +1239,25 @@ def _infer_If(self, node): # noqa: N802 subgraphs[0].CopyFrom(subgraphs[1]) for i_sub, subgraph in enumerate(subgraphs): - subgraph_infer = self._onnx_infer_subgraph( - node, subgraph, use_node_input=False - ) + subgraph_infer = self._onnx_infer_subgraph(node, subgraph, use_node_input=False) for i_out in range(len(node.output)): vi = self.known_vi_[node.output[i_out]] if i_sub == 0: vi.CopyFrom(subgraph.output[i_out]) vi.name = node.output[i_out] else: - self._fuse_tensor_type( - node, i_out, vi.type, subgraph.output[i_out].type - ) + self._fuse_tensor_type(node, i_out, vi.type, subgraph.output[i_out].type) # pass on sympy data from subgraph, if cond is constant if cond is not None and i_sub == (0 if as_scalar(cond) > 0 else 1): if subgraph.output[i_out].name in subgraph_infer.sympy_data_: - self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[ - subgraph.output[i_out].name - ] + self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[subgraph.output[i_out].name] def _infer_Loop(self, node): # noqa: N802 + """Infer the shape and type of variables produced by the 'Loop' operation in an ONNX graph.""" subgraph = get_attribute(node, "body") assert len(subgraph.input) == len(node.input) - num_loop_carried = ( - len(node.input) - 2 - ) # minus the length and initial loop condition + num_loop_carried = len(node.input) - 2 # minus the length and initial loop condition # when sequence_type is used as loop carried input # needs to run subgraph infer twice if the tensor shape in sequence contains None for i, si in enumerate(subgraph.input): @@ -1367,9 +1279,7 @@ def _infer_Loop(self, node): # noqa: N802 # copy shape from output to input # note that loop input is [loop_len, cond, input_0, input_1, ...] # while loop output is [cond, output_0, output_1, ...] - subgraph.input[i_out + 1].type.sequence_type.elem_type.CopyFrom( - so.type.sequence_type.elem_type - ) + subgraph.input[i_out + 1].type.sequence_type.elem_type.CopyFrom(so.type.sequence_type.elem_type) need_second_infer = True else: si = subgraph.input[i_out + 1] @@ -1377,9 +1287,7 @@ def _infer_Loop(self, node): # noqa: N802 for di, dims in enumerate(zip(si_shape, so_shape)): if dims[0] != dims[1]: new_dim = onnx.TensorShapeProto.Dimension() - new_dim.dim_param = str( - self._new_symbolic_dim_from_output(node, i_out, di) - ) + new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, i_out, di)) si.type.tensor_type.shape.dim[di].CopyFrom(new_dim) so.type.tensor_type.shape.dim[di].CopyFrom(new_dim) need_second_infer = True @@ -1397,13 +1305,9 @@ def _infer_Loop(self, node): # noqa: N802 loop_iter_dim = str(self._new_symbolic_dim_from_output(node)) for i in range(len(node.output)): vi = self.known_vi_[node.output[i]] - vi.CopyFrom( - subgraph.output[i + 1] - ) # first subgraph output is condition, not in node output + vi.CopyFrom(subgraph.output[i + 1]) # first subgraph output is condition, not in node output if i >= num_loop_carried: - assert not is_sequence( - vi.type - ) # TODO: handle loop accumulation in sequence_type + assert not is_sequence(vi.type) # TODO: handle loop accumulation in sequence_type subgraph_vi_dim = subgraph.output[i + 1].type.tensor_type.shape.dim vi.type.tensor_type.shape.ClearField("dim") vi_dim = vi.type.tensor_type.shape.dim @@ -1412,45 +1316,36 @@ def _infer_Loop(self, node): # noqa: N802 vi.name = node.output[i] def _infer_MatMul(self, node): # noqa: N802 + """Infer the output shape of a MatMul node by computing the matrix multiplication dimensions.""" self._compute_matmul_shape(node) def _infer_MatMulInteger(self, node): # noqa: N802 + """Infer the output shape of a MatMulInteger node by computing the matrix multiplication dimensions using INT32 data type.""" self._compute_matmul_shape(node, onnx.TensorProto.INT32) def _infer_NonMaxSuppression(self, node): # noqa: N802 + """Infer the output shape of a NonMaxSuppression node, returning selected indices as INT64 tensor with shape [selected, 3].""" selected = str(self._new_symbolic_dim_from_output(node)) vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], onnx.TensorProto.INT64, [selected, 3] - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [selected, 3])) def _infer_NonZero(self, node): # noqa: N802 + """Infer the shape of NonZero operation output, assigning a new symbolic dimension.""" input_rank = self._get_shape_rank(node, 0) # create a new symbolic dimension for NonZero output nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1)) vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], vi.type.tensor_type.elem_type, [input_rank, nz_len] - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, [input_rank, nz_len])) def _infer_OneHot(self, node): # noqa: N802 + """Infer the shape and type of the output tensor for the OneHot node operation.""" sympy_shape = self._get_sympy_shape(node, 0) depth = self._try_get_value(node, 1) axis = get_attribute(node, "axis", -1) axis = handle_negative_axis(axis, len(sympy_shape) + 1) new_shape = get_shape_from_sympy_shape( sympy_shape[:axis] - + [ - ( - self._new_symbolic_dim_from_output(node) - if not is_literal(depth) - else depth - ) - ] + + [(self._new_symbolic_dim_from_output(node) if not is_literal(depth) else depth)] + sympy_shape[axis:] ) vi = self.known_vi_[node.output[0]] @@ -1463,6 +1358,7 @@ def _infer_OneHot(self, node): # noqa: N802 ) def _infer_Pad(self, node): # noqa: N802 + """Infers the output shape and type for the Pad operation based on ONNX node attributes and opset version.""" if get_opset(self.out_mp_) <= 10: pads = get_attribute(node, "pads") else: @@ -1474,8 +1370,7 @@ def _infer_Pad(self, node): # noqa: N802 if pads is not None: assert len(pads) == 2 * rank new_sympy_shape = [ - d + pad_up + pad_down - for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:]) + d + pad_up + pad_down for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:]) ] self._update_computed_dims(new_sympy_shape) else: @@ -1485,12 +1380,11 @@ def _infer_Pad(self, node): # noqa: N802 vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape) - ) + helper.make_tensor_value_info(node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape)) ) def _infer_Pool(self, node): # noqa: N802 + """Infer and update dimensions for pooling layers based on the input node.""" sympy_shape = self._compute_conv_pool_shape(node) self._update_computed_dims(sympy_shape) for o in node.output: @@ -1506,18 +1400,16 @@ def _infer_Pool(self, node): # noqa: N802 ) def _infer_aten_bitwise_or(self, node): + """Infers the output shape for Aten bitwise OR operation based on input node shapes.""" shape0 = self._get_shape(node, 0) shape1 = self._get_shape(node, 1) new_shape = self._broadcast_shapes(shape0, shape1) t0 = self.known_vi_[node.input[0]] vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], t0.type.tensor_type.elem_type, new_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type, new_shape)) def _infer_aten_diagonal(self, node): + """Infers the shape of the diagonal of a tensor given a node using sympy shapes and specified dimensions.""" sympy_shape = self._get_sympy_shape(node, 0) rank = len(sympy_shape) offset = self._try_get_value(node, 1) @@ -1552,16 +1444,13 @@ def _infer_aten_diagonal(self, node): ) def _infer_aten_multinomial(self, node): + """Infers the output shape and type for the PyTorch multinomial operation in an ONNX graph node.""" sympy_shape = self._get_sympy_shape(node, 0) rank = len(sympy_shape) assert rank in [1, 2] num_samples = self._try_get_value(node, 1) di = rank - 1 - last_dim = ( - num_samples - if num_samples - else str(self._new_symbolic_dim_from_output(node, 0, di)) - ) + last_dim = num_samples if num_samples else str(self._new_symbolic_dim_from_output(node, 0, di)) output_shape = sympy_shape[:-1] + [last_dim] vi = self.known_vi_[node.output[0]] vi.CopyFrom( @@ -1573,28 +1462,20 @@ def _infer_aten_multinomial(self, node): ) def _infer_aten_pool2d(self, node): + """Infer the output shape of a 2D pooling operation in an ATen graph node.""" sympy_shape = self._get_sympy_shape(node, 0) assert len(sympy_shape) == 4 - sympy_shape[-2:] = [ - self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3] - ] + sympy_shape[-2:] = [self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3]] self._update_computed_dims(sympy_shape) for i, o in enumerate(node.output): if not o: continue vi = self.known_vi_[o] - elem_type = ( - onnx.TensorProto.INT64 - if i == 1 - else self.known_vi_[node.input[0]].type.tensor_type.elem_type - ) - vi.CopyFrom( - helper.make_tensor_value_info( - o, elem_type, get_shape_from_sympy_shape(sympy_shape) - ) - ) + elem_type = onnx.TensorProto.INT64 if i == 1 else self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi.CopyFrom(helper.make_tensor_value_info(o, elem_type, get_shape_from_sympy_shape(sympy_shape))) def _infer_aten_minmax(self, node): + """Infer the output shape and type for ATen Min/Max operations in an ONNX node.""" vi = self.known_vi_[node.output[0]] if len(node.input) == 1: vi.CopyFrom( @@ -1611,9 +1492,7 @@ def _infer_aten_minmax(self, node): dim = self._try_get_value(node, 1) if dim is None: rank = self._get_shape_rank(node, 0) - output_shape = self._new_symbolic_shape( - rank if keepdim else rank - 1, node - ) + output_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node) else: shape = self._get_sympy_shape(node, 0) dim = handle_negative_axis(dim, len(shape)) @@ -1631,13 +1510,10 @@ def _infer_aten_minmax(self, node): ) ) vi1 = self.known_vi_[node.output[1]] - vi1.CopyFrom( - helper.make_tensor_value_info( - node.output[1], onnx.TensorProto.INT64, output_shape - ) - ) + vi1.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT64, output_shape)) def _infer_aten_unfold(self, node): + """Infer the tensor shape for the 'aten::unfold' operation based on input shape and parameters.""" sympy_shape = self._get_sympy_shape(node, 0) dimension = self._try_get_value(node, 1) size = self._try_get_value(node, 2) @@ -1661,6 +1537,7 @@ def _infer_aten_unfold(self, node): ) def _infer_aten_argmax(self, node): + """Infers the output shape for the ONNX ATen argmax operation.""" new_shape = None if not node.input[1]: # The argmax of the flattened input is returned. @@ -1678,27 +1555,18 @@ def _infer_aten_argmax(self, node): del sympy_shape[dim] else: rank = len(sympy_shape) - sympy_shape = self._new_symbolic_shape( - rank if keepdim else rank - 1, node - ) + sympy_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node) self._update_computed_dims(sympy_shape) new_shape = get_shape_from_sympy_shape(sympy_shape) if node.output[0] and new_shape is not None: vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], onnx.TensorProto.INT64, new_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape)) def _infer_aten_group_norm(self, node): + """Infers the output shapes and types for the ATen GroupNorm operation based on the provided node information.""" self._propagate_shape_and_type(node) input_shape = self._get_shape(node, 0) - N = ( - input_shape[0] - if input_shape is not None and len(input_shape) != 0 - else None - ) # noqa: N806 + N = input_shape[0] if input_shape is not None and len(input_shape) != 0 else None # noqa: N806 group = self._try_get_value(node, 6) output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type for i in [1, 2]: @@ -1709,11 +1577,7 @@ def _infer_aten_group_norm(self, node): node.output[i], output_dtype, [ - ( - N - if N is not None - else str(self._new_symbolic_dim_from_output(node, i, 0)) - ), + (N if N is not None else str(self._new_symbolic_dim_from_output(node, i, 0))), ( as_scalar(group) if group is not None @@ -1724,30 +1588,24 @@ def _infer_aten_group_norm(self, node): ) def _infer_aten_upsample(self, node): + """Infers the output shape for the aten::upsample operation based on the input shape and output size values.""" new_shape = None input_shape = self._get_shape(node, 0) if input_shape is not None: new_shape = input_shape[:2] output_size = self._try_get_value(node, 1) if output_size is not None: - new_shape += [ - dim_size.item() if type(dim_size) == np.int64 else dim_size - for dim_size in output_size - ] + new_shape += [dim_size.item() if type(dim_size) == np.int64 else dim_size for dim_size in output_size] else: rank = len(input_shape) - new_shape += [ - str(self._new_symbolic_dim_from_output(node, 0, i)) - for i in range(2, rank) - ] + new_shape += [str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)] if node.output[0] and new_shape is not None: output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, new_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) def _infer_BatchNormalization(self, node): # noqa: N802 + """Propagate the shape and type information for the BatchNormalization node.""" self._propagate_shape_and_type(node) # this works for opsets < 14 and 14 since we check i < len(node.output) in the loop @@ -1757,6 +1615,7 @@ def _infer_BatchNormalization(self, node): # noqa: N802 self._propagate_shape_and_type(node, input_index=1, output_index=i) def _infer_Range(self, node): # noqa: N802 + """Infers the shape and type for Range nodes based on the provided start, limit, and delta values.""" vi = self.known_vi_[node.output[0]] input_data = self._get_int_or_float_values(node) if all([i is not None for i in input_data]): @@ -1776,6 +1635,7 @@ def _infer_Range(self, node): # noqa: N802 ) def _infer_ReduceSum(self, node): # noqa: N802 + """Infer output shape for ReduceSum operation based on input shape, axes, and keep_dims attribute.""" keep_dims = get_attribute(node, "keepdims", 1) if get_opset(self.out_mp_) >= 13 and len(node.input) > 1: # ReduceSum changes axes to input[1] in opset 13 @@ -1787,11 +1647,7 @@ def _infer_ReduceSum(self, node): # noqa: N802 helper.make_tensor_value_info( node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape( - self._new_symbolic_shape( - self._get_shape_rank(node, 0), node - ) - ), + get_shape_from_sympy_shape(self._new_symbolic_shape(self._get_shape_rank(node, 0), node)), ) ) else: @@ -1813,6 +1669,7 @@ def _infer_ReduceSum(self, node): # noqa: N802 ) def _infer_ReduceProd(self, node): # noqa: N802 + """Infer the ReduceProd operation's output shape and sympy data given node attributes.""" axes = get_attribute(node, "axes") keep_dims = get_attribute(node, "keepdims", 1) if keep_dims == 0 and axes == [0]: @@ -1821,6 +1678,7 @@ def _infer_ReduceProd(self, node): # noqa: N802 self.sympy_data_[node.output[0]] = sympy_reduce_product(data) def _infer_RelativePositionBias(self, node): # noqa: N802 + """Infers the relative position bias for a given ONNX node.""" seq_len = self._try_get_value(node, 1) real_seq_len = self._try_get_value(node, 2) if seq_len is None or real_seq_len is None: @@ -1831,11 +1689,10 @@ def _infer_RelativePositionBias(self, node): # noqa: N802 output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, new_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) def _infer_Reshape(self, node): # noqa: N802 + """Infer the output shape for the Reshape operation using given node attributes and input shapes.""" shape_value = self._try_get_value(node, 1) vi = self.known_vi_[node.output[0]] if shape_value is None: @@ -1847,9 +1704,7 @@ def _infer_Reshape(self, node): # noqa: N802 helper.make_tensor_value_info( node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape( - self._new_symbolic_shape(shape_rank, node) - ), + get_shape_from_sympy_shape(self._new_symbolic_shape(shape_rank, node)), ) ) else: @@ -1890,15 +1745,13 @@ def _infer_Reshape(self, node): # noqa: N802 self._pass_on_sympy_data(node) def _infer_Resize(self, node): # noqa: N802 + """Infers and updates the shape of the output tensor for a Resize node based on scales or sizes.""" vi = self.known_vi_[node.output[0]] input_sympy_shape = self._get_sympy_shape(node, 0) if get_opset(self.out_mp_) <= 10: scales = self._try_get_value(node, 1) if scales is not None: - new_sympy_shape = [ - sympy.simplify(sympy.floor(d * s)) - for d, s in zip(input_sympy_shape, scales) - ] + new_sympy_shape = [sympy.simplify(sympy.floor(d * s)) for d, s in zip(input_sympy_shape, scales)] self._update_computed_dims(new_sympy_shape) vi.CopyFrom( helper.make_tensor_value_info( @@ -1916,10 +1769,7 @@ def _infer_Resize(self, node): # noqa: N802 self._update_computed_dims(new_sympy_shape) elif scales is not None: rank = len(scales) - if ( - get_attribute(node, "coordinate_transformation_mode") - == "tf_crop_and_resize" - ): + if get_attribute(node, "coordinate_transformation_mode") == "tf_crop_and_resize": assert len(roi) == 2 * rank roi_start = list(roi)[:rank] roi_end = list(roi)[rank:] @@ -1929,15 +1779,11 @@ def _infer_Resize(self, node): # noqa: N802 scales = list(scales) new_sympy_shape = [ sympy.simplify(sympy.floor(d * (end - start) * scale)) - for d, start, end, scale in zip( - input_sympy_shape, roi_start, roi_end, scales - ) + for d, start, end, scale in zip(input_sympy_shape, roi_start, roi_end, scales) ] self._update_computed_dims(new_sympy_shape) else: - new_sympy_shape = self._new_symbolic_shape( - self._get_shape_rank(node, 0), node - ) + new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) vi.CopyFrom( helper.make_tensor_value_info( @@ -1948,6 +1794,7 @@ def _infer_Resize(self, node): # noqa: N802 ) def _infer_Scan(self, node): # noqa: N802 + """Infer shape and type information for the ONNX 'Scan' operator node.""" subgraph = get_attribute(node, "body") num_scan_inputs = get_attribute(node, "num_scan_inputs") scan_input_axes = get_attribute(node, "scan_input_axes", [0] * num_scan_inputs) @@ -1965,36 +1812,25 @@ def _infer_Scan(self, node): # noqa: N802 si.CopyFrom(self.known_vi_[node.input[i]]) if i >= num_scan_states: scan_input_dim = si.type.tensor_type.shape.dim - scan_input_dim.remove( - scan_input_dim[scan_input_axes[i - num_scan_states]] - ) + scan_input_dim.remove(scan_input_dim[scan_input_axes[i - num_scan_states]]) si.name = subgraph_name self._onnx_infer_subgraph(node, subgraph) num_scan_outputs = len(node.output) - num_scan_states - scan_output_axes = get_attribute( - node, "scan_output_axes", [0] * num_scan_outputs - ) - scan_input_dim = get_shape_from_type_proto(self.known_vi_[node.input[-1]].type)[ - scan_input_axes[-1] - ] + scan_output_axes = get_attribute(node, "scan_output_axes", [0] * num_scan_outputs) + scan_input_dim = get_shape_from_type_proto(self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]] for i, o in enumerate(node.output): vi = self.known_vi_[o] if i >= num_scan_states: shape = get_shape_from_type_proto(subgraph.output[i].type) - new_dim = handle_negative_axis( - scan_output_axes[i - num_scan_states], len(shape) + 1 - ) + new_dim = handle_negative_axis(scan_output_axes[i - num_scan_states], len(shape) + 1) shape = shape[:new_dim] + [scan_input_dim] + shape[new_dim:] - vi.CopyFrom( - helper.make_tensor_value_info( - o, subgraph.output[i].type.tensor_type.elem_type, shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(o, subgraph.output[i].type.tensor_type.elem_type, shape)) else: vi.CopyFrom(subgraph.output[i]) vi.name = o def _infer_ScatterElements(self, node): # noqa: N802 + """Infer the output shape and type for the ScatterElements operation.""" data_shape = self._get_shape(node, 0) vi = self.known_vi_[node.output[0]] vi.CopyFrom( @@ -2006,7 +1842,7 @@ def _infer_ScatterElements(self, node): # noqa: N802 ) def _infer_SequenceAt(self, node): # noqa: N802 - # need to create new symbolic dimension if sequence shape has None: + """Infers the shape and type for the output of the 'SequenceAt' ONNX operation, handling symbolic dimensions if necessary.""" seq_shape = self._get_shape(node, 0) vi = self.known_vi_[node.output[0]] if seq_shape is not None: @@ -2018,7 +1854,7 @@ def _infer_SequenceAt(self, node): # noqa: N802 vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim) def _infer_SequenceInsert(self, node): # noqa: N802 - # workaround bug in onnx's shape inference + """Workaround ONNX's shape inference bug by fusing tensor types for sequence insert operations.""" vi_seq = self.known_vi_[node.input[0]] vi_tensor = self.known_vi_[node.input[1]] vi_out_seq = self.known_vi_[node.output[0]] @@ -2027,9 +1863,11 @@ def _infer_SequenceInsert(self, node): # noqa: N802 self._fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type) def _infer_Shape(self, node): # noqa: N802 + """Infers and sets the symbolic shape for the output node in the computation graph.""" self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0) def _infer_Size(self, node): # noqa: N802 + """Infers and sets the size of the output node by computing the product of its shape in the computation graph.""" sympy_shape = self._get_sympy_shape(node, 0) self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape) self.known_vi_[node.output[0]].CopyFrom( @@ -2037,6 +1875,7 @@ def _infer_Size(self, node): # noqa: N802 ) def _infer_Slice(self, node): # noqa: N802 + """Infer the shape and value information for the Slice node using SymPy and ONNX helper methods.""" # SymPy fails to prove that `x_0 + ... + x_n >= 0` if one of `x_i` is a `sympy.Min(a, b)`, # even when the relation holds for both `a` and `b`. # @@ -2045,18 +1884,13 @@ def _infer_Slice(self, node): # noqa: N802 # # If the number of `min(...)` subexpressions is not exactly one, this function just returns `[expr]`. def flatten_min(expr): - assert isinstance( - expr, sympy.Add - ), f"Expected a sum of two arguments, got {expr}" - min_positions = [ - idx - for idx in range(len(expr.args)) - if isinstance(expr.args[idx], sympy.Min) - ] + assert isinstance(expr, sympy.Add), f"Expected a sum of two arguments, got {expr}" + min_positions = [idx for idx in range(len(expr.args)) if isinstance(expr.args[idx], sympy.Min)] if len(min_positions) == 1: min_pos = min_positions[0] def replace_min_with_arg(arg_idx): + """Replace the sympy.Min() function at a specified position in a sympy.Add() expression with one of its arguments.""" replaced = list(expr.args) assert isinstance( replaced[min_pos], sympy.Min @@ -2074,6 +1908,7 @@ def replace_min_with_arg(arg_idx): return [expr] def less_equal(x, y): + """Returns True if x is less than or equal to y, otherwise False.""" try: return bool(x <= y) except TypeError: @@ -2097,7 +1932,7 @@ def less_equal(x, y): return all(bool(d >= 0) for d in flatten_min(y - x)) def handle_negative_index(index, bound): - """normalizes a negative index to be in [0, bound)""" + """Normalizes a negative index to be in [0, bound)""" try: if not less_equal(0, index): if is_literal(index) and index <= -self.int_max_: @@ -2162,18 +1997,14 @@ def handle_negative_index(index, bound): if not less_equal(e, new_sympy_shape[i]): e = new_sympy_shape[i] # noqa: PLW2901 except Exception: - logger.warning( - f"Unable to determine if {e} <= {new_sympy_shape[i]}, treat as equal" - ) + logger.warning(f"Unable to determine if {e} <= {new_sympy_shape[i]}, treat as equal") e = new_sympy_shape[i] # noqa: PLW2901 s = handle_negative_index(s, new_sympy_shape[i]) # noqa: PLW2901 if is_literal(new_sympy_shape[i]) and is_literal(s): s = max(0, min(s, new_sympy_shape[i])) # noqa: PLW2901 - new_sympy_shape[i] = sympy.simplify( - (e - s + t + (-1 if t > 0 else 1)) // t - ) + new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t) self._update_computed_dims(new_sympy_shape) @@ -2201,11 +2032,10 @@ def handle_negative_index(index, bound): if type(input_sympy_data) == list or ( # noqa: E721 type(input_sympy_data) == np.array and len(input_sympy_data.shape) == 1 ): - self.sympy_data_[node.output[0]] = input_sympy_data[ - starts[0] : ends[0] : steps[0] - ] + self.sympy_data_[node.output[0]] = input_sympy_data[starts[0] : ends[0] : steps[0]] def _infer_SoftmaxCrossEntropyLoss(self, node): # noqa: N802 + """Infer the softmax cross-entropy loss for a given node in the computation graph.""" vi = self.known_vi_[node.output[0]] elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type @@ -2223,10 +2053,9 @@ def _infer_SoftmaxCrossEntropyLoss(self, node): # noqa: N802 vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, data_shape)) def _infer_Split_Common(self, node, make_value_info_func): # noqa: N802 + """Infers the output shape for the Split operator given an ONNX node and a function to create tensor value info.""" input_sympy_shape = self._get_sympy_shape(node, 0) - axis = handle_negative_axis( - get_attribute(node, "axis", 0), len(input_sympy_shape) - ) + axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape)) op_set = get_opset(self.out_mp_) # Depending on op-version 'split' are provided as attribute or via 2nd input @@ -2250,22 +2079,21 @@ def _infer_Split_Common(self, node, make_value_info_func): # noqa: N802 make_value_info_func( node.output[i_o], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape( - input_sympy_shape[:axis] - + [split[i_o]] - + input_sympy_shape[axis + 1 :] - ), + get_shape_from_sympy_shape(input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis + 1 :]), ) ) self.known_vi_[vi.name] = vi def _infer_Split(self, node): # noqa: N802 + """Infers the output shapes and types for the Split node using the common inference logic.""" self._infer_Split_Common(node, helper.make_tensor_value_info) def _infer_SplitToSequence(self, node): # noqa: N802 + """Infers the output shapes and types for the SplitToSequence node using the common inference logic.""" self._infer_Split_Common(node, helper.make_sequence_value_info) def _infer_Squeeze(self, node): # noqa: N802 + """Infers the output shapes and types for the Squeeze node using the input shape and operation set.""" input_shape = self._get_shape(node, 0) op_set = get_opset(self.out_mp_) @@ -2283,9 +2111,7 @@ def _infer_Squeeze(self, node): # noqa: N802 # For symbolic dimensions we guess they are !=1. output_shape = [s for s in input_shape if s != 1] if self.verbose_ > 0: - symbolic_dimensions = [ - s for s in input_shape if type(s) != int - ] # noqa: E721 + symbolic_dimensions = [s for s in input_shape if type(s) != int] # noqa: E721 if len(symbolic_dimensions) > 0: logger.debug( f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " @@ -2298,9 +2124,7 @@ def _infer_Squeeze(self, node): # noqa: N802 if i not in axes: output_shape.append(input_shape[i]) else: - assert ( - input_shape[i] == 1 or type(input_shape[i]) != int - ) # noqa: E721 + assert input_shape[i] == 1 or type(input_shape[i]) != int # noqa: E721 if self.verbose_ > 0 and type(input_shape[i]) != int: # noqa: E721 logger.debug( f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " @@ -2318,6 +2142,7 @@ def _infer_Squeeze(self, node): # noqa: N802 self._pass_on_sympy_data(node) def _infer_Tile(self, node): # noqa: N802 + """Infers the output shape for the Tile operation in a computation graph based on input shape and repeat values.""" repeats_value = self._try_get_value(node, 1) new_sympy_shape = [] if repeats_value is not None: @@ -2327,9 +2152,7 @@ def _infer_Tile(self, node): # noqa: N802 new_sympy_shape.append(new_dim) self._update_computed_dims(new_sympy_shape) else: - new_sympy_shape = self._new_symbolic_shape( - self._get_shape_rank(node, 0), node - ) + new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) vi = self.known_vi_[node.output[0]] vi.CopyFrom( helper.make_tensor_value_info( @@ -2340,6 +2163,7 @@ def _infer_Tile(self, node): # noqa: N802 ) def _infer_TopK(self, node): # noqa: N802 + """Infers the output shape for the TopK operation in an ONNX graph node based on input shape and specified axis.""" rank = self._get_shape_rank(node, 0) axis = handle_negative_axis(get_attribute(node, "axis", -1), rank) new_shape = self._get_shape(node, 0) @@ -2366,26 +2190,20 @@ def _infer_TopK(self, node): # noqa: N802 for i_o in range(len(node.output)): vi = self.known_vi_[node.output[i_o]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[i_o], vi.type.tensor_type.elem_type, new_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[i_o], vi.type.tensor_type.elem_type, new_shape)) def _infer_Transpose(self, node): # noqa: N802 + """Infer and update the shape information for a Transpose node based on its input shape and permutation attributes.""" if node.input[0] in self.sympy_data_: data_shape = self._get_shape(node, 0) perm = get_attribute(node, "perm", reversed(list(range(len(data_shape))))) input_data = self.sympy_data_[node.input[0]] self.sympy_data_[node.output[0]] = ( - np.transpose( - np.array(input_data).reshape(*data_shape), axes=tuple(perm) - ) - .flatten() - .tolist() + np.transpose(np.array(input_data).reshape(*data_shape), axes=tuple(perm)).flatten().tolist() ) def _infer_Unsqueeze(self, node): # noqa: N802 + """Infers the output shape for the Unsqueeze operation based on the input shape and operator set.""" input_shape = self._get_shape(node, 0) op_set = get_opset(self.out_mp_) @@ -2421,6 +2239,7 @@ def _infer_Unsqueeze(self, node): # noqa: N802 self._pass_on_sympy_data(node) def _infer_ZipMap(self, node): # noqa: N802 + """Infer the type of keys for a ZipMap node based on its class labels attribute.""" map_key_type = None if get_attribute(node, "classlabels_int64s") is not None: map_key_type = onnx.TensorProto.INT64 @@ -2430,22 +2249,19 @@ def _infer_ZipMap(self, node): # noqa: N802 assert map_key_type is not None new_vi = onnx.ValueInfoProto() new_vi.name = node.output[0] - new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = ( - onnx.TensorProto.FLOAT - ) + new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type vi = self.known_vi_[node.output[0]] vi.CopyFrom(new_vi) def _infer_Attention(self, node): # noqa: N802 + """Infer shape and data type for ONNX Attention node outputs given input shapes and attributes.""" shape = self._get_shape(node, 0) shape_weights = self._get_shape(node, 1) shape_bias = self._try_get_shape(node, 2) if shape_bias is not None: assert len(shape_bias) == 1 - tripled_hidden_size = ( - shape_bias[0] if shape_bias is not None else shape_weights[1] - ) + tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] if shape and len(shape) == 3: qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") if qkv_hidden_sizes_attr is not None: @@ -2455,9 +2271,7 @@ def _infer_Attention(self, node): # noqa: N802 shape[2] = int(tripled_hidden_size / 3) output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) if len(node.output) > 1: # input shape: (batch_size, sequence_length, hidden_size) @@ -2465,31 +2279,19 @@ def _infer_Attention(self, node): # noqa: N802 # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len) # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length input_shape = self._get_shape(node, 0) - past_shape = ( - self._get_shape(node, 4) - if len(node.input) > 4 and node.input[4] - else [] - ) - mask_shape = ( - self._get_shape(node, 3) - if len(node.input) > 3 and node.input[3] - else [] - ) + past_shape = self._get_shape(node, 4) if len(node.input) > 4 and node.input[4] else [] + mask_shape = self._get_shape(node, 3) if len(node.input) > 3 and node.input[3] else [] if past_shape and len(past_shape) == 5: if mask_shape and len(mask_shape) in [2, 3]: past_shape[3] = mask_shape[-1] elif input_shape and len(input_shape) == 3: - if isinstance(input_shape[1], int) and isinstance( - past_shape[3], int - ): + if isinstance(input_shape[1], int) and isinstance(past_shape[3], int): past_shape[3] = input_shape[1] + past_shape[3] else: past_shape[3] = f"{past_shape[3]}+{input_shape[1]}" vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, output_dtype, past_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) # No past input but present output still exists else: num_heads = get_attribute(node, "num_heads") @@ -2502,13 +2304,10 @@ def _infer_Attention(self, node): # noqa: N802 head_size, ] vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info( - vi.name, output_dtype, present_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) def _infer_GatedRelativePositionBias(self, node): # noqa: N802 + """Infer the shape for gated relative position bias given the node attributes.""" # When padding is removed: # query_layer: (token_count, num_heads x head_size) # token_offset: (batch_size, seq_len) @@ -2538,19 +2337,16 @@ def _infer_GatedRelativePositionBias(self, node): # noqa: N802 output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, output_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_PackedAttention(self, node): # noqa: N802 + """Infer shape and data type for PackedAttention nodes in a given computational graph.""" shape = self._get_shape(node, 0) shape_weights = self._get_shape(node, 1) shape_bias = self._try_get_shape(node, 2) if shape_bias is not None: assert len(shape_bias) == 1 - tripled_hidden_size = ( - shape_bias[0] if shape_bias is not None else shape_weights[1] - ) + tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] if shape and len(shape) == 2: qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") if qkv_hidden_sizes_attr is not None: @@ -2560,11 +2356,10 @@ def _infer_PackedAttention(self, node): # noqa: N802 shape[1] = int(tripled_hidden_size / 3) output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) def _infer_PackedMultiHeadAttention(self, node): # noqa: N802 + """Infer the output shape for PackedMultiHeadAttention node in the computational graph.""" shape_value = self._try_get_shape(node, 2) if shape_value is not None and len(shape_value) == 2: output_shape = shape_value @@ -2575,51 +2370,34 @@ def _infer_PackedMultiHeadAttention(self, node): # noqa: N802 output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, output_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_RemovePadding(self, node): # noqa: N802 + """Infers the shape and data type for the output tensor after removing padding.""" shape = self._get_shape(node, 0) if shape and len(shape) == 3: output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_dtype, ["token_count", shape[2]] - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, ["token_count", shape[2]])) vi_token_offset = self.known_vi_[node.output[1]] vi_token_offset.CopyFrom( - helper.make_tensor_value_info( - node.output[1], onnx.TensorProto.INT32, [shape[0], shape[1]] - ) + helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, [shape[0], shape[1]]) ) vi_cumulated_seq_len = self.known_vi_[node.output[2]] vi_cumulated_seq_len.CopyFrom( - helper.make_tensor_value_info( - node.output[2], onnx.TensorProto.INT32, ["batch_size + 1"] - ) + helper.make_tensor_value_info(node.output[2], onnx.TensorProto.INT32, ["batch_size + 1"]) ) vi_max_seq_len = self.known_vi_[node.output[3]] - vi_max_seq_len.CopyFrom( - helper.make_tensor_value_info( - node.output[3], onnx.TensorProto.INT32, [1] - ) - ) + vi_max_seq_len.CopyFrom(helper.make_tensor_value_info(node.output[3], onnx.TensorProto.INT32, [1])) def _infer_RestorePadding(self, node): # noqa: N802 + """Infers the output shape and type for the RestorePadding operation.""" shape_input = self._get_shape(node, 0) shape_token_offset = self._get_shape(node, 1) - if ( - shape_input - and len(shape_input) == 2 - and shape_token_offset - and len(shape_token_offset) == 2 - ): + if shape_input and len(shape_input) == 2 and shape_token_offset and len(shape_token_offset) == 2: output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] @@ -2628,16 +2406,14 @@ def _infer_RestorePadding(self, node): # noqa: N802 shape_token_offset[1], shape_input[1], ] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_dtype, output_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_BiasGelu(self, node): # noqa: N802 + """Propagate shape and type information for BiasGelu node during inference.""" self._propagate_shape_and_type(node) def _infer_MultiHeadAttention(self, node): # noqa: N802 + """Propagate shape and type information for MultiHeadAttention node during inference.""" # Output 0 has shape (batch_size, sequence_length, v_hidden_size) # Q, K and V without packing: # Input 0 (query) has shape (batch_size, sequence_length, hidden_size) @@ -2668,11 +2444,7 @@ def _infer_MultiHeadAttention(self, node): # noqa: N802 output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_dtype, output_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) elif len(query_shape) == 5: if isinstance(query_shape[2], int) and isinstance(query_shape[4], int): @@ -2692,11 +2464,7 @@ def _infer_MultiHeadAttention(self, node): # noqa: N802 output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_dtype, output_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) if len(node.output) > 1: batch_size = query_shape[0] @@ -2715,14 +2483,10 @@ def _infer_MultiHeadAttention(self, node): # noqa: N802 past_shape = self._try_get_shape(node, 6) if past_shape is not None: - if isinstance(past_shape[2], int) and isinstance( - total_sequence_length, int - ): + if isinstance(past_shape[2], int) and isinstance(total_sequence_length, int): total_sequence_length = past_shape[2] + total_sequence_length else: - total_sequence_length = ( - f"{past_shape[2]}+{total_sequence_length}" - ) + total_sequence_length = f"{past_shape[2]}+{total_sequence_length}" present_shape = [ batch_size, @@ -2734,19 +2498,12 @@ def _infer_MultiHeadAttention(self, node): # noqa: N802 assert output_dtype is not None if len(node.output) > 2 and node.output[1] and node.output[2]: vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info( - vi.name, output_dtype, present_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) vi = self.known_vi_[node.output[2]] - vi.CopyFrom( - helper.make_tensor_value_info( - vi.name, output_dtype, present_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) def _infer_DecoderMaskedMultiHeadAttention(self, node): # noqa: N802 + """Infers the output shape of the DecoderMaskedMultiHeadAttention node based on input shapes and attributes in the computational graph.""" # Output 0 has shape (batch_size, 1, v_hidden_size) # Q, K and V without packing: # Input 0 (query) has shape (batch_size, 1, hidden_size) @@ -2758,40 +2515,38 @@ def _infer_DecoderMaskedMultiHeadAttention(self, node): # noqa: N802 output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type assert output_dtype is not None vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_dtype, output_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) if len(node.output) > 2 and node.output[1] and node.output[2]: past_shape = self._try_get_shape(node, 5) if past_shape is not None: vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, output_dtype, past_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) vi = self.known_vi_[node.output[2]] - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, output_dtype, past_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) def _infer_FastGelu(self, node): # noqa: N802 + """Infers the output shapes and types for the FastGelu node using shape propagation.""" self._propagate_shape_and_type(node) def _infer_Gelu(self, node): # noqa: N802 + """Infers the output shapes and types for the Gelu node using shape propagation.""" self._propagate_shape_and_type(node) def _infer_QuickGelu(self, node): # noqa: N802 + """Infers the output shapes and types for the QuickGelu node using shape propagation.""" self._propagate_shape_and_type(node) def _infer_GemmFastGelu(self, node): # noqa: N802 + """Infers the output shapes and types for the GemmFastGelu node using matrix multiplication shape computation.""" self._compute_matmul_shape(node) def _infer_GemmFloat8(self, node): # noqa: N802 + """Infers the output shapes and types for the GemmFloat8 node using matrix multiplication shape computation.""" self._compute_matmul_shape(node) def _infer_LayerNormalization(self, node): # noqa: N802 + """Infers the output shapes and types for the LayerNormalization node, and updates the node's output shape information.""" self._propagate_shape_and_type(node) if len(node.output) > 1: axis = get_attribute(node, "axis") @@ -2803,29 +2558,20 @@ def _infer_LayerNormalization(self, node): # noqa: N802 axis = handle_negative_axis(axis, rank) mean_shape = x_shape[:axis] + [1 for _ in range(rank - axis)] mean_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - if ( - mean_dtype == onnx.TensorProto.FLOAT16 - or mean_dtype == onnx.TensorProto.BFLOAT16 - ): + if mean_dtype == onnx.TensorProto.FLOAT16 or mean_dtype == onnx.TensorProto.BFLOAT16: mean_dtype = onnx.TensorProto.FLOAT vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[1], mean_dtype, mean_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[1], mean_dtype, mean_shape)) if len(node.output) > 2: vi = self.known_vi_[node.output[2]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[2], mean_dtype, mean_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[2], mean_dtype, mean_shape)) def _infer_LongformerAttention(self, node): # noqa: N802 + """Propagates shape and type information for a LongformerAttention node.""" self._propagate_shape_and_type(node) def _infer_EmbedLayerNormalization(self, node): # noqa: N802 + """Infers shape for EmbedLayerNormalization node, ensuring input_ids and word_embedding tensors have correct dimensions.""" input_ids_shape = self._get_shape(node, 0) word_embedding_shape = self._get_shape(node, 2) assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2 @@ -2833,32 +2579,21 @@ def _infer_EmbedLayerNormalization(self, node): # noqa: N802 word_embedding_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], word_embedding_dtype, output_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], word_embedding_dtype, output_shape)) if len(node.output) > 1 and node.output[1]: mask_index_shape = [input_ids_shape[0]] vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[1], onnx.TensorProto.INT32, mask_index_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, mask_index_shape)) if len(node.output) > 2: # Optional output of add before layer normalization is done # shape is same as the output vi = self.known_vi_[node.output[2]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[2], word_embedding_dtype, output_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[2], word_embedding_dtype, output_shape)) def _infer_SkipLayerNormalization(self, node): # noqa: N802 + """Infer the shape and data type for the SkipLayerNormalization node and propagate them accordingly.""" self._propagate_shape_and_type(node) # If the SkipLayerNormalization node contains the optional @@ -2867,14 +2602,17 @@ def _infer_SkipLayerNormalization(self, node): # noqa: N802 self._propagate_shape_and_type(node, 0, 3) def _infer_GroupNorm(self, node): # noqa: N802 + """Infers the shape and data type for a GroupNorm node.""" self._propagate_shape_and_type(node) def _infer_SkipGroupNorm(self, node): # noqa: N802 + """Infers the shape and type for a SkipGroupNorm node and propagates them accordingly based on the number of outputs.""" self._propagate_shape_and_type(node, 0, 0) if len(node.output) > 1: self._propagate_shape_and_type(node, 0, 1) def _infer_BiasSplitGelu(self, node): # noqa: N802 + """Infers the shape and type for a BiasSplitGelu node based on the input and bias shape and propagates the output accordingly.""" input_shape = self._get_shape(node, 0) bias_shape = self._get_shape(node, 1) if input_shape and bias_shape and isinstance(bias_shape[0], int): @@ -2882,39 +2620,32 @@ def _infer_BiasSplitGelu(self, node): # noqa: N802 output_shape[2] = int(bias_shape[0] / 2) vi = self.known_vi_[node.output[0]] output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, output_dtype, output_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape)) def _infer_BiasAdd(self, node): # noqa: N802 + """Infer the output shape and type for a BiasAdd node by propagating input shape and type information.""" self._propagate_shape_and_type(node) def _infer_RotaryEmbedding(self, node): # noqa: N802 + """Infer the output shape and type for a RotaryEmbedding node by appropriately propagating input shape and type information.""" if len(node.output) == 1: self._propagate_shape_and_type(node) elif len(node.output) == 2: # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` self._propagate_shape_and_type(node, input_index=1, output_index=0) - self._propagate_shape_and_type( - node, input_index=0, output_index=1 - ) # true output + self._propagate_shape_and_type(node, input_index=0, output_index=1) # true output elif len(node.output) == 3: # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` self._propagate_shape_and_type(node, input_index=1, output_index=0) self._propagate_shape_and_type(node, input_index=1, output_index=1) - self._propagate_shape_and_type( - node, input_index=0, output_index=2 - ) # true output + self._propagate_shape_and_type(node, input_index=0, output_index=2) # true output def _infer_PythonOp(self, node): # noqa: N802 + """Infer and propagate the shape and type information for a PythonOp node in the computation graph.""" output_tensor_types = get_attribute(node, "output_tensor_types") - assert ( - output_tensor_types - ), f"PythonOp '{node.name}' has no output_tensor_types attribute." + assert output_tensor_types, f"PythonOp '{node.name}' has no output_tensor_types attribute." output_tensor_ranks = get_attribute(node, "output_tensor_ranks") - assert ( - output_tensor_ranks - ), f"PythonOp '{node.name}' has no output_tensor_ranks attribute." + assert output_tensor_ranks, f"PythonOp '{node.name}' has no output_tensor_ranks attribute." from onnxruntime.capi._pybind_state import get_shape_inference_function @@ -2924,9 +2655,7 @@ def _infer_PythonOp(self, node): # noqa: N802 # Set the context output separately. # The first output is torch.autograd.Function''s context. vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])) if shape_inferer is not None: input_shapes = [] @@ -2934,13 +2663,9 @@ def _infer_PythonOp(self, node): # noqa: N802 for input_index in range(len(node.input)): shape = self._get_shape(node, input_index) input_shapes.append(shape) - input_dtype = self.known_vi_[ - node.input[input_index] - ].type.tensor_type.elem_type + input_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type input_dtypes.append(input_dtype) - output_shapes, output_dtypes = shape_inferer( - node, input_shapes, input_dtypes - ) + output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes) assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1), ( f"PythonOp '{func_name}' returned {len(output_shapes)} shapes and {len(output_dtypes)} dtypes, " f"but expected {len(node.output) - 1} outputs." @@ -2949,9 +2674,7 @@ def _infer_PythonOp(self, node): # noqa: N802 output_index = i + 1 vi = self.known_vi_[node.output[output_index]] vi.CopyFrom( - helper.make_tensor_value_info( - node.output[output_index], output_dtypes[i], output_shapes[i] - ) + helper.make_tensor_value_info(node.output[output_index], output_dtypes[i], output_shapes[i]) ) else: # General shape inference for PythonOp. @@ -2962,24 +2685,18 @@ def _infer_PythonOp(self, node): # noqa: N802 vi = self.known_vi_[node.output[i + 1]] sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node) shape = get_shape_from_sympy_shape(sympy_shape) - value_info = helper.make_tensor_value_info( - node.output[i + 1], output_tensor_types[i], shape - ) + value_info = helper.make_tensor_value_info(node.output[i + 1], output_tensor_types[i], shape) vi.CopyFrom(value_info) def _propagate_shape_and_type(self, node, input_index=0, output_index=0): + """Propagates the shape and type information from input to output nodes in a computational graph.""" shape = self._get_shape(node, input_index) - output_dtype = self.known_vi_[ - node.input[input_index] - ].type.tensor_type.elem_type + output_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type vi = self.known_vi_[node.output[output_index]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[output_index], output_dtype, shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtype, shape)) def _is_none_dim(self, dim_value): + """Check if dimension value is a string representing an unknown dimension that is not in symbolic_dims_.""" if type(dim_value) != str: # noqa: E721 return False if "unk__" not in dim_value: @@ -2989,12 +2706,14 @@ def _is_none_dim(self, dim_value): return True def _is_shape_contains_none_dim(self, out_shape): + """Check if any dimension in the given shape contains 'None' and return that dimension, else return None.""" for out in out_shape: if self._is_none_dim(out): return out return None def _infer_impl(self, start_sympy_data=None): + """Infer implementation details and update symbolic data and input symbols.""" self.sympy_data_ = start_sympy_data or {} self.out_mp_.graph.ClearField("value_info") self._apply_suggested_merge(graph_input_only=True) @@ -3012,13 +2731,9 @@ def _infer_impl(self, start_sympy_data=None): for i_dim, dim in enumerate(input_shape): if dim is None: # some models use None for symbolic dim in input, replace it with a string - input_dims[i_dim].dim_param = str( - self._new_symbolic_dim(i.name, i_dim) - ) + input_dims[i_dim].dim_param = str(self._new_symbolic_dim(i.name, i_dim)) - self.input_symbols_.update( - [d for d in input_shape if type(d) == str] - ) # noqa: E721 + self.input_symbols_.update([d for d in input_shape if type(d) == str]) # noqa: E721 for s in self.input_symbols_: if s in self.suggested_merge_: @@ -3035,13 +2750,12 @@ def _infer_impl(self, start_sympy_data=None): self.tmp_mp_.CopyFrom(self.out_mp_) self.tmp_mp_.graph.ClearField("initializer") - # compute prerequesite for node for topological sort + # compute prerequisite for node for topological sort # node with subgraphs may have dependency on implicit inputs, which will affect topological sort - prereq_for_node = ( - {} - ) # map from node to all its inputs, including implicit ones in subgraph + prereq_for_node = {} # map from node to all its inputs, including implicit ones in subgraph def get_prereq(node): + """Compute the prerequisite inputs for a given node, including implicit inputs from subgraphs for topological sorting.""" names = {i for i in node.input if i} subgraphs = [] if node.op_type == "If": @@ -3057,13 +2771,7 @@ def get_prereq(node): for n in g.node: g_outputs_and_initializers.update(n.output) for n in g.node: - g_prereq.update( - [ - i - for i in get_prereq(n) - if i not in g_outputs_and_initializers - ] - ) + g_prereq.update([i for i in get_prereq(n) if i not in g_outputs_and_initializers]) names.update(g_prereq) # remove subgraph inputs from g_prereq since those are local-only for i in g.input: @@ -3076,26 +2784,16 @@ def get_prereq(node): # topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate sorted_nodes = [] - sorted_known_vi = { - i.name - for i in list(self.out_mp_.graph.input) - + list(self.out_mp_.graph.initializer) - } + sorted_known_vi = {i.name for i in list(self.out_mp_.graph.input) + list(self.out_mp_.graph.initializer)} if any([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): # Loop/Scan will have some graph output in graph inputs, so don't do topological sort sorted_nodes = self.out_mp_.graph.node else: - while not all( - [o.name in sorted_known_vi for o in self.out_mp_.graph.output] - ): + while not all([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): old_sorted_nodes_len = len(sorted_nodes) for node in self.out_mp_.graph.node: if (node.output[0] not in sorted_known_vi) and all( - [ - i in sorted_known_vi - for i in prereq_for_node[node.output[0]] - if i - ] + [i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i] ): sorted_known_vi.update(node.output) sorted_nodes.append(node) @@ -3121,11 +2819,7 @@ def get_prereq(node): for attr in node.attribute: # TODO: Is overload_name needed? if attr.name == "operator": - aten_op_name = ( - attr.s.decode("utf-8") - if isinstance(attr.s, bytes) - else attr.s - ) + aten_op_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s if aten_op_name in self.aten_op_dispatcher_: known_aten_op = True self.aten_op_dispatcher_[aten_op_name](node) @@ -3135,9 +2829,7 @@ def get_prereq(node): logger.debug(node.op_type + ": " + node.name) for i, name in enumerate(node.input): logger.debug( - " Input {}: {} {}".format( - i, name, "initializer" if name in self.initializers_ else "" - ) + " Input {}: {} {}".format(i, name, "initializer" if name in self.initializers_ else "") ) # onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb'] @@ -3156,20 +2848,8 @@ def get_prereq(node): vi = self.known_vi_[node.output[0]] out_rank = len(get_shape_from_type_proto(vi.type)) in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] - for d in range( - out_rank - - ( - 2 - if node.op_type - in ["MatMul", "MatMulInteger", "MatMulInteger16"] - else 0 - ) - ): - in_dims = [ - s[len(s) - out_rank + d] - for s in in_shapes - if len(s) + d >= out_rank - ] + for d in range(out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)): + in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank] if len(in_dims) > 1: self._check_merged_dims(in_dims, allow_broadcast=True) @@ -3180,8 +2860,7 @@ def get_prereq(node): # the RotaryEmbedding op created during export can be replaced by the RotaryEmbedding # contrib op if ( - node.op_type == "SkipLayerNormalization" - or node.op_type == "SkipSimplifiedLayerNormalization" + node.op_type == "SkipLayerNormalization" or node.op_type == "SkipSimplifiedLayerNormalization" ) and i_o in [1, 2]: continue if node.op_type == "RotaryEmbedding" and len(node.output) > 1: @@ -3197,9 +2876,7 @@ def get_prereq(node): if out_type_kind not in ["tensor_type", "sparse_tensor_type", None]: if self.verbose_ > 2: if out_type_kind == "sequence_type": - seq_cls_type = out_type.sequence_type.elem_type.WhichOneof( - "value" - ) + seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value") if seq_cls_type == "tensor_type": logger.debug( " {}: sequence of {} {}".format( @@ -3211,38 +2888,27 @@ def get_prereq(node): ) ) else: - logger.debug( - f" {node.output[i_o]}: sequence of {seq_cls_type}" - ) + logger.debug(f" {node.output[i_o]}: sequence of {seq_cls_type}") else: logger.debug(f" {node.output[i_o]}: {out_type_kind}") continue out_shape = get_shape_from_value_info(vi) - out_type_undefined = ( - out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED - ) + out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED if self.verbose_ > 2: logger.debug( " {}: {} {}".format( node.output[i_o], str(out_shape), - onnx.TensorProto.DataType.Name( - vi.type.tensor_type.elem_type - ), + onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type), ) ) if node.output[i_o] in self.sympy_data_: - logger.debug( - " Sympy Data: " + str(self.sympy_data_[node.output[i_o]]) - ) + logger.debug(" Sympy Data: " + str(self.sympy_data_[node.output[i_o]])) # onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain if ( - out_shape is not None - and ( - None in out_shape or self._is_shape_contains_none_dim(out_shape) - ) + out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape)) ) or out_type_undefined: if self.auto_merge_: if node.op_type in [ @@ -3264,36 +2930,21 @@ def get_prereq(node): "Min", "Max", ]: - shapes = [ - self._get_shape(node, i) for i in range(len(node.input)) - ] + shapes = [self._get_shape(node, i) for i in range(len(node.input))] if node.op_type in [ "MatMul", "MatMulInteger", "MatMulInteger16", ]: - if ( - None in out_shape - or self._is_shape_contains_none_dim(out_shape) - ): + if None in out_shape or self._is_shape_contains_none_dim(out_shape): if None in out_shape: idx = out_shape.index(None) else: - idx = out_shape.index( - self._is_shape_contains_none_dim(out_shape) - ) - dim_idx = [ - len(s) - len(out_shape) + idx for s in shapes - ] + idx = out_shape.index(self._is_shape_contains_none_dim(out_shape)) + dim_idx = [len(s) - len(out_shape) + idx for s in shapes] # only support auto merge for MatMul for dim < rank-2 when rank > 2 - assert ( - len(shapes[0]) > 2 - and dim_idx[0] < len(shapes[0]) - 2 - ) - assert ( - len(shapes[1]) > 2 - and dim_idx[1] < len(shapes[1]) - 2 - ) + assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2 + assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2 elif node.op_type == "Expand": # auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq]) shapes = [ @@ -3305,15 +2956,11 @@ def get_prereq(node): if shapes: for idx in range(len(out_shape)): - if out_shape[idx] is not None and not self._is_none_dim( - out_shape[idx] - ): + if out_shape[idx] is not None and not self._is_none_dim(out_shape[idx]): continue # note that the broadcasting rule aligns from right to left # if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge - dim_idx = [ - len(s) - len(out_shape) + idx for s in shapes - ] + dim_idx = [len(s) - len(out_shape) + idx for s in shapes] if len(dim_idx) > 0: self._add_suggested_merge( [ @@ -3329,22 +2976,12 @@ def get_prereq(node): self.run_ = False # create new dynamic dims for ops not handled by symbolic shape inference - if ( - self.run_ is False - and node.op_type not in self.dispatcher_ - and not known_aten_op - ): - is_unknown_op = out_type_undefined and ( - out_shape is None or len(out_shape) == 0 - ) + if self.run_ is False and node.op_type not in self.dispatcher_ and not known_aten_op: + is_unknown_op = out_type_undefined and (out_shape is None or len(out_shape) == 0) if is_unknown_op: # unknown op to ONNX, maybe from higher opset or other domain # only guess the output rank from input 0 when using guess_output_rank option - out_rank = ( - self._get_shape_rank(node, 0) - if self.guess_output_rank_ - else -1 - ) + out_rank = self._get_shape_rank(node, 0) if self.guess_output_rank_ else -1 else: # valid ONNX op, but not handled by symbolic shape inference, just assign dynamic shape out_rank = len(out_shape) @@ -3353,9 +2990,7 @@ def get_prereq(node): new_shape = self._new_symbolic_shape(out_rank, node, i_o) if out_type_undefined: # guess output data type from input vi if not defined - out_dtype = self.known_vi_[ - node.input[0] - ].type.tensor_type.elem_type + out_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type else: # otherwise, use original data type out_dtype = vi.type.tensor_type.elem_type @@ -3386,12 +3021,7 @@ def get_prereq(node): continue # continue the inference after guess, no need to stop as no merge is needed if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined: - logger.debug( - "Stopping at incomplete shape inference at " - + node.op_type - + ": " - + node.name - ) + logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name) logger.debug("node inputs:") for i in node.input: if i in self.known_vi_: @@ -3412,21 +3042,19 @@ def get_prereq(node): return True def _update_output_from_vi(self): + """Update output attributes using known value information dictionary.""" for output in self.out_mp_.graph.output: if output.name in self.known_vi_: output.CopyFrom(self.known_vi_[output.name]) @staticmethod - def infer_shapes( - in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0 - ): + def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0): + """Perform symbolic shape inference on an ONNX model using the specified options to handle model shapes efficiently.""" onnx_opset = get_opset(in_mp) if (not onnx_opset) or onnx_opset < 7: logger.warning("Only support models of onnx opset 7 and above.") return None - symbolic_shape_inference = SymbolicShapeInference( - int_max, auto_merge, guess_output_rank, verbose - ) + symbolic_shape_inference = SymbolicShapeInference(int_max, auto_merge, guess_output_rank, verbose) all_shapes_inferred = False symbolic_shape_inference._preprocess(in_mp) while symbolic_shape_inference.run_: @@ -3443,6 +3071,7 @@ def infer_shapes( def parse_arguments(): + """Parses command-line arguments for ONNX model transformation options.""" parser = argparse.ArgumentParser() parser.add_argument("--input", required=True, help="The input model file") parser.add_argument("--output", help="The output model file") diff --git a/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py b/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py index ba4f521..f1a2149 100644 --- a/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py +++ b/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py @@ -43,7 +43,7 @@ def dtype_to_onnx(dtype: Union[np.dtype, "onnx.TensorProto.DataType"]) -> int: def check_duplicate_node_names(nodes: Sequence[Node], level=G_LOGGER.WARNING): - # Check if node names are unique. If not, log based on severity. + """Check if node names are unique and log any duplicates based on the specified severity level.""" # Note: # Empty string or None attribute values are not considered duplicates. @@ -64,7 +64,7 @@ def check_duplicate_node_names(nodes: Sequence[Node], level=G_LOGGER.WARNING): def update_import_domains(graph): - # Update the import_domains field to contain the graph's ONNX opset, + """Update the import_domains field of a graph to include its ONNX opset and other used non-ONNX domains.""" # as well as other non-ONNX domains which are used by this graph's nodes. # Returns the updated value of the import_domains field. @@ -83,17 +83,14 @@ def update_import_domains(graph): DEFAULT_CUSTOM_OPSET_VERSION = 1 for used_domain in all_used_domains: if used_domain not in current_domains: - graph.import_domains.append( - onnx.helper.make_opsetid(used_domain, DEFAULT_CUSTOM_OPSET_VERSION) - ) + graph.import_domains.append(onnx.helper.make_opsetid(used_domain, DEFAULT_CUSTOM_OPSET_VERSION)) current_domains.add(used_domain) return graph.import_domains # Converts a fp32 gs.Constant to a bf16 onnx.TensorProto def tensor_to_onnx_bf16(tensor: Constant): - - # Converts the fp32 numpy array to bf16 values and store in a uint16 numpy array + """Converts an fp32 gs.Constant tensor to a bf16 onnx.TensorProto.""" def np_float32_to_bf16_as_uint16(arr): new_arr = np.empty(arr.size, dtype=np.uint16) flatten = arr.flatten() @@ -142,9 +139,7 @@ def export_sparse_tensor_proto(tensor: Constant) -> onnx.SparseTensorProto: return tensor._values.tensor @staticmethod - def export_value_info_proto( - tensor: Tensor, do_type_check: bool - ) -> onnx.ValueInfoProto: + def export_value_info_proto(tensor: Tensor, do_type_check: bool) -> onnx.ValueInfoProto: if do_type_check and tensor.dtype is None: G_LOGGER.critical( "Graph input and output tensors must include dtype information. Please set the dtype attribute for: {:}".format( @@ -154,9 +149,7 @@ def export_value_info_proto( if tensor.dtype is not None: if isinstance(tensor, Constant) or tensor.type == "tensor_type": - onnx_tensor = onnx.helper.make_tensor_value_info( - tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape - ) + onnx_tensor = onnx.helper.make_tensor_value_info(tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape) elif tensor.type == "sequence_type": onnx_tensor = onnx.helper.make_tensor_sequence_value_info( tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape @@ -186,9 +179,7 @@ def export_attributes(attrs: dict) -> List[onnx.AttributeProto]: # Netron has a bug which makes it crash if a Tensor attribute has no tensor data. # So provide some meaningless tensor data for Netron to read. if val.type == Tensor: - tensor_proto = OnnxExporter.export_tensor_proto( - Constant("", np.array([0], dtype=np.float32)) - ) + tensor_proto = OnnxExporter.export_tensor_proto(Constant("", np.array([0], dtype=np.float32))) onnx_attr.t.CopyFrom(tensor_proto) onnx_attr.ref_attr_name = val.name @@ -232,9 +223,7 @@ def export_function(func: Function) -> onnx.FunctionProto: for tensor in func.tensors().values(): if isinstance(tensor, Constant): # Copying the tensor prevents the new node from appearing in the Constant tensor's inputs. - new_const_nodes.append( - Node("Constant", attrs={"value": tensor}, outputs=[tensor.copy()]) - ) + new_const_nodes.append(Node("Constant", attrs={"value": tensor}, outputs=[tensor.copy()])) # Const nodes have no inputs, so this maintains a topological ordering. func_nodes = new_const_nodes + func_nodes @@ -281,20 +270,13 @@ def export_graph(graph: Graph, do_type_check=True) -> onnx.GraphProto: """ check_duplicate_node_names(graph.nodes, level=G_LOGGER.WARNING) nodes = [OnnxExporter.export_node(node) for node in graph.nodes] - inputs = [ - OnnxExporter.export_value_info_proto(inp, do_type_check) - for inp in graph.inputs - ] - outputs = [ - OnnxExporter.export_value_info_proto(out, do_type_check) - for out in graph.outputs - ] + inputs = [OnnxExporter.export_value_info_proto(inp, do_type_check) for inp in graph.inputs] + outputs = [OnnxExporter.export_value_info_proto(out, do_type_check) for out in graph.outputs] tensor_map = graph.tensors() initializer = [ OnnxExporter.export_tensor_proto(tensor) for tensor in tensor_map.values() - if isinstance(tensor, Constant) - and not isinstance(tensor._values, SparseValues) + if isinstance(tensor, Constant) and not isinstance(tensor._values, SparseValues) ] sparse_initializer = [ OnnxExporter.export_sparse_tensor_proto(tensor) @@ -309,9 +291,8 @@ def export_graph(graph: Graph, do_type_check=True) -> onnx.GraphProto: # Omit tensors from value_info if we don't know their shape/dtype def has_value_info(tensor): - return isinstance(tensor, Variable) and ( - tensor.dtype is not None or tensor.shape is not None - ) + """Check if a tensor is a Variable with either a defined dtype or shape.""" + return isinstance(tensor, Variable) and (tensor.dtype is not None or tensor.shape is not None) value_info = [ OnnxExporter.export_value_info_proto(tensor, do_type_check) diff --git a/onnxslim/onnx_graphsurgeon/graph_pattern/graph_pattern.py b/onnxslim/onnx_graphsurgeon/graph_pattern/graph_pattern.py index fdb8e09..a009bae 100644 --- a/onnxslim/onnx_graphsurgeon/graph_pattern/graph_pattern.py +++ b/onnxslim/onnx_graphsurgeon/graph_pattern/graph_pattern.py @@ -22,9 +22,7 @@ class PatternMapping(dict): - """ - Represents a graph pattern mapping result. - """ + """Represents a graph pattern mapping result.""" def __init__(self, onnx_node=None) -> None: super().__init__() @@ -39,36 +37,34 @@ def __init__(self, onnx_node=None) -> None: self.constants = dict() # constant name -> onnx tensor mapping def set_input_onnx_tensor(self, onnx_tensor, index): + """Sets an ONNX tensor at a specified index of the input list, expanding the list if necessary.""" length = len(self.inputs) for _ in range(index - length + 1): self.inputs.append(None) - if ( - self.inputs[index] is not None - and self.inputs[index].name != onnx_tensor.name - ): + if self.inputs[index] is not None and self.inputs[index].name != onnx_tensor.name: return False # This input tensor has been set up by another onnx tensor self.inputs[index] = onnx_tensor return True def set_output_onnx_tensor(self, onnx_tensor, index): + """Sets the output ONNX tensor at the given index within the outputs list.""" length = len(self.outputs) for _ in range(index - length + 1): self.outputs.append(None) - if ( - self.outputs[index] is not None - and self.outputs[index].name != onnx_tensor.name - ): + if self.outputs[index] is not None and self.outputs[index].name != onnx_tensor.name: return False # This output tensor has been set up by another onnx tensor self.outputs[index] = onnx_tensor return True def set_constant_onnx_tensor(self, onnx_tensor, name): + """Set an ONNX tensor as a constant if it hasn't already been set with a different name.""" if name in self.constants and self.constants[name].name != onnx_tensor.name: return False self.constants[name] = onnx_tensor return True def _get_node(self): + """Return the ONNX node associated with the current instance.""" return self.onnx_node def get(self, name: str): @@ -88,13 +84,7 @@ def get(self, name: str): def __str__(self) -> str: if self.onnx_node is None: - return ( - "{" - + str.join( - ", ", [f"{key}: {str(value)}" for key, value in self.items()] - ) - + "}" - ) + return "{" + str.join(", ", [f"{key}: {str(value)}" for key, value in self.items()]) + "}" return self.onnx_node.name @@ -223,6 +213,7 @@ def add( return tuple(self.node_outputs[name]) def _get_inbound(self, tensor_index): + """Retrieve the tensor id and first inbound node for a given tensor index.""" if len(self.input_tensors) > tensor_index: tensor_id = self.input_tensors[tensor_index] if len(self.tensor_outputs[tensor_id]): @@ -231,6 +222,7 @@ def _get_inbound(self, tensor_index): return None, None def _get_outbound(self, tensor_index): + """Retrieve the outbound node and tensor ID based on the specified tensor index, or return (None, None) if not found.""" if len(self.output_tensors) > tensor_index: tensor_id = self.output_tensors[tensor_index] if len(self.tensor_inputs[tensor_id]): @@ -252,22 +244,18 @@ def _single_node_match(self, onnx_node: Node) -> bool: if not self.check_func(onnx_node): G_LOGGER.info("No match because: check_func returned false.") return False - G_LOGGER.info( - "Single node is matched: {:}, {:}".format(self.op, onnx_node.name) - ) + G_LOGGER.info("Single node is matched: {:}, {:}".format(self.op, onnx_node.name)) return True - def _get_tensor_index_for_node( - self, node: str, tensor_id: int, is_node_input: bool - ): + def _get_tensor_index_for_node(self, node: str, tensor_id: int, is_node_input: bool): + """Returns the index of a tensor for a given node, based on whether it is an input or output tensor.""" if is_node_input: return self.node_inputs[node].index(tensor_id) else: return self.node_outputs[node].index(tensor_id) - def get_inbound_or_outbound_onnx_node( - self, mapping: PatternMapping, is_inbound: bool, tensor_index: int - ): + def get_inbound_or_outbound_onnx_node(self, mapping: PatternMapping, is_inbound: bool, tensor_index: int): + """Gets the ONNX node based on whether it's inbound or outbound for a specified tensor index and mapping.""" if self.op is not None: onnx_node = mapping._get_node() return onnx_node @@ -277,9 +265,7 @@ def get_inbound_or_outbound_onnx_node( return self.nodes[inbound_node].get_inbound_or_outbound_onnx_node( mapping[inbound_node], is_inbound=True, - tensor_index=self._get_tensor_index_for_node( - inbound_node, inbound_tensor, is_node_input=True - ), + tensor_index=self._get_tensor_index_for_node(inbound_node, inbound_tensor, is_node_input=True), ) else: @@ -288,9 +274,7 @@ def get_inbound_or_outbound_onnx_node( return self.nodes[outbound_node].get_inbound_or_outbound_onnx_node( mapping[outbound_node], is_inbound=False, - tensor_index=self._get_tensor_index_for_node( - outbound_node, outbound_tensor, is_node_input=False - ), + tensor_index=self._get_tensor_index_for_node(outbound_node, outbound_tensor, is_node_input=False), ) return None @@ -346,14 +330,8 @@ def _match_node( from_inbound: bool, ) -> bool: with G_LOGGER.indent(): - G_LOGGER.info( - "Checking node: {:} against pattern node: {:}.".format( - onnx_node.name, node_name - ) - ) - tensor_index_for_node = self._get_tensor_index_for_node( - node_name, from_tensor, is_node_input=from_inbound - ) + G_LOGGER.info("Checking node: {:} against pattern node: {:}.".format(onnx_node.name, node_name)) + tensor_index_for_node = self._get_tensor_index_for_node(node_name, from_tensor, is_node_input=from_inbound) subgraph_mapping = self.nodes[node_name].match( onnx_node, from_inbound, @@ -369,40 +347,30 @@ def _match_node( input_onnx_tensors = subgraph_mapping.inputs if len(input_onnx_tensors) != len(self.node_inputs[node_name]): return False # Number of node inputs should equal to number of input onnx tensors of the node. - for node_input_tensor, onnx_tensor in zip( - self.node_inputs[node_name], input_onnx_tensors - ): + for node_input_tensor, onnx_tensor in zip(self.node_inputs[node_name], input_onnx_tensors): if onnx_tensor is None: return False # tensor paired up. if node_input_tensor in self.input_tensors: - if not mapping.set_input_onnx_tensor( - onnx_tensor, self.input_tensors.index(node_input_tensor) - ): + if not mapping.set_input_onnx_tensor(onnx_tensor, self.input_tensors.index(node_input_tensor)): return False # this tensor is mapped to another onnx tensor continue if node_input_tensor in self.constant_tensors: if not isinstance(onnx_tensor, Constant): return False # constant tensor not match - if not mapping.set_constant_onnx_tensor( - onnx_tensor, self.constant_tensors[node_input_tensor] - ): + if not mapping.set_constant_onnx_tensor(onnx_tensor, self.constant_tensors[node_input_tensor]): # this constant tensor is mapped to another onnx tensor return False continue if len(self.tensor_inputs[node_input_tensor]) != len(onnx_tensor.inputs): return False - for input_node, input_onnx_node in zip( - self.tensor_inputs[node_input_tensor], onnx_tensor.inputs - ): + for input_node, input_onnx_node in zip(self.tensor_inputs[node_input_tensor], onnx_tensor.inputs): # dfs ends when revisiting a node. We need to check if the edges are matched. if input_node in mapping: outbound_tensor_index = self._get_tensor_index_for_node( input_node, node_input_tensor, is_node_input=False ) - outbound_onnx_node_of_input_node = self.nodes[ - input_node - ].get_inbound_or_outbound_onnx_node( + outbound_onnx_node_of_input_node = self.nodes[input_node].get_inbound_or_outbound_onnx_node( mapping[input_node], is_inbound=False, tensor_index=outbound_tensor_index, @@ -428,16 +396,12 @@ def _match_node( output_onnx_tensors = subgraph_mapping.outputs if len(output_onnx_tensors) != len(self.node_outputs[node_name]): return False # Number of node outputs should be equal to number of output onnx tensors of the node. - for node_output_tensor, onnx_tensor in zip( - self.node_outputs[node_name], output_onnx_tensors - ): + for node_output_tensor, onnx_tensor in zip(self.node_outputs[node_name], output_onnx_tensors): if onnx_tensor is None: return False # tensor matched if node_output_tensor in self.output_tensors: - if not mapping.set_output_onnx_tensor( - onnx_tensor, self.output_tensors.index(node_output_tensor) - ): + if not mapping.set_output_onnx_tensor(onnx_tensor, self.output_tensors.index(node_output_tensor)): return False # this tensor is mapped to another onnx tensor continue if onnx_tensor.name in onnx_graph_output_tensors: @@ -446,25 +410,20 @@ def _match_node( # For sub-patterns, each input tensor can only have 1 output node. Otherwise the following test will fail. if len(self.tensor_outputs[node_output_tensor]) != len(onnx_tensor.outputs): return False - for output_node, output_onnx_node in zip( - self.tensor_outputs[node_output_tensor], onnx_tensor.outputs - ): + for output_node, output_onnx_node in zip(self.tensor_outputs[node_output_tensor], onnx_tensor.outputs): # dfs ends when revisiting a node. We need to check if the edges are matched. if output_node in mapping: inbound_tensor_index = self._get_tensor_index_for_node( output_node, node_output_tensor, is_node_input=True ) - inbound_onnx_node_of_output_node = self.nodes[ - output_node - ].get_inbound_or_outbound_onnx_node( + inbound_onnx_node_of_output_node = self.nodes[output_node].get_inbound_or_outbound_onnx_node( mapping[output_node], is_inbound=True, tensor_index=inbound_tensor_index, ) if ( inbound_onnx_node_of_output_node is None - or inbound_onnx_node_of_output_node.name - != output_onnx_node.name + or inbound_onnx_node_of_output_node.name != output_onnx_node.name ): return False continue diff --git a/onnxslim/onnx_graphsurgeon/importers/onnx_importer.py b/onnxslim/onnx_graphsurgeon/importers/onnx_importer.py index f2bcc7f..453c82f 100644 --- a/onnxslim/onnx_graphsurgeon/importers/onnx_importer.py +++ b/onnxslim/onnx_graphsurgeon/importers/onnx_importer.py @@ -53,13 +53,9 @@ } -def get_onnx_tensor_shape( - onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto] -) -> List[int]: +def get_onnx_tensor_shape(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto]) -> List[int]: shape = None - if isinstance(onnx_tensor, onnx.TensorProto) or isinstance( - onnx_tensor, onnx.SparseTensorProto - ): + if isinstance(onnx_tensor, onnx.TensorProto) or isinstance(onnx_tensor, onnx.SparseTensorProto): shape = onnx_tensor.dims else: if onnx_tensor.type.tensor_type.HasField("shape"): @@ -75,10 +71,12 @@ def get_onnx_tensor_shape( def get_dtype_name(onnx_type): + """Get the ONNX data type name from its integer representation.""" return {val: key for key, val in onnx.TensorProto.DataType.items()}[onnx_type] def get_itemsize(dtype): + """Return the byte size of an element for a given ONNX data type.""" np_dtype = get_numpy_type(dtype) if np_dtype is not None: return np.dtype(np_dtype).itemsize @@ -97,6 +95,7 @@ def get_itemsize(dtype): def get_numpy_type(onnx_type): + """Convert an ONNX tensor type to a corresponding NumPy type, if supported.""" if not isinstance(onnx_type, int): # Already a NumPy type return onnx_type @@ -111,16 +110,13 @@ def get_numpy_type(onnx_type): # TENSOR_TYPE_TO_NP_TYPE maps types unsupported by NumPy to random other types. # This obviously breaks things, so we need to treat this as a special case. - if ( - onnx_type not in numpy_unsupported_types - and onnx_type in onnx.helper.get_all_tensor_dtypes() - ): + if onnx_type not in numpy_unsupported_types and onnx_type in onnx.helper.get_all_tensor_dtypes(): return onnx.helper.tensor_dtype_to_np_dtype(onnx_type) return None def get_onnx_tensor_dtype( - onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto] + onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto], ) -> Union[np.dtype, "onnx.TensorProto.DataType"]: if isinstance(onnx_tensor, onnx.TensorProto): onnx_dtype = onnx_tensor.data_type @@ -152,9 +148,7 @@ def get_onnx_tensor_dtype( return onnx_dtype -def get_onnx_tensor_type( - onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto] -) -> str: +def get_onnx_tensor_type(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto]) -> str: if isinstance(onnx_tensor, onnx.TensorProto): onnx_type = "tensor_type" else: @@ -176,9 +170,7 @@ def get_onnx_tensor_type( return onnx_type -def get_onnx_tensor_type( - onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto] -) -> str: +def get_onnx_tensor_type(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto]) -> str: if isinstance(onnx_tensor, onnx.TensorProto): onnx_type = "tensor_type" else: @@ -203,33 +195,25 @@ def get_onnx_tensor_type( class OnnxImporter(BaseImporter): @staticmethod def get_opset(model_or_func: Union[onnx.ModelProto, onnx.FunctionProto]): - class_name = ( - "Function" if isinstance(model_or_func, onnx.FunctionProto) else "Model" - ) + """Return the ONNX opset version for the given ONNX model or function, or None if the information is unavailable.""" + class_name = "Function" if isinstance(model_or_func, onnx.FunctionProto) else "Model" try: for importer in OnnxImporter.get_import_domains(model_or_func): if importer.domain == "" or importer.domain == "ai.onnx": return importer.version - G_LOGGER.warning( - f"{class_name} does not contain ONNX domain opset information! Using default opset." - ) + G_LOGGER.warning(f"{class_name} does not contain ONNX domain opset information! Using default opset.") return None except: - G_LOGGER.warning( - f"{class_name} does not contain opset information! Using default opset." - ) + G_LOGGER.warning(f"{class_name} does not contain opset information! Using default opset.") return None @staticmethod def get_import_domains(model_or_func: Union[onnx.ModelProto, onnx.FunctionProto]): + """Retrieve the opset import information from an ONNX model or function.""" return model_or_func.opset_import @staticmethod - def import_tensor( - onnx_tensor: Union[ - onnx.ValueInfoProto, onnx.TensorProto, onnx.SparseTensorProto - ] - ) -> Tensor: + def import_tensor(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto, onnx.SparseTensorProto]) -> Tensor: if isinstance(onnx_tensor, onnx.SparseTensorProto): return Constant( name=onnx_tensor.values.name, @@ -237,11 +221,7 @@ def import_tensor( data_location=onnx_tensor.values.data_location, ) elif isinstance(onnx_tensor, onnx.TensorProto): - data_location = ( - int(onnx_tensor.data_location) - if onnx_tensor.HasField("data_location") - else None - ) + data_location = int(onnx_tensor.data_location) if onnx_tensor.HasField("data_location") else None return Constant( name=onnx_tensor.name, values=LazyValues(onnx_tensor), @@ -268,6 +248,7 @@ def import_attributes( for attr in onnx_attributes: def process_attr(attr_str: str): + """Process an ONNX attribute based on its type, handling strings, tensors, graphs, and numeric sequences.""" if attr.ref_attr_name: attr_type = misc.convert_from_onnx_attr_type(attr.type) return Node.AttributeRef(attr.ref_attr_name, attr_type) @@ -295,9 +276,7 @@ def process_attr(attr_str: str): attr_dict[attr.name] = process_attr(attr_str) else: G_LOGGER.warning( - "Attribute of type {:} is currently unsupported. Skipping attribute.".format( - attr_str - ) + "Attribute of type {:} is currently unsupported. Skipping attribute.".format(attr_str) ) else: G_LOGGER.warning( @@ -317,7 +296,7 @@ def import_node( ) -> Node: # Optional inputs/outputs are represented by empty tensors. All other tensors should already have been populated during shape inference. def get_tensor(name: str, check_outer_graph=True): - # Prioritize the subgraph even if check_outer_graph is set + """Retrieve a tensor by its name, prioritizing the subgraph tensor map and optionally checking the outer graph.""" if name in subgraph_tensor_map: return subgraph_tensor_map[name] @@ -371,9 +350,7 @@ def import_function( model_import_domains: onnx.OperatorSetIdProto = None, ) -> Function: opset = OnnxImporter.get_opset(onnx_function) or model_opset - import_domains = ( - OnnxImporter.get_import_domains(onnx_function) or model_import_domains - ) + import_domains = OnnxImporter.get_import_domains(onnx_function) or model_import_domains subgraph_tensor_map = OrderedDict() # Tensors in this function def make_tensor(name: str) -> Tensor: @@ -384,9 +361,7 @@ def make_tensor(name: str) -> Tensor: function_inputs = [make_tensor(inp) for inp in onnx_function.input] function_outputs = [make_tensor(out) for out in onnx_function.output] nodes = [ - OnnxImporter.import_node( - onnx_node, dict(), subgraph_tensor_map, opset, import_domains - ) + OnnxImporter.import_node(onnx_node, dict(), subgraph_tensor_map, opset, import_domains) for onnx_node in onnx_function.node ] @@ -438,18 +413,14 @@ def import_graph( functions (List[Function]): The list of custom functions which are available to use in the model. """ functions = misc.default_value(functions, []) - tensor_map = copy.copy( - misc.default_value(tensor_map, OrderedDict()) - ) # Outer graph tensors, read-only + tensor_map = copy.copy(misc.default_value(tensor_map, OrderedDict())) # Outer graph tensors, read-only subgraph_tensor_map = OrderedDict() # Tensors in this subgraph # Retrieves a Tensor from subgraph_tensor_map or the outer graph (tensor_map) if present, otherwise imports the tensor # If overwrite=True, this function will overwrite previously imported tensors # if the new tensor has more information available. def get_tensor( - onnx_tensor: Union[ - onnx.ValueInfoProto, onnx.TensorProto, onnx.SparseTensorProto - ], + onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto, onnx.SparseTensorProto], overwrite=False, check_outer_graph=True, ) -> Tensor: @@ -462,12 +433,8 @@ def get_tensor( if overwrite: tensor = OnnxImporter.import_tensor(onnx_tensor) if isinstance(subgraph_tensor_map[name], Variable): - subgraph_tensor_map[name].dtype = ( - subgraph_tensor_map[name].dtype or tensor.dtype - ) - subgraph_tensor_map[name].shape = ( - subgraph_tensor_map[name].shape or tensor.shape - ) + subgraph_tensor_map[name].dtype = subgraph_tensor_map[name].dtype or tensor.dtype + subgraph_tensor_map[name].shape = subgraph_tensor_map[name].shape or tensor.shape return subgraph_tensor_map[name] if check_outer_graph and name in tensor_map: @@ -512,9 +479,7 @@ def get_tensor( G_LOGGER.verbose("Importing nodes") nodes = [] # List[Node] for onnx_node in onnx_graph.node: - node = OnnxImporter.import_node( - onnx_node, tensor_map, subgraph_tensor_map, opset, import_domains - ) + node = OnnxImporter.import_node(onnx_node, tensor_map, subgraph_tensor_map, opset, import_domains) nodes.append(node) return Graph( diff --git a/onnxslim/onnx_graphsurgeon/ir/function.py b/onnxslim/onnx_graphsurgeon/ir/function.py index aa2bd38..c50c827 100644 --- a/onnxslim/onnx_graphsurgeon/ir/function.py +++ b/onnxslim/onnx_graphsurgeon/ir/function.py @@ -27,15 +27,14 @@ class Function(Graph): """ - Represents a local function, which is a default implementation of a Custom Op. - This default implementation is represented as a Graph of other Ops. + Represents a local function, which is a default implementation of a Custom Op. This default implementation is + represented as a Graph of other Ops. Functions are used in a model by creating a Node with the same name and domain as the function. This can be done - using the __call__() method of a Function, which creates this new node and appends it to a Graph. - A Function is not a subgraph of a Graph, and its Nodes, Tensors, and subgraphs are entirely separate - from the main Graph. + using the __call__() method of a Function, which creates this new node and appends it to a Graph. A Function is not + a subgraph of a Graph, and its Nodes, Tensors, and subgraphs are entirely separate from the main Graph. - Functions can be composed of other functions, but cyclical or recursive defintions are not allowed in ONNX. + Functions can be composed of other functions, but cyclical or recursive definitions are not allowed in ONNX. """ DEFAULT_DOMAIN = "onnx_graphsurgeon" @@ -90,9 +89,7 @@ def __init__( @property def unique_id(self): - """ - Returns a tuple which uniquely identifies this function. - """ + """Returns a tuple which uniquely identifies this function.""" return (self.domain, self.name) def cleanup( @@ -102,9 +99,8 @@ def cleanup( remove_unused_graph_inputs=False, recurse_functions=False, ): - """ - See Graph.cleanup() - The only difference is that 'recurse_functions' defaults to False, so that only this Function is cleaned up. + """See Graph.cleanup() The only difference is that 'recurse_functions' defaults to False, so that only this + Function is cleaned up. """ if recurse_functions: G_LOGGER.warning( @@ -118,9 +114,8 @@ def cleanup( ) def fold_constants(self, recurse_functions=False, **kwargs): - """ - See Graph.fold_constants() - The only difference is that 'recurse_functions' defaults to False, so that only this Function's constants are folded. + """See Graph.fold_constants() The only difference is that 'recurse_functions' defaults to False, so that only + this Function's constants are folded. """ if recurse_functions: G_LOGGER.warning( @@ -134,10 +129,8 @@ def toposort( recurse_functions=False, mode="nodes", ): - """ - See Graph.toposort() - The only difference is that 'recurse_functions' defaults to False and mode defaults to "nodes", - so that by default only this function's nodes will be sorted. + """See Graph.toposort() The only difference is that 'recurse_functions' defaults to False and mode defaults to + "nodes", so that by default only this function's nodes will be sorted. """ if recurse_functions: G_LOGGER.warning( @@ -149,12 +142,10 @@ def toposort( mode=mode, ) - def __call__( - self, graph, inputs=None, outputs=None, *args, **kwargs - ) -> List[Tensor]: + def __call__(self, graph, inputs=None, outputs=None, *args, **kwargs) -> List[Tensor]: """ - Creates a Node which is an instance of this function. - The created node can be used in a Graph or another Function. + Creates a Node which is an instance of this function. The created node can be used in a Graph or another + Function. The provided inputs are processed the same way as in Graph.layer(). If outputs are not provided, they are created based on the Function's outputs. @@ -171,12 +162,8 @@ def __call__( List[Tensor]: The output tensors of the node. """ if inputs is not None and len(inputs) != len(self.inputs): - msg_template = ( - "Function {} expects {} inputs, but was called with {} inputs." - ) - G_LOGGER.warning( - msg_template.format(self.name, len(self.inputs), len(inputs)) - ) + msg_template = "Function {} expects {} inputs, but was called with {} inputs." + G_LOGGER.warning(msg_template.format(self.name, len(self.inputs), len(inputs))) new_output_indices = [] if outputs is None: @@ -184,16 +171,10 @@ def __call__( outputs = [out.name for out in self.outputs] new_output_indices = list(range(len(outputs))) elif len(outputs) != len(self.outputs): - msg_template = ( - "Function {} expects {} outputs, but was called with {} outputs." - ) - G_LOGGER.warning( - msg_template.format(self.name, len(self.outputs), len(outputs)) - ) + msg_template = "Function {} expects {} outputs, but was called with {} outputs." + G_LOGGER.warning(msg_template.format(self.name, len(self.outputs), len(outputs))) else: - new_output_indices = [ - i for i in range(len(outputs)) if not isinstance(outputs[i], Tensor) - ] + new_output_indices = [i for i in range(len(outputs)) if not isinstance(outputs[i], Tensor)] attrs = kwargs.get("attrs", None) if attrs is not None: @@ -213,7 +194,7 @@ def __call__( outputs=outputs, ) - # For newly created output tensors, set their shape and dtype to match the Function defintion. + # For newly created output tensors, set their shape and dtype to match the Function definition. for i in new_output_indices: outputs[i].dtype = self.outputs[i].dtype outputs[i].shape = self.outputs[i].shape @@ -235,6 +216,7 @@ def copy(self): local_tensor_copies = {n: t.copy() for n, t in self.tensors().items()} def get_tensor(name): + """Retrieve a deep-copied tensor by its name from the local tensor copies.""" if not name: return Variable.empty() return local_tensor_copies[name] @@ -267,10 +249,9 @@ def get_tensor(name): ) def __eq__(self, other: "Function"): + """Checks equality of self with another Function object based on their attributes.""" def sequences_equal(seq1, seq2): - return len(seq1) == len(seq2) and all( - [elem1 == elem2 for elem1, elem2 in zip(seq1, seq2)] - ) + return len(seq1) == len(seq2) and all([elem1 == elem2 for elem1, elem2 in zip(seq1, seq2)]) return ( self.unique_id == other.unique_id @@ -282,6 +263,7 @@ def sequences_equal(seq1, seq2): ) def __str__(self): + """Returns a string representation of the function including its name, domain, opset, inputs, nodes, and outputs.""" nodes_str = "\n".join([str(node) for node in self.nodes]) out = f"Function {self.name}, Domain {self.domain}, Opset {self.opset}" out += f"\nInputs: {self.inputs}" diff --git a/onnxslim/onnx_graphsurgeon/ir/graph.py b/onnxslim/onnx_graphsurgeon/ir/graph.py index c31df88..b3e7a1f 100644 --- a/onnxslim/onnx_graphsurgeon/ir/graph.py +++ b/onnxslim/onnx_graphsurgeon/ir/graph.py @@ -17,7 +17,7 @@ import copy import numbers -from collections import defaultdict, OrderedDict +from collections import OrderedDict, defaultdict from typing import List, Sequence import numpy as np @@ -30,23 +30,23 @@ class NodeIDAdder(object): def __init__(self, graph): + """Initializes NodeIDAdder with a specified graph.""" self.graph = graph def __enter__(self): - # To get unique ids for each node, add an `id` attribute. This will be removed before the function returns. + """Assigns unique `id` attributes to each node in the graph upon entering the context.""" # Using the index in the node list allows the same object to count as different nodes. for index, node in enumerate(self.graph.nodes): node.id = index def __exit__(self, exc_type, exc_value, traceback): + """Removes the `id` attributes from each node in the graph upon exiting the context.""" for node in self.graph.nodes: del node.id class Graph(object): - """ - Represents a graph containing nodes and tensors. - """ + """Represents a graph containing nodes and tensors.""" DEFAULT_OPSET = 11 OPSET_FUNC_MAP = defaultdict(dict) # Ops registered for specific opsets. @@ -55,8 +55,8 @@ class Graph(object): @staticmethod def register(opsets=None): """ - Registers a function with the Graph class for the specified group of opsets. - After registering the function, it can be accessed like a normal member function. + Registers a function with the Graph class for the specified group of opsets. After registering the function, it + can be accessed like a normal member function. For example: :: @@ -77,6 +77,7 @@ def add(self, a, b): """ def register_func(func): + """Registers a function for different opsets, overwriting any previously registered function with the same name.""" if hasattr(Graph, func.__name__): G_LOGGER.warning( "Registered function: {:} is hidden by a Graph attribute or function with the same name. " @@ -142,6 +143,7 @@ def __init__( G_LOGGER.ultra_verbose(lambda: "Created Graph: {:}".format(self)) def __getattr__(self, name): + """Dynamically retrieves attributes and handles multiple registered functions gracefully.""" try: return super().__getattribute__(name) except AttributeError as err: @@ -150,14 +152,9 @@ def __getattr__(self, name): method_descs = [] # Opset specific ops always take priority over global ops. - if ( - self.opset in Graph.OPSET_FUNC_MAP - and name in Graph.OPSET_FUNC_MAP[self.opset] - ): + if self.opset in Graph.OPSET_FUNC_MAP and name in Graph.OPSET_FUNC_MAP[self.opset]: methods.append(Graph.OPSET_FUNC_MAP[self.opset][name]) - method_descs.append( - f'GraphSurgeon-registered function "{name}" with opset {self.opset}' - ) + method_descs.append(f'GraphSurgeon-registered function "{name}" with opset {self.opset}') # Registered ops take priority over Local Functions. if name in Graph.GLOBAL_FUNC_MAP: @@ -167,27 +164,19 @@ def __getattr__(self, name): for func in self.functions: if func.name == name: methods.append(func.__call__) - method_descs.append( - f'Local Function "{func.name}" with domain "{func.domain}"' - ) + method_descs.append(f'Local Function "{func.name}" with domain "{func.domain}"') if methods: if len(methods) > 1: msg_template = "Method name {} is overloaded with the following candidates: {}. " msg_template += "Choosing candidate {}" G_LOGGER.warning( - message=msg_template.format( - name, method_descs, method_descs[0] - ), + message=msg_template.format(name, method_descs, method_descs[0]), mode=LogMode.ONCE, ) return lambda *args, **kwargs: methods[0](self, *args, **kwargs) - found_in_other_opsets = { - opset - for opset, opset_map in Graph.OPSET_FUNC_MAP.items() - if name in opset_map - } + found_in_other_opsets = {opset for opset, opset_map in Graph.OPSET_FUNC_MAP.items() if name in opset_map} G_LOGGER.error( f"Function: '{name}' was not registered for opset {self.opset}. " @@ -200,7 +189,7 @@ def __getattr__(self, name): raise err def __setattr__(self, name, value): - # We don't want graph inputs/outputs to be SynchronizedLists + """Sets an attribute to the given value, converting 'inputs' and 'outputs' to lists.""" if name in ["inputs", "outputs"]: value = list(value) return super().__setattr__(name, value) @@ -211,7 +200,7 @@ def functions(self) -> "List[Function]": @functions.setter def functions(self, new_fns: "Sequence[Function]"): - # The 'self._functions' list object is shared between + """Get or set the list of functions, ensuring changes propagate to all associated subgraphs and functions.""" # this graph, its subgraphs, and its functions. # If the user sets a new value for self.functions, # all subgraphs and functions should also see this new value. @@ -219,6 +208,7 @@ def functions(self, new_fns: "Sequence[Function]"): self._functions += list(new_fns) def __eq__(self, other: "Graph"): + """Check for equality of two Graph objects by comparing their nodes, inputs, and outputs.""" nodes_match = misc.sequences_equal(self.nodes, other.nodes) if not nodes_match: return False @@ -229,9 +219,7 @@ def __eq__(self, other: "Graph"): if not outputs_match: return False - opset_matches = ( - self.opset == other.opset and self.import_domains == other.import_domains - ) + opset_matches = self.opset == other.opset and self.import_domains == other.import_domains if not opset_matches: return False @@ -254,65 +242,61 @@ def node_ids(self): # Gets the node ID for a node. All internal code should use this instead of accessing `node.id` directly. def _get_node_id(self, node): + """Gets the node ID for a node, ensuring all internal code uses this instead of accessing `node.id` directly.""" try: return node.id except AttributeError: G_LOGGER.critical( "Encountered a node not in the graph:\n{:}.\n\n" - "To fix this, please append the node to this graph's `nodes` attribute.".format( - node - ) + "To fix this, please append the node to this graph's `nodes` attribute.".format(node) ) # A tensor is local if it is produced in this graph, or is explicitly a graph input. def _local_tensors(self): - local_tensors = { - t.name: t for node in self.nodes for t in node.outputs if not t.is_empty() - } + """Return a dictionary of tensors that are local to the graph, including nodes' outputs, graph inputs, and constants.""" + local_tensors = {t.name: t for node in self.nodes for t in node.outputs if not t.is_empty()} local_tensors.update({t.name: t for t in self.inputs}) - local_tensors.update( - {t.name: t for t in self.tensors().values() if isinstance(t, Constant)} - ) + local_tensors.update({t.name: t for t in self.tensors().values() if isinstance(t, Constant)}) return local_tensors # Returns tensors used by this graph which are not present in the graph. # These may come from an outer graph for example. def _foreign_tensors(self): + """Returns tensors used by this graph which are not present in the graph, potentially from an outer graph.""" local_tensors = self._local_tensors() foreign_tensors = {} def is_foreign_tensor(tensor): + """Check if a tensor is not present in the local tensors of the current graph.""" return tensor.name not in local_tensors for node in self.nodes: - foreign_tensors.update( - {t.name: t for t in node.inputs if is_foreign_tensor(t)} - ) + foreign_tensors.update({t.name: t for t in node.inputs if is_foreign_tensor(t)}) for subgraph in node.subgraphs(): subgraph_foreign_tensors = subgraph._foreign_tensors() # Some of the foreign tensors from a subgraph may come from this graph. subgraph_foreign_tensors = { - t.name: t - for t in subgraph_foreign_tensors.values() - if is_foreign_tensor(t) + t.name: t for t in subgraph_foreign_tensors.values() if is_foreign_tensor(t) } foreign_tensors.update(subgraph_foreign_tensors) return foreign_tensors def _get_used_node_ids(self): + """Retrieve a dictionary of tensors that are used by node IDs in the current subgraph.""" local_tensors = self._local_tensors() # We only want to consider tensors that are local to this graph, because we can't # remove external tensors (e.g. from outer graphs) anyway. class IgnoreDupAndForeign(object): def __init__(self, initial_tensors=None): + """Initialize IgnoreDupAndForeign with an optional list of initial_tensors.""" tensors = misc.default_value(initial_tensors, []) self.seen_tensors = set([tensor.name for tensor in tensors]) def __call__(self, tensor): - # Returns True if a tensor should included, + """Determine whether a tensor should be included based on its emptiness and presence in seen and local tensors.""" # False if it should be filtered out. if tensor.is_empty(): return True @@ -345,11 +329,12 @@ def __call__(self, tensor): return used_node_ids, used_tensors def _merge_subgraph_functions(self): - # When a user adds a Graph as a node attr, that graph will have a different + """Merge function lists of subgraphs into the parent graph's function list.""" # function list than the parent graph. This function merges those lists. func_ids = {func.unique_id for func in self.functions} def absorb_function_list(func_list): + """Absorb and merge unique functions from a provided list into the current graph's function list.""" for func in func_list: if func.unique_id not in func_ids: self.functions.append(func) @@ -366,8 +351,8 @@ def absorb_function_list(func_list): def subgraphs(self, recursive=False): """ - Convenience function to iterate over all subgraphs which are contained in this graph. - Subgraphs are found in the attributes of ONNX control flow nodes such as 'If' and 'Loop'. + Convenience function to iterate over all subgraphs which are contained in this graph. Subgraphs are found in the + attributes of ONNX control flow nodes such as 'If' and 'Loop'. Args: recursive (bool): Whether to recursively search this graph's subgraphs for more subgraphs. Defaults to False. @@ -387,8 +372,8 @@ def cleanup( recurse_functions=True, ): """ - Removes unused nodes and tensors from the graph. - A node or tensor is considered unused if it does not contribute to any of the graph outputs. + Removes unused nodes and tensors from the graph. A node or tensor is considered unused if it does not contribute + to any of the graph outputs. Additionally, any producer nodes of graph input tensors, as well as consumer nodes of graph output tensors that are not in the graph, are removed from the graph. @@ -412,6 +397,7 @@ def cleanup( """ def cleanup_subgraphs(): + """Clean up subgraphs by removing unused node outputs and graph inputs, optionally recursing into subgraphs and local functions.""" for subgraph in self.subgraphs(): subgraph.cleanup( remove_unused_node_outputs=remove_unused_node_outputs, @@ -465,10 +451,9 @@ def cleanup_subgraphs(): for node in nodes: def is_hanging_tensor(tensor): + """Checks if a tensor is hanging by verifying it has no outputs and its name is not in graph_output_names.""" return ( - not tensor.is_empty() - and len(tensor.outputs) == 0 - and tensor.name not in graph_output_names + not tensor.is_empty() and len(tensor.outputs) == 0 and tensor.name not in graph_output_names ) to_remove = [out for out in node.outputs if is_hanging_tensor(out)] @@ -523,9 +508,7 @@ def toposort( if sort_nodes and recurse_subgraphs: for subgraph in self.subgraphs(): - subgraph.toposort( - recurse_subgraphs=True, recurse_functions=False, mode="nodes" - ) + subgraph.toposort(recurse_subgraphs=True, recurse_functions=False, mode="nodes") G_LOGGER.debug("Topologically sorting {:}".format(self.name)) @@ -533,10 +516,12 @@ def toposort( # 0 corresponds to an input node, N corresponds to a node with N layers of inputs. class HierarchyDescriptor(object): def __init__(self, node_or_func, level=None): + """Initializes a HierarchyDescriptor with a node or function and an optional level in the graph hierarchy.""" self.node_or_func = node_or_func self.level = level def __lt__(self, other): + """Defines less-than comparison behavior based on hierarchy levels.""" return self.level < other.level hierarchy_levels = {} # Dict[int, HierarchyDescriptor] @@ -545,11 +530,13 @@ def __lt__(self, other): func_id_to_func = dict() def get_id(node_or_func): + """Returns the unique ID for a Node object or a function.""" if isinstance(node_or_func, Node): return self._get_node_id(node_or_func) return node_or_func.unique_id def get_hierarchy_level(node_or_func, visited=None): + """Returns the hierarchy level of a node or function, optionally using a 'visited' set to track processed elements.""" from onnxslim.onnx_graphsurgeon.ir.function import Function visited = misc.default_value(visited, set()) @@ -558,17 +545,16 @@ def get_hierarchy_level(node_or_func, visited=None): if isinstance(node_or_func, Function): G_LOGGER.critical("Cycle detected in function definitions!") - G_LOGGER.critical( - "Cycle detected in graph! Are there tensors with duplicate names in the graph?" - ) + G_LOGGER.critical("Cycle detected in graph! Are there tensors with duplicate names in the graph?") visited.add(get_id(node_or_func)) def get_inputs(node_or_func): - # Find all nodes used by this node. + """Find all nodes used by a given node or function.""" def get_used_nodes(node): inputs = {} def add_local_producers(tensor): + """Add local tensors and their producer nodes to the inputs dictionary.""" nonlocal inputs if tensor.name in local_tensors: for inp_node in tensor.inputs: @@ -586,6 +572,7 @@ def add_local_producers(tensor): # Find all functions used in this list of nodes. def get_used_funcs(nodes): + """Return a dictionary of functions used in the provided list of nodes.""" inputs = {} for subgraph in self.subgraphs(): inputs.update(get_used_funcs(subgraph.nodes)) @@ -606,25 +593,17 @@ def get_used_funcs(nodes): # The level of a node is the level of its highest input + 1. max_input_level = max( - [ - get_hierarchy_level(inp, visited=visited) - for inp in get_inputs(node_or_func) - ] - + [-1] + [get_hierarchy_level(inp, visited=visited) for inp in get_inputs(node_or_func)] + [-1] ) visited.remove(get_id(node_or_func)) - hierarchy_levels[get_id(node_or_func)] = HierarchyDescriptor( - node_or_func, level=max_input_level + 1 - ) + hierarchy_levels[get_id(node_or_func)] = HierarchyDescriptor(node_or_func, level=max_input_level + 1) return max_input_level + 1 if sort_nodes: with self.node_ids(): for node in self.nodes: - hierarchy_levels[get_id(node)] = HierarchyDescriptor( - node, level=get_hierarchy_level(node) - ) + hierarchy_levels[get_id(node)] = HierarchyDescriptor(node, level=get_hierarchy_level(node)) self.nodes = [hd.node_or_func for hd in sorted(hierarchy_levels.values())] if sort_functions: @@ -632,18 +611,15 @@ def get_used_funcs(nodes): func_id_to_func.update({func.unique_id: func for func in self.functions}) hierarchy_levels.clear() for func in self.functions: - hierarchy_levels[func.unique_id] = HierarchyDescriptor( - func, level=get_hierarchy_level(func) - ) - self.functions = [ - hd.node_or_func for hd in sorted(hierarchy_levels.values()) - ] + hierarchy_levels[func.unique_id] = HierarchyDescriptor(func, level=get_hierarchy_level(func)) + self.functions = [hd.node_or_func for hd in sorted(hierarchy_levels.values())] return self def tensors(self, check_duplicates=False): """ - Creates a tensor map of all the tensors used by this graph by walking over all nodes. Empty tensors are omitted from this map. + Creates a tensor map of all the tensors used by this graph by walking over all nodes. Empty tensors are omitted + from this map. Tensors are guaranteed to be in order of the nodes in the graph. Hence, if the graph is topologically sorted, the tensor map will be too. @@ -659,19 +635,20 @@ def tensors(self, check_duplicates=False): tensor_map = OrderedDict() def add_to_tensor_map(tensor): + """Add a tensor to the tensor map, ensuring no duplicate names exist by checking tensor IDs.""" if not tensor.is_empty(): - if tensor.name in tensor_map and not ( - tensor_map[tensor.name] is tensor - ): + if tensor.name in tensor_map and not (tensor_map[tensor.name] is tensor): msg = "Found distinct tensors that share the same name:\n[id: {:}] {:}\n[id: {:}] {:}\n".format( id(tensor_map[tensor.name]), tensor_map[tensor.name], id(tensor), tensor, ) - msg += "Note: Producer node(s) of first tensor:\n{:}\nProducer node(s) of second tensor:\n{:}".format( - tensor_map[tensor.name].inputs, - tensor.inputs, + msg += ( + "Note: Producer node(s) of first tensor:\n{:}\nProducer node(s) of second tensor:\n{:}".format( + tensor_map[tensor.name].inputs, + tensor.inputs, + ) ) if check_duplicates: @@ -705,8 +682,8 @@ def fold_constants( recurse_functions=True, ): """ - Folds constants in-place in the graph. The graph's nodes and functions must be topologically - sorted prior to calling this function (see `toposort()`). + Folds constants in-place in the graph. The graph's nodes and functions must be topologically sorted prior to + calling this function (see `toposort()`). This function will not remove constants after folding them. In order to get rid of these hanging nodes, you can run the `cleanup()` function. @@ -730,8 +707,8 @@ def fold_constants( - None: Do not partition the graph. If inference fails, no constants are folded. - "basic": Partition the graph. If inference fails in one partition, other partitions will remain unaffected. - - "recursive": Parition the graph recursively. If inference fails in a partition, the partition - will be further paritioned. + - "recursive": Partition the graph recursively. If inference fails in a partition, the partition + will be further partitioned. Defaults to None. error_ok (bool): @@ -767,12 +744,11 @@ def fold_constants( export_onnx, ) - custom_should_exclude_node = misc.default_value( - should_exclude_node, lambda node: False - ) + custom_should_exclude_node = misc.default_value(should_exclude_node, lambda node: False) # Don't fold nodes with attribute values which are variable. def should_exclude_node(node): + """Determine if an ONNX graph node should be excluded based on its attributes.""" for attr_val in node.attrs.values(): if isinstance(attr_val, Node.AttributeRef): return True @@ -780,11 +756,7 @@ def should_exclude_node(node): PARTITIONING_MODES = [None, "basic", "recursive"] if partitioning not in PARTITIONING_MODES: - G_LOGGER.critical( - "Argument for parameter 'partitioning' must be one of: {:}".format( - PARTITIONING_MODES - ) - ) + G_LOGGER.critical("Argument for parameter 'partitioning' must be one of: {:}".format(PARTITIONING_MODES)) ORT_PROVIDERS = ["CPUExecutionProvider"] G_LOGGER.debug("Folding constants in {:}".format(self.name)) @@ -805,9 +777,7 @@ def should_exclude_node(node): node = tensor.inputs[0] if node.op == "Constant": if len(node.attrs) != 1: - G_LOGGER.warning( - "Constant node must contain exactly one attribute" - ) + G_LOGGER.warning("Constant node must contain exactly one attribute") continue attr_name, attr_val = list(node.attrs.items())[0] allowed_attrs = { @@ -818,9 +788,7 @@ def should_exclude_node(node): "value_ints", } if attr_name not in allowed_attrs: - G_LOGGER.warning( - f"Unsupported attribute for Constant node: {attr_name}" - ) + G_LOGGER.warning(f"Unsupported attribute for Constant node: {attr_name}") continue if isinstance(attr_val, Node.AttributeRef): continue @@ -833,6 +801,7 @@ def should_exclude_node(node): # Pass 2: Run shape-tensor cast elision def run_cast_elision(node): + """Perform cast elision optimization on an ONNX node to eliminate unnecessary cast operations.""" import onnx # Search for Cast(s) (from int -> float) -> intermediate operator (with float constants) -> Cast(s) (back to int) @@ -868,8 +837,7 @@ def run_cast_elision(node): inp_node for inp_tensor in node.inputs for inp_node in inp_tensor.inputs - if inp_node.op == "Cast" - and inp_node.attrs["to"] == onnx.TensorProto.DataType.FLOAT + if inp_node.op == "Cast" and inp_node.attrs["to"] == onnx.TensorProto.DataType.FLOAT ] # No cast nodes found, return early @@ -877,9 +845,7 @@ def run_cast_elision(node): return # Ensure that all input cast nodes are casting from the same type - inp_dtypes = [ - dtype_to_onnx(inp_cast.inputs[0].dtype) for inp_cast in inp_casts - ] + inp_dtypes = [dtype_to_onnx(inp_cast.inputs[0].dtype) for inp_cast in inp_casts] if len(set(inp_dtypes)) != 1: return @@ -891,8 +857,7 @@ def run_cast_elision(node): for out_tensor in node.outputs for out_node in out_tensor.outputs if out_node.op == "Cast" - and out_node.attrs["to"] - in [onnx.TensorProto.DataType.INT32, onnx.TensorProto.DataType.INT64] + and out_node.attrs["to"] in [onnx.TensorProto.DataType.INT32, onnx.TensorProto.DataType.INT64] ] # No cast node found on outputs, return early @@ -912,9 +877,7 @@ def run_cast_elision(node): # `cast_node.inputs[0].outputs[0] == cast_node`. for index, inp in enumerate(node.inputs): if isinstance(inp, Constant): - inp.values = inp.values.astype( - onnx.helper.tensor_dtype_to_np_dtype(final_type) - ) + inp.values = inp.values.astype(onnx.helper.tensor_dtype_to_np_dtype(final_type)) for cast in inp_casts: if cast.outputs[0] == inp: @@ -929,9 +892,7 @@ def run_cast_elision(node): if fold_shapes: # Perform shape tensor cast elision prior to most other folding - G_LOGGER.debug( - "Performing shape tensor cast elision in {:}".format(self.name) - ) + G_LOGGER.debug("Performing shape tensor cast elision in {:}".format(self.name)) try: with self.node_ids(): for node in self.nodes: @@ -939,11 +900,7 @@ def run_cast_elision(node): except Exception as err: if not error_ok: raise err - G_LOGGER.warning( - "'{:}' routine failed with: {:}".format( - "Shape tensor cast elision", err - ) - ) + G_LOGGER.warning("'{:}' routine failed with: {:}".format("Shape tensor cast elision", err)) # Note that most of the remaining passes operate on a clone of the original graph. # Pass 3: Find all descendants of constant tensors @@ -956,6 +913,7 @@ def run_cast_elision(node): graph_clone.producer_version = "" def update_foldable_outputs(graph_constants): + """Updates the graph's outputs to ensure certain operations remain foldable.""" def is_foldable(node): NO_FOLD_OPS = [ "QuantizeLinear", @@ -966,10 +924,8 @@ def is_foldable(node): return False def all_tensors_const(tensors): - # Ignore omitted optional inputs. - return all( - [t.name in graph_constants for t in tensors if not t.is_empty()] - ) + """Check if all tensors in the given list are constants, excluding omitted optional inputs.""" + return all([t.name in graph_constants for t in tensors if not t.is_empty()]) if not all_tensors_const(node.inputs): return False @@ -977,13 +933,9 @@ def all_tensors_const(tensors): all_subgraph_foreign_tensors_const = True for subgraph in node.subgraphs(): foreign_tensors = subgraph._foreign_tensors().values() - all_subgraph_foreign_tensors_const &= all_tensors_const( - foreign_tensors - ) + all_subgraph_foreign_tensors_const &= all_tensors_const(foreign_tensors) - return all_subgraph_foreign_tensors_const and not should_exclude_node( - node - ) + return all_subgraph_foreign_tensors_const and not should_exclude_node(node) # Walks along the outputs of graph_constants to see if they can also be computed statically. # Since the graph is topologically sorted, this should find all constant nodes in the graph. @@ -992,19 +944,13 @@ def all_tensors_const(tensors): graph_constants.update({out.name: out for out in node.outputs}) return graph_constants - graph_constants = { - name: tensor - for name, tensor in clone_tensors.items() - if isinstance(tensor, Constant) - } + graph_constants = {name: tensor for name, tensor in clone_tensors.items() if isinstance(tensor, Constant)} graph_constants = update_foldable_outputs(graph_constants) # Pass 4: Shape Folding def get_producer(tensor, op): - """ - Get the producer of the specified tensor iff it matches op - """ + """Get the producer of the specified tensor iff it matches op.""" if len(tensor.inputs) != 1: return None @@ -1014,9 +960,7 @@ def get_producer(tensor, op): return node def get_input(node, index=0): - """ - Get the input tensor of a node iff the input tensor is not already marked a graph constant. - """ + """Get the input tensor of a node iff the input tensor is not already marked a graph constant.""" if node is None: return None @@ -1029,15 +973,14 @@ def get_input(node, index=0): return inp def get_scalar_value(tensor): - """ - Gets the scalar value of a constant tensor with a single item - """ + """Gets the scalar value of a constant tensor with a single item.""" if not tensor.shape: return tensor.values else: return list(tensor.values)[0] def fold_shape(tensor): + """Returns the input tensor shape if available, otherwise returns None.""" inp = get_input(get_producer(tensor, "Shape")) if inp is None: return None @@ -1047,6 +990,7 @@ def fold_shape(tensor): return np.array(inp.shape, dtype=np.int64) def fold_shape_gather(tensor): + """Retrieves and returns the shape of the input tensor as a NumPy array, otherwise returns None.""" gather = get_producer(tensor, "Gather") if gather is None: return None @@ -1074,6 +1018,7 @@ def fold_shape_gather(tensor): return np.array(shape, dtype=np.int64) def fold_shape_slice(tensor): + """Fold tensor shape slice if dynamic dimensions are present, returning numpy array of shape, or None if dynamic.""" slice = get_producer(tensor, "Slice") if slice is None: return None @@ -1130,27 +1075,20 @@ def fold_shape_slice(tensor): shape_of = shape_fold_func(tensor) if shape_of is not None: - G_LOGGER.ultra_verbose( - "Folding shape tensor: {:} to: {:}".format( - tensor.name, shape_of - ) - ) + G_LOGGER.ultra_verbose("Folding shape tensor: {:} to: {:}".format(tensor.name, shape_of)) graph_constants[tensor.name] = tensor.to_constant(shape_of) graph_constants[tensor.name].inputs.clear() except Exception as err: if not error_ok: raise err - G_LOGGER.warning( - "'{:}' routine failed with:\n{:}".format( - shape_fold_func.__name__, err - ) - ) + G_LOGGER.warning("'{:}' routine failed with:\n{:}".format(shape_fold_func.__name__, err)) else: graph_constants = update_foldable_outputs(graph_constants) # Pass 5: Evaluate all tensors descended from constants with ONNX-Runtime and replace them with constant values. def partition_and_infer(subgraph): + """Evaluates and partitions the subgraph to infer constant values using ONNX-Runtime.""" def get_out_node_ids(): # Gets the final output nodes - producer nodes of graph output tensors without other outputs. with subgraph.node_ids(): @@ -1183,11 +1121,7 @@ def get_out_node_ids(): ) values = sess.run(names, {}) except Exception as err: - G_LOGGER.warning( - "Inference failed for subgraph: {:}. Note: Error was:\n{:}".format( - part.name, err - ) - ) + G_LOGGER.warning("Inference failed for subgraph: {:}. Note: Error was:\n{:}".format(part.name, err)) if partitioning == "recursive": G_LOGGER.verbose("Attempting to recursively partition subgraph") # Partition failed, peel off last node. @@ -1197,17 +1131,13 @@ def get_out_node_ids(): out_node.outputs.clear() out_node.inputs.clear() else: - G_LOGGER.info( - "You may see better results if you set partitioning='recursive'" - ) + G_LOGGER.info("You may see better results if you set partitioning='recursive'") if not error_ok: raise err constant_values.update(partition_and_infer(part)) else: - constant_values.update( - {name: val for name, val in zip(names, values)} - ) + constant_values.update({name: val for name, val in zip(names, values)}) return constant_values @@ -1215,39 +1145,28 @@ def get_out_node_ids(): # Otherwise, if all the outputs are foldable, then we can just evaluate the outputs directly. # Additionally, if we can determine tensor size, do not evaluate tensors whose sizes exceed the size threshold. def should_eval_foldable(tensor): + """Determine if foldable values should be evaluated based on output nature and tensor size constraints.""" from onnxslim.onnx_graphsurgeon.importers.onnx_importer import get_itemsize non_const = not isinstance(tensor, Constant) is_graph_output = not tensor.outputs - has_non_foldable_outputs = any( - out.name not in graph_constants for out in tensor.outputs - ) + has_non_foldable_outputs = any(out.name not in graph_constants for out in tensor.outputs) exceeds_size_threshold = ( tensor.shape is not None and not misc.is_dynamic_shape(tensor.shape) and tensor.dtype is not None and size_threshold is not None - ) and ( - misc.volume(tensor.shape) * get_itemsize(tensor.dtype) > size_threshold - ) + ) and (misc.volume(tensor.shape) * get_itemsize(tensor.dtype) > size_threshold) - return ( - non_const - and (is_graph_output or has_non_foldable_outputs) - and not exceeds_size_threshold - ) + return non_const and (is_graph_output or has_non_foldable_outputs) and not exceeds_size_threshold - graph_clone.outputs = [ - t for t in graph_constants.values() if should_eval_foldable(t) - ] + graph_clone.outputs = [t for t in graph_constants.values() if should_eval_foldable(t)] G_LOGGER.debug("Folding tensors: {:}".format(graph_clone.outputs)) graph_clone.cleanup(remove_unused_graph_inputs=True, recurse_functions=False) # Using ._values avoids a deep copy of the values. constant_values = { - name: tensor._values - for name, tensor in graph_constants.items() - if isinstance(tensor, Constant) + name: tensor._values for name, tensor in graph_constants.items() if isinstance(tensor, Constant) } if graph_clone.outputs: if partitioning: @@ -1258,15 +1177,11 @@ def should_eval_foldable(tensor): import onnxruntime as onnxrt sess = onnxrt.InferenceSession( - export_onnx( - graph_clone, do_type_check=False - ).SerializeToString(), + export_onnx(graph_clone, do_type_check=False).SerializeToString(), providers=ORT_PROVIDERS, ) values = sess.run(names, {}) - constant_values.update( - {name: val for name, val in zip(names, values)} - ) + constant_values.update({name: val for name, val in zip(names, values)}) except Exception as err: G_LOGGER.warning( "Inference failed. You may want to try enabling partitioning to see better results. " @@ -1307,20 +1222,18 @@ def should_eval_foldable(tensor): if large_tensors: large_tensors_mib = { - tensor_name: "{:} MiB".format(value // (1 << 20)) - for tensor_name, value in large_tensors.items() + tensor_name: "{:} MiB".format(value // (1 << 20)) for tensor_name, value in large_tensors.items() } G_LOGGER.warning( "It looks like this model contains foldable nodes that produce large outputs.\n" "In order to avoid bloating the model, you may want to set a constant-folding size threshold.\n" - "Note: Large tensors and their corresponding sizes were: {:}".format( - large_tensors_mib - ), + "Note: Large tensors and their corresponding sizes were: {:}".format(large_tensors_mib), mode=LogMode.ONCE, ) # Folding subgraphs after the outer graph can lead to better folding. def fold_subgraphs(): + """Folds constants within subgraphs to optimize the performance of the outer graph.""" for subgraph in self.subgraphs(): subgraph.fold_constants( fold_shapes=fold_shapes, @@ -1343,9 +1256,7 @@ def fold_subgraphs(): if node.op == "If" and isinstance(node.inputs[0], Constant): G_LOGGER.debug("Flattening conditional: {:}".format(node)) cond = get_scalar_value(node.inputs[0]) - subgraph = ( - node.attrs["then_branch"] if cond else node.attrs["else_branch"] - ) + subgraph = node.attrs["then_branch"] if cond else node.attrs["else_branch"] # Need to add a suffix to subgraph tensors so they don't collide with outer graph tensors for tensor in subgraph._local_tensors().values(): tensor.name += "_subg_{:}_{:}".format(index, subgraph.name) @@ -1385,7 +1296,7 @@ def fold_subgraphs(): return self def _generate_name(self, prefix: str, existing_names: set): - # `existing_names` will ensure that generated name does not clash existing names. + """Generate a unique name by appending an index to the given prefix, ensuring it does not clash with existing names.""" # Generation is done by appending an index to the prefix. while True: name = "{}_{}".format(prefix, self.name_idx) @@ -1430,7 +1341,7 @@ def layer(self, inputs=None, outputs=None, *args, **kwargs): outputs = misc.default_value(outputs, []) def process_io(io, existing_names): - # Note: modifies `existing_names` in-place + """Processes input/output elements, converting them to Tensor, Variable, or Constant, and ensuring unique names.""" new_io = [] for elem in io: if isinstance(elem, Tensor): @@ -1440,27 +1351,15 @@ def process_io(io, existing_names): tensor = Variable(name=name) new_io.append(tensor) elif isinstance(elem, np.ndarray): - name = self._generate_name( - "onnx_graphsurgeon_constant", existing_names - ) + name = self._generate_name("onnx_graphsurgeon_constant", existing_names) new_io.append(Constant(name=name, values=elem)) - elif ( - isinstance(elem, list) - or isinstance(elem, tuple) - or isinstance(elem, numbers.Number) - ): + elif isinstance(elem, list) or isinstance(elem, tuple) or isinstance(elem, numbers.Number): if isinstance(elem, list) or isinstance(elem, tuple): - dtype = ( - np.float32 - if any([isinstance(x, float) for x in elem]) - else np.int64 - ) + dtype = np.float32 if any([isinstance(x, float) for x in elem]) else np.int64 else: dtype = np.float32 if isinstance(elem, float) else np.int64 arr = np.array(elem, dtype=dtype) - name = self._generate_name( - "onnx_graphsurgeon_lst_constant", existing_names - ) + name = self._generate_name("onnx_graphsurgeon_lst_constant", existing_names) new_io.append(Constant(name=name, values=arr)) else: G_LOGGER.critical( @@ -1477,9 +1376,7 @@ def process_io(io, existing_names): outputs = process_io(outputs, existing_names) if "name" not in kwargs: - kwargs["name"] = self._generate_name( - "onnx_graphsurgeon_node", {node.name for node in self.nodes} - ) + kwargs["name"] = self._generate_name("onnx_graphsurgeon_node", {node.name for node in self.nodes}) node = Node(*args, **kwargs, inputs=inputs, outputs=outputs) self.nodes.append(node) @@ -1511,11 +1408,10 @@ def copy(self, tensor_map: "OrderedDict[str, Tensor]" = None): # However, we should prioritize copies already made by the outer graph. local_tensor_copies.update(tensor_map) # And locally produced tensors should take precedence over everything else. - local_tensor_copies.update( - {n: t.copy() for n, t in self._local_tensors().items()} - ) + local_tensor_copies.update({n: t.copy() for n, t in self._local_tensors().items()}) def get_tensor(name): + """Retrieve a tensor by its name from local copies, or return an empty variable if no name is provided.""" if not name: return Variable.empty() return local_tensor_copies[name] @@ -1544,6 +1440,7 @@ def get_tensor(name): ) def __str__(self): + """Return a string representation of the graph including its name, opset, local functions, inputs, nodes, and outputs.""" nodes_str = "\n".join([str(node) for node in self.nodes]) functions_str = ",".join([str(func.name) for func in self.functions]) out = f"Graph {self.name} (Opset {self.opset})" @@ -1554,4 +1451,5 @@ def __str__(self): return out def __repr__(self): + """Returns a string representation of the object.""" return self.__str__() diff --git a/onnxslim/onnx_graphsurgeon/ir/node.py b/onnxslim/onnx_graphsurgeon/ir/node.py index 307d20a..4b0a9e5 100644 --- a/onnxslim/onnx_graphsurgeon/ir/node.py +++ b/onnxslim/onnx_graphsurgeon/ir/node.py @@ -25,12 +25,11 @@ class Node(object): - @dataclass class AttributeRef: """ - An AttributeRef is an attribute value which references an attribute in the parent function. - A node's attribute can only be an AttributeRef if the node lives inside a Function. + An AttributeRef is an attribute value which references an attribute in the parent function. A node's attribute + can only be an AttributeRef if the node lives inside a Function. Args: name (str): The name of the referenced attribute in the parent Function. @@ -64,18 +63,14 @@ def __init__( self.op = op self.name = misc.default_value(name, "") self.attrs = misc.default_value(attrs, OrderedDict()) - self.inputs = misc.SynchronizedList( - self, field_name="outputs", initial=misc.default_value(inputs, []) - ) - self.outputs = misc.SynchronizedList( - self, field_name="inputs", initial=misc.default_value(outputs, []) - ) + self.inputs = misc.SynchronizedList(self, field_name="outputs", initial=misc.default_value(inputs, [])) + self.outputs = misc.SynchronizedList(self, field_name="inputs", initial=misc.default_value(outputs, [])) self.domain = domain def i(self, tensor_idx=0, producer_idx=0): """ - Convenience function to get a producer node of one of this node's input tensors. - Note that the parameters are swapped compared to the o() function; this is because tensors are likely to have only a single producer + Convenience function to get a producer node of one of this node's input tensors. Note that the parameters are + swapped compared to the o() function; this is because tensors are likely to have only a single producer. For example: :: @@ -113,8 +108,8 @@ def o(self, consumer_idx=0, tensor_idx=0): def subgraphs(self, recursive=False): """ - Convenience function to iterate over all subgraphs which are contained in this node. - Node subgraphs are found in attributes of ONNX control flow nodes such as 'If' and 'Loop'. + Convenience function to iterate over all subgraphs which are contained in this node. Node subgraphs are found in + attributes of ONNX control flow nodes such as 'If' and 'Loop'. Args: recursive (bool): Whether to recurse into the subgraph nodes when looking for subgraphs. Defaults to False. @@ -139,6 +134,7 @@ def subgraphs(self, recursive=False): yield attr def __setattr__(self, name, value): + """Sets the attribute 'name' to 'value', handling special cases for 'inputs' and 'outputs' attributes.""" if name in ["inputs", "outputs"]: try: attr = getattr(self, name) @@ -184,9 +180,11 @@ def copy( ) def __str__(self): + """Return a string representation of the object showing its name and operation.""" ret = "{:} ({:})".format(self.name, self.op) def add_io(name, io): + """Add the input or output operations and their names to the string representation of the object.""" nonlocal ret ret += "\n\t{:}: [".format(name) for elem in io: @@ -205,18 +203,13 @@ def add_io(name, io): return ret def __repr__(self): + """Return the string representation of the Ultralytics object.""" return self.__str__() def __eq__(self, other): - """ - Check whether two nodes are equal by comparing name, attributes, op, inputs, and outputs. - """ + """Check whether two nodes are equal by comparing name, attributes, op, inputs, and outputs.""" G_LOGGER.verbose("Comparing node: {:} with {:}".format(self.name, other.name)) - attrs_match = ( - self.name == other.name - and self.op == other.op - and self.attrs == other.attrs - ) + attrs_match = self.name == other.name and self.op == other.op and self.attrs == other.attrs if not attrs_match: return False diff --git a/onnxslim/onnx_graphsurgeon/ir/tensor.py b/onnxslim/onnx_graphsurgeon/ir/tensor.py index b90da39..f6ae28e 100644 --- a/onnxslim/onnx_graphsurgeon/ir/tensor.py +++ b/onnxslim/onnx_graphsurgeon/ir/tensor.py @@ -24,17 +24,16 @@ class Tensor(object): - """Abstract base class for tensors in a graph""" + """Abstract base class for tensors in a graph.""" DYNAMIC = -1 def __init__(self): - """ - **This class is abstract and cannot be constructed directly.** - """ + """**This class is abstract and cannot be constructed directly.**""" raise NotImplementedError("Tensor is an abstract class") def __setattr__(self, name, value): + """Set an attribute, ensuring special handling for "inputs" and "outputs" properties.""" if name in ["inputs", "outputs"]: try: attr = getattr(self, name) @@ -69,7 +68,8 @@ def to_constant( export_dtype: Union[np.dtype, "onnx.TensorProto.DataType"] = None, ): """ - Modifies this tensor in-place to convert it to a Constant. This means that all consumers/producers of the tensor will see the update. + Modifies this tensor in-place to convert it to a Constant. This means that all consumers/producers of the tensor + will see the update. Args: values (np.ndarray): The values in this tensor @@ -95,7 +95,8 @@ def to_variable( shape: Sequence[Union[int, str]] = [], ): """ - Modifies this tensor in-place to convert it to a Variable. This means that all consumers/producers of the tensor will see the update. + Modifies this tensor in-place to convert it to a Variable. This means that all consumers/producers of the tensor + will see the update. Args: dtype (Union[numpy.dtype, onnx.TensorProto.DataType]): The data type of the tensor. @@ -115,8 +116,8 @@ def to_variable( def i(self, tensor_idx=0, producer_idx=0): """ - Convenience function to get an input tensor of one of this tensor's input nodes. - Note that the parameters are swapped compared to the o() function; this is because tensors are likely to have only a single producer + Convenience function to get an input tensor of one of this tensor's input nodes. Note that the parameters are + swapped compared to the o() function; this is because tensors are likely to have only a single producer. For example: :: @@ -153,11 +154,11 @@ def o(self, consumer_idx=0, tensor_idx=0): return self.outputs[consumer_idx].outputs[tensor_idx] def __str__(self): - return "{:} ({:}): (shape={:}, dtype={:})".format( - type(self).__name__, self.name, self.shape, self.dtype - ) + """Returns a string representation of the object including its type, name, shape, and data type.""" + return "{:} ({:}): (shape={:}, dtype={:})".format(type(self).__name__, self.name, self.shape, self.dtype) def __repr__(self): # Hack to make logging output pretty. + """Returns a string representation of the object for logging output.""" return self.__str__() def __eq__(self, other): @@ -172,6 +173,7 @@ def __eq__(self, other): class Variable(Tensor): @staticmethod def empty(): + """Create and return an empty Variable tensor with an empty name.""" return Variable(name="") def __init__( @@ -216,44 +218,27 @@ def copy(self): return Variable(self.name, self.dtype, self.shape) def __eq__(self, other): - """ - Perform a check to see if two variables are equal. - """ + """Perform a check to see if two variables are equal.""" if not isinstance(other, Variable): return False name_match = self.name == other.name inputs_match = len(self.inputs) == len(other.inputs) and all( - [ - inp.name == other_inp.name - for inp, other_inp in zip(self.inputs, other.inputs) - ] + [inp.name == other_inp.name for inp, other_inp in zip(self.inputs, other.inputs)] ) outputs_match = len(self.outputs) == len(other.outputs) and all( - [ - out.name == other_out.name - for out, other_out in zip(self.outputs, other.outputs) - ] + [out.name == other_out.name for out, other_out in zip(self.outputs, other.outputs)] ) dtype_match = self.dtype == other.dtype shape_match = self.shape == other.shape type_match = self.type == other.type - return ( - name_match - and inputs_match - and outputs_match - and dtype_match - and shape_match - and type_match - ) + return name_match and inputs_match and outputs_match and dtype_match and shape_match and type_match class LazyValues(object): - """ - A special object that represents constant tensor values that should be lazily loaded. - """ + """A special object that represents constant tensor values that should be lazily loaded.""" def __init__(self, tensor): """ @@ -297,15 +282,15 @@ def load(self): return np.array(onnx.numpy_helper.to_array(self.tensor)) def __str__(self): + """Returns a string representation of the LazyValues object with its shape and dtype.""" return "LazyValues (shape={:}, dtype={:})".format(self.shape, self.dtype) def __repr__(self): # Hack to make logging output pretty. + """Returns a string representation of the LazyValues object for logging purposes.""" return self.__str__() def __eq__(self, other): - """ - Perform a check to see if two variables are equal. - """ + """Perform a check to see if two variables are equal.""" if not isinstance(other, LazyValues): return False @@ -317,9 +302,7 @@ def __eq__(self, other): class SparseValues(LazyValues): - """ - A special object that represents constant tensor values that is sparse - """ + """A special object that represents constant tensor values that is sparse.""" def load(self): """ @@ -343,9 +326,7 @@ def load(self): ) if self.tensor.values.data_type == onnx.TensorProto.FLOAT16: - values_data = np.asarray( - self.tensor.values.int32_data, dtype=np.uint16 - ).view(np.float16) + values_data = np.asarray(self.tensor.values.int32_data, dtype=np.uint16).view(np.float16) else: field_name = onnx.helper.tensor_dtype_to_field(self.tensor.values.data_type) values = getattr(self.tensor.values, field_name) @@ -366,13 +347,12 @@ def load(self): for i in range(len(values_data)): values[tuple(indices_data[i])] = values_data[i] else: - G_LOGGER.critical( - f"Unsupported index data dims {self.tensor.indices.dims} in {self.tensor.values.name}" - ) + G_LOGGER.critical(f"Unsupported index data dims {self.tensor.indices.dims} in {self.tensor.values.name}") return values def __str__(self): + """Return a string representation of the SparseValues object with its shape and data type.""" return "SparseValues (shape={:}, dtype={:})".format(self.shape, self.dtype) @@ -411,17 +391,14 @@ def __init__( G_LOGGER.critical( "Provided `values` argument is not a NumPy array, a LazyValues instance or a" "SparseValues instance. Please provide a NumPy array or LazyValues instance " - "to construct a Constant. Note: Provided `values` parameter was: {:}".format( - values - ) + "to construct a Constant. Note: Provided `values` parameter was: {:}".format(values) ) self._values = values self.data_location = data_location self._export_dtype = export_dtype - def to_variable( - self, dtype: np.dtype = None, shape: Sequence[Union[int, str]] = [] - ): + def to_variable(self, dtype: np.dtype = None, shape: Sequence[Union[int, str]] = []): + """Convert instance values to a variable with specified dtype and shape attributes.""" var_dtype = self.export_dtype del self._export_dtype @@ -442,25 +419,29 @@ def copy(self): @property def values(self): - # Load values when they are first accesed + """Returns the values of the tensor, loading them if they are accessed for the first time.""" if isinstance(self._values, LazyValues): self._values = self._values.load() return self._values @values.setter def values(self, values: Union[np.ndarray, LazyValues]): + """Returns the values of the tensor, loading them if accessed for the first time.""" self._values = values @property def shape(self): + """Returns the shape of the tensor values.""" return self._values.shape @property def dtype(self): + """Returns the data type of the tensor values.""" return self._values.dtype @property def export_dtype(self): + """Returns the export data type of the tensor values if specified, otherwise defaults to the tensor's dtype.""" if self._export_dtype is not None: return self._export_dtype @@ -468,23 +449,21 @@ def export_dtype(self): @export_dtype.setter def export_dtype(self, export_dtype): + """Returns the export data type of the tensor values if specified, otherwise defaults to the tensor's dtype.""" self._export_dtype = export_dtype def __repr__(self): # Hack to make logging output pretty. + """Provides a string representation of the object including its values for enhanced logging output readability.""" ret = self.__str__() ret += "\n{:}".format(self._values) return ret def __eq__(self, other): - """ - Perform a check to see if two variables are equal. - """ + """Perform a check to see if two variables are equal.""" if not isinstance(other, Constant): return False - if isinstance(self._values, LazyValues) and isinstance( - other._values, LazyValues - ): + if isinstance(self._values, LazyValues) and isinstance(other._values, LazyValues): value_match = self._values == other._values else: value_match = np.array_equal(self.values, other.values) diff --git a/onnxslim/onnx_graphsurgeon/logger/logger.py b/onnxslim/onnx_graphsurgeon/logger/logger.py index 6e46bd0..71f0282 100644 --- a/onnxslim/onnx_graphsurgeon/logger/logger.py +++ b/onnxslim/onnx_graphsurgeon/logger/logger.py @@ -16,7 +16,6 @@ # import enum - import inspect import os import sys @@ -28,30 +27,36 @@ # Context manager to apply indentation to messages class LoggerIndent(object): def __init__(self, logger, indent): + """Initialize the LoggerIndent context manager with the specified logger and indentation level.""" self.logger = logger self.old_indent = self.logger.logging_indent self.indent = indent def __enter__(self): + """Set logger indentation level on entering the context.""" self.logger.logging_indent = self.indent return self def __exit__(self, exc_type, exc_value, traceback): + """Reset logger indentation level on exiting the context.""" self.logger.logging_indent = self.old_indent # Context manager to suppress messages class LoggerSuppress(object): def __init__(self, logger, severity): + """Initialize a LoggerSuppress object with a logger and severity level.""" self.logger = logger self.old_severity = self.logger.severity self.severity = severity def __enter__(self): + """Set logger severity to a specified level when entering the context.""" self.logger.severity = self.severity return self def __exit__(self, exc_type, exc_value, traceback): + """Reset logger severity to its original level when exiting the context.""" self.logger.severity = self.old_severity @@ -89,9 +94,7 @@ class Logger(object): CRITICAL: "red_1", } - def __init__( - self, severity=INFO, colors=True, letter=True, timestamp=False, line_info=False - ): + def __init__(self, severity=INFO, colors=True, letter=True, timestamp=False, line_info=False): """ Logger. @@ -104,9 +107,7 @@ def __init__( """ self._severity = severity self.logging_indent = 0 - self.root_dir = os.path.abspath( - os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) - ) + self.root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) self.once_logged = set() self.colors = colors self.letter = letter @@ -116,18 +117,20 @@ def __init__( @property def severity(self): + """Returns the logging severity level.""" return self._severity @severity.setter def severity(self, value): + """Returns or sets the logging severity level with callback updates.""" self._severity = value for callback in self.logger_callbacks: callback(self._severity) def register_callback(self, callback): """ - Registers a callback with the logger, which will be invoked when the logging severity is modified. - The callback is guaranteed to be called at least once in the register_callback function. + Registers a callback with the logger, which will be invoked when the logging severity is modified. The callback + is guaranteed to be called at least once in the register_callback function. Args: callback (Callable(Logger.Severity)): A callback that accepts the current logger severity. @@ -136,9 +139,7 @@ def register_callback(self, callback): self.logger_callbacks.append(callback) def indent(self, level=1): - """ - Returns a context manager that indents all strings logged by the specified amount. - """ + """Returns a context manager that indents all strings logged by the specified amount.""" return LoggerIndent(self, level + self.logging_indent) def suppress(self, severity=CRITICAL): @@ -153,6 +154,7 @@ def suppress(self, severity=CRITICAL): # If once is True, the logger will only log this message a single time. Useful in loops. # message may be a callable which returns a message. This way, only if the message needs to be logged is it ever generated. def log(self, message, severity, mode=LogMode.EACH, stack_depth=2): + """Logs the given message with specified severity and log mode, optionally including stack trace depth for line information.""" def process_message(message, stack_depth): def get_prefix(): def get_line_info(): @@ -165,9 +167,7 @@ def get_line_info(): # If the file is not located in trt_smeagol, use its basename instead. if os.pardir in filename: filename = os.path.basename(filename) - return "[{:}:{:}] ".format( - filename, sys._getframe(stack_depth).f_lineno - ) + return "[{:}:{:}] ".format(filename, sys._getframe(stack_depth).f_lineno) prefix = "" if self.letter: @@ -179,12 +179,12 @@ def get_line_info(): return prefix def apply_indentation(message): + """Indent each line in the message by the specified logging_indent level.""" message_lines = str(message).splitlines() - return "\n".join( - ["\t" * self.logging_indent + line for line in message_lines] - ) + return "\n".join(["\t" * self.logging_indent + line for line in message_lines]) def apply_color(message): + """Apply color formatting to the message if color support is enabled.""" if self.colors: try: import colored @@ -204,6 +204,7 @@ def apply_color(message): return apply_color("{:}{:}".format(prefix, message)) def should_log(message): + """Determines if a message should be logged based on the severity level and logging mode.""" should = severity >= self._severity if mode == LogMode.ONCE: message_hash = hash(message) @@ -220,25 +221,32 @@ def should_log(message): print(process_message(message, stack_depth=stack_depth)) def ultra_verbose(self, message, mode=LogMode.EACH): + """Logs an ultra-verbose message with a specified logging mode and stack depth of 3.""" self.log(message, Logger.ULTRA_VERBOSE, mode=mode, stack_depth=3) def verbose(self, message, mode=LogMode.EACH): + """Logs a verbose message with a specified logging mode and stack depth of 3.""" self.log(message, Logger.VERBOSE, mode=mode, stack_depth=3) def debug(self, message, mode=LogMode.EACH): + """Logs a debug message with a specified logging mode and stack depth of 3.""" self.log(message, Logger.DEBUG, mode=mode, stack_depth=3) def info(self, message, mode=LogMode.EACH): + """Logs an informational message with a specified logging mode and stack depth of 3.""" self.log(message, Logger.INFO, mode=mode, stack_depth=3) def warning(self, message, mode=LogMode.EACH): + """Logs a warning message with a specified logging mode and stack depth of 3.""" self.log(message, Logger.WARNING, mode=mode, stack_depth=3) def error(self, message, mode=LogMode.EACH): + """Logs an error message with a specified logging mode and stack depth of 3.""" self.log(message, Logger.ERROR, mode=mode, stack_depth=3) # Like error, but immediately exits. def critical(self, message): + """Logs a critical error message with a stack depth of 3 and raises an OnnxGraphSurgeonException.""" self.log(message, Logger.CRITICAL, stack_depth=3) raise OnnxGraphSurgeonException(message) from None # Erase exception chain diff --git a/onnxslim/onnx_graphsurgeon/util/exception.py b/onnxslim/onnx_graphsurgeon/util/exception.py index addf20a..deb84a6 100644 --- a/onnxslim/onnx_graphsurgeon/util/exception.py +++ b/onnxslim/onnx_graphsurgeon/util/exception.py @@ -17,6 +17,4 @@ class OnnxGraphSurgeonException(Exception): - """ - An exception raised by ONNX-GraphSurgeon. - """ + """An exception raised by ONNX-GraphSurgeon.""" diff --git a/onnxslim/onnx_graphsurgeon/util/misc.py b/onnxslim/onnx_graphsurgeon/util/misc.py index 97b1424..113d387 100644 --- a/onnxslim/onnx_graphsurgeon/util/misc.py +++ b/onnxslim/onnx_graphsurgeon/util/misc.py @@ -53,12 +53,15 @@ # >>> y.value # [] def default_value(value, default): + """Return the value if not None, otherwise return the default value.""" return value if value is not None else default def combine_dicts(dict0, dict1): """ - Combine two dictionaries. Values in the second will overwrite values in the first. + Combine two dictionaries. + + Values in the second will overwrite values in the first. """ combined = OrderedDict() combined.update(dict0) @@ -67,14 +70,17 @@ def combine_dicts(dict0, dict1): def is_dynamic_dimension(dim): + """Check if a dimension is dynamic (non-integer or negative).""" return not isinstance(dim, int) or dim < 0 def is_dynamic_shape(shape): + """Determine if any dimension in the given shape is dynamic (non-integer or negative).""" return any(is_dynamic_dimension(dim) for dim in shape) def volume(obj): + """Calculate the volume by multiplying the elements of the given iterable object.""" vol = 1 for elem in obj: vol *= elem @@ -87,6 +93,7 @@ def volume(obj): # This method prevents circular import of Tensor and Graph def _init_dicts(): + """Initialize the mapping dictionaries between ONNX attribute types and GraphSurgeon types.""" global _ONNX_ATTR_TYPE_TO_GS_TYPE global _GS_TYPE_TO_ONNX_ATTR_TYPE if _ONNX_ATTR_TYPE_TO_GS_TYPE and _GS_TYPE_TO_ONNX_ATTR_TYPE: @@ -116,11 +123,13 @@ def _init_dicts(): def convert_from_onnx_attr_type(onnx_attr_type): + """Converts an ONNX attribute type to its corresponding GS attribute type.""" _init_dicts() return _ONNX_ATTR_TYPE_TO_GS_TYPE[onnx_attr_type] def convert_to_onnx_attr_type(any_type): + """Converts a given type to its corresponding ONNX attribute type.""" _init_dicts() if any_type in _GS_TYPE_TO_ONNX_ATTR_TYPE: return _GS_TYPE_TO_ONNX_ATTR_TYPE[any_type] @@ -139,69 +148,83 @@ def convert_to_onnx_attr_type(any_type): # See test_ir.TestNodeIO for functional tests class SynchronizedList(list): def __init__(self, parent_obj, field_name, initial): + """Initialize a SynchronizedList with a parent object, a field name, and an initial set of elements.""" self.parent_obj = parent_obj self.field_name = field_name self.extend(initial) def _add_to_elem(self, elem): - # Explicitly avoid SynchronizedList overrides to prevent infinite recursion + """Append the parent_obj to the list attribute defined by field_name in the provided elem object.""" list.append(getattr(elem, self.field_name), self.parent_obj) def _remove_from_elem(self, elem): - # Explicitly avoid SynchronizedList overrides to prevent infinite recursion + """Remove the parent_obj from the list attribute defined by field_name in the provided elem object.""" list.remove(getattr(elem, self.field_name), self.parent_obj) def __delitem__(self, index): + """Remove the element at the specified index and update the corresponding list attribute in the parent object.""" self._remove_from_elem(self[index]) super().__delitem__(index) def __setitem__(self, index, elem): + """Update the element at the specified index and modify the associated list attribute in the parent object.""" self._remove_from_elem(self[index]) super().__setitem__(index, elem) self._add_to_elem(elem) def append(self, x): + """Append an element to the list and update the associated list attribute in the parent object.""" super().append(x) self._add_to_elem(x) def extend(self, iterable: Sequence[object]): + """Extend the list with elements from the iterable and update the associated list attribute in the parent object.""" super().extend(iterable) for elem in iterable: self._add_to_elem(elem) def insert(self, i, x): + """Insert an element at a specified position and update the associated list attribute in the parent object.""" super().insert(i, x) self._add_to_elem(x) def remove(self, x): + """Remove an element and update the associated list attribute in the parent object.""" super().remove(x) self._remove_from_elem(x) def pop(self, i=-1): + """Remove the element at a given position in the list and update the associated list attribute in the parent object.""" elem = super().pop(i) self._remove_from_elem(elem) return elem def clear(self): + """Clear all elements from the list and update the associated list attribute in the parent object.""" for elem in self: self._remove_from_elem(elem) super().clear() def __add__(self, other_list: List[object]): + """Concatenate the current list with another list and return the resulting list.""" return list(self) + list(other_list) def __iadd__(self, other_list: List[object]): + """Append elements from another list to the current list and return the modified list.""" self.extend(other_list) return self def __copy__(self): + """Return a shallow copy of the current list.""" return list(self) def __deepcopy__(self, memo): + """Return a deep copy of the current list.""" return list(self) def sequences_equal(seq1, seq2): + """Check if two sequences are equal by comparing their lengths and elements.""" length_match = len(seq1) == len(seq2) if not length_match: return False diff --git a/onnxslim/utils/tabulate.py b/onnxslim/utils/tabulate.py index 7e79985..e080f37 100644 --- a/onnxslim/utils/tabulate.py +++ b/onnxslim/utils/tabulate.py @@ -10,7 +10,8 @@ from collections.abc import Iterable, Sized from functools import partial, reduce from html import escape as htmlescape -from itertools import chain, zip_longest as izip_longest +from itertools import chain +from itertools import zip_longest as izip_longest try: import wcwidth # optional wide-character (CJK) support @@ -19,6 +20,7 @@ def _is_file(f): + """Check if an object 'f' is an instance of io.IOBase.""" return isinstance(f, io.IOBase) @@ -103,17 +105,18 @@ def _is_file(f): def _is_separating_line(row): + """Determine if a row is a separating line based on its type and specific content criteria.""" row_type = type(row) is_sl = (row_type == list or row_type == str) and ( - (len(row) >= 1 and row[0] == SEPARATING_LINE) - or (len(row) >= 2 and row[1] == SEPARATING_LINE) + (len(row) >= 1 and row[0] == SEPARATING_LINE) or (len(row) >= 2 and row[1] == SEPARATING_LINE) ) return is_sl def _pipe_segment_with_colons(align, colwidth): - """Return a segment of a horizontal line with optional colons which - indicate column's alignment (as in `pipe` output format).""" + """Return a segment of a horizontal line with optional colons which indicate column's alignment (as in `pipe` output + format). + """ w = colwidth if align in ["right", "decimal"]: return ("-" * (w - 1)) + ":" @@ -126,8 +129,7 @@ def _pipe_segment_with_colons(align, colwidth): def _pipe_line_with_colons(colwidths, colaligns): - """Return a horizontal line with optional colons to indicate column's - alignment (as in `pipe` output format).""" + """Return a horizontal line with optional colons to indicate column's alignment (as in `pipe` output format).""" if not colaligns: # e.g. printing an empty data frame (github issue #15) colaligns = [""] * len(colwidths) segments = [_pipe_segment_with_colons(a, w) for a, w in zip(colaligns, colwidths)] @@ -135,6 +137,7 @@ def _pipe_line_with_colons(colwidths, colaligns): def _mediawiki_row_with_attrs(separator, cell_values, colwidths, colaligns): + """Returns a MediaWiki table row with specific alignment attributes for each cell based on given parameters.""" alignment = { "left": "", "right": 'style="text-align: right;"| ', @@ -143,14 +146,13 @@ def _mediawiki_row_with_attrs(separator, cell_values, colwidths, colaligns): } # hard-coded padding _around_ align attribute and value together # rather than padding parameter which affects only the value - values_with_attrs = [ - " " + alignment.get(a, "") + c + " " for c, a in zip(cell_values, colaligns) - ] + values_with_attrs = [" " + alignment.get(a, "") + c + " " for c, a in zip(cell_values, colaligns)] colsep = separator * 2 return (separator + colsep.join(values_with_attrs)).rstrip() def _textile_row_with_attrs(cell_values, colwidths, colaligns): + """Generate a Textile-formatted table row with specified cell values, column widths, and alignments.""" cell_values[0] += " " alignment = {"left": "<.", "right": ">.", "center": "=.", "decimal": ">."} values = (alignment.get(a, "") + v for a, v in zip(colaligns, cell_values)) @@ -158,11 +160,12 @@ def _textile_row_with_attrs(cell_values, colwidths, colaligns): def _html_begin_table_without_header(colwidths_ignore, colaligns_ignore): - # this table header will be suppressed if there is a header row + """Generate the beginning of an HTML table without a header row.""" return "\n" def _html_row_with_attrs(celltag, unsafe, cell_values, colwidths, colaligns): + """Generate an HTML table row with specified attributes for each cell.""" alignment = { "left": "", "right": ' style="text-align: right;"', @@ -171,8 +174,7 @@ def _html_row_with_attrs(celltag, unsafe, cell_values, colwidths, colaligns): } if unsafe: values_with_attrs = [ - "<{0}{1}>{2}".format(celltag, alignment.get(a, ""), c) - for c, a in zip(cell_values, colaligns) + "<{0}{1}>{2}".format(celltag, alignment.get(a, ""), c) for c, a in zip(cell_values, colaligns) ] else: values_with_attrs = [ @@ -186,6 +188,7 @@ def _html_row_with_attrs(celltag, unsafe, cell_values, colwidths, colaligns): def _moin_row_with_attrs(celltag, cell_values, colwidths, colaligns, header=""): + """Generate a row of HTML table cells with specified attributes like alignment and headers.""" alignment = { "left": "", "right": '', @@ -193,39 +196,33 @@ def _moin_row_with_attrs(celltag, cell_values, colwidths, colaligns, header=""): "decimal": '', } values_with_attrs = [ - "{}{} {} ".format(celltag, alignment.get(a, ""), header + c + header) - for c, a in zip(cell_values, colaligns) + "{}{} {} ".format(celltag, alignment.get(a, ""), header + c + header) for c, a in zip(cell_values, colaligns) ] return "".join(values_with_attrs) + "||" def _latex_line_begin_tabular(colwidths, colaligns, booktabs=False, longtable=False): + """Generate LaTeX tabular or longtable environment start with specified column widths, alignments, and booktabs option.""" alignment = {"left": "l", "right": "r", "center": "c", "decimal": "r"} tabular_columns_fmt = "".join([alignment.get(a, "l") for a in colaligns]) return "\n".join( [ - ("\\begin{tabular}{" if not longtable else "\\begin{longtable}{") - + tabular_columns_fmt - + "}", + ("\\begin{tabular}{" if not longtable else "\\begin{longtable}{") + tabular_columns_fmt + "}", "\\toprule" if booktabs else "\\hline", ] ) def _asciidoc_row(is_header, *args): - """handle header and data rows for asciidoc format""" + """Handle header and data rows for asciidoc format.""" def make_header_line(is_header, colwidths, colaligns): # generate the column specifiers alignment = {"left": "<", "right": ">", "center": "^", "decimal": ">"} # use the column widths generated by tabulate for the asciidoc column width specifiers - asciidoc_alignments = zip( - colwidths, [alignment[colalign] for colalign in colaligns] - ) - asciidoc_column_specifiers = [ - "{:d}{}".format(width, align) for width, align in asciidoc_alignments - ] + asciidoc_alignments = zip(colwidths, [alignment[colalign] for colalign in colaligns]) + asciidoc_column_specifiers = ["{:d}{}".format(width, align) for width, align in asciidoc_alignments] header_list = ['cols="' + (",".join(asciidoc_column_specifiers)) + '"'] # generate the list of options (currently only "header") @@ -282,6 +279,7 @@ def make_header_line(is_header, colwidths, colaligns): def _latex_row(cell_values, colwidths, colaligns, escrules=LATEX_ESCAPE_RULES): + """Generates a LaTeX table row with escaped special characters based on provided cell values, column widths, and alignments.""" def escape_char(c): return escrules.get(c, c) @@ -291,6 +289,7 @@ def escape_char(c): def _rst_escape_first_column(rows, headers): + """Escapes empty values in the first column of rows and headers for reStructuredText (RST) formatting.""" def escape_empty(val): if isinstance(val, (str, bytes)) and not val.strip(): return ".." @@ -772,18 +771,16 @@ def escape_empty(val): _ansi_codes_bytes = re.compile(_ansi_escape_pat.encode("utf8"), re.VERBOSE) _ansi_color_reset_code = "\033[0m" -_float_with_thousands_separators = re.compile( - r"^(([+-]?[0-9]{1,3})(?:,([0-9]{3}))*)?(?(1)\.[0-9]*|\.[0-9]+)?$" -) +_float_with_thousands_separators = re.compile(r"^(([+-]?[0-9]{1,3})(?:,([0-9]{3}))*)?(?(1)\.[0-9]*|\.[0-9]+)?$") def simple_separated_format(separator): - """Construct a simple TableFormat with columns separated by a separator. + """ + Construct a simple TableFormat with columns separated by a separator. >>> tsv = simple_separated_format("\\t") ; \ tabulate([["foo", 1], ["spam", 23]], tablefmt=tsv) == 'foo \\t 1\\nspam\\t23' True - """ return TableFormat( None, @@ -831,6 +828,7 @@ def _isnumber_with_thousands_separator(string): def _isconvertible(conv, string): + """Check if a string can be converted to a specified type without raising a ValueError or TypeError.""" try: conv(string) return True @@ -853,9 +851,7 @@ def _isnumber(string): """ if not _isconvertible(float, string): return False - elif isinstance(string, (str, bytes)) and ( - math.isinf(float(string)) or math.isnan(float(string)) - ): + elif isinstance(string, (str, bytes)) and (math.isinf(float(string)) or math.isnan(float(string))): return string.lower() in ["inf", "-inf", "nan"] return True @@ -873,9 +869,7 @@ def _isint(string, inttype=int): (hasattr(string, "is_integer") or hasattr(string, "__array__")) and str(type(string)).startswith(">> _isbool(1) False """ - return type(string) is bool or ( - isinstance(string, (bytes, str)) and string in ("True", "False") - ) + return type(string) is bool or (isinstance(string, (bytes, str)) and string in ("True", "False")) def _type(string, has_invisible=True, numparse=True): - """The least generic type (type(None), int, float, str, unicode). + """ + The least generic type (type(None), int, float, str, unicode). >>> _type(None) is type(None) True @@ -906,7 +899,6 @@ def _type(string, has_invisible=True, numparse=True): True >>> _type('\x1b[31m42\x1b[0m') is type(42) True - """ if has_invisible and isinstance(string, (str, bytes)): @@ -929,7 +921,8 @@ def _type(string, has_invisible=True, numparse=True): def _afterpoint(string): - """Symbols after a decimal point, -1 if the string lacks the decimal point. + """ + Symbols after a decimal point, -1 if the string lacks the decimal point. >>> _afterpoint("123.45") 2 @@ -941,7 +934,6 @@ def _afterpoint(string): 2 >>> _afterpoint("123,456.78") 2 - """ if _isnumber(string) or _isnumber_with_thousands_separator(string): if _isint(string): @@ -958,44 +950,46 @@ def _afterpoint(string): def _padleft(width, s): - """Flush right. + """ + Flush right. >>> _padleft(6, '\u044f\u0439\u0446\u0430') == ' \u044f\u0439\u0446\u0430' True - """ fmt = "{0:>%ds}" % width return fmt.format(s) def _padright(width, s): - """Flush left. + """ + Flush left. >>> _padright(6, '\u044f\u0439\u0446\u0430') == '\u044f\u0439\u0446\u0430 ' True - """ fmt = "{0:<%ds}" % width return fmt.format(s) def _padboth(width, s): - """Center string. + """ + Center string. >>> _padboth(6, '\u044f\u0439\u0446\u0430') == ' \u044f\u0439\u0446\u0430 ' True - """ fmt = "{0:^%ds}" % width return fmt.format(s) def _padnone(ignore_width, s): + """Returns the input string without padding.""" return s def _strip_ansi(s): - r"""Remove ANSI escape sequences, both CSI (color codes, etc) and OSC hyperlinks. + r""" + Remove ANSI escape sequences, both CSI (color codes, etc) and OSC hyperlinks. CSI sequences are simply removed from the output, while OSC hyperlinks are replaced with the link text. Note: it may be desirable to show the URI instead but this is not @@ -1006,7 +1000,6 @@ def _strip_ansi(s): >>> repr(_strip_ansi('\x1b[31mred\x1b[0m text')) "'red text'" - """ if isinstance(s, str): return _ansi_codes.sub(r"\4", s) @@ -1015,11 +1008,11 @@ def _strip_ansi(s): def _visible_width(s): - """Visible width of a printed string. ANSI color codes are removed. + """ + Visible width of a printed string. ANSI color codes are removed. >>> _visible_width('\x1b[31mhello\x1b[0m'), _visible_width("world") (5, 5) - """ # optional wide-character support if wcwidth is not None and WIDE_CHARS_MODE: @@ -1033,6 +1026,7 @@ def _visible_width(s): def _is_multiline(s): + """Check if the input string or bytestring contains multiline ANSI escape codes.""" if isinstance(s, str): return bool(re.search(_multiline_codes, s)) else: # a bytestring @@ -1060,6 +1054,7 @@ def _choose_width_fn(has_invisible, enable_widechars, is_multiline): def _align_column_choose_padfn(strings, alignment, has_invisible): + """Selects the appropriate padding function based on alignment and visibility of invisible characters for given strings.""" if alignment == "right": if not PRESERVE_WHITESPACE: strings = [s.strip() for s in strings] @@ -1086,6 +1081,7 @@ def _align_column_choose_padfn(strings, alignment, has_invisible): def _align_column_choose_width_fn(has_invisible, enable_widechars, is_multiline): + """Choose the appropriate width function for aligning text columns based on visibility, wide characters support, and multiline status.""" if has_invisible: line_width_fn = _visible_width elif enable_widechars: # optional wide-character support if available @@ -1105,6 +1101,7 @@ def _align_column_multiline_width(multiline_s, line_width_fn=len): def _flat_list(nested_list): + """Flatten a nested list into a single list.""" ret = [] for item in nested_list: if isinstance(item, list): @@ -1125,26 +1122,18 @@ def _align_column( ): """[string] -> [padded_string]""" strings, padfn = _align_column_choose_padfn(strings, alignment, has_invisible) - width_fn = _align_column_choose_width_fn( - has_invisible, enable_widechars, is_multiline - ) + width_fn = _align_column_choose_width_fn(has_invisible, enable_widechars, is_multiline) s_widths = list(map(width_fn, strings)) maxwidth = max(max(_flat_list(s_widths)), minwidth) # TODO: refactor column alignment in single-line and multiline modes if is_multiline: if not enable_widechars and not has_invisible: - padded_strings = [ - "\n".join([padfn(maxwidth, s) for s in ms.splitlines()]) - for ms in strings - ] + padded_strings = ["\n".join([padfn(maxwidth, s) for s in ms.splitlines()]) for ms in strings] else: # enable wide-character width corrections s_lens = [[len(s) for s in re.split("[\r\n]", ms)] for ms in strings] - visible_widths = [ - [maxwidth - (w - l) for w, l in zip(mw, ml)] - for mw, ml in zip(s_widths, s_lens) - ] + visible_widths = [[maxwidth - (w - l) for w, l in zip(mw, ml)] for mw, ml in zip(s_widths, s_lens)] # wcswidth and _visible_width don't count invisible characters; # padfn doesn't need to apply another correction padded_strings = [ @@ -1165,6 +1154,7 @@ def _align_column( def _more_generic(type1, type2): + """Return the more generic type between type1 and type2 based on a predefined hierarchy.""" types = { type(None): 0, bool: 1, @@ -1186,7 +1176,8 @@ def _more_generic(type1, type2): def _column_type(strings, has_invisible=True, numparse=True): - """The least generic type all column values are convertible to. + """ + The least generic type all column values are convertible to. >>> _column_type([True, False]) is bool True @@ -1205,14 +1196,14 @@ def _column_type(strings, has_invisible=True, numparse=True): >>> import datetime as dt >>> _column_type([dt.datetime(1991,2,19), dt.time(17,35)]) is str True - """ types = [_type(s, has_invisible, numparse) for s in strings] return reduce(_more_generic, types, bool) def _format(val, valtype, floatfmt, intfmt, missingval="", has_invisible=True): - """Format a value according to its type. + """ + Format a value according to its type. Unicode is supported: @@ -1221,7 +1212,6 @@ def _format(val, valtype, floatfmt, intfmt, missingval="", has_invisible=True): good_result = '\\u0431\\u0443\\u043a\\u0432\\u0430 \\u0446\\u0438\\u0444\\u0440\\u0430\\n------- -------\\n\\u0430\\u0437 2\\n\\u0431\\u0443\\u043a\\u0438 4' ; \ tabulate(tbl, headers=hrow) == good_result True - """ # noqa if val is None: return missingval @@ -1247,15 +1237,12 @@ def _format(val, valtype, floatfmt, intfmt, missingval="", has_invisible=True): return f"{val}" -def _align_header( - header, alignment, width, visible_width, is_multiline=False, width_fn=None -): +def _align_header(header, alignment, width, visible_width, is_multiline=False, width_fn=None): + """Pad string header to width chars given known visible_width of the header.""" "Pad string header to width chars given known visible_width of the header." if is_multiline: header_lines = re.split(_multiline_codes, header) - padded_lines = [ - _align_header(h, alignment, width, width_fn(h)) for h in header_lines - ] + padded_lines = [_align_header(h, alignment, width, width_fn(h)) for h in header_lines] return "\n".join(padded_lines) # else: not multiline ninvisible = len(header) - visible_width @@ -1271,6 +1258,7 @@ def _align_header( def _remove_separating_lines(rows): + """Remove rows that are separating lines and return the filtered rows along with indices of separating lines if input is a list.""" if type(rows) == list: separating_lines = [] sans_rows = [] @@ -1285,6 +1273,7 @@ def _remove_separating_lines(rows): def _reinsert_separating_lines(rows, separating_lines): + """Reinserts separating lines back into their original positions in the rows.""" if separating_lines: for index in separating_lines: rows.insert(index, SEPARATING_LINE) @@ -1311,6 +1300,7 @@ def _prepend_row_index(rows, index): def _bool(val): + """Convert a value to a boolean without throwing an exception on NumPy arrays.""" "A wrapper around standard bool() which doesn't throw on NumPy arrays" try: return bool(val) @@ -1319,7 +1309,8 @@ def _bool(val): def _normalize_tabular_data(tabular_data, headers, showindex="default"): - """Transform a supported data type to a list of lists, and a list of headers, with headers padding. + """ + Transform a supported data type to a list of lists, and a list of headers, with headers padding. Supported tabular data types: @@ -1348,7 +1339,6 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): If showindex="always", show row indices for all types of data. If showindex="never", don't show row indices for all types of data. If showindex is an iterable, show its values as row indices. - """ try: @@ -1364,16 +1354,11 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): if hasattr(tabular_data.values, "__call__"): # likely a conventional dict keys = tabular_data.keys() - rows = list( - izip_longest(*tabular_data.values()) - ) # columns have to be transposed + rows = list(izip_longest(*tabular_data.values())) # columns have to be transposed elif hasattr(tabular_data, "index"): # values is a property, has .index => it's likely a pandas.DataFrame (pandas 0.11.0) keys = list(tabular_data) - if ( - showindex in ["default", "always", True] - and tabular_data.index.name is not None - ): + if showindex in ["default", "always", True] and tabular_data.index.name is not None: if isinstance(tabular_data.index.name, list): keys[:0] = tabular_data.index.name else: @@ -1394,19 +1379,10 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): if headers == "keys" and not rows: # an empty table (issue #81) headers = [] - elif ( - headers == "keys" - and hasattr(tabular_data, "dtype") - and getattr(tabular_data.dtype, "names") - ): + elif headers == "keys" and hasattr(tabular_data, "dtype") and getattr(tabular_data.dtype, "names"): # numpy record array headers = tabular_data.dtype.names - elif ( - headers == "keys" - and len(rows) > 0 - and isinstance(rows[0], tuple) - and hasattr(rows[0], "_fields") - ): + elif headers == "keys" and len(rows) > 0 and isinstance(rows[0], tuple) and hasattr(rows[0], "_fields"): # namedtuple headers = list(map(str, rows[0]._fields)) elif len(rows) > 0 and hasattr(rows[0], "keys") and hasattr(rows[0], "values"): @@ -1437,9 +1413,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): else: headers = [] elif headers: - raise ValueError( - "headers for a list of dicts is not a dict or a keyword" - ) + raise ValueError("headers for a list of dicts is not a dict or a keyword") rows = [[row.get(k) for k in keys] for row in rows] elif ( @@ -1452,11 +1426,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): # print tabulate(cursor, headers='keys') headers = [column[0] for column in tabular_data.description] - elif ( - dataclasses is not None - and len(rows) > 0 - and dataclasses.is_dataclass(rows[0]) - ): + elif dataclasses is not None and len(rows) > 0 and dataclasses.is_dataclass(rows[0]): # Python 3.7+'s dataclass field_names = [field.name for field in dataclasses.fields(rows[0])] if headers == "keys": @@ -1508,6 +1478,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): def _wrap_text_to_colwidths(list_of_lists, colwidths, numparses=True): + """Wrap text in each cell of a list of lists to fit specified column widths, optionally parsing numbers.""" if len(list_of_lists): num_cols = len(list_of_lists[0]) else: @@ -1528,14 +1499,8 @@ def _wrap_text_to_colwidths(list_of_lists, colwidths, numparses=True): # Cast based on our internal type handling # Any future custom formatting of types (such as datetimes) # may need to be more explicit than just `str` of the object - casted_cell = ( - str(cell) if _isnumber(cell) else _type(cell, False, numparse)(cell) - ) - wrapped = [ - "\n".join(wrapper.wrap(line)) - for line in casted_cell.splitlines() - if line.strip() != "" - ] + casted_cell = str(cell) if _isnumber(cell) else _type(cell, False, numparse)(cell) + wrapped = ["\n".join(wrapper.wrap(line)) for line in casted_cell.splitlines() if line.strip() != ""] new_row.append("\n".join(wrapped)) else: new_row.append(cell) @@ -1588,7 +1553,8 @@ def tabulate( rowalign=None, maxheadercolwidths=None, ): - """Format a fixed width table for pretty printing. + """ + Format a fixed width table for pretty printing. >>> print(tabulate([[1, 2.34], [-56, "8.999"], ["2", "10001"]])) --- --------- @@ -2079,15 +2045,12 @@ def tabulate( +------------+------------+-------------------------------+ Header column width can be specified in a similar way using `maxheadercolwidth` - """ if tabular_data is None: tabular_data = [] - list_of_lists, headers, headers_pad = _normalize_tabular_data( - tabular_data, headers, showindex=showindex - ) + list_of_lists, headers, headers_pad = _normalize_tabular_data(tabular_data, headers, showindex=showindex) list_of_lists, separating_lines = _remove_separating_lines(list_of_lists) if maxcolwidths is not None: @@ -2101,23 +2064,17 @@ def tabulate( maxcolwidths = _expand_iterable(maxcolwidths, num_cols, None) numparses = _expand_numparse(disable_numparse, num_cols) - list_of_lists = _wrap_text_to_colwidths( - list_of_lists, maxcolwidths, numparses=numparses - ) + list_of_lists = _wrap_text_to_colwidths(list_of_lists, maxcolwidths, numparses=numparses) if maxheadercolwidths is not None: num_cols = len(list_of_lists[0]) if isinstance(maxheadercolwidths, int): # Expand scalar for all columns - maxheadercolwidths = _expand_iterable( - maxheadercolwidths, num_cols, maxheadercolwidths - ) + maxheadercolwidths = _expand_iterable(maxheadercolwidths, num_cols, maxheadercolwidths) else: # Ignore col width for any 'trailing' columns maxheadercolwidths = _expand_iterable(maxheadercolwidths, num_cols, None) numparses = _expand_numparse(disable_numparse, num_cols) - headers = _wrap_text_to_colwidths( - [headers], maxheadercolwidths, numparses=numparses - )[0] + headers = _wrap_text_to_colwidths([headers], maxheadercolwidths, numparses=numparses)[0] # empty values in the first column of RST tables should be escaped (issue #82) # "" should be escaped as "\\ " or ".." @@ -2156,11 +2113,7 @@ def tabulate( has_invisible = _ansi_codes.search(plain_text) is not None enable_widechars = wcwidth is not None and WIDE_CHARS_MODE - if ( - not isinstance(tablefmt, TableFormat) - and tablefmt in multiline_formats - and _is_multiline(plain_text) - ): + if not isinstance(tablefmt, TableFormat) and tablefmt in multiline_formats and _is_multiline(plain_text): tablefmt = multiline_formats.get(tablefmt, tablefmt) is_multiline = True else: @@ -2172,17 +2125,13 @@ def tabulate( numparses = _expand_numparse(disable_numparse, len(cols)) coltypes = [_column_type(col, numparse=np) for col, np in zip(cols, numparses)] if isinstance(floatfmt, str): # old version - float_formats = len(cols) * [ - floatfmt - ] # just duplicate the string to use in each column + float_formats = len(cols) * [floatfmt] # just duplicate the string to use in each column else: # if floatfmt is list, tuple etc we have one per column float_formats = list(floatfmt) if len(float_formats) < len(cols): float_formats.extend((len(cols) - len(float_formats)) * [_DEFAULT_FLOATFMT]) if isinstance(intfmt, str): # old version - int_formats = len(cols) * [ - intfmt - ] # just duplicate the string to use in each column + int_formats = len(cols) * [intfmt] # just duplicate the string to use in each column else: # if intfmt is list, tuple etc we have one per column int_formats = list(intfmt) if len(int_formats) < len(cols): @@ -2195,9 +2144,7 @@ def tabulate( missing_vals.extend((len(cols) - len(missing_vals)) * [_DEFAULT_MISSINGVAL]) cols = [ [_format(v, ct, fl_fmt, int_fmt, miss_v, has_invisible) for v in c] - for c, ct, fl_fmt, int_fmt, miss_v in zip( - cols, coltypes, float_formats, int_formats, missing_vals - ) + for c, ct, fl_fmt, int_fmt, miss_v in zip(cols, coltypes, float_formats, int_formats, missing_vals) ] # align columns @@ -2206,7 +2153,7 @@ def tabulate( aligns = [colglobalalign] * len(cols) else: # default aligns = [numalign if ct in [int, float] else stralign for ct in coltypes] - # then specific alignements + # then specific alignments if colalign is not None: assert isinstance(colalign, Iterable) if isinstance(colalign, str): @@ -2219,9 +2166,7 @@ def tabulate( break elif align != "global": aligns[idx] = align - minwidths = ( - [width_fn(h) + min_padding for h in headers] if headers else [0] * len(cols) - ) + minwidths = [width_fn(h) + min_padding for h in headers] if headers else [0] * len(cols) cols = [ _align_column(c, a, minw, has_invisible, enable_widechars, is_multiline) for c, a, minw in zip(cols, aligns, minwidths) @@ -2236,7 +2181,7 @@ def tabulate( aligns_headers = [headersglobalalign] * len(t_cols) else: # default aligns_headers = aligns or [stralign] * len(headers) - # then specific header alignements + # then specific header alignments if headersalign is not None: assert isinstance(headersalign, Iterable) if isinstance(headersalign, str): @@ -2252,10 +2197,7 @@ def tabulate( aligns_headers[hidx] = aligns[hidx] elif align != "global": aligns_headers[hidx] = align - minwidths = [ - max(minw, max(width_fn(cl) for cl in c)) - for minw, c in zip(minwidths, t_cols) - ] + minwidths = [max(minw, max(width_fn(cl) for cl in c)) for minw, c in zip(minwidths, t_cols)] headers = [ _align_header(h, a, minw, width_fn(h), is_multiline, width_fn) for h, a, minw in zip(headers, aligns_headers, minwidths) @@ -2286,8 +2228,9 @@ def tabulate( def _expand_numparse(disable_numparse, column_count): """ - Return a list of bools of length `column_count` which indicates whether - number parsing should be used on each column. + Return a list of bools of length `column_count` which indicates whether number parsing should be used on each + column. + If `disable_numparse` is a list of indices, each of those indices are False, and everything else is True. If `disable_numparse` is a bool, then the returned list is all the same. @@ -2303,8 +2246,9 @@ def _expand_numparse(disable_numparse, column_count): def _expand_iterable(original, num_desired, default): """ - Expands the `original` argument to return a return a list of - length `num_desired`. If `original` is shorter than `num_desired`, it will + Expands the `original` argument to return a return a list of length `num_desired`. + + If `original` is shorter than `num_desired`, it will be padded with the value in `default`. If `original` is not a list to begin with (i.e. scalar value) a list of length `num_desired` completely populated with `default will be returned @@ -2316,6 +2260,7 @@ def _expand_iterable(original, num_desired, default): def _pad_row(cells, padding): + """Pads the strings in a list `cells` with spaces of length `padding` on both sides.""" if cells: pad = " " * padding padded_cells = [pad + cell + pad for cell in cells] @@ -2325,12 +2270,14 @@ def _pad_row(cells, padding): def _build_simple_row(padded_cells, rowfmt): + """Format a list of padded cells into a table row according to the specified DataRow format.""" "Format row according to DataRow format without padding." begin, sep, end = rowfmt return (begin + sep.join(padded_cells) + end).rstrip() def _build_row(padded_cells, colwidths, colaligns, rowfmt): + """Format a list of padded cells into a table row according to specified format or custom row formatting function.""" "Return a string which represents a row of data cells." if not rowfmt: return None @@ -2341,12 +2288,13 @@ def _build_row(padded_cells, colwidths, colaligns, rowfmt): def _append_basic_row(lines, padded_cells, colwidths, colaligns, rowfmt, rowalign=None): - # NOTE: rowalign is ignored and exists for api compatibility with _append_multiline_row + """Append a formatted row to the lines list using the provided cell data and column specifications.""" lines.append(_build_row(padded_cells, colwidths, colaligns, rowfmt)) return lines def _align_cell_veritically(text_lines, num_lines, column_width, row_alignment): + """Adjust vertical alignment of text lines within a column based on the specified row alignment.""" delta_lines = num_lines - len(text_lines) blank = [" " * column_width] if row_alignment == "bottom": @@ -2359,9 +2307,8 @@ def _align_cell_veritically(text_lines, num_lines, column_width, row_alignment): return text_lines + blank * delta_lines -def _append_multiline_row( - lines, padded_multiline_cells, padded_widths, colaligns, rowfmt, pad, rowalign=None -): +def _append_multiline_row(lines, padded_multiline_cells, padded_widths, colaligns, rowfmt, pad, rowalign=None): + """Append a multiline row to the table lines with specified alignments and padding.""" colwidths = [w - 2 * pad for w in padded_widths] cells_lines = [c.splitlines() for c in padded_multiline_cells] nlines = max(map(len, cells_lines)) # number of lines in the row @@ -2370,10 +2317,7 @@ def _append_multiline_row( # (cl + [" " * w] * (nlines - len(cl))) for cl, w in zip(cells_lines, colwidths) # ] - cells_lines = [ - _align_cell_veritically(cl, nlines, w, rowalign) - for cl, w in zip(cells_lines, colwidths) - ] + cells_lines = [_align_cell_veritically(cl, nlines, w, rowalign) for cl, w in zip(cells_lines, colwidths)] lines_cells = [[cl[i] for cl in cells_lines] for i in range(nlines)] for ln in lines_cells: padded_ln = _pad_row(ln, pad) @@ -2382,6 +2326,7 @@ def _append_multiline_row( def _build_line(colwidths, colaligns, linefmt): + """Return a string representing a horizontal line formatted with column widths and alignments using the specified format.""" "Return a string which represents a horizontal line." if not linefmt: return None @@ -2394,26 +2339,25 @@ def _build_line(colwidths, colaligns, linefmt): def _append_line(lines, colwidths, colaligns, linefmt): + """Append a formatted line to the list of lines based on column widths, alignments, and line format.""" lines.append(_build_line(colwidths, colaligns, linefmt)) return lines class JupyterHTMLStr(str): - """Wrap the string with a _repr_html_ method so that Jupyter - displays the HTML table""" + """Wrap the string with a _repr_html_ method so that Jupyter displays the HTML table.""" def _repr_html_(self): + """Return the HTML representation of the JupyterHTMLStr object for proper display in Jupyter Notebooks.""" return self @property def str(self): - """add a .str property so that the raw string is still accessible""" + """Add a .str property so that the raw string is still accessible.""" return self -def _format_table( - fmt, headers, headersaligns, rows, colwidths, colaligns, is_multiline, rowaligns -): +def _format_table(fmt, headers, headersaligns, rows, colwidths, colaligns, is_multiline, rowaligns): """Produce a plain-text representation of the table.""" lines = [] hidden = fmt.with_header_hide if (headers and fmt.with_header_hide) else [] @@ -2442,9 +2386,7 @@ def _format_table( if padded_rows and fmt.linebetweenrows and "linebetweenrows" not in hidden: # initial rows with a line below for row, ralign in zip(padded_rows[:-1], rowaligns): - append_row( - lines, row, padded_widths, colaligns, fmt.datarow, rowalign=ralign - ) + append_row(lines, row, padded_widths, colaligns, fmt.datarow, rowalign=ralign) _append_line(lines, padded_widths, colaligns, fmt.linebetweenrows) # the last row without a line below append_row( @@ -2457,11 +2399,7 @@ def _format_table( ) else: separating_line = ( - fmt.linebetweenrows - or fmt.linebelowheader - or fmt.linebelow - or fmt.lineabove - or Line("", "", "", "") + fmt.linebetweenrows or fmt.linebelowheader or fmt.linebelow or fmt.lineabove or Line("", "", "", "") ) for row in padded_rows: # test to see if either the 1st column or the 2nd column (account for showindex) has @@ -2485,7 +2423,10 @@ def _format_table( class _CustomTextWrap(textwrap.TextWrapper): - """A custom implementation of CPython's textwrap.TextWrapper. This supports + """ + A custom implementation of CPython's textwrap.TextWrapper. + + This supports both wide characters (Korea, Japanese, Chinese) - including mixed string. For the most part, the `_handle_long_word` and `_wrap_chunks` functions were copy pasted out of the CPython baseline, and updated with our custom length @@ -2493,14 +2434,14 @@ class _CustomTextWrap(textwrap.TextWrapper): """ def __init__(self, *args, **kwargs): + """Initialize the wrapper with support for wide characters and custom length logic.""" self._active_codes = [] self.max_lines = None # For python2 compatibility textwrap.TextWrapper.__init__(self, *args, **kwargs) @staticmethod def _len(item): - """Custom len that gets console column width for wide - and non-wide characters as well as ignores color codes""" + """Custom len that gets console column width for wide and non-wide characters as well as ignores color codes.""" stripped = _strip_ansi(item) if wcwidth: return wcwidth.wcswidth(stripped) @@ -2508,15 +2449,12 @@ def _len(item): return len(stripped) def _update_lines(self, lines, new_line): - """Adds a new line to the list of lines the text is being wrapped into - This function will also track any ANSI color codes in this string as well - as add any colors from previous lines order to preserve the same formatting + """Adds a new line to the list of lines the text is being wrapped into This function will also track any ANSI + color codes in this string as well as add any colors from previous lines order to preserve the same formatting as a single unwrapped string. """ code_matches = [x for x in _ansi_codes.finditer(new_line)] - color_codes = [ - code.string[code.span()[0] : code.span()[1]] for code in code_matches - ] + color_codes = [code.string[code.span()[0] : code.span()[1]] for code in code_matches] # Add color codes from earlier in the unwrapped line, and then track any new ones we add. new_line = "".join(self._active_codes) + new_line @@ -2527,7 +2465,7 @@ def _update_lines(self, lines, new_line): else: # A single reset code resets everything self._active_codes = [] - # Always ensure each line is color terminted if any colors are + # Always ensure each line is color terminated if any colors are # still active, otherwise colors will bleed into other cells on the console if len(self._active_codes) > 0: new_line = new_line + _ansi_color_reset_code @@ -2573,9 +2511,11 @@ def _handle_long_word(self, reversed_chunks, cur_line, cur_len, width): # devoted to the long word that we can't handle right now. def _wrap_chunks(self, chunks): - """_wrap_chunks(chunks : [string]) -> [string] - Wrap a sequence of text chunks and return a list of lines of - length 'self.width' or less. (If 'break_long_words' is false, + """ + _wrap_chunks(chunks : [string]) -> [string] Wrap a sequence of text chunks and return a list of lines of length + 'self.width' or less. + + (If 'break_long_words' is false, some lines may be longer than this.) Chunks correspond roughly to words and the whitespace between them: each chunk is indivisible (modulo 'break_long_words'), but a line break can @@ -2646,12 +2586,7 @@ def _wrap_chunks(self, chunks): if ( self.max_lines is None or len(lines) + 1 < self.max_lines - or ( - not chunks - or self.drop_whitespace - and len(chunks) == 1 - and not chunks[0].strip() - ) + or (not chunks or self.drop_whitespace and len(chunks) == 1 and not chunks[0].strip()) and cur_len <= width ): # Convert current line back to a string and store it in @@ -2659,10 +2594,7 @@ def _wrap_chunks(self, chunks): self._update_lines(lines, indent + "".join(cur_line)) else: while cur_line: - if ( - cur_line[-1].strip() - and cur_len + self._len(self.placeholder) <= width - ): + if cur_line[-1].strip() and cur_len + self._len(self.placeholder) <= width: cur_line.append(self.placeholder) self._update_lines(lines, indent + "".join(cur_line)) break @@ -2671,10 +2603,7 @@ def _wrap_chunks(self, chunks): else: if lines: prev_line = lines[-1].rstrip() - if ( - self._len(prev_line) + self._len(self.placeholder) - <= self.width - ): + if self._len(prev_line) + self._len(self.placeholder) <= self.width: lines[-1] = prev_line + self.placeholder break self._update_lines(lines, indent + self.placeholder.lstrip()) @@ -2782,6 +2711,7 @@ def _main(): def _pprint_file(fobject, headers, tablefmt, sep, floatfmt, intfmt, file, colalign): + """Pretty prints a tabulated version of a file-like object's content using specified formatting parameters.""" rows = fobject.readlines() table = [re.split(sep, r.rstrip()) for r in rows if r.strip()] print( diff --git a/onnxslim/utils/utils.py b/onnxslim/utils/utils.py index 4121763..46994f2 100644 --- a/onnxslim/utils/utils.py +++ b/onnxslim/utils/utils.py @@ -1,24 +1,18 @@ +import logging from typing import Dict, List, Optional, Tuple, Union import numpy as np - import onnx from ..utils.font import GREEN, WHITE from ..utils.tabulate import SEPARATING_LINE, tabulate - -import logging - # Configure logging logging.basicConfig( level=logging.ERROR, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - handlers=[ - logging.FileHandler("app.log"), - logging.StreamHandler() - ] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.FileHandler("app.log"), logging.StreamHandler()], ) # Create a logger @@ -96,10 +90,7 @@ def gen_onnxruntime_input_data( if "data" in info: input_data_dict[name] = info["data"] else: - shapes = [ - shape if (shape != -1 and not isinstance(shape, str)) else 1 - for shape in info["shape"] - ] + shapes = [shape if (shape != -1 and not isinstance(shape, str)) else 1 for shape in info["shape"]] shapes = shapes if shapes else [1] dtype = info["dtype"] @@ -112,14 +103,10 @@ def gen_onnxruntime_input_data( return input_data_dict -def onnxruntime_inference( - model: onnx.ModelProto, input_data: dict -) -> Dict[str, np.array]: +def onnxruntime_inference(model: onnx.ModelProto, input_data: dict) -> Dict[str, np.array]: import onnxruntime as rt - sess = rt.InferenceSession( - model.SerializeToString(), providers=["CPUExecutionProvider"] - ) + sess = rt.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"]) onnx_output = sess.run(None, input_data) output_names = [output.name for output in sess.get_outputs()] @@ -128,12 +115,9 @@ def onnxruntime_inference( return onnx_output -def print_model_info_as_table( - model_name: str, model_info_list: List[Dict], elapsed_time: float = 0.0 -): - assert ( - len(model_info_list) > 0 - ), "model_info_list must contain more than one model info" +def print_model_info_as_table(model_name: str, model_info_list: List[Dict], elapsed_time: float = 0.0): + """Prints detailed model information as a formatted table.""" + assert len(model_info_list) > 0, "model_info_list must contain more than one model info" final_op_info = [] if len(model_info_list) == 1: @@ -142,15 +126,11 @@ def print_model_info_as_table( final_op_info.append(["Op Set ", model_info_list[0]["op_set"]]) else: final_op_info.append( - ["Model Name", model_name, "Op Set: " + model_info_list[0]["op_set"]] - + [""] * (len(model_info_list) - 2) + ["Model Name", model_name, "Op Set: " + model_info_list[0]["op_set"]] + [""] * (len(model_info_list) - 2) ) final_op_info.append([SEPARATING_LINE]) - final_op_info.append( - ["Model Info", "Original Model"] - + ["Slimmed Model"] * (len(model_info_list) - 1) - ) + final_op_info.append(["Model Info", "Original Model"] + ["Slimmed Model"] * (len(model_info_list) - 1)) final_op_info.append([SEPARATING_LINE] * (len(model_info_list) + 1)) all_inputs = list(model_info_list[0]["op_input_info"].keys()) @@ -164,11 +144,7 @@ def print_model_info_as_table( input_info_list.append(inputs_shape) final_op_info.append(input_info_list) - all_outputs = set( - op_type - for model_info in model_info_list - for op_type in model_info.get("op_output_info", {}) - ) + all_outputs = set(op_type for model_info in model_info_list for op_type in model_info.get("op_output_info", {})) for outputs in all_outputs: output_info_list = [ @@ -181,11 +157,7 @@ def print_model_info_as_table( final_op_info.append([SEPARATING_LINE] * (len(model_info_list) + 1)) - all_ops = set( - op_type - for model_info in model_info_list - for op_type in model_info.get("op_type_counts", {}) - ) + all_ops = set(op_type for model_info in model_info_list for op_type in model_info.get("op_type_counts", {})) sorted_ops = list(all_ops) sorted_ops.sort() for op in sorted_ops: @@ -200,10 +172,7 @@ def print_model_info_as_table( final_op_info.append(op_info_list) final_op_info.append([SEPARATING_LINE] * (len(model_info_list) + 1)) - final_op_info.append( - ["Model Size"] - + [format_bytes(model_info["model_size"]) for model_info in model_info_list] - ) + final_op_info.append(["Model Size"] + [format_bytes(model_info["model_size"]) for model_info in model_info_list]) final_op_info.append([SEPARATING_LINE] * (len(model_info_list) + 1)) final_op_info.append(["Elapsed Time"] + [f"{elapsed_time:.2f} s"]) lines = tabulate( @@ -214,11 +183,7 @@ def print_model_info_as_table( ).split("\n") time_row = lines[-2].split("|") - time_row[-3] = ( - time_row[-2][: len(time_row[-2]) // 2 + 1] - + time_row[-3] - + time_row[-2][len(time_row[-2]) // 2 :] - ) + time_row[-3] = time_row[-2][: len(time_row[-2]) // 2 + 1] + time_row[-3] + time_row[-2][len(time_row[-2]) // 2 :] time_row.pop(-2) lines[-2] = "|".join(time_row) output = "\n".join([line if line != "| \x01 |" else lines[0] for line in lines]) @@ -227,6 +192,7 @@ def print_model_info_as_table( def dump_model_info_to_disk(model_name: str, model_info: Dict): + """Dumps detailed model information to a CSV file for a given model and its associated operation information.""" import csv import os diff --git a/tests/test_folder.py b/tests/test_folder.py index 8667838..1bee0ee 100644 --- a/tests/test_folder.py +++ b/tests/test_folder.py @@ -7,6 +7,7 @@ def parse_arguments(): + """Parses command-line arguments for specifying the ONNX model directory.""" parser = argparse.ArgumentParser(description="Test script for ONNX models") parser.add_argument( "--model-dir", @@ -22,10 +23,12 @@ def parse_arguments(): @pytest.fixture(params=glob.glob(f"{args.model_dir}/*/*.onnx")) def model_file(request): + """Yields ONNX model file paths from the specified directory for parameterized testing.""" yield request.param def test_model_file(model_file): + """Tests the slimming of an ONNX model file using onnxslim command, validates success, and cleans up generated files.""" slim_model_file = model_file.replace(".onnx", "_slim.onnx") command = f"onnxslim {model_file} {slim_model_file}" result = subprocess.run(command, shell=True, capture_output=True, text=True) diff --git a/tests/test_onnx_nets.py b/tests/test_onnx_nets.py index 63d9f2a..524f5ec 100644 --- a/tests/test_onnx_nets.py +++ b/tests/test_onnx_nets.py @@ -7,7 +7,6 @@ import torch import torchvision.models as models - FUSE = True PRETRAINED = False @@ -25,6 +24,7 @@ class TestTorchVisionClass: ), ) def test_torchvision(self, request, model, shape=(1, 3, 224, 224)): + """Test various TorchVision models with random input tensors of a specified shape.""" model = model(pretrained=PRETRAINED) x = torch.rand(shape) os.makedirs("tmp/" + request.node.name, exist_ok=True) @@ -47,9 +47,11 @@ def test_torchvision(self, request, model, shape=(1, 3, 224, 224)): class TestTimmClass: @pytest.fixture(params=timm.list_models()) def model_name(self, request): + """Yields names of models available in TIMM (https://github.com/rwightman/pytorch-image-models) for pytest fixture parameterization.""" yield request.param def test_timm(self, request, model_name): + """Tests a TIMM model's forward pass with a random input tensor of the appropriate size.""" model = timm.create_model(model_name, pretrained=PRETRAINED) input_size = model.default_cfg.get("input_size") x = torch.randn((1,) + input_size) diff --git a/tests/test_onnxslim.py b/tests/test_onnxslim.py index c5f1f77..10e5d73 100644 --- a/tests/test_onnxslim.py +++ b/tests/test_onnxslim.py @@ -1,7 +1,6 @@ import subprocess import pytest - from utils import download_onnx_from_url @@ -20,9 +19,8 @@ ) class TestOnnxModel: def test_onnx_model(self, request, name): - filename = download_onnx_from_url( - f"http://120.224.26.32:15030/aifarm/onnx/{name}.onnx" - ) + """Test downloading an ONNX model by its name and running 'onnxslim' command to slim the model.""" + filename = download_onnx_from_url(f"http://120.224.26.32:15030/aifarm/onnx/{name}.onnx") command = f"onnxslim {filename} {name}_slim.onnx" result = subprocess.run(command, shell=True, capture_output=True, text=True) output = result.stderr.strip() @@ -31,24 +29,21 @@ def test_onnx_model(self, request, name): assert result.returncode == 0 def test_onnxslim_python_api(self, request, name): + """Tests the ONNX model slimming Python API using the 'onnxslim' command for a given model name.""" import onnx + from onnxslim import slim - filename = download_onnx_from_url( - f"http://120.224.26.32:15030/aifarm/onnx/{name}.onnx" - ) + filename = download_onnx_from_url(f"http://120.224.26.32:15030/aifarm/onnx/{name}.onnx") model_slim = slim(filename) onnx.save(model_slim, f"{name}_slim.onnx") class TestFeat: def test_input_shape_modification(self, request): - filename = download_onnx_from_url( - f"http://120.224.26.32:15030/aifarm/onnx/UNetModel-fp16.onnx" - ) - command = ( - f"onnxslim {filename} UNetModel-fp16_slim.onnx --input_shapes cc:1,1,768" - ) + """Test the modification of input shapes for a UNet model using the onnxslim command.""" + filename = download_onnx_from_url(f"http://120.224.26.32:15030/aifarm/onnx/UNetModel-fp16.onnx") + command = f"onnxslim {filename} UNetModel-fp16_slim.onnx --input_shapes cc:1,1,768" result = subprocess.run(command, shell=True, capture_output=True, text=True) output = result.stderr.strip() # Assert the expected return code @@ -56,9 +51,8 @@ def test_input_shape_modification(self, request): assert result.returncode == 0 def test_fp162fp32_conversion(self, request): - filename = download_onnx_from_url( - f"http://120.224.26.32:15030/aifarm/onnx/UNetModel-fp16.onnx" - ) + """Test the conversion of an ONNX model from FP16 to FP32 using the onnxslim tool with specified input shapes.""" + filename = download_onnx_from_url(f"http://120.224.26.32:15030/aifarm/onnx/UNetModel-fp16.onnx") command = f"onnxslim {filename} UNetModel-fp16_slim.onnx --input_shapes cc:1,1,768 --dtype fp32" result = subprocess.run(command, shell=True, capture_output=True, text=True) output = result.stderr.strip() @@ -67,9 +61,8 @@ def test_fp162fp32_conversion(self, request): assert result.returncode == 0 def test_output_modification(self, request): - filename = download_onnx_from_url( - f"http://120.224.26.32:15030/aifarm/onnx/yolov5m.onnx" - ) + """Tests output modification of an ONNX model by running a slimming command and checking for successful execution.""" + filename = download_onnx_from_url(f"http://120.224.26.32:15030/aifarm/onnx/yolov5m.onnx") command = f"onnxslim {filename} yolov5m_slim.onnx --outputs 591 739 443" result = subprocess.run(command, shell=True, capture_output=True, text=True) output = result.stderr.strip() diff --git a/tests/utils.py b/tests/utils.py index 5902f3f..a676e82 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,7 +8,6 @@ import tempfile import warnings import zipfile - from urllib.error import HTTPError from urllib.parse import urlparse # noqa: F401 from urllib.request import Request, urlopen @@ -39,6 +38,7 @@ def __init__( # ignore unit, unit_scale, unit_divisor; they're just for real tqdm def update(self, n): + """Updates the progress by incrementing the counter 'n' by a specified amount.""" if self.disable: return @@ -46,18 +46,19 @@ def update(self, n): if self.total is None: sys.stderr.write("\r{0:.1f} bytes".format(self.n)) else: - sys.stderr.write( - "\r{0:.1f}%".format(100 * self.n / float(self.total)) - ) + sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) sys.stderr.flush() def close(self): + """Disables the progress indicator.""" self.disable = True def __enter__(self): + """Enters the runtime context related to this object.""" return self def __exit__(self, exc_type, exc_val, exc_tb): + """Exits the runtime context related to this object, cleaning up if the progress indicator is enabled.""" if self.disable: return @@ -79,6 +80,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): # Copied from tools/shared/module_loader to be included in torch package def import_module(name, path): + """Dynamically import a module given its name and file path.""" import importlib.util from importlib.abc import Loader @@ -90,6 +92,7 @@ def import_module(name, path): def _remove_if_exists(path): + """Remove a file or directory if it exists at the specified path.""" if os.path.exists(path): if os.path.isfile(path): os.remove(path) @@ -98,19 +101,19 @@ def _remove_if_exists(path): def _git_archive_link(repo_owner, repo_name, branch): - return "https://github.com/{}/{}/archive/{}.zip".format( - repo_owner, repo_name, branch - ) + """Generate a GitHub archive link for a specific repository owner, name, and branch.""" + return "https://github.com/{}/{}/archive/{}.zip".format(repo_owner, repo_name, branch) def _load_attr_from_module(module, func_name): - # Check if callable is defined in the module + """Load an attribute by name from a module if it exists.""" if func_name not in dir(module): return None return getattr(module, func_name) def _get_torch_home(): + """Get the directory path where Torch caches data.""" torch_home = os.path.expanduser( os.getenv( ENV_TORCH_HOME, @@ -121,6 +124,7 @@ def _get_torch_home(): def _parse_repo_info(github): + """Parse the GitHub repository information and determine the default branch if not specified.""" if ":" in github: repo_info, branch = github.split(":") else: @@ -143,12 +147,13 @@ def _parse_repo_info(github): def _read_url(url): + """Fetches and decodes the content from a specified URL.""" with urlopen(url) as r: return r.read().decode(r.headers.get_content_charset("utf-8")) def _validate_not_a_forked_repo(repo_owner, repo_name, branch): - # Use urlopen to avoid depending on local git. + """Ensures the specified branch exists in the given GitHub repository and is not from a forked repository.""" headers = {"Accept": "application/vnd.github.v3+json"} token = os.environ.get(ENV_GITHUB_TOKEN) if token is not None: @@ -176,7 +181,7 @@ def _validate_not_a_forked_repo(repo_owner, repo_name, branch): def _get_cache_or_reload(github, force_reload, verbose=True, skip_validation=False): - # Setup hub_dir to save downloaded files + """Retrieve cached repository or reload it from GitHub if necessary.""" hub_dir = get_dir() if not os.path.exists(hub_dir): os.makedirs(hub_dir) @@ -225,23 +230,24 @@ def _get_cache_or_reload(github, force_reload, verbose=True, skip_validation=Fal def _check_module_exists(name): + """Check if a module exists by name using importlib.util.find_spec().""" import importlib.util return importlib.util.find_spec(name) is not None def _check_dependencies(m): + """Verify that all dependencies defined in the specified module are installed, raising a RuntimeError if any are missing.""" dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) if dependencies is not None: missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] if len(missing_deps): - raise RuntimeError( - "Missing dependencies: {}".format(", ".join(missing_deps)) - ) + raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps))) def _load_entry_from_hubconf(m, model): + """Load a callable function from hubconf while checking for required dependencies and valid input model string.""" if not isinstance(model, str): raise ValueError("Invalid input: model should be a string of function name") @@ -306,9 +312,7 @@ def list(github, force_reload=False, skip_validation=False): Example: >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True) """ - repo_dir = _get_cache_or_reload( - github, force_reload, verbose=True, skip_validation=skip_validation - ) + repo_dir = _get_cache_or_reload(github, force_reload, verbose=True, skip_validation=skip_validation) sys.path.insert(0, repo_dir) @@ -318,11 +322,7 @@ def list(github, force_reload=False, skip_validation=False): sys.path.remove(repo_dir) # We take functions starts with '_' as internal helper functions - entrypoints = [ - f - for f in dir(hub_module) - if callable(getattr(hub_module, f)) and not f.startswith("_") - ] + entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith("_")] return entrypoints @@ -346,9 +346,7 @@ def help(github, model, force_reload=False, skip_validation=False): Example: >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)) """ - repo_dir = _get_cache_or_reload( - github, force_reload, verbose=True, skip_validation=skip_validation - ) + repo_dir = _get_cache_or_reload(github, force_reload, verbose=True, skip_validation=skip_validation) sys.path.insert(0, repo_dir) @@ -424,14 +422,10 @@ def load( source = source.lower() if source not in ("github", "local"): - raise ValueError( - f'Unknown source: "{source}". Allowed values: "github" | "local".' - ) + raise ValueError(f'Unknown source: "{source}". Allowed values: "github" | "local".') if source == "github": - repo_or_dir = _get_cache_or_reload( - repo_or_dir, force_reload, verbose, skip_validation - ) + repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, verbose, skip_validation) model = _load_local(repo_or_dir, model, *args, **kwargs) return model @@ -470,7 +464,8 @@ def _load_local(hubconf_dir, model, *args, **kwargs): def download_url_to_file(url, dst, hash_prefix=None, progress=True): - r"""Download object at the given URL to a local path. + r""" + Download object at the given URL to a local path. Args: url (string): URL of the object to download @@ -482,7 +477,6 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True): Example: >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') - """ file_size = None req = Request(url, headers={"User-Agent": "torch.hub"}) @@ -525,11 +519,7 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True): if hash_prefix is not None: digest = sha256.hexdigest() if digest[: len(hash_prefix)] != hash_prefix: - raise RuntimeError( - 'invalid hash value (expected "{}", got "{}")'.format( - hash_prefix, digest - ) - ) + raise RuntimeError('invalid hash value (expected "{}", got "{}")'.format(hash_prefix, digest)) shutil.move(f.name, dst) finally: f.close() @@ -538,6 +528,7 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True): def _download_url_to_file(url, dst, hash_prefix=None, progress=True): + """Download a file from a URL to a destination path with an optional hash prefix and progress display.""" warnings.warn( "torch.hub._download_url_to_file has been renamed to\ torch.hub.download_url_to_file to be a public API,\ @@ -550,6 +541,7 @@ def _download_url_to_file(url, dst, hash_prefix=None, progress=True): # The legacy zip format expects only one file from torch.save() < 1.6 in the zip. # We should remove this support since zipfile is now default zipfile format for torch.save(). def _is_legacy_zip_format(filename): + """Checks if the given zip file is in the legacy format, expecting only one non-directory file.""" if zipfile.is_zipfile(filename): infolist = zipfile.ZipFile(filename).infolist() return len(infolist) == 1 and not infolist[0].is_dir() @@ -557,6 +549,7 @@ def _is_legacy_zip_format(filename): def _legacy_zip_load(filename, model_dir, map_location): + """Load a legacy zip file, extract its contents, and load the extracted file using torch.load().""" warnings.warn( "Falling back to the old format < 1.6. This support will be " "deprecated in favor of default zipfile format introduced in 1.6. " @@ -575,9 +568,8 @@ def _legacy_zip_load(filename, model_dir, map_location): return torch.load(extracted_file, map_location=map_location) -def download_onnx_from_url( - url, model_dir=None, progress=True, check_hash=False, file_name=None -): +def download_onnx_from_url(url, model_dir=None, progress=True, check_hash=False, file_name=None): + """Download an ONNX file from a URL and save it to the specified directory.""" if model_dir is None: hub_dir = get_dir() model_dir = os.path.join(hub_dir, "onnx")