diff --git a/nncf/openvino/graph/layer_attributes.py b/nncf/openvino/graph/layer_attributes.py index 5609c3f34e0..49b2f64f479 100644 --- a/nncf/openvino/graph/layer_attributes.py +++ b/nncf/openvino/graph/layer_attributes.py @@ -80,10 +80,7 @@ def get_backend_agnostic_attributes(self): ] -OVConvLayout = List[OVLayoutElem] - - -def get_conv_weights_layout_from_node(node: NNCFNode) -> OVConvLayout: +def get_conv_weights_layout_from_node(node: NNCFNode) -> List[OVLayoutElem]: """ Calculates weights layout for a target convolution node. @@ -97,7 +94,7 @@ def get_conv_weights_layout_from_node(node: NNCFNode) -> OVConvLayout: ) -def get_linear_weights_layout_from_node(node: NNCFNode) -> OVConvLayout: +def get_linear_weights_layout_from_node(node: NNCFNode) -> List[OVLayoutElem]: """ Calculates weights layout for a target linear node. @@ -126,7 +123,7 @@ def _get_constant_port_id_from_layer_attributes(layer_attributes: OVLayerAttribu return port_ids[0] -def get_conv_weights_layout(ov_metatype: OVOpMetatype, weights_shape: Tuple[int, ...]) -> OVConvLayout: +def get_conv_weights_layout(ov_metatype: OVOpMetatype, weights_shape: Tuple[int, ...]) -> List[OVLayoutElem]: """ Calculates weights layout for a target convolution node. @@ -140,7 +137,7 @@ def get_conv_weights_layout(ov_metatype: OVOpMetatype, weights_shape: Tuple[int, return tuple(weights_layout) -def get_linear_weights_layout(weights_shape: Tuple[int, ...], transpose: bool, port_id: int) -> OVConvLayout: +def get_linear_weights_layout(weights_shape: Tuple[int, ...], transpose: bool, port_id: int) -> List[OVLayoutElem]: """ Calculates weights layout for a target linear node. diff --git a/nncf/openvino/graph/metatypes/openvino_metatypes.py b/nncf/openvino/graph/metatypes/openvino_metatypes.py index 4da63cda226..ab1f99a014d 100644 --- a/nncf/openvino/graph/metatypes/openvino_metatypes.py +++ b/nncf/openvino/graph/metatypes/openvino_metatypes.py @@ -20,6 +20,7 @@ from nncf.common.graph.operator_metatypes import OperatorMetatypeRegistry from nncf.common.graph.operator_metatypes import UnknownMetatype from nncf.common.hardware.opset import HWConfigOpName +from nncf.openvino.graph.layout import OVLayoutElem OV_OPERATOR_METATYPES = OperatorMetatypeRegistry("openvino_operator_metatypes") @@ -58,7 +59,8 @@ class OVConvolutionMetatype(OVOpMetatype): name = "ConvOp" op_names = ["Convolution"] hw_config_names = [HWConfigOpName.CONVOLUTION] - const_channel_axis = [0] # const layout: [C_OUT, C_IN, Z, Y, X] + const_channel_axis = [0] + const_layout = [OVLayoutElem.C_OUT, OVLayoutElem.C_IN] output_channel_axis = 1 @@ -67,7 +69,8 @@ class OVConvolutionBackpropDataMetatype(OVOpMetatype): name = "ConvBackpropDataOp" op_names = ["ConvolutionBackpropData"] hw_config_names = [HWConfigOpName.CONVOLUTION] - const_channel_axis = [1] # const layout: [C_IN, C_OUT, Z, Y, X] + const_channel_axis = [1] + const_layout = [OVLayoutElem.C_IN, OVLayoutElem.C_OUT] output_channel_axis = 1 @@ -76,7 +79,8 @@ class OVDepthwiseConvolutionMetatype(OVOpMetatype): name = "DepthwiseConvolutionOp" op_names = ["GroupConvolution"] hw_config_names = [HWConfigOpName.DEPTHWISECONVOLUTION] - const_channel_axis = [0, 1] # const layout: [GROUPS, C_OUT / GROUPS, C_IN / GROUPS, Z, Y, X] + const_channel_axis = [0, 1] + const_layout = [OVLayoutElem.GROUPS, OVLayoutElem.C_OUT, OVLayoutElem.C_IN] output_channel_axis = 1 @classmethod @@ -90,7 +94,8 @@ class OVGroupConvolutionMetatype(OVOpMetatype): op_names = ["GroupConvolution"] hw_config_names = [HWConfigOpName.CONVOLUTION] subtypes = [OVDepthwiseConvolutionMetatype] - const_channel_axis = [0, 1] # const layout: [GROUPS, C_OUT / GROUPS, C_IN / GROUPS, Z, Y, X] + const_channel_axis = [0, 1] + const_layout = [OVLayoutElem.GROUPS, OVLayoutElem.C_OUT, OVLayoutElem.C_IN] output_channel_axis = 1 @@ -99,7 +104,8 @@ class OVGroupConvolutionBackpropDataMetatype(OVOpMetatype): name = "GroupConvolutionBackpropDataOp" op_names = ["GroupConvolutionBackpropData"] hw_config_names = [HWConfigOpName.CONVOLUTION] - const_channel_axis = [0, 2] # const layout: [GROUPS, C_IN / GROUPS, C_OUT / GROUPS, Z, Y, X] + const_channel_axis = [0, 2] + const_layout = [OVLayoutElem.GROUPS, OVLayoutElem.C_IN, OVLayoutElem.C_OUT] output_channel_axis = 1