Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
UltralyticsAssistant committed Sep 7, 2024
1 parent 016b5d6 commit bdf5126
Show file tree
Hide file tree
Showing 19 changed files with 104 additions and 135 deletions.
2 changes: 1 addition & 1 deletion examples/common_subexpression_elimination/cse_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class Model(torch.nn.Module):
def __init__(self):
"""Initializes the Model class with a single LayerNorm layer of embedding dimension 10."""
super(Model, self).__init__()
super().__init__()
embedding_dim = 10
self.layer_norm = nn.LayerNorm(embedding_dim)

Expand Down
30 changes: 15 additions & 15 deletions onnxslim/cli/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def slim(model: Union[str, onnx.ModelProto], *args, **kwargs):
from onnxslim.core import (
convert_data_format,
freeze,
input_modification,
input_shape_modification,
optimize,
input_modification,
output_modification,
shape_infer,
)
Expand All @@ -29,20 +29,20 @@ def slim(model: Union[str, onnx.ModelProto], *args, **kwargs):
summarize_model,
)

output_model = args[0] if len(args) > 0 else kwargs.get('output_model', None)
model_check = kwargs.get('model_check', False)
input_shapes = kwargs.get('input_shapes', None)
inputs = kwargs.get('inputs', None)
outputs = kwargs.get('outputs', None)
no_shape_infer = kwargs.get('no_shape_infer', False)
no_constant_folding = kwargs.get('no_constant_folding', False)
dtype = kwargs.get('dtype', None)
skip_fusion_patterns = kwargs.get('skip_fusion_patterns', None)
inspect = kwargs.get('inspect', False)
dump_to_disk = kwargs.get('dump_to_disk', False)
save_as_external_data = kwargs.get('save_as_external_data', False)
model_check_inputs = kwargs.get('model_check_inputs', None)
verbose = kwargs.get('verbose', False)
output_model = args[0] if len(args) > 0 else kwargs.get("output_model", None)
model_check = kwargs.get("model_check", False)
input_shapes = kwargs.get("input_shapes", None)
inputs = kwargs.get("inputs", None)
outputs = kwargs.get("outputs", None)
no_shape_infer = kwargs.get("no_shape_infer", False)
no_constant_folding = kwargs.get("no_constant_folding", False)
dtype = kwargs.get("dtype", None)
skip_fusion_patterns = kwargs.get("skip_fusion_patterns", None)
inspect = kwargs.get("inspect", False)
dump_to_disk = kwargs.get("dump_to_disk", False)
save_as_external_data = kwargs.get("save_as_external_data", False)
model_check_inputs = kwargs.get("model_check_inputs", None)
verbose = kwargs.get("verbose", False)

logger = init_logging(verbose)

Expand Down
4 changes: 2 additions & 2 deletions onnxslim/misc/tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def make_header_line(is_header, colwidths, colaligns):
alignment = {"left": "<", "right": ">", "center": "^", "decimal": ">"}
# use the column widths generated by tabulate for the asciidoc column width specifiers
asciidoc_alignments = zip(colwidths, [alignment[colalign] for colalign in colaligns])
asciidoc_column_specifiers = ["{:d}{}".format(width, align) for width, align in asciidoc_alignments]
asciidoc_column_specifiers = [f"{width:d}{align}" for width, align in asciidoc_alignments]
header_list = ['cols="' + (",".join(asciidoc_column_specifiers)) + '"']

