Skip to content

Commit

Permalink
Torch cat layer attributes update
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 15, 2023
1 parent 13ca0a6 commit 6c4d598
Show file tree
Hide file tree
Showing 4 changed files with 384 additions and 334 deletions.
30 changes: 13 additions & 17 deletions nncf/torch/dynamic_graph/layer_attributes_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@
from nncf.common.graph.layer_attributes import PermuteLayerAttributes
from nncf.common.graph.layer_attributes import ReshapeLayerAttributes
from nncf.common.graph.layer_attributes import TransposeLayerAttributes
from nncf.common.graph.utils import get_concat_axis
from nncf.common.graph.utils import get_split_axis
from nncf.torch.graph.operator_metatypes import PTCatMetatype
from nncf.torch.graph.operator_metatypes import PTGroupNormMetatype
from nncf.torch.graph.operator_metatypes import PTPadMetatype
from nncf.torch.graph.operator_metatypes import PTReshapeMetatype
Expand All @@ -49,8 +47,9 @@
PERMUTE_OP_NAMES = ["permute"]
GETITEM_OP_NAMES = ["__getitem__"]
PAD_OP_NAMES = PTPadMetatype.get_all_aliases()
CONCAT_OP_NAMES = ["cat"]
OP_NAMES_REQUIRING_ATTRS_FROM_ARGS_KWARGS = list(
TRANSPOSE_OP_NAMES + PERMUTE_OP_NAMES + GETITEM_OP_NAMES + PAD_OP_NAMES
TRANSPOSE_OP_NAMES + PERMUTE_OP_NAMES + GETITEM_OP_NAMES + PAD_OP_NAMES + CONCAT_OP_NAMES
)


Expand Down Expand Up @@ -119,25 +118,13 @@ def get_layer_attributes_from_args_and_kwargs(op_name: str, args, kwargs) -> Bas
layer_attrs = _get_getitem_attrs_from_args_kwargs(args, kwargs)
elif op_name in PAD_OP_NAMES:
layer_attrs = _get_pad_attrs_from_args_kwargs(args, kwargs)
elif op_name in CONCAT_OP_NAMES:
layer_attrs = _get_concat_attrs_from_args_kwargs(args, kwargs)
return layer_attrs


def set_nodes_attributes_in_nncf_graph(graph: NNCFGraph) -> None:
for node in graph.get_all_nodes():
if node.metatype is PTCatMetatype:
input_edges = graph.get_input_edges(node)
output_edges = graph.get_output_edges(node)
# Case of intermediate node
if input_edges and output_edges:
input_shapes = [edge.tensor_shape for edge in input_edges]
output_shapes = [edge.tensor_shape for edge in output_edges]
# Case node is stack
if len(input_shapes[0]) != len(output_shapes[0]):
continue
axis = get_concat_axis(input_shapes, output_shapes)
layer_attributes = MultipleInputLayerAttributes(axis)
node.layer_attributes = layer_attributes

if node.metatype in [PTReshapeMetatype, PTSqueezeMetatype]:
input_nodes = graph.get_input_edges(node)
output_nodes = graph.get_output_edges(node)
Expand Down Expand Up @@ -178,6 +165,15 @@ def _get_pad_attrs_from_args_kwargs(args, kwargs) -> PadLayerAttributes:
return PadLayerAttributes(mode, value)


def _get_concat_attrs_from_args_kwargs(args, kwargs) -> PadLayerAttributes:
if "tensors" in kwargs:
tensors = kwargs["tensors"]
else:
tensors = args[0]
axis = kwargs.get("dim", 0 if len(args) < 2 else args[1])
return MultipleInputLayerAttributes(axis=axis, num_inputs=len(tensors))


def _get_kwargs_shifted(args_names, args, kwargs, shift=1):
res_kwargs = {}
for idx, arg_name in enumerate(args_names):
Expand Down
Loading

0 comments on commit 6c4d598

Please sign in to comment.