Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into dl/conv_layer_attr…
Browse files Browse the repository at this point in the history
…s_update
  • Loading branch information
daniil-lyakhov committed Nov 15, 2023
2 parents b5b023e + 8efd04a commit 608b16e
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torchvision.models.detection.ssd import SSD
from torchvision.models.detection.ssd import GeneralizedRCNNTransform
from nncf.common.logging.track_progress import track
from functools import partial

ROOT = Path(__file__).parent.resolve()
DATASET_URL = "https://ultralytics.com/assets/coco128.zip"
Expand Down Expand Up @@ -125,10 +126,10 @@ def validate(model: torch.nn.Module, dataset: COCO128Dataset, device: torch.devi
return computed_metrics["map_50"]


def transform_fn(data_item: Tuple[torch.Tensor, Dict]) -> torch.Tensor:
def transform_fn(data_item: Tuple[torch.Tensor, Dict], device: torch.device) -> torch.Tensor:
# Skip label and add a batch dimension to an image tensor
images, _ = data_item
return images[None]
return images[None].to(device)


def main():
Expand All @@ -149,7 +150,7 @@ def main():
disable_tracing(SSD.postprocess_detections)

# Quantize model
calibration_dataset = nncf.Dataset(dataset, transform_fn)
calibration_dataset = nncf.Dataset(dataset, partial(transform_fn, device=device))
quantized_model = nncf.quantize(model, calibration_dataset)

# Convert to OpenVINO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def do_compression(
if isinstance(reduction_axes, tuple) and len(reduction_axes) != 1:
nncf_logger.warning(
f"Weight compression expects a single reduction axes, but given {len(reduction_axes)}. "
f"Weight shape: {const_shape}, reduction axes: {reduction_axes}, node name: {nncf_node.name}. "
"The node won't be quantized."
f"Weight shape: {const_shape}, reduction axes: {reduction_axes}, "
f"node name: {nncf_node.node_name}. The node won't be quantized."
)
continue
reduction_axis = reduction_axes[0] if isinstance(reduction_axes, tuple) else reduction_axes
Expand Down
9 changes: 0 additions & 9 deletions nncf/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
from torch import nn

from nncf.common.engine import Engine
from nncf.torch.nested_objects_traversal import objwalk
from nncf.torch.utils import get_model_device
from nncf.torch.utils import is_tensor


class PTEngine(Engine):
Expand All @@ -34,7 +31,6 @@ def __init__(self, model: nn.Module):

self._model = model
self._model.eval()
self._device = get_model_device(model)

def infer(
self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]]
Expand All @@ -46,11 +42,6 @@ def infer(
:return: Model outputs.
"""

def send_to_device(tensor):
return tensor.to(self._device)

input_data = objwalk(input_data, is_tensor, send_to_device)

if isinstance(input_data, dict):
return self._model(**input_data)
if isinstance(input_data, tuple):
Expand Down
14 changes: 14 additions & 0 deletions tests/openvino/native/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,3 +786,17 @@ def _create_ov_model(self):
result.get_output_tensor(0).set_names(set(["Result"]))
model = ov.Model([result], [input_node])
return model


class GatherWithTwoReductionAxes(OVReferenceModel):
def _create_ov_model(self):
input_1 = opset.parameter([2, 3], name="Input")
convert_1 = opset.convert(input_1, destination_type="i64", name="Convert_1")

gather_2_data = opset.constant(self._rng.random((3, 2, 1)), dtype=np.float32, name="gather_2_data")
gather_2 = opset.gather(gather_2_data, convert_1, axis=0, batch_dims=0)
gather_2.set_friendly_name("Gather_2")

result = opset.result(gather_2, name="Result")
model = ov.Model([result], [input_1])
return model
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nncf.quantization.algorithms.weight_compression.openvino_backend import _reshape_weights_for_grouped_quantization
from nncf.scopes import IgnoredScope
from tests.openvino.native.common import get_openvino_version
from tests.openvino.native.models import GatherWithTwoReductionAxes
from tests.openvino.native.models import IntegerModel
from tests.openvino.native.models import SequentialMatmulModel
from tests.openvino.native.models import WeightsModel
Expand Down Expand Up @@ -202,6 +203,14 @@ def test_mixed_precision(ratio, group_size, ref_nf4_nodes):
assert op.get_element_type() == ov.Type.nf4


def test_not_quantize_with_multiple_reduction_axes():
model = GatherWithTwoReductionAxes().ov_model
compressed_model = compress_weights(model, mode=CompressWeightsMode.INT8)
for op in compressed_model.get_ordered_ops():
if op.get_type_name() == "Constant" and op.get_friendly_name() == "gather_2_data":
assert op.get_element_type() == ov.Type(np.float32)


@dataclass
class QuantErrorDesc:
weight: List[float]
Expand Down

0 comments on commit 608b16e

Please sign in to comment.