Skip to content

Commit

Permalink
[TorchFX] Bias fusing is removed from default transformations
Browse files Browse the repository at this point in the history
Code comments

Constant folding empty input test case
  • Loading branch information
daniil-lyakhov committed Nov 28, 2024
1 parent 7ea17f2 commit 13f125f
Show file tree
Hide file tree
Showing 29 changed files with 19,504 additions and 19,576 deletions.
17 changes: 16 additions & 1 deletion nncf/experimental/torch/fx/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,19 @@ def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
env[n] = self.unknown_value # type: ignore[assignment]


def _is_impure(node: torch.fx.Node) -> bool:
"""
Returns True if the node call affects the model outputs even in case
the node have zero users, False otherwise.
:param node: A node to check.
:return: True if the node call affects the model outputs even in case
the node have zero users, False otherwise.
"""
return node.op in {"placeholder", "output"}


def constant_fold(
gm: torch.fx.GraphModule,
constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
Expand Down Expand Up @@ -252,6 +265,8 @@ def constant_fold(
for node in erased_params:
gm.graph.erase_node(node)

gm.graph.eliminate_dead_code()
# Custom _is_impure function allows to eliminate all layers with zero
# users including inplace ops like relu_ besides output and placeholders.
gm.graph.eliminate_dead_code(_is_impure)
gm.graph.lint()
gm.recompile()
35 changes: 32 additions & 3 deletions nncf/experimental/torch/fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,14 @@ def _apply_model_extraction(
def remap_fn(node: torch.fx.Node):
return value_remap.get(node) # noqa F821

visited_outputs_names = []
for node in model.graph.nodes:
if node.name not in visited or node.op == "output":
if node.name not in visited:
continue
if node.op == "output":
visited_outputs_names.append(node.name)
continue
value_remap[node] = extracted_graph.node_copy(node, remap_fn)
del value_remap

for input_name in transformation.input_node_names:
node_with_input = get_graph_node_by_name(extracted_graph, input_name)
Expand All @@ -149,7 +152,33 @@ def remap_fn(node: torch.fx.Node):
args[0] = graph_input
node_with_input.args = tuple(args)

nodes_with_output = [get_graph_node_by_name(extracted_graph, name) for name in transformation.output_node_names]
# Merge new output with the original output in case
# the original output is requested in the extracted graph.
nodes_with_output = []
for name in transformation.output_node_names:
nodes_with_output.append(
name if name in visited_outputs_names else get_graph_node_by_name(extracted_graph, name)
)

for idx, node in enumerate(nodes_with_output):
if isinstance(node, torch.fx.Node):
continue
# Current node is the original graph output.
# Should be replaced by its arguments.
output_node = get_graph_node_by_name(model.graph, node)
args = output_node.args[0]
if isinstance(args, torch.fx.Node):
# Case of non tuple output.
args = value_remap[args]
else:
# Case of tuple output.
args = [value_remap[n] for n in args]
# Unpack target output args in case
# the only one arg is presented.
if len(args) == 1:
args = args[0]
nodes_with_output[idx] = args

last_node = list(extracted_graph.nodes)[-1]
with extracted_graph.inserting_after(last_node):
graph_output_name = "output"
Expand Down
15 changes: 7 additions & 8 deletions nncf/experimental/torch/fx/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,10 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
:return: True if the node has a bias, False otherwise.
"""
# Assumes that all biases were unfused
if node.metatype in FX_OPERATORS_WITH_BIAS_METATYPES:
next_nodes = nncf_graph.get_next_nodes(node)
if len(next_nodes) != 1:
return False
return next_nodes[0].metatype in (om.PTAddMetatype,)
if node.metatype not in FX_OPERATORS_WITH_BIAS_METATYPES or len(nncf_graph.get_input_edges(node)) != 3:
return False
const_node = nncf_graph.get_input_edge_by_port_id(node, 2).from_node
return const_node.metatype is om.PTConstNoopMetatype


def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphModule) -> Tensor:
Expand All @@ -82,7 +81,7 @@ def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphM
:param model: Target GraphModule.
:return: Bias value of the given node.
"""
bias_node = nncf_graph.get_next_nodes(node)[0]
bias_node = nncf_graph.get_input_edge_by_port_id(node, 2).from_node
# TODO(dlyakhov): make a node_name_vs_node map to speed up the process
graph_bias_node = get_graph_node_by_name(model.graph, bias_node.node_name)
return Tensor(get_tensor_constant_from_node(graph_bias_node.all_input_nodes[1], model))
graph_bias_const = get_graph_node_by_name(model.graph, bias_node.node_name)
return Tensor(get_tensor_constant_from_node(graph_bias_const, model))
8 changes: 1 addition & 7 deletions nncf/experimental/torch/fx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from nncf.experimental.torch.fx.transformations import apply_quantization_transformations
from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation
from nncf.experimental.torch.fx.transformations import fq_weights_transformation
from nncf.experimental.torch.fx.transformations import revert_quantization_transformations
from nncf.parameters import BackupMode
from nncf.parameters import CompressWeightsMode
from nncf.parameters import ModelType
Expand Down Expand Up @@ -85,17 +84,12 @@ def quantize_impl(
advanced_parameters=advanced_parameters,
)

# To make it easier for bias correction algorithms,
# biases are being separated by the followng calls.
# To make it easier for bias correction algorithms.
apply_quantization_transformations(copied_model)

nncf_graph = NNCFGraphFactory.create(copied_model)
quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset)

# Revert applied transformation to keep original model
# bias configuration.
revert_quantization_transformations(quantized_model)

if is_weight_compression_needed(advanced_parameters):
compress_post_quantize_transformation(quantized_model)
else:
Expand Down
228 changes: 0 additions & 228 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,33 +160,6 @@ def leaf_module_insertion_transformation(model: torch.fx.GraphModule):
return leaf_module_insertion_transformation


def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor, input_port_id: int) -> TransformationFNType:
"""
Return transformation which updates constant of the given node with bias to the given value.
:param node: Node with bias which requires bias constant update.
:param value: New value to use as the bias constant.
:param input_port_id: Input port id to get constant node from.
:return: Transformation which updates constant of the given node with bias to the given value.
"""

def bias_update_transformation(model: torch.fx.GraphModule):
graph = model.graph
target_node_name = node.node_name
graph_node = get_graph_node_by_name(graph, target_node_name)
add_nodes = []
for user in graph_node.users:
if _is_add(user):
add_nodes.append(user)
if len(add_nodes) != 1:
raise nncf.InternalError(f"Node {graph_node.name} has {len(add_nodes)} outputs with adds, 1 expected")

bias_node = add_nodes[0]
constant_update_fn(model, bias_node, value, input_port_id=input_port_id)

return bias_update_transformation


def constant_update_transformation_builder(
node: NNCFNode, value: torch.Tensor, input_port_id: int = 1
) -> TransformationFNType:
Expand Down Expand Up @@ -794,8 +767,6 @@ def apply_quantization_transformations(model: torch.fx.GraphModule) -> None:
# are being fused
fold_constant_except_qdq(model)
fuse_conv_bn(model)
separate_conv_and_bias(model)
separate_linear_and_bias(model)


def fold_constant_except_qdq(model: torch.fx.GraphModule):
Expand All @@ -811,26 +782,6 @@ def constraint_fn(node: torch.fx.Node):
constant_fold(model, constraint_fn=constraint_fn)


def revert_quantization_transformations(model: torch.fx.GraphModule) -> None:
"""
Reverts quantization transformations from the model.
:param model: Model to revert transformations from.
"""
merge_conv_and_bias(model)
merge_linear_and_bias(model)


def _is_linear(n: torch.fx.Node) -> bool:
"""
Return whether the node refers to an aten linear op.
:param n: The given node.
:return: True if given node is a linear node, else False.
"""
return n.op == "call_function" and n.target in (torch.ops.aten.linear.default,)


def _is_conv(n: torch.fx.Node):
"""
Return whether the node refers to an aten conv op.
Expand All @@ -840,182 +791,3 @@ def _is_conv(n: torch.fx.Node):
torch.ops.aten.conv2d.default,
torch.ops.aten.conv_transpose2d.input,
)


def _is_add(n: torch.fx.Node):
"""
Return whether the node refers to an aten add op.
"""
return n.op == "call_function" and n.target in (
torch.ops.aten.add_.Tensor,
torch.ops.aten.add.Tensor,
)


def separate_linear_and_bias(model: torch.fx.GraphModule):
"""
Separates one joined linear+bias node to two nodes: conv and bias.
Needed as nncf does not expect joined conv
:param model: Target model.
"""
add_node_target = torch.ops.aten.add.Tensor
for n in model.graph.nodes:
if not _is_linear(n):
continue
# This check also makes sure to ignore linear nodes which might already
# have quantization applied to the weights.
if len(n.args) < 3 or n.args[2] is None or n.args[1].op != "get_attr":
continue
linear_node = n
linear_bias_node = linear_node.args[2]
while linear_bias_node.op != "get_attr":
# Assume zero argument is on a path to the constant
linear_bias_node = linear_bias_node.args[0]
linear_bias_value = get_tensor_constant_from_node(linear_bias_node, model)
args = list(n.args)
args[2] = None
linear_node.args = tuple(args)
with model.graph.inserting_after(linear_node):
new_linear_bias_node = create_getattr_from_value(
model,
model.graph,
linear_bias_node.name + "_",
linear_bias_value,
)
with model.graph.inserting_after(new_linear_bias_node):
add_node = model.graph.create_node(
"call_function", add_node_target, (linear_node, new_linear_bias_node), {}
)
for user in list(linear_node.users):
if user is add_node:
continue
user.replace_input_with(linear_node, add_node)
if "val" in linear_node.meta:
add_node.meta["val"] = linear_node.meta["val"]
model.graph.eliminate_dead_code()
model.recompile()


def separate_conv_and_bias(model: torch.fx.GraphModule):
"""
Separates one joined conv+bias node to two nodes: conv and bias.
Needed as nncf does not expect joined conv
:param model: Target model.
"""
add_node_target = torch.ops.aten.add_.Tensor
for n in model.graph.nodes:
if not _is_conv(n):
continue
# This check also makes sure to ignore convolution nodes which might
# already have quantization applied to the weights.
if len(n.args) < 3 or n.args[2] is None or n.args[1].op != "get_attr":
continue
conv_node = n
dims = len(get_tensor_constant_from_node(conv_node.args[1], model).shape)
conv_bias_node = conv_node.args[2]
conv_bias_value = get_tensor_constant_from_node(conv_bias_node, model)
args = list(n.args)
args[2] = None
conv_node.args = tuple(args)
with model.graph.inserting_after(conv_node):
new_conv_bias_node = create_getattr_from_value(
model, model.graph, conv_bias_node.name + "_", conv_bias_value.reshape((1, -1) + (1,) * (dims - 2))
)
with model.graph.inserting_after(new_conv_bias_node):
add_node = model.graph.create_node("call_function", add_node_target, (conv_node, new_conv_bias_node), {})
for user in list(conv_node.users):
if user is add_node:
continue
user.replace_input_with(conv_node, add_node)

if "val" in conv_node.meta:
add_node.meta["val"] = conv_node.meta["val"]
model.graph.eliminate_dead_code()
model.recompile()


def merge_conv_and_bias(model: torch.fx.GraphModule):
"""
Merges two separate conv and bias nodes to a one node: conv+bias.
Needed as nncf does not expect joined conv
:param model: Target model.
"""
_merge_node_and_bias(model, _is_conv)


def merge_linear_and_bias(model: torch.fx.GraphModule):
"""
Merges two separate linear and bias nodes to a one node: linear+bias.
:param model: Target model.
"""
_merge_node_and_bias(model, _is_linear)


def _get_connected_nodes(graph: torch.fx.Graph) -> List[torch.fx.Node]:
"""
Returns the List of nodes which are directly or indirectly connected
to the output node.
:param graph: The torch FX graph to get nodes from.
"""
output_nodes = [node for node in graph.nodes if node.op == "output"]
assert len(output_nodes) == 1
output_node = output_nodes[0]
connected_nodes = set() # Every node is unique in the graph
nodes_to_visit = [output_node]
while nodes_to_visit:
current_node = nodes_to_visit.pop()
if current_node in connected_nodes:
continue
connected_nodes.add(current_node)
nodes_to_visit.extend(current_node.all_input_nodes)
return list(connected_nodes)


def _merge_node_and_bias(model: torch.fx.GraphModule, is_target_node: Callable[[torch.fx.Node], bool]):
"""
Merges two separate node and bias node to a one node: node+bias.
Check which node should be merged by the given `is_target_node` predicate.
:param model: Target model.
:param is_target_node: Predicate to specify nodes which should be merged with the bias
"""
add_node_targets = (torch.ops.aten.add.Tensor, torch.ops.aten.add_.Tensor)
for n in model.graph.nodes:
if not is_target_node(n):
continue
if len(n.args) > 2 and n.args[2] is not None:
continue
bias_node = next(iter(n.users))
if len(n.users) > 1 or bias_node.target not in add_node_targets:
continue
conv_node = n
const_node = None
for node in bias_node.all_input_nodes:
if node is not conv_node:
const_node = node
break
assert const_node is not None
bias_value = get_tensor_constant_from_node(const_node, model).squeeze()
with model.graph.inserting_before(conv_node):
new_bias_node = create_getattr_from_value(model, model.graph, const_node.name + "_", bias_value)
args = list(conv_node.args)
args[2] = new_bias_node
conv_node.args = tuple(args)
for user in list(bias_node.users):
user.replace_input_with(bias_node, conv_node)

# Remove nodes which are not connected to output. This removes dead nodes and dead subgraphs in the model graph.
nodes_connected_to_output = _get_connected_nodes(model.graph)
is_impure = lambda node: node in nodes_connected_to_output

for node in reversed(model.graph.nodes):
if not is_impure(node) and len(node.users) == 0:
model.graph.erase_node(node)

model.graph.eliminate_dead_code()
model.recompile()
Loading

0 comments on commit 13f125f

Please sign in to comment.