# generate the list of options (currently only "header")
Expand Down Expand Up @@ -2484,7 +2484,7 @@ def _wrap_chunks(self, chunks):
"""
lines = []
if self.width <= 0:
raise ValueError("invalid width %r (must be > 0)" % self.width)
raise ValueError(f"invalid width {self.width!r} (must be > 0)")
if self.max_lines is not None:
indent = self.subsequent_indent if self.max_lines > 1 else self.initial_indent
if self._len(indent) + self._len(self.placeholder.lstrip()) > self.width:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph


class BaseExporter(object):
class BaseExporter:
@staticmethod
def export_graph(graph: Graph):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,7 @@ def check_duplicate_node_names(nodes: Sequence[Node], level=G_LOGGER.WARNING):
if not node.name:
continue
if node.name in name_map:
msg = "Found distinct Nodes that share the same name:\n[id: {:}]:\n {:}---\n[id: {:}]:\n {:}\n".format(
id(name_map[node.name]),
name_map[node.name],
id(node),
node,
)
msg = f"Found distinct Nodes that share the same name:\n[id: {id(name_map[node.name])}]:\n {name_map[node.name]}---\n[id: {id(node)}]:\n {node}\n"
G_LOGGER.log(msg, level)
else:
name_map[node.name] = node
Expand Down Expand Up @@ -146,9 +141,7 @@ def export_value_info_proto(tensor: Tensor, do_type_check: bool) -> onnx.ValueIn
"""Creates an ONNX ValueInfoProto from a Tensor, optionally checking for dtype information."""
if do_type_check and tensor.dtype is None:
G_LOGGER.critical(
"Graph input and output tensors must include dtype information. Please set the dtype attribute for: {:}".format(
tensor
)
f"Graph input and output tensors must include dtype information. Please set the dtype attribute for: {tensor}"
)

if tensor.dtype is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(self) -> None:
self.op = None # op (str)
self.check_func = None # callback function for single node
# pattern node name -> GraphPattern nodes(single or subpattern)
self.nodes: Dict[str, "GraphPattern"] = {}
self.nodes: Dict[str, GraphPattern] = {}
# pattern node name -> input tensors
self.node_inputs: Dict[str, List[int]] = {}
# pattern node name -> output tensors
Expand Down Expand Up @@ -239,15 +239,13 @@ def _single_node_match(self, onnx_node: Node) -> bool:
with G_LOGGER.indent():
if self.op != onnx_node.op:
G_LOGGER.info(
"No match because: Op did not match. Node op was: {:} but pattern op was: {:}.".format(
onnx_node.op, self.op
)
f"No match because: Op did not match. Node op was: {onnx_node.op} but pattern op was: {self.op}."
)
return False
if self.check_func is not None and not self.check_func(onnx_node):
G_LOGGER.info("No match because: check_func returned false.")
return False
G_LOGGER.info("Single node is matched: {:}, {:}".format(self.op, onnx_node.name))
G_LOGGER.info(f"Single node is matched: {self.op}, {onnx_node.name}")
return True

def _get_tensor_index_for_node(self, node: str, tensor_id: int, is_node_input: bool):
Expand Down Expand Up @@ -330,7 +328,7 @@ def _match_node(
) -> bool:
"""Matches ONNX nodes to the graph pattern starting from a specific node and tensor context."""
with G_LOGGER.indent():
G_LOGGER.info("Checking node: {:} against pattern node: {:}.".format(onnx_node.name, node_name))
G_LOGGER.info(f"Checking node: {onnx_node.name} against pattern node: {node_name}.")
tensor_index_for_node = self._get_tensor_index_for_node(node_name, from_tensor, is_node_input=from_inbound)
subgraph_mapping = self.nodes[node_name].match(
onnx_node,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph


class BaseImporter(object):
class BaseImporter:
@staticmethod
def import_graph(graph) -> Graph:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,11 @@ def process_attr(attr_str: str):
attr_dict[attr.name] = process_attr(attr_str)
else:
G_LOGGER.warning(
"Attribute of type {:} is currently unsupported. Skipping attribute.".format(attr_str)
f"Attribute of type {attr_str} is currently unsupported. Skipping attribute."
)
else:
G_LOGGER.warning(
"Attribute type: {:} was not recognized. Was the graph generated with a newer IR version than the installed `onnx` package? Skipping attribute.".format(
attr.type
)
f"Attribute type: {attr.type} was not recognized. Was the graph generated with a newer IR version than the installed `onnx` package? Skipping attribute."
)
return attr_dict

Expand Down Expand Up @@ -315,9 +313,7 @@ def get_tensor(name: str, check_outer_graph=True):
return Variable.empty()

G_LOGGER.verbose(
"Tensor: {:} was not generated during shape inference, or shape inference was not run on this model. Creating a new Tensor.".format(
name
)
f"Tensor: {name} was not generated during shape inference, or shape inference was not run on this model. Creating a new Tensor."
)
subgraph_tensor_map[name] = Variable(name)
return subgraph_tensor_map[name]
Expand Down
Loading

0 comments on commit bdf5126

Please sign in to comment.