From f455f251ba5e4fd18a9051b1d092ef27af3e88a0 Mon Sep 17 00:00:00 2001 From: inisis Date: Wed, 12 Jun 2024 16:05:17 +0000 Subject: [PATCH] fix an underlying bug --- onnxslim/core/optimizer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/onnxslim/core/optimizer.py b/onnxslim/core/optimizer.py index aa84bfe..7439d24 100644 --- a/onnxslim/core/optimizer.py +++ b/onnxslim/core/optimizer.py @@ -113,22 +113,24 @@ def graph_constant_fold_inplace(graph): for node in graph.nodes: if node.op in {"Identity", "Dropout"}: delete_node(node) - elif node.op == "Pad": if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant): pad_value = node.inputs[1].values.tolist() pad_value = pad_value if isinstance(pad_value, list) else [pad_value] if all(value == 0 for value in pad_value): delete_node(node) + logger.debug(f"removing Pad op: {node.name}") elif node.op == "Cast": inp_dtype = [dtype_to_onnx(input.dtype) for input in node.inputs][0] if inp_dtype == node.attrs["to"]: delete_node(node) + logger.debug(f"removing Cast op: {node.name}") elif node.op == "Reshape": if (node.inputs[0].shape and len(node.inputs[0].shape) == 1) and ( node.outputs[0].shape and len(node.outputs[0].shape) == 1 ): delete_node(node) + logger.debug(f"removing Reshape op: {node.name}") else: node_output_shape = node.outputs[0].shape if node_output_shape and check_shape(node_output_shape): @@ -147,21 +149,25 @@ def graph_constant_fold_inplace(graph): if np.all(constant_variable.values == 1): var_idx = 0 if idx == 1 else 1 delete_node(node, var_idx) + logger.debug(f"removing Mul op: {node.name}") elif node.op == "Add": if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or ( isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable) ): idx, constant_variable = get_constant_variable(node, return_idx=True) - if np.all(constant_variable.values == 0): + if np.all(constant_variable.values == 0) and (node.inputs[0].shape == node.inputs[1].shape): idx = 0 if idx == 1 else 1 delete_node(node, idx) + logger.debug(f"removing Add op: {node.name}") elif node.op == "Expand": if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant) and np.all(node.inputs[1].values == 1): idx = 0 if idx == 1 else 1 delete_node(node, idx) + logger.debug(f"removing Expand op: {node.name}") elif node.op == "Concat": if len(node.inputs) == 1: delete_node(node) + logger.debug(f"removing Concat op: {node.name}") else: for input in node.inputs: if isinstance(input, Constant) and input.values.size == 0: