diff --git a/tests/openvino/native/test_smooth_quant.py b/tests/openvino/native/test_smooth_quant.py index 885506b6ce9..2b6f30974a5 100644 --- a/tests/openvino/native/test_smooth_quant.py +++ b/tests/openvino/native/test_smooth_quant.py @@ -25,6 +25,7 @@ from nncf.quantization.algorithms.smooth_quant.openvino_backend import OVSmoothQuantAlgoBackend from tests.post_training.test_templates.helpers import ConvTestModel from tests.post_training.test_templates.helpers import LinearMultiShapeModel +from tests.post_training.test_templates.helpers import ShareWeghtsConvAndShareLinearModel from tests.post_training.test_templates.test_smooth_quant import TemplateTestSQAlgorithm OV_LINEAR_MODEL_MM_OP_MAP = { @@ -80,6 +81,8 @@ def get_node_name_map(self, model_cls) -> Dict[str, str]: return OV_LINEAR_MODEL_MM_OP_MAP if model_cls is ConvTestModel: return OV_CONV_MODEL_MM_OP_MAP + if model_cls is ShareWeghtsConvAndShareLinearModel: + return {} raise NotImplementedError @staticmethod diff --git a/tests/post_training/test_templates/helpers.py b/tests/post_training/test_templates/helpers.py index cebfbec747a..97816fcc104 100644 --- a/tests/post_training/test_templates/helpers.py +++ b/tests/post_training/test_templates/helpers.py @@ -338,3 +338,21 @@ def forward(self, x): x = self.embedding(x) x = self.matmul(x) return x + + +class ShareWeghtsConvAndShareLinearModel(nn.Module): + INPUT_SIZE = [1, 1, 4, 4] + + def __init__(self): + super().__init__() + with set_torch_seed(): + self.conv = create_conv(1, 1, 1) + self.linear = nn.Linear(4, 4) + self.linear.weight.data = torch.randn((4, 4), dtype=torch.float32) + self.linear.bias.data = torch.randn((1, 4), dtype=torch.float32) + + def forward(self, x): + for _ in range(2): + x = self.conv(x) + x = self.linear(x) + return x diff --git a/tests/post_training/test_templates/test_smooth_quant.py b/tests/post_training/test_templates/test_smooth_quant.py index 436accb7975..6f93d57e3b4 100644 --- a/tests/post_training/test_templates/test_smooth_quant.py +++ b/tests/post_training/test_templates/test_smooth_quant.py @@ -27,9 +27,11 @@ from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization from nncf.quantization.algorithms.smooth_quant.algorithm import SmoothQuant from nncf.quantization.algorithms.smooth_quant.backend import SmoothQuantAlgoBackend +from nncf.quantization.algorithms.smooth_quant.openvino_backend import OVSmoothQuantAlgoBackend from tests.post_training.test_templates.helpers import ConvTestModel from tests.post_training.test_templates.helpers import LinearMultiShapeModel from tests.post_training.test_templates.helpers import NonZeroLinearModel +from tests.post_training.test_templates.helpers import ShareWeghtsConvAndShareLinearModel from tests.post_training.test_templates.helpers import get_static_dataset TModel = TypeVar("TModel") @@ -203,12 +205,16 @@ def test_get_abs_max_channel_collector(self, inplace_statistics: bool): ], ), (ConvTestModel, [("Conv1", 0)]), + (ShareWeghtsConvAndShareLinearModel, []), ), ) def test__get_nodes_to_smooth_data(self, model_cls, references, tmpdir): model = self.backend_specific_model(model_cls(), tmpdir) nncf_graph = NNCFGraphFactory.create(model) + if isinstance(self.get_backend(), OVSmoothQuantAlgoBackend) and model_cls is ShareWeghtsConvAndShareLinearModel: + pytest.xfail("Matmuls don't share one weight in OV ir for some reason") + algo = SmoothQuant() algo._set_backend_entity(model) alpha_map = algo._get_alpha_map() diff --git a/tests/torch/ptq/test_smooth_quant.py b/tests/torch/ptq/test_smooth_quant.py index 95ba9d3d19b..fa5d0599672 100644 --- a/tests/torch/ptq/test_smooth_quant.py +++ b/tests/torch/ptq/test_smooth_quant.py @@ -26,6 +26,7 @@ from nncf.torch.nncf_network import ExtraCompressionModuleType from tests.post_training.test_templates.helpers import ConvTestModel from tests.post_training.test_templates.helpers import LinearMultiShapeModel +from tests.post_training.test_templates.helpers import ShareWeghtsConvAndShareLinearModel from tests.post_training.test_templates.test_smooth_quant import TemplateTestSQAlgorithm PT_LINEAR_MODEL_SQ_MAP = { @@ -60,6 +61,8 @@ def get_node_name_map(self, model_cls) -> Dict[str, str]: return PT_LINEAR_MODEL_MM_MAP if model_cls is ConvTestModel: return PT_CONV_MODEL_MM_MAP + if model_cls is ShareWeghtsConvAndShareLinearModel: + return {} raise NotImplementedError @staticmethod