From 5437420e5163acc0dc329d27585a80f6dfe2a735 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 24 Aug 2023 11:23:33 +0200 Subject: [PATCH] Tests --- .../algorithms/channel_alignment/algorithm.py | 10 +- .../algorithms/channel_alignment/backend.py | 11 -- .../channel_alignment/openvino_backend.py | 41 ----- .../algorithms/smooth_quant/backend.py | 11 -- .../quantization/test_channel_alignment.py | 56 +------ tests/post_training/test_templates/models.py | 5 +- .../test_templates/test_channel_alignment.py | 141 ++++++++++++++---- 7 files changed, 125 insertions(+), 150 deletions(-) diff --git a/nncf/quantization/algorithms/channel_alignment/algorithm.py b/nncf/quantization/algorithms/channel_alignment/algorithm.py index 086163d6fcc..cae43d8cbbd 100644 --- a/nncf/quantization/algorithms/channel_alignment/algorithm.py +++ b/nncf/quantization/algorithms/channel_alignment/algorithm.py @@ -24,13 +24,13 @@ from nncf.common.graph.transformations.commands import TargetPoint from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.common.logging import nncf_logger from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend from nncf.openvino.graph.model_utils import create_bias_constant_value from nncf.openvino.graph.node_utils import get_channel_agnostic_reduction_shape -from nncf.openvino.graph.node_utils import get_weight_channel_axes from nncf.quantization.algorithms.algorithm import Algorithm from nncf.quantization.algorithms.channel_alignment.backend import ALGO_BACKENDS from nncf.quantization.algorithms.channel_alignment.backend import ChannelAlignmentAlgoBackend @@ -116,6 +116,10 @@ def filter_func(point: StatisticPoint) -> bool: assert len(tensor_collectors) == 1 stat = tensor_collectors[0].get_statistics() if stat.min_values is None or stat.max_values is None: + nncf_logger.debug( + f"Skipping channel alignment for pairs {conv_in.node_name}, {conv_out.node_name} " + "because statistics were not collected for this pair." + ) continue conv_in_cont = ConvParamsContainer(conv_in, model, graph, self._backend_entity) @@ -124,6 +128,10 @@ def filter_func(point: StatisticPoint) -> bool: conv_in_cont.dims.conv_weight_out_channels_dim is None or conv_out_cont.dims.conv_weight_out_channels_dim is None ): + nncf_logger.debug( + f"Skipping channel alignment for pairs {conv_in.node_name}, {conv_out.node_name} " + " because one of the node is 1D MatMul, 1D Matmuls are not supported by CA algortihm yet." + ) continue amean = (stat.max_values + stat.min_values) * 0.5 diff --git a/nncf/quantization/algorithms/channel_alignment/backend.py b/nncf/quantization/algorithms/channel_alignment/backend.py index cf431604b7b..3c2721954c4 100644 --- a/nncf/quantization/algorithms/channel_alignment/backend.py +++ b/nncf/quantization/algorithms/channel_alignment/backend.py @@ -122,17 +122,6 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: (bias is added to the output tensor of that operation), `False` otherwise. """ - @staticmethod - @abstractmethod - def get_dims_descriptor(node: NNCFNode) -> LayoutDescriptor: - """ - Return weights layout descriptor of the given node if it is possible and None otherwise. - Only convolutional and linear nodes are supported. - - :param node: NNCFNode to get layout descriptor from. - :return: Weights layout descriptor of the given node if it is possible and None otherwise. - """ - @staticmethod @abstractmethod def get_conv_layer_attributes(node: NNCFNode) -> Optional[ConvolutionLayerAttributes]: diff --git a/nncf/quantization/algorithms/channel_alignment/openvino_backend.py b/nncf/quantization/algorithms/channel_alignment/openvino_backend.py index 3ac0c2e34ee..b6bbde472ea 100644 --- a/nncf/quantization/algorithms/channel_alignment/openvino_backend.py +++ b/nncf/quantization/algorithms/channel_alignment/openvino_backend.py @@ -99,44 +99,3 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: bias_constant = get_node_with_bias_value(add_node, nncf_graph) return bias_constant is not None - - @staticmethod - def get_dims_descriptor(node: NNCFNode): - if node.metatype == OVConvolutionMetatype: - return LayoutDescriptor( - conv_weight_out_channels_dim=0, - conv_weight_in_channels_dim=1, - bias_channels_dim=node.metatype.output_channel_axis, - ) - if node.metatype in [OVGroupConvolutionMetatype, OVDepthwiseConvolutionMetatype]: - # Using groups dim as output channels dim for ChannelAlignment algorithm - # TODO(dlyakhov) support group convolutions with groups number not in [1, out_channels] - return LayoutDescriptor( - conv_weight_out_channels_dim=0, - conv_weight_in_channels_dim=2, - bias_channels_dim=node.metatype.output_channel_axis, - ) - if node.metatype == OVMatMulMetatype: - if node.layer_attributes is None: - raise RuntimeError(f"Attempt to align matmul node {node.node_name} that have no any constant inputs") - layer_attributes: OVLayerAttributes = node.layer_attributes - key = layer_attributes.get_const_port_ids() - assert len(key) == 1 - key = key[0] - const_attr = layer_attributes.constant_attributes[key] - a, b = list(range(len(const_attr["shape"])))[-2:] - assert key in [a, b] - if key == a: - out_ch_dim = a - in_ch_dim = b - else: - out_ch_dim = b - in_ch_dim = a - if const_attr.get("transpose", False): - out_ch_dim, in_ch_dim = in_ch_dim, out_ch_dim - return LayoutDescriptor( - conv_weight_in_channels_dim=in_ch_dim, - conv_weight_out_channels_dim=out_ch_dim, - bias_channels_dim=node.metatype.output_channel_axis, - ) - raise RuntimeError(f"Could not retrieve dims description for node {node} with metatype {node.metatype}") diff --git a/nncf/quantization/algorithms/smooth_quant/backend.py b/nncf/quantization/algorithms/smooth_quant/backend.py index dda7fc44d2d..379a9aee508 100644 --- a/nncf/quantization/algorithms/smooth_quant/backend.py +++ b/nncf/quantization/algorithms/smooth_quant/backend.py @@ -215,17 +215,6 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int: :return: Channel axis number. """ - @staticmethod - @abstractmethod - def get_weight_channel_axis(node: NNCFNode, port_id: int) -> int: - """ - Returns axis number of the weight tensor which correspond to it channel. - - :param node: NNCFNode instance. - :param port_id: Specified input port id. - :return: Channel axis number. - """ - @staticmethod @abstractmethod def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int: diff --git a/tests/openvino/native/quantization/test_channel_alignment.py b/tests/openvino/native/quantization/test_channel_alignment.py index 432aa89a536..ff81d656a7d 100644 --- a/tests/openvino/native/quantization/test_channel_alignment.py +++ b/tests/openvino/native/quantization/test_channel_alignment.py @@ -11,37 +11,20 @@ from typing import Type -import pytest - -from nncf.common.graph import NNCFNode from nncf.common.graph.transformations.commands import TargetType from nncf.openvino.graph.layer_attributes import OVLayerAttributes from nncf.openvino.graph.metatypes.openvino_metatypes import OVAddMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVConstantMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype -from nncf.openvino.graph.metatypes.openvino_metatypes import OVGroupConvolutionMetatype -from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype from nncf.openvino.graph.transformations.command_creation import OVCommandCreator from nncf.openvino.graph.transformations.commands import OVBiasCorrectionCommand from nncf.openvino.graph.transformations.commands import OVBiasInsertionCommand from nncf.openvino.graph.transformations.commands import OVTargetPoint from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand -from nncf.quantization.algorithms.channel_alignment.backend import LayoutDescriptor from nncf.quantization.algorithms.channel_alignment.openvino_backend import OVChannelAlignmentAlgoBackend from tests.post_training.test_templates.test_channel_alignment import TemplateTestChannelAlignment -def _get_nncf_node(metatype, layer_attrs): - return NNCFNode( - { - NNCFNode.ID_NODE_ATTR: 0, - NNCFNode.NODE_NAME_ATTR: "test", - NNCFNode.METATYPE_ATTR: metatype, - NNCFNode.LAYER_ATTRIBUTES: layer_attrs, - } - ) - - class TestOVChannelAlignment(TemplateTestChannelAlignment): def get_backend_cls(self) -> Type[OVChannelAlignmentAlgoBackend]: return OVChannelAlignmentAlgoBackend @@ -50,7 +33,7 @@ def target_point(self, target_type: TargetType, target_node_name: str, port_id: return OVTargetPoint(target_type, target_node_name, port_id) def convert_conv_layer_attrs(self, layer_attributes): - return OVLayerAttributes({}, {1: layer_attributes}) + return OVLayerAttributes({}, layer_attributes) def get_conv_metatype(self): return OVConvolutionMetatype @@ -69,40 +52,3 @@ def get_transformation_commands(self): def mock_command_creation_factory(self, mocker) -> None: mocker.patch("nncf.common.factory.CommandCreatorFactory.create", return_value=OVCommandCreator) - - @pytest.mark.parametrize("transpose", [False, True]) - @pytest.mark.parametrize("shape", [[3, 4], [1, 2, 3, 4]]) - @pytest.mark.parametrize("port_id", [-1, -2]) - def test_get_dims_descriptor_matmul(self, transpose, shape, port_id): - _port_id = len(shape) + port_id - node = _get_nncf_node(OVMatMulMetatype, OVLayerAttributes({_port_id: {"transpose": transpose, "shape": shape}})) - dims_descr = OVChannelAlignmentAlgoBackend.get_dims_descriptor(node) - - in_dims, out_dims = (0, 1) if port_id == -1 else (1, 0) - if len(shape) > 2: - in_dims += 2 - out_dims += 2 - if transpose: - in_dims, out_dims = out_dims, in_dims - - assert dims_descr.conv_weight_in_channels_dim == in_dims - assert dims_descr.conv_weight_out_channels_dim == out_dims - assert dims_descr.bias_channels_dim == OVMatMulMetatype.output_channel_axis - - def test_get_dims_descriptor_mm_no_layer_attrs(self): - node = _get_nncf_node(OVMatMulMetatype, None) - with pytest.raises(RuntimeError): - OVChannelAlignmentAlgoBackend.get_dims_descriptor(node) - - @pytest.mark.parametrize( - "metatype,ref_desc", - [ - (OVConvolutionMetatype, LayoutDescriptor(0, 1, 1)), - (OVGroupConvolutionMetatype, LayoutDescriptor(0, 2, 1)), - (OVGroupConvolutionMetatype, LayoutDescriptor(0, 2, 1)), - ], - ) - def test_get_dims_descriptor_convs(self, metatype, ref_desc): - node = _get_nncf_node(metatype, None) - dims_descr = OVChannelAlignmentAlgoBackend.get_dims_descriptor(node) - assert dims_descr.__dict__ == ref_desc.__dict__ diff --git a/tests/post_training/test_templates/models.py b/tests/post_training/test_templates/models.py index 546a4104318..43246691048 100644 --- a/tests/post_training/test_templates/models.py +++ b/tests/post_training/test_templates/models.py @@ -171,6 +171,7 @@ def __init__( conv_metatype, add_metatype, conv_layer_attrs=None, + conv_layer_attrs_1=None, both_biases=True, add_layer_attrs=None, constant_metatype=ConstantTestMetatype, @@ -187,6 +188,8 @@ def __init__( # | # Add_2 # Output_1 + if conv_layer_attrs_1 is None: + conv_layer_attrs_1 = conv_layer_attrs nodes = [ NodeWithType("Input_1", InputNoopMetatype), NodeWithType("Conv_1_W", constant_metatype), @@ -194,7 +197,7 @@ def __init__( NodeWithType("Add_1_W", constant_metatype), NodeWithType("Add_1", add_metatype, layer_attributes=add_layer_attrs), NodeWithType("Conv_2_W", constant_metatype), - NodeWithType("Conv_2", conv_metatype, layer_attributes=conv_layer_attrs), + NodeWithType("Conv_2", conv_metatype, layer_attributes=conv_layer_attrs_1), NodeWithType("Output_1", OutputNoopMetatype), ] if both_biases: diff --git a/tests/post_training/test_templates/test_channel_alignment.py b/tests/post_training/test_templates/test_channel_alignment.py index d3b6dd045e5..9c2a24cd781 100644 --- a/tests/post_training/test_templates/test_channel_alignment.py +++ b/tests/post_training/test_templates/test_channel_alignment.py @@ -17,6 +17,8 @@ from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes +from nncf.common.graph.layer_attributes import LayoutElem +from nncf.common.graph.layer_attributes import LinearLayerAttributes from nncf.common.graph.model_transformer import ModelTransformer from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.commands import TransformationType @@ -27,6 +29,7 @@ from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.quantization.algorithms.channel_alignment.algorithm import ChannelAlignment +from nncf.quantization.algorithms.channel_alignment.algorithm import ConvParamsContainer from nncf.quantization.algorithms.channel_alignment.backend import ChannelAlignmentAlgoBackend from nncf.quantization.algorithms.channel_alignment.backend import LayoutDescriptor from tests.post_training.test_templates.models import NNCFGraphCA @@ -46,9 +49,47 @@ groups=1, transpose=False, padding_values=(0, 0, 0, 0), + weights_layout=(LayoutElem.C_OUT, LayoutElem.C_IN, LayoutElem.SPATIAL, LayoutElem.SPATIAL), ) +DEPTHWISE_CONV_LAYER_ATTR = ConvolutionLayerAttributes( + weight_requires_grad=False, + in_channels=5, + out_channels=1, + kernel_size=(5, 5), + stride=(1, 1), + dilations=(1, 1), + groups=5, + transpose=False, + padding_values=(0, 0, 0, 0), + weights_layout=(LayoutElem.GROUPS, LayoutElem.C_OUT, LayoutElem.C_IN, LayoutElem.SPATIAL, LayoutElem.SPATIAL), +) + +MATMUL_LAYER_METATYPES = [ + # 2D + LinearLayerAttributes( + weight_requires_grad=False, + in_features=5, + out_features=10, + with_bias=False, + weights_layout=[LayoutElem.C_IN, LayoutElem.C_OUT], + ), + # 1D + LinearLayerAttributes( + weight_requires_grad=False, in_features=5, out_features=None, with_bias=False, weights_layout=[LayoutElem.C_IN] + ), + # 5D + LinearLayerAttributes( + weight_requires_grad=False, + in_features=5, + out_features=None, + with_bias=False, + weights_layout=[LayoutElem.SPATIAL, LayoutElem.SPATIAL, LayoutElem.SPATIAL, LayoutElem.C_IN, LayoutElem.C_OUT], + ), +] + + INVALID_CONSUMER_CONV_LAYER_ATTRS = [ ConvolutionLayerAttributes( weight_requires_grad=False, @@ -60,6 +101,7 @@ groups=1, transpose=False, padding_values=(0, 0, 0, 0), + weights_layout=(LayoutElem.C_OUT, LayoutElem.C_IN, LayoutElem.SPATIAL, LayoutElem.SPATIAL), ), ConvolutionLayerAttributes( weight_requires_grad=False, @@ -71,6 +113,7 @@ groups=1, transpose=False, padding_values=(0, 0, 0, 0), + weights_layout=(LayoutElem.C_OUT, LayoutElem.C_IN, LayoutElem.SPATIAL, LayoutElem.SPATIAL), ), ConvolutionLayerAttributes( weight_requires_grad=False, @@ -82,6 +125,7 @@ groups=1, transpose=False, padding_values=(0, 0, 0, 0), + weights_layout=(LayoutElem.C_OUT, LayoutElem.C_IN, LayoutElem.SPATIAL, LayoutElem.SPATIAL), ), ConvolutionLayerAttributes( weight_requires_grad=False, @@ -93,6 +137,7 @@ groups=1, transpose=False, padding_values=(1, 0, 0, 0), + weights_layout=(LayoutElem.C_OUT, LayoutElem.C_IN, LayoutElem.SPATIAL, LayoutElem.SPATIAL), ), ] @@ -107,6 +152,7 @@ groups=5, transpose=False, padding_values=(0, 0, 0, 0), + weights_layout=(LayoutElem.GROUPS, LayoutElem.C_OUT, LayoutElem.C_IN, LayoutElem.SPATIAL, LayoutElem.SPATIAL), ) @@ -232,9 +278,8 @@ def check_updated_values(updated_conv_in, updated_conv_out, updated_bias_in): (INVALID_CONV_LAYER_ATTR, INVALID_CONV_LAYER_ATTR, False), ] ) - GET_NODES_TEST_CASES.extend( - [(VALID_CONV_LAYER_ATTR, None, False), (None, VALID_CONV_LAYER_ATTR, False), (None, None, False)] - ) + GET_NODES_TEST_CASES.extend([(attr, VALID_CONV_LAYER_ATTR, True) for attr in MATMUL_LAYER_METATYPES]) + GET_NODES_TEST_CASES.append((None, VALID_CONV_LAYER_ATTR, False)) @pytest.mark.parametrize("first_conv_attrs,second_conv_attrs,ref_match", GET_NODES_TEST_CASES) def test_get_node_pairs(self, first_conv_attrs, second_conv_attrs, ref_match): @@ -260,16 +305,21 @@ def test_get_node_pairs(self, first_conv_attrs, second_conv_attrs, ref_match): else: assert len(pairs) == 0 - def _get_nncf_graph(self, num_biases: int) -> NNCFGraph: - cla = self.convert_conv_layer_attrs(VALID_CONV_LAYER_ATTR) + def _get_nncf_graph( + self, num_biases: int, conv_layer_attrs=DEPTHWISE_CONV_LAYER_ATTR, conv_layer_attrs_1=VALID_CONV_LAYER_ATTR + ) -> NNCFGraph: + cla = self.convert_conv_layer_attrs(conv_layer_attrs) + cla_1 = self.convert_conv_layer_attrs(conv_layer_attrs_1) + if num_biases == 0: - return NNCFGraphCA(self.get_conv_metatype(), cla).nncf_graph + return NNCFGraphCA(self.get_conv_metatype(), cla, cla_1).nncf_graph bla = self.get_add_layer_attrs() if num_biases == 1: return NNCFGraphCAWithBias( self.get_conv_metatype(), self.get_add_metatype(), cla, + cla_1, both_biases=False, constant_metatype=self.get_constant_metatype(), add_layer_attrs=bla, @@ -278,20 +328,37 @@ def _get_nncf_graph(self, num_biases: int) -> NNCFGraph: self.get_conv_metatype(), self.get_add_metatype(), cla, + cla_1, both_biases=True, add_layer_attrs=bla, constant_metatype=self.get_constant_metatype(), ).nncf_graph + @staticmethod + def _get_constant_lambda(value, counter=False): + if counter: + _state = 0 + + def f(*args, **kwargs): + if not counter: + return value + nonlocal _state + _state += 1 + return value + str(_state) + + return f + + @pytest.mark.parametrize("one_dim_mm", [False, True]) @pytest.mark.parametrize("empty_statistics", [False, True]) @pytest.mark.parametrize("num_biases", [0, 1, 2]) # pylint: disable=too-many-statements # pylint: disable=too-many-branches - def test_transformation_layout(self, empty_statistics, num_biases, mocker): + def test_transformation_layout(self, one_dim_mm, empty_statistics, num_biases, mocker): mocked_transformer = mocker.MagicMock() self.mock_model_transformer_factory(mocker, mocked_transformer) - nncf_graph = self._get_nncf_graph(num_biases) + first_conv_layer_attrs = DEPTHWISE_CONV_LAYER_ATTR if not one_dim_mm else MATMUL_LAYER_METATYPES[1] + nncf_graph = self._get_nncf_graph(num_biases, first_conv_layer_attrs) self.mock_nncf_graph_factory(mocker, nncf_graph) self.mock_command_creation_factory(mocker) @@ -308,19 +375,6 @@ class TestTensorStats(MinMaxTensorStatistic): def tensor_eq(*args, **kwargs): return True - def get_constant_lambda(value, counter=False): - if counter: - _state = 0 - - def f(*args, **kwargs): - if not counter: - return value - nonlocal _state - _state += 1 - return value + str(_state) - - return f - algorithm = ChannelAlignment() tensor_collector = TensorCollector() if empty_statistics: @@ -328,18 +382,16 @@ def f(*args, **kwargs): else: stat_value = (np.array([-1], dtype=np.int32), np.array([2], dtype=np.int32)) - tensor_collector.get_statistics = get_constant_lambda(TestTensorStats(*stat_value)) + tensor_collector.get_statistics = self._get_constant_lambda(TestTensorStats(*stat_value)) statistic_points.add_statistic_point(StatisticPoint(target_point, tensor_collector, algorithm._algorithm_key)) class MockBackend(backend_cls): pass ref_weights_val = "ref_weights_val" - MockBackend.get_weight_value = get_constant_lambda(ref_weights_val, True) + MockBackend.get_weight_value = self._get_constant_lambda(ref_weights_val, True) ref_bias_val = "ref_bias_val" - MockBackend.get_bias_value = get_constant_lambda(ref_bias_val, True) - ref_dims_descr = "ref_dims_descr" - MockBackend.get_dims_descriptor = get_constant_lambda(ref_dims_descr, True) + MockBackend.get_bias_value = self._get_constant_lambda(ref_bias_val, True) algorithm._backend_entity = MockBackend algorithm._set_backend_entity = mocker.MagicMock() @@ -358,7 +410,7 @@ class MockBackend(backend_cls): ) algorithm.apply(None, nncf_graph, statistic_points) - if empty_statistics: + if empty_statistics or one_dim_mm: assert algorithm._align_means.call_count == 0 assert algorithm._align_scales.call_count == 0 mocked_transformer.transform.assert_called_once() @@ -367,12 +419,15 @@ class MockBackend(backend_cls): return assert algorithm._align_means.call_count == 1 + + ref_dims = LayoutDescriptor(0, 2, 1) + ref_dims_1 = LayoutDescriptor(0, 1, 1) args = [ np.zeros((1, 1, 1, 1)), np.zeros((1, 1, 1, 1)), ref_weights_val + "2", np.array(0.5, dtype=np.float32), - ref_dims_descr + "2", + ref_dims_1, ] for i in range(num_biases): args[i] = f"ref_bias_val{i + 1}" @@ -385,8 +440,8 @@ class MockBackend(backend_cls): assert args[1] == ref_weights_val + "2" assert args[2] == ref_bias_in_after_align assert ((args[3] - 3) < EPS).all() - assert args[4] == ref_dims_descr + "1" - assert args[5] == ref_dims_descr + "2" + assert args[4] == ref_dims + assert args[5] == ref_dims_1 assert args[6] < EPS mocked_transformer.transform.assert_called_once() @@ -497,3 +552,29 @@ def test_statistic_collectors(self, inplace_ref, q_ref): assert isinstance(aggr, MedianAggregator) assert aggr.num_samples == num_samples_ref assert not aggr._use_per_sample_stats + + @pytest.mark.parametrize( + "layer_attributes,ref_layout_desc", + [ + (VALID_CONV_LAYER_ATTR, LayoutDescriptor(0, 1, 1)), + (DEPTHWISE_CONV_LAYER_ATTR, LayoutDescriptor(0, 2, 1)), + (MATMUL_LAYER_METATYPES[0], LayoutDescriptor(1, 0, 1)), + (MATMUL_LAYER_METATYPES[1], LayoutDescriptor(None, 0, 1)), + (MATMUL_LAYER_METATYPES[2], LayoutDescriptor(4, 3, 1)), + ], + ) + def test_conv_params_dims(self, layer_attributes, ref_layout_desc): + backend_cls = self.get_backend_cls() + + class MockBackend(backend_cls): + pass + + ref_weights_val = "ref_weights_val" + MockBackend.get_weight_value = self._get_constant_lambda(ref_weights_val) + ref_bias_val = "ref_bias_val" + MockBackend.get_bias_value = self._get_constant_lambda(ref_bias_val) + nncf_graph = NNCFGraphCAWithBias( + self.get_conv_metatype(), self.get_add_metatype(), self.convert_conv_layer_attrs(layer_attributes) + ).nncf_graph + cont = ConvParamsContainer(nncf_graph.get_node_by_name("/Conv_1_0"), None, nncf_graph, MockBackend) + assert cont.dims == ref_layout_desc