Skip to content

Commit

Permalink
Revert "Docstrings and code refactor (#22)"
Browse files Browse the repository at this point in the history
This reverts commit 942d5c3.
  • Loading branch information
glenn-jocher authored Sep 7, 2024
1 parent 942d5c3 commit 016b5d6
Show file tree
Hide file tree
Showing 33 changed files with 188 additions and 239 deletions.
4 changes: 1 addition & 3 deletions examples/common_subexpression_elimination/cse_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@


class Model(torch.nn.Module):
"""A PyTorch model applying LayerNorm to input tensors for normalization in neural network layers."""

def __init__(self):
"""Initializes the Model class with a single LayerNorm layer of embedding dimension 10."""
super().__init__()
super(Model, self).__init__()
embedding_dim = 10
self.layer_norm = nn.LayerNorm(embedding_dim)

Expand Down
5 changes: 0 additions & 5 deletions onnxslim/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,7 @@ class CheckerArguments:


class ArgumentParser:
"""Parses command-line arguments into specified dataclasses for ONNX model optimization and modification tasks."""

def __init__(self, *argument_dataclasses: Type):
"""Initializes the ArgumentParser with dataclass types for parsing ONNX model optimization arguments."""
self.argument_dataclasses = argument_dataclasses
self.parser = argparse.ArgumentParser(
description="OnnxSlim: A Toolkit to Help Optimizer Onnx Model",
Expand All @@ -122,7 +119,6 @@ def __init__(self, *argument_dataclasses: Type):
self._add_arguments()

def _add_arguments(self):
"""Adds command-line arguments to the parser based on provided dataclass fields and their metadata."""
for dataclass_type in self.argument_dataclasses:
for field_name, field_def in dataclass_type.__dataclass_fields__.items():
arg_type = field_def.type
Expand Down Expand Up @@ -154,7 +150,6 @@ def _add_arguments(self):
self.parser.add_argument("-v", "--version", action="version", version=onnxslim.__version__)

def parse_args_into_dataclasses(self):
"""Parses command-line arguments into specified dataclass instances for structured configuration."""
args = self.parser.parse_args()
args_dict = vars(args)

Expand Down
31 changes: 15 additions & 16 deletions onnxslim/cli/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,16 @@


def slim(model: Union[str, onnx.ModelProto], *args, **kwargs):
"""Slims an ONNX model by optimizing and modifying its structure, inputs, and outputs for improved performance."""
import os
import time
from pathlib import Path

from onnxslim.core import (
convert_data_format,
freeze,
input_modification,
input_shape_modification,
optimize,
input_modification,
output_modification,
shape_infer,
)
Expand All @@ -30,20 +29,20 @@ def slim(model: Union[str, onnx.ModelProto], *args, **kwargs):
summarize_model,
)

output_model = args[0] if args else kwargs.get("output_model", None)
model_check = kwargs.get("model_check", False)
input_shapes = kwargs.get("input_shapes", None)
inputs = kwargs.get("inputs", None)
outputs = kwargs.get("outputs", None)
no_shape_infer = kwargs.get("no_shape_infer", False)
no_constant_folding = kwargs.get("no_constant_folding", False)
dtype = kwargs.get("dtype", None)
skip_fusion_patterns = kwargs.get("skip_fusion_patterns", None)
inspect = kwargs.get("inspect", False)
dump_to_disk = kwargs.get("dump_to_disk", False)
save_as_external_data = kwargs.get("save_as_external_data", False)
model_check_inputs = kwargs.get("model_check_inputs", None)
verbose = kwargs.get("verbose", False)
output_model = args[0] if len(args) > 0 else kwargs.get('output_model', None)
model_check = kwargs.get('model_check', False)
input_shapes = kwargs.get('input_shapes', None)
inputs = kwargs.get('inputs', None)
outputs = kwargs.get('outputs', None)
no_shape_infer = kwargs.get('no_shape_infer', False)
no_constant_folding = kwargs.get('no_constant_folding', False)
dtype = kwargs.get('dtype', None)
skip_fusion_patterns = kwargs.get('skip_fusion_patterns', None)
inspect = kwargs.get('inspect', False)
dump_to_disk = kwargs.get('dump_to_disk', False)
save_as_external_data = kwargs.get('save_as_external_data', False)
model_check_inputs = kwargs.get('model_check_inputs', None)
verbose = kwargs.get('verbose', False)

