Skip to content

Commit

Permalink
Add Python Docstrings to functions and methods (#7)
Browse files Browse the repository at this point in the history
* Auto-format by https://ultralytics.com/actions

* Add Python docstrings

* Auto-format by https://ultralytics.com/actions

* Update README.md

* Auto-format by https://ultralytics.com/actions

* fix import error

* Auto-format by https://ultralytics.com/actions

---------

Co-authored-by: UltralyticsAssistant <[email protected]>
Co-authored-by: inisis <[email protected]>
  • Loading branch information
3 people authored Jun 1, 2024
1 parent 308003b commit 4625ccf
Show file tree
Hide file tree
Showing 22 changed files with 518 additions and 57 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,4 @@ For more usage, see onnxslim -h or refer to our [examples](./examples)
# Contact

Discord: https://discord.gg/nRw2Fd3VUS
QQ Group: 873569894
Discord: https://discord.gg/nRw2Fd3VUS QQ Group: 873569894
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@


def setup(app):
"""Configure the Sphinx app by adding a custom CSS file ('style.css')."""
app.add_css_file("style.css")
4 changes: 4 additions & 0 deletions examples/common_subexpression_elimination/cse_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@

class Model(torch.nn.Module):
def __init__(self):
"""Initializes the Model class with a single LayerNorm layer of embedding dimension 10."""
super(Model, self).__init__()
embedding_dim = 10
self.layer_norm = nn.LayerNorm(embedding_dim)

def forward(self, x):
"""Applies LayerNorm to the input tensor and adds it to an independently computed LayerNorm of the same
tensor.
"""
return self.layer_norm(x) + F.layer_norm(x, [10])


Expand Down
1 change: 1 addition & 0 deletions onnxslim/cli/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def slim(


def main():
"""Entry point for the OnnxSlim toolkit, processes command-line arguments and passes them to the slim function."""
import argparse

import onnxslim
Expand Down
70 changes: 49 additions & 21 deletions onnxslim/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@


def register_fusion_pattern(layer_type):
"""Registers a fusion pattern function for a specified layer type in the DEFAULT_FUSION_PATTERNS dictionary."""

def insert(fn):
if layer_type in DEFAULT_FUSION_PATTERNS.keys():
raise
Expand All @@ -25,6 +27,7 @@ def insert(fn):


def get_fusion_patterns(skip_fusion_patterns: str = None):
"""Returns a copy of the default fusion patterns, optionally excluding specific patterns."""
default_fusion_patterns = DEFAULT_FUSION_PATTERNS.copy()
if skip_fusion_patterns:
for pattern in skip_fusion_patterns:
Expand All @@ -34,6 +37,7 @@ def get_fusion_patterns(skip_fusion_patterns: str = None):


def get_node_users(node):
"""Retrieve the list of nodes that use the outputs of the given node."""
users = []
for output in node.outputs: # output is a Variable
for user in output.outputs: # user is a Node
Expand All @@ -42,6 +46,7 @@ def get_node_users(node):


def get_node_feeds(node):
"""Retrieve the list of nodes that provide inputs to the given node."""
feeds = []
for input in node.inputs: # input is a Variable
for feed in input.inputs: # feed is a Node
Expand All @@ -50,6 +55,7 @@ def get_node_feeds(node):


def get_previous_node_by_type(node, op_type, trajectory=[]):
"""Recursively find and return the first preceding node of a specified type in the computation graph."""
node_feeds = get_node_feeds(node)
for node_feed in node_feeds:
if node_feed.op == op_type:
Expand All @@ -61,12 +67,14 @@ def get_previous_node_by_type(node, op_type, trajectory=[]):


def get_constant_variable(node, return_idx=False):
"""Return the first constant variable found in a node's inputs, optionally including the index."""
for idx, input in enumerate(list(node.inputs)):
if isinstance(input, Constant):
return input if not return_idx else (idx, input)


def delete_node(node, input_var_idx=0, output_var_idx=0):
"""Delete a node from the computation graph while re-linking its input and output to maintain graph integrity."""
input_variable = node.inputs[input_var_idx]
node_variable = node.outputs[output_var_idx]
next_nodes = get_node_users(node)
Expand All @@ -83,6 +91,7 @@ def delete_node(node, input_var_idx=0, output_var_idx=0):


def check_shape(shapes):
"""Verify that 'shapes' contains exactly one string and all other elements are positive integers."""
string_count = 0
non_negative_int_count = 0

Expand All @@ -96,6 +105,9 @@ def check_shape(shapes):


def graph_constant_fold_inplace(graph):
"""Perform in-place constant folding optimizations on the given computational graph by eliminating redundant
nodes.
"""
for node in graph.nodes:
if node.op == "Identity" or node.op == "Dropout":
delete_node(node)
Expand Down Expand Up @@ -150,14 +162,16 @@ def graph_constant_fold_inplace(graph):

@register_fusion_pattern("FusionPadConv")
def find_conv_nodes(node, opset):
# fmt: off
'''
"""Identify and match convolution nodes following a padding operation to update padding attributes for fusion
purposes.
"""
"""
x
|
Pad
|
Conv
'''
"""
# fmt: on
match = {}
if node.op == "Conv":
Expand Down Expand Up @@ -271,14 +285,14 @@ def find_conv_transpose_nodes(node, opset):

@register_fusion_pattern("EliminationSlice")
def find_slice_nodes(node, opset):
# fmt: off
'''
"""Identify and combine consecutive 'Slice' nodes in a computational graph for optimization purposes."""
"""
x
|
Slice
|
Slice
'''
"""
# fmt: on
match = {}
if node.op == "Slice":
Expand Down Expand Up @@ -378,14 +392,16 @@ def find_slice_nodes(node, opset):

@register_fusion_pattern("EliminationReshape")
def find_reshape_nodes(node, opset):
# fmt: off
'''
"""Identify consecutive 'Reshape' nodes in the computational graph for potential fusion, returning a matching
dictionary when criteria are met.
"""
"""
x
|
Reshape
|
Reshape
'''
"""
# fmt: on
match = {}
if node.op == "Reshape":
Expand Down Expand Up @@ -436,14 +452,14 @@ def check_constant_mergeable(reshape_node):

# @register_fusion_pattern("EliminationTranspose")
def find_slice_nodes(node, opset):
# fmt: off
'''
"""Identifies and processes patterns of consecutive Transpose nodes in a computational graph."""
"""
x
|
Transpose
|
Transpose
'''
"""
# fmt: on
match = {}
if node.op == "Transpose":
Expand Down Expand Up @@ -476,14 +492,16 @@ def find_slice_nodes(node, opset):

@register_fusion_pattern("FusionGemm")
def find_matmul_add_nodes(node, opset):
# fmt: off
'''
"""Identifies and returns a pattern match for MatMul followed by Add operations for optimization in a computational
graph.
"""
"""
x
|
MatMul
|
Add
'''
"""
# fmt: on
match = {}
if node.op == "Add":
Expand Down Expand Up @@ -635,8 +653,10 @@ def find_matmul_add_nodes(node, opset):

# @register_fusion_pattern("FusionGelu")
def find_gelu_nodes(node, opset):
# fmt: off
'''
"""Identifies GELU (Gaussian Error Linear Unit) activation pattern nodes in a computational graph based on given
conditions.
"""
"""
x
/ \
| Div
Expand All @@ -648,7 +668,7 @@ def find_gelu_nodes(node, opset):
Mul
|
Mul
'''
"""
# fmt: on
match = {}
if node.op == "Mul":
Expand Down Expand Up @@ -683,14 +703,16 @@ def find_gelu_nodes(node, opset):

@register_fusion_pattern("FusionReduce")
def find_slice_nodes(node, opset):
# fmt: off
'''
"""Find and return a dictionary of matching 'ReduceSum' followed by 'Unsqueeze' nodes that match specific conditions
in the graph.
"""
"""
x
|
ReduceSum
|
Unsqueeze
'''
"""
# fmt: on
match = {}
if node.op == "Unsqueeze":
Expand Down Expand Up @@ -754,6 +776,7 @@ def replace_custom_layer(


def find_matches(graph: Graph, fusion_patterns: dict):
"""Find matching patterns in the graph based on provided fusion patterns."""
opset = graph.opset
match_map = {}
counter = Counter()
Expand All @@ -776,6 +799,8 @@ def find_matches(graph: Graph, fusion_patterns: dict):


def find_and_remove_replaceable_nodes(nodes):
"""Find and remove duplicate or replaceable nodes in a given list of computational graph nodes."""

def get_node_key(node):
input_names = []
for input_node in node.inputs:
Expand Down Expand Up @@ -823,6 +848,7 @@ def replace_node_references(existing_node, to_be_removed_node):


def sequences_equal(seq1, seq2):
"""Check if two sequences are equal by comparing their lengths and elements."""
length_match = len(seq1) == len(seq2)
if not length_match:
return False
Expand All @@ -835,13 +861,15 @@ def sequences_equal(seq1, seq2):


def can_be_replaced(node, other_node):
"""Check if two nodes can be replaced based on their operations, attributes, and inputs."""
attrs_match = node.op == other_node.op and node.attrs == other_node.attrs
inputs_match = sequences_equal(node.inputs, other_node.inputs)

return attrs_match and inputs_match


def subexpression_elimination(graph):
"""Perform subexpression elimination on a computational graph to optimize node operations."""
nodes_by_op = {}

for node in graph.nodes:
Expand Down
13 changes: 12 additions & 1 deletion onnxslim/core/slim.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


def init_logging(verbose=False):
# Remove all handlers associated with the root logger object.
"""Configure the logging settings for the application based on the verbosity level."""
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)

Expand Down Expand Up @@ -75,6 +75,7 @@ def summarize_model(model: onnx.ModelProto) -> Dict:
op_type_counts = {}

def get_tensor_dtype_shape(tensor):
"""Extract the data type and shape of an ONNX tensor."""
type_str = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE.get(tensor.type.tensor_type.elem_type, "Unknown")
shape = None
if tensor.type.tensor_type.HasField("shape"):
Expand Down Expand Up @@ -130,6 +131,7 @@ def get_shape(inputs: onnx.ModelProto) -> Dict[str, List[int]]:


def model_save_as_external_data(model: onnx.ModelProto, model_path: str):
"""Save an ONNX model with tensor data as an external file."""
location = os.path.basename(model_path) + ".data"
if os.path.exists(location):
os.remove(location)
Expand Down Expand Up @@ -204,13 +206,15 @@ def output_modification(model: onnx.ModelProto, outputs: str) -> onnx.ModelProto


def check_onnx(model: onnx.ModelProto, model_check_inputs=None):
"""Validates an ONNX model by generating input data and performing inference to check outputs."""
input_data_dict = gen_onnxruntime_input_data(model, model_check_inputs)
raw_onnx_output = onnxruntime_inference(model, input_data_dict)

return input_data_dict, raw_onnx_output


def shape_infer(model: onnx.ModelProto):
"""Infer tensor shapes in an ONNX model using symbolic and static shape inference techniques."""
logger.debug("Start shape inference.")
try:
logger.debug("try onnxruntime shape infer.")
Expand All @@ -233,6 +237,7 @@ def shape_infer(model: onnx.ModelProto):


def optimize(model: onnx.ModelProto, skip_fusion_patterns: str = None):
"""Optimize the given ONNX model with options to skip specific fusion patterns and return the optimized model."""
logger.debug("Start converting model to gs.")
graph = gs.import_onnx(model).toposort()
logger.debug("Finish converting model to gs.")
Expand All @@ -249,6 +254,7 @@ def optimize(model: onnx.ModelProto, skip_fusion_patterns: str = None):


def check_point(model: onnx.ModelProto):
"""Imports an ONNX model checkpoint into a Graphsurgeon graph representation."""
graph_check_point = gs.import_onnx(model)

return graph_check_point
Expand Down Expand Up @@ -292,6 +298,7 @@ def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto:


def save(model: onnx.ModelProto, model_path: str, model_check: bool = False):
"""Save an ONNX model to a specified path, with optional model checking for validity."""
if model_check:
try:
checker.check_model(model)
Expand Down Expand Up @@ -320,6 +327,9 @@ def save(model: onnx.ModelProto, model_path: str, model_check: bool = False):


def check_result(raw_onnx_output, slimmed_onnx_output):
"""Verify the consistency of outputs between the raw and slimmed ONNX models, logging warnings if discrepancies are
detected.
"""
if set(raw_onnx_output.keys()) != set(slimmed_onnx_output.keys()):
logger.warning("Model output mismatch after slimming.")
logger.warning("Raw model output keys: {}".format(raw_onnx_output.keys()))
Expand All @@ -341,6 +351,7 @@ def check_result(raw_onnx_output, slimmed_onnx_output):


def freeze(model: onnx.ModelProto):
"""Freeze the input layers of an ONNX model by removing the initializers from the input graph."""
inputs = model.graph.input
name_to_input = {}
for input in inputs:
Expand Down
Loading

0 comments on commit 4625ccf

Please sign in to comment.