Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into dl/torch_experimen…
Browse files Browse the repository at this point in the history
…tal_statistics
  • Loading branch information
daniil-lyakhov committed Sep 15, 2023
2 parents f58a1e0 + 8472e67 commit 5b68b4b
Show file tree
Hide file tree
Showing 62 changed files with 1,474 additions and 508 deletions.
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ ifdef DATA
DATA_ARG := --data $(DATA)
endif

ifdef WEEKLY_MODELS
WEEKLY_MODELS_ARG := --weekly-models $(WEEKLY_MODELS)
endif

install-pre-commit:
pip install pre-commit==3.2.2

Expand Down Expand Up @@ -124,7 +128,13 @@ install-torch-dev: install-torch-test install-pre-commit install-pylint
pip install -r examples/post_training_quantization/torch/ssd300_vgg16/requirements.txt

test-torch:
pytest ${COVERAGE_ARGS} tests/common tests/torch --junitxml ${JUNITXML_PATH} $(DATA_ARG)
pytest ${COVERAGE_ARGS} tests/common tests/torch -m "not weekly and not nightly" --junitxml ${JUNITXML_PATH} $(DATA_ARG)

test-torch-nightly:
pytest ${COVERAGE_ARGS} tests/torch -m nightly --junitxml ${JUNITXML_PATH} $(DATA_ARG)

test-torch-weekly:
pytest ${COVERAGE_ARGS} tests/torch -m weekly --junitxml ${JUNITXML_PATH} $(DATA_ARG) ${WEEKLY_MODELS_ARG}

