From 5f4b8159965d3e1d9f7d98b8591c121dd3e4896f Mon Sep 17 00:00:00 2001 From: Kayzwer Date: Mon, 10 Jun 2024 22:31:18 +0800 Subject: [PATCH] simple optimizations --- onnxslim/core/optimizer.py | 13 ++-- onnxslim/core/slim.py | 4 +- onnxslim/core/symbolic_shape_infer.py | 66 +++++++++---------- onnxslim/misc/tabulate.py | 31 +++++---- .../importers/onnx_importer.py | 6 +- onnxslim/onnx_graphsurgeon/ir/graph.py | 16 ++--- onnxslim/onnx_graphsurgeon/ir/node.py | 2 +- onnxslim/onnx_graphsurgeon/ir/tensor.py | 6 +- onnxslim/utils.py | 4 +- 9 files changed, 73 insertions(+), 75 deletions(-) diff --git a/onnxslim/core/optimizer.py b/onnxslim/core/optimizer.py index f22a963..aa84bfe 100644 --- a/onnxslim/core/optimizer.py +++ b/onnxslim/core/optimizer.py @@ -58,11 +58,10 @@ def get_previous_node_by_type(node, op_type, trajectory=None): trajectory = [] node_feeds = get_node_feeds(node) for node_feed in node_feeds: + trajectory.append(node_feed) if node_feed.op == op_type: - trajectory.append(node_feed) return trajectory else: - trajectory.append(node_feed) return get_previous_node_by_type(node_feed, op_type, trajectory) @@ -112,7 +111,7 @@ def graph_constant_fold_inplace(graph): graph_constant_fold_inplace(subgraph) for node in graph.nodes: - if node.op in ["Identity", "Dropout"]: + if node.op in {"Identity", "Dropout"}: delete_node(node) elif node.op == "Pad": @@ -229,10 +228,10 @@ def find_conv_transpose_nodes(node, opset): """X | Conv/ConvTranspose | BatchNormalization.""" # fmt: on match = {} - if node.op == "BatchNormalization" and node.i(0).op in [ + if node.op == "BatchNormalization" and node.i(0).op in { "ConvTranspose", "Conv", - ]: + }: conv_transpose_node = node.i(0) conv_transpose_node_users = get_node_users(conv_transpose_node) if len(conv_transpose_node_users) == 1: @@ -802,7 +801,7 @@ def replace_node_references(existing_node, to_be_removed_node): if keep_nodes[i]: for j in range(i + 1, len(bucketed_nodes)): if keep_nodes[j]: - logger.debug(f"node.op {bucketed_nodes[0].op} idx i: {i}, idx j: {j}") + logger.debug(f"node.op {bucketed_nodes[i].op} idx i: {i}, idx j: {j}") if can_be_replaced(node, bucketed_nodes[j]): keep_nodes[j] = False existing_node = node @@ -846,7 +845,7 @@ def optimize_model(model: Union[onnx.ModelProto, gs.Graph], skip_fusion_patterns graph = model if isinstance(model, gs.Graph) else gs.import_onnx(model) fusion_patterns = get_fusion_patterns(skip_fusion_patterns) fusion_pairs = find_matches(graph, fusion_patterns) - for _, match in fusion_pairs.items(): + for match in fusion_pairs.values(): graph.replace_custom_layer(**match) graph.cleanup(remove_unused_graph_inputs=True).toposort() graph_constant_fold_inplace(graph) diff --git a/onnxslim/core/slim.py b/onnxslim/core/slim.py index 7789a87..6183e33 100644 --- a/onnxslim/core/slim.py +++ b/onnxslim/core/slim.py @@ -30,7 +30,7 @@ def input_shape_modification(model: onnx.ModelProto, input_shapes: str) -> onnx. raise Exception(f"Input name {key} not found in model, available keys: {' '.join(input_names)}") tensors[key].shape = values_list - for _, tensor in tensors.items(): + for tensor in tensors.values(): if tensor.name not in input_names: if isinstance(tensor, Constant): continue @@ -127,7 +127,7 @@ def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto: for node in graph.nodes: if node.op == "Cast": inp_dtype = [input.dtype for input in node.inputs][0] - if inp_dtype in [np.float16, np.float32]: + if inp_dtype in {np.float16, np.float32}: delete_node(node) for tensor in graph.tensors().values(): diff --git a/onnxslim/core/symbolic_shape_infer.py b/onnxslim/core/symbolic_shape_infer.py index 057b2a1..3b0d4d9 100644 --- a/onnxslim/core/symbolic_shape_infer.py +++ b/onnxslim/core/symbolic_shape_infer.py @@ -30,7 +30,7 @@ def get_dim_from_proto(dim): def is_sequence(type_proto): """Check if the given ONNX proto type is a sequence.""" cls_type = type_proto.WhichOneof("value") - assert cls_type in ["tensor_type", "sequence_type"] + assert cls_type in {"tensor_type", "sequence_type"} return cls_type == "sequence_type" @@ -80,7 +80,7 @@ def is_literal(dim): """Check if a dimension is a literal number (int, np.int64, np.int32, sympy.Integer) or has an 'is_number' attribute. """ - return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(dim, "is_number") and dim.is_number) + return type(dim) in {int, np.int64, np.int32, sympy.Integer} or (hasattr(dim, "is_number") and dim.is_number) def handle_negative_axis(axis, rank): @@ -475,7 +475,7 @@ def _update_computed_dims(self, new_sympy_shape): def _onnx_infer_single_node(self, node): """Performs ONNX shape inference for a single node, skipping inference for specified operation types.""" - skip_infer = node.op_type in [ + skip_infer = node.op_type in { "If", "Loop", "Scan", @@ -507,7 +507,7 @@ def _onnx_infer_single_node(self, node): "NhwcConv", "QuickGelu", "RotaryEmbedding", - ] + } if not skip_infer: # Only pass initializers that satisfy the following condition: @@ -516,7 +516,7 @@ def _onnx_infer_single_node(self, node): # (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec. # (3) The initializer is not in graph input. The means the node input is "constant" in inference. initializers = [] - if (get_opset(self.out_mp_) >= 9) and node.op_type in ["Unsqueeze"]: + if (get_opset(self.out_mp_) >= 9) and node.op_type == "Unsqueeze": initializers = [ self.initializers_[name] for name in node.input @@ -525,7 +525,7 @@ def _onnx_infer_single_node(self, node): if ( node.op_type - in [ + in { "Add", "Sub", "Mul", @@ -535,13 +535,13 @@ def _onnx_infer_single_node(self, node): "MatMulInteger16", "Where", "Sum", - ] + } and node.output[0] in self.known_vi_ ): vi = self.known_vi_[node.output[0]] out_rank = len(get_shape_from_type_proto(vi.type)) in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] - for d in range(out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)): + for d in range(out_rank - (2 if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} else 0)): in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank] if len(in_dims) > 1: self._check_merged_dims(in_dims, allow_broadcast=True) @@ -675,7 +675,7 @@ def _compute_on_sympy_data(self, node, op_func): # Before mul & div operations # cast inputs into integer might lose decimal part and reduce precision # keep them as float, finish the operation, then cast the result into integer - if node.op_type in ["Mul", "Div"]: + if node.op_type in {"Mul", "Div"}: values = self._get_int_or_float_values(node, broadcast=True, allow_float_values=True) else: values = self._get_int_or_float_values(node, broadcast=True) @@ -692,11 +692,11 @@ def _pass_on_sympy_data(self, node): """Pass Sympy data through a node, validating input length or node operation type 'Reshape', 'Unsqueeze', 'Squeeze'. """ - assert len(node.input) == 1 or node.op_type in [ + assert len(node.input) == 1 or node.op_type in { "Reshape", "Unsqueeze", "Squeeze", - ] + } self._compute_on_sympy_data(node, lambda x: x[0]) def _pass_on_shape_and_type(self, node): @@ -772,7 +772,7 @@ def _compute_conv_pool_shape(self, node, channels_last=False): if pads is None: pads = [0] * (2 * rank) auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8") - if auto_pad not in ["VALID", "NOTSET"]: + if auto_pad not in {"VALID", "NOTSET"}: try: residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)] total_pads = [ @@ -1147,7 +1147,7 @@ def _infer_Einsum(self, node): # noqa: N802 for i in range(num_ellipsis_indices): new_sympy_shape.append(shape[i]) for c in left_equation: - if c not in [44, 46]: # c != b',' and c != b'.': + if c not in {44, 46}: # c != b',' and c != b'.': if c in num_letter_occurrences: num_letter_occurrences[c] = num_letter_occurrences[c] + 1 else: @@ -1201,7 +1201,7 @@ def _infer_Gather(self, node): # noqa: N802 else: self.sympy_data_[node.output[0]] = data[int(idx)] else: - assert idx in [0, -1] + assert idx in {0, -1} self.sympy_data_[node.output[0]] = data def _infer_GatherElements(self, node): # noqa: N802 @@ -1434,7 +1434,7 @@ def _infer_aten_diagonal(self, node): dim1 = handle_negative_axis(dim1, rank) dim2 = handle_negative_axis(dim2, rank) - new_shape = [val for dim, val in enumerate(sympy_shape) if dim not in [dim1, dim2]] + new_shape = [val for dim, val in enumerate(sympy_shape) if dim not in {dim1, dim2}] shape1 = sympy_shape[dim1] shape2 = sympy_shape[dim2] if offset >= 0: @@ -1475,7 +1475,7 @@ def _infer_aten_pool2d(self, node): """Infer the output shape of a 2D pooling operation in an ONNX graph node.""" sympy_shape = self._get_sympy_shape(node, 0) assert len(sympy_shape) == 4 - sympy_shape[-2:] = [self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3]] + sympy_shape[-2:] = [self._new_symbolic_dim_from_output(node, 0, i) for i in {2, 3}] self._update_computed_dims(sympy_shape) for i, o in enumerate(node.output): if not o: @@ -1581,7 +1581,7 @@ def _infer_aten_group_norm(self, node): N = input_shape[0] if input_shape is not None and len(input_shape) != 0 else None # noqa: N806 group = self._try_get_value(node, 6) output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - for i in [1, 2]: + for i in {1, 2}: if node.output[i]: vi = self.known_vi_[node.output[i]] vi.CopyFrom( @@ -1621,7 +1621,7 @@ def _infer_BatchNormalization(self, node): # noqa: N802 self._propagate_shape_and_type(node) # this works for opsets < 14 and 14 since we check i < len(node.output) in the loop - for i in [1, 2, 3, 4]: + for i in {1, 2, 3, 4}: if i < len(node.output) and node.output[i]: # all of these parameters have the same shape as the 1st input self._propagate_shape_and_type(node, input_index=1, output_index=i) @@ -2198,7 +2198,7 @@ def _infer_TopK(self, node): # noqa: N802 k = self._get_int_or_float_values(node)[1] k = self._new_symbolic_dim_from_output(node) if k is None else as_scalar(k) - if type(k) in [int, str]: + if type(k) in {int, str}: new_shape[axis] = k else: new_sympy_shape = self._get_sympy_shape(node, 0) @@ -2582,10 +2582,10 @@ def _infer_LayerNormalization(self, node): # noqa: N802 axis = handle_negative_axis(axis, rank) mean_shape = x_shape[:axis] + [1 for _ in range(rank - axis)] mean_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - if mean_dtype in [ + if mean_dtype in { onnx.TensorProto.FLOAT16, onnx.TensorProto.BFLOAT16, - ]: + }: mean_dtype = onnx.TensorProto.FLOAT vi = self.known_vi_[node.output[1]] vi.CopyFrom(helper.make_tensor_value_info(node.output[1], mean_dtype, mean_shape)) @@ -2788,7 +2788,7 @@ def get_prereq(node): get_attribute(node, "then_branch"), get_attribute(node, "else_branch"), ] - elif node.op_type in ["Loop", "Scan"]: + elif node.op_type in {"Loop", "Scan"}: subgraphs = [get_attribute(node, "body")] for g in subgraphs: g_outputs_and_initializers = {i.name for i in g.initializer} @@ -2833,7 +2833,7 @@ def get_prereq(node): known_aten_op = False if node.op_type in self.dispatcher_: self.dispatcher_[node.op_type](node) - elif node.op_type in ["ConvTranspose"]: + elif node.op_type == "ConvTranspose": # onnx shape inference ops like ConvTranspose may have empty shape for symbolic input # before adding symbolic compute for them # mark the output type as UNDEFINED to allow guessing of rank @@ -2859,7 +2859,7 @@ def get_prereq(node): # onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb'] # symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case - if node.op_type in [ + if node.op_type in { "Add", "Sub", "Mul", @@ -2869,11 +2869,11 @@ def get_prereq(node): "MatMulInteger16", "Where", "Sum", - ]: + }: vi = self.known_vi_[node.output[0]] out_rank = len(get_shape_from_type_proto(vi.type)) in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] - for d in range(out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)): + for d in range(out_rank - (2 if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} else 0)): in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank] if len(in_dims) > 1: self._check_merged_dims(in_dims, allow_broadcast=True) @@ -2884,10 +2884,10 @@ def get_prereq(node): # 2) We do not care about the extraneous constant outputs in RotaryEmbedding because # the RotaryEmbedding op created during export can be replaced by the RotaryEmbedding # contrib op - if node.op_type in [ + if node.op_type in { "SkipLayerNormalization", "SkipSimplifiedLayerNormalization", - ] and i_o in [1, 2]: + } and i_o in {1, 2}: continue if node.op_type == "RotaryEmbedding" and len(node.output) > 1: # Skip symbolic shape inference for RotaryEmbedding functions that have extraneous outputs @@ -2899,7 +2899,7 @@ def get_prereq(node): out_type_kind = out_type.WhichOneof("value") # do not process shape for non-tensors - if out_type_kind not in ["tensor_type", "sparse_tensor_type", None]: + if out_type_kind not in {"tensor_type", "sparse_tensor_type", None}: if self.verbose_ > 2: if out_type_kind == "sequence_type": seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value") @@ -2937,7 +2937,7 @@ def get_prereq(node): out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape)) ) or out_type_undefined: if self.auto_merge_: - if node.op_type in [ + if node.op_type in { "Add", "Sub", "Mul", @@ -2955,13 +2955,13 @@ def get_prereq(node): "GreaterOrEqual", "Min", "Max", - ]: + }: shapes = [self._get_shape(node, i) for i in range(len(node.input))] - if node.op_type in [ + if node.op_type in { "MatMul", "MatMulInteger", "MatMulInteger16", - ] and (None in out_shape or self._is_shape_contains_none_dim(out_shape)): + } and (None in out_shape or self._is_shape_contains_none_dim(out_shape)): if None in out_shape: idx = out_shape.index(None) else: diff --git a/onnxslim/misc/tabulate.py b/onnxslim/misc/tabulate.py index 9f43b1c..f2f7778 100644 --- a/onnxslim/misc/tabulate.py +++ b/onnxslim/misc/tabulate.py @@ -106,8 +106,7 @@ def _is_file(f): def _is_separating_line(row): """Determine if a row is a separating line based on its type and specific content conditions.""" - row_type = type(row) - return row_type in [list, str] and ( + return type(row) in {list, str} and ( (len(row) >= 1 and row[0] == SEPARATING_LINE) or (len(row) >= 2 and row[1] == SEPARATING_LINE) ) @@ -117,7 +116,7 @@ def _pipe_segment_with_colons(align, colwidth): format). """ w = colwidth - if align in ["right", "decimal"]: + if align in {"right", "decimal"}: return ("-" * (w - 1)) + ":" elif align == "center": return ":" + ("-" * (w - 2)) + ":" @@ -854,7 +853,7 @@ def _isnumber(string): if not _isconvertible(float, string): return False elif isinstance(string, (str, bytes)) and (math.isinf(float(string)) or math.isnan(float(string))): - return string.lower() in ["inf", "-inf", "nan"] + return string.lower() in {"inf", "-inf", "nan"} return True @@ -884,7 +883,7 @@ def _isbool(string): >>> _isbool(1) False """ - return type(string) is bool or (isinstance(string, (bytes, str)) and string in ("True", "False")) + return type(string) is bool or (isinstance(string, (bytes, str)) and string in {"True", "False"}) def _type(string, has_invisible=True, numparse=True): @@ -1343,7 +1342,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): elif hasattr(tabular_data, "index"): # values is a property, has .index => it's likely a pandas.DataFrame (pandas 0.11.0) keys = list(tabular_data) - if showindex in ["default", "always", True] and tabular_data.index.name is not None: + if showindex in {"default", "always", True} and tabular_data.index.name is not None: if isinstance(tabular_data.index.name, list): keys[:0] = tabular_data.index.name else: @@ -1440,7 +1439,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): rows = list(map(lambda r: r if _is_separating_line(r) else list(r), rows)) # add or remove an index column - showindex_is_a_str = type(showindex) in [str, bytes] + showindex_is_a_str = type(showindex) in {str, bytes} if showindex == "default" and index is not None: rows = _prepend_row_index(rows, index) elif isinstance(showindex, Sized) and not showindex_is_a_str: @@ -2129,7 +2128,7 @@ def tabulate( if colglobalalign is not None: # if global alignment provided aligns = [colglobalalign] * len(cols) else: # default - aligns = [numalign if ct in [int, float] else stralign for ct in coltypes] + aligns = [numalign if ct in {int, float} else stralign for ct in coltypes] # then specific alignments if colalign is not None: assert isinstance(colalign, Iterable) @@ -2611,25 +2610,25 @@ def _main(): sep = r"\s+" outfile = "-" for opt, value in opts: - if opt in ["-1", "--header"]: + if opt in {"-1", "--header"}: headers = "firstrow" - elif opt in ["-o", "--output"]: + elif opt in {"-o", "--output"}: outfile = value - elif opt in ["-F", "--float"]: + elif opt in {"-F", "--float"}: floatfmt = value - elif opt in ["-I", "--int"]: + elif opt in {"-I", "--int"}: intfmt = value - elif opt in ["-C", "--colalign"]: + elif opt in {"-C", "--colalign"}: colalign = value.split() - elif opt in ["-f", "--format"]: + elif opt in {"-f", "--format"}: if value not in tabulate_formats: print(f"{value} is not a supported table format") print(usage) sys.exit(3) tablefmt = value - elif opt in ["-s", "--sep"]: + elif opt in {"-s", "--sep"}: sep = value - elif opt in ["-h", "--help"]: + elif opt in {"-h", "--help"}: print(usage) sys.exit(0) files = args or [sys.stdin] diff --git a/onnxslim/onnx_graphsurgeon/importers/onnx_importer.py b/onnxslim/onnx_graphsurgeon/importers/onnx_importer.py index 5d969e6..7edc7e2 100644 --- a/onnxslim/onnx_graphsurgeon/importers/onnx_importer.py +++ b/onnxslim/onnx_graphsurgeon/importers/onnx_importer.py @@ -83,12 +83,12 @@ def get_itemsize(dtype): if dtype == onnx.TensorProto.BFLOAT16: return 2 - if dtype in [ + if dtype in { onnx.TensorProto.FLOAT8E4M3FN, onnx.TensorProto.FLOAT8E4M3FNUZ, onnx.TensorProto.FLOAT8E5M2, onnx.TensorProto.FLOAT8E5M2FNUZ, - ]: + }: return 1 G_LOGGER.critical(f"Unsupported type: {dtype}") @@ -193,7 +193,7 @@ def get_opset(model_or_func: Union[onnx.ModelProto, onnx.FunctionProto]): class_name = "Function" if isinstance(model_or_func, onnx.FunctionProto) else "Model" try: for importer in OnnxImporter.get_import_domains(model_or_func): - if importer.domain in ["", "ai.onnx"]: + if importer.domain in {"", "ai.onnx"}: return importer.version G_LOGGER.warning(f"{class_name} does not contain ONNX domain opset information! Using default opset.") return None diff --git a/onnxslim/onnx_graphsurgeon/ir/graph.py b/onnxslim/onnx_graphsurgeon/ir/graph.py index d39d820..64ab7c2 100644 --- a/onnxslim/onnx_graphsurgeon/ir/graph.py +++ b/onnxslim/onnx_graphsurgeon/ir/graph.py @@ -194,7 +194,7 @@ def __getattr__(self, name): def __setattr__(self, name, value): """Sets an attribute to the given value, converting 'inputs' and 'outputs' to lists if applicable.""" - if name in ["inputs", "outputs"]: + if name in {"inputs", "outputs"}: value = list(value) return super().__setattr__(name, value) @@ -502,8 +502,8 @@ def toposort( if mode not in ALLOWED_MODES: G_LOGGER.critical(f'Mode "{mode}" not in {ALLOWED_MODES}') - sort_nodes = mode in ["full", "nodes"] - sort_functions = mode in ["full", "functions"] + sort_nodes = mode in {"full", "nodes"} + sort_functions = mode in {"full", "functions"} if sort_nodes and recurse_functions: for func in self.functions: @@ -812,7 +812,7 @@ def run_cast_elision(node): # Search for Cast(s) (from int -> float) -> intermediate operator (with float constants) -> Cast(s) (back to int) # This pattern is problematic for TensorRT since these operations may be performed on Shape Tensors, which # are not allowed to be floating point type. Attempt to fold the pattern here - VALID_CAST_ELISION_OPS = [ + VALID_CAST_ELISION_OPS = { "Add", "Sub", "Mul", @@ -823,7 +823,7 @@ def run_cast_elision(node): "Greater", "Less", "Concat", - ] + } if node.op not in VALID_CAST_ELISION_OPS: return @@ -862,7 +862,7 @@ def run_cast_elision(node): for out_tensor in node.outputs for out_node in out_tensor.outputs if out_node.op == "Cast" - and out_node.attrs["to"] in [onnx.TensorProto.DataType.INT32, onnx.TensorProto.DataType.INT64] + and out_node.attrs["to"] in {onnx.TensorProto.DataType.INT32, onnx.TensorProto.DataType.INT64} ] # No cast node found on outputs, return early @@ -1029,7 +1029,7 @@ def fold_shape_slice(tensor): if len(slice.inputs) >= 3: starts, ends = slice.inputs[1:3] - if any(not isinstance(t, Constant) for t in [starts, ends]): + if any(not isinstance(t, Constant) for t in {starts, ends}): return None starts, ends = get_scalar_value(starts), get_scalar_value(ends) elif "starts" in slice.attrs and "ends" in slice.attrs: @@ -1070,7 +1070,7 @@ def fold_shape_slice(tensor): if fold_shapes: # NOTE: The order of shape folding passes is important to maximize how much we fold (phase-ordering problem). - SHAPE_FOLD_FUNCS = [fold_shape_gather, fold_shape_slice, fold_shape] + SHAPE_FOLD_FUNCS = {fold_shape_gather, fold_shape_slice, fold_shape} for shape_fold_func in SHAPE_FOLD_FUNCS: try: for tensor in clone_tensors.values(): diff --git a/onnxslim/onnx_graphsurgeon/ir/node.py b/onnxslim/onnx_graphsurgeon/ir/node.py index 3339cfb..804ed20 100644 --- a/onnxslim/onnx_graphsurgeon/ir/node.py +++ b/onnxslim/onnx_graphsurgeon/ir/node.py @@ -135,7 +135,7 @@ def subgraphs(self, recursive=False): def __setattr__(self, name, value): """Sets the attribute 'name' to 'value', handling special cases for 'inputs' and 'outputs' attributes.""" - if name in ["inputs", "outputs"]: + if name in {"inputs", "outputs"}: try: attr = getattr(self, name) if value is attr: diff --git a/onnxslim/onnx_graphsurgeon/ir/tensor.py b/onnxslim/onnx_graphsurgeon/ir/tensor.py index e702c24..8b858aa 100644 --- a/onnxslim/onnx_graphsurgeon/ir/tensor.py +++ b/onnxslim/onnx_graphsurgeon/ir/tensor.py @@ -34,7 +34,7 @@ def __init__(self): def __setattr__(self, name, value): """Set an attribute, ensuring special handling for "inputs" and "outputs" properties.""" - if name in ["inputs", "outputs"]: + if name in {"inputs", "outputs"}: try: attr = getattr(self, name) if value is attr: @@ -344,8 +344,8 @@ def load(self): values = np.zeros(self.tensor.dims) indices_data = np.asarray(indices_data).reshape(self.tensor.indices.dims) - for i in range(len(values_data)): - values[tuple(indices_data[i])] = values_data[i] + for value_data, index_data in zip(values_data, indices_data): + values[tuple(index_data)] = value_data else: G_LOGGER.critical(f"Unsupported index data dims {self.tensor.indices.dims} in {self.tensor.values.name}") diff --git a/onnxslim/utils.py b/onnxslim/utils.py index dda9ab5..60653c2 100644 --- a/onnxslim/utils.py +++ b/onnxslim/utils.py @@ -126,7 +126,7 @@ def gen_onnxruntime_input_data( shapes = shapes or [1] dtype = info["dtype"] - if dtype in [np.int32, np.int64]: + if dtype in {np.int32, np.int64}: random_data = np.random.randint(10, size=shapes).astype(dtype) else: random_data = np.random.rand(*shapes).astype(dtype) @@ -300,7 +300,7 @@ def dump_model_info_to_disk(model_name: str, model_info: Dict): def get_opset(model: onnx.ModelProto) -> int: try: for importer in model.opset_import: - if importer.domain in ["", "ai.onnx"]: + if importer.domain in {"", "ai.onnx"}: return importer.version return None