Skip to content

Commit

Permalink
Comments are applied for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 20, 2023
1 parent b197294 commit cdfffa0
Show file tree
Hide file tree
Showing 9 changed files with 310 additions and 196 deletions.
40 changes: 27 additions & 13 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __eq__(self, __o: object) -> bool:
def __hash__(self) -> int:
return hash((self.__class__.__name__, self.inplace, self._reduction_axes))

def _get_reduction_shape(self, tensor: NNCFTensor) -> Union[int, Tuple[int, ...]]:
def _get_reduction_axes(self, tensor: NNCFTensor) -> Union[int, Tuple[int, ...]]:
if self._reduction_axes is not None:
return self._reduction_axes
return tuple(range(len(tensor.shape)))
Expand Down Expand Up @@ -443,29 +443,29 @@ def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]:
class MinReducer(TensorReducerBase):
def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]:
x = x[0]
reduction_shape = self._get_reduction_shape(x)
return [self._tensor_processor.reduce_min(x, reduction_shape, keepdims=self._keepdims)]
reduction_axes = self._get_reduction_axes(x)
return [self._tensor_processor.reduce_min(x, reduction_axes, keepdims=self._keepdims)]


class MaxReducer(TensorReducerBase):
def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]:
x = x[0]
reduction_shape = self._get_reduction_shape(x)
return [self._tensor_processor.reduce_max(x, reduction_shape, keepdims=self._keepdims)]
reduction_axes = self._get_reduction_axes(x)
return [self._tensor_processor.reduce_max(x, reduction_axes, keepdims=self._keepdims)]


class AbsMaxReducer(TensorReducerBase):
def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]:
x = self._tensor_processor.abs(x[0])
reduction_shape = self._get_reduction_shape(x)
return [self._tensor_processor.reduce_max(x, reduction_shape, keepdims=self._keepdims)]
reduction_axes = self._get_reduction_axes(x)
return [self._tensor_processor.reduce_max(x, reduction_axes, keepdims=self._keepdims)]


class MeanReducer(TensorReducerBase):
def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]:
x = x[0]
reduction_shape = self._get_reduction_shape(x)
return [self._tensor_processor.mean(x, reduction_shape, keepdims=self._keepdims)]
reduction_axes = self._get_reduction_axes(x)
return [self._tensor_processor.mean(x, reduction_axes, keepdims=self._keepdims)]


class QuantileReducerBase(TensorReducerBase):
Expand All @@ -488,8 +488,8 @@ def __hash__(self) -> int:
class QuantileReducer(QuantileReducerBase):
def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]:
x = x[0]
reduction_shape = self._get_reduction_shape(x)
return self._tensor_processor.quantile(x, self._quantile, reduction_shape, keepdims=self._keepdims)
reduction_axes = self._get_reduction_axes(x)
return self._tensor_processor.quantile(x, self._quantile, reduction_axes, keepdims=self._keepdims)


class AbsQuantileReducer(QuantileReducerBase):
Expand All @@ -504,8 +504,8 @@ def __init__(

def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]:
x = self._tensor_processor.abs(x[0])
reduction_shape = self._get_reduction_shape(x)
return self._tensor_processor.quantile(x, self._quantile, reduction_shape, keepdims=self._keepdims)
reduction_axes = self._get_reduction_axes(x)
return self._tensor_processor.quantile(x, self._quantile, reduction_axes, keepdims=self._keepdims)


