From 681c3575cbbb90b4691bc263647ab4b8fed893e3 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 24 Aug 2023 17:08:49 +0200 Subject: [PATCH] Refactor Torch/Torch PTQ to use experimental TensorCollector --- nncf/common/tensor_statistics/collectors.py | 38 +- .../tensor_statistics/statistic_point.py | 2 +- nncf/common/tensor_statistics/statistics.py | 7 + .../common/tensor_statistics/collectors.py | 224 ++++++--- .../tensorflow/quantization/init_range.py | 2 +- nncf/openvino/statistics/collectors.py | 45 +- nncf/openvino/statistics/statistics.py | 9 + .../fast_bias_correction/torch_backend.py | 7 +- .../algorithms/min_max/openvino_backend.py | 2 +- .../algorithms/min_max/torch_backend.py | 124 ++--- nncf/tensorflow/quantization/init_range.py | 4 +- nncf/torch/quantization/init_range.py | 98 ++-- nncf/torch/quantization/layers.py | 2 + nncf/torch/statistics/aggregator.py | 10 +- nncf/torch/tensor_statistics/algo.py | 13 +- nncf/torch/tensor_statistics/collectors.py | 466 ++++++++++++------ nncf/torch/tensor_statistics/statistics.py | 34 +- tests/common/test_statistics_aggregator.py | 8 +- .../common/test_reducers_and_aggregators.py | 212 ++++++-- .../test_reducers_and_aggregators.py | 14 +- .../test_templates/test_channel_alignment.py | 2 +- .../test_templates/test_quantizer_config.py | 4 +- .../test_templates/test_smooth_quant.py | 2 +- .../test_tensor_statistics.py | 16 +- tests/torch/ptq/test_ptq_params.py | 22 +- tests/torch/ptq/test_quantizer_config.py | 20 +- .../ptq/test_reducers_and_aggregators.py | 83 ++++ tests/torch/quantization/test_range_init.py | 98 ++-- .../test_tensor_statistics.py | 230 +++++---- tests/torch/test_statistics_aggregator.py | 2 +- 30 files changed, 1255 insertions(+), 545 deletions(-) create mode 100644 tests/torch/ptq/test_reducers_and_aggregators.py diff --git a/nncf/common/tensor_statistics/collectors.py b/nncf/common/tensor_statistics/collectors.py index 907ae30fec8..84dc8746b8b 100644 --- a/nncf/common/tensor_statistics/collectors.py +++ b/nncf/common/tensor_statistics/collectors.py @@ -44,7 +44,7 @@ def __init__(self, reduction_shape: Optional[ReductionShape] = None, num_samples def num_samples(self) -> int: return self._num_samples - def register_input(self, x: TensorType) -> TensorType: + def register_inputs(self, x: TensorType) -> TensorType: """Registers input tensor""" if not self._enabled: return x @@ -251,6 +251,11 @@ def unstack(x: NNCFTensor, axis: int = 0) -> List[NNCFTensor]: :return: List of NNCFTensor. """ + @staticmethod + @abstractmethod + def squeeze(x: NNCFTensor, dim: Optional[int] = None) -> NNCFTensor: + """""" + @staticmethod @abstractmethod def sum(tensor: NNCFTensor) -> TensorElementsType: @@ -278,6 +283,17 @@ def quantile( :returns: List of the quantile-th percentile(s) of the tensor elements. """ + @classmethod + @abstractmethod + def precentile( + cls, + tensor: NNCFTensor, + precentile: Union[float, List[float]], + axis: Union[int, tuple, list], + keepdims: bool = False, + ) -> List[TensorElementsType]: + """""" + @staticmethod @abstractmethod def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor: @@ -291,7 +307,9 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor: @classmethod @abstractmethod - def no_outliers_map(cls, x: NNCFTensor, fn: MaskedReduceFN, axis: int = 0, alpha: float = 0.01) -> NNCFTensor: + def no_outliers_map( + cls, x: NNCFTensor, fn: MaskedReduceFN, axis: Union[int, Tuple[int, ...]] = 0, alpha: float = 0.01 + ) -> NNCFTensor: """ Computes quantiles [alpha, 1 - alpha] on given tensor, masks all elements that are smaller that alpha and bigger than 1 - alpha quantile and applies @@ -305,6 +323,22 @@ def no_outliers_map(cls, x: NNCFTensor, fn: MaskedReduceFN, axis: int = 0, alpha :returns: Result of given masked reduction function on filtered from outliers NNCFTensor. """ + @classmethod + def masked_map(cls, x: NNCFTensor, fn: MaskedReduceFN, filter_fn) -> NNCFTensor: + """ """ + + @classmethod + def sub(cls, a: NNCFTensor, b: NNCFTensor) -> NNCFTensor: + """""" + + @classmethod + def filter_by_fn(cls, x: NNCFTensor, filter_fn) -> NNCFTensor: + """ """ + + @classmethod + def non_zero_elements(cls, x: NNCFTensor) -> NNCFTensor: + """ """ + class MinMaxStatisticCollector(OnlineTensorStatisticCollector): """Collector estimates min of minimum values and max of maximum values.""" diff --git a/nncf/common/tensor_statistics/statistic_point.py b/nncf/common/tensor_statistics/statistic_point.py index 3cf533ac7b6..735a56e6fef 100644 --- a/nncf/common/tensor_statistics/statistic_point.py +++ b/nncf/common/tensor_statistics/statistic_point.py @@ -38,7 +38,7 @@ def __eq__(self, other): def register_tensor(self, x: TensorType): for tensor_collectors in self.algorithm_to_tensor_collectors.values(): for tensor_collector in tensor_collectors: - tensor_collector.register_input(x) + tensor_collector.register_unnamed_inputs(x) class StatisticPointsContainer(UserDict): diff --git a/nncf/common/tensor_statistics/statistics.py b/nncf/common/tensor_statistics/statistics.py index 0f6d0d1aad3..e269b2dc42c 100644 --- a/nncf/common/tensor_statistics/statistics.py +++ b/nncf/common/tensor_statistics/statistics.py @@ -20,6 +20,8 @@ class TensorStatistic(ABC): """Base class that stores statistic data""" + TENSOR_STATISTIC_OUTPUT_KEY = "tensor_statistic_output" + @staticmethod @abstractmethod def tensor_eq(tensor1: TensorType, tensor2: TensorType, rtol=1e-6) -> bool: @@ -63,6 +65,9 @@ def __eq__(self, other: "MeanTensorStatistic") -> bool: class MedianMADTensorStatistic(TensorStatistic): + MEDIAN_VALUES_STAT = "median_values" + MAD_VALUES_STAT = "mad_values" + def __init__(self, median_values, mad_values): self.median_values = median_values self.mad_values = mad_values @@ -74,6 +79,8 @@ def __eq__(self, other: "MedianMADTensorStatistic") -> bool: class PercentileTensorStatistic(TensorStatistic): + PRECENTILE_VS_VALUE_DICT = "percentile_vs_values_dict" + def __init__(self, percentile_vs_values_dict): self.percentile_vs_values_dict = percentile_vs_values_dict diff --git a/nncf/experimental/common/tensor_statistics/collectors.py b/nncf/experimental/common/tensor_statistics/collectors.py index 3655fffe5d6..94f447732d8 100644 --- a/nncf/experimental/common/tensor_statistics/collectors.py +++ b/nncf/experimental/common/tensor_statistics/collectors.py @@ -13,6 +13,7 @@ from abc import abstractmethod from collections import defaultdict from collections import deque +from functools import partial from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union from nncf.common.tensor import TensorType @@ -31,16 +32,18 @@ class TensorReducerBase(ABC): the specified rule. Could handle tensors inplace or out of place. """ - def __init__(self, reduction_shape: Optional[ReductionShape] = None, inplace: bool = False): + def __init__(self, reduction_axes: Optional[ReductionShape] = None, inplace: bool = False, keepdims: bool = True): """ :param reduction_shape: Reduction shape for reduction calculation. Equal to list(range(len(input.shape))) if empty. :param inplace: Whether should be calculated inplace or out of place. - + :param keepdims: Should the axes which are reduced are left in the result + as dimensions with size one or not. """ - self._reduction_shape = reduction_shape + self._reduction_axes = reduction_axes self._tensor_processor: NNCFCollectorTensorProcessor = self._get_processor() self._inplace = inplace + self._keepdims = keepdims @property def inplace(self): @@ -95,16 +98,16 @@ def __call__(self, x: List[NNCFTensor]): def __eq__(self, __o: object) -> bool: return ( isinstance(__o, self.__class__) - and self._reduction_shape == __o._reduction_shape + and self._reduction_axes == __o._reduction_axes and self._inplace == __o.inplace ) def __hash__(self) -> int: - return hash((self.__class__.__name__, self.inplace, self._reduction_shape)) + return hash((self.__class__.__name__, self.inplace, self._reduction_axes)) def _get_reduction_shape(self, tensor: NNCFTensor) -> Union[int, Tuple[int, ...]]: - if self._reduction_shape is not None: - return self._reduction_shape + if self._reduction_axes is not None: + return self._reduction_axes return tuple(range(len(tensor.shape))) @@ -114,7 +117,13 @@ class TensorAggregatorBase: aggregate them in terms of NNCFCollectorTensorProcessor operations. """ - def __init__(self, tensor_processor: NNCFCollectorTensorProcessor, num_samples: Optional[int] = None): + def __init__( + self, + tensor_processor: NNCFCollectorTensorProcessor, + aggregation_axes: Optional[Tuple[int, ...]] = None, + keepdims: bool = False, + num_samples: Optional[int] = None, + ): """ :param tensor_processor: Backend-specific tensor processor. :param num_samples: Maximum number of samples to collect. Aggregator @@ -123,6 +132,8 @@ def __init__(self, tensor_processor: NNCFCollectorTensorProcessor, num_samples: """ self._tensor_processor = tensor_processor + self._aggregation_axes = (0,) if aggregation_axes is None else aggregation_axes + self._keepdims = keepdims self._num_samples = num_samples self._collected_samples = 0 self._container = [] @@ -187,8 +198,8 @@ class TensorCollector: def __init__(self, statistic_container: Optional[TensorStatistic] = None) -> None: self._reducers: Set[TensorReducerBase] = set() - self._aggregators: Dict[Tuple[int, int], TensorAggregatorBase] = {} - self._stat_container_kwargs_map: Dict[str, Tuple[int, int]] = {} + self._aggregators: Dict[Tuple[int, int, int], TensorAggregatorBase] = {} + self._stat_container_kwargs_map: Dict[str, Tuple[int, int, int]] = {} self._stat_container = statistic_container self._enabled = True @@ -290,6 +301,12 @@ def register_inputs(self, inputs: Dict[int, List[NNCFTensor]]) -> None: if reducer_hash in reduced_inputs: aggregator.register_reduced_input(reduced_inputs[reducer_hash][reducer_port_id]) + def register_unnamed_inputs(self, inputs: NNCFTensor): + formated_inputs = {} + for reducer in self._reducers: + formated_inputs[hash(reducer)] = [inputs] + self.register_inputs(formated_inputs) + def _aggregate(self) -> None: result = {} for ( @@ -315,7 +332,7 @@ def get_statistics(self) -> Union[TensorStatistic, Dict[str, Any]]: if not self._stat_container: return kwargs - return self._stat_container(**kwargs) + return self._stat_container(kwargs) def get_inplace_fn_info(self) -> List[Tuple[Any, int]]: """ @@ -425,67 +442,70 @@ class MinReducer(TensorReducerBase): def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: x = x[0] reduction_shape = self._get_reduction_shape(x) - return [self._tensor_processor.reduce_min(x, reduction_shape, keepdims=True)] + return [self._tensor_processor.reduce_min(x, reduction_shape, keepdims=self._keepdims)] class MaxReducer(TensorReducerBase): def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: x = x[0] reduction_shape = self._get_reduction_shape(x) - return [self._tensor_processor.reduce_max(x, reduction_shape, keepdims=True)] + return [self._tensor_processor.reduce_max(x, reduction_shape, keepdims=self._keepdims)] class AbsMaxReducer(TensorReducerBase): def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: x = self._tensor_processor.abs(x[0]) reduction_shape = self._get_reduction_shape(x) - return [self._tensor_processor.reduce_max(x, reduction_shape, keepdims=True)] + return [self._tensor_processor.reduce_max(x, reduction_shape, keepdims=self._keepdims)] class MeanReducer(TensorReducerBase): def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: x = x[0] reduction_shape = self._get_reduction_shape(x) - return [self._tensor_processor.mean(x, reduction_shape, keepdims=True)] + return [self._tensor_processor.mean(x, reduction_shape, keepdims=self._keepdims)] class QuantileReducerBase(TensorReducerBase): def __init__( self, - reduction_shape: Optional[ReductionShape] = None, + reduction_axes: Optional[ReductionShape] = None, quantile: Optional[Union[float, Tuple[float]]] = None, inplace: bool = False, + keepdims: bool = True, ): - super().__init__(reduction_shape, False) + super().__init__(reduction_axes, False, keepdims) self._quantile = (0.01, 0.99) if quantile is None else quantile def __eq__(self, __o: object) -> bool: return super().__eq__(__o) and self._quantile == __o._quantile def __hash__(self) -> int: - return hash((self.__class__.__name__, self.inplace, self._reduction_shape, tuple(self._quantile))) + return hash((self.__class__.__name__, self.inplace, self._reduction_axes, tuple(self._quantile))) class QuantileReducer(QuantileReducerBase): def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: x = x[0] reduction_shape = self._get_reduction_shape(x) - return self._tensor_processor.quantile(x, self._quantile, reduction_shape, keepdims=True) + return self._tensor_processor.quantile(x, self._quantile, reduction_shape, keepdims=self._keepdims) class AbsQuantileReducer(QuantileReducerBase): def __init__( self, - reduction_shape: Optional[ReductionShape] = None, - quantile: Union[float, List[float]] = 0.99, + reduction_axes: Optional[ReductionShape] = None, + quantile: Optional[Union[float, List[float]]] = None, inplace: bool = False, + keepdims: bool = True, ): - super().__init__(reduction_shape, quantile, False) + quantile = (0.99,) if quantile is None else quantile + super().__init__(reduction_axes, quantile, False, keepdims) def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: x = self._tensor_processor.abs(x[0]) reduction_shape = self._get_reduction_shape(x) - return self._tensor_processor.quantile(x, [self._quantile], reduction_shape, keepdims=True) + return self._tensor_processor.quantile(x, self._quantile, reduction_shape, keepdims=self._keepdims) class BatchMeanReducer(TensorReducerBase): @@ -501,7 +521,7 @@ def __init__(self, channel_dim: int = 1, inplace: bool = False): super().__init__(channel_dim, inplace) def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - return [self._tensor_processor.mean_per_channel(x[0], self._reduction_shape)] + return [self._tensor_processor.mean_per_channel(x[0], self._reduction_axes)] ##################################################Aggregators################################################## @@ -509,7 +529,7 @@ def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: class NoopAggregator(TensorAggregatorBase): def __init__(self, num_samples: Optional[int]): - super().__init__(None, num_samples) + super().__init__(None, num_samples=num_samples) def _register_reduced_input_impl(self, x: TensorType) -> None: self._container.append(x.tensor) @@ -520,7 +540,7 @@ def _aggregate_impl(self): class ShapeAggregator(TensorAggregatorBase): def __init__(self): - super().__init__(None, 1) + super().__init__(None, num_samples=1) def _register_reduced_input_impl(self, x: TensorType) -> None: self._container = x @@ -529,46 +549,61 @@ def _aggregate_impl(self): return self._container.shape -class MinAggregator(TensorAggregatorBase): - def _register_reduced_input_impl(self, x: TensorType) -> None: - if not self._container: - self._container = x +class OnlineOfflineAggregatorBase(TensorAggregatorBase): + def __init__( + self, + tensor_processor: NNCFCollectorTensorProcessor, + aggregation_axes: Optional[Tuple[int, ...]] = None, + keepdims: bool = False, + num_samples: Optional[int] = None, + window_size=None, + ): + super().__init__( + tensor_processor, aggregation_axes=aggregation_axes, keepdims=keepdims, num_samples=num_samples + ) + self._window_size = window_size + self._container = deque(maxlen=window_size) + + +class OnlineAggregatorBase(OnlineOfflineAggregatorBase, ABC): + def _online_register_reduced_input_impl(self, x: TensorType, fn) -> None: + online_aggregation_axes = tuple([dim - 1 for dim in self._aggregation_axes if dim != 0]) + if online_aggregation_axes: + reduced = fn(x, axis=online_aggregation_axes, keepdims=self._keepdims) else: - self._container = self._tensor_processor.min(x, self._container) + reduced = x + if 0 in self._aggregation_axes: + if self._container: + reduced = fn(self._tensor_processor.stack([reduced, self._container]), axis=0, keepdims=False) + self._container = reduced + else: + self._container.append(reduced) def _aggregate_impl(self): - return self._container.tensor + if 0 in self._aggregation_axes: + if self._keepdims: + return self._tensor_processor.stack([self._container]).tensor + return self._container.tensor + return self._tensor_processor.stack(self._container).tensor -class MaxAggregator(TensorAggregatorBase): +class MinAggregator(OnlineAggregatorBase): def _register_reduced_input_impl(self, x: TensorType) -> None: - if not self._container: - self._container = x - else: - self._container = self._tensor_processor.max(x, self._container) + return self._online_register_reduced_input_impl(x, self._tensor_processor.reduce_min) - def _aggregate_impl(self): - return self._container.tensor +class MaxAggregator(OnlineAggregatorBase): + def _register_reduced_input_impl(self, x: TensorType) -> None: + return self._online_register_reduced_input_impl(x, self._tensor_processor.reduce_max) -class OfflineAggregatorBase(TensorAggregatorBase, ABC): - def __init__( - self, tensor_processor, use_per_sample_stats: bool = False, num_samples: Optional[int] = None, window_size=None - ): - super().__init__(tensor_processor, num_samples) - self._window_size = window_size - self._container = deque(maxlen=window_size) - self._use_per_sample_stats = use_per_sample_stats +class OfflineAggregatorBase(OnlineOfflineAggregatorBase, ABC): def _register_reduced_input_impl(self, x: TensorType) -> None: - if self._use_per_sample_stats: - self._container.extend(self._tensor_processor.unstack(x)) - else: - self._container.append(x) + self._container.append(x) def _offline_aggregation_impl(self, fn): stacked_val = self._tensor_processor.stack(self._container) - return fn(stacked_val, axis=0, keepdims=False).tensor + return fn(stacked_val, axis=self._aggregation_axes, keepdims=self._keepdims).tensor class MeanAggregator(OfflineAggregatorBase): @@ -584,18 +619,25 @@ def _aggregate_impl(self): class NoOutliersAggregatorBase(OfflineAggregatorBase, ABC): def __init__( self, - tensor_processor, - use_per_sample_stats: bool = False, + tensor_processor: NNCFCollectorTensorProcessor, + aggregation_axes: Optional[Tuple[int, ...]] = None, + keepdims: bool = False, num_samples: Optional[int] = None, window_size=None, quantile: float = 0.01, ): - super().__init__(tensor_processor, use_per_sample_stats, num_samples, window_size) + super().__init__( + tensor_processor, aggregation_axes=aggregation_axes, keepdims=keepdims, num_samples=num_samples + ) + self._window_size = window_size + self._container = deque(maxlen=window_size) self._quantile = quantile def _offline_aggregation_impl(self, fn) -> List[NNCFTensor]: stacked_val = self._tensor_processor.stack(self._container) - result = self._tensor_processor.no_outliers_map(stacked_val, fn, axis=0, alpha=self._quantile) + result = self._tensor_processor.no_outliers_map( + stacked_val, fn, axis=self._aggregation_axes, alpha=self._quantile, keepdims=self._keepdims + ) return result.tensor def __eq__(self, __o: object) -> bool: @@ -607,12 +649,78 @@ def __hash__(self) -> int: class MeanNoOutliersAggregator(NoOutliersAggregatorBase): def _aggregate_impl(self) -> Any: - return self._offline_aggregation_impl(self._tensor_processor.masked_mean) + return self._offline_aggregation_impl(partial(self._tensor_processor.masked_mean, keepdims=self._keepdims)) class MedianNoOutliersAggregator(NoOutliersAggregatorBase): def _aggregate_impl(self) -> Any: - return self._offline_aggregation_impl(self._tensor_processor.masked_median) + return self._offline_aggregation_impl(partial(self._tensor_processor.masked_median, keepdims=self._keepdims)) + + +class MedianAbsoluteDeviationAggregator(OnlineOfflineAggregatorBase): + def _register_reduced_input_impl(self, x: TensorType) -> None: + return self._container.append(x) + + def _aggregate_impl(self) -> Any: + stacked_val = self._tensor_processor.stack(self._container) + median_fn = partial(self._tensor_processor.masked_median, axis=self._aggregation_axes, keepdims=True) + filter_fn = self._tensor_processor.non_zero_elements + median_per_ch = self._tensor_processor.masked_map(stacked_val, median_fn, filter_fn) + + mad_values = self._tensor_processor.median( + self._tensor_processor.abs(self._tensor_processor.sub(stacked_val, median_per_ch)), + axis=self._aggregation_axes, + keepdims=self._keepdims, + ) + if not self._keepdims: + median_per_ch = self._tensor_processor.squeeze(median_per_ch, self._aggregation_axes) + return {"median_values": median_per_ch.tensor, "mad_values": mad_values.tensor} + + +class PrecentileAggregator(OnlineOfflineAggregatorBase): + def __init__( + self, + tensor_processor: NNCFCollectorTensorProcessor, + percentiles_to_collect: List[float], + aggregation_axes: Optional[Tuple[int, ...]] = None, + keepdims: bool = False, + num_samples: Optional[int] = None, + window_size=None, + ): + super().__init__( + tensor_processor, aggregation_axes=aggregation_axes, keepdims=keepdims, num_samples=num_samples + ) + self._precentiles_to_collect = percentiles_to_collect + self._window_size = window_size + self._container = deque(maxlen=window_size) + + def _register_reduced_input_impl(self, x: TensorType) -> None: + return self._container.append(x) + + def _aggregate_impl(self) -> Any: + stacked_val = self._tensor_processor.stack(self._container) + + precentiles = self._tensor_processor.precentile( + stacked_val, self._precentiles_to_collect, axis=self._aggregation_axes, keepdims=self._keepdims + ) + retval = {} + for idx, precentile in enumerate(self._precentiles_to_collect): + retval[precentile] = precentiles[idx].tensor + return retval + + +class PostAggregateHook(TensorAggregatorBase, ABC): + def __init__(self, aggregator: TensorAggregatorBase, post_aggregation_hook): + super().__init__(None) + self._aggregator = aggregator + self._post_aggregation_hook = post_aggregation_hook + + def _register_reduced_input_impl(self, x: TensorType) -> None: + return self._aggregator.register_reduced_input(x) + + def _aggregate_impl(self) -> Any: + retval = self._aggregator.aggregate() + return self._post_aggregation_hook(retval) AGGREGATORS_MAP = { diff --git a/nncf/experimental/tensorflow/quantization/init_range.py b/nncf/experimental/tensorflow/quantization/init_range.py index f514177c717..7899dfe35c6 100644 --- a/nncf/experimental/tensorflow/quantization/init_range.py +++ b/nncf/experimental/tensorflow/quantization/init_range.py @@ -65,7 +65,7 @@ def _register_op_collector(self, op, collectors, handles, op_weights): collector = RangeInitializerV2.generate_stat_collector( reduction_shape, collector_params, init_config, num_batches ) - handles.append(op.register_hook_pre_call(collector.register_input)) + handles.append(op.register_hook_pre_call(collector.register_inputs)) op.enabled = False collectors.append((op, collector, op_weights)) diff --git a/nncf/openvino/statistics/collectors.py b/nncf/openvino/statistics/collectors.py index 61ef776fab7..94c3dcecc52 100644 --- a/nncf/openvino/statistics/collectors.py +++ b/nncf/openvino/statistics/collectors.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Deque, List, Optional, Union +from typing import Any, Callable, Deque, List, Optional, Tuple, Union import numpy as np @@ -49,11 +49,11 @@ class OVNNCFCollectorTensorProcessor(NNCFCollectorTensorProcessor): """ @staticmethod - def reduce_min(x: NNCFTensor, axis: Union[int, tuple], keepdims: bool = True) -> NNCFTensor: + def reduce_min(x: NNCFTensor, axis: Union[int, Tuple], keepdims: bool = True) -> NNCFTensor: return OVNNCFTensor(np.amin(x.tensor, axis=axis, keepdims=keepdims)) @staticmethod - def reduce_max(x: NNCFTensor, axis: Union[int, tuple], keepdims: bool = True) -> NNCFTensor: + def reduce_max(x: NNCFTensor, axis: Union[int, Tuple], keepdims: bool = True) -> NNCFTensor: return OVNNCFTensor(np.amax(x.tensor, axis=axis, keepdims=keepdims)) @staticmethod @@ -69,25 +69,25 @@ def max(x1: NNCFTensor, x2: NNCFTensor) -> NNCFTensor: return OVNNCFTensor(np.maximum(x1.tensor, x2.tensor)) @staticmethod - def mean(x: NNCFTensor, axis: Union[int, tuple], keepdims: bool = False) -> NNCFTensor: + def mean(x: NNCFTensor, axis: Union[int, Tuple], keepdims: bool = False) -> NNCFTensor: return OVNNCFTensor(np.mean(x.tensor, axis=axis, keepdims=keepdims)) @staticmethod - def median(x: NNCFTensor, axis: Union[int, tuple, list], keepdims: bool = False) -> NNCFTensor: + def median(x: NNCFTensor, axis: Union[int, Tuple, list], keepdims: bool = False) -> NNCFTensor: return OVNNCFTensor(np.median(x.tensor, axis=axis, keepdims=keepdims)) @classmethod def masked_mean( - cls, x: NNCFTensor, axis: Optional[Union[int, tuple, list]], mask: Optional[NNCFTensor], keepdims: bool = False + cls, x: NNCFTensor, axis: Optional[Union[int, Tuple, list]], mask: Optional[NNCFTensor], keepdims: bool = False ) -> NNCFTensor: if mask is None: return cls.mean(x, axis=axis, keepdims=keepdims) masked_x = np.ma.array(x.tensor, mask=mask.tensor) - return OVNNCFTensor(np.ma.mean(masked_x, axis=axis, keepdims=False).data) + return OVNNCFTensor(np.ma.mean(masked_x, axis=axis, keepdims=keepdims).data) @classmethod def masked_median( - cls, x: NNCFTensor, axis: Optional[Union[int, tuple, list]], mask: Optional[NNCFTensor], keepdims: bool = False + cls, x: NNCFTensor, axis: Optional[Union[int, Tuple, list]], mask: Optional[NNCFTensor], keepdims: bool = False ) -> NNCFTensor: if mask is None: return cls.median(x, axis=axis, keepdims=keepdims) @@ -107,20 +107,20 @@ def no_outliers_map( cls, x: NNCFTensor, fn: Callable[[NNCFTensor, int, NNCFTensor], Any], - axis: int = 0, + axis: Union[int, Tuple[int, ...]] = 0, alpha: float = 0.01, keepdims: bool = False, ) -> NNCFTensor: - if len(x.shape) == 1: - return fn(x, axis=None, mask=None, keepdims=keepdims) + if isinstance(axis, int): + axis = (axis,) - x = x.tensor - if axis: - x = np.moveaxis(x, axis, 0) + if len(axis) == len(x.shape): + return fn(x, axis=axis, mask=None, keepdims=keepdims) - low_values, high_values = np.quantile(x, [alpha, 1 - alpha], 0) + x = x.tensor + low_values, high_values = np.quantile(x, [alpha, 1 - alpha], axis=axis) outliers_mask = np.logical_or(x < low_values, high_values < x) - return fn(OVNNCFTensor(x), axis=0, mask=OVNNCFTensor(outliers_mask), keepdims=keepdims) + return fn(OVNNCFTensor(x), axis=axis, mask=OVNNCFTensor(outliers_mask), keepdims=keepdims) @staticmethod def batch_mean(x: NNCFTensor) -> NNCFTensor: @@ -141,7 +141,7 @@ def sum(tensor: NNCFTensor) -> TensorElementsType: @staticmethod def quantile( - tensor: NNCFTensor, quantile: Union[float, List[float]], axis: Union[int, tuple, list], keepdims: bool = False + tensor: NNCFTensor, quantile: Union[float, List[float]], axis: Union[int, Tuple, list], keepdims: bool = False ) -> List[NNCFTensor]: result = np.quantile(tensor.tensor, quantile, axis, keepdims=keepdims) return [OVNNCFTensor(x) for x in result] @@ -157,7 +157,7 @@ def _get_processor(self): return OVNNCFCollectorTensorProcessor def get_inplace_fn(self): - return get_inplace_min_op(self.name, self._reduction_shape) + return get_inplace_min_op(self.name, self._reduction_axes) def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace) @@ -168,7 +168,7 @@ def _get_processor(self): return OVNNCFCollectorTensorProcessor def get_inplace_fn(self): - return get_inplace_max_op(self.name, self._reduction_shape, False) + return get_inplace_max_op(self.name, self._reduction_axes, False) def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace) @@ -179,7 +179,7 @@ def _get_processor(self): return OVNNCFCollectorTensorProcessor def get_inplace_fn(self): - return get_inplace_max_op(self.name, self._reduction_shape, True) + return get_inplace_max_op(self.name, self._reduction_axes, True) def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace) @@ -190,7 +190,7 @@ def _get_processor(self): return OVNNCFCollectorTensorProcessor def get_inplace_fn(self): - return get_inplace_mean_op(self.name, self._reduction_shape) + return get_inplace_mean_op(self.name, self._reduction_axes) def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace) @@ -212,7 +212,7 @@ def _get_processor(self): return OVNNCFCollectorTensorProcessor def get_inplace_fn(self): - return get_inplace_mean_per_ch(self.name, self._reduction_shape) + return get_inplace_mean_per_ch(self.name, self._reduction_axes) def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace) @@ -252,7 +252,6 @@ def get_mean_stat_collector(num_samples, channel_axis, window_size=None, inplace kwargs = { "tensor_processor": OVNNCFCollectorTensorProcessor, - "use_per_sample_stats": False, "num_samples": num_samples, "window_size": window_size, } diff --git a/nncf/openvino/statistics/statistics.py b/nncf/openvino/statistics/statistics.py index 12a0c82af9c..de2cf3e23e2 100644 --- a/nncf/openvino/statistics/statistics.py +++ b/nncf/openvino/statistics/statistics.py @@ -17,18 +17,27 @@ class OVMinMaxTensorStatistic(MinMaxTensorStatistic): + def __init__(self, tensor_collector_output): + super().__init__(tensor_collector_output[self.MIN_STAT], tensor_collector_output[self.MAX_STAT]) + @staticmethod def tensor_eq(tensor1: np.ndarray, tensor2: np.ndarray, rtol=1e-6) -> bool: return bool(np.allclose(tensor1, tensor2, rtol=rtol)) class OVMeanTensorStatistic(MeanTensorStatistic): + def __init__(self, tensor_collector_output): + super().__init__(tensor_collector_output[self.MEAN_STAT], tensor_collector_output[self.SHAPE_STAT]) + @staticmethod def tensor_eq(tensor: np.ndarray, rtol=1e-6) -> bool: return bool(np.all(tensor, rtol=rtol)) class OVRawTensorStatistic(RawTensorStatistic): + def __init__(self, tensor_collector_output): + super().__init__(tensor_collector_output[self.VALUES_STATS]) + @staticmethod def tensor_eq(tensor: np.ndarray, rtol=1e-6) -> bool: return bool(np.all(tensor, rtol=rtol)) diff --git a/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py b/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py index b7316724db0..173acaf05e7 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py @@ -20,6 +20,7 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.tensor_statistics.collectors import ReductionShape from nncf.common.utils.backend import BackendType +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.quantization.algorithms.fast_bias_correction.backend import ALGO_BACKENDS from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend from nncf.torch.graph.transformations.command_creation import create_bias_correction_command @@ -32,8 +33,8 @@ from nncf.torch.model_analyzer import is_quantized_weights from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.tensor import PTNNCFTensor -from nncf.torch.tensor_statistics.collectors import PTMeanStatisticCollector from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor +from nncf.torch.tensor_statistics.collectors import get_mean_stat_collector @ALGO_BACKENDS.register(BackendType.TORCH) @@ -71,8 +72,8 @@ def mean_statistic_collector( inplace: bool, num_samples: Optional[int] = None, window_size: Optional[int] = None, - ) -> PTMeanStatisticCollector: - return PTMeanStatisticCollector(reduction_shape, num_samples, window_size) + ) -> TensorCollector: + return get_mean_stat_collector(num_samples, reduction_shape, window_size) @staticmethod def get_sub_input_output_names(subgraph: NNCFNetwork) -> Tuple[str, str]: diff --git a/nncf/quantization/algorithms/min_max/openvino_backend.py b/nncf/quantization/algorithms/min_max/openvino_backend.py index 42d89b76051..4e2c6d05ee5 100644 --- a/nncf/quantization/algorithms/min_max/openvino_backend.py +++ b/nncf/quantization/algorithms/min_max/openvino_backend.py @@ -181,7 +181,7 @@ def get_statistic_collector( f"Aggregator type: {params.aggregator_type} is not supported for OpenVino PTQ backend yet." ) - kwargs = {"reduction_shape": reduction_shape, "inplace": inplace} + kwargs = {"reduction_axes": reduction_shape, "inplace": inplace} if params.statistics_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]: if container_key == OVMinMaxTensorStatistic.MIN_STAT: quantile = params.quantile_outlier_prob diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index 59c410274f8..80701b57fd0 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -28,6 +28,9 @@ from nncf.common.quantization.structs import QuantizationMode from nncf.common.quantization.structs import QuantizerConfig from nncf.common.utils.backend import BackendType +from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP +from nncf.experimental.common.tensor_statistics.collectors import PostAggregateHook +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.parameters import ModelType from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import AggregatorType @@ -49,8 +52,9 @@ from nncf.torch.quantization.layers import BaseQuantizer from nncf.torch.quantization.layers import PTQuantizerSpec from nncf.torch.quantization.layers import get_scale_shape -from nncf.torch.tensor_statistics.collectors import PTMeanMinMaxStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMinMaxStatisticCollector +from nncf.torch.tensor import PTNNCFTensor +from nncf.torch.tensor_statistics.collectors import PT_REDUCERS_MAP +from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor from nncf.torch.tensor_statistics.statistics import PTMinMaxTensorStatistic @@ -155,32 +159,61 @@ def get_statistic_collector( quantizer_config: QuantizerConfig, inplace: bool, num_samples: int = None, - ) -> Union[PTMinMaxStatisticCollector, PTMeanMinMaxStatisticCollector]: - if ( - range_estimator_params.min.statistics_type == StatisticsType.MIN - and range_estimator_params.min.aggregator_type == AggregatorType.MIN - and range_estimator_params.max.statistics_type == StatisticsType.MAX - and range_estimator_params.max.aggregator_type == AggregatorType.MAX + ) -> TensorCollector: + collector_params = PTMinMaxAlgoBackend._default_collector_params(nncf_graph, target_point, quantizer_config) + collector_kwargs = collector_params.convert_statistic_params(per_sample_stats=False) + + collector = TensorCollector(PTMinMaxTensorStatistic) + for params, container_key in zip( + [range_estimator_params.min, range_estimator_params.max], + [PTMinMaxTensorStatistic.MIN_STAT, PTMinMaxTensorStatistic.MAX_STAT], ): - collector_name = "min_max" - - elif ( - range_estimator_params.min.statistics_type == StatisticsType.MIN - and range_estimator_params.min.aggregator_type == AggregatorType.MEAN - and range_estimator_params.max.statistics_type == StatisticsType.MAX - and range_estimator_params.max.aggregator_type == AggregatorType.MEAN - ): - collector_name = "mean_min_max" - - else: - raise RuntimeError( - "The following range estimator parameters are not supported by PyTorch backend by now: " - f"{str(range_estimator_params)}" - ) - - return PTMinMaxAlgoBackend._statistic_collector_builder( - collector_name, nncf_graph, target_point, quantizer_config, num_samples - ) + if not params.statistics_type in PT_REDUCERS_MAP: + raise RuntimeError( + f"Statistic type: {params.statistics_type} is not supported for Torch PTQ backend yet." + ) + + if not params.aggregator_type in AGGREGATORS_MAP: + raise RuntimeError( + f"Aggregator type: {params.aggregator_type} is not supported for Torch PTQ backend yet." + ) + + kwargs = { + "reduction_axes": collector_kwargs["reducers_axes"], + "keepdims": collector_kwargs["reducers_keepdims"], + } + if params.statistics_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]: + if container_key == PTMinMaxTensorStatistic.MIN_STAT: + quantile = params.quantile_outlier_prob + else: + quantile = 1 - params.quantile_outlier_prob + kwargs.update({"quantile": [quantile]}) + # TODO(dlyakhov): merge two quantile aggregators in one + + statistic_type = params.statistics_type + if collector_params.use_abs_max and statistic_type == StatisticsType.MAX: + statistic_type = StatisticsType.ABS_MAX + reducer = PT_REDUCERS_MAP[statistic_type](**kwargs) + + kwargs = { + "aggregation_axes": collector_kwargs["aggregators_axes"], + "keepdims": collector_kwargs["aggregators_keepdims"], + "num_samples": num_samples, + "tensor_processor": PTNNCFCollectorTensorProcessor, + } + aggregator = AGGREGATORS_MAP[params.aggregator_type](**kwargs) + + if collector_kwargs["squeeze_dims"] is not None: + + def post_aggregation_hook(aggregated_value): + return PTNNCFCollectorTensorProcessor.squeeze( + PTNNCFTensor(aggregated_value), dim=collector_kwargs["squeeze_dims"] + ).tensor + + aggregator = PostAggregateHook(aggregator=aggregator, post_aggregation_hook=post_aggregation_hook) + + collector.register_statistic_branch(container_key, reducer, aggregator) + return collector @staticmethod def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]: @@ -223,37 +256,18 @@ def _get_input_scale_shape( return input_shape, scale_shape, channel_idx @staticmethod - def _default_collector_params_and_scale_shape( + def _default_collector_params( nncf_graph: NNCFGraph, target_point: PTTargetPoint, quantizer_config: QuantizerConfig - ) -> Tuple[PTRangeInitCollectorParams, Tuple[int, ...]]: - input_shape, scale_shape, channel_idx = PTMinMaxAlgoBackend._get_input_scale_shape( - nncf_graph, target_point, quantizer_config - ) - return ( - PTRangeInitCollectorParams( - is_weights=target_point.is_weight_target_point(), - mode=quantizer_config.mode, - per_channel=quantizer_config.per_channel, - input_shape=input_shape, - channel_idx=channel_idx, - ), - scale_shape, - ) - - @staticmethod - def _statistic_collector_builder( - collector_name: str, - nncf_graph: NNCFGraph, - target_point: PTTargetPoint, - quantizer_config: QuantizerConfig, - num_samples: int = None, - ) -> PTMeanMinMaxStatisticCollector: - collector_params, scale_shape = PTMinMaxAlgoBackend._default_collector_params_and_scale_shape( + ) -> PTRangeInitCollectorParams: + input_shape, _, channel_idx = PTMinMaxAlgoBackend._get_input_scale_shape( nncf_graph, target_point, quantizer_config ) - init_config = RangeInitConfig(collector_name, num_samples) - return StatCollectorGenerator.generate_stat_collector_for_range_init_config( - init_config, scale_shape, collector_params, num_samples + return PTRangeInitCollectorParams( + is_weights=target_point.is_weight_target_point(), + mode=quantizer_config.mode, + per_channel=quantizer_config.per_channel, + input_shape=input_shape, + channel_idx=channel_idx, ) @staticmethod diff --git a/nncf/tensorflow/quantization/init_range.py b/nncf/tensorflow/quantization/init_range.py index 2d6bff0a405..2f54ddeb61e 100644 --- a/nncf/tensorflow/quantization/init_range.py +++ b/nncf/tensorflow/quantization/init_range.py @@ -156,7 +156,7 @@ def _register_layer_statistics(self, layer: tf.keras.layers.Layer, layer_statist collector = RangeInitializer.generate_stat_collector( reduction_shape, collector_params, init_config, num_batches ) - handles.append(layer.register_hook_pre_quantizer(collector.register_input)) + handles.append(layer.register_hook_pre_quantizer(collector.register_inputs)) layer.enabled = False layer_statistics.append((layer, collector)) @@ -180,7 +180,7 @@ def _register_op_statistics(self, layer: tf.keras.layers.Layer, op_statistics: l collector = RangeInitializer.generate_stat_collector( reduction_shape, collector_params, init_config, num_batches ) - handles.append(op.register_hook_pre_call(collector.register_input)) + handles.append(op.register_hook_pre_call(collector.register_inputs)) op.enabled = False op_statistics.append((layer, op_name, op, collector)) diff --git a/nncf/torch/quantization/init_range.py b/nncf/torch/quantization/init_range.py index 259e7ec1ada..270c406c7cf 100644 --- a/nncf/torch/quantization/init_range.py +++ b/nncf/torch/quantization/init_range.py @@ -37,13 +37,13 @@ from nncf.torch.quantization.layers import SymmetricQuantizer from nncf.torch.quantization.layers import get_scale_shape from nncf.torch.quantization.translator import PTTargetPointTranslator +from nncf.torch.tensor import PTNNCFTensor from nncf.torch.tensor_statistics.algo import TensorStatisticObservationPoint -from nncf.torch.tensor_statistics.collectors import PTMeanMinMaxStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMeanPercentileStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMedianMADStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMinMaxStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMixedMinMaxStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTPercentileStatisticCollector +from nncf.torch.tensor_statistics.collectors import get_mean_percentile_statistic_collector +from nncf.torch.tensor_statistics.collectors import get_median_mad_statistic_collector +from nncf.torch.tensor_statistics.collectors import get_min_max_statistic_collector +from nncf.torch.tensor_statistics.collectors import get_mixed_min_max_statistic_collector +from nncf.torch.tensor_statistics.collectors import get_precentile_tensor_collector from nncf.torch.tensor_statistics.statistics import pt_convert_stat_to_min_max_tensor_stat @@ -103,7 +103,7 @@ def __init__( self._input_shape = input_shape self._channel_idx = channel_idx - def convert_reduction_shape(self, per_sample_stats) -> ReductionShape: + def convert_reduction_axes(self, per_sample_stats) -> ReductionShape: """ Calculates the reduction shape of the tensor. @@ -115,10 +115,30 @@ def convert_reduction_shape(self, per_sample_stats) -> ReductionShape: if self._per_channel: val = (ndims + self._channel_idx) % ndims reduction_shape.remove(val) + if not val and self.use_per_sample_stats(per_sample_stats): + raise RuntimeError("Batch dimension should be equal to zero") if self.use_per_sample_stats(per_sample_stats): reduction_shape = reduction_shape[1:] # Assumes batch is the first dimension return tuple(reduction_shape) + def convert_statistic_params(self, per_sample_stats): + reducer_axes = self.convert_reduction_axes(per_sample_stats) + reducer_keep_dims = self._per_channel + aggregator_axes = [0] + aggregator_keep_dims = not self._per_channel + squeeze_dims = None + if self.use_per_sample_stats(per_sample_stats): + aggregator_axes += [1] + aggregator_keep_dims = True + squeeze_dims = (0,) + return { + "reducers_axes": reducer_axes, + "reducers_keepdims": reducer_keep_dims, + "aggregators_axes": tuple(aggregator_axes), + "aggregators_keepdims": aggregator_keep_dims, + "squeeze_dims": squeeze_dims, + } + class StatCollectorGenerator: @staticmethod @@ -154,8 +174,8 @@ def generate_collectors_for_range_init_statistics_collection( @staticmethod def generate_stat_collector_for_range_init_config( init_config: RangeInitConfig, - reduction_shape: ReductionShape = None, - collector_params=None, + scale_shape: ReductionShape = None, + collector_params: PTRangeInitCollectorParams = None, num_samples_to_collect_override: int = None, ) -> TensorStatisticCollectorBase: num_samples = init_config.num_init_samples @@ -163,41 +183,54 @@ def generate_stat_collector_for_range_init_config( num_samples = num_samples_to_collect_override if init_config.init_type not in RANGE_INIT_TYPES_VS_DESCRIPTIONS: raise RuntimeError("Unknown range init type: {}".format(init_config.init_type)) + + use_per_sample_stats = collector_params.use_per_sample_stats(init_config.init_type == "mixed_min_max") + collector_kwargs = collector_params.convert_statistic_params(use_per_sample_stats) + if init_config.init_type == "min_max": - reduction_shape_converted = collector_params.convert_reduction_shape(per_sample_stats=False) - return PTMinMaxStatisticCollector( - collector_params.use_abs_max, reduction_shape_converted, reduction_shape, num_samples + return get_min_max_statistic_collector( + use_abs_max=collector_params.use_abs_max, + num_samples=num_samples, + **collector_kwargs, ) if init_config.init_type == "mixed_min_max": - reduction_shape_converted = collector_params.convert_reduction_shape(per_sample_stats=True) - return PTMixedMinMaxStatisticCollector( - collector_params.use_per_sample_stats(per_sample_stats=True), - collector_params.use_abs_max, - collector_params.use_means_of_mins, - collector_params.use_means_of_maxs, - reduction_shape_converted, - reduction_shape, - num_samples, + return get_mixed_min_max_statistic_collector( + use_abs_max=collector_params.use_abs_max, + use_means_of_mins=collector_params.use_means_of_mins, + use_means_of_maxs=collector_params.use_means_of_maxs, + num_samples=num_samples, + **collector_kwargs, ) if init_config.init_type == "mean_min_max": - reduction_shape_converted = collector_params.convert_reduction_shape(per_sample_stats=False) - return PTMeanMinMaxStatisticCollector( - collector_params.use_per_sample_stats(per_sample_stats=False), - collector_params.use_abs_max, - reduction_shape_converted, - reduction_shape, - num_samples, + return get_mixed_min_max_statistic_collector( + use_abs_max=collector_params.use_abs_max, + use_means_of_mins=True, + use_means_of_maxs=True, + num_samples=num_samples, + **collector_kwargs, ) if init_config.init_type == "threesigma": - return PTMedianMADStatisticCollector(reduction_shape, num_samples) + return get_median_mad_statistic_collector( + num_samples=num_samples, + **collector_kwargs, + ) if init_config.init_type == "percentile": min_percentile = init_config.init_type_specific_params.get("min_percentile", 0.1) max_percentile = init_config.init_type_specific_params.get("max_percentile", 99.9) - return PTPercentileStatisticCollector([min_percentile, max_percentile], reduction_shape, num_samples) + return get_precentile_tensor_collector( + percentiles_to_collect=[min_percentile, max_percentile], + num_samples=num_samples, + **collector_kwargs, + ) + if init_config.init_type == "mean_percentile": min_percentile = init_config.init_type_specific_params.get("min_percentile", 0.1) max_percentile = init_config.init_type_specific_params.get("max_percentile", 99.9) - return PTMeanPercentileStatisticCollector([min_percentile, max_percentile], reduction_shape, num_samples) + return get_mean_percentile_statistic_collector( + percentiles_to_collect=[min_percentile, max_percentile], + num_samples=num_samples, + **collector_kwargs, + ) raise ValueError("Range init type not handled!") @classmethod @@ -251,7 +284,8 @@ def __init__( def _get_fwd_hook(self, collector: TensorStatisticCollectorBase) -> Callable: def fwd_hook(module, input_, output): - collector.register_input(input_[0]) + collector.register_unnamed_inputs(PTNNCFTensor(input_[0])) + return input_[0] return fwd_hook diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index 4a588ddaa8d..cde70c59af3 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -997,6 +997,8 @@ def get_quantizer_config(self) -> QuantizerConfig: def get_per_channel_scale_shape(input_shape, is_weights, channel_idx: int = None): + # TODO: case channel_ids=0, is_weights=True and per_sample_stats=True + # leads to dimension error in statistic calculation scale_shape = [1 for _ in input_shape] if channel_idx is None: if is_weights: diff --git a/nncf/torch/statistics/aggregator.py b/nncf/torch/statistics/aggregator.py index 6c2c48256c6..84cb2b63a73 100644 --- a/nncf/torch/statistics/aggregator.py +++ b/nncf/torch/statistics/aggregator.py @@ -41,6 +41,14 @@ def _get_transformation_layout_extra_outputs( ) -> TransformationLayout: transformation_layout = TransformationLayout() transformation_commands = [] + + def register_inputs_fn(fn): + def register_inputs(input_: torch.Tensor): + fn(PTNNCFTensor(input_)) + return input_ + + return register_inputs + for _statistic_points in statistic_points.values(): for _statistic_point in _statistic_points: for collectors in _statistic_point.algorithm_to_tensor_collectors.values(): @@ -48,7 +56,7 @@ def _get_transformation_layout_extra_outputs( transformation_commands.append( PTInsertionCommand( _statistic_point.target_point, - collector.register_input, + register_inputs_fn(collector.register_unnamed_inputs), TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION, ) ) diff --git a/nncf/torch/tensor_statistics/algo.py b/nncf/torch/tensor_statistics/algo.py index 0f9c56f14e5..3b63056f3b0 100644 --- a/nncf/torch/tensor_statistics/algo.py +++ b/nncf/torch/tensor_statistics/algo.py @@ -8,6 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Dict, Set from nncf.api.compression import CompressionStage @@ -24,6 +25,7 @@ from nncf.torch.graph.transformations.commands import TransformationPriority from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.tensor import PTNNCFTensor class TensorStatisticObservationPoint: @@ -54,9 +56,16 @@ def _get_transformation_layout(self, target_model: NNCFNetwork) -> PTTransformat layout = PTTransformationLayout() for op, rs_vs_collector in self._observation_points_vs_collectors.items(): for collector in rs_vs_collector.values(): - hook_obj = collector.register_input + + def hook_obj(x, collector): + collector.register_unnamed_inputs(PTNNCFTensor(x)) + return x + command = PTInsertionCommand( - op.target_point, hook_obj, TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION + op.target_point, + # collector.register_inputs, + partial(hook_obj, collector=collector), + TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION, ) layout.register(command) return layout diff --git a/nncf/torch/tensor_statistics/collectors.py b/nncf/torch/tensor_statistics/collectors.py index 186bbb18393..a60a6516ab4 100644 --- a/nncf/torch/tensor_statistics/collectors.py +++ b/nncf/torch/tensor_statistics/collectors.py @@ -9,12 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Deque, List, Optional, Union +from functools import partial +from typing import Any, Callable, Deque, List, Optional, Tuple, Union +import numpy as np import torch +import torch.masked as mt from nncf.common.tensor import NNCFTensor from nncf.common.tensor import TensorElementsType +from nncf.common.tensor_statistics.collectors import MaskedReduceFN from nncf.common.tensor_statistics.collectors import MeanMinMaxStatisticCollector from nncf.common.tensor_statistics.collectors import MeanPercentileStatisticCollector from nncf.common.tensor_statistics.collectors import MeanStatisticCollector @@ -22,9 +26,28 @@ from nncf.common.tensor_statistics.collectors import MinMaxStatisticCollector from nncf.common.tensor_statistics.collectors import MixedMinMaxStatisticCollector from nncf.common.tensor_statistics.collectors import NNCFCollectorTensorProcessor +from nncf.common.tensor_statistics.collectors import NNCFTensor from nncf.common.tensor_statistics.collectors import PercentileStatisticCollector from nncf.common.tensor_statistics.collectors import ReductionShape from nncf.common.tensor_statistics.reduction import np_percentile_reduce_like +from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer +from nncf.experimental.common.tensor_statistics.collectors import AbsQuantileReducer +from nncf.experimental.common.tensor_statistics.collectors import BatchMeanReducer +from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator +from nncf.experimental.common.tensor_statistics.collectors import MaxReducer +from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator +from nncf.experimental.common.tensor_statistics.collectors import MeanPerChReducer +from nncf.experimental.common.tensor_statistics.collectors import MeanReducer +from nncf.experimental.common.tensor_statistics.collectors import MedianAbsoluteDeviationAggregator +from nncf.experimental.common.tensor_statistics.collectors import MinAggregator +from nncf.experimental.common.tensor_statistics.collectors import MinReducer +from nncf.experimental.common.tensor_statistics.collectors import NoopReducer +from nncf.experimental.common.tensor_statistics.collectors import PostAggregateHook +from nncf.experimental.common.tensor_statistics.collectors import PrecentileAggregator +from nncf.experimental.common.tensor_statistics.collectors import QuantileReducer +from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.quantization.advanced_parameters import StatisticsType from nncf.torch.dynamic_graph.context import no_nncf_trace from nncf.torch.tensor import PTNNCFTensor from nncf.torch.tensor_statistics.reduction import expand_like @@ -51,13 +74,15 @@ def reduce_max(x: NNCFTensor, axis: Union[int, tuple, list], keepdims: bool = Fa def abs(x: NNCFTensor) -> NNCFTensor: return PTNNCFTensor(torch.abs(x.tensor)) - @staticmethod - def min(x1: NNCFTensor, x2: NNCFTensor) -> NNCFTensor: - return PTNNCFTensor(torch.min(x1.tensor, x2.tensor)) + @classmethod + def min(cls, *args) -> NNCFTensor: + stacked = cls.stack(args) + return cls.reduce_min(stacked, axis=0, keepdims=False) - @staticmethod - def max(x1: NNCFTensor, x2: NNCFTensor) -> NNCFTensor: - return PTNNCFTensor(torch.max(x1.tensor, x2.tensor)) + @classmethod + def max(cls, *args) -> NNCFTensor: + stacked = cls.stack(args) + return cls.reduce_max(stacked, axis=0, keepdims=False) @staticmethod def mean(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCFTensor: @@ -65,15 +90,33 @@ def mean(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCFTe @staticmethod def median(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCFTensor: - return PTNNCFTensor(x.tensor.median(dim=axis, keepdim=keepdims)) + # See https://github.com/pytorch/pytorch/issues/61582 + if not isinstance(axis, int): + return PTNNCFTensor(torch.tensor(np.median(x.tensor.detach().cpu().numpy(), axis=axis, keepdims=keepdims))) + return PTNNCFTensor(x.tensor.median(dim=axis, keepdim=keepdims).values) - @staticmethod - def masked_mean(x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False) -> NNCFTensor: - raise NotImplementedError() + @classmethod + def masked_mean(cls, x: NNCFTensor, axis: Union[int, tuple], mask: NNCFTensor, keepdims=False) -> NNCFTensor: + if mask is None: + return cls.mean(x, axis=axis, keepdims=keepdims) + masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=mask.tensor) + result = np.ma.mean(masked_x, axis=axis, keepdims=keepdims).astype(masked_x.dtype) + if result.size <= 1: + return PTNNCFTensor(torch.tensor(result)) + return PTNNCFTensor(torch.tensor(result.data)) - @staticmethod - def masked_median(x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False) -> NNCFTensor: - raise NotImplementedError() + @classmethod + def masked_median( + cls, x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False + ) -> NNCFTensor: + # Implemented in numy as torch.masked.median is not implemented yet + if mask is None: + return cls.median(x, axis=axis, keepdims=keepdims) + masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=mask.tensor.detach().cpu().numpy()) + result = np.ma.median(masked_x, axis=axis, keepdims=keepdims).astype(masked_x.dtype) + if len(result) == 1: + return PTNNCFTensor(torch.tensor(result)) + return PTNNCFTensor(torch.tensor(result.data)) @staticmethod def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor: @@ -100,6 +143,10 @@ def unstack(x: NNCFTensor, axis: int = 0) -> List[NNCFTensor]: tensor_list = torch.unbind(tensor, dim=axis) return [PTNNCFTensor(t) for t in tensor_list] + @staticmethod + def squeeze(x: NNCFTensor, dim: Optional[int] = None) -> NNCFTensor: + return PTNNCFTensor(torch.squeeze(x.tensor, dim=dim)) + @staticmethod def sum(tensor: NNCFTensor) -> TensorElementsType: return torch.sum(tensor.tensor).item() @@ -108,153 +155,296 @@ def sum(tensor: NNCFTensor) -> TensorElementsType: def quantile( tensor: NNCFTensor, quantile: Union[float, List[float]], axis: Union[int, tuple, list], keepdims: bool = False ) -> List[NNCFTensor]: - raise NotImplementedError() + # See https://github.com/pytorch/pytorch/issues/61582 + if not isinstance(axis, int): + result = torch.tensor( + np.quantile(tensor.tensor.detach().cpu().numpy(), q=quantile, axis=axis, keepdims=keepdims) + ) + else: + result = torch.quantile(tensor.tensor, torch.tensor(quantile).type(tensor.tensor.dtype), axis, keepdims) + result = result.type(tensor.tensor.dtype) + return [PTNNCFTensor(x) for x in result] + + @classmethod + def precentile( + cls, + tensor: NNCFTensor, + precentile: Union[float, List[float]], + axis: Union[int, tuple, list], + keepdims: bool = False, + ) -> List[TensorElementsType]: + quantile = np.true_divide(precentile, 100) + return cls.quantile(tensor, quantile=quantile, axis=axis, keepdims=keepdims) @classmethod def no_outliers_map( - cls, x: NNCFTensor, fn: Callable[[NNCFTensor, Optional[int]], Any], axis: int = 0, alpha: float = 0.01 + cls, + x: NNCFTensor, + fn: Callable[[NNCFTensor, int, NNCFTensor], Any], + axis: Union[int, Tuple[int, ...]] = 0, + alpha: float = 0.01, + keepdims: bool = False, ): - raise NotImplementedError() + if isinstance(axis, int): + axis = (axis,) + if len(x.shape) == len(axis): + return fn(x, axis=axis, mask=None, keepdims=keepdims) -class PTMinMaxStatisticCollector(MinMaxStatisticCollector): - def __init__( - self, use_abs_max: bool, reduction_shape: ReductionShape, output_shape: ReductionShape, num_samples: int = None - ): - super().__init__(use_abs_max, reduction_shape, num_samples) - self._output_shape = output_shape + low_values, high_values = cls.quantile(x, [alpha, 1 - alpha], axis=axis) + outliers_mask = torch.logical_or(x.tensor < low_values.tensor, high_values.tensor < x.tensor) + return fn(x, axis=axis, mask=PTNNCFTensor(outliers_mask), keepdims=keepdims) - @staticmethod - def _get_processor() -> NNCFCollectorTensorProcessor: - return PTNNCFCollectorTensorProcessor() - - def _register_input(self, x: torch.Tensor): - with no_nncf_trace(): - self._register_input_common(PTNNCFTensor(x)) - - def _get_statistics(self) -> PTMinMaxTensorStatistic: - min_values = self._min_values.tensor.view(self._output_shape) - max_values = self._max_values.tensor.view(self._output_shape) - return PTMinMaxTensorStatistic(min_values, max_values) - - -class PTMixedMinMaxStatisticCollector(MixedMinMaxStatisticCollector): - def __init__( - self, - use_per_sample_stats: bool, - use_abs_max: bool, - use_means_of_mins: bool, - use_means_of_maxs: bool, - reduction_shape: ReductionShape, - output_shape: ReductionShape, - num_samples: int = None, - window_size: int = None, - ): - super().__init__( - use_per_sample_stats, - use_abs_max, - use_means_of_mins, - use_means_of_maxs, - reduction_shape, - num_samples, - window_size, - ) - self._output_shape = output_shape + @classmethod + def masked_map(cls, x: NNCFTensor, fn: MaskedReduceFN, filter_fn) -> NNCFTensor: + return fn(x, mask=filter_fn(x)) - @staticmethod - def _get_processor() -> NNCFCollectorTensorProcessor: - return PTNNCFCollectorTensorProcessor() - - def _register_input(self, x: torch.Tensor): - with no_nncf_trace(): - self._register_input_common(PTNNCFTensor(x)) - - def _get_statistics(self) -> PTMinMaxTensorStatistic: - min_values = self._min_aggregate().tensor.view(self._output_shape) - max_values = self._max_aggregate().tensor.view(self._output_shape) - return PTMinMaxTensorStatistic(min_values, max_values) - - -class PTMeanMinMaxStatisticCollector(MeanMinMaxStatisticCollector): - def __init__( - self, - use_per_sample_stats: bool, - use_abs_max: bool, - reduction_shape: ReductionShape, - output_shape: ReductionShape, - num_samples: int = None, - window_size: int = None, - ): - super().__init__(use_per_sample_stats, use_abs_max, reduction_shape, num_samples, window_size) - self._output_shape = output_shape + @classmethod + def sub(cls, a: NNCFTensor, b: NNCFTensor) -> NNCFTensor: + return NNCFTensor(a.tensor - b.tensor) - @staticmethod - def _get_processor() -> NNCFCollectorTensorProcessor: - return PTNNCFCollectorTensorProcessor() + @classmethod + def non_zero_elements(cls, x: NNCFTensor) -> NNCFTensor: + pt_tensor = x.tensor + eps = torch.finfo(pt_tensor.dtype).eps + return NNCFTensor(pt_tensor.abs() > eps) - def _register_input(self, x: torch.Tensor): - with no_nncf_trace(): - self._register_input_common(PTNNCFTensor(x)) - def _get_statistics(self) -> PTMinMaxTensorStatistic: - min_values = self._min_aggregate().tensor.view(self._output_shape) - max_values = self._max_aggregate().tensor.view(self._output_shape) - return PTMinMaxTensorStatistic(min_values, max_values) +class PTReducerMixIn: + def _get_processor(self): + return PTNNCFCollectorTensorProcessor + def get_inplace_fn(self): + return None -class PTMedianMADStatisticCollector(MedianMADStatisticCollector): - def _register_input(self, x: torch.Tensor): - with no_nncf_trace(): - self._samples.append(x.detach().cpu().numpy()) + def get_output_names(self, target_node_name: str, port_id: int) -> List[str]: + return [] - def _get_statistics(self) -> PTMedianMADTensorStatistic: - numpy_median, numpy_mad = self._prepare_statistics() - median_tensor = torch.from_numpy(numpy_median).to(dtype=torch.float) - mad_tensor = torch.from_numpy(numpy_mad).to(dtype=torch.float) - median_tensor = expand_like(median_tensor, list(self._reduction_shape)) - mad_tensor = expand_like(mad_tensor, list(self._reduction_shape)) +class PTNoopReducer(PTReducerMixIn, NoopReducer): + pass - return PTMedianMADTensorStatistic(median_tensor, mad_tensor) +class PTMinReducer(PTReducerMixIn, MinReducer): + pass -class PTPercentileStatisticCollector(PercentileStatisticCollector): - def _register_input(self, x: torch.Tensor): - with no_nncf_trace(): - self._samples.append(x.detach().cpu().numpy()) - def _get_statistics(self) -> PTPercentileTensorStatistic: - percentile_vs_values_dict = self._prepare_statistics() - for key, val in percentile_vs_values_dict.items(): - torch_percentiles = torch.from_numpy(val).to(dtype=torch.float) - percentile_vs_values_dict[key] = expand_like(torch_percentiles, list(self._reduction_shape)) - return PTPercentileTensorStatistic(percentile_vs_values_dict) +class PTMaxReducer(PTReducerMixIn, MaxReducer): + pass -class PTMeanPercentileStatisticCollector(MeanPercentileStatisticCollector): - def _register_input(self, x: torch.Tensor): - with no_nncf_trace(): - for pct, val in self._all_pct_values.items(): - np_vals = np_percentile_reduce_like(x.cpu().numpy(), self._reduction_shape, pct) - torch_vals = torch.from_numpy(np_vals).to(dtype=torch.float) - val.append(torch_vals) +class PTAbsMaxReducer(PTReducerMixIn, AbsMaxReducer): + pass - def _get_statistics(self) -> PTPercentileTensorStatistic: - mean_percentile_values = {} - for pct, val in self._all_pct_values.items(): - stacked_pct_vals = torch.stack(list(val)) - mean_percentile_values[pct] = stacked_pct_vals.mean(dim=0).view(self._reduction_shape) - return PTPercentileTensorStatistic(mean_percentile_values) +class PTMeanReducer(PTReducerMixIn, MeanReducer): + pass + + +class PTQuantileReducer(PTReducerMixIn, QuantileReducer): + pass + + +class PTAbsQuantileReducer(PTReducerMixIn, AbsQuantileReducer): + pass + + +class PTBatchMeanReducer(PTReducerMixIn, BatchMeanReducer): + pass + + +class PTMeanPerChanelReducer(PTReducerMixIn, MeanPerChReducer): + pass -class PTMeanStatisticCollector(MeanStatisticCollector): - @staticmethod - def _get_processor() -> NNCFCollectorTensorProcessor: - return PTNNCFCollectorTensorProcessor() - def _register_input(self, x: torch.Tensor): - with no_nncf_trace(): - self._register_input_common(PTNNCFTensor(x)) +def maybe_add_squeeze(aggregator, squeeze_dims): + if not squeeze_dims: + return aggregator - def _get_statistics(self) -> PTMeanTensorStatistic: - return PTMeanTensorStatistic(self._mean_aggregate().tensor, self._shape()) + def post_aggregation_hook(aggregated_value): + return PTNNCFCollectorTensorProcessor.squeeze(PTNNCFTensor(aggregated_value), dim=squeeze_dims).tensor + + return PostAggregateHook(aggregator=aggregator, post_aggregation_hook=post_aggregation_hook) + + +def get_min_max_statistic_collector( + use_abs_max, + reducers_axes, + reducers_keepdims: bool, + aggregators_axes, + aggregators_keepdims, + num_samples: int, + squeeze_dims, +): + tensor_collector = TensorCollector(PTMinMaxTensorStatistic) + + aggregator_kwargs = { + "tensor_processor": PTNNCFCollectorTensorProcessor, + "num_samples": num_samples, + "aggregation_axes": aggregators_axes, + "keepdims": aggregators_keepdims, + } + min_reducer = PTMinReducer(reducers_axes, keepdims=reducers_keepdims) + min_aggregator = MinAggregator(**aggregator_kwargs) + min_aggregator = maybe_add_squeeze(min_aggregator, squeeze_dims) + tensor_collector.register_statistic_branch(PTMinMaxTensorStatistic.MIN_STAT, min_reducer, min_aggregator) + + max_reducer_cls = PTAbsMaxReducer if use_abs_max else PTMaxReducer + max_reducer = max_reducer_cls(reducers_axes, keepdims=reducers_keepdims) + max_aggregator = MaxAggregator(**aggregator_kwargs) + max_aggregator = maybe_add_squeeze(max_aggregator, squeeze_dims) + tensor_collector.register_statistic_branch(PTMinMaxTensorStatistic.MAX_STAT, max_reducer, max_aggregator) + return tensor_collector + + +def get_mixed_min_max_statistic_collector( + reducers_axes, + reducers_keepdims: bool, + aggregators_axes, + aggregators_keepdims, + use_abs_max: bool, + use_means_of_mins: bool, + use_means_of_maxs: bool, + squeeze_dims, + num_samples: int = None, + window_size: int = None, +): + tensor_collector = TensorCollector(PTMinMaxTensorStatistic) + min_reducer = PTMinReducer(reducers_axes, keepdims=reducers_keepdims) + + kwargs = { + "tensor_processor": PTNNCFCollectorTensorProcessor, + "num_samples": num_samples, + "aggregation_axes": aggregators_axes, + "keepdims": aggregators_keepdims, + "window_size": window_size, + } + min_aggregator_cls = MeanAggregator if use_means_of_mins else MinAggregator + min_aggregator = min_aggregator_cls(**kwargs) + min_aggregator = maybe_add_squeeze(min_aggregator, squeeze_dims) + tensor_collector.register_statistic_branch(PTMinMaxTensorStatistic.MIN_STAT, min_reducer, min_aggregator) + + max_reducer_cls = PTAbsMaxReducer if use_abs_max else PTMaxReducer + max_reducer = max_reducer_cls(reducers_axes, keepdims=reducers_keepdims) + max_aggregator_cls = MeanAggregator if use_means_of_maxs else MinAggregator + max_aggregator = max_aggregator_cls(**kwargs) + max_aggregator = maybe_add_squeeze(max_aggregator, squeeze_dims) + tensor_collector.register_statistic_branch(PTMinMaxTensorStatistic.MAX_STAT, max_reducer, max_aggregator) + + return tensor_collector + + +def get_median_mad_statistic_collector( + reducers_axes, + reducers_keepdims: bool, + aggregators_axes, + aggregators_keepdims, + num_samples: int, + squeeze_dims, + window_size: int = None, +): + return _get_collection_without_reduction( + MedianAbsoluteDeviationAggregator, + PTMedianMADTensorStatistic, + reducers_axes=reducers_axes, + reducers_keepdims=reducers_keepdims, + aggregators_axes=aggregators_axes, + aggregators_keepdims=aggregators_keepdims, + num_samples=num_samples, + squeeze_dims=squeeze_dims, + window_size=window_size, + ) + + +def get_precentile_tensor_collector( + percentiles_to_collect, + reducers_axes, + reducers_keepdims: bool, + aggregators_axes, + aggregators_keepdims, + num_samples: int, + squeeze_dims, + window_size: int = None, +): + return _get_collection_without_reduction( + partial(PrecentileAggregator, percentiles_to_collect=percentiles_to_collect), + PTPercentileTensorStatistic, + reducers_axes=reducers_axes, + reducers_keepdims=reducers_keepdims, + aggregators_axes=aggregators_axes, + aggregators_keepdims=aggregators_keepdims, + num_samples=num_samples, + squeeze_dims=squeeze_dims, + window_size=window_size, + ) + + +def _get_collection_without_reduction( + aggregator_cls, + statistic_class, + reducers_axes, + reducers_keepdims: bool, + aggregators_axes, + aggregators_keepdims, + num_samples: int, + squeeze_dims, + window_size: int = None, +): + tensor_collector = TensorCollector(statistic_class) + reducer = PTNoopReducer() + aggregation_axes = list(set(list(aggregators_axes) + [dim + 1 for dim in reducers_axes])) + aggregator = aggregator_cls( + PTNNCFCollectorTensorProcessor, + aggregation_axes=aggregation_axes, + window_size=window_size, + num_samples=num_samples, + keepdims=True, + ) + dims_to_squeeze = [0] if squeeze_dims else [] + dims_to_squeeze += [axis + 1 for axis in reducers_axes] if not reducers_keepdims else [] + dims_to_squeeze += aggregators_axes if not aggregators_keepdims else [] + if dims_to_squeeze: + + def post_aggregation_hook(aggregated_value): + retval = {} + for key, value in aggregated_value.items(): + retval[key] = PTNNCFCollectorTensorProcessor.squeeze(PTNNCFTensor(value), dim=dims_to_squeeze).tensor + return retval + + aggregator = PostAggregateHook(aggregator=aggregator, post_aggregation_hook=post_aggregation_hook) + + tensor_collector.register_statistic_branch( + PTMedianMADTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY, reducer, aggregator + ) + return tensor_collector + + +def get_mean_stat_collector(num_samples, channel_axis, window_size=None): + if channel_axis == 0: + reducer = PTBatchMeanReducer() + else: + reducer = PTMeanPerChanelReducer(channel_axis) + noop_reducer = PTNoopReducer() + + kwargs = { + "tensor_processor": PTNNCFCollectorTensorProcessor, + "num_samples": num_samples, + "window_size": window_size, + } + aggregate_mean = MeanAggregator(**kwargs) + aggregate_shape = ShapeAggregator() + + collector = TensorCollector(PTMeanTensorStatistic) + collector.register_statistic_branch(PTMeanTensorStatistic.MEAN_STAT, reducer, aggregate_mean) + collector.register_statistic_branch(PTMeanTensorStatistic.SHAPE_STAT, noop_reducer, aggregate_shape) + return collector + + +PT_REDUCERS_MAP = { + StatisticsType.MIN: PTMinReducer, + StatisticsType.MAX: PTMaxReducer, + StatisticsType.ABS_MAX: PTAbsMaxReducer, + StatisticsType.MEAN: PTMeanReducer, + StatisticsType.QUANTILE: PTQuantileReducer, + StatisticsType.ABS_QUANTILE: PTAbsQuantileReducer, +} diff --git a/nncf/torch/tensor_statistics/statistics.py b/nncf/torch/tensor_statistics/statistics.py index 7a251b19207..d57f135c778 100644 --- a/nncf/torch/tensor_statistics/statistics.py +++ b/nncf/torch/tensor_statistics/statistics.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Optional + import torch from nncf.common.tensor_statistics.statistics import MeanTensorStatistic @@ -19,24 +21,45 @@ class PTMinMaxTensorStatistic(MinMaxTensorStatistic): + def __init__(self, tensor_collector_output): + super().__init__(tensor_collector_output[self.MIN_STAT], tensor_collector_output[self.MAX_STAT]) + @staticmethod def tensor_eq(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol=1e-6) -> bool: return bool(torch.allclose(tensor1, tensor2, rtol=rtol)) class PTMedianMADTensorStatistic(MedianMADTensorStatistic): + def __init__(self, tensor_collector_output): + super().__init__( + tensor_collector_output[self.TENSOR_STATISTIC_OUTPUT_KEY][self.MEDIAN_VALUES_STAT], + tensor_collector_output[self.TENSOR_STATISTIC_OUTPUT_KEY][self.MAD_VALUES_STAT], + ) + @staticmethod def tensor_eq(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol=1e-6) -> bool: return bool(torch.allclose(tensor1, tensor2, rtol=rtol)) class PTPercentileTensorStatistic(PercentileTensorStatistic): + def __init__(self, tensor_collector_output): + if self.TENSOR_STATISTIC_OUTPUT_KEY in tensor_collector_output: + super().__init__(tensor_collector_output[self.TENSOR_STATISTIC_OUTPUT_KEY]) + else: + percentile_vs_values_dict = {} + for (_, percentile), value in tensor_collector_output.items(): + percentile_vs_values_dict[percentile] = value + super().__init__(percentile_vs_values_dict) + @staticmethod def tensor_eq(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol=1e-6) -> bool: return bool(torch.allclose(tensor1, tensor2, rtol=rtol)) class PTMeanTensorStatistic(MeanTensorStatistic): + def __init__(self, tensor_collector_output): + super().__init__(tensor_collector_output[self.MEAN_STAT], tensor_collector_output[self.SHAPE_STAT]) + @staticmethod def tensor_eq(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol=1e-6) -> bool: return bool(torch.allclose(tensor1, tensor2, rtol=rtol)) @@ -49,8 +72,10 @@ def pt_convert_stat_to_min_max_tensor_stat(statistic: TensorStatistic) -> PTMinM # Using three-sigma approach to estimate min and max # Constant factor depends on the distribution form - assuming normal and the factor is 1.4826 return PTMinMaxTensorStatistic( - statistic.median_values - 3 * 1.4826230 * statistic.mad_values, - statistic.median_values + 3 * 1.4826230 * statistic.mad_values, + { + PTMinMaxTensorStatistic.MIN_STAT: statistic.median_values - 3 * 1.4826230 * statistic.mad_values, + PTMinMaxTensorStatistic.MAX_STAT: statistic.median_values + 3 * 1.4826230 * statistic.mad_values, + } ) if isinstance(statistic, PTPercentileTensorStatistic): if len(statistic.percentile_vs_values_dict.keys()) < 2: @@ -58,6 +83,9 @@ def pt_convert_stat_to_min_max_tensor_stat(statistic: TensorStatistic) -> PTMinM min_pct = min(statistic.percentile_vs_values_dict.keys()) max_pct = max(statistic.percentile_vs_values_dict.keys()) return PTMinMaxTensorStatistic( - statistic.percentile_vs_values_dict[min_pct], statistic.percentile_vs_values_dict[max_pct] + { + PTMinMaxTensorStatistic.MIN_STAT: statistic.percentile_vs_values_dict[min_pct], + PTMinMaxTensorStatistic.MAX_STAT: statistic.percentile_vs_values_dict[max_pct], + } ) raise ValueError("Unknown TensorStatistic to generate min-max stat from!") diff --git a/tests/common/test_statistics_aggregator.py b/tests/common/test_statistics_aggregator.py index cd0545a4580..da100dc5cfb 100644 --- a/tests/common/test_statistics_aggregator.py +++ b/tests/common/test_statistics_aggregator.py @@ -375,6 +375,7 @@ def test_statistics_aggregator_min_max( inplace_statistics, is_backend_support_custom_estimators, ): + inplace_statistics = False model = self.get_backend_model(dataset_samples) quantizer_config = QuantizerConfig( mode=test_parameters.quantization_mode, per_channel=test_parameters.per_channel @@ -428,6 +429,9 @@ def filter_func(point): shape = (3, 1, 1, 1) ref_min_val, ref_max_val = map(lambda x: np.reshape(x, shape), (ref_min_val, ref_max_val)) + if not np.allclose(stat.min_values, ref_min_val): + # breakpoint() + stat = tensor_collector.get_statistics() assert np.allclose(stat.min_values, ref_min_val) assert np.allclose(stat.max_values, ref_max_val) if isinstance(ref_min_val, np.ndarray): @@ -811,10 +815,10 @@ def test_same_collectors_different_attrs_dont_merge(self, statistics_type, test_ model = params["model"](dataset_samples) params = {} if statistics_type in [StatisticsType.MIN, StatisticsType.MAX, StatisticsType.ABS_MAX, StatisticsType.MEAN]: - params["reduction_shape"] = [None, (0, 1, 3), (1, 2, 3)] + params["reduction_axes"] = [None, (0, 1, 3), (1, 2, 3)] params["inplace"] = [False, True] elif statistics_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]: - params["reduction_shape"] = [None, (0, 1, 3), (1, 2, 3)] + params["reduction_axes"] = [None, (0, 1, 3), (1, 2, 3)] params["quantile"] = [[0.01, 0.99], [0.001, 0.999]] elif statistics_type == "batch_mean": pytest.skip("Inplace statistic woun't work until openvino==2023.0.0 release") diff --git a/tests/experimental/common/test_reducers_and_aggregators.py b/tests/experimental/common/test_reducers_and_aggregators.py index cc54fe987ac..8567cfb7a03 100644 --- a/tests/experimental/common/test_reducers_and_aggregators.py +++ b/tests/experimental/common/test_reducers_and_aggregators.py @@ -11,10 +11,12 @@ from abc import abstractmethod from itertools import product +from typing import Any, List, Optional, Tuple import numpy as np import pytest +from nncf.common.graph.layer_attributes import Dtype from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator from nncf.experimental.common.tensor_statistics.collectors import MeanNoOutliersAggregator @@ -43,17 +45,123 @@ default_test_quantile = 0.1 -def default_test_mean_no_outlier(tp, ps): - return MeanNoOutliersAggregator(tp, ps, quantile=default_test_quantile) +OFFLINE_AGGREGATORS_TEST_CASES = [ + ( + None, + False, + [[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]], + [[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]], + ), + ( + (0,), + False, + [[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]], + [[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]], + ), + ( + (0, 2), + False, + [[-50000, -28, -32]], + [[50000, 28, 32]], + ), + ( + (2,), + False, + [ + [[-50000, 5, 10]], + [[-40000, 4, 8]], + [[-30000, 3, 6]], + [[-20000, 2, 4]], + [[-10000, 1, 2]], + [[0, 0, 0]], + [[-6, -7, -8]], + [[-12, -14, -16]], + [[-18, -21, -24]], + [[-24, -28, -32]], + ], + [ + [[50000, -5, -10]], + [[40000, -4, -8]], + [[30000, -3, -6]], + [[20000, -2, -4]], + [[10000, -1, -2]], + [[0, 0, 0]], + [[6, 7, 8]], + [[12, 14, 16]], + [[18, 21, 24]], + [[24, 28, 32]], + ], + ), + ( + None, + True, + [[[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]]], + [[[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]]], + ), + ( + (0,), + True, + [[[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]]], + [[[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]]], + ), + ( + (0, 2), + True, + [[[[-50000, -28, -32]]]], + [[[[50000, 28, 32]]]], + ), + ( + (2,), + True, + [ + [[[-50000, 5, 10]]], + [[[-40000, 4, 8]]], + [[[-30000, 3, 6]]], + [[[-20000, 2, 4]]], + [[[-10000, 1, 2]]], + [[[0, 0, 0]]], + [[[-6, -7, -8]]], + [[[-12, -14, -16]]], + [[[-18, -21, -24]]], + [[[-24, -28, -32]]], + ], + [ + [[[50000, -5, -10]]], + [[[40000, -4, -8]]], + [[[30000, -3, -6]]], + [[[20000, -2, -4]]], + [[[10000, -1, -2]]], + [[[0, 0, 0]]], + [[[6, 7, 8]]], + [[[12, 14, 16]]], + [[[18, 21, 24]]], + [[[24, 28, 32]]], + ], + ), +] + + +def default_test_mean_no_outlier(tensor_processor, aggregation_axes, keepdims): + return MeanNoOutliersAggregator( + tensor_processor=tensor_processor, + aggregation_axes=aggregation_axes, + quantile=default_test_quantile, + keepdims=keepdims, + ) -def default_test_median_no_outlier(tp, ps): - return MedianNoOutliersAggregator(tp, ps, quantile=default_test_quantile) +def default_test_median_no_outlier(tensor_processor, aggregation_axes, keepdims): + return MedianNoOutliersAggregator( + tensor_processor=tensor_processor, + aggregation_axes=aggregation_axes, + quantile=default_test_quantile, + keepdims=keepdims, + ) class TemplateTestReducersAggreagtors: @abstractmethod - def get_nncf_tensor(self, x: np.array): + def get_nncf_tensor(self, x: np.array, dtype: Optional[Dtype] = None): pass @pytest.fixture @@ -70,6 +178,18 @@ def reducers(self): def all_close(self, val, ref) -> bool: pass + @abstractmethod + def squeeze_tensor(self, ref_tensor: List[Any], axes: Optional[Tuple[int]] = None): + pass + + @abstractmethod + def cast_tensor(self, tensor, dtype: Dtype): + pass + + @abstractmethod + def expand_dims(self, tensor, dims: Tuple[int, ...]): + pass + def test_noop_reducer(self, reducers): reducer = reducers["noop"]() input_ = np.arange(24).reshape((1, 2, 3, 4)) @@ -87,27 +207,31 @@ def test_noop_reducer(self, reducers): ], ) def test_min_max_mean_reducers(self, reducer_name, ref, reducers): - reduction_shape = (1, 2) + reduction_axes = (1, 2) input_ = np.arange(-26, 10).reshape((4, 3, 3)) - for i, red_shape in enumerate([reduction_shape, None]): - reducer = reducers[reducer_name](red_shape, False) - val = reducer([self.get_nncf_tensor(input_)]) - assert len(val) == 1 - assert self.all_close(val[0].tensor, ref[i]) + for i, red_axes in enumerate([reduction_axes, None]): + for keepdims in [True, False]: + reducer = reducers[reducer_name](reduction_axes=red_axes, inplace=False, keepdims=keepdims) + val = reducer([self.get_nncf_tensor(input_, Dtype.FLOAT)]) + assert len(val) == 1 + ref_ = ref[i] if keepdims else self.squeeze_tensor(ref[i]) + assert self.all_close(val[0].tensor, self.cast_tensor(ref_, Dtype.FLOAT)) @pytest.mark.parametrize( "reducer_name,ref", [("quantile", ([[[[-20000]]]], [[[[10000]]]])), ("abs_quantile", ([[[[20000]]]],))] ) def test_quantile_reducers(self, reducer_name, ref, reducers): - reduction_shape = (1, 2, 3) + reduction_axes = (1, 2, 3) input_ = np.arange(-26, 10).reshape((1, 4, 3, 3)) input_[0][0][0] = -20000 input_[0][0][1] = 10000 - reducer = reducers[reducer_name](reduction_shape, inplace=False) - val = reducer([self.get_nncf_tensor(input_)]) - assert len(val) == len(ref) - for i, ref_ in enumerate(ref): - assert self.all_close(val[i].tensor, ref_) + for keepdims in [True, False]: + reducer = reducers[reducer_name](reduction_axes=reduction_axes, inplace=False, keepdims=keepdims) + val = reducer([self.get_nncf_tensor(input_, dtype=Dtype.FLOAT)]) + assert len(val) == len(ref) + for i, ref_ in enumerate(ref): + ref_ = ref[i] if keepdims else self.squeeze_tensor(ref[i], (1, 2, 3)) + assert self.all_close(val[i].tensor, self.cast_tensor(ref_, Dtype.FLOAT)) @pytest.mark.parametrize( "reducer_name,ref", @@ -116,9 +240,9 @@ def test_quantile_reducers(self, reducer_name, ref, reducers): def test_batch_mean_mean_per_ch_reducers(self, reducer_name, ref, reducers): input_ = np.arange(-26, 10).reshape((4, 1, 3, 3)) reducer = reducers[reducer_name](inplace=False) - val = reducer([self.get_nncf_tensor(input_)]) + val = reducer([self.get_nncf_tensor(input_, Dtype.FLOAT)]) assert len(val) == 1 - assert self.all_close(val[0].tensor, ref) + assert self.all_close(val[0].tensor, self.cast_tensor(ref, Dtype.FLOAT)) def test_noop_aggregator(self): aggregator = NoopAggregator(None) @@ -146,20 +270,28 @@ def test_shape_aggregator(self): assert aggregator._collected_samples == 1 assert ref_shape == aggregator.aggregate() - def test_min_max_aggregators(self, tensor_processor): - min_aggregator = MinAggregator(tensor_processor) - max_aggregator = MaxAggregator(tensor_processor) + @pytest.mark.parametrize( + "aggregation_axes,keepdims,min_ref,max_ref", + OFFLINE_AGGREGATORS_TEST_CASES, + ) + def test_min_max_aggregators(self, aggregation_axes, keepdims, min_ref, max_ref, tensor_processor): + min_aggregator = MinAggregator( + tensor_processor=tensor_processor, aggregation_axes=aggregation_axes, keepdims=keepdims + ) + max_aggregator = MaxAggregator( + tensor_processor=tensor_processor, aggregation_axes=aggregation_axes, keepdims=keepdims + ) input_ = np.arange(3 * 3).reshape((1, 3, 3)) input_[0, 0, 0] = -10000 for i in range(-5, 5): min_aggregator.register_reduced_input(self.get_nncf_tensor(input_ * (-i))) max_aggregator.register_reduced_input(self.get_nncf_tensor(input_ * i)) - min_ref = [[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]] - assert self.all_close(min_ref, min_aggregator.aggregate()) - - max_ref = [[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]] - assert self.all_close(max_ref, max_aggregator.aggregate()) + assert self.all_close( + min_aggregator.aggregate(), + min_ref, + ) + assert self.all_close(max_aggregator.aggregate(), max_ref) NO_OUTLIERS_TEST_PARAMS = [ (MeanAggregator, True, 1, 1404.5138888888905), @@ -199,7 +331,10 @@ def test_min_max_aggregators(self, tensor_processor): ] @pytest.mark.parametrize("aggregator_cls,use_per_sample_stats,dims,refs", NO_OUTLIERS_TEST_PARAMS) - def test_mean_median_agggregators(self, aggregator_cls, refs, tensor_processor, dims, use_per_sample_stats): + @pytest.mark.parametrize("keepdims", [True, False]) + def test_mean_median_agggregators( + self, aggregator_cls, refs, tensor_processor, dims, use_per_sample_stats, keepdims + ): input_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) input_with_outliers = np.array( [100_000, -100_000, 200_000, -200_000, 300_000, -300_000, 400_000, -400_000, 500_000] @@ -211,19 +346,26 @@ def test_mean_median_agggregators(self, aggregator_cls, refs, tensor_processor, input_ = input_.reshape((1, 3, 3)) input_with_outliers = input_with_outliers.reshape((1, 3, 3)) - aggregator = aggregator_cls(tensor_processor, use_per_sample_stats) + aggregation_axes = (0, 1) if use_per_sample_stats else (0,) + aggregator = aggregator_cls( + tensor_processor=tensor_processor, aggregation_axes=aggregation_axes, keepdims=keepdims + ) for i in range(1, 6): - aggregator.register_reduced_input(self.get_nncf_tensor(input_ * i)) + aggregator.register_reduced_input(self.get_nncf_tensor(input_ * i, Dtype.FLOAT)) # this registration is to make diff between mean and median bigger - aggregator.register_reduced_input(self.get_nncf_tensor(input_ * 10)) + aggregator.register_reduced_input(self.get_nncf_tensor(input_ * 10, Dtype.FLOAT)) is_median = isinstance(aggregator, (MedianAggregator, MedianNoOutliersAggregator)) # Outliers registration for i in range(2): # mult is needed to make outlier and no outlier aggreagators differs mult = 2.2 * i - 1 if not is_median else 1 - aggregator.register_reduced_input(self.get_nncf_tensor(input_with_outliers * mult)) + aggregator.register_reduced_input(self.get_nncf_tensor(input_with_outliers * mult, Dtype.FLOAT)) ret_val = aggregator.aggregate() - assert self.all_close(ret_val, refs) + + if keepdims: + refs = self.expand_dims(refs, (0, 1) if use_per_sample_stats else (0,)) + + assert self.all_close(ret_val, self.cast_tensor(refs, Dtype.FLOAT)) @pytest.mark.parametrize( "reducer_name", @@ -240,10 +382,10 @@ def test_reducers_name_hash_equal(self, reducer_name, reducers): params = {} if reducer_name in ["min", "max", "abs_max", "mean"]: - params["reduction_shape"] = [None, (0, 1, 3), (1, 2, 3)] + params["reduction_axes"] = [None, (0, 1, 3), (1, 2, 3)] params["inplace"] = [False, True] elif reducer_name in ["quantile", "abs_quantile"]: - params["reduction_shape"] = [None, (0, 1, 3), (1, 2, 3)] + params["reduction_axes"] = [None, (0, 1, 3), (1, 2, 3)] params["quantile"] = [[0.01, 0.99], [0.001, 0.999]] elif reducer_name == "batch_mean": params["inplace"] = [False, True] diff --git a/tests/openvino/native/quantization/test_reducers_and_aggregators.py b/tests/openvino/native/quantization/test_reducers_and_aggregators.py index 4726ac61194..213a64b2e84 100644 --- a/tests/openvino/native/quantization/test_reducers_and_aggregators.py +++ b/tests/openvino/native/quantization/test_reducers_and_aggregators.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, List, Optional, Tuple + import numpy as np import pytest +from nncf.common.graph.layer_attributes import Dtype from nncf.openvino.statistics.collectors import OVAbsMaxReducer from nncf.openvino.statistics.collectors import OVAbsQuantileReducer from nncf.openvino.statistics.collectors import OVBatchMeanReducer @@ -31,7 +34,7 @@ class TestReducersAggregators(TemplateTestReducersAggreagtors): def tensor_processor(self): return OVNNCFCollectorTensorProcessor - def get_nncf_tensor(self, x: np.array): + def get_nncf_tensor(self, x: np.array, dtype: Optional[Dtype] = None): return OVNNCFTensor(x) @pytest.fixture(scope="module") @@ -52,3 +55,12 @@ def all_close(self, val, ref) -> bool: val_ = np.array(val) ref_ = np.array(ref) return np.allclose(val_, ref_) and val_.shape == ref_.shape + + def squeeze_tensor(self, ref_tensor: List[Any], axes: Optional[Tuple[int]] = None): + return np.squeeze(np.array(ref_tensor), axes) + + def cast_tensor(self, tensor, dtype: Dtype): + return tensor + + def expand_dims(self, tensor, dims: Tuple[int, ...]): + return np.expand_dims(np.array(tensor), dims) diff --git a/tests/post_training/test_templates/test_channel_alignment.py b/tests/post_training/test_templates/test_channel_alignment.py index d3b6dd045e5..b15884445d8 100644 --- a/tests/post_training/test_templates/test_channel_alignment.py +++ b/tests/post_training/test_templates/test_channel_alignment.py @@ -489,7 +489,7 @@ def test_statistic_collectors(self, inplace_ref, q_ref): assert len(statistic_collector.reducers) == 1 reducer = statistic_collector.reducers.pop() assert isinstance(reducer, QuantileReducer) - assert reducer._reduction_shape == reduction_shape_ref + assert reducer._reduction_axes == reduction_shape_ref assert np.allclose(reducer._quantile, (q_ref, 1 - q_ref)) assert len(statistic_collector.aggregators) == 2 diff --git a/tests/post_training/test_templates/test_quantizer_config.py b/tests/post_training/test_templates/test_quantizer_config.py index e614138d0a9..afd2accc285 100644 --- a/tests/post_training/test_templates/test_quantizer_config.py +++ b/tests/post_training/test_templates/test_quantizer_config.py @@ -278,8 +278,8 @@ def test_get_stat_collector( for reducer in reducers: if q_config_per_channel: - assert reducer._reduction_shape == params.ref_per_ch_reduction_shape + assert reducer._reduction_axes == params.ref_per_ch_reduction_shape else: - assert reducer._reduction_shape == params.ref_per_tensor_reduction_shape + assert reducer._reduction_axes == params.ref_per_tensor_reduction_shape assert tensor_collector.num_samples == num_samples diff --git a/tests/post_training/test_templates/test_smooth_quant.py b/tests/post_training/test_templates/test_smooth_quant.py index 42fe17e01b0..4292e228657 100644 --- a/tests/post_training/test_templates/test_smooth_quant.py +++ b/tests/post_training/test_templates/test_smooth_quant.py @@ -145,7 +145,7 @@ def test_get_abs_max_channel_collector(self): for reducer in backend_tensor_collector.reducers: assert isinstance(reducer, AbsMaxReducer) assert reducer.inplace == inplace_type - assert reducer._reduction_shape == reduction_shape + assert reducer._reduction_axes == reduction_shape @pytest.mark.parametrize( "model_cls, references", diff --git a/tests/tensorflow/tensor_statistics/test_tensor_statistics.py b/tests/tensorflow/tensor_statistics/test_tensor_statistics.py index e99487c69d9..7a32f2af2ab 100644 --- a/tests/tensorflow/tensor_statistics/test_tensor_statistics.py +++ b/tests/tensorflow/tensor_statistics/test_tensor_statistics.py @@ -106,7 +106,7 @@ def test_collected_statistics_with_shape_convert( for reduction_shape in reduction_shapes_vs_ref_statistic.keys(): collector_obj = collector(use_abs_max=True, reduction_shape=reduction_shape) for input_ in TestCollectedStatistics.REF_INPUTS: - collector_obj.register_input(input_) + collector_obj.register_inputs(input_) test_stats = collector_obj.get_statistics() assert reduction_shapes_vs_ref_statistic[reduction_shape] == test_stats @@ -184,7 +184,7 @@ def test_collected_statistics( for reduction_shape in reduction_shapes_vs_ref_statistic.keys(): collector_obj = collector(reduction_shape=reduction_shape) for input_ in TestCollectedStatistics.REF_INPUTS: - collector_obj.register_input(input_) + collector_obj.register_inputs(input_) test_stats = collector_obj.get_statistics() assert reduction_shapes_vs_ref_statistic[reduction_shape] == test_stats @@ -210,12 +210,12 @@ def collector_for_interface_test(self, request): def test_collected_samples(self, collector_for_interface_test: TensorStatisticCollectorBase): for input_ in TestCollectedStatistics.REF_INPUTS: - collector_for_interface_test.register_input(input_) + collector_for_interface_test.register_inputs(input_) assert collector_for_interface_test.collected_samples() == len(TestCollectedStatistics.REF_INPUTS) def test_reset(self, collector_for_interface_test: TensorStatisticCollectorBase): for input_ in TestCollectedStatistics.REF_INPUTS: - collector_for_interface_test.register_input(input_) + collector_for_interface_test.register_inputs(input_) collector_for_interface_test.reset() assert collector_for_interface_test.collected_samples() == 0 with pytest.raises(StatisticsNotCollectedError): @@ -223,16 +223,16 @@ def test_reset(self, collector_for_interface_test: TensorStatisticCollectorBase) def test_enable_disable(self, collector_for_interface_test: TensorStatisticCollectorBase): for input_ in TestCollectedStatistics.REF_INPUTS: - collector_for_interface_test.register_input(input_) + collector_for_interface_test.register_inputs(input_) collector_for_interface_test.disable() for input_ in TestCollectedStatistics.REF_INPUTS: - collector_for_interface_test.register_input(input_) + collector_for_interface_test.register_inputs(input_) assert collector_for_interface_test.collected_samples() == len(TestCollectedStatistics.REF_INPUTS) collector_for_interface_test.enable() for input_ in TestCollectedStatistics.REF_INPUTS: - collector_for_interface_test.register_input(input_) + collector_for_interface_test.register_inputs(input_) assert collector_for_interface_test.collected_samples() == 2 * len(TestCollectedStatistics.REF_INPUTS) OFFLINE_COLLECTORS = [ @@ -258,7 +258,7 @@ def collector_for_num_samples_test(self, request): def test_num_samples(self, collector_for_num_samples_test: OfflineTensorStatisticCollector): for input_ in TestCollectedStatistics.REF_INPUTS * 10: - collector_for_num_samples_test.register_input(input_) + collector_for_num_samples_test.register_inputs(input_) assert collector_for_num_samples_test.collected_samples() == TestCollectedStatistics.REF_NUM_SAMPLES diff --git a/tests/torch/ptq/test_ptq_params.py b/tests/torch/ptq/test_ptq_params.py index c174ec8b322..35ddfe3128e 100644 --- a/tests/torch/ptq/test_ptq_params.py +++ b/tests/torch/ptq/test_ptq_params.py @@ -18,6 +18,10 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.quantization.structs import QuantizationPreset from nncf.common.utils.backend import BackendType +from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator +from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator +from nncf.experimental.common.tensor_statistics.collectors import MinAggregator +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.parameters import ModelType from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters @@ -33,8 +37,6 @@ from nncf.torch.graph.operator_metatypes import PTModuleLinearMetatype from nncf.torch.graph.operator_metatypes import PTSoftmaxMetatype from nncf.torch.quantization.quantize_model import _create_nncf_config -from nncf.torch.tensor_statistics.collectors import PTMeanMinMaxStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMinMaxStatisticCollector from tests.common.quantization.metatypes import Conv2dTestMetatype from tests.common.quantization.metatypes import LinearTestMetatype from tests.common.quantization.metatypes import SoftmaxTestMetatype @@ -104,11 +106,17 @@ class TestPTQParams(TemplateTestPTQParams): def get_algo_backend(self): return PTMinMaxAlgoBackend() - def check_is_min_max_statistic_collector(self, tensor_collector): - assert isinstance(tensor_collector, PTMinMaxStatisticCollector) - - def check_is_mean_min_max_statistic_collector(self, tensor_collector): - assert isinstance(tensor_collector, PTMeanMinMaxStatisticCollector) + def check_is_min_max_statistic_collector(self, tensor_collector: TensorCollector): + aggrs = [aggr.__class__ for aggr in tensor_collector.aggregators.values()] + assert len(aggrs) == 2 + assert MinAggregator in aggrs + assert MaxAggregator in aggrs + + def check_is_mean_min_max_statistic_collector(self, tensor_collector: TensorCollector): + aggrs = [aggr.__class__ for aggr in tensor_collector.aggregators.values()] + assert len(aggrs) == 2 + assert MeanAggregator in aggrs + assert aggrs[0].__class__ == aggrs[1].__class__ def check_quantize_outputs_fq_num(self, quantize_outputs, act_num_q, weight_num_q): if quantize_outputs: diff --git a/tests/torch/ptq/test_quantizer_config.py b/tests/torch/ptq/test_quantizer_config.py index 41cab6438b5..152503c802b 100644 --- a/tests/torch/ptq/test_quantizer_config.py +++ b/tests/torch/ptq/test_quantizer_config.py @@ -12,9 +12,11 @@ import pytest from nncf.common.graph.transformations.commands import TargetType +from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator +from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator +from nncf.experimental.common.tensor_statistics.collectors import MinAggregator +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend -from nncf.torch.tensor_statistics.collectors import PTMeanMinMaxStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMinMaxStatisticCollector from tests.post_training.test_templates.models import NNCFGraphToTest from tests.post_training.test_templates.models import NNCFGraphToTestDepthwiseConv from tests.post_training.test_templates.models import NNCFGraphToTestSumAggregation @@ -30,11 +32,17 @@ class TestQuantizerConfig(TemplateTestQuantizerConfig): def get_algo_backend(self): return PTMinMaxAlgoBackend() - def check_is_min_max_statistic_collector(self, tensor_collector): - assert isinstance(tensor_collector, PTMinMaxStatisticCollector) + def check_is_min_max_statistic_collector(self, tensor_collector: TensorCollector): + aggrs = [aggr.__class__ for aggr in tensor_collector.aggregators.values()] + assert len(aggrs) == 2 + assert MinAggregator in aggrs + assert MaxAggregator in aggrs - def check_is_mean_min_max_statistic_collector(self, tensor_collector): - assert isinstance(tensor_collector, PTMeanMinMaxStatisticCollector) + def check_is_mean_min_max_statistic_collector(self, tensor_collector: TensorCollector): + aggrs = [aggr.__class__ for aggr in tensor_collector.aggregators.values()] + assert len(aggrs) == 2 + assert MeanAggregator in aggrs + assert aggrs[0].__class__ == aggrs[1].__class__ @pytest.fixture( params=[ diff --git a/tests/torch/ptq/test_reducers_and_aggregators.py b/tests/torch/ptq/test_reducers_and_aggregators.py new file mode 100644 index 00000000000..c6a97696b00 --- /dev/null +++ b/tests/torch/ptq/test_reducers_and_aggregators.py @@ -0,0 +1,83 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Tuple + +import numpy as np +import pytest +import torch + +from nncf.common.graph.layer_attributes import Dtype +from nncf.torch.tensor import PTNNCFTensor +from nncf.torch.tensor_statistics.collectors import PTAbsMaxReducer +from nncf.torch.tensor_statistics.collectors import PTAbsQuantileReducer +from nncf.torch.tensor_statistics.collectors import PTBatchMeanReducer +from nncf.torch.tensor_statistics.collectors import PTMaxReducer +from nncf.torch.tensor_statistics.collectors import PTMeanPerChanelReducer +from nncf.torch.tensor_statistics.collectors import PTMeanReducer +from nncf.torch.tensor_statistics.collectors import PTMinReducer +from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor +from nncf.torch.tensor_statistics.collectors import PTNoopReducer +from nncf.torch.tensor_statistics.collectors import PTQuantileReducer +from tests.experimental.common.test_reducers_and_aggregators import TemplateTestReducersAggreagtors + + +class TestReducersAggregators(TemplateTestReducersAggreagtors): + @pytest.fixture + def tensor_processor(self): + return PTNNCFCollectorTensorProcessor + + def get_nncf_tensor(self, x: np.ndarray, dtype: Optional[Dtype] = None): + torch_tensor = torch.tensor(x) + if dtype == Dtype.FLOAT: + torch_tensor = torch_tensor.float() + elif dtype == Dtype.INTEGER: + torch_tensor = torch_tensor.int() + return PTNNCFTensor(torch_tensor) + + @pytest.fixture(scope="module") + def reducers(self): + return { + "noop": PTNoopReducer, + "min": PTMinReducer, + "max": PTMaxReducer, + "abs_max": PTAbsMaxReducer, + "mean": PTMeanReducer, + "quantile": PTQuantileReducer, + "abs_quantile": PTAbsQuantileReducer, + "batch_mean": PTBatchMeanReducer, + "mean_per_ch": PTMeanPerChanelReducer, + } + + def all_close(self, val, ref) -> bool: + val_ = torch.tensor(val) + ref_ = torch.tensor(ref) + return torch.allclose(val_, ref_) and val_.shape == ref_.shape + + def squeeze_tensor(self, ref_tensor: List[Any], axes: Optional[Tuple[int]] = None): + if axes is None: + return torch.tensor(ref_tensor).squeeze() + return torch.tensor(ref_tensor).squeeze(axes) + + def cast_tensor(self, tensor, dtype: Dtype): + tensor = torch.tensor(tensor) + if dtype == Dtype.FLOAT: + return tensor.float() + if dtype == Dtype.INTEGER: + return tensor.int() + raise RuntimeError() + + def expand_dims(self, tensor, dims: Tuple[int, ...]): + tensor_ = torch.tensor(tensor) + shape = list(tensor_.shape) + for dim in dims: + shape.insert(dim, 1) + return tensor_.view(shape) diff --git a/tests/torch/quantization/test_range_init.py b/tests/torch/quantization/test_range_init.py index 84975de7bd4..f3be4ef2c0d 100644 --- a/tests/torch/quantization/test_range_init.py +++ b/tests/torch/quantization/test_range_init.py @@ -22,6 +22,7 @@ from torch.utils.data import DataLoader from torchvision.models import squeezenet1_1 +import nncf.torch.tensor_statistics.collectors as pt_collectors from nncf.common.graph import NNCFNodeName from nncf.common.quantization.initialization.range import PerLayerRangeInitConfig from nncf.common.quantization.initialization.range import RangeInitConfig @@ -33,6 +34,8 @@ from nncf.common.quantization.structs import QuantizerGroup from nncf.config import NNCFConfig from nncf.config.structures import QuantizationRangeInitArgs +from nncf.experimental.common.tensor_statistics.collectors import PostAggregateHook +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.torch import utils from nncf.torch.checkpoint_loading import load_state from nncf.torch.initialization import DefaultInitializingDataLoader @@ -46,9 +49,7 @@ from nncf.torch.quantization.layers import BaseQuantizer from nncf.torch.quantization.layers import PTQuantizerSpec from nncf.torch.quantization.layers import SymmetricQuantizer -from nncf.torch.tensor_statistics.collectors import PTMeanMinMaxStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMedianMADStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMinMaxStatisticCollector +from nncf.torch.tensor import PTNNCFTensor from nncf.torch.tensor_statistics.statistics import pt_convert_stat_to_min_max_tensor_stat from nncf.torch.utils import get_all_modules_by_type from nncf.torch.utils import safe_thread_call @@ -528,6 +529,7 @@ def init_idfn(val): ("mean_min_max", 9999, 0, 9999), ("threesigma", 16119.5, -6119.5, 22239), ("percentile", 6789, 3210, 3578), + ("mean_percentile", 9989.0010, 9.9990, 9979.0020), ] ), ids=init_idfn, @@ -671,53 +673,69 @@ def range_init_call_count_test_struct(request): return request.param +class CustomSpy: + def __init__(self, fn) -> None: + self._fn = fn + self.call_count = 0 + self.return_values_list = [] + + def __call__(self, *args, **kwargs): + self.call_count += 1 + retval = self._fn(*args, **kwargs) + self.return_values_list.append(retval) + return retval + + # pylint:disable=redefined-outer-name def test_per_layer_range_init_collectors_are_called_the_required_number_of_times( range_init_call_count_test_struct, mocker ): + range_minmax_init_create_spy = CustomSpy(pt_collectors.get_min_max_statistic_collector) + mocker.patch("nncf.torch.quantization.init_range.get_min_max_statistic_collector", new=range_minmax_init_create_spy) + range_meanminmax_init_create_spy = CustomSpy(pt_collectors.get_mixed_min_max_statistic_collector) + mocker.patch( + "nncf.torch.quantization.init_range.get_mixed_min_max_statistic_collector", new=range_meanminmax_init_create_spy + ) + range_threesigma_init_create_spy = CustomSpy(pt_collectors.get_median_mad_statistic_collector) + mocker.patch( + "nncf.torch.quantization.init_range.get_median_mad_statistic_collector", new=range_threesigma_init_create_spy + ) + config = create_config() config["compression"]["initializer"]["range"] = range_init_call_count_test_struct.range_init_config data_loader = TestRangeInit.create_dataloader(True, config, 10) config.register_extra_structs([QuantizationRangeInitArgs(data_loader)]) - range_minmax_init_create_spy = mocker.spy(PTMinMaxStatisticCollector, "__init__") - range_meanminmax_init_create_spy = mocker.spy(PTMeanMinMaxStatisticCollector, "__init__") - range_threesigma_init_create_spy = mocker.spy(PTMedianMADStatisticCollector, "__init__") - - range_minmax_init_register_input_spy = mocker.spy(PTMinMaxStatisticCollector, "_register_input") - range_meanminmax_init_register_input_spy = mocker.spy(PTMeanMinMaxStatisticCollector, "_register_input") - range_threesigma_init_register_input_spy = mocker.spy(PTMedianMADStatisticCollector, "_register_input") - TestRangeInit.create_algo_and_compressed_model(config) - assert ( - range_minmax_init_create_spy.call_count - == range_init_call_count_test_struct.expected_call_count_initializer_create["min_max"] - ) - assert ( - range_meanminmax_init_create_spy.call_count - == range_init_call_count_test_struct.expected_call_count_initializer_create["mean_min_max"] - ) - assert ( - range_threesigma_init_create_spy.call_count - == range_init_call_count_test_struct.expected_call_count_initializer_create["three_sigma"] - ) - - assert ( - range_minmax_init_register_input_spy.call_count - == range_init_call_count_test_struct.expected_call_count_register_input["min_max"] - ) - assert ( - range_meanminmax_init_register_input_spy.call_count - == range_init_call_count_test_struct.expected_call_count_register_input["mean_min_max"] - ) - assert ( - range_threesigma_init_register_input_spy.call_count - == range_init_call_count_test_struct.expected_call_count_register_input["three_sigma"] - ) - - -QUANTIZER_RANGE_INITIALIZERS = ["min_max", "threesigma", "mean_min_max", "percentile", "mixed_min_max"] + for stat_type, spy in [ + ("min_max", range_minmax_init_create_spy), + ("mean_min_max", range_meanminmax_init_create_spy), + ("three_sigma", range_threesigma_init_create_spy), + ]: + assert spy.call_count == range_init_call_count_test_struct.expected_call_count_initializer_create[stat_type] + collected_samples = 0 + for tensor_collector in spy.return_values_list: + cur_values = set() + for aggr in tensor_collector.aggregators.values(): + if isinstance(aggr, PostAggregateHook): + cur_values.add(aggr._aggregator._collected_samples) + else: + cur_values.add(aggr._collected_samples) + assert len(cur_values) == 1 + collected_samples += cur_values.pop() + + assert collected_samples == range_init_call_count_test_struct.expected_call_count_register_input[stat_type] + + +QUANTIZER_RANGE_INITIALIZERS = [ + "min_max", + "threesigma", + "mean_min_max", + "percentile", + "mixed_min_max", + "mean_percentile", +] class QuantizeRangeInitScaleShapeTestStruct: @@ -794,7 +812,7 @@ def test_quantize_range_init_sets_correct_scale_shapes(quantizer_range_init_test collector = StatCollectorGenerator.generate_stat_collector_for_range_init_config( range_init_config, tuple(quantizer.scale_shape), collector_params ) - collector.register_input(torch.ones(test_struct.input_shape)) + collector.register_unnamed_inputs(PTNNCFTensor(torch.ones(test_struct.input_shape))) stat = collector.get_statistics() minmax_values = pt_convert_stat_to_min_max_tensor_stat(stat) quantizer.apply_minmax_init(min_values=minmax_values.min_values, max_values=minmax_values.max_values) diff --git a/tests/torch/tensor_statistics/test_tensor_statistics.py b/tests/torch/tensor_statistics/test_tensor_statistics.py index 5c5cccc0220..7981bd96461 100644 --- a/tests/torch/tensor_statistics/test_tensor_statistics.py +++ b/tests/torch/tensor_statistics/test_tensor_statistics.py @@ -15,19 +15,16 @@ import pytest import torch -from nncf.common.tensor_statistics.collectors import OfflineTensorStatisticCollector from nncf.common.tensor_statistics.collectors import ReductionShape -from nncf.common.tensor_statistics.collectors import StatisticsNotCollectedError from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase from nncf.common.tensor_statistics.statistics import TensorStatistic from nncf.torch.tensor import PTNNCFTensor -from nncf.torch.tensor_statistics.collectors import PTMeanMinMaxStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMeanPercentileStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMedianMADStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMinMaxStatisticCollector -from nncf.torch.tensor_statistics.collectors import PTMixedMinMaxStatisticCollector from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor -from nncf.torch.tensor_statistics.collectors import PTPercentileStatisticCollector +from nncf.torch.tensor_statistics.collectors import get_mean_percentile_statistic_collector +from nncf.torch.tensor_statistics.collectors import get_median_mad_statistic_collector +from nncf.torch.tensor_statistics.collectors import get_min_max_statistic_collector +from nncf.torch.tensor_statistics.collectors import get_mixed_min_max_statistic_collector +from nncf.torch.tensor_statistics.collectors import get_precentile_tensor_collector from nncf.torch.tensor_statistics.statistics import PTMedianMADTensorStatistic from nncf.torch.tensor_statistics.statistics import PTMinMaxTensorStatistic from nncf.torch.tensor_statistics.statistics import PTPercentileTensorStatistic @@ -43,16 +40,22 @@ class TestCollectedStatistics: ("collector", "reduction_shapes_vs_ref_statistic"), [ ( - PTMinMaxStatisticCollector, + get_min_max_statistic_collector, { ((1,), (0, 1)): PTMinMaxTensorStatistic( - min_values=torch.tensor([-4.0]), max_values=torch.tensor([6.1]) + {"min_values": torch.tensor([-4.0]), "max_values": torch.tensor([6.1])} ), ((3, 1), (1,)): PTMinMaxTensorStatistic( - min_values=torch.tensor([[1.0], [-4.0], [4.0]]), max_values=torch.tensor([[4.5], [4.0], [6.1]]) + { + "min_values": torch.tensor([[1.0], [-4.0], [4.0]]), + "max_values": torch.tensor([[4.5], [4.0], [6.1]]), + } ), ((1, 3), (0,)): PTMinMaxTensorStatistic( - min_values=torch.tensor([[-1.3, -4.0, -3.5]]), max_values=torch.tensor([[4.5, 5.8, 6.1]]) + { + "min_values": torch.tensor([[-1.3, -4.0, -3.5]]), + "max_values": torch.tensor([[4.5, 5.8, 6.1]]), + } ), # Not supported for now: # ((3, 3), ): PTMinMaxTensorStatistic( @@ -70,37 +73,50 @@ class TestCollectedStatistics: }, ), ( - partial(PTMeanMinMaxStatisticCollector, use_per_sample_stats=False), + partial( + get_mixed_min_max_statistic_collector, + use_means_of_mins=True, + use_means_of_maxs=True, + ), { ((1,), (0, 1)): PTMinMaxTensorStatistic( - min_values=torch.tensor([-3.5]), max_values=torch.tensor([6.05]) + {"min_values": torch.tensor([-3.5]), "max_values": torch.tensor([6.05])} ), ((3, 1), (1,)): PTMinMaxTensorStatistic( - min_values=torch.tensor([[1.8], [-3.5], [4.15]]), - max_values=torch.tensor([[3.75], [3.5], [6.05]]), + { + "min_values": torch.tensor([[1.8], [-3.5], [4.15]]), + "max_values": torch.tensor([[3.75], [3.5], [6.05]]), + } ), ((1, 3), (0,)): PTMinMaxTensorStatistic( - min_values=torch.tensor([[-1.15, -3, -3.25]]), max_values=torch.tensor([[4.25, 5.4, 6.05]]) + { + "min_values": torch.tensor([[-1.15, -3, -3.25]]), + "max_values": torch.tensor([[4.25, 5.4, 6.05]]), + } ), }, ), ( partial( - PTMixedMinMaxStatisticCollector, - use_per_sample_stats=False, + get_mixed_min_max_statistic_collector, use_means_of_mins=False, use_means_of_maxs=True, ), { ((1,), (0, 1)): PTMinMaxTensorStatistic( - min_values=torch.tensor([-4.0]), max_values=torch.tensor([6.05]) + {"min_values": torch.tensor([-4.0]), "max_values": torch.tensor([6.05])} ), ((3, 1), (1,)): PTMinMaxTensorStatistic( - min_values=torch.tensor([[1.0], [-4.0], [4.0]]), - max_values=torch.tensor([[3.75], [3.5], [6.05]]), + { + "min_values": torch.tensor([[1.0], [-4.0], [4.0]]), + "max_values": torch.tensor([[3.75], [3.5], [6.05]]), + } ), ((1, 3), (0,)): PTMinMaxTensorStatistic( - min_values=torch.tensor([[-1.3, -4.0, -3.5]]), max_values=torch.tensor([[4.25, 5.4, 6.05]]) + { + "min_values": torch.tensor([[-1.3, -4.0, -3.5]]), + "max_values": torch.tensor([[4.25, 5.4, 6.05]]), + } ), }, ), @@ -112,10 +128,18 @@ def test_collected_statistics_with_shape_convert( reduction_shapes_vs_ref_statistic: Dict[Tuple[ReductionShape, ReductionShape], TensorStatistic], ): for shapes in reduction_shapes_vs_ref_statistic.keys(): - output_shape, reduction_shape = shapes - collector_obj = collector(use_abs_max=True, reduction_shape=reduction_shape, output_shape=output_shape) + output_shape, reducer_axes = shapes + collector_obj = collector( + use_abs_max=True, + reducers_axes=reducer_axes, + reducers_keepdims=len(output_shape) > 1, + aggregators_axes=(0,), + aggregators_keepdims=False, + squeeze_dims=None, + num_samples=None, + ) for input_ in TestCollectedStatistics.REF_INPUTS: - collector_obj.register_input(input_) + collector_obj.register_unnamed_inputs(PTNNCFTensor(input_)) test_stats = collector_obj.get_statistics() assert reduction_shapes_vs_ref_statistic[shapes] == test_stats @@ -123,15 +147,32 @@ def test_collected_statistics_with_shape_convert( ("collector", "reduction_shapes_vs_ref_statistic"), [ ( - PTMedianMADStatisticCollector, + get_median_mad_statistic_collector, + # PTMedianMADStatisticCollector, { - (1,): PTMedianMADTensorStatistic(median_values=torch.tensor([2.8]), mad_values=torch.tensor([2.6])), + (1,): PTMedianMADTensorStatistic( + { + "tensor_statistic_output": { + "median_values": torch.tensor([2.8]), + "mad_values": torch.tensor([2.6]), + } + } + ), (3, 1): PTMedianMADTensorStatistic( - median_values=torch.tensor([[2.8], [-2.5], [5.4]]), - mad_values=torch.tensor([[0.85], [1.1], [0.65]]), + { + "tensor_statistic_output": { + "median_values": torch.tensor([[2.8], [-2.5], [5.4]]), + "mad_values": torch.tensor([[0.85], [1.1], [0.65]]), + } + } ), (1, 3): PTMedianMADTensorStatistic( - median_values=torch.tensor([[2.5, 2.3, 3.35]]), mad_values=torch.tensor([[1.9, 3.1, 2.7]]) + { + "tensor_statistic_output": { + "median_values": torch.tensor([[2.5, 2.3, 3.35]]), + "mad_values": torch.tensor([[1.9, 3.1, 2.7]]), + } + } ), # Not supported for now: # (3, 3): PTMedianMADTensorStatistic( @@ -149,11 +190,16 @@ def test_collected_statistics_with_shape_convert( }, ), ( - partial(PTPercentileStatisticCollector, percentiles_to_collect=[10.0]), + partial(get_precentile_tensor_collector, percentiles_to_collect=[10.0]), + # partial(PTPercentileStatisticCollector, percentiles_to_collect=[10.0]), { - (1,): PTPercentileTensorStatistic({10.0: torch.tensor([-3.15])}), - (3, 1): PTPercentileTensorStatistic({10.0: torch.tensor([[1.5], [-3.75], [4.15]])}), - (1, 3): PTPercentileTensorStatistic({10.0: torch.tensor([[-1.15, -3, -3.25]])}), + (1,): PTPercentileTensorStatistic({"tensor_statistic_output": {10.0: torch.tensor([-3.15])}}), + (3, 1): PTPercentileTensorStatistic( + {"tensor_statistic_output": {10.0: torch.tensor([[1.5], [-3.75], [4.15]])}} + ), + (1, 3): PTPercentileTensorStatistic( + {"tensor_statistic_output": {10.0: torch.tensor([[-1.15, -3, -3.25]])}} + ), # Not supported for now: # (3, 3): PTPercentileTensorStatistic( # { @@ -167,11 +213,16 @@ def test_collected_statistics_with_shape_convert( }, ), ( - partial(PTMeanPercentileStatisticCollector, percentiles_to_collect=[10.0]), + # partial(PTMeanPercentileStatisticCollector, percentiles_to_collect=[10.0]), + partial(get_mean_percentile_statistic_collector, percentiles_to_collect=[10.0]), { - (1,): PTPercentileTensorStatistic({10.0: torch.tensor([-2.9])}), - (3, 1): PTPercentileTensorStatistic({10.0: torch.tensor([[2.0100], [-3.3500], [4.4000]])}), - (1, 3): PTPercentileTensorStatistic({10.0: torch.tensor([[-0.3900, -1.9400, -1.9300]])}), + (1,): PTPercentileTensorStatistic({"tensor_statistic_output": {10.0: torch.tensor([-2.9])}}), + (3, 1): PTPercentileTensorStatistic( + {"tensor_statistic_output": {10.0: torch.tensor([[2.0100], [-3.3500], [4.4000]])}} + ), + (1, 3): PTPercentileTensorStatistic( + {"tensor_statistic_output": {10.0: torch.tensor([[-0.3900, -1.9400, -1.9300]])}} + ), # Not supported for now: # (3, 3): PTPercentileTensorStatistic( # { @@ -191,88 +242,29 @@ def test_collected_statistics( collector: Type[TensorStatisticCollectorBase], reduction_shapes_vs_ref_statistic: Dict[ReductionShape, TensorStatistic], ): - for shapes in reduction_shapes_vs_ref_statistic.keys(): - reduction_shape = shapes - collector_obj = collector(reduction_shape=reduction_shape) + for reduction_shape in reduction_shapes_vs_ref_statistic.keys(): + if len(reduction_shape) > 1: + reducer_axes = ([dim for dim, val in enumerate(reduction_shape) if val == 1][0],) + aggregator_keep_dims = False + else: + reducer_axes = (0, 1) + aggregator_keep_dims = True + + collector_obj = collector( + reducers_axes=reducer_axes, + reducers_keepdims=len(reduction_shape) > 1, + aggregators_axes=(0,), + aggregators_keepdims=aggregator_keep_dims, + num_samples=None, + squeeze_dims=None, + ) for input_ in TestCollectedStatistics.REF_INPUTS: - collector_obj.register_input(input_) + if hasattr(collector_obj, "register_unnamed_inputs"): + collector_obj.register_unnamed_inputs(PTNNCFTensor(input_)) + else: + collector_obj.register_inputs(input_) test_stats = collector_obj.get_statistics() - assert reduction_shapes_vs_ref_statistic[shapes] == test_stats - - COLLECTORS = [ - partial(PTMinMaxStatisticCollector, use_abs_max=False, output_shape=(1,)), - partial( - PTMixedMinMaxStatisticCollector, - use_per_sample_stats=False, - use_abs_max=False, - use_means_of_mins=False, - use_means_of_maxs=False, - output_shape=(1,), - ), - partial(PTMeanMinMaxStatisticCollector, use_per_sample_stats=False, use_abs_max=False, output_shape=(1,)), - PTMedianMADStatisticCollector, - partial(PTPercentileStatisticCollector, percentiles_to_collect=[10.0]), - partial(PTMeanPercentileStatisticCollector, percentiles_to_collect=[10.0]), - ] - - @pytest.fixture(params=COLLECTORS) - def collector_for_interface_test(self, request): - collector_type = request.param - return collector_type(reduction_shape=(1,)) - - def test_collected_samples(self, collector_for_interface_test: TensorStatisticCollectorBase): - for input_ in TestCollectedStatistics.REF_INPUTS: - collector_for_interface_test.register_input(input_) - assert collector_for_interface_test.collected_samples() == len(TestCollectedStatistics.REF_INPUTS) - - def test_reset(self, collector_for_interface_test: TensorStatisticCollectorBase): - for input_ in TestCollectedStatistics.REF_INPUTS: - collector_for_interface_test.register_input(input_) - collector_for_interface_test.reset() - assert collector_for_interface_test.collected_samples() == 0 - with pytest.raises(StatisticsNotCollectedError): - collector_for_interface_test.get_statistics() - - def test_enable_disable(self, collector_for_interface_test: TensorStatisticCollectorBase): - for input_ in TestCollectedStatistics.REF_INPUTS: - collector_for_interface_test.register_input(input_) - - collector_for_interface_test.disable() - for input_ in TestCollectedStatistics.REF_INPUTS: - collector_for_interface_test.register_input(input_) - assert collector_for_interface_test.collected_samples() == len(TestCollectedStatistics.REF_INPUTS) - - collector_for_interface_test.enable() - for input_ in TestCollectedStatistics.REF_INPUTS: - collector_for_interface_test.register_input(input_) - assert collector_for_interface_test.collected_samples() == 2 * len(TestCollectedStatistics.REF_INPUTS) - - OFFLINE_COLLECTORS = [ - partial( - PTMixedMinMaxStatisticCollector, - use_per_sample_stats=False, - use_abs_max=False, - use_means_of_mins=False, - use_means_of_maxs=False, - output_shape=(1,), - ), - partial(PTMeanMinMaxStatisticCollector, use_per_sample_stats=False, use_abs_max=False, output_shape=(1,)), - PTMedianMADStatisticCollector, - partial(PTPercentileStatisticCollector, percentiles_to_collect=[10.0]), - partial(PTMeanPercentileStatisticCollector, percentiles_to_collect=[10.0]), - ] - - REF_NUM_SAMPLES = 3 - - @pytest.fixture(params=OFFLINE_COLLECTORS) - def collector_for_num_samples_test(self, request): - collector_type = request.param - return collector_type(reduction_shape=(1,), num_samples=TestCollectedStatistics.REF_NUM_SAMPLES) - - def test_num_samples(self, collector_for_num_samples_test: OfflineTensorStatisticCollector): - for input_ in TestCollectedStatistics.REF_INPUTS * 10: - collector_for_num_samples_test.register_input(input_) - assert collector_for_num_samples_test.collected_samples() == TestCollectedStatistics.REF_NUM_SAMPLES + assert reduction_shapes_vs_ref_statistic[reduction_shape] == test_stats class TestCollectorTensorProcessor: diff --git a/tests/torch/test_statistics_aggregator.py b/tests/torch/test_statistics_aggregator.py index 4a4b8f48914..60bfb99015a 100644 --- a/tests/torch/test_statistics_aggregator.py +++ b/tests/torch/test_statistics_aggregator.py @@ -62,7 +62,7 @@ def get_backend_model(self, dataset_samples): @pytest.fixture def is_backend_support_custom_estimators(self) -> bool: - return False + return True @pytest.fixture(scope="session") def test_params(self):