Skip to content

Commit

Permalink
add size-threshold argument (#69)
Browse files Browse the repository at this point in the history
* add ort benchmark

* add size-threshold argument and better optimization
  • Loading branch information
inisis authored Jan 7, 2025
1 parent b2c84db commit f58424b
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 3 deletions.
6 changes: 6 additions & 0 deletions onnxslim/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ class OptimizationArguments:
"choices": list(onnxslim.DEFAULT_FUSION_PATTERNS.keys()),
},
)
size_threshold: int = field(
default=None,
metadata={
"help": "size threshold in bytes, size larger than this value will not be folded, default None, which means fold all constants",
},
)


@dataclass
Expand Down
4 changes: 3 additions & 1 deletion onnxslim/cli/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def slim(model: Union[str, onnx.ModelProto, List[Union[str, onnx.ModelProto]]],
no_constant_folding = kwargs.get("no_constant_folding", False)
dtype = kwargs.get("dtype", None)
skip_fusion_patterns = kwargs.get("skip_fusion_patterns", None)
size_threshold = kwargs.get("size_threshold", None)
size_threshold = int(size_threshold) if size_threshold else None
kwargs.get("inspect", False)
dump_to_disk = kwargs.get("dump_to_disk", False)
save_as_external_data = kwargs.get("save_as_external_data", False)
Expand Down Expand Up @@ -99,7 +101,7 @@ def get_info(model, inspect=False):
graph_check_point = check_point(model)
while MAX_ITER > 0:
logger.debug(f"iter: {MAX_ITER}")
model = optimize(model, skip_fusion_patterns)
model = optimize(model, skip_fusion_patterns, size_threshold)
if not no_shape_infer:
model = shape_infer(model)
graph = check_point(model)
Expand Down
4 changes: 2 additions & 2 deletions onnxslim/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ def shape_infer(model: onnx.ModelProto):
return model


def optimize(model: onnx.ModelProto, skip_fusion_patterns: str = None):
def optimize(model: onnx.ModelProto, skip_fusion_patterns: str = None, size_threshold: int = None):
"""Optimize the given ONNX model with options to skip specific fusion patterns and return the optimized model."""
logger.debug("Start converting model to gs.")
graph = gs.import_onnx(model).toposort()
logger.debug("Finish converting model to gs.")
logger.debug("Start constant folding.")
graph.fold_constants().cleanup().toposort()
graph.fold_constants(size_threshold=size_threshold).cleanup().toposort()
logger.debug("Finish constant folding.")
logger.debug("Start optimize model.")
model = optimize_model(graph, skip_fusion_patterns)
Expand Down
4 changes: 4 additions & 0 deletions onnxslim/core/optimization/dead_node_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def dead_node_elimination(graph, is_subgraph=False):
elif np.all(value == 1) and (node.inputs[0].shape == node.outputs[0].shape):
node.replace_all_uses_with(node.feeds[0])
logger.debug(f"removing {node.op} op: {node.name}")
elif node.op == "Split":
if len(node.outputs) == 1 and node.outputs[0].shape and node.inputs[0].shape and node.outputs[0].shape == node.inputs[0].shape:
node.replace_all_uses_with(node.feeds[0])
logger.debug(f"removing {node.op} op: {node.name}")


def check_shape(shapes):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ def bench_polygraphy(input, output):
result = bench_main(command)
return result

def bench_onnxruntime(input, output):
try:
import onnxruntime as rt
sess_options = rt.SessionOptions()
# Set graph optimization level
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
# To enable model serialization after graph optimization set this
sess_options.optimized_model_filepath = output
session = rt.InferenceSession(input, sess_options)
return True

except Exception as e:
print(e)
return None


def bench_onnxruntime(input, output):
try:
Expand Down

0 comments on commit f58424b

Please sign in to comment.