logger = init_logging(verbose)

Expand Down
11 changes: 7 additions & 4 deletions onnxslim/core/optimization/weight_tying.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def replace_constant_references(existing_constant, to_be_removed_constant):
for i, constant_tensor in enumerate(constant_tensors):
if keep_constants[i]:
for j in range(i + 1, len(constant_tensors)):
if keep_constants[j] and constant_tensor == constant_tensors[j]:
keep_constants[j] = False
replace_constant_references(constant_tensor, constant_tensors[j])
logger.debug(f"Constant {constant_tensors[j].name} can be replaced by {constant_tensor.name}")
if keep_constants[j]:
if constant_tensor == constant_tensors[j]:
keep_constants[j] = False
replace_constant_references(constant_tensor, constant_tensors[j])
logger.debug(
f"Constant {constant_tensors[j].name} can be replaced by {constant_tensor.name}"
)
8 changes: 0 additions & 8 deletions onnxslim/core/pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ def get_name(name):


class NodeDescriptor:
"""Represents a node in a computational graph, detailing its operation type, inputs, and outputs."""

def __init__(self, node_spec):
"""Initialize NodeDescriptor with node_spec list requiring at least 4 elements."""
if not isinstance(node_spec, list):
Expand Down Expand Up @@ -89,8 +87,6 @@ def __dict__(self):


class Pattern:
"""Parses and matches ONNX graph patterns into NodeDescriptor objects for model optimization tasks."""

def __init__(self, pattern):
"""Initialize the Pattern class with a given pattern and parse its nodes."""
self.pattern = pattern
Expand All @@ -113,8 +109,6 @@ def __repr__(self):


class PatternMatcher:
"""Matches computational graph nodes to predefined patterns for optimization and transformation tasks."""

def __init__(self, pattern, priority):
"""Initialize the PatternMatcher with a given pattern and priority, and prepare node references and output
names.
Expand Down Expand Up @@ -190,8 +184,6 @@ def parameter_check(self):


class PatternGenerator:
"""Generates pattern templates from an ONNX model by processing its graph structure and node connections."""

def __init__(self, onnx_model):
"""Initialize the PatternGenerator class with an ONNX model and process its graph."""
self.graph = gs.import_onnx(onnx_model)
Expand Down
2 changes: 0 additions & 2 deletions onnxslim/core/pattern/elimination/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@


class ReshapePatternMatcher(PatternMatcher):
"""Matches and optimizes nested reshape operations in computational graphs to eliminate redundancy."""

def __init__(self, priority):
"""Initializes the ReshapePatternMatcher with a priority and a specific pattern for detecting nested reshape
operations.
Expand Down
2 changes: 0 additions & 2 deletions onnxslim/core/pattern/elimination/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@


class SlicePatternMatcher(PatternMatcher):
"""Matches and optimizes nested slice operations in ONNX graphs to improve computational efficiency."""