COMMON_PYFILES := $(shell python3 tools/collect_pylint_input_files_for_backend.py common)
pylint-torch:
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,13 +280,15 @@ A collection of ready-to-run Jupyter* notebooks are available to demonstrate how
- [NNCF Post-Training Optimization of Segment Anything Model](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/237-segment-anything)
- [NNCF Post-Training Optimization of CLIP Model](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/228-clip-zero-shot-image-classification)
- [NNCF Post-Training Optimization of ImageBind Model](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/239-image-bind)
- [NNCF Post-Training Optimization of Whisper Model](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/227-whisper-subtitles-generation)
- [Quantize a Segmentation Model and Show Live Inference](https://github.com/openvinotoolkit/openvino_notebooks/blob/main/notebooks/110-ct-segmentation-quantize)
- [Training to Deployment with TensorFlow and OpenVINO](https://github.com/openvinotoolkit/openvino_notebooks/blob/main/notebooks/301-tensorflow-training-openvino)
- [Migrate quantization from POT API to NNCF API](https://github.com/openvinotoolkit/openvino_notebooks/blob/main/notebooks/111-yolov5-quantization-migration)
- [Post-Training Quantization of Pytorch model with NNCF](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/112-pytorch-post-training-quantization-nncf)
- [Optimizing PyTorch models with NNCF of OpenVINO by 8-bit quantization](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/302-pytorch-quantization-aware-training)
- [Optimizing TensorFlow models with NNCF of OpenVINO by 8-bit quantization](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/305-tensorflow-quantization-aware-training)
- [Accelerate Inference of Sparse Transformer Models with OpenVINO and 4th Gen Intel Xeon Scalable Processors](https://github.com/openvinotoolkit/openvino_notebooks/blob/main/notebooks/116-sparsity-optimization)
- [Quantization with accuracy control using NNCF](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/122-quantizing-model-with-accuracy-control)

### Post-Training Quantization Samples

Expand Down
60 changes: 60 additions & 0 deletions ReleaseNotes.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,65 @@
# Release Notes

## New in Release 2.6.0

Post-training Quantization:

- Features:
- Added `CPU_SPR` device type support.
- Added quantizers scales unification.
- Added quantization scheme for ReduceSum operation.
- Added new types (ReduceL2, ReduceSum, Maximum) to the ignored scope for `ModelType.Transformer`.
- (OpenVINO) Added SmoothQuant algorithm.
- (OpenVINO) Added ChannelAlignment algorithm.
- (OpenVINO) Added HyperparameterTuner algorithm.
- (PyTorch) Added FastBiasCorrection algorithm support.
- (OpenVINO, ONNX) Added embedding weights quantization.
- (OpenVINO, PyTorch) Added new `compress_weights` method that provides data-free [INT8 weights compression](docs/compression_algorithms/CompressWeights.md).
- Fixes:
- Fixed detection of decomposed post-processing in models.
- Multiple fixes (new patterns, bugfixes, etc.) to solve [#1936](https://github.com/openvinotoolkit/nncf/issues/1936) issue.
- Fixed model reshaping while quantization to keep original model shape.
- (OpenVINO) Added support for sequential models quanitzation.
- (OpenVINO) Fixed in-place statistics cast to support empty dimensions.
- (OpenVINO, ONNX) Fixed quantization of the MatMul operation with weights rank > 2.
- (OpenVINO, ONNX) Fixed BiasCorrection algorithm to enable [CLIP model quantization](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/228-clip-zero-shot-image-classification).
- Improvements:
- Optimized `quantize(…)` pipeline (up to 4.3x speed up in total).
- Optimized `quantize_with_accuracy_control(…)` pipelilne (up to 8x speed up for [122-quantizing-model-with-accuracy-control](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/122-quantizing-model-with-accuracy-control) notebook).
- Optimized general statistics collection (up to 1.2x speed up for ONNX backend).
- Ignored patterns separated from Fused patterns scheme (with multiple patterns addition).
- Tutorials:
- [Post-Training Optimization of Segment Anything Model](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/237-segment-anything).
- [Post-Training Optimization of CLIP Model](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/228-clip-zero-shot-image-classification).
- [Post-Training Optimization of ImageBind Model](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/239-image-bind).
- [Post-Training Optimization of Whisper Model](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/227-whisper-subtitles-generation).
- [Post-Training Optimization with accuracy control](https://github.com/openvinotoolkit/openvino_notebooks/tree/main/notebooks/122-quantizing-model-with-accuracy-control).

Compression-aware training:

- Features:
- Added shape pruning processor for BootstrapNAS algorithm.
- Added KD loss for BootstrapNAS algorithm.
- Added `validate_scopes` parameter for NNCF configuration.
- (PyTorch) Added PyTorch 2.0 support.
- (PyTorch) Added `.strip()` option to API.
- (PyTorch) Enabled bfloat data type for quantization kernels.
- (PyTorch) Quantized models can now be `torch.jit.trace`d without calling `.strip()`.
- (PyTorch) Added support for overridden `forward` instance attribute on model objects passed into `create_compressed_model`.
- (Tensorflow) Added Tensorflow 2.12 support.
- Fixes:
- (PyTorch) Fixed padding adjustment issue in the elastic kernel to work with the different active kernel sizes.
- (PyTorch) Fixed the torch graph tracing in the case the tensors belonging to parallel edges are interleaved in the order of the tensor argument.
- (PyTorch) Fixed recurrent nodes matching (LSTM, GRU cells) condition with the strict rule to avoid adding not necessary nodes to the ignored scope.
- (PyTorch) Fixed `torch.jit.script` wrapper so that user-side handling exceptions during `torch.jit.script` invocation do not cause NNCF to be permanently disabled.
- (PyTorch, Tensorflow) Adjusted quantizer propagation algorithm to check if quantizer propagation will result in output quantization.
- (PyTorch) Added redefined `__class__` method for ProxyModule that avoids causing error while calling `.super()` in forward method.
- Deprecations/Removals:
- (PyTorch) Removed deprecated `NNCFNetwork.__getattr__`, `NNCFNetwork.get_nncf_wrapped_model` methods.
- Requirements:
- Updated PyTorch version (2.0.1).
- Updated Tensorflow version (2.12.0).

## New in Release 2.5.0

Post-training Quantization:
Expand Down
22 changes: 11 additions & 11 deletions examples/post_training_quantization/openvino/yolov8/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
import openvino.runtime as ov
import torch
from tqdm import tqdm
from ultralytics import YOLO
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.utils import check_det_dataset
from ultralytics.yolo.engine.validator import BaseValidator as Validator
from ultralytics.yolo.utils import DATASETS_DIR
from ultralytics.yolo.utils import DEFAULT_CFG
from ultralytics.yolo.utils import ops
from ultralytics.yolo.utils.metrics import ConfusionMatrix
from ultralytics.cfg import get_cfg
from ultralytics.data.converter import coco80_to_coco91_class
from ultralytics.data.utils import check_det_dataset
from ultralytics.engine.validator import BaseValidator as Validator
from ultralytics.models.yolo import YOLO
from ultralytics.utils import DATASETS_DIR
from ultralytics.utils import DEFAULT_CFG
from ultralytics.utils.metrics import ConfusionMatrix

import nncf

Expand Down Expand Up @@ -66,17 +66,17 @@ def print_statistics(stats: np.ndarray, total_images: int, total_objects: int) -


def prepare_validation(model: YOLO, args: Any) -> Tuple[Validator, torch.utils.data.DataLoader]:
validator = model.ValidatorClass(args)
validator = model.smart_load("validator")(args)
validator.data = check_det_dataset(args.data)
dataset = validator.data["val"]
print(f"{dataset}")

data_loader = validator.get_dataloader(f"{DATASETS_DIR}/coco128", 1)

validator = model.ValidatorClass(args)
validator = model.smart_load("validator")(args)

validator.is_coco = True
validator.class_map = ops.coco80_to_coco91_class()
validator.class_map = coco80_to_coco91_class()
validator.names = model.model.names
validator.metrics.names = validator.names
validator.nc = model.model.model[-1].nc
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
ultralytics==8.0.43
ultralytics==8.0.170
onnx>=1.12.0
openvino-dev==2023.0.1
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
import openvino.runtime as ov
import torch
from tqdm import tqdm
from ultralytics import YOLO
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.utils import check_det_dataset
from ultralytics.yolo.engine.validator import BaseValidator as Validator
from ultralytics.yolo.utils import DATASETS_DIR
from ultralytics.yolo.utils import DEFAULT_CFG
from ultralytics.yolo.utils import ops
from ultralytics.yolo.utils.metrics import ConfusionMatrix
from ultralytics.cfg import get_cfg
from ultralytics.data.converter import coco80_to_coco91_class
from ultralytics.data.utils import check_det_dataset
from ultralytics.engine.validator import BaseValidator as Validator
from ultralytics.models.yolo import YOLO
from ultralytics.utils import DATASETS_DIR
from ultralytics.utils import DEFAULT_CFG
from ultralytics.utils import ops
from ultralytics.utils.metrics import ConfusionMatrix

import nncf

Expand Down Expand Up @@ -91,17 +92,17 @@ def print_statistics(stats: np.ndarray, total_images: int, total_objects: int) -


def prepare_validation(model: YOLO, args: Any) -> Tuple[Validator, torch.utils.data.DataLoader]:
validator = model.ValidatorClass(args)
validator = model.smart_load("validator")(args)
validator.data = check_det_dataset(args.data)
dataset = validator.data["val"]
print(f"{dataset}")

data_loader = validator.get_dataloader(f"{DATASETS_DIR}/coco128-seg", 1)

validator = model.ValidatorClass(args)
validator = model.smart_load("validator")(args)

validator.is_coco = True
validator.class_map = ops.coco80_to_coco91_class()
validator.class_map = coco80_to_coco91_class()
validator.names = model.model.names
validator.metrics.names = validator.names
validator.nc = model.model.model[-1].nc
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
ultralytics==8.0.43
ultralytics==8.0.170
onnx>=1.12.0
openvino-dev==2023.0.1
6 changes: 6 additions & 0 deletions nncf/common/hardware/configs/cpu.json
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@
"activations": "q8_a"
}
},
{
"type": "GroupNormalization",
"quantization": {
"activations": "q8_a"
}
},
{"type": "Flatten"},
{"type": "Squeeze"},
{"type": "Unsqueeze"},
Expand Down
1 change: 1 addition & 0 deletions nncf/common/hardware/opset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ class HWConfigOpName:
GELU = "Gelu"
LSTMSEQUENCE = "LSTMSequence"
GRUSEQUENCE = "GRUSequence"
GROUPNORMALIZATION = "GroupNormalization"
10 changes: 8 additions & 2 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from itertools import islice
from typing import Any, Dict, TypeVar

from tqdm import tqdm
from tqdm.auto import tqdm

from nncf.common import factory
from nncf.common.graph.graph import NNCFGraph
Expand Down Expand Up @@ -54,9 +54,15 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
model_with_outputs = model_transformer.transform(transformation_layout)
engine = factory.EngineFactory.create(model_with_outputs)

dataset_length = self.dataset.get_length()
total = (
min(dataset_length or self.stat_subset_size, self.stat_subset_size)
if self.stat_subset_size is not None
else None
)
for input_data in tqdm(
islice(self.dataset.get_inference_data(), self.stat_subset_size),
total=self.stat_subset_size,
total=total,
desc="Statistics collection",
):
outputs = engine.infer(input_data)
Expand Down
2 changes: 1 addition & 1 deletion nncf/config/schemata/algo/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@
},
"export_to_onnx_standard_ops": with_attributes(
BOOLEAN,
description="Determines how should the additional quantization "
description="[Deprecated] Determines how should the additional quantization "
"operations be exported into the ONNX format. Set "
"this to true to export to ONNX "
"standard QuantizeLinear-DequantizeLinear "
Expand Down
9 changes: 9 additions & 0 deletions nncf/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ def get_inference_data(self, indices: Optional[List[int]] = None) -> Iterable[Mo
"""
return DataProvider(self._data_source, self._transform_func, indices)

def get_length(self) -> Optional[int]:
"""
Tries to fetch length of the underlying dataset.
:return: The length of the data_source if __len__() is implemented for it, and None otherwise.
"""
if hasattr(self._data_source, "__len__"):
return self._data_source.__len__()
return None


class DataProvider(Generic[DataItem, ModelInput]):
def __init__(
Expand Down
1 change: 1 addition & 0 deletions nncf/openvino/graph/metatypes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
ov_metatypes.OVSquaredDifferenceMetatype,
ov_metatypes.OVLSTMSequenceMetatype,
ov_metatypes.OVGRUSequenceMetatype,
ov_metatypes.OVGroupNormalizationMetatype,
]


Expand Down
7 changes: 7 additions & 0 deletions nncf/openvino/graph/metatypes/openvino_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,13 @@ class OVAbsMetatype(OVOpMetatype):
op_names = ["Abs"]


@OV_OPERATOR_METATYPES.register()
class OVGroupNormalizationMetatype(OVOpMetatype):
name = "GroupNormalizationOp"
op_names = ["GroupNormalization"]
hw_config_names = [HWConfigOpName.GROUPNORMALIZATION]


def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
"""
Returns a list of the operator metatypes.
Expand Down
5 changes: 4 additions & 1 deletion nncf/openvino/graph/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def _add_nncf_node(node: ov.Node, graph: NNCFGraph) -> None:
metatype = get_node_metatype(node)
graph.add_nncf_node(node_name=node.get_friendly_name(), node_type=node_type, node_metatype=metatype)

# pylint: disable=too-many-branches
@staticmethod
def create_nncf_graph(model: ov.Model) -> NNCFGraph:
"""
Expand Down Expand Up @@ -174,8 +175,10 @@ def create_nncf_graph(model: ov.Model) -> NNCFGraph:
node_attributes = node.get_attributes()
const_transpose_name = attribute_names[const_port_id]
const_attrs[const_port_id]["transpose"] = node_attributes[const_transpose_name]

act_attrs["transpose"] = node_attributes[attribute_names[act_port_id]]
elif metatype == OVGRUSequenceMetatype:
node_attributes = node.get_attributes()
act_attrs["linear_before_reset"] = node_attributes["linear_before_reset"]

if const_attrs or act_attrs:
nncf_node = nncf_graph.get_node_by_name(node_name)
Expand Down
4 changes: 4 additions & 0 deletions nncf/openvino/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ def native_quantize_with_accuracy_control_impl(
copied_parameters,
)

if advanced_accuracy_restorer_parameters.intermediate_model_dir:
quantized_model_path = f"{advanced_accuracy_restorer_parameters.intermediate_model_dir}/intermediate_model.xml"
ov.serialize(quantized_model, quantized_model_path)

evaluator = Evaluator(validation_fn)
evaluator.enable_iteration_count()
initial_metric_results = evaluator.collect_metric_results(model, validation_dataset, model_name="initial")
Expand Down
4 changes: 4 additions & 0 deletions nncf/quantization/advanced_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,16 @@ class AdvancedAccuracyRestorerParameters:
:param num_ranking_processes: The number of parallel processes that are used to rank
quantization operations.
:type num_ranking_processes: Optional[int]
:param intermediate_model_dir: Path to the folder where the model, which was fully
quantized with initial parameters, should be saved.
:type intermediate_model_dir: Optional[str]
"""

max_num_iterations: int = sys.maxsize
tune_hyperparams: bool = False
ranking_subset_size: Optional[int] = None
num_ranking_processes: Optional[int] = None
intermediate_model_dir: Optional[str] = None


def changes_asdict(params: Any) -> Dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import Any, Dict, List, Optional, Tuple, TypeVar

import numpy as np
from tqdm import tqdm
from tqdm.auto import tqdm

from nncf import Dataset
from nncf import nncf_logger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Any, Dict, List, Optional, Tuple, TypeVar

import numpy as np
from tqdm import tqdm
from tqdm.auto import tqdm

from nncf import Dataset
from nncf.common.factory import CommandCreatorFactory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union

from tqdm import tqdm
from tqdm.auto import tqdm

from nncf import Dataset
from nncf.common.factory import EngineFactory
Expand Down
5 changes: 5 additions & 0 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,11 @@ def _get_ignored_names(

ignored_names = {name: IgnoreReason.AUTOGENERATED for name in autogenerated_ignored_names}

ignored_names_by_layer_attributes = self._backend_entity.get_ignored_names_by_layer_attributes(
inference_nncf_graph
)
ignored_names.update({name: IgnoreReason.AUTOGENERATED for name in ignored_names_by_layer_attributes})

# User ignored scope has higher priority
ignored_names.update({name: IgnoreReason.USER_REQUESTED for name in user_ignored_names})

Expand Down
Loading

0 comments on commit 5b68b4b

Please sign in to comment.