Skip to content

Commit

Permalink
simple optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
Kayzwer committed Jun 10, 2024
1 parent 554d454 commit 5f4b815
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 75 deletions.
13 changes: 6 additions & 7 deletions onnxslim/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,10 @@ def get_previous_node_by_type(node, op_type, trajectory=None):
trajectory = []
node_feeds = get_node_feeds(node)
for node_feed in node_feeds:
trajectory.append(node_feed)
if node_feed.op == op_type:
trajectory.append(node_feed)
return trajectory
else:
trajectory.append(node_feed)
return get_previous_node_by_type(node_feed, op_type, trajectory)


Expand Down Expand Up @@ -112,7 +111,7 @@ def graph_constant_fold_inplace(graph):
graph_constant_fold_inplace(subgraph)

for node in graph.nodes:
if node.op in ["Identity", "Dropout"]:
if node.op in {"Identity", "Dropout"}:
delete_node(node)

elif node.op == "Pad":
Expand Down Expand Up @@ -229,10 +228,10 @@ def find_conv_transpose_nodes(node, opset):
"""X | Conv/ConvTranspose | BatchNormalization."""
# fmt: on
match = {}
if node.op == "BatchNormalization" and node.i(0).op in [
if node.op == "BatchNormalization" and node.i(0).op in {
"ConvTranspose",
"Conv",
]:
}:
conv_transpose_node = node.i(0)
conv_transpose_node_users = get_node_users(conv_transpose_node)
if len(conv_transpose_node_users) == 1:
Expand Down Expand Up @@ -802,7 +801,7 @@ def replace_node_references(existing_node, to_be_removed_node):
if keep_nodes[i]:
for j in range(i + 1, len(bucketed_nodes)):
if keep_nodes[j]:
logger.debug(f"node.op {bucketed_nodes[0].op} idx i: {i}, idx j: {j}")
logger.debug(f"node.op {bucketed_nodes[i].op} idx i: {i}, idx j: {j}")
if can_be_replaced(node, bucketed_nodes[j]):
keep_nodes[j] = False
existing_node = node
Expand Down Expand Up @@ -846,7 +845,7 @@ def optimize_model(model: Union[onnx.ModelProto, gs.Graph], skip_fusion_patterns
graph = model if isinstance(model, gs.Graph) else gs.import_onnx(model)
fusion_patterns = get_fusion_patterns(skip_fusion_patterns)
fusion_pairs = find_matches(graph, fusion_patterns)
for _, match in fusion_pairs.items():
for match in fusion_pairs.values():
graph.replace_custom_layer(**match)
graph.cleanup(remove_unused_graph_inputs=True).toposort()
graph_constant_fold_inplace(graph)
Expand Down
4 changes: 2 additions & 2 deletions onnxslim/core/slim.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def input_shape_modification(model: onnx.ModelProto, input_shapes: str) -> onnx.
raise Exception(f"Input name {key} not found in model, available keys: {' '.join(input_names)}")
tensors[key].shape = values_list

for _, tensor in tensors.items():
for tensor in tensors.values():
if tensor.name not in input_names:
if isinstance(tensor, Constant):
continue
Expand Down Expand Up @@ -127,7 +127,7 @@ def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto:
for node in graph.nodes:
if node.op == "Cast":
inp_dtype = [input.dtype for input in node.inputs][0]
if inp_dtype in [np.float16, np.float32]:
if inp_dtype in {np.float16, np.float32}:
delete_node(node)

for tensor in graph.tensors().values():
Expand Down
66 changes: 33 additions & 33 deletions onnxslim/core/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_dim_from_proto(dim):
def is_sequence(type_proto):
"""Check if the given ONNX proto type is a sequence."""
cls_type = type_proto.WhichOneof("value")
assert cls_type in ["tensor_type", "sequence_type"]
assert cls_type in {"tensor_type", "sequence_type"}
return cls_type == "sequence_type"


Expand Down Expand Up @@ -80,7 +80,7 @@ def is_literal(dim):
"""Check if a dimension is a literal number (int, np.int64, np.int32, sympy.Integer) or has an 'is_number'
attribute.
"""
return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(dim, "is_number") and dim.is_number)
return type(dim) in {int, np.int64, np.int32, sympy.Integer} or (hasattr(dim, "is_number") and dim.is_number)


def handle_negative_axis(axis, rank):
Expand Down Expand Up @@ -475,7 +475,7 @@ def _update_computed_dims(self, new_sympy_shape):

def _onnx_infer_single_node(self, node):
"""Performs ONNX shape inference for a single node, skipping inference for specified operation types."""
skip_infer = node.op_type in [
skip_infer = node.op_type in {
"If",
"Loop",
"Scan",
Expand Down Expand Up @@ -507,7 +507,7 @@ def _onnx_infer_single_node(self, node):
"NhwcConv",
"QuickGelu",
"RotaryEmbedding",
]
}