def __init__(self, priority):
"""Initializes the SlicePatternMatcher with a specified priority using a predefined graph pattern."""
pattern = Pattern(
Expand Down
89 changes: 40 additions & 49 deletions onnxslim/core/pattern/elimination/unsqueeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@


class UnsqueezePatternMatcher(PatternMatcher):
"""Matches and optimizes nested unsqueeze patterns in ONNX graphs to improve computational efficiency."""

def __init__(self, priority):
"""Initializes the UnsqueezePatternMatcher with a specified priority using a predefined graph pattern."""
pattern = Pattern(
Expand All @@ -31,60 +29,53 @@ def rewrite(self, opset=11):
node_unsqueeze_0 = self.unsqueeze_0
users_node_unsqueeze_0 = get_node_users(node_unsqueeze_0)
node_unsqueeze_1 = self.unsqueeze_1
if (
len(users_node_unsqueeze_0) == 1
and node_unsqueeze_0.inputs[0].shape
and node_unsqueeze_1.inputs[0].shape
and (
opset < 13
or (
isinstance(node_unsqueeze_0.inputs[1], gs.Constant)
and isinstance(node_unsqueeze_1.inputs[1], gs.Constant)
)
)
):
if len(users_node_unsqueeze_0) == 1 and node_unsqueeze_0.inputs[0].shape and node_unsqueeze_1.inputs[0].shape:
if opset < 13 or (
isinstance(node_unsqueeze_0.inputs[1], gs.Constant)
and isinstance(node_unsqueeze_1.inputs[1], gs.Constant)
):

def get_unsqueeze_axes(unsqueeze_node, opset):
dim = len(unsqueeze_node.inputs[0].shape)
if opset < 13:
axes = unsqueeze_node.attrs["axes"]
else:
axes = unsqueeze_node.inputs[1].values
return [axis + dim + len(axes) if axis < 0 else axis for axis in axes]
def get_unsqueeze_axes(unsqueeze_node, opset):
dim = len(unsqueeze_node.inputs[0].shape)
if opset < 13:
axes = unsqueeze_node.attrs["axes"]
else:
axes = unsqueeze_node.inputs[1].values
return [axis + dim + len(axes) if axis < 0 else axis for axis in axes]

axes_node_unsqueeze_0 = get_unsqueeze_axes(node_unsqueeze_0, opset)
axes_node_unsqueeze_1 = get_unsqueeze_axes(node_unsqueeze_1, opset)
axes_node_unsqueeze_0 = get_unsqueeze_axes(node_unsqueeze_0, opset)
axes_node_unsqueeze_1 = get_unsqueeze_axes(node_unsqueeze_1, opset)

axes_node_unsqueeze_0 = [
axis + sum(bool(axis_ <= axis) for axis_ in axes_node_unsqueeze_1) for axis in axes_node_unsqueeze_0
]
axes_node_unsqueeze_0 = [
axis + sum(1 for axis_ in axes_node_unsqueeze_1 if axis_ <= axis) for axis in axes_node_unsqueeze_0
]

inputs = [node_unsqueeze_0.inputs[0]]
outputs = list(node_unsqueeze_1.outputs)
node_unsqueeze_0.inputs.clear()
node_unsqueeze_0.outputs.clear()
node_unsqueeze_1.inputs.clear()
node_unsqueeze_1.outputs.clear()
inputs = [node_unsqueeze_0.inputs[0]]
outputs = list(node_unsqueeze_1.outputs)
node_unsqueeze_0.inputs.clear()
node_unsqueeze_0.outputs.clear()
node_unsqueeze_1.inputs.clear()
node_unsqueeze_1.outputs.clear()

if opset < 13:
attrs = {"axes": axes_node_unsqueeze_0 + axes_node_unsqueeze_1}
else:
attrs = None
inputs.append(
gs.Constant(
name=f"{node_unsqueeze_0.name}_axes",
values=np.array(axes_node_unsqueeze_0 + axes_node_unsqueeze_1, dtype=np.int64),
if opset < 13:
attrs = {"axes": axes_node_unsqueeze_0 + axes_node_unsqueeze_1}
else:
attrs = None
inputs.append(
gs.Constant(
name=f"{node_unsqueeze_0.name}_axes",
values=np.array(axes_node_unsqueeze_0 + axes_node_unsqueeze_1, dtype=np.int64),
)
)
)

match_case[node_unsqueeze_0.name] = {
"op": "Unsqueeze",
"inputs": inputs,
"outputs": outputs,
"name": node_unsqueeze_0.name,
"attrs": attrs,
"domain": None,
}
match_case[node_unsqueeze_0.name] = {
"op": "Unsqueeze",
"inputs": inputs,
"outputs": outputs,
"name": node_unsqueeze_0.name,
"attrs": attrs,
"domain": None,
}

return match_case

Expand Down
2 changes: 0 additions & 2 deletions onnxslim/core/pattern/fusion/convbn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@


class ConvBatchNormMatcher(PatternMatcher):
"""Fuses Conv and BatchNormalization layers in an ONNX graph to optimize model performance and inference speed."""

def __init__(self, priority):
"""Initializes the ConvBatchNormMatcher for fusing Conv and BatchNormalization layers in an ONNX graph."""
pattern = Pattern(
Expand Down
2 changes: 0 additions & 2 deletions onnxslim/core/pattern/fusion/gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@


class GeluPatternMatcher(PatternMatcher):
"""Matches and fuses GELU patterns in computational graphs for optimization purposes."""

def __init__(self, priority):
"""Initializes a `GeluPatternMatcher` to identify and fuse GELU patterns in a computational graph."""
pattern = Pattern(
Expand Down
2 changes: 0 additions & 2 deletions onnxslim/core/pattern/fusion/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@


class MatMulAddPatternMatcher(PatternMatcher):
"""Matches and fuses MatMul and Add operations in ONNX graphs to optimize computational efficiency."""

def __init__(self, priority):
"""Initializes a matcher for fusing MatMul and Add operations in ONNX graph optimization."""
pattern = Pattern(
Expand Down
2 changes: 0 additions & 2 deletions onnxslim/core/pattern/fusion/padconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@


class PadConvMatcher(PatternMatcher):
"""Matches and optimizes Pad-Conv patterns in ONNX graphs by ensuring padding parameters are constants."""

def __init__(self, priority):
"""Initializes the PadConvMatcher with a specified priority and defines its matching pattern."""
pattern = Pattern(
Expand Down
2 changes: 0 additions & 2 deletions onnxslim/core/pattern/fusion/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@


class ReducePatternMatcher(PatternMatcher):
"""Optimizes ONNX graph patterns with ReduceSum and Unsqueeze operations for improved model performance."""

def __init__(self, priority):
"""Initializes the ReducePatternMatcher with a specified pattern matching priority level."""
pattern = Pattern(
Expand Down
4 changes: 2 additions & 2 deletions onnxslim/misc/tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def make_header_line(is_header, colwidths, colaligns):
alignment = {"left": "<", "right": ">", "center": "^", "decimal": ">"}
# use the column widths generated by tabulate for the asciidoc column width specifiers
asciidoc_alignments = zip(colwidths, [alignment[colalign] for colalign in colaligns])
asciidoc_column_specifiers = [f"{width:d}{align}" for width, align in asciidoc_alignments]
asciidoc_column_specifiers = ["{:d}{}".format(width, align) for width, align in asciidoc_alignments]
header_list = ['cols="' + (",".join(asciidoc_column_specifiers)) + '"']

# generate the list of options (currently only "header")
Expand Down Expand Up @@ -2484,7 +2484,7 @@ def _wrap_chunks(self, chunks):
"""
lines = []
if self.width <= 0:
raise ValueError(f"invalid width {self.width!r} (must be > 0)")
raise ValueError("invalid width %r (must be > 0)" % self.width)
if self.max_lines is not None:
indent = self.subsequent_indent if self.max_lines > 1 else self.initial_indent
if self._len(indent) + self._len(self.placeholder.lstrip()) > self.width:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph


class BaseExporter:
"""BaseExporter provides a static method to export ONNX graphs to a specified destination format."""

class BaseExporter(object):
@staticmethod
def export_graph(graph: Graph):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ def check_duplicate_node_names(nodes: Sequence[Node], level=G_LOGGER.WARNING):
if not node.name:
continue
if node.name in name_map:
msg = f"Found distinct Nodes that share the same name:\n[id: {id(name_map[node.name])}]:\n {name_map[node.name]}---\n[id: {id(node)}]:\n {node}\n"
msg = "Found distinct Nodes that share the same name:\n[id: {:}]:\n {:}---\n[id: {:}]:\n {:}\n".format(
id(name_map[node.name]),
name_map[node.name],
id(node),
node,
)
G_LOGGER.log(msg, level)
else:
name_map[node.name] = node
Expand Down Expand Up @@ -105,8 +110,6 @@ def np_float32_to_bf16_as_uint16(arr):


class OnnxExporter(BaseExporter):
"""Exports internal graph structures to ONNX format for model interoperability."""

@staticmethod
def export_tensor_proto(tensor: Constant) -> onnx.TensorProto:
# Do *not* load LazyValues into an intermediate numpy array - instead, use
Expand Down Expand Up @@ -143,7 +146,9 @@ def export_value_info_proto(tensor: Tensor, do_type_check: bool) -> onnx.ValueIn
"""Creates an ONNX ValueInfoProto from a Tensor, optionally checking for dtype information."""
if do_type_check and tensor.dtype is None:
G_LOGGER.critical(
f"Graph input and output tensors must include dtype information. Please set the dtype attribute for: {tensor}"
"Graph input and output tensors must include dtype information. Please set the dtype attribute for: {:}".format(
tensor
)
)

if tensor.dtype is None:
Expand Down
Loading

0 comments on commit 016b5d6

Please sign in to comment.