diff --git a/nncf/openvino/graph/node_utils.py b/nncf/openvino/graph/node_utils.py index 0a97c826376..de7736ffe59 100644 --- a/nncf/openvino/graph/node_utils.py +++ b/nncf/openvino/graph/node_utils.py @@ -17,6 +17,8 @@ from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.layer_attributes import LayoutElem +from nncf.common.graph.layer_attributes import LinearLayerAttributes from nncf.openvino.graph.layer_attributes import OVLayerAttributes from nncf.openvino.graph.metatypes.groups import OPERATIONS_WITH_BIAS from nncf.openvino.graph.metatypes.groups import OPERATIONS_WITH_WEIGHTS @@ -310,7 +312,7 @@ def get_reducer_output_node_names( return [get_result_node_name(target_node_name, port_id)] -def get_weight_channel_axes(node: NNCFNode, weights_port_id: int) -> List[int]: +def get_weight_channel_axes(node: NNCFNode) -> List[int]: """ Returns axes numbers of the weight tensor which correspond to its channels. @@ -321,35 +323,27 @@ def get_weight_channel_axes(node: NNCFNode, weights_port_id: int) -> List[int]: if node.metatype not in OPERATIONS_WITH_WEIGHTS: raise ValueError("Channel axis cannot be defined for operation without weights.") - channel_axes = node.metatype.const_channel_axis - if node.metatype == OVMatMulMetatype: - assert isinstance(node.layer_attributes, OVLayerAttributes) - assert len(channel_axes) == 1 - const_attrs = node.layer_attributes.constant_attributes[weights_port_id] - transpose = const_attrs["transpose"] - ndims = len(const_attrs["shape"]) - channel_axes = get_matmul_channel_axes(weights_port_id, ndims, transpose) + if node.metatype != OVMatMulMetatype: + return node.metatype.const_channel_axis - return channel_axes + return get_matmul_channel_axes(node) -def get_matmul_channel_axes(weights_port_id: int, ndims: int, transpose: bool) -> List[int]: +def get_matmul_channel_axes(node: ov.Node) -> List[int]: """ Calculate channel axes for the MatMul operation. - :param weights_port_id: Weight port id of the target node. - :param ndims: The number of MatMul dimensions. - :param transpose: Whether the transpose is applied to weights. + :param node: The target node. :return: List of channel axes for the MatMul operation. """ - matmul_channel_axis = OVMatMulMetatype.const_channel_axis[0] - if (weights_port_id == 1) == transpose: - matmul_channel_axis -= 1 - matmul_channel_axis = max(ndims, 2) + matmul_channel_axis - channel_axes = list(range(ndims - 2)) - if matmul_channel_axis < ndims: - channel_axes.append(matmul_channel_axis) - return channel_axes + assert isinstance(node.layer_attributes, OVLayerAttributes) + layer_attributes = node.layer_attributes.get_backend_agnostic_attributes() + assert isinstance(layer_attributes, LinearLayerAttributes) + return [ + idx + for idx, elem in enumerate(layer_attributes.weights_layout) + if elem in [LayoutElem.SPATIAL, LayoutElem.C_OUT] + ] def get_channel_agnostic_reduction_shape(channel_axes: List[int], shape: List[int]) -> Tuple[int]: diff --git a/nncf/openvino/quantization/weights_compression.py b/nncf/openvino/quantization/weights_compression.py index 8b23834e36d..ee1792e97f7 100644 --- a/nncf/openvino/quantization/weights_compression.py +++ b/nncf/openvino/quantization/weights_compression.py @@ -15,13 +15,14 @@ import openvino.runtime as ov from openvino.runtime import opset9 as opset +from nncf.common.graph.layer_attributes import LayoutElem +from nncf.common.graph.layer_attributes import LinearLayerAttributes from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVEmbeddingMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype from nncf.openvino.graph.metatypes.openvino_metatypes import get_operation_const_op from nncf.openvino.graph.node_utils import get_const_value -from nncf.openvino.graph.node_utils import get_matmul_channel_axes from nncf.quantization.fake_quantize import calculate_scale_zero_point @@ -88,10 +89,9 @@ def _get_reduction_axes(metatype: Type[OperatorMetatype], node: ov.Node, weight_ :return: The reduction axes as an integer or a tuple of integers. """ if metatype is OVMatMulMetatype: - transpose = node.get_attributes()[f"transpose_{'a' if weight_port_id == 0 else 'b'}"] - ndims = node.input(weight_port_id).get_partial_shape().rank.get_max_length() - channel_axes = get_matmul_channel_axes(weight_port_id, ndims, transpose) - axes = tuple(i for i in range(ndims) if i not in channel_axes) + layer_attributes = node.layer_attributes.get_backend_agnostic_attributes() + assert isinstance(layer_attributes, LinearLayerAttributes) + axes = tuple(idx for idx, elem in enumerate(layer_attributes.weights_layout) if elem == LayoutElem.C_IN) elif metatype is OVEmbeddingMetatype: axes = (metatype.const_channel_axis[0] + 1) % 2 else: diff --git a/nncf/quantization/algorithms/channel_alignment/openvino_backend.py b/nncf/quantization/algorithms/channel_alignment/openvino_backend.py index 00091c89eb1..a20e198e78a 100644 --- a/nncf/quantization/algorithms/channel_alignment/openvino_backend.py +++ b/nncf/quantization/algorithms/channel_alignment/openvino_backend.py @@ -9,20 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Tuple +from typing import Any, Tuple import numpy as np import openvino.runtime as ov from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode -from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes from nncf.common.graph.transformations.commands import TargetType from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase from nncf.common.utils.backend import BackendType from nncf.experimental.common.tensor_statistics.collectors import MedianAggregator from nncf.experimental.common.tensor_statistics.collectors import TensorCollector -from nncf.openvino.graph.layer_attributes import OVLayerAttributes from nncf.openvino.graph.metatypes.openvino_metatypes import OVAddMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVDepthwiseConvolutionMetatype @@ -39,7 +37,6 @@ from nncf.openvino.statistics.statistics import OVMinMaxTensorStatistic from nncf.quantization.algorithms.channel_alignment.backend import ALGO_BACKENDS from nncf.quantization.algorithms.channel_alignment.backend import ChannelAlignmentAlgoBackend -from nncf.quantization.algorithms.channel_alignment.backend import LayoutDescriptor @ALGO_BACKENDS.register(BackendType.OPENVINO) @@ -102,5 +99,5 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: return bias_constant is not None @staticmethod - def create_bias_tensor(node: NNCFNode, nncf_graph: NNCFGraph, value: Any)-> np.ndarray: + def create_bias_tensor(node: NNCFNode, nncf_graph: NNCFGraph, value: Any) -> np.ndarray: return create_bias_tensor(node, nncf_graph, value) diff --git a/nncf/quantization/algorithms/min_max/openvino_backend.py b/nncf/quantization/algorithms/min_max/openvino_backend.py index 42d89b76051..8408c10f3ad 100644 --- a/nncf/quantization/algorithms/min_max/openvino_backend.py +++ b/nncf/quantization/algorithms/min_max/openvino_backend.py @@ -147,7 +147,7 @@ def _get_reduction_shape_and_use_abs_max( const_shape = node.layer_attributes.constant_attributes[target_point.port_id]["shape"] if quantizer_config.per_channel: - channel_axes = get_weight_channel_axes(node, target_point.port_id) + channel_axes = get_weight_channel_axes(node) axes = get_channel_agnostic_reduction_shape(channel_axes, const_shape) else: axes = tuple(range(len(const_shape))) diff --git a/tests/openvino/native/test_node_utils.py b/tests/openvino/native/test_node_utils.py index 4a2a24872d4..0f9bfbe54be 100644 --- a/tests/openvino/native/test_node_utils.py +++ b/tests/openvino/native/test_node_utils.py @@ -11,10 +11,12 @@ import numpy as np import pytest +from openvino.runtime import opset9 as opset from nncf.common.factory import NNCFGraphFactory from nncf.common.graph.graph import NNCFNode from nncf.openvino.graph.layer_attributes import OVLayerAttributes +from nncf.openvino.graph.layer_attributes import get_weighted_layer_attributes from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype from nncf.openvino.graph.nncf_graph_builder import GraphConverter from nncf.openvino.graph.node_utils import get_channel_agnostic_reduction_shape @@ -60,10 +62,10 @@ def test_is_node_with_bias(model_to_create, is_with_bias, node_name): @pytest.mark.parametrize( "weights_port_id, transpose, shape, expected_channel_axes", [ - (0, False, (1,), [0]), + (0, False, (1,), []), (0, True, (1,), []), (1, False, (1,), []), - (1, True, (1,), [0]), + (1, True, (1,), []), (0, False, (1, 1), [0]), (0, True, (1, 1), [1]), (1, False, (1, 1), [1]), @@ -75,16 +77,23 @@ def test_is_node_with_bias(model_to_create, is_with_bias, node_name): ], ) def test_get_weight_channel_axes_for_matmul(weights_port_id, transpose, shape, expected_channel_axes): + input_1 = opset.parameter([1, 1], name="Input", dtype=np.float32) + constant_1 = opset.constant(np.ones(shape).astype(np.float32)) + inputs_ = (input_1, constant_1) if weights_port_id == 1 else (constant_1, input_1) + matmul_1 = opset.matmul(*inputs_, transpose_a=transpose, transpose_b=transpose, name="MatMul") + + constant_attrs = {weights_port_id: {"transpose": transpose, "shape": shape}} attributes = { NNCFNode.ID_NODE_ATTR: 0, NNCFNode.NODE_NAME_ATTR: "test", NNCFNode.METATYPE_ATTR: OVMatMulMetatype, NNCFNode.LAYER_ATTRIBUTES: OVLayerAttributes( - constant_attributes={weights_port_id: {"transpose": transpose, "shape": shape}} + layer_attributes=get_weighted_layer_attributes(matmul_1, OVMatMulMetatype, constant_attrs), + constant_attributes=constant_attrs, ), } node = NNCFNode(attributes) - actual_channel_axes = get_weight_channel_axes(node, weights_port_id) + actual_channel_axes = get_weight_channel_axes(node) assert len(actual_channel_axes) == len(expected_channel_axes) assert all(a == b for a, b in zip(actual_channel_axes, expected_channel_axes))