if not skip_infer:
# Only pass initializers that satisfy the following condition:
Expand All @@ -516,7 +516,7 @@ def _onnx_infer_single_node(self, node):
# (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec.
# (3) The initializer is not in graph input. The means the node input is "constant" in inference.
initializers = []
if (get_opset(self.out_mp_) >= 9) and node.op_type in ["Unsqueeze"]:
if (get_opset(self.out_mp_) >= 9) and node.op_type == "Unsqueeze":
initializers = [
self.initializers_[name]
for name in node.input
Expand All @@ -525,7 +525,7 @@ def _onnx_infer_single_node(self, node):

if (
node.op_type
in [
in {
"Add",
"Sub",
"Mul",
Expand All @@ -535,13 +535,13 @@ def _onnx_infer_single_node(self, node):
"MatMulInteger16",
"Where",
"Sum",
]
}
and node.output[0] in self.known_vi_
):
vi = self.known_vi_[node.output[0]]
out_rank = len(get_shape_from_type_proto(vi.type))
in_shapes = [self._get_shape(node, i) for i in range(len(node.input))]
for d in range(out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)):
for d in range(out_rank - (2 if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} else 0)):
in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank]
if len(in_dims) > 1:
self._check_merged_dims(in_dims, allow_broadcast=True)
Expand Down Expand Up @@ -675,7 +675,7 @@ def _compute_on_sympy_data(self, node, op_func):
# Before mul & div operations
# cast inputs into integer might lose decimal part and reduce precision
# keep them as float, finish the operation, then cast the result into integer
if node.op_type in ["Mul", "Div"]:
if node.op_type in {"Mul", "Div"}:
values = self._get_int_or_float_values(node, broadcast=True, allow_float_values=True)
else:
values = self._get_int_or_float_values(node, broadcast=True)
Expand All @@ -692,11 +692,11 @@ def _pass_on_sympy_data(self, node):
"""Pass Sympy data through a node, validating input length or node operation type 'Reshape', 'Unsqueeze',
'Squeeze'.
"""
assert len(node.input) == 1 or node.op_type in [
assert len(node.input) == 1 or node.op_type in {
"Reshape",
"Unsqueeze",
"Squeeze",
]
}
self._compute_on_sympy_data(node, lambda x: x[0])

def _pass_on_shape_and_type(self, node):
Expand Down Expand Up @@ -772,7 +772,7 @@ def _compute_conv_pool_shape(self, node, channels_last=False):
if pads is None:
pads = [0] * (2 * rank)
auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8")
if auto_pad not in ["VALID", "NOTSET"]:
if auto_pad not in {"VALID", "NOTSET"}:
try:
residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)]
total_pads = [
Expand Down Expand Up @@ -1147,7 +1147,7 @@ def _infer_Einsum(self, node): # noqa: N802
for i in range(num_ellipsis_indices):
new_sympy_shape.append(shape[i])
for c in left_equation:
if c not in [44, 46]: # c != b',' and c != b'.':
if c not in {44, 46}: # c != b',' and c != b'.':
if c in num_letter_occurrences:
num_letter_occurrences[c] = num_letter_occurrences[c] + 1
else:
Expand Down Expand Up @@ -1201,7 +1201,7 @@ def _infer_Gather(self, node): # noqa: N802
else:
self.sympy_data_[node.output[0]] = data[int(idx)]
else:
assert idx in [0, -1]
assert idx in {0, -1}
self.sympy_data_[node.output[0]] = data

def _infer_GatherElements(self, node): # noqa: N802
Expand Down Expand Up @@ -1434,7 +1434,7 @@ def _infer_aten_diagonal(self, node):
dim1 = handle_negative_axis(dim1, rank)
dim2 = handle_negative_axis(dim2, rank)

new_shape = [val for dim, val in enumerate(sympy_shape) if dim not in [dim1, dim2]]
new_shape = [val for dim, val in enumerate(sympy_shape) if dim not in {dim1, dim2}]
shape1 = sympy_shape[dim1]
shape2 = sympy_shape[dim2]
if offset >= 0:
Expand Down Expand Up @@ -1475,7 +1475,7 @@ def _infer_aten_pool2d(self, node):
"""Infer the output shape of a 2D pooling operation in an ONNX graph node."""
sympy_shape = self._get_sympy_shape(node, 0)
assert len(sympy_shape) == 4
sympy_shape[-2:] = [self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3]]
sympy_shape[-2:] = [self._new_symbolic_dim_from_output(node, 0, i) for i in {2, 3}]
self._update_computed_dims(sympy_shape)
for i, o in enumerate(node.output):
if not o:
Expand Down Expand Up @@ -1581,7 +1581,7 @@ def _infer_aten_group_norm(self, node):
N = input_shape[0] if input_shape is not None and len(input_shape) != 0 else None # noqa: N806
group = self._try_get_value(node, 6)
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
for i in [1, 2]:
for i in {1, 2}:
if node.output[i]:
vi = self.known_vi_[node.output[i]]
vi.CopyFrom(
Expand Down Expand Up @@ -1621,7 +1621,7 @@ def _infer_BatchNormalization(self, node): # noqa: N802
self._propagate_shape_and_type(node)

# this works for opsets < 14 and 14 since we check i < len(node.output) in the loop
for i in [1, 2, 3, 4]:
for i in {1, 2, 3, 4}:
if i < len(node.output) and node.output[i]:
# all of these parameters have the same shape as the 1st input
self._propagate_shape_and_type(node, input_index=1, output_index=i)
Expand Down Expand Up @@ -2198,7 +2198,7 @@ def _infer_TopK(self, node): # noqa: N802
k = self._get_int_or_float_values(node)[1]

k = self._new_symbolic_dim_from_output(node) if k is None else as_scalar(k)
if type(k) in [int, str]:
if type(k) in {int, str}:
new_shape[axis] = k
else:
new_sympy_shape = self._get_sympy_shape(node, 0)
Expand Down Expand Up @@ -2582,10 +2582,10 @@ def _infer_LayerNormalization(self, node): # noqa: N802
axis = handle_negative_axis(axis, rank)
mean_shape = x_shape[:axis] + [1 for _ in range(rank - axis)]
mean_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
if mean_dtype in [
if mean_dtype in {
onnx.TensorProto.FLOAT16,
onnx.TensorProto.BFLOAT16,
]:
}:
mean_dtype = onnx.TensorProto.FLOAT
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[1], mean_dtype, mean_shape))
Expand Down Expand Up @@ -2788,7 +2788,7 @@ def get_prereq(node):
get_attribute(node, "then_branch"),
get_attribute(node, "else_branch"),
]
elif node.op_type in ["Loop", "Scan"]:
elif node.op_type in {"Loop", "Scan"}:
subgraphs = [get_attribute(node, "body")]
for g in subgraphs:
g_outputs_and_initializers = {i.name for i in g.initializer}
Expand Down Expand Up @@ -2833,7 +2833,7 @@ def get_prereq(node):
known_aten_op = False
if node.op_type in self.dispatcher_:
self.dispatcher_[node.op_type](node)
elif node.op_type in ["ConvTranspose"]:
elif node.op_type == "ConvTranspose":
# onnx shape inference ops like ConvTranspose may have empty shape for symbolic input
# before adding symbolic compute for them
# mark the output type as UNDEFINED to allow guessing of rank
Expand All @@ -2859,7 +2859,7 @@ def get_prereq(node):

# onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb']
# symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case
if node.op_type in [
if node.op_type in {
"Add",
"Sub",
"Mul",
Expand All @@ -2869,11 +2869,11 @@ def get_prereq(node):
"MatMulInteger16",
"Where",
"Sum",
]:
}:
vi = self.known_vi_[node.output[0]]
out_rank = len(get_shape_from_type_proto(vi.type))
in_shapes = [self._get_shape(node, i) for i in range(len(node.input))]
for d in range(out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)):
for d in range(out_rank - (2 if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} else 0)):
in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank]
if len(in_dims) > 1:
self._check_merged_dims(in_dims, allow_broadcast=True)
Expand All @@ -2884,10 +2884,10 @@ def get_prereq(node):
# 2) We do not care about the extraneous constant outputs in RotaryEmbedding because
# the RotaryEmbedding op created during export can be replaced by the RotaryEmbedding
# contrib op
if node.op_type in [
if node.op_type in {
"SkipLayerNormalization",
"SkipSimplifiedLayerNormalization",
] and i_o in [1, 2]:
} and i_o in {1, 2}:
continue
if node.op_type == "RotaryEmbedding" and len(node.output) > 1:
# Skip symbolic shape inference for RotaryEmbedding functions that have extraneous outputs
Expand All @@ -2899,7 +2899,7 @@ def get_prereq(node):
out_type_kind = out_type.WhichOneof("value")

# do not process shape for non-tensors
if out_type_kind not in ["tensor_type", "sparse_tensor_type", None]:
if out_type_kind not in {"tensor_type", "sparse_tensor_type", None}:
if self.verbose_ > 2:
if out_type_kind == "sequence_type":
seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value")
Expand Down Expand Up @@ -2937,7 +2937,7 @@ def get_prereq(node):
out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape))
) or out_type_undefined:
if self.auto_merge_:
if node.op_type in [
if node.op_type in {
"Add",
"Sub",
"Mul",
Expand All @@ -2955,13 +2955,13 @@ def get_prereq(node):
"GreaterOrEqual",
"Min",
"Max",
]:
}:
shapes = [self._get_shape(node, i) for i in range(len(node.input))]
if node.op_type in [
if node.op_type in {
"MatMul",
"MatMulInteger",
"MatMulInteger16",
] and (None in out_shape or self._is_shape_contains_none_dim(out_shape)):
} and (None in out_shape or self._is_shape_contains_none_dim(out_shape)):
if None in out_shape:
idx = out_shape.index(None)
else:
Expand Down
Loading

0 comments on commit 5f4b815

Please sign in to comment.