Skip to content

Commit

Permalink
Fix get_matmul_channel_axes
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 8, 2023
1 parent 0f5f147 commit 669f0a9
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 38 deletions.
38 changes: 16 additions & 22 deletions nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.layer_attributes import LayoutElem
from nncf.common.graph.layer_attributes import LinearLayerAttributes
from nncf.openvino.graph.layer_attributes import OVLayerAttributes
from nncf.openvino.graph.metatypes.groups import OPERATIONS_WITH_BIAS
from nncf.openvino.graph.metatypes.groups import OPERATIONS_WITH_WEIGHTS
Expand Down Expand Up @@ -310,7 +312,7 @@ def get_reducer_output_node_names(
return [get_result_node_name(target_node_name, port_id)]


def get_weight_channel_axes(node: NNCFNode, weights_port_id: int) -> List[int]:
def get_weight_channel_axes(node: NNCFNode) -> List[int]:
"""
Returns axes numbers of the weight tensor which correspond to its channels.
Expand All @@ -321,35 +323,27 @@ def get_weight_channel_axes(node: NNCFNode, weights_port_id: int) -> List[int]:
if node.metatype not in OPERATIONS_WITH_WEIGHTS:
raise ValueError("Channel axis cannot be defined for operation without weights.")

channel_axes = node.metatype.const_channel_axis
if node.metatype == OVMatMulMetatype:
assert isinstance(node.layer_attributes, OVLayerAttributes)
assert len(channel_axes) == 1
const_attrs = node.layer_attributes.constant_attributes[weights_port_id]
transpose = const_attrs["transpose"]
ndims = len(const_attrs["shape"])
channel_axes = get_matmul_channel_axes(weights_port_id, ndims, transpose)
if node.metatype != OVMatMulMetatype:
return node.metatype.const_channel_axis

return channel_axes
return get_matmul_channel_axes(node)


def get_matmul_channel_axes(weights_port_id: int, ndims: int, transpose: bool) -> List[int]:
def get_matmul_channel_axes(node: ov.Node) -> List[int]:
"""
Calculate channel axes for the MatMul operation.
:param weights_port_id: Weight port id of the target node.
:param ndims: The number of MatMul dimensions.
:param transpose: Whether the transpose is applied to weights.
:param node: The target node.
:return: List of channel axes for the MatMul operation.
"""
matmul_channel_axis = OVMatMulMetatype.const_channel_axis[0]
if (weights_port_id == 1) == transpose:
matmul_channel_axis -= 1
matmul_channel_axis = max(ndims, 2) + matmul_channel_axis
channel_axes = list(range(ndims - 2))
if matmul_channel_axis < ndims:
channel_axes.append(matmul_channel_axis)
return channel_axes
assert isinstance(node.layer_attributes, OVLayerAttributes)
layer_attributes = node.layer_attributes.get_backend_agnostic_attributes()
assert isinstance(layer_attributes, LinearLayerAttributes)
return [
idx
for idx, elem in enumerate(layer_attributes.weights_layout)
if elem in [LayoutElem.SPATIAL, LayoutElem.C_OUT]
]


def get_channel_agnostic_reduction_shape(channel_axes: List[int], shape: List[int]) -> Tuple[int]:
Expand Down
10 changes: 5 additions & 5 deletions nncf/openvino/quantization/weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
import openvino.runtime as ov
from openvino.runtime import opset9 as opset

from nncf.common.graph.layer_attributes import LayoutElem
from nncf.common.graph.layer_attributes import LinearLayerAttributes
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVEmbeddingMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype
from nncf.openvino.graph.metatypes.openvino_metatypes import get_operation_const_op
from nncf.openvino.graph.node_utils import get_const_value
from nncf.openvino.graph.node_utils import get_matmul_channel_axes
from nncf.quantization.fake_quantize import calculate_scale_zero_point


Expand Down Expand Up @@ -88,10 +89,9 @@ def _get_reduction_axes(metatype: Type[OperatorMetatype], node: ov.Node, weight_
:return: The reduction axes as an integer or a tuple of integers.
"""
if metatype is OVMatMulMetatype:
transpose = node.get_attributes()[f"transpose_{'a' if weight_port_id == 0 else 'b'}"]
ndims = node.input(weight_port_id).get_partial_shape().rank.get_max_length()
channel_axes = get_matmul_channel_axes(weight_port_id, ndims, transpose)
axes = tuple(i for i in range(ndims) if i not in channel_axes)
layer_attributes = node.layer_attributes.get_backend_agnostic_attributes()
assert isinstance(layer_attributes, LinearLayerAttributes)
axes = tuple(idx for idx, elem in enumerate(layer_attributes.weights_layout) if elem == LayoutElem.C_IN)
elif metatype is OVEmbeddingMetatype:
axes = (metatype.const_channel_axis[0] + 1) % 2
else:
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/advanced_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class AdvancedQuantizationParameters:
overflow_fix: OverflowFix = OverflowFix.FIRST_LAYER
quantize_outputs: bool = False
inplace_statistics: bool = True
disable_channel_alignment: bool = False
disable_channel_alignment: bool = True
disable_bias_correction: bool = False
smooth_quant_alpha: float = 0.95

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional, Tuple
from typing import Any, Tuple

import numpy as np
import openvino.runtime as ov

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
from nncf.common.utils.backend import BackendType
from nncf.experimental.common.tensor_statistics.collectors import MedianAggregator
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.openvino.graph.layer_attributes import OVLayerAttributes
from nncf.openvino.graph.metatypes.openvino_metatypes import OVAddMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVDepthwiseConvolutionMetatype
Expand All @@ -39,7 +37,6 @@
from nncf.openvino.statistics.statistics import OVMinMaxTensorStatistic
from nncf.quantization.algorithms.channel_alignment.backend import ALGO_BACKENDS
from nncf.quantization.algorithms.channel_alignment.backend import ChannelAlignmentAlgoBackend
from nncf.quantization.algorithms.channel_alignment.backend import LayoutDescriptor


@ALGO_BACKENDS.register(BackendType.OPENVINO)
Expand Down Expand Up @@ -102,5 +99,5 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
return bias_constant is not None

@staticmethod
def create_bias_tensor(node: NNCFNode, nncf_graph: NNCFGraph, value: Any)-> np.ndarray:
def create_bias_tensor(node: NNCFNode, nncf_graph: NNCFGraph, value: Any) -> np.ndarray:
return create_bias_tensor(node, nncf_graph, value)
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _get_reduction_shape_and_use_abs_max(
const_shape = node.layer_attributes.constant_attributes[target_point.port_id]["shape"]

if quantizer_config.per_channel:
channel_axes = get_weight_channel_axes(node, target_point.port_id)
channel_axes = get_weight_channel_axes(node)
axes = get_channel_agnostic_reduction_shape(channel_axes, const_shape)
else:
axes = tuple(range(len(const_shape)))
Expand Down
17 changes: 13 additions & 4 deletions tests/openvino/native/test_node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@

import numpy as np
import pytest
from openvino.runtime import opset9 as opset

from nncf.common.factory import NNCFGraphFactory
from nncf.common.graph.graph import NNCFNode
from nncf.openvino.graph.layer_attributes import OVLayerAttributes
from nncf.openvino.graph.layer_attributes import get_weighted_layer_attributes
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.openvino.graph.nncf_graph_builder import GraphConverter
from nncf.openvino.graph.node_utils import get_channel_agnostic_reduction_shape
Expand Down Expand Up @@ -60,10 +62,10 @@ def test_is_node_with_bias(model_to_create, is_with_bias, node_name):
@pytest.mark.parametrize(
"weights_port_id, transpose, shape, expected_channel_axes",
[
(0, False, (1,), [0]),
(0, False, (1,), []),
(0, True, (1,), []),
(1, False, (1,), []),
(1, True, (1,), [0]),
(1, True, (1,), []),
(0, False, (1, 1), [0]),
(0, True, (1, 1), [1]),
(1, False, (1, 1), [1]),
Expand All @@ -75,16 +77,23 @@ def test_is_node_with_bias(model_to_create, is_with_bias, node_name):
],
)
def test_get_weight_channel_axes_for_matmul(weights_port_id, transpose, shape, expected_channel_axes):
input_1 = opset.parameter([1, 1], name="Input", dtype=np.float32)
constant_1 = opset.constant(np.ones(shape).astype(np.float32))
inputs_ = (input_1, constant_1) if weights_port_id == 1 else (constant_1, input_1)
matmul_1 = opset.matmul(*inputs_, transpose_a=transpose, transpose_b=transpose, name="MatMul")

constant_attrs = {weights_port_id: {"transpose": transpose, "shape": shape}}
attributes = {
NNCFNode.ID_NODE_ATTR: 0,
NNCFNode.NODE_NAME_ATTR: "test",
NNCFNode.METATYPE_ATTR: OVMatMulMetatype,
NNCFNode.LAYER_ATTRIBUTES: OVLayerAttributes(
constant_attributes={weights_port_id: {"transpose": transpose, "shape": shape}}
layer_attributes=get_weighted_layer_attributes(matmul_1, OVMatMulMetatype, constant_attrs),
constant_attributes=constant_attrs,
),
}
node = NNCFNode(attributes)
actual_channel_axes = get_weight_channel_axes(node, weights_port_id)
actual_channel_axes = get_weight_channel_axes(node)

assert len(actual_channel_axes) == len(expected_channel_axes)
assert all(a == b for a, b in zip(actual_channel_axes, expected_channel_axes))
Expand Down

0 comments on commit 669f0a9

Please sign in to comment.