Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NNCFOVQuantizer and NNCFFXQuantizer #31

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 268 additions & 0 deletions nncf/experimental/torch/fx/constant_folding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.utils._pytree as pytree

aten = torch.ops.aten


def _replace_node_with_constant(
gm: torch.fx.GraphModule,
node: torch.fx.Node,
constant: torch.Tensor,
name: Optional[str] = None,
) -> None:
g = gm.graph

if name:
qualname = name
else:
if not hasattr(gm, "_frozen_param_count"):
gm._frozen_param_count = 0 # type: ignore[assignment]
i = gm._frozen_param_count

while True:
qualname = f"_frozen_param{i}"
if not hasattr(gm, qualname):
break
i += 1

gm._frozen_param_count = i + 1

with g.inserting_before(node):
new_input_node = g.create_node("get_attr", qualname, (), {})
node.replace_all_uses_with(new_input_node)
new_input_node.meta.update(node.meta)
g.erase_node(node)

# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
gm.register_buffer(qualname, constant)
setattr(gm, qualname, constant)


def _is_const_source(node: torch.fx.Node, lifted_constants: Optional[Dict[str, Any]]) -> bool:
return node.op == "get_attr" or (
node.op == "placeholder" and lifted_constants is not None and node.name in lifted_constants
)


class _ConstantFolder(torch.fx.Interpreter):
def __init__(
self,
gm: torch.fx.GraphModule,
skip_constructors: bool = False,
lifted_constants: Optional[Dict[str, torch.Tensor]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> None:
super().__init__(gm)
self.node_replacements: Dict[torch.fx.Node, Any] = {}
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
self.unknown_value = object()
self.skip_constructors: bool = skip_constructors

# overwrite this to deallocate env values if their only remaining use
# is the output
self.user_to_last_uses = self.node_to_last_non_output_use()
self.lifted_constants = lifted_constants

def _support_dynamic_shape(self) -> bool:
# ConstantFolder not support dynamic shape now
return False

def _deduce_value(self, node: torch.fx.Node) -> Any:
return super().run_node(node)

def is_impure(self, node: torch.fx.node.Node) -> bool:
def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
return (
node.target == torch.ops.prims.convert_element_type.default # type: ignore[return-value]
and isinstance(node.args[0], torch.fx.Node)
and "val" in node.args[0].meta
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
and node.args[1] == torch.bfloat16
)

if (
is_woq_int8_pattern(node)
or (
node.target == torch.ops.aten.permute.default
and len(node.users) == 1
and is_woq_int8_pattern(next(iter(node.users)))
)
) and _is_const_source(
node.args[0], self.lifted_constants # type: ignore[arg-type]
):
# Case 1: int8_weight -> dq -> bf16_weight
# Case 2: int8_weight -> permute -> dq -> bf16_weight
return True

quant_registered = getattr(torch.ops.quantized_decomposed, "dequantize_per_channel", None) is not None
if quant_registered and node.target in [
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
]:
# For the pattern fp32_weight -> q -> dq
# We only folding fp32_weight -> q
# int8_weight and leave dq in graph to be fused
return True
return False

def node_to_last_non_output_use(self) -> Dict[torch.fx.Node, List[torch.fx.Node]]:
last_non_output_use = collections.defaultdict(list)
seen_uses = set()
output_node = next(iter(reversed(self.module.graph.nodes)))

for node in reversed(self.module.graph.nodes):
if node.target == "output":
continue

def add_use(inp: torch.fx.Node) -> None:
if inp in seen_uses:
return

seen_uses.add(inp)
last_non_output_use[node].append(inp)

# In-place is fine since we don't mutate
pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs))

# if this node is only used in output, we want to gc it right away
if len(node.users) == 1 and output_node in node.users:
last_non_output_use[node].append(node)

return last_non_output_use

def run_node(self, node: torch.fx.Node) -> Any:
if node.target == "output":
# because we remove nodes from env on last non output use,
# re-define them now or we'll get error in interpreter
def set_env(arg: torch.fx.Node) -> None:
self.env[arg] = self.unknown_value

# In-place is fine since we don't mutate
pytree.tree_map_only_(torch.fx.Node, set_env, node.args)
return super().run_node(node)

args, kwargs = self.fetch_args_kwargs_from_env(node)
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)

# We need to do this weird thing because in cases where flattened_inputs
# contains a ScriptObject, equality checking results in a type error if
# the types are different.
if any(
type(self.unknown_value) is type(input_) and self.unknown_value == input_ for input_ in flattened_inputs
):
return self.unknown_value

# TODO - fix errors with this
if node.op == "call_function" and node.target == aten._efficientzerotensor.default:
return self.unknown_value

# TODO - constant folding triton kernel returns the inputs -- fix this
if node.op == "call_function" and node.name == "triton_kernel_wrapper_functional_proxy":
return self.unknown_value

# skip constructors, since inductor generates optimal code for them already
# and turning into tensor would result in an additional global memory read
# TODO - more complicated strategy
if (
self.skip_constructors
and not _is_const_source(node, self.lifted_constants)
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
):
return self.unknown_value

# All mutations should either be removed or on inputs which we did not make constant
if isinstance(node.target, torch._ops.OpOverload) and torch.Tag.nondeterministic_seeded in node.target.tags:
return self.unknown_value

out = self._deduce_value(node)
if out == self.unknown_value:
return self.unknown_value

if not _is_const_source(node, self.lifted_constants) and isinstance(out, torch.Tensor):
if out.device.type == "meta":
return out

if not self.insertable_tensor_check(out):
return out

if self.is_impure(node):
return self.unknown_value

self.add_node_replacement(node, out)

flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)

for n in flattened_node_inps:
if not isinstance(n, torch.fx.Node):
continue

self.replaced_uses[n] += 1

for to_delete in self.user_to_last_uses.get(node, []):
if self.replaced_uses[to_delete] == len(to_delete.users):
self.node_replacements.pop(to_delete, None)

return out

def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
return True

def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
self.node_replacements[node] = tensor

def run(self) -> Any: # type: ignore[override]
env: Dict[torch.fx.Node, Any] = {}
self.insert_placerholder_values(env)
return super().run(initial_env=env)

def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
for n in self.module.graph.find_nodes(op="placeholder"):
if self.lifted_constants is not None and n.name in self.lifted_constants:
env[n] = self.lifted_constants[n.name]
else:
env[n] = self.unknown_value # type: ignore[assignment]


def constant_fold(
gm: torch.fx.GraphModule,
) -> None:
"""
Calcualtes constant subgraphs values and replaces them with a constan node inplace.

:param gm: Given graph model.
"""
with torch.utils._python_dispatch._disable_current_modes():
cf = _ConstantFolder(gm, skip_constructors=True)
cf.run()

for node, constant in cf.node_replacements.items():
_replace_node_with_constant(gm, node, constant)

erased_params = []
for node in gm.graph.find_nodes(op="get_attr"):
if len(node.users) == 0:
if hasattr(gm, node.target):
delattr(gm, node.target)
erased_params.append(node)

for node in erased_params:
gm.graph.erase_node(node)

gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
31 changes: 28 additions & 3 deletions nncf/experimental/torch/fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,14 @@ def _apply_model_extraction(
def remap_fn(node: torch.fx.Node):
return value_remap.get(node) # noqa F821

visited_outputs_names = []
for node in model.graph.nodes:
if node.name not in visited or node.op == "output":
if node.name not in visited:
continue
if node.op == "output":
visited_outputs_names.append(node.name)
continue
value_remap[node] = extracted_graph.node_copy(node, remap_fn)
del value_remap

for input_name in transformation.input_node_names:
node_with_input = get_graph_node_by_name(extracted_graph, input_name)
Expand All @@ -146,7 +149,29 @@ def remap_fn(node: torch.fx.Node):
args[0] = graph_input
node_with_input.args = tuple(args)

nodes_with_output = [get_graph_node_by_name(extracted_graph, name) for name in transformation.output_node_names]
# Merge new output with the original output in case
# the original output is requested in the extracted graph.
nodes_with_output = []
for name in transformation.output_node_names:
nodes_with_output.append(
name if name in visited_outputs_names else get_graph_node_by_name(extracted_graph, name)
)

for idx, node in enumerate(nodes_with_output):
if isinstance(node, torch.fx.Node):
continue
output_node = get_graph_node_by_name(model.graph, node)
args = output_node.args[0]
if isinstance(args, torch.fx.Node):
args = value_remap[args]
else:
args = [value_remap[n] for n in args]
# Unpack target output args in case
# only one arg is presented.
if len(args) == 1:
args = args[0]
nodes_with_output[idx] = args

last_node = list(extracted_graph.nodes)[-1]
with extracted_graph.inserting_after(last_node):
graph_output_name = "output"
Expand Down
15 changes: 7 additions & 8 deletions nncf/experimental/torch/fx/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,10 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
:return: True if the node has a bias, False otherwise.
"""
# Assumes that all biases were unfused
if node.metatype in FX_OPERATORS_WITH_BIAS_METATYPES:
next_nodes = nncf_graph.get_next_nodes(node)
if len(next_nodes) != 1:
return False
return next_nodes[0].metatype in (om.PTAddMetatype,)
if node.metatype not in FX_OPERATORS_WITH_BIAS_METATYPES or len(nncf_graph.get_input_edges(node)) != 3:
return False
const_node = nncf_graph.get_input_edge_by_port_id(node, 2).from_node
return const_node.metatype is om.PTConstNoopMetatype


def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphModule) -> Tensor:
Expand All @@ -82,7 +81,7 @@ def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphM
:param model: Target GraphModule.
:return: Bias value of the given node.
"""
bias_node = nncf_graph.get_next_nodes(node)[0]
bias_node = nncf_graph.get_input_edge_by_port_id(node, 2).from_node
# TODO(dlyakhov): make a node_name_vs_node map to speed up the process
graph_bias_node = get_graph_node_by_name(model.graph, bias_node.node_name)
return Tensor(get_tensor_constant_from_node(graph_bias_node.all_input_nodes[1], model))
graph_bias_const = get_graph_node_by_name(model.graph, bias_node.node_name)
return Tensor(get_tensor_constant_from_node(graph_bias_const, model))
Loading