Skip to content

Commit

Permalink
Apply comments
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Oct 23, 2023
1 parent eedee0a commit 37a4489
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 5 deletions.
28 changes: 27 additions & 1 deletion nncf/openvino/graph/layer_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def get_backend_agnostic_attributes(self):


def get_conv_weights_layout_from_node(node: NNCFNode) -> OVConvLayout:
"""
Calculates weights layout for a target convolution node.
:param node: Target convolution node.
:return: Target convolution Node weights layout.
"""
layer_attributes = node.layer_attributes
port_id = _get_port_id_from_layer_attributes(layer_attributes)
return get_conv_weights_layout(
Expand All @@ -92,6 +98,12 @@ def get_conv_weights_layout_from_node(node: NNCFNode) -> OVConvLayout:


def get_linear_weights_layout_from_node(node: NNCFNode) -> OVConvLayout:
"""
Calculates weights layout for a target linear node.
:param node: Target linear node.
:return: Target linear Node weight layout.
"""
layer_attributes = node.layer_attributes
port_id = _get_port_id_from_layer_attributes(layer_attributes)
constant_layer_attrs = layer_attributes.constant_attributes[port_id]
Expand All @@ -102,20 +114,34 @@ def get_linear_weights_layout_from_node(node: NNCFNode) -> OVConvLayout:
)


def _get_port_id_from_layer_attributes(layer_attributes):
def _get_port_id_from_layer_attributes(layer_attributes) -> int:
port_ids = list(layer_attributes.constant_attributes.keys())
assert len(port_ids) == 1
return port_ids[0]


def get_conv_weights_layout(ov_metatype: OVOpMetatype, weights_shape: Tuple[int, ...]) -> OVConvLayout:
"""
Calculates weights layout for a target convolution node.
:param ov_metatype: Target convolution node OpenVINO metatype.
:param weights_shape: Shape of the target convolution node weight.
:return: Target convolution node weights layout.
"""
weights_layout = ov_metatype.const_layout
kernel_size = weights_shape[len(weights_layout) :]
weights_layout += [OVConvLayoutElem.SPATIAL] * len(kernel_size)
return tuple(weights_layout)


def get_linear_weights_layout(weights_shape: Tuple[int, ...], transpose: bool, port_id: int) -> OVConvLayout:
"""
Calculates weights layout for a target linear node.
:param weights_shape: Shape of the target linear node weight.
:param port_id: Port id of the target liner node weights.
:return: Target linear node weight layout.
"""
weights_layout = [OVConvLayoutElem.SPATIAL] * (len(weights_shape) - 2)
if len(weights_shape) > 1:
if (transpose and port_id == 0) or (not transpose and port_id == 1):
Expand Down
4 changes: 2 additions & 2 deletions nncf/quantization/algorithms/channel_alignment/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,10 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin

channel_axis = conv_in.metatype.output_channel_axis
activation_shape = list(range(len(graph.get_output_edges(node_in)[0].tensor_shape)))
reduction_shape = self._backend_entity.get_channel_agnostic_reduction_axes([channel_axis], activation_shape)
reduction_axes = self._backend_entity.get_channel_agnostic_reduction_axes([channel_axis], activation_shape)

statistic_collector = self._backend_entity.get_statistic_collector(
tuple(reduction_shape), self._quantile, self.subset_size, self.inplace_statistics
reduction_axes, self._quantile, self.subset_size, self.inplace_statistics
)
statistic_container.add_statistic_point(
StatisticPoint(
Expand Down
3 changes: 1 addition & 2 deletions nncf/quantization/algorithms/channel_alignment/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@

from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Optional, Tuple, TypeVar
from typing import Any, Tuple, TypeVar

import numpy as np

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
Expand Down

0 comments on commit 37a4489

Please sign in to comment.