class BatchMeanReducer(TensorReducerBase):
Expand Down Expand Up @@ -563,6 +563,13 @@ def __init__(


class OnlineAggregatorBase(TensorAggregatorBase, ABC):
"""
Base class for aggregators which are using aggregation function fn with following property:
fn([x1, x2, x3]) == fn([fn([x1, x2]), x3]) where x1, x2, x3 are samples to aggregate.
Online aggregation fn([fn([x1, x2]), x3]) allows to keep memory stamp low as only
one sample is stored during statistic collection.
"""

def _online_register_reduced_input_impl(self, x: TensorType, fn) -> None:
online_aggregation_axes = tuple(dim - 1 for dim in self._aggregation_axes if dim != 0)
if online_aggregation_axes:
Expand Down Expand Up @@ -595,6 +602,13 @@ def _register_reduced_input_impl(self, x: TensorType) -> None:


class OfflineAggregatorBase(TensorAggregatorBase, ABC):
"""
Base class for aggregators which are using aggregation function fn which
does not fulfill property fn([x1, x2, x3]) == fn([fn([x1, x2]), x3])
where x1, x2, x3 are samples to aggregate. Child aggregators collects
all samples to a container and aggregates them in one step.
"""

def _register_reduced_input_impl(self, x: TensorType) -> None:
self._container.append(x)

Expand Down
4 changes: 2 additions & 2 deletions nncf/tensorflow/quantization/init_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _register_layer_statistics(self, layer: tf.keras.layers.Layer, layer_statist
collector = RangeInitializer.generate_stat_collector(
reduction_shape, collector_params, init_config, num_batches
)
handles.append(layer.register_hook_pre_quantizer(collector.register_inputs))
handles.append(layer.register_hook_pre_quantizer(collector.register_input))
layer.enabled = False
layer_statistics.append((layer, collector))

Expand All @@ -180,7 +180,7 @@ def _register_op_statistics(self, layer: tf.keras.layers.Layer, op_statistics: l
collector = RangeInitializer.generate_stat_collector(
reduction_shape, collector_params, init_config, num_batches
)
handles.append(op.register_hook_pre_call(collector.register_inputs))
handles.append(op.register_hook_pre_call(collector.register_input))
op.enabled = False
op_statistics.append((layer, op_name, op, collector))

Expand Down
4 changes: 2 additions & 2 deletions nncf/torch/quantization/init_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from nncf.torch.quantization.layers import get_scale_shape
from nncf.torch.quantization.translator import PTTargetPointTranslator
from nncf.torch.tensor_statistics.algo import TensorStatisticObservationPoint
from nncf.torch.tensor_statistics.algo import register_inputs_hook_factory
from nncf.torch.tensor_statistics.algo import create_register_input_hook
from nncf.torch.tensor_statistics.collectors import get_mean_percentile_statistic_collector
from nncf.torch.tensor_statistics.collectors import get_median_mad_statistic_collector
from nncf.torch.tensor_statistics.collectors import get_min_max_statistic_collector
Expand Down Expand Up @@ -281,7 +281,7 @@ def __init__(
def _get_fwd_hook(
self, collector: TensorStatisticCollectorBase
) -> Callable[["torch.Module", torch.Tensor, torch.Tensor], torch.Tensor]:
hook = register_inputs_hook_factory(collector=collector)
hook = create_register_input_hook(collector=collector)

def fwd_hook(module, input_, output):
hook(input_[0])
Expand Down
4 changes: 2 additions & 2 deletions nncf/torch/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.tensor import PTNNCFTensor
from nncf.torch.tensor_statistics.algo import register_inputs_hook_factory
from nncf.torch.tensor_statistics.algo import create_register_input_hook


class PTStatisticsAggregator(StatisticsAggregator):
Expand All @@ -50,7 +50,7 @@ def _get_transformation_layout_extra_outputs(
transformation_commands.append(
PTInsertionCommand(
_statistic_point.target_point,
register_inputs_hook_factory(collector=collector),
create_register_input_hook(collector=collector),
TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION,
)
)
Expand Down
6 changes: 3 additions & 3 deletions nncf/torch/tensor_statistics/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def __eq__(self, other: "TensorStatisticObservationPoint"):
return self.target_point == other.target_point


def register_inputs_hook_factory(collector: TensorCollector) -> Callable[[torch.Tensor], torch.Tensor]:
def create_register_input_hook(collector: TensorCollector) -> Callable[[torch.Tensor], torch.Tensor]:
"""
Factory function to get regiter inputs hook function.
Function to create regiter inputs hook function.
:param collector: Collector to use in resulting hook.
:return: Register inputs hook function.
Expand Down Expand Up @@ -83,7 +83,7 @@ def _get_transformation_layout(self, target_model: NNCFNetwork) -> PTTransformat
for collector in rs_vs_collector.values():
command = PTInsertionCommand(
op.target_point,
register_inputs_hook_factory(collector=collector),
create_register_input_hook(collector=collector),
TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION,
)
layout.register(command)
Expand Down
76 changes: 50 additions & 26 deletions tests/common/experimental/test_reducers_and_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
# limitations under the License.

from abc import abstractmethod
from dataclasses import dataclass
from functools import partial
from itertools import product
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, Union

import numpy as np
import pytest
Expand Down Expand Up @@ -48,25 +49,35 @@
default_test_quantile = 0.1


REF_TYPE = List[Union[float, "REF_TYPE"]]


@dataclass
class OfflineAggregatorTestCase:
aggregation_axes: Optional[Tuple[int, ...]]
min_ref: REF_TYPE
max_ref: REF_TYPE


OFFLINE_AGGREGATORS_TEST_CASES = [
(
None,
[[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]],
[[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]],
OfflineAggregatorTestCase(
aggregation_axes=None,
min_ref=[[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]],
max_ref=[[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]],
),
(
(0,),
[[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]],
[[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]],
OfflineAggregatorTestCase(
aggregation_axes=(0,),
min_ref=[[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]],
max_ref=[[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]],
),
(
(0, 2),
[[-50000, -28, -32]],
[[50000, 28, 32]],
OfflineAggregatorTestCase(
aggregation_axes=(0, 2),
min_ref=[[-50000, -28, -32]],
max_ref=[[50000, 28, 32]],
),
(
(2,),
[
OfflineAggregatorTestCase(
aggregation_axes=(2,),
min_ref=[
[[-50000, 5, 10]],
[[-40000, 4, 8]],
[[-30000, 3, 6]],
Expand All @@ -78,7 +89,7 @@
[[-18, -21, -24]],
[[-24, -28, -32]],
],
[
max_ref=[
[[50000, -5, -10]],
[[40000, -4, -8]],
[[30000, -3, -6]],
Expand Down Expand Up @@ -160,8 +171,8 @@ def test_noop_reducer(self, reducers):
def test_min_max_mean_reducers(self, reducer_name, ref, reducers):
reduction_axes = (1, 2)
input_ = np.arange(-26, 10).reshape((4, 3, 3))
for i, red_axes in enumerate([reduction_axes, None]):
reducer = reducers[reducer_name](reduction_axes=red_axes, inplace=False)
for i, reduction_axes_ in enumerate([reduction_axes, None]):
reducer = reducers[reducer_name](reduction_axes=reduction_axes_, inplace=False)
val = reducer([self.get_nncf_tensor(input_, Dtype.FLOAT)])
assert len(val) == 1
assert self.all_close(val[0].tensor, self.cast_tensor(ref[i], Dtype.FLOAT))
Expand Down Expand Up @@ -218,10 +229,11 @@ def test_shape_aggregator(self):
assert ref_shape == aggregator.aggregate()

@pytest.mark.parametrize(
"aggregation_axes,min_ref,max_ref",
"offline_aggregators_test_desc",
OFFLINE_AGGREGATORS_TEST_CASES,
)
def test_min_max_aggregators(self, aggregation_axes, min_ref, max_ref, tensor_processor):
def test_min_max_aggregators(self, offline_aggregators_test_desc: OfflineAggregatorTestCase, tensor_processor):
aggregation_axes = offline_aggregators_test_desc.aggregation_axes
min_aggregator = MinAggregator(tensor_processor=tensor_processor, aggregation_axes=aggregation_axes)
max_aggregator = MaxAggregator(tensor_processor=tensor_processor, aggregation_axes=aggregation_axes)
input_ = np.arange(3 * 3).reshape((1, 3, 3))
Expand All @@ -230,6 +242,8 @@ def test_min_max_aggregators(self, aggregation_axes, min_ref, max_ref, tensor_pr
min_aggregator.register_reduced_input(self.get_nncf_tensor(input_ * (-i)))
max_aggregator.register_reduced_input(self.get_nncf_tensor(input_ * i))

min_ref = offline_aggregators_test_desc.min_ref
max_ref = offline_aggregators_test_desc.max_ref
assert self.all_close(
min_aggregator.aggregate(),
min_ref,
Expand Down Expand Up @@ -309,23 +323,33 @@ def test_mean_median_agggregators(self, aggregator_cls, refs, tensor_processor,

REF_MAD_PERCENTILE_REF_VALUES = {
MedianAbsoluteDeviationAggregator: {
1: {
None: {
"median_values": np.array([4.5, 9.0, 13.5, 18.0, 22.5, 27.0, 31.5, 36.0, 40.5]),
"mad_values": np.array([2.5, 5.0, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 22.5]),
},
(0,): {
"median_values": np.array([4.5, 9.0, 13.5, 18.0, 22.5, 27.0, 31.5, 36.0, 40.5]),
"mad_values": np.array([2.5, 5.0, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 22.5]),
},
2: {
(0, 1): {
"median_values": np.array(18.0),
"mad_values": np.array(12.0),
},
},
percentileAggregator: {
1: {
None: {
5: np.array([0.4, 0.8, 1.2, 1.6, 2.0, 2.4, 2.8, 3.2, 3.6]),
10: np.array([0.8, 1.6, 2.4, 3.2, 4.0, 4.8, 5.6, 6.4, 7.2]),
90: np.array([7.2, 14.4, 21.6, 28.8, 36.0, 43.2, 50.4, 57.6, 64.8]),
95: np.array([7.6, 15.2, 22.8, 30.4, 38.0, 45.6, 53.2, 60.8, 68.4]),
},
(0,): {
5: np.array([0.4, 0.8, 1.2, 1.6, 2.0, 2.4, 2.8, 3.2, 3.6]),
10: np.array([0.8, 1.6, 2.4, 3.2, 4.0, 4.8, 5.6, 6.4, 7.2]),
90: np.array([7.2, 14.4, 21.6, 28.8, 36.0, 43.2, 50.4, 57.6, 64.8]),
95: np.array([7.6, 15.2, 22.8, 30.4, 38.0, 45.6, 53.2, 60.8, 68.4]),
},
2: {
(0, 1): {
5: np.array(0.0),
10: np.array(0.0),
90: np.array(48.0),
Expand All @@ -352,7 +376,7 @@ def test_mad_percentile_aggregators(self, aggregator_cls, tensor_processor, aggr
aggregator.register_reduced_input(self.get_nncf_tensor(input_ * i, Dtype.FLOAT))

ret_val = aggregator.aggregate()
ref_values = self.REF_MAD_PERCENTILE_REF_VALUES[aggregator.__class__][len(aggregation_axes or (0,))]
ref_values = self.REF_MAD_PERCENTILE_REF_VALUES[aggregator.__class__][aggregation_axes]
assert len(ret_val) == len(ref_values)
for k, v in ref_values.items():
assert self.all_close(ret_val[k], self.cast_tensor(v, Dtype.FLOAT))
Expand Down
5 changes: 4 additions & 1 deletion tests/common/test_statistics_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def inplace_statistics(self) -> bool:
@abstractmethod
@pytest.fixture
def is_backend_support_custom_estimators(self) -> bool:
pass
"""
False if backend can initialize only following tensor collectors:
MinMax, MeanMinMax.
"""

@abstractmethod
def reducers_map(self) -> List[TensorReducerBase]:
Expand Down
6 changes: 3 additions & 3 deletions tests/post_training/test_templates/test_channel_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,16 +480,16 @@ class MockBackend(backend_cls):
@pytest.mark.parametrize("inplace_ref", [False, True])
@pytest.mark.parametrize("q_ref", [1e-4, 0.3])
def test_statistic_collectors(self, inplace_ref, q_ref):
reduction_shape_ref = (0, 2, 3)
reduction_axes_ref = (0, 2, 3)
num_samples_ref = 123
statistic_collector: TensorCollector = self.get_backend_cls().get_statistic_collector(
reduction_shape=reduction_shape_ref, q=q_ref, num_samples=num_samples_ref, inplace=inplace_ref
reduction_shape=reduction_axes_ref, q=q_ref, num_samples=num_samples_ref, inplace=inplace_ref
)

assert len(statistic_collector.reducers) == 1
reducer = statistic_collector.reducers.pop()
assert isinstance(reducer, QuantileReducer)
assert reducer._reduction_axes == reduction_shape_ref
assert reducer._reduction_axes == reduction_axes_ref
assert np.allclose(reducer._quantile, (q_ref, 1 - q_ref))

assert len(statistic_collector.aggregators) == 2
Expand Down
Loading

0 comments on commit cdfffa0

Please sign in to comment.