Skip to content

Commit

Permalink
Fix bug of fp16 converter about the cast node and topology in sub-gra…
Browse files Browse the repository at this point in the history
…ph (#291)

* Update float16.py

* g

* tempor fix

* Update float16.py

* update-1

* update

* Update test_float16.py

* update

* update

* update

* Update float16.py

* update

* Delete float16.onnx

* del

* update

* Update test_float16.py

* update

* fix comments

* Update float16.py
  • Loading branch information
xiaowuhu authored Jun 6, 2024
1 parent efc2b67 commit 8beb9a2
Show file tree
Hide file tree
Showing 5 changed files with 334 additions and 210 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,5 @@ venv.bak/

# PyCharm
.idea/
tests/data/image_classifier16.onnx
tests/data/fp16_tensor.data
7 changes: 4 additions & 3 deletions onnxconverter_common/auto_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def validate(res1, res2):
def run_attempt(node_block_list, return_model=False):
print(node_block_list)
model = float16.convert_float_to_float16(copy.deepcopy(model0), node_block_list=node_block_list,
keep_io_types=keep_io_types, disable_shape_infer=True)
keep_io_types=keep_io_types, disable_shape_infer=False)
res1 = get_tensor_values_using_ort(model, feed_dict)
if return_model:
return validate(res0, res1), model
Expand Down Expand Up @@ -129,15 +129,16 @@ def get_tensor_values_using_ort(model, input_feed, output_names=None, sess_optio
# Below code is for debug only, keep it for next time use
# sess_options = ort.SessionOptions()
# sess_options.optimized_model_filepath = "d:/optimized_model.onnx"
sess = ort.InferenceSession(model.SerializeToString(), sess_options, providers=['CUDAExecutionProvider'])
sess = ort.InferenceSession(model.SerializeToString(), sess_options, providers=['CPUExecutionProvider'])
return sess.run(None, input_feed)
original_outputs = list(model.graph.output)
while len(model.graph.output) > 0:
model.graph.output.pop()
for n in output_names:
out = model.graph.output.add()
out.name = n
sess = ort.InferenceSession(model.SerializeToString(), sess_options, providers=['CUDAExecutionProvider'])
# if set to 'CUDAExecutionProvider', will be failed, need further investigation
sess = ort.InferenceSession(model.SerializeToString(), sess_options, providers=['CPUExecutionProvider'])
try:
return sess.run(output_names, input_feed)
finally:
Expand Down
Loading

0 comments on commit 8beb9a2

Please sign in to comment.