Skip to content

Commit

Permalink
Add Python docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed May 31, 2024
1 parent 30b34f4 commit 6e4e32e
Show file tree
Hide file tree
Showing 21 changed files with 359 additions and 36 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@


def setup(app):
"""Configure Sphinx application to include custom CSS from 'style.css'."""
app.add_css_file("style.css")
2 changes: 2 additions & 0 deletions examples/common_subexpression_elimination/cse_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

class Model(torch.nn.Module):
def __init__(self):
"""Initializes the Model class with a single LayerNorm layer of embedding dimension 10."""
super(Model, self).__init__()
embedding_dim = 10
self.layer_norm = nn.LayerNorm(embedding_dim)

def forward(self, x):
"""Applies LayerNorm to the input tensor and adds it to an independently computed LayerNorm of the same tensor."""
return self.layer_norm(x) + F.layer_norm(x, [10])


Expand Down
1 change: 1 addition & 0 deletions onnxslim/cli/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def slim(


def main():
"""Entry point for the OnnxSlim toolkit, processes command-line arguments and passes them to the slim function."""
import argparse

import onnxslim
Expand Down
28 changes: 21 additions & 7 deletions onnxslim/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


def register_fusion_pattern(layer_type):
"""Registers a fusion pattern function for a specific layer type in the DEFAULT_FUSION_PATTERNS dictionary."""
def insert(fn):
if layer_type in DEFAULT_FUSION_PATTERNS.keys():
raise
Expand All @@ -25,6 +26,7 @@ def insert(fn):


def get_fusion_patterns(skip_fusion_patterns: str = None):
"""Returns a dictionary of default fusion patterns, optionally excluding specific patterns."""
default_fusion_patterns = DEFAULT_FUSION_PATTERNS.copy()
if skip_fusion_patterns:
for pattern in skip_fusion_patterns:
Expand All @@ -34,6 +36,7 @@ def get_fusion_patterns(skip_fusion_patterns: str = None):


def get_node_users(node):
"""Retrieve the list of users for a given node based on its outputs."""
users = []
for output in node.outputs: # output is a Variable
for user in output.outputs: # user is a Node
Expand All @@ -42,6 +45,7 @@ def get_node_users(node):


def get_node_feeds(node):
"""Retrieve the list of feed nodes for a given node based on its inputs."""
feeds = []
for input in node.inputs: # input is a Variable
for feed in input.inputs: # feed is a Node
Expand All @@ -50,6 +54,7 @@ def get_node_feeds(node):


def get_previous_node_by_type(node, op_type, trajectory=[]):
"""Retrieve the previous node of a specific type in the computational graph starting from the given node."""
node_feeds = get_node_feeds(node)
for node_feed in node_feeds:
if node_feed.op == op_type:
Expand All @@ -61,12 +66,14 @@ def get_previous_node_by_type(node, op_type, trajectory=[]):


def get_constant_variable(node, return_idx=False):
"""Return the first constant variable found in a node's inputs, optionally returning the index."""
for idx, input in enumerate(list(node.inputs)):
if isinstance(input, Constant):
return input if not return_idx else (idx, input)


def delete_node(node, input_var_idx=0, output_var_idx=0):
"""Delete a node from the computation graph while redirecting its inputs to its outputs to maintain graph integrity."""
input_variable = node.inputs[input_var_idx]
node_variable = node.outputs[output_var_idx]
next_nodes = get_node_users(node)
Expand All @@ -83,6 +90,7 @@ def delete_node(node, input_var_idx=0, output_var_idx=0):


def check_shape(shapes):
"""Verify that 'shapes' contains exactly one string and all other elements are positive integers."""
string_count = 0
non_negative_int_count = 0

Expand All @@ -96,6 +104,7 @@ def check_shape(shapes):


def graph_constant_fold_inplace(graph):
"""Perform in-place constant folding optimizations on the provided computational graph by eliminating redundant nodes."""
for node in graph.nodes:
if node.op == "Identity" or node.op == "Dropout":
delete_node(node)
Expand Down Expand Up @@ -150,7 +159,7 @@ def graph_constant_fold_inplace(graph):

@register_fusion_pattern("FusionPadConv")
def find_conv_nodes(node, opset):
# fmt: off
"""Identify and match convolution nodes following a padding operation to update padding attributes for fusion purposes."""
'''
x
|
Expand Down Expand Up @@ -271,7 +280,7 @@ def find_conv_transpose_nodes(node, opset):

@register_fusion_pattern("EliminationSlice")
def find_slice_nodes(node, opset):
# fmt: off
"""Identify and combine consecutive 'Slice' nodes in a computational graph for optimization."""
'''
x
|
Expand Down Expand Up @@ -378,7 +387,7 @@ def find_slice_nodes(node, opset):

@register_fusion_pattern("EliminationReshape")
def find_reshape_nodes(node, opset):
# fmt: off
"""Identify consecutive 'Reshape' nodes in the computational graph and validate their mergeability based on input and output shapes."""
'''
x
|
Expand Down Expand Up @@ -436,7 +445,7 @@ def check_constant_mergeable(reshape_node):

# @register_fusion_pattern("EliminationTranspose")
def find_slice_nodes(node, opset):
# fmt: off
"""Identifies and processes consecutive Transpose nodes, removing redundant ones to optimize the graph."""
'''
x
|
Expand Down Expand Up @@ -476,7 +485,7 @@ def find_slice_nodes(node, opset):

@register_fusion_pattern("FusionGemm")
def find_matmul_add_nodes(node, opset):
# fmt: off
"""Identifies and returns a pattern match for MatMul followed by Add operations for optimization in a computational graph."""
'''
x
|
Expand Down Expand Up @@ -635,7 +644,7 @@ def find_matmul_add_nodes(node, opset):

# @register_fusion_pattern("FusionGelu")
def find_gelu_nodes(node, opset):
# fmt: off
"""Identifies GELU (Gaussian Error Linear Unit) activation pattern nodes in a computational graph."""
'''
x
/ \
Expand Down Expand Up @@ -683,7 +692,7 @@ def find_gelu_nodes(node, opset):

@register_fusion_pattern("FusionReduce")
def find_slice_nodes(node, opset):
# fmt: off
"""Find and return a dictionary of matching 'ReduceSum' followed by 'Unsqueeze' nodes that match specific conditions in the graph."""
'''
x
|
Expand Down Expand Up @@ -754,6 +763,7 @@ def replace_custom_layer(


def find_matches(graph: Graph, fusion_patterns: dict):
"""Find matching patterns in the graph based on provided fusion patterns."""
opset = graph.opset
match_map = {}
counter = Counter()
Expand All @@ -776,6 +786,7 @@ def find_matches(graph: Graph, fusion_patterns: dict):


def find_and_remove_replaceable_nodes(nodes):
"""Find and remove duplicate or replaceable nodes in a given list of computational graph nodes."""
def get_node_key(node):
input_names = []
for input_node in node.inputs:
Expand Down Expand Up @@ -823,6 +834,7 @@ def replace_node_references(existing_node, to_be_removed_node):


def sequences_equal(seq1, seq2):
"""Check if two sequences are equal by comparing their lengths and elements."""
length_match = len(seq1) == len(seq2)
if not length_match:
return False
Expand All @@ -835,13 +847,15 @@ def sequences_equal(seq1, seq2):


def can_be_replaced(node, other_node):
"""Check if two nodes can be replaced based on their operations, attributes, and inputs."""
attrs_match = node.op == other_node.op and node.attrs == other_node.attrs
inputs_match = sequences_equal(node.inputs, other_node.inputs)

return attrs_match and inputs_match


def subexpression_elimination(graph):
"""Perform subexpression elimination on a computational graph to optimize node operations."""
nodes_by_op = {}

for node in graph.nodes:
Expand Down
11 changes: 10 additions & 1 deletion onnxslim/core/slim.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


def init_logging(verbose=False):
# Remove all handlers associated with the root logger object.
"""Configure the logging settings for the application, setting verbosity based on the 'verbose' parameter."""
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)

Expand Down Expand Up @@ -75,6 +75,7 @@ def summarize_model(model: onnx.ModelProto) -> Dict:
op_type_counts = {}

def get_tensor_dtype_shape(tensor):
"""Extract the data type and shape of an ONNX tensor."""
type_str = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE.get(tensor.type.tensor_type.elem_type, "Unknown")
shape = None
if tensor.type.tensor_type.HasField("shape"):
Expand Down Expand Up @@ -130,6 +131,7 @@ def get_shape(inputs: onnx.ModelProto) -> Dict[str, List[int]]:


def model_save_as_external_data(model: onnx.ModelProto, model_path: str):
"""Save an ONNX model with tensor data as an external file."""
location = os.path.basename(model_path) + ".data"
if os.path.exists(location):
os.remove(location)
Expand Down Expand Up @@ -204,13 +206,15 @@ def output_modification(model: onnx.ModelProto, outputs: str) -> onnx.ModelProto


def check_onnx(model: onnx.ModelProto, model_check_inputs=None):
"""Validates an ONNX model by generating input data and performing inference to check outputs."""
input_data_dict = gen_onnxruntime_input_data(model, model_check_inputs)
raw_onnx_output = onnxruntime_inference(model, input_data_dict)

return input_data_dict, raw_onnx_output


def shape_infer(model: onnx.ModelProto):
"""Infer tensor shapes in an ONNX model using onnxruntime or ONNX shape inference methods."""
logger.debug("Start shape inference.")
try:
logger.debug("try onnxruntime shape infer.")
Expand All @@ -233,6 +237,7 @@ def shape_infer(model: onnx.ModelProto):


def optimize(model: onnx.ModelProto, skip_fusion_patterns: str = None):
"""Optimize the given ONNX model by converting to GraphSurgeon format, performing constant folding, and model optimizations."""
logger.debug("Start converting model to gs.")
graph = gs.import_onnx(model).toposort()
logger.debug("Finish converting model to gs.")
Expand All @@ -249,6 +254,7 @@ def optimize(model: onnx.ModelProto, skip_fusion_patterns: str = None):


def check_point(model: onnx.ModelProto):
"""Imports an ONNX model into a graph from the specified model checkpoint."""
graph_check_point = gs.import_onnx(model)

return graph_check_point
Expand Down Expand Up @@ -292,6 +298,7 @@ def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto:


def save(model: onnx.ModelProto, model_path: str, model_check: bool = False):
"""Save an ONNX model to a specified path, with optional model checking for validity."""
if model_check:
try:
checker.check_model(model)
Expand Down Expand Up @@ -320,6 +327,7 @@ def save(model: onnx.ModelProto, model_path: str, model_check: bool = False):


def check_result(raw_onnx_output, slimmed_onnx_output):
"""Verify the consistency of outputs between the raw and slimmed ONNX models, logging warnings if discrepancies are detected."""
if set(raw_onnx_output.keys()) != set(slimmed_onnx_output.keys()):
logger.warning("Model output mismatch after slimming.")
logger.warning("Raw model output keys: {}".format(raw_onnx_output.keys()))
Expand All @@ -341,6 +349,7 @@ def check_result(raw_onnx_output, slimmed_onnx_output):


def freeze(model: onnx.ModelProto):
"""Freezes the ONNX model by removing inputs that are also present in the initializers."""
inputs = model.graph.input
name_to_input = {}
for input in inputs:
Expand Down
Loading

0 comments on commit 6e4e32e

Please sign in to comment.