Skip to content

Commit

Permalink
[Release] 0.1.26
Browse files Browse the repository at this point in the history
1. add elapse time print; 2. ready for 0.1.26 (#26)
  • Loading branch information
inisis authored May 14, 2024
1 parent 8c1ec8e commit 73021c4
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 14 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.25
0.1.26
6 changes: 6 additions & 0 deletions onnxslim/cli/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def slim(
onnx.ModelProto/None: If `output_model` is None, return slimmed model else return None.
"""
import os
import time
from pathlib import Path

from onnxslim.core.slim import (
Expand Down Expand Up @@ -92,6 +93,8 @@ def slim(

freeze(model)

start_time = time.time()

if save_as_external_data:
model_save_as_external_data(model, output_model)
return None
Expand Down Expand Up @@ -147,10 +150,13 @@ def slim(
if slimmed_info["model_size"] >= onnx.checker.MAXIMUM_PROTOBUF:
model_size = model.ByteSize()
slimmed_info["model_size"] = [model_size, slimmed_info["model_size"]]
end_time = time.time()
elapsed_time = end_time - start_time

print_model_info_as_table(
model_name,
[float_info, slimmed_info],
elapsed_time,
)


Expand Down
29 changes: 17 additions & 12 deletions onnxslim/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,23 @@ def graph_constant_fold_inplace(graph):
if inp_dtype == node.attrs["to"]:
delete_node(node)
elif node.op == "Reshape":
node_output_shape = node.outputs[0].shape
if node_output_shape and check_shape(node_output_shape):
shapes = [
shape if isinstance(shape, int) else -1
for shape in node_output_shape
]
reshape_const = gs.Constant(
node.inputs[1].name + "_",
values=np.array(shapes, dtype=np.int64),
)
node.inputs.pop(1)
node.inputs.insert(1, reshape_const)
if (node.inputs[0].shape and len(node.inputs[0].shape) == 1) and (
node.outputs[0].shape and len(node.outputs[0].shape) == 1
):
delete_node(node)
else:
node_output_shape = node.outputs[0].shape
if node_output_shape and check_shape(node_output_shape):
shapes = [
shape if isinstance(shape, int) else -1
for shape in node_output_shape
]
reshape_const = gs.Constant(
node.inputs[1].name + "_",
values=np.array(shapes, dtype=np.int64),
)
node.inputs.pop(1)
node.inputs.insert(1, reshape_const)
elif node.op == "Mul":
if (
isinstance(node.inputs[1], Constant)
Expand Down
14 changes: 13 additions & 1 deletion onnxslim/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def onnxruntime_inference(
return onnx_output


def print_model_info_as_table(model_name: str, model_info_list: List[Dict]):
def print_model_info_as_table(
model_name: str, model_info_list: List[Dict], elapsed_time: float = 0.0
):
assert (
len(model_info_list) > 0
), "model_info_list must contain more than one model info"
Expand Down Expand Up @@ -185,13 +187,23 @@ def print_model_info_as_table(model_name: str, model_info_list: List[Dict]):
["Model Size"]
+ [format_bytes(model_info["model_size"]) for model_info in model_info_list]
)
final_op_info.append([SEPARATING_LINE] * (len(model_info_list) + 1))
final_op_info.append(["Elapsed Time"] + [f"{elapsed_time:.2f} s"])
lines = tabulate(
final_op_info,
headers=[],
tablefmt="pretty",
maxcolwidths=[None] + [40] * len(model_info_list),
).split("\n")

time_row = lines[-2].split("|")
time_row[-3] = (
time_row[-2][: len(time_row[-2]) // 2 + 1]
+ time_row[-3]
+ time_row[-2][len(time_row[-2]) // 2 :]
)
time_row.pop(-2)
lines[-2] = "|".join(time_row)
output = "\n".join([line if line != "| \x01 |" else lines[0] for line in lines])

print(output)
Expand Down

0 comments on commit 73021c4

Please sign in to comment.