Skip to content

Commit

Permalink
Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Aug 24, 2023
1 parent 2f5d4ac commit 5437420
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 150 deletions.
10 changes: 9 additions & 1 deletion nncf/quantization/algorithms/channel_alignment/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging import nncf_logger
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.openvino.graph.model_utils import create_bias_constant_value
from nncf.openvino.graph.node_utils import get_channel_agnostic_reduction_shape
from nncf.openvino.graph.node_utils import get_weight_channel_axes
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.algorithms.channel_alignment.backend import ALGO_BACKENDS
from nncf.quantization.algorithms.channel_alignment.backend import ChannelAlignmentAlgoBackend
Expand Down Expand Up @@ -116,6 +116,10 @@ def filter_func(point: StatisticPoint) -> bool:
assert len(tensor_collectors) == 1
stat = tensor_collectors[0].get_statistics()
if stat.min_values is None or stat.max_values is None:
nncf_logger.debug(
f"Skipping channel alignment for pairs {conv_in.node_name}, {conv_out.node_name} "
"because statistics were not collected for this pair."
)
continue

conv_in_cont = ConvParamsContainer(conv_in, model, graph, self._backend_entity)
Expand All @@ -124,6 +128,10 @@ def filter_func(point: StatisticPoint) -> bool:
conv_in_cont.dims.conv_weight_out_channels_dim is None
or conv_out_cont.dims.conv_weight_out_channels_dim is None
):
nncf_logger.debug(
f"Skipping channel alignment for pairs {conv_in.node_name}, {conv_out.node_name} "
" because one of the node is 1D MatMul, 1D Matmuls are not supported by CA algortihm yet."
)
continue

amean = (stat.max_values + stat.min_values) * 0.5
Expand Down
11 changes: 0 additions & 11 deletions nncf/quantization/algorithms/channel_alignment/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,6 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
(bias is added to the output tensor of that operation), `False` otherwise.
"""

@staticmethod
@abstractmethod
def get_dims_descriptor(node: NNCFNode) -> LayoutDescriptor:
"""
Return weights layout descriptor of the given node if it is possible and None otherwise.
Only convolutional and linear nodes are supported.
:param node: NNCFNode to get layout descriptor from.
:return: Weights layout descriptor of the given node if it is possible and None otherwise.
"""

@staticmethod
@abstractmethod
def get_conv_layer_attributes(node: NNCFNode) -> Optional[ConvolutionLayerAttributes]:
Expand Down
41 changes: 0 additions & 41 deletions nncf/quantization/algorithms/channel_alignment/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,44 +99,3 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:

bias_constant = get_node_with_bias_value(add_node, nncf_graph)
return bias_constant is not None

@staticmethod
def get_dims_descriptor(node: NNCFNode):
if node.metatype == OVConvolutionMetatype:
return LayoutDescriptor(
conv_weight_out_channels_dim=0,
conv_weight_in_channels_dim=1,
bias_channels_dim=node.metatype.output_channel_axis,
)
if node.metatype in [OVGroupConvolutionMetatype, OVDepthwiseConvolutionMetatype]:
# Using groups dim as output channels dim for ChannelAlignment algorithm
# TODO(dlyakhov) support group convolutions with groups number not in [1, out_channels]
return LayoutDescriptor(
conv_weight_out_channels_dim=0,
conv_weight_in_channels_dim=2,
bias_channels_dim=node.metatype.output_channel_axis,
)
if node.metatype == OVMatMulMetatype:
if node.layer_attributes is None:
raise RuntimeError(f"Attempt to align matmul node {node.node_name} that have no any constant inputs")
layer_attributes: OVLayerAttributes = node.layer_attributes
key = layer_attributes.get_const_port_ids()
assert len(key) == 1
key = key[0]
const_attr = layer_attributes.constant_attributes[key]
a, b = list(range(len(const_attr["shape"])))[-2:]
assert key in [a, b]
if key == a:
out_ch_dim = a
in_ch_dim = b
else:
out_ch_dim = b
in_ch_dim = a
if const_attr.get("transpose", False):
out_ch_dim, in_ch_dim = in_ch_dim, out_ch_dim
return LayoutDescriptor(
conv_weight_in_channels_dim=in_ch_dim,
conv_weight_out_channels_dim=out_ch_dim,
bias_channels_dim=node.metatype.output_channel_axis,
)
raise RuntimeError(f"Could not retrieve dims description for node {node} with metatype {node.metatype}")
11 changes: 0 additions & 11 deletions nncf/quantization/algorithms/smooth_quant/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,17 +215,6 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int:
:return: Channel axis number.
"""

@staticmethod
@abstractmethod
def get_weight_channel_axis(node: NNCFNode, port_id: int) -> int:
"""
Returns axis number of the weight tensor which correspond to it channel.
:param node: NNCFNode instance.
:param port_id: Specified input port id.
:return: Channel axis number.
"""

@staticmethod
@abstractmethod
def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int:
Expand Down
56 changes: 1 addition & 55 deletions tests/openvino/native/quantization/test_channel_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,20 @@

from typing import Type

import pytest

from nncf.common.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.openvino.graph.layer_attributes import OVLayerAttributes
from nncf.openvino.graph.metatypes.openvino_metatypes import OVAddMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConstantMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVGroupConvolutionMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.openvino.graph.transformations.command_creation import OVCommandCreator
from nncf.openvino.graph.transformations.commands import OVBiasCorrectionCommand
from nncf.openvino.graph.transformations.commands import OVBiasInsertionCommand
from nncf.openvino.graph.transformations.commands import OVTargetPoint
from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand
from nncf.quantization.algorithms.channel_alignment.backend import LayoutDescriptor
from nncf.quantization.algorithms.channel_alignment.openvino_backend import OVChannelAlignmentAlgoBackend
from tests.post_training.test_templates.test_channel_alignment import TemplateTestChannelAlignment


def _get_nncf_node(metatype, layer_attrs):
return NNCFNode(
{
NNCFNode.ID_NODE_ATTR: 0,
NNCFNode.NODE_NAME_ATTR: "test",
NNCFNode.METATYPE_ATTR: metatype,
NNCFNode.LAYER_ATTRIBUTES: layer_attrs,
}
)


class TestOVChannelAlignment(TemplateTestChannelAlignment):
def get_backend_cls(self) -> Type[OVChannelAlignmentAlgoBackend]:
return OVChannelAlignmentAlgoBackend
Expand All @@ -50,7 +33,7 @@ def target_point(self, target_type: TargetType, target_node_name: str, port_id:
return OVTargetPoint(target_type, target_node_name, port_id)

def convert_conv_layer_attrs(self, layer_attributes):
return OVLayerAttributes({}, {1: layer_attributes})
return OVLayerAttributes({}, layer_attributes)

def get_conv_metatype(self):
return OVConvolutionMetatype
Expand All @@ -69,40 +52,3 @@ def get_transformation_commands(self):

def mock_command_creation_factory(self, mocker) -> None:
mocker.patch("nncf.common.factory.CommandCreatorFactory.create", return_value=OVCommandCreator)

@pytest.mark.parametrize("transpose", [False, True])
@pytest.mark.parametrize("shape", [[3, 4], [1, 2, 3, 4]])
@pytest.mark.parametrize("port_id", [-1, -2])
def test_get_dims_descriptor_matmul(self, transpose, shape, port_id):
_port_id = len(shape) + port_id
node = _get_nncf_node(OVMatMulMetatype, OVLayerAttributes({_port_id: {"transpose": transpose, "shape": shape}}))
dims_descr = OVChannelAlignmentAlgoBackend.get_dims_descriptor(node)

in_dims, out_dims = (0, 1) if port_id == -1 else (1, 0)
if len(shape) > 2:
in_dims += 2
out_dims += 2
if transpose:
in_dims, out_dims = out_dims, in_dims

assert dims_descr.conv_weight_in_channels_dim == in_dims
assert dims_descr.conv_weight_out_channels_dim == out_dims
assert dims_descr.bias_channels_dim == OVMatMulMetatype.output_channel_axis

def test_get_dims_descriptor_mm_no_layer_attrs(self):
node = _get_nncf_node(OVMatMulMetatype, None)
with pytest.raises(RuntimeError):
OVChannelAlignmentAlgoBackend.get_dims_descriptor(node)

@pytest.mark.parametrize(
"metatype,ref_desc",
[
(OVConvolutionMetatype, LayoutDescriptor(0, 1, 1)),
(OVGroupConvolutionMetatype, LayoutDescriptor(0, 2, 1)),
(OVGroupConvolutionMetatype, LayoutDescriptor(0, 2, 1)),
],
)
def test_get_dims_descriptor_convs(self, metatype, ref_desc):
node = _get_nncf_node(metatype, None)
dims_descr = OVChannelAlignmentAlgoBackend.get_dims_descriptor(node)
assert dims_descr.__dict__ == ref_desc.__dict__
5 changes: 4 additions & 1 deletion tests/post_training/test_templates/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(
conv_metatype,
add_metatype,
conv_layer_attrs=None,
conv_layer_attrs_1=None,
both_biases=True,
add_layer_attrs=None,
constant_metatype=ConstantTestMetatype,
Expand All @@ -187,14 +188,16 @@ def __init__(
# |
# Add_2
# Output_1
if conv_layer_attrs_1 is None:
conv_layer_attrs_1 = conv_layer_attrs
nodes = [
NodeWithType("Input_1", InputNoopMetatype),
NodeWithType("Conv_1_W", constant_metatype),
NodeWithType("Conv_1", conv_metatype, layer_attributes=conv_layer_attrs),
NodeWithType("Add_1_W", constant_metatype),
NodeWithType("Add_1", add_metatype, layer_attributes=add_layer_attrs),
NodeWithType("Conv_2_W", constant_metatype),
NodeWithType("Conv_2", conv_metatype, layer_attributes=conv_layer_attrs),
NodeWithType("Conv_2", conv_metatype, layer_attributes=conv_layer_attrs_1),
NodeWithType("Output_1", OutputNoopMetatype),
]
if both_biases:
Expand Down
Loading

0 comments on commit 5437420

Please sign in to comment.