diff --git a/examples/post_training_quantization/torch/ssd300_vgg16/main.py b/examples/post_training_quantization/torch/ssd300_vgg16/main.py index 6d6b9365a34..3bed9cfee45 100644 --- a/examples/post_training_quantization/torch/ssd300_vgg16/main.py +++ b/examples/post_training_quantization/torch/ssd300_vgg16/main.py @@ -29,6 +29,7 @@ from torchvision.models.detection.ssd import SSD from torchvision.models.detection.ssd import GeneralizedRCNNTransform from nncf.common.logging.track_progress import track +from functools import partial ROOT = Path(__file__).parent.resolve() DATASET_URL = "https://ultralytics.com/assets/coco128.zip" @@ -125,10 +126,10 @@ def validate(model: torch.nn.Module, dataset: COCO128Dataset, device: torch.devi return computed_metrics["map_50"] -def transform_fn(data_item: Tuple[torch.Tensor, Dict]) -> torch.Tensor: +def transform_fn(data_item: Tuple[torch.Tensor, Dict], device: torch.device) -> torch.Tensor: # Skip label and add a batch dimension to an image tensor images, _ = data_item - return images[None] + return images[None].to(device) def main(): @@ -149,7 +150,7 @@ def main(): disable_tracing(SSD.postprocess_detections) # Quantize model - calibration_dataset = nncf.Dataset(dataset, transform_fn) + calibration_dataset = nncf.Dataset(dataset, partial(transform_fn, device=device)) quantized_model = nncf.quantize(model, calibration_dataset) # Convert to OpenVINO diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index a55ff9c861d..0a0d4f09fa6 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -77,8 +77,8 @@ def do_compression( if isinstance(reduction_axes, tuple) and len(reduction_axes) != 1: nncf_logger.warning( f"Weight compression expects a single reduction axes, but given {len(reduction_axes)}. " - f"Weight shape: {const_shape}, reduction axes: {reduction_axes}, node name: {nncf_node.name}. " - "The node won't be quantized." + f"Weight shape: {const_shape}, reduction axes: {reduction_axes}, " + f"node name: {nncf_node.node_name}. The node won't be quantized." ) continue reduction_axis = reduction_axes[0] if isinstance(reduction_axes, tuple) else reduction_axes diff --git a/nncf/torch/engine.py b/nncf/torch/engine.py index 63b4e93f114..44271123d6b 100644 --- a/nncf/torch/engine.py +++ b/nncf/torch/engine.py @@ -15,9 +15,6 @@ from torch import nn from nncf.common.engine import Engine -from nncf.torch.nested_objects_traversal import objwalk -from nncf.torch.utils import get_model_device -from nncf.torch.utils import is_tensor class PTEngine(Engine): @@ -34,7 +31,6 @@ def __init__(self, model: nn.Module): self._model = model self._model.eval() - self._device = get_model_device(model) def infer( self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]] @@ -46,11 +42,6 @@ def infer( :return: Model outputs. """ - def send_to_device(tensor): - return tensor.to(self._device) - - input_data = objwalk(input_data, is_tensor, send_to_device) - if isinstance(input_data, dict): return self._model(**input_data) if isinstance(input_data, tuple): diff --git a/tests/openvino/native/models.py b/tests/openvino/native/models.py index ca563218df3..860efbdba69 100644 --- a/tests/openvino/native/models.py +++ b/tests/openvino/native/models.py @@ -786,3 +786,17 @@ def _create_ov_model(self): result.get_output_tensor(0).set_names(set(["Result"])) model = ov.Model([result], [input_node]) return model + + +class GatherWithTwoReductionAxes(OVReferenceModel): + def _create_ov_model(self): + input_1 = opset.parameter([2, 3], name="Input") + convert_1 = opset.convert(input_1, destination_type="i64", name="Convert_1") + + gather_2_data = opset.constant(self._rng.random((3, 2, 1)), dtype=np.float32, name="gather_2_data") + gather_2 = opset.gather(gather_2_data, convert_1, axis=0, batch_dims=0) + gather_2.set_friendly_name("Gather_2") + + result = opset.result(gather_2, name="Result") + model = ov.Model([result], [input_1]) + return model diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 561d1927509..b83cf386462 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -25,6 +25,7 @@ from nncf.quantization.algorithms.weight_compression.openvino_backend import _reshape_weights_for_grouped_quantization from nncf.scopes import IgnoredScope from tests.openvino.native.common import get_openvino_version +from tests.openvino.native.models import GatherWithTwoReductionAxes from tests.openvino.native.models import IntegerModel from tests.openvino.native.models import SequentialMatmulModel from tests.openvino.native.models import WeightsModel @@ -202,6 +203,14 @@ def test_mixed_precision(ratio, group_size, ref_nf4_nodes): assert op.get_element_type() == ov.Type.nf4 +def test_not_quantize_with_multiple_reduction_axes(): + model = GatherWithTwoReductionAxes().ov_model + compressed_model = compress_weights(model, mode=CompressWeightsMode.INT8) + for op in compressed_model.get_ordered_ops(): + if op.get_type_name() == "Constant" and op.get_friendly_name() == "gather_2_data": + assert op.get_element_type() == ov.Type(np.float32) + + @dataclass class QuantErrorDesc: weight: List[float]