Skip to content

Commit

Permalink
adapt to new api
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Nov 22, 2024
1 parent 8b1420d commit ad8d6ec
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions onnxslim/cli/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import onnx


def slim(model: List[Union[str, onnx.ModelProto]], *args, **kwargs):
def slim(model: Union[str, onnx.ModelProto, List[Union[str, onnx.ModelProto]]], *args, **kwargs):
import os
import time
from pathlib import Path
Expand Down Expand Up @@ -64,15 +64,15 @@ def get_info(model, inspect=False):

model_info = summarize_model(model, model_name)

return model_name, model_info
return model_info

if isinstance(model, list):
model_name_list, model_info_list = zip(*[get_info(m, inspect=True) for m in model])
model_info_list = (get_info(m, inspect=True) for m in model)

if dump_to_disk:
[dump_model_info_to_disk(name, info) for name, info in zip(model_name_list, model_info_list)]
[dump_model_info_to_disk(info) for info in model_info_list]

print_model_info_as_table(model_name_list[0], model_info_list)
print_model_info_as_table(model_info_list)

return
else:
Expand Down Expand Up @@ -127,7 +127,6 @@ def get_info(model, inspect=False):
end_time = time.time()
elapsed_time = end_time - start_time
print_model_info_as_table(
model_name,
[original_info, slimmed_info],
elapsed_time,
)
Expand Down

0 comments on commit ad8d6ec

Please sign in to comment.