Skip to content

Commit

Permalink
add support for subgraph
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Jun 5, 2024
1 parent e96d82a commit c095a5d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
3 changes: 3 additions & 0 deletions onnxslim/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def graph_constant_fold_inplace(graph):
"""Perform in-place constant folding optimizations on the given computational graph by eliminating redundant
nodes.
"""
for subgraph in graph.subgraphs():
graph_constant_fold_inplace(subgraph)

for node in graph.nodes:
if node.op == "Identity" or node.op == "Dropout":
delete_node(node)
Expand Down
33 changes: 21 additions & 12 deletions onnxslim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,18 +316,27 @@ def get_shape(inputs: onnx.ModelProto) -> Dict[str, List[int]]:

value_info_dict = {value_info.name: value_info for value_info in model.graph.value_info}

for node in model.graph.node:
op_type = node.op_type
op_type_counts[op_type] += 1

for output in node.output:
shapes = []
if output in value_info_dict:
tensor = value_info_dict[output]
type_str, shape = get_tensor_dtype_shape(tensor)
shapes.append([type_str, shape])

op_info[node.name] = [node.op_type, shapes]
def get_graph_node_info(graph: onnx.GraphProto) -> Dict[str, List[str]]:
for node in graph.node:
op_type = node.op_type
op_type_counts[op_type] += 1
for output in node.output:
shapes = []
if output in value_info_dict:
tensor = value_info_dict[output]
type_str, shape = get_tensor_dtype_shape(tensor)
shapes.append([type_str, shape])

op_info[node.name] = [node.op_type, shapes]

for attr in node.attribute:
ATTR_TYPE_MAPPING = {v: k for k, v in onnx.AttributeProto.AttributeType.items()}
if attr.type in ATTR_TYPE_MAPPING:
attr_str = ATTR_TYPE_MAPPING[attr.type]
if attr_str == "GRAPH":
get_graph_node_info(attr.g)

get_graph_node_info(model.graph)

model_info["op_set"] = str(get_opset(model))
model_info["op_info"] = op_info
Expand Down

0 comments on commit c095a5d

Please sign in to comment.