Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper handling of repeated fp16 conversion. #310

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions onnxconverter_common/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
###########################################################################

import itertools
import uuid
import warnings
from typing import Optional

import numpy as np
import onnx
import packaging.version as pv
import warnings
from onnx import helper, numpy_helper
from onnx import onnx_pb as onnx_proto


FLOAT32 = 1
FLOAT16 = 10

Expand Down Expand Up @@ -239,32 +241,32 @@ def process_node_in_block_list(graph: onnx_proto.GraphProto, global_input_name_d

# Todo: global_input_name_dict still not fill value
def insert_cast32_before_node(graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict):
for i in range(len(node.input)):
for i, input_name in enumerate(node.input):
input_name = node.input[i]
for value_info in itertools.chain(graph.value_info, graph.input):
if input_name == value_info.name:
if value_info.type.tensor_type.elem_type != onnx_proto.TensorProto.FLOAT16:
break
cast_output_name = node.name + "_input_cast_" + str(i)
cast_node_name = f"onnxconverter_inserted_cast_{str(uuid.uuid4())}"
cast_output_name = f"{cast_node_name}_output"
add_new_value_info(graph, value_info, cast_output_name, onnx_proto.TensorProto.FLOAT)
cast_node_name = node.name + "_input_cast" + str(i)
add_cast_node(graph, [input_name], [cast_output_name], cast_node_name, onnx_proto.TensorProto.FLOAT)
node.input[i] = cast_output_name
break


# Todo: global_input_name_dict still not fill value
def insert_cast16_after_node(graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict):
for i in range(len(node.output)):
for i, output_name in enumerate(node.output):
output_name = node.output[i]
for value_info in itertools.chain(graph.value_info, graph.output):
if output_name == value_info.name:
if value_info.type.tensor_type.elem_type != onnx_proto.TensorProto.FLOAT:
break
cast_input_name = node.name + "_output_cast_" + str(i)
cast_node_name = f"onnxconverter_inserted_cast_{str(uuid.uuid4())}"
cast_input_name = f"{cast_node_name}_input"
add_new_value_info(graph, value_info, cast_input_name, onnx_proto.TensorProto.FLOAT)
value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16
cast_node_name = node.name + "_output_cast" + str(i)
add_cast_node(graph, [cast_input_name], [output_name], cast_node_name, onnx_proto.TensorProto.FLOAT16)
node.output[i] = cast_input_name
break
Expand All @@ -274,7 +276,8 @@ def insert_cast16_after_node(graph: onnx_proto.GraphProto, node: onnx_proto.Node
def process_tensor_in_node(graph: onnx_proto.GraphProto, op_block_list: list, node_block_list: list, min_positive_val, max_finite_val):
value_info_block_list = set() # This is for later use, not in this step
for node in graph.node:
if (node.op_type in op_block_list) or (node.name in node_block_list):
# NOTE: "Cast" operation cannot change its output type because it is strongly typed.
if (node.op_type in op_block_list) or (node.name in node_block_list) or (node.op_type == "Cast"):
# Only need to block the output value_info changing
for output_name in node.output:
value_info_block_list.add(output_name)
Expand Down Expand Up @@ -519,10 +522,31 @@ def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto):
if upstream_node.op_type == 'Constant':
cast_node_list.remove(cast_node)

# 4. find the cast(to16) node which downstream is Cast(to32)
# 4. find (cast_to_fp16, cast_to_fp32) pairs where --fp32--> cast_to_fp16 --fp16--> cast_to_fp32.
remove_candidate = []

name_to_value_info = {
value_info.name: value_info for value_info in itertools.chain(graph_proto.value_info, graph_proto.input)
}

def get_type(name: str) -> Optional[int]:
if name in name_to_value_info:
return name_to_value_info[name].type
else:
# `name` has no value info.
return None

for cast_node_name, downstream_node in cast_node_downstream_dict.items():
cast_node = name_to_node_dict[cast_node_name]
if len(cast_node.input) != 1:
raise RuntimeError(
f"Cast node {cast_node_name} should have only one input, but has {len(cast_node.input)}."
)

input_type = get_type(cast_node.input[0])
if input_type != onnx_proto.TensorProto.FLOAT:
continue

if isinstance(downstream_node, list):
for dn in downstream_node:
if dn.op_type == 'Cast' and \
Expand All @@ -539,7 +563,8 @@ def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto):
cast_node in cast_node_list:
remove_candidate.append((cast_node, downstream_node))

# 5. change the connection of "upstream->cast16->cast32->downstream" to "upstream->downstream"
# 5. change "upstream --fp32--> cast_to_fp16 --fp16--> cast_to_fp32 --fp32--> downstream" to
# "upstream --fp32--> downstream".
for cast_node_pair in remove_candidate:
first_cast_node = cast_node_pair[0]
second_cast_node = cast_node_pair[1]
Expand Down