Skip to content

Commit

Permalink
fix dtype compare bug
Browse files Browse the repository at this point in the history
  • Loading branch information
initialencounter authored Nov 12, 2024
1 parent 2e05a2a commit 25f3eb6
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion onnxslim/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto:
for node in graph.nodes:
if node.op == "Cast":
inp_dtype = [input.dtype for input in node.inputs][0]
if inp_dtype in {np.float16, np.float32}:
if inp_dtype in [np.float16, np.float32]:
delete_node(node)

for tensor in graph.tensors().values():
Expand Down

0 comments on commit 25f3eb6

Please sign in to comment.