diff --git a/nncf/openvino/graph/layer_attributes.py b/nncf/openvino/graph/layer_attributes.py index eb5425104a4..537cc0fc071 100644 --- a/nncf/openvino/graph/layer_attributes.py +++ b/nncf/openvino/graph/layer_attributes.py @@ -84,6 +84,12 @@ def get_backend_agnostic_attributes(self): def get_conv_weights_layout_from_node(node: NNCFNode) -> OVConvLayout: + """ + Calculates weights layout for a target convolution node. + + :param node: Target convolution node. + :return: Target convolution Node weights layout. + """ layer_attributes = node.layer_attributes port_id = _get_port_id_from_layer_attributes(layer_attributes) return get_conv_weights_layout( @@ -92,6 +98,12 @@ def get_conv_weights_layout_from_node(node: NNCFNode) -> OVConvLayout: def get_linear_weights_layout_from_node(node: NNCFNode) -> OVConvLayout: + """ + Calculates weights layout for a target linear node. + + :param node: Target linear node. + :return: Target linear Node weight layout. + """ layer_attributes = node.layer_attributes port_id = _get_port_id_from_layer_attributes(layer_attributes) constant_layer_attrs = layer_attributes.constant_attributes[port_id] @@ -102,13 +114,20 @@ def get_linear_weights_layout_from_node(node: NNCFNode) -> OVConvLayout: ) -def _get_port_id_from_layer_attributes(layer_attributes): +def _get_port_id_from_layer_attributes(layer_attributes) -> int: port_ids = list(layer_attributes.constant_attributes.keys()) assert len(port_ids) == 1 return port_ids[0] def get_conv_weights_layout(ov_metatype: OVOpMetatype, weights_shape: Tuple[int, ...]) -> OVConvLayout: + """ + Calculates weights layout for a target convolution node. + + :param ov_metatype: Target convolution node OpenVINO metatype. + :param weights_shape: Shape of the target convolution node weight. + :return: Target convolution node weights layout. + """ weights_layout = ov_metatype.const_layout kernel_size = weights_shape[len(weights_layout) :] weights_layout += [OVConvLayoutElem.SPATIAL] * len(kernel_size) @@ -116,6 +135,13 @@ def get_conv_weights_layout(ov_metatype: OVOpMetatype, weights_shape: Tuple[int, def get_linear_weights_layout(weights_shape: Tuple[int, ...], transpose: bool, port_id: int) -> OVConvLayout: + """ + Calculates weights layout for a target linear node. + + :param weights_shape: Shape of the target linear node weight. + :param port_id: Port id of the target liner node weights. + :return: Target linear node weight layout. + """ weights_layout = [OVConvLayoutElem.SPATIAL] * (len(weights_shape) - 2) if len(weights_shape) > 1: if (transpose and port_id == 0) or (not transpose and port_id == 1): diff --git a/nncf/quantization/algorithms/channel_alignment/algorithm.py b/nncf/quantization/algorithms/channel_alignment/algorithm.py index 82ea653cf6c..1678c70a9d0 100644 --- a/nncf/quantization/algorithms/channel_alignment/algorithm.py +++ b/nncf/quantization/algorithms/channel_alignment/algorithm.py @@ -390,10 +390,10 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin channel_axis = conv_in.metatype.output_channel_axis activation_shape = list(range(len(graph.get_output_edges(node_in)[0].tensor_shape))) - reduction_shape = self._backend_entity.get_channel_agnostic_reduction_axes([channel_axis], activation_shape) + reduction_axes = self._backend_entity.get_channel_agnostic_reduction_axes([channel_axis], activation_shape) statistic_collector = self._backend_entity.get_statistic_collector( - tuple(reduction_shape), self._quantile, self.subset_size, self.inplace_statistics + reduction_axes, self._quantile, self.subset_size, self.inplace_statistics ) statistic_container.add_statistic_point( StatisticPoint( diff --git a/nncf/quantization/algorithms/channel_alignment/backend.py b/nncf/quantization/algorithms/channel_alignment/backend.py index 977788fb4b9..158bb4aa157 100644 --- a/nncf/quantization/algorithms/channel_alignment/backend.py +++ b/nncf/quantization/algorithms/channel_alignment/backend.py @@ -11,13 +11,12 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Optional, Tuple, TypeVar +from typing import Any, Tuple, TypeVar import numpy as np from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode -from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes from nncf.common.graph.transformations.commands import TargetPoint from nncf.common.graph.transformations.commands import TargetType from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase