Skip to content

Commit

Permalink
SQ model with shared nodes used in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jan 24, 2024
1 parent ad5e134 commit 50b28a0
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/openvino/native/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nncf.quantization.algorithms.smooth_quant.openvino_backend import OVSmoothQuantAlgoBackend
from tests.post_training.test_templates.helpers import ConvTestModel
from tests.post_training.test_templates.helpers import LinearMultiShapeModel
from tests.post_training.test_templates.helpers import ShareWeghtsConvAndShareLinearModel
from tests.post_training.test_templates.test_smooth_quant import TemplateTestSQAlgorithm

OV_LINEAR_MODEL_MM_OP_MAP = {
Expand Down Expand Up @@ -80,6 +81,8 @@ def get_node_name_map(self, model_cls) -> Dict[str, str]:
return OV_LINEAR_MODEL_MM_OP_MAP
if model_cls is ConvTestModel:
return OV_CONV_MODEL_MM_OP_MAP
if model_cls is ShareWeghtsConvAndShareLinearModel:
return {}
raise NotImplementedError

@staticmethod
Expand Down
18 changes: 18 additions & 0 deletions tests/post_training/test_templates/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,21 @@ def forward(self, x):
x = self.embedding(x)
x = self.matmul(x)
return x


class ShareWeghtsConvAndShareLinearModel(nn.Module):
INPUT_SIZE = [1, 1, 4, 4]

def __init__(self):
super().__init__()
with set_torch_seed():
self.conv = create_conv(1, 1, 1)
self.linear = nn.Linear(4, 4)
self.linear.weight.data = torch.randn((4, 4), dtype=torch.float32)
self.linear.bias.data = torch.randn((1, 4), dtype=torch.float32)

def forward(self, x):
for _ in range(2):
x = self.conv(x)
x = self.linear(x)
return x
6 changes: 6 additions & 0 deletions tests/post_training/test_templates/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
from nncf.quantization.algorithms.smooth_quant.algorithm import SmoothQuant
from nncf.quantization.algorithms.smooth_quant.backend import SmoothQuantAlgoBackend
from nncf.quantization.algorithms.smooth_quant.openvino_backend import OVSmoothQuantAlgoBackend
from tests.post_training.test_templates.helpers import ConvTestModel
from tests.post_training.test_templates.helpers import LinearMultiShapeModel
from tests.post_training.test_templates.helpers import NonZeroLinearModel
from tests.post_training.test_templates.helpers import ShareWeghtsConvAndShareLinearModel
from tests.post_training.test_templates.helpers import get_static_dataset

TModel = TypeVar("TModel")
Expand Down Expand Up @@ -203,12 +205,16 @@ def test_get_abs_max_channel_collector(self, inplace_statistics: bool):
],
),
(ConvTestModel, [("Conv1", 0)]),
(ShareWeghtsConvAndShareLinearModel, []),
),
)
def test__get_nodes_to_smooth_data(self, model_cls, references, tmpdir):
model = self.backend_specific_model(model_cls(), tmpdir)
nncf_graph = NNCFGraphFactory.create(model)

if isinstance(self.get_backend(), OVSmoothQuantAlgoBackend) and model_cls is ShareWeghtsConvAndShareLinearModel:
pytest.xfail("Matmuls don't share one weight in OV ir for some reason")

algo = SmoothQuant()
algo._set_backend_entity(model)
alpha_map = algo._get_alpha_map()
Expand Down
3 changes: 3 additions & 0 deletions tests/torch/ptq/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nncf.torch.nncf_network import ExtraCompressionModuleType
from tests.post_training.test_templates.helpers import ConvTestModel
from tests.post_training.test_templates.helpers import LinearMultiShapeModel
from tests.post_training.test_templates.helpers import ShareWeghtsConvAndShareLinearModel
from tests.post_training.test_templates.test_smooth_quant import TemplateTestSQAlgorithm

PT_LINEAR_MODEL_SQ_MAP = {
Expand Down Expand Up @@ -60,6 +61,8 @@ def get_node_name_map(self, model_cls) -> Dict[str, str]:
return PT_LINEAR_MODEL_MM_MAP
if model_cls is ConvTestModel:
return PT_CONV_MODEL_MM_MAP
if model_cls is ShareWeghtsConvAndShareLinearModel:
return {}
raise NotImplementedError

@staticmethod
Expand Down

0 comments on commit 50b28a0

Please sign in to comment.