Skip to content

Commit

Permalink
Merge branch 'dl/conv_layer_attrs_update' into dl/quantization/passes…
Browse files Browse the repository at this point in the history
…_for_splitted_graphs
  • Loading branch information
daniil-lyakhov committed Nov 16, 2023
2 parents 6bc2f86 + 3fb5f9a commit 03e3faa
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
11 changes: 4 additions & 7 deletions nncf/openvino/graph/layer_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,7 @@ def get_backend_agnostic_attributes(self):
]


OVConvLayout = List[OVLayoutElem]


def get_conv_weights_layout_from_node(node: NNCFNode) -> OVConvLayout:
def get_conv_weights_layout_from_node(node: NNCFNode) -> List[OVLayoutElem]:
"""
Calculates weights layout for a target convolution node.
Expand All @@ -97,7 +94,7 @@ def get_conv_weights_layout_from_node(node: NNCFNode) -> OVConvLayout:
)


def get_linear_weights_layout_from_node(node: NNCFNode) -> OVConvLayout:
def get_linear_weights_layout_from_node(node: NNCFNode) -> List[OVLayoutElem]:
"""
Calculates weights layout for a target linear node.
Expand Down Expand Up @@ -126,7 +123,7 @@ def _get_constant_port_id_from_layer_attributes(layer_attributes: OVLayerAttribu
return port_ids[0]


def get_conv_weights_layout(ov_metatype: OVOpMetatype, weights_shape: Tuple[int, ...]) -> OVConvLayout:
def get_conv_weights_layout(ov_metatype: OVOpMetatype, weights_shape: Tuple[int, ...]) -> List[OVLayoutElem]:
"""
Calculates weights layout for a target convolution node.
Expand All @@ -140,7 +137,7 @@ def get_conv_weights_layout(ov_metatype: OVOpMetatype, weights_shape: Tuple[int,
return tuple(weights_layout)


def get_linear_weights_layout(weights_shape: Tuple[int, ...], transpose: bool, port_id: int) -> OVConvLayout:
def get_linear_weights_layout(weights_shape: Tuple[int, ...], transpose: bool, port_id: int) -> List[OVLayoutElem]:
"""
Calculates weights layout for a target linear node.
Expand Down
16 changes: 11 additions & 5 deletions nncf/openvino/graph/metatypes/openvino_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nncf.common.graph.operator_metatypes import OperatorMetatypeRegistry
from nncf.common.graph.operator_metatypes import UnknownMetatype
from nncf.common.hardware.opset import HWConfigOpName
from nncf.openvino.graph.layout import OVLayoutElem

OV_OPERATOR_METATYPES = OperatorMetatypeRegistry("openvino_operator_metatypes")

Expand Down Expand Up @@ -58,7 +59,8 @@ class OVConvolutionMetatype(OVOpMetatype):
name = "ConvOp"
op_names = ["Convolution"]
hw_config_names = [HWConfigOpName.CONVOLUTION]
const_channel_axis = [0] # const layout: [C_OUT, C_IN, Z, Y, X]
const_channel_axis = [0]
const_layout = [OVLayoutElem.C_OUT, OVLayoutElem.C_IN]
output_channel_axis = 1


Expand All @@ -67,7 +69,8 @@ class OVConvolutionBackpropDataMetatype(OVOpMetatype):
name = "ConvBackpropDataOp"
op_names = ["ConvolutionBackpropData"]
hw_config_names = [HWConfigOpName.CONVOLUTION]
const_channel_axis = [1] # const layout: [C_IN, C_OUT, Z, Y, X]
const_channel_axis = [1]
const_layout = [OVLayoutElem.C_IN, OVLayoutElem.C_OUT]
output_channel_axis = 1


Expand All @@ -76,7 +79,8 @@ class OVDepthwiseConvolutionMetatype(OVOpMetatype):
name = "DepthwiseConvolutionOp"
op_names = ["GroupConvolution"]
hw_config_names = [HWConfigOpName.DEPTHWISECONVOLUTION]
const_channel_axis = [0, 1] # const layout: [GROUPS, C_OUT / GROUPS, C_IN / GROUPS, Z, Y, X]
const_channel_axis = [0, 1]
const_layout = [OVLayoutElem.GROUPS, OVLayoutElem.C_OUT, OVLayoutElem.C_IN]
output_channel_axis = 1

@classmethod
Expand All @@ -90,7 +94,8 @@ class OVGroupConvolutionMetatype(OVOpMetatype):
op_names = ["GroupConvolution"]
hw_config_names = [HWConfigOpName.CONVOLUTION]
subtypes = [OVDepthwiseConvolutionMetatype]
const_channel_axis = [0, 1] # const layout: [GROUPS, C_OUT / GROUPS, C_IN / GROUPS, Z, Y, X]
const_channel_axis = [0, 1]
const_layout = [OVLayoutElem.GROUPS, OVLayoutElem.C_OUT, OVLayoutElem.C_IN]
output_channel_axis = 1


Expand All @@ -99,7 +104,8 @@ class OVGroupConvolutionBackpropDataMetatype(OVOpMetatype):
name = "GroupConvolutionBackpropDataOp"
op_names = ["GroupConvolutionBackpropData"]
hw_config_names = [HWConfigOpName.CONVOLUTION]
const_channel_axis = [0, 2] # const layout: [GROUPS, C_IN / GROUPS, C_OUT / GROUPS, Z, Y, X]
const_channel_axis = [0, 2]
const_layout = [OVLayoutElem.GROUPS, OVLayoutElem.C_IN, OVLayoutElem.C_OUT]
output_channel_axis = 1


Expand Down

0 comments on commit 03e3faa

Please sign in to comment.