diff --git a/onnxslim/core/optimizer.py b/onnxslim/core/optimizer.py index 433b6e3..3b2cb83 100644 --- a/onnxslim/core/optimizer.py +++ b/onnxslim/core/optimizer.py @@ -108,6 +108,9 @@ def graph_constant_fold_inplace(graph): """Perform in-place constant folding optimizations on the given computational graph by eliminating redundant nodes. """ + for subgraph in graph.subgraphs(): + graph_constant_fold_inplace(subgraph) + for node in graph.nodes: if node.op == "Identity" or node.op == "Dropout": delete_node(node) diff --git a/onnxslim/utils.py b/onnxslim/utils.py index aba5af5..1a0251a 100644 --- a/onnxslim/utils.py +++ b/onnxslim/utils.py @@ -316,18 +316,27 @@ def get_shape(inputs: onnx.ModelProto) -> Dict[str, List[int]]: value_info_dict = {value_info.name: value_info for value_info in model.graph.value_info} - for node in model.graph.node: - op_type = node.op_type - op_type_counts[op_type] += 1 - - for output in node.output: - shapes = [] - if output in value_info_dict: - tensor = value_info_dict[output] - type_str, shape = get_tensor_dtype_shape(tensor) - shapes.append([type_str, shape]) - - op_info[node.name] = [node.op_type, shapes] + def get_graph_node_info(graph: onnx.GraphProto) -> Dict[str, List[str]]: + for node in graph.node: + op_type = node.op_type + op_type_counts[op_type] += 1 + for output in node.output: + shapes = [] + if output in value_info_dict: + tensor = value_info_dict[output] + type_str, shape = get_tensor_dtype_shape(tensor) + shapes.append([type_str, shape]) + + op_info[node.name] = [node.op_type, shapes] + + for attr in node.attribute: + ATTR_TYPE_MAPPING = {v: k for k, v in onnx.AttributeProto.AttributeType.items()} + if attr.type in ATTR_TYPE_MAPPING: + attr_str = ATTR_TYPE_MAPPING[attr.type] + if attr_str == "GRAPH": + get_graph_node_info(attr.g) + + get_graph_node_info(model.graph) model_info["op_set"] = str(get_opset(model)) model_info["op_info"] = op_info