diff --git a/VERSION b/VERSION index 0e7400f..7db2672 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.25 +0.1.26 diff --git a/onnxslim/cli/_main.py b/onnxslim/cli/_main.py index 9d69061..069a56f 100644 --- a/onnxslim/cli/_main.py +++ b/onnxslim/cli/_main.py @@ -51,6 +51,7 @@ def slim( onnx.ModelProto/None: If `output_model` is None, return slimmed model else return None. """ import os + import time from pathlib import Path from onnxslim.core.slim import ( @@ -92,6 +93,8 @@ def slim( freeze(model) + start_time = time.time() + if save_as_external_data: model_save_as_external_data(model, output_model) return None @@ -147,10 +150,13 @@ def slim( if slimmed_info["model_size"] >= onnx.checker.MAXIMUM_PROTOBUF: model_size = model.ByteSize() slimmed_info["model_size"] = [model_size, slimmed_info["model_size"]] + end_time = time.time() + elapsed_time = end_time - start_time print_model_info_as_table( model_name, [float_info, slimmed_info], + elapsed_time, ) diff --git a/onnxslim/core/optimizer.py b/onnxslim/core/optimizer.py index a895a4a..3b9f84c 100644 --- a/onnxslim/core/optimizer.py +++ b/onnxslim/core/optimizer.py @@ -116,18 +116,23 @@ def graph_constant_fold_inplace(graph): if inp_dtype == node.attrs["to"]: delete_node(node) elif node.op == "Reshape": - node_output_shape = node.outputs[0].shape - if node_output_shape and check_shape(node_output_shape): - shapes = [ - shape if isinstance(shape, int) else -1 - for shape in node_output_shape - ] - reshape_const = gs.Constant( - node.inputs[1].name + "_", - values=np.array(shapes, dtype=np.int64), - ) - node.inputs.pop(1) - node.inputs.insert(1, reshape_const) + if (node.inputs[0].shape and len(node.inputs[0].shape) == 1) and ( + node.outputs[0].shape and len(node.outputs[0].shape) == 1 + ): + delete_node(node) + else: + node_output_shape = node.outputs[0].shape + if node_output_shape and check_shape(node_output_shape): + shapes = [ + shape if isinstance(shape, int) else -1 + for shape in node_output_shape + ] + reshape_const = gs.Constant( + node.inputs[1].name + "_", + values=np.array(shapes, dtype=np.int64), + ) + node.inputs.pop(1) + node.inputs.insert(1, reshape_const) elif node.op == "Mul": if ( isinstance(node.inputs[1], Constant) diff --git a/onnxslim/utils/utils.py b/onnxslim/utils/utils.py index c7d0c3a..85b185c 100644 --- a/onnxslim/utils/utils.py +++ b/onnxslim/utils/utils.py @@ -111,7 +111,9 @@ def onnxruntime_inference( return onnx_output -def print_model_info_as_table(model_name: str, model_info_list: List[Dict]): +def print_model_info_as_table( + model_name: str, model_info_list: List[Dict], elapsed_time: float = 0.0 +): assert ( len(model_info_list) > 0 ), "model_info_list must contain more than one model info" @@ -185,6 +187,8 @@ def print_model_info_as_table(model_name: str, model_info_list: List[Dict]): ["Model Size"] + [format_bytes(model_info["model_size"]) for model_info in model_info_list] ) + final_op_info.append([SEPARATING_LINE] * (len(model_info_list) + 1)) + final_op_info.append(["Elapsed Time"] + [f"{elapsed_time:.2f} s"]) lines = tabulate( final_op_info, headers=[], @@ -192,6 +196,14 @@ def print_model_info_as_table(model_name: str, model_info_list: List[Dict]): maxcolwidths=[None] + [40] * len(model_info_list), ).split("\n") + time_row = lines[-2].split("|") + time_row[-3] = ( + time_row[-2][: len(time_row[-2]) // 2 + 1] + + time_row[-3] + + time_row[-2][len(time_row[-2]) // 2 :] + ) + time_row.pop(-2) + lines[-2] = "|".join(time_row) output = "\n".join([line if line != "| \x01 |" else lines[0] for line in lines]) print(output)