From b37bd30391eea44ae081d5a6184be8b9a4919e91 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 16 Nov 2023 11:00:27 +0100 Subject: [PATCH] Constant filtering and shape of removal passes are updated --- nncf/common/graph/operator_metatypes.py | 1 + nncf/common/utils/dot_file_rw.py | 2 + nncf/onnx/graph/nncf_graph_builder.py | 5 +- .../graph/metatypes/openvino_metatypes.py | 56 ++++++++++++++++ .../algorithms/min_max/algorithm.py | 2 +- .../algorithms/min_max/backend.py | 5 -- .../algorithms/min_max/onnx_backend.py | 4 -- nncf/quantization/passes.py | 42 +++++++----- nncf/torch/graph/operator_metatypes.py | 22 +++++++ .../test_constant_filtering_model_after.dot | 12 ---- .../test_constant_filtering_model_after0.dot | 16 +++++ .../test_constant_filtering_model_after1.dot | 16 +++++ .../test_constant_filtering_model_before.dot | 18 ------ .../test_constant_filtering_model_before0.dot | 42 ++++++++++++ .../test_constant_filtering_model_before1.dot | 64 +++++++++++++++++++ tests/common/quantization/test_passes.py | 32 +++++++--- tests/post_training/test_templates/models.py | 60 +++++++++++++++-- 17 files changed, 326 insertions(+), 73 deletions(-) delete mode 100644 tests/common/data/reference_graphs/passes/test_constant_filtering_model_after.dot create mode 100644 tests/common/data/reference_graphs/passes/test_constant_filtering_model_after0.dot create mode 100644 tests/common/data/reference_graphs/passes/test_constant_filtering_model_after1.dot delete mode 100644 tests/common/data/reference_graphs/passes/test_constant_filtering_model_before.dot create mode 100644 tests/common/data/reference_graphs/passes/test_constant_filtering_model_before0.dot create mode 100644 tests/common/data/reference_graphs/passes/test_constant_filtering_model_before1.dot diff --git a/nncf/common/graph/operator_metatypes.py b/nncf/common/graph/operator_metatypes.py index fc9d89f5f27..49869cfbb2e 100644 --- a/nncf/common/graph/operator_metatypes.py +++ b/nncf/common/graph/operator_metatypes.py @@ -29,6 +29,7 @@ class OperatorMetatype: hw_config_names: List[str] = [] output_channel_axis: Optional[int] = None ignored_input_ports: List[int] = [] + input_edges_num_expected = None @classmethod def get_all_aliases(cls) -> List[str]: diff --git a/nncf/common/utils/dot_file_rw.py b/nncf/common/utils/dot_file_rw.py index f956b22c8ac..bf7b667c361 100644 --- a/nncf/common/utils/dot_file_rw.py +++ b/nncf/common/utils/dot_file_rw.py @@ -50,6 +50,8 @@ def read_dot_graph(path: pathlib.Path) -> nx.DiGraph: def _maybe_escape_colons_in_attrs(data: Dict): for attr_name in data: attr_val = data[attr_name] + if not isinstance(attr_val, str): + continue if RESERVED_CHAR in attr_val and not (attr_val[0] == '"' or attr_val[-1] == '"'): data[attr_name] = '"' + data[attr_name] + '"' # escaped colons are allowed diff --git a/nncf/onnx/graph/nncf_graph_builder.py b/nncf/onnx/graph/nncf_graph_builder.py index 177e895ef75..6c0c7291572 100644 --- a/nncf/onnx/graph/nncf_graph_builder.py +++ b/nncf/onnx/graph/nncf_graph_builder.py @@ -67,7 +67,7 @@ def __init__( self.weight_attrs = weight_attrs if weight_attrs is not None else {} self.bias_attrs = bias_attrs if bias_attrs is not None else {} self.node_attrs = node_attrs if node_attrs is not None else {} - self.layer_attributes = layer_attributes + self._layer_attributes = layer_attributes def has_weight(self) -> bool: return bool(self.weight_attrs) @@ -78,6 +78,9 @@ def has_bias(self) -> bool: def has_node_attrs(self) -> bool: return bool(self.node_attrs) + def get_backend_agnostic_attributes(self) -> BaseLayerAttributes: + return self._layer_attributes + def get_constant_weight_port_ids(metatype: ONNXOpMetatype) -> List[int]: """ diff --git a/nncf/openvino/graph/metatypes/openvino_metatypes.py b/nncf/openvino/graph/metatypes/openvino_metatypes.py index ab1f99a014d..1646665ee35 100644 --- a/nncf/openvino/graph/metatypes/openvino_metatypes.py +++ b/nncf/openvino/graph/metatypes/openvino_metatypes.py @@ -62,6 +62,7 @@ class OVConvolutionMetatype(OVOpMetatype): const_channel_axis = [0] const_layout = [OVLayoutElem.C_OUT, OVLayoutElem.C_IN] output_channel_axis = 1 + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -72,6 +73,7 @@ class OVConvolutionBackpropDataMetatype(OVOpMetatype): const_channel_axis = [1] const_layout = [OVLayoutElem.C_IN, OVLayoutElem.C_OUT] output_channel_axis = 1 + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -82,6 +84,7 @@ class OVDepthwiseConvolutionMetatype(OVOpMetatype): const_channel_axis = [0, 1] const_layout = [OVLayoutElem.GROUPS, OVLayoutElem.C_OUT, OVLayoutElem.C_IN] output_channel_axis = 1 + input_edges_num_expected = 2 @classmethod def matches(cls, node: ov.Node) -> bool: @@ -97,6 +100,7 @@ class OVGroupConvolutionMetatype(OVOpMetatype): const_channel_axis = [0, 1] const_layout = [OVLayoutElem.GROUPS, OVLayoutElem.C_OUT, OVLayoutElem.C_IN] output_channel_axis = 1 + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -107,6 +111,7 @@ class OVGroupConvolutionBackpropDataMetatype(OVOpMetatype): const_channel_axis = [0, 2] const_layout = [OVLayoutElem.GROUPS, OVLayoutElem.C_IN, OVLayoutElem.C_OUT] output_channel_axis = 1 + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -118,6 +123,7 @@ class OVMatMulMetatype(OVOpMetatype): -1 ] # const layout: [B, ..., Y, X], where const is the second operand of matrix multiplication output_channel_axis = -1 + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -188,6 +194,7 @@ class OVAddMetatype(OVOpMetatype): name = "AddOp" op_names = ["Add"] hw_config_names = [HWConfigOpName.ADD] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -195,6 +202,7 @@ class OVSubtractMetatype(OVOpMetatype): name = "SubtractOp" op_names = ["Subtract"] hw_config_names = [HWConfigOpName.SUBTRACT] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -202,6 +210,7 @@ class OVMultiplyMetatype(OVOpMetatype): name = "MultiplyOp" op_names = ["Multiply"] hw_config_names = [HWConfigOpName.MULTIPLY] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -209,6 +218,7 @@ class OVDivideMetatype(OVOpMetatype): name = "DivideOp" op_names = ["Divide"] hw_config_names = [HWConfigOpName.DIVIDE] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -229,6 +239,7 @@ class OVConcatMetatype(OVOpMetatype): class OVBatchNormMetatype(OVOpMetatype): name = "BatchNormalizationOp" op_names = ["BatchNormInference"] + input_edges_num_expected = 5 @OV_OPERATOR_METATYPES.register() @@ -236,6 +247,7 @@ class OVInterpolateMetatype(OVOpMetatype): name = "InterpolateOp" op_names = ["Interpolate"] hw_config_names = [HWConfigOpName.INTERPOLATE] + input_edges_num_expected = 4 @OV_OPERATOR_METATYPES.register() @@ -256,6 +268,7 @@ class OVReshapeMetatype(OVOpMetatype): name = "ReshapeOp" op_names = ["Reshape"] hw_config_names = [HWConfigOpName.RESHAPE] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -275,12 +288,14 @@ class OVSplitMetatype(OVOpMetatype): name = "SplitOp" op_names = ["Split"] hw_config_names = [HWConfigOpName.SPLIT] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() class OVVariadicSplitMetatype(OVOpMetatype): name = "VariadicSplitOp" op_names = ["VariadicSplit"] + input_edges_num_expected = 3 @OV_OPERATOR_METATYPES.register() @@ -293,6 +308,7 @@ class OVShuffleChannelsMetatype(OVOpMetatype): class OVBroadcastMetatype(OVOpMetatype): name = "BroadcastOp" op_names = ["Broadcast"] + input_edges_num_expected = 3 @OV_OPERATOR_METATYPES.register() @@ -331,6 +347,7 @@ class OVLSTMSequenceMetatype(OVOpMetatype): op_names = ["LSTMSequence"] hw_config_names = [HWConfigOpName.LSTMSEQUENCE] const_channel_axis = [1] # const layout: [num_directions, 4 \* hidden_size, input_size] + input_edges_num_expected = 7 @OV_OPERATOR_METATYPES.register() @@ -339,6 +356,7 @@ class OVGRUSequenceMetatype(OVOpMetatype): op_names = ["GRUSequence"] hw_config_names = [HWConfigOpName.GRUSEQUENCE] const_channel_axis = [1] # const layout: [num_directions, 3 \* hidden_size, input_size] + input_edges_num_expected = 5 @OV_OPERATOR_METATYPES.register() @@ -352,6 +370,7 @@ class OVLessMetatype(OVOpMetatype): name = "LessOp" op_names = ["Less"] hw_config_names = [HWConfigOpName.LESS] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -359,6 +378,7 @@ class OVLessEqualMetatype(OVOpMetatype): name = "LessEqualOp" op_names = ["LessEqual"] hw_config_names = [HWConfigOpName.LESSEQUAL] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -366,6 +386,7 @@ class OVGreaterMetatype(OVOpMetatype): name = "GreaterOp" op_names = ["Greater"] hw_config_names = [HWConfigOpName.GREATER] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -373,6 +394,7 @@ class OVGreaterEqualMetatype(OVOpMetatype): name = "GreaterEqualOp" op_names = ["GreaterEqual"] hw_config_names = [HWConfigOpName.GREATEREQUAL] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -380,6 +402,7 @@ class OVEqualMetatype(OVOpMetatype): name = "EqualOp" op_names = ["Equal"] hw_config_names = [HWConfigOpName.EQUAL] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -387,6 +410,7 @@ class OVNotEqualMetatype(OVOpMetatype): name = "NotEqualOp" op_names = ["NotEqual"] hw_config_names = [HWConfigOpName.NOTEQUAL] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -394,6 +418,7 @@ class OVLogicalNotMetatype(OVOpMetatype): name = "LogicalNotOp" op_names = ["LogicalNot"] hw_config_names = [HWConfigOpName.LOGICALNOT] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -401,6 +426,7 @@ class OVLogicalAndMetatype(OVOpMetatype): name = "LogicalAndOp" op_names = ["LogicalAnd"] hw_config_names = [HWConfigOpName.LOGICALAND] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -408,6 +434,7 @@ class OVLogicalOrMetatype(OVOpMetatype): name = "LogicalOrOp" op_names = ["LogicalOr"] hw_config_names = [HWConfigOpName.LOGICALOR] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -415,6 +442,7 @@ class OVLogicalXorMetatype(OVOpMetatype): name = "LogicalXorOp" op_names = ["LogicalXor"] hw_config_names = [HWConfigOpName.LOGICALXOR] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -432,6 +460,7 @@ def matches(cls, node: ov.Node) -> bool: class OVFloorMetatype(OVOpMetatype): name = "FloorOp" op_names = ["Floor"] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -439,6 +468,7 @@ class OVFloorModMetatype(OVOpMetatype): name = "FloorModOp" op_names = ["FloorMod"] hw_config_names = [HWConfigOpName.FLOORMOD] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -446,6 +476,7 @@ class OVMaximumMetatype(OVOpMetatype): name = "MaximumOp" op_names = ["Maximum"] hw_config_names = [HWConfigOpName.MAXIMUM] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -454,6 +485,8 @@ class OVMinimumMetatype(OVOpMetatype): op_names = ["Minimum"] hw_config_names = [HWConfigOpName.MINIMUM] + input_edges_num_expected = 2 + @OV_OPERATOR_METATYPES.register() class OVSqrtMetatype(OVOpMetatype): @@ -467,18 +500,21 @@ class OVPowerMetatype(OVOpMetatype): name = "PowerOp" op_names = ["Power"] hw_config_names = [HWConfigOpName.POWER] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() class OVLogMetatype(OVOpMetatype): name = "LogOp" op_names = ["Log"] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() class OVRoiAlignMetatype(OVOpMetatype): name = "RoiAlignOp" op_names = ["ROIAlign"] + input_edges_num_expected = 3 @OV_OPERATOR_METATYPES.register() @@ -486,18 +522,21 @@ class OVGatherMetatype(OVOpMetatype): name = "GatherOp" op_names = ["Gather"] subtypes = [OVEmbeddingMetatype] + input_edges_num_expected = 3 @OV_OPERATOR_METATYPES.register() class OVGatherNDMetatype(OVOpMetatype): name = "GatherNDOp" op_names = ["GatherND"] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() class OVGatherElementsMetatype(OVOpMetatype): name = "GatherElementsOp" op_names = ["GatherElements"] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -518,12 +557,14 @@ class OVSqueezeMetatype(OVOpMetatype): class OVNonMaxSuppressionMetatype(OVOpMetatype): name = "NonMaxSuppressionOp" op_names = ["NonMaxSuppression"] + input_edges_num_expected = 5 @OV_OPERATOR_METATYPES.register() class OVReduceMinMetatype(OVOpMetatype): name = "ReduceMinOp" op_names = ["ReduceMin"] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -531,6 +572,7 @@ class OVReduceMaxMetatype(OVOpMetatype): name = "ReduceMaxOp" op_names = ["ReduceMax"] hw_config_names = [HWConfigOpName.REDUCEMAX] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -538,12 +580,14 @@ class OVReduceMeanMetatype(OVOpMetatype): name = "ReduceMeanOp" op_names = ["ReduceMean"] hw_config_names = [HWConfigOpName.REDUCEMEAN] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() class OVReduceL1Metatype(OVOpMetatype): name = "ReduceL1Op" op_names = ["ReduceL1"] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -551,12 +595,14 @@ class OVReduceL2Metatype(OVOpMetatype): name = "ReduceL2Op" op_names = ["ReduceL2"] hw_config_names = [HWConfigOpName.REDUCEL2] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() class OVTopKMetatype(OVOpMetatype): name = "TopKOp" op_names = ["TopK"] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -564,6 +610,7 @@ class OVStridedSliceMetatype(OVOpMetatype): name = "StridedSliceOp" op_names = ["StridedSlice"] hw_config_names = [HWConfigOpName.STRIDEDSLICE] + input_edges_num_expected = 4 @OV_OPERATOR_METATYPES.register() @@ -577,6 +624,7 @@ class OVTransposeMetatype(OVOpMetatype): name = "TransposeOp" op_names = ["Transpose"] hw_config_names = [HWConfigOpName.TRANSPOSE] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() @@ -584,24 +632,28 @@ class OVTileMetatype(OVOpMetatype): name = "TileOp" op_names = ["Tile"] hw_config_names = [HWConfigOpName.TILE] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() class OVScatterElementsUpdateMetatype(OVOpMetatype): name = "ScatterElementsUpdateOp" op_names = ["ScatterElementsUpdate"] + input_edges_num_expected = 4 @OV_OPERATOR_METATYPES.register() class OVScatterNDUpdateMetatype(OVOpMetatype): name = "ScatterNDUpdateOp" op_names = ["ScatterNDUpdate"] + input_edges_num_expected = 3 @OV_OPERATOR_METATYPES.register() class OVScatterUpdateMetatype(OVOpMetatype): name = "ScatterUpdateOp" op_names = ["ScatterUpdate"] + input_edges_num_expected = 4 @OV_OPERATOR_METATYPES.register() @@ -615,6 +667,7 @@ class OVPadMetatype(OVOpMetatype): name = "PadOp" op_names = ["Pad"] hw_config_names = [HWConfigOpName.PAD] + input_edges_num_expected = 4 @OV_OPERATOR_METATYPES.register() @@ -671,12 +724,14 @@ class OVClampMetatype(OVOpMetatype): class OVSquaredDifferenceMetatype(OVOpMetatype): name = "SquaredDifferenceOp" op_names = ["SquaredDifference"] + input_edges_num_expected = 2 @OV_OPERATOR_METATYPES.register() class OVDeformableConvolutionMetatype(OVOpMetatype): name = "DeformableConvolutionOp" op_names = ["DeformableConvolution"] + input_edges_num_expected = 3 @OV_OPERATOR_METATYPES.register() @@ -696,6 +751,7 @@ class OVGroupNormalizationMetatype(OVOpMetatype): name = "GroupNormalizationOp" op_names = ["GroupNormalization"] hw_config_names = [HWConfigOpName.GROUPNORMALIZATION] + input_edges_num_expected = 3 def get_operator_metatypes() -> List[Type[OperatorMetatype]]: diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 0191ea90423..e9fc13dfafd 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -539,8 +539,8 @@ def _get_quantization_target_points( deepcopy(nncf_graph), self._backend_entity.shapeof_metatypes, self._backend_entity.dropout_metatypes, - self._backend_entity.read_variable_metatypes, self._backend_entity.constant_metatypes, + self._backend_entity.read_variable_metatypes, ) quantizer_setup = self._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns) diff --git a/nncf/quantization/algorithms/min_max/backend.py b/nncf/quantization/algorithms/min_max/backend.py index 92c6a69576b..ad3f6d8acf9 100644 --- a/nncf/quantization/algorithms/min_max/backend.py +++ b/nncf/quantization/algorithms/min_max/backend.py @@ -81,11 +81,6 @@ def constant_metatypes(self) -> List[OperatorMetatype]: Property for the backend-specific metatypes that can be interpreted as constants. """ - @property - @abstractmethod - def nodes_with_weights_metatypes(self) -> List[OperatorMetatype]: - pass - @property @abstractmethod def overflow_fix_metatypes(self) -> List[OperatorMetatype]: diff --git a/nncf/quantization/algorithms/min_max/onnx_backend.py b/nncf/quantization/algorithms/min_max/onnx_backend.py index 879f3b78d68..ed9cae28b3f 100644 --- a/nncf/quantization/algorithms/min_max/onnx_backend.py +++ b/nncf/quantization/algorithms/min_max/onnx_backend.py @@ -84,10 +84,6 @@ def read_variable_metatypes(self) -> List[OperatorMetatype]: def constant_metatypes(self) -> List[OperatorMetatype]: return [om.ONNXConstantMetatype, om.ONNXConstantOfShapeMetatype] - @property - def nodes_with_weights_metatypes(self) -> List[OperatorMetatype]: - return [om] - @property def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]: return {om.ONNXConcatMetatype: self.overflow_fix_metatypes} diff --git a/nncf/quantization/passes.py b/nncf/quantization/passes.py index 055f1f27a5b..6af7bd6ca49 100644 --- a/nncf/quantization/passes.py +++ b/nncf/quantization/passes.py @@ -13,6 +13,7 @@ from typing import List, Optional, TypeVar from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes from nncf.common.graph.operator_metatypes import OperatorMetatype TModel = TypeVar("TModel") @@ -22,8 +23,8 @@ def transform_to_inference_graph( nncf_graph: NNCFGraph, shapeof_metatypes: List[OperatorMetatype], dropout_metatypes: List[OperatorMetatype], + constant_metatypes: List[OperatorMetatype], read_variable_metatypes: Optional[List[OperatorMetatype]] = None, - nncf_graph_contains_constants: bool = True, ) -> NNCFGraph: """ This method contains inplace pipeline of the passes that uses to provide inference graph without constant flows. @@ -36,10 +37,9 @@ def transform_to_inference_graph( :param nncf_graph_contains_constants: Whether NNCFGraph contains constant nodes or not. :return: NNCFGraph in the inference style. """ + filter_constant_nodes(nncf_graph, constant_metatypes) remove_shapeof_subgraphs(nncf_graph, shapeof_metatypes, read_variable_metatypes) remove_nodes_and_reconnect_graph(nncf_graph, dropout_metatypes) - if nncf_graph_contains_constants: - filter_constant_nodes(nncf_graph, read_variable_metatypes) return nncf_graph @@ -50,6 +50,7 @@ def remove_shapeof_subgraphs( ) -> NNCFGraph: """ Removes the ShapeOf subgraphs from the provided NNCFGraph instance inplace. + Constant subgraph should be already removed from the given NNCFGraph. :param nncf_graph: NNCFGraph instance for the transformation. :param shapeof_metatypes: List of backend-specific ShapeOf metatypes. @@ -58,12 +59,13 @@ def remove_shapeof_subgraphs( :return: NNCFGraph without ShapeOf subgraphs. """ read_variable_metatypes = read_variable_metatypes if read_variable_metatypes else [] + nodes_without_inputs = [node for node in nncf_graph.get_all_nodes() if not nncf_graph.get_all_edges(node)] nodes_to_drop = set() shape_of_nodes = [] infer_nodes = [] similar_inputs = nncf_graph.get_nodes_by_metatypes(read_variable_metatypes) - nodes_queue = collections.deque(nncf_graph.get_input_nodes() + similar_inputs) + nodes_queue = collections.deque(nncf_graph.get_input_nodes() + similar_inputs + nodes_without_inputs) while nodes_queue: node = nodes_queue.pop() if node.metatype in shapeof_metatypes: @@ -143,7 +145,8 @@ def remove_nodes_and_reconnect_graph( def filter_constant_nodes( - nncf_graph: NNCFGraph, read_variable_metatypes: Optional[List[OperatorMetatype]] = None + nncf_graph: NNCFGraph, + constant_metatypes: Optional[List[OperatorMetatype]] = None, ) -> NNCFGraph: """ Removes all Constant nodes from NNCFGraph inplace, making it inference graph. @@ -154,23 +157,32 @@ def filter_constant_nodes( that also can be interpreted as inputs (ReadValue). :return: NNCFGraph without Constant nodes. """ - read_variable_metatypes = read_variable_metatypes if read_variable_metatypes else [] - input_nodes = nncf_graph.get_input_nodes() - similar_input_nodes = nncf_graph.get_nodes_by_metatypes(read_variable_metatypes) - - start_nodes = input_nodes + similar_input_nodes - - if not start_nodes: - return nncf_graph + constant_metatypes = constant_metatypes if constant_metatypes else [] + constant_nodes = set(nncf_graph.get_nodes_by_metatypes(constant_metatypes)) visited_nodes = set() - nodes_queue = collections.deque(start_nodes) + nodes_queue = collections.deque(constant_nodes) while nodes_queue: node = nodes_queue.pop() if node in visited_nodes: continue + input_edges_num_expected = node.metatype.input_edges_num_expected + if node.layer_attributes is not None and isinstance( + node.layer_attributes.get_backend_agnostic_attributes(), MultipleInputLayerAttributes + ): + input_edges_num_expected = node.layer_attributes.get_backend_agnostic_attributes().num_inputs + + if input_edges_num_expected: + input_edges = nncf_graph.get_input_edges(node) + # Node has missed input edges thus considered to be an inference node + if len(input_edges) < input_edges_num_expected: + continue + # Node should have all input nodes marked as a constant node to be + # a constant node. + if any(edge.from_node not in constant_nodes for edge in input_edges): + continue + constant_nodes.add(node) visited_nodes.add(node) nodes_queue.extend(nncf_graph.get_next_nodes(node)) - constant_nodes = [node for node in nncf_graph.get_all_nodes() if node not in visited_nodes] nncf_graph.remove_nodes_from(constant_nodes) return nncf_graph diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index a05b8b7d85d..d554d2ec10c 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -164,6 +164,7 @@ class PTDepthwiseConv1dSubtype(PTDepthwiseConvOperatorSubtype): hw_config_name = [HWConfigOpName.DEPTHWISECONVOLUTION] module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv1d"]} output_channel_axis = 1 + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -173,6 +174,7 @@ class PTModuleConv1dMetatype(PTModuleOperatorSubtype): module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv1d"]} subtypes = [PTDepthwiseConv1dSubtype] output_channel_axis = 1 + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -182,6 +184,7 @@ class PTConv1dMetatype(PTOperatorMetatype): module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv1d"]} subtypes = [PTModuleConv1dMetatype] output_channel_axis = 1 + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -190,6 +193,7 @@ class PTDepthwiseConv2dSubtype(PTDepthwiseConvOperatorSubtype): hw_config_names = [HWConfigOpName.DEPTHWISECONVOLUTION] module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv2d"]} output_channel_axis = 1 + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -199,6 +203,7 @@ class PTModuleConv2dMetatype(PTModuleOperatorSubtype): module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv2d"]} subtypes = [PTDepthwiseConv2dSubtype] output_channel_axis = 1 + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -208,6 +213,7 @@ class PTConv2dMetatype(PTOperatorMetatype): module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv2d"]} subtypes = [PTModuleConv2dMetatype] output_channel_axis = 1 + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -216,6 +222,7 @@ class PTDepthwiseConv3dSubtype(PTDepthwiseConvOperatorSubtype): hw_config_names = [HWConfigOpName.DEPTHWISECONVOLUTION] module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv3d"]} output_channel_axis = 1 + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -225,6 +232,7 @@ class PTModuleConv3dMetatype(PTModuleOperatorSubtype): module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv3d"]} subtypes = [PTDepthwiseConv3dSubtype] output_channel_axis = 1 + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -234,6 +242,7 @@ class PTConv3dMetatype(PTOperatorMetatype): module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv3d"]} subtypes = [PTModuleConv3dMetatype] output_channel_axis = 1 + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -242,6 +251,7 @@ class PTModuleConvTranspose1dMetatype(PTModuleOperatorSubtype): hw_config_names = [HWConfigOpName.CONVOLUTION] module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv_transpose1d"]} output_channel_axis = 1 + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -251,6 +261,7 @@ class PTConvTranspose1dMetatype(PTOperatorMetatype): module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv_transpose1d"]} subtypes = [PTModuleConvTranspose1dMetatype] output_channel_axis = 1 + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -259,6 +270,7 @@ class PTModuleConvTranspose2dMetatype(PTModuleOperatorSubtype): hw_config_names = [HWConfigOpName.CONVOLUTION] module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv_transpose2d"]} output_channel_axis = 1 + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -268,6 +280,7 @@ class PTConvTranspose2dMetatype(PTOperatorMetatype): module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv_transpose2d"]} subtypes = [PTModuleConvTranspose2dMetatype] output_channel_axis = 1 + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -276,6 +289,7 @@ class PTModuleConvTranspose3dMetatype(PTModuleOperatorSubtype): hw_config_names = [HWConfigOpName.CONVOLUTION] module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv_transpose3d"]} output_channel_axis = 1 + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -285,12 +299,14 @@ class PTConvTranspose3dMetatype(PTOperatorMetatype): module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["conv_transpose3d"]} subtypes = [PTModuleConvTranspose3dMetatype] output_channel_axis = 1 + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() class PTModuleDeformConv2dMetatype(PTModuleOperatorSubtype): name = "DeformConv2dOp" module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["deform_conv2d"]} + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -420,6 +436,7 @@ class PTAddMetatype(PTOperatorMetatype): NamespaceTarget.TORCH: ["add"], } hw_config_names = [HWConfigOpName.ADD] + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -430,6 +447,7 @@ class PTSubMetatype(PTOperatorMetatype): NamespaceTarget.TORCH: ["sub"], } hw_config_names = [HWConfigOpName.SUBTRACT] + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -440,6 +458,7 @@ class PTMulMetatype(PTOperatorMetatype): NamespaceTarget.TORCH: ["mul"], } hw_config_names = [HWConfigOpName.MULTIPLY] + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -450,6 +469,7 @@ class PTDivMetatype(PTOperatorMetatype): NamespaceTarget.TORCH: ["div"], } hw_config_names = [HWConfigOpName.DIVIDE] + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -500,6 +520,7 @@ class PTMatMulMetatype(PTOperatorMetatype): NamespaceTarget.TORCH: ["matmul", "bmm", "mm"], } hw_config_names = [HWConfigOpName.MATMUL] + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() @@ -511,6 +532,7 @@ class PTBaddBmmMetatype(PTOperatorMetatype): # presuming that most runtime implementations will fuse the bias addition into the matrix multiplication # and therefore won't quantize the bias input, as this would break the hardware-fused pattern. ignored_input_ports: List[int] = [0] + input_edges_num_expected = 2 @PT_OPERATOR_METATYPES.register() diff --git a/tests/common/data/reference_graphs/passes/test_constant_filtering_model_after.dot b/tests/common/data/reference_graphs/passes/test_constant_filtering_model_after.dot deleted file mode 100644 index 3dc12b5ad16..00000000000 --- a/tests/common/data/reference_graphs/passes/test_constant_filtering_model_after.dot +++ /dev/null @@ -1,12 +0,0 @@ -strict digraph { -"0 /Input_1_0" [id=0, type=Input_1]; -"1 /ReadVariable_0" [id=1, type=ReadVariable]; -"4 /Conv_0" [id=4, type=Conv]; -"6 /Conv2_0" [id=6, type=Conv2]; -"7 /Add_0" [id=7, type=Add]; -"8 /Final_node_0" [id=8, type=Final_node]; -"0 /Input_1_0" -> "4 /Conv_0"; -"1 /ReadVariable_0" -> "7 /Add_0"; -"6 /Conv2_0" -> "7 /Add_0"; -"7 /Add_0" -> "8 /Final_node_0"; -} diff --git a/tests/common/data/reference_graphs/passes/test_constant_filtering_model_after0.dot b/tests/common/data/reference_graphs/passes/test_constant_filtering_model_after0.dot new file mode 100644 index 00000000000..98034ff015e --- /dev/null +++ b/tests/common/data/reference_graphs/passes/test_constant_filtering_model_after0.dot @@ -0,0 +1,16 @@ +strict digraph { +"8 /Concat_with_missed_input_0" [id=8, type=Concat_with_missed_input]; +"9 /Input_1_0" [id=9, type=Input_1]; +"10 /Input_2_0" [id=10, type=Input_2]; +"11 /Concat_with_input_0" [id=11, type=Concat_with_input]; +"12 /ReadVariable_0" [id=12, type=ReadVariable]; +"14 /Conv_0" [id=14, type=Conv]; +"16 /Conv2_0" [id=16, type=Conv2]; +"17 /Add_0" [id=17, type=Add]; +"18 /Final_node_0" [id=18, type=Final_node]; +"9 /Input_1_0" -> "14 /Conv_0"; +"10 /Input_2_0" -> "11 /Concat_with_input_0"; +"12 /ReadVariable_0" -> "17 /Add_0"; +"16 /Conv2_0" -> "17 /Add_0"; +"17 /Add_0" -> "18 /Final_node_0"; +} diff --git a/tests/common/data/reference_graphs/passes/test_constant_filtering_model_after1.dot b/tests/common/data/reference_graphs/passes/test_constant_filtering_model_after1.dot new file mode 100644 index 00000000000..87aab101c9e --- /dev/null +++ b/tests/common/data/reference_graphs/passes/test_constant_filtering_model_after1.dot @@ -0,0 +1,16 @@ +strict digraph { +"15 /Concat_with_missed_input_0" [id=15, type=Concat_with_missed_input]; +"16 /Input_1_0" [id=16, type=Input_1]; +"17 /Input_2_0" [id=17, type=Input_2]; +"18 /Concat_with_input_0" [id=18, type=Concat_with_input]; +"19 /ReadVariable_0" [id=19, type=ReadVariable]; +"22 /Conv_0" [id=22, type=Conv]; +"25 /Conv2_0" [id=25, type=Conv2]; +"26 /Add_0" [id=26, type=Add]; +"27 /Final_node_0" [id=27, type=Final_node]; +"16 /Input_1_0" -> "22 /Conv_0"; +"17 /Input_2_0" -> "18 /Concat_with_input_0"; +"19 /ReadVariable_0" -> "26 /Add_0"; +"25 /Conv2_0" -> "26 /Add_0"; +"26 /Add_0" -> "27 /Final_node_0"; +} diff --git a/tests/common/data/reference_graphs/passes/test_constant_filtering_model_before.dot b/tests/common/data/reference_graphs/passes/test_constant_filtering_model_before.dot deleted file mode 100644 index b4590b212a3..00000000000 --- a/tests/common/data/reference_graphs/passes/test_constant_filtering_model_before.dot +++ /dev/null @@ -1,18 +0,0 @@ -strict digraph { -"0 /Input_1_0" [id=0, type=Input_1]; -"1 /ReadVariable_0" [id=1, type=ReadVariable]; -"2 /Weights_0" [id=2, type=Weights]; -"3 /AnyNodeBetweenWeightAndConv_0" [id=3, type=AnyNodeBetweenWeightAndConv]; -"4 /Conv_0" [id=4, type=Conv]; -"5 /Weights2_0" [id=5, type=Weights2]; -"6 /Conv2_0" [id=6, type=Conv2]; -"7 /Add_0" [id=7, type=Add]; -"8 /Final_node_0" [id=8, type=Final_node]; -"0 /Input_1_0" -> "4 /Conv_0"; -"1 /ReadVariable_0" -> "7 /Add_0"; -"2 /Weights_0" -> "3 /AnyNodeBetweenWeightAndConv_0"; -"3 /AnyNodeBetweenWeightAndConv_0" -> "4 /Conv_0"; -"5 /Weights2_0" -> "6 /Conv2_0"; -"6 /Conv2_0" -> "7 /Add_0"; -"7 /Add_0" -> "8 /Final_node_0"; -} diff --git a/tests/common/data/reference_graphs/passes/test_constant_filtering_model_before0.dot b/tests/common/data/reference_graphs/passes/test_constant_filtering_model_before0.dot new file mode 100644 index 00000000000..ed66aae8628 --- /dev/null +++ b/tests/common/data/reference_graphs/passes/test_constant_filtering_model_before0.dot @@ -0,0 +1,42 @@ +strict digraph { +"0 /Const0_0" [id=0, type=Const0]; +"1 /Const1_0" [id=1, type=Const1]; +"2 /Const2_0" [id=2, type=Const2]; +"3 /Const3_0" [id=3, type=Const3]; +"4 /Const4_0" [id=4, type=Const4]; +"5 /Concat_with_constant_0" [id=5, type=Concat_with_constant]; +"6 /Const5_0" [id=6, type=Const5]; +"7 /Const6_0" [id=7, type=Const6]; +"8 /Concat_with_missed_input_0" [id=8, type=Concat_with_missed_input]; +"9 /Input_1_0" [id=9, type=Input_1]; +"10 /Input_2_0" [id=10, type=Input_2]; +"11 /Concat_with_input_0" [id=11, type=Concat_with_input]; +"12 /ReadVariable_0" [id=12, type=ReadVariable]; +"13 /Weights_0" [id=13, type=Weights]; +"14 /Conv_0" [id=14, type=Conv]; +"15 /Weights2_0" [id=15, type=Weights2]; +"16 /Conv2_0" [id=16, type=Conv2]; +"17 /Add_0" [id=17, type=Add]; +"18 /Final_node_0" [id=18, type=Final_node]; +"19 /Weights3_0" [id=19, type=Weights3]; +"20 /Weights4_0" [id=20, type=Weights4]; +"21 /Conv3_0" [id=21, type=Conv3]; +"22 /NodeAfterConstantConv_0" [id=22, type=NodeAfterConstantConv]; +"0 /Const0_0" -> "11 /Concat_with_input_0"; +"1 /Const1_0" -> "11 /Concat_with_input_0"; +"2 /Const2_0" -> "5 /Concat_with_constant_0"; +"3 /Const3_0" -> "5 /Concat_with_constant_0"; +"4 /Const4_0" -> "5 /Concat_with_constant_0"; +"6 /Const5_0" -> "8 /Concat_with_missed_input_0"; +"7 /Const6_0" -> "8 /Concat_with_missed_input_0"; +"9 /Input_1_0" -> "14 /Conv_0"; +"10 /Input_2_0" -> "11 /Concat_with_input_0"; +"12 /ReadVariable_0" -> "17 /Add_0"; +"13 /Weights_0" -> "14 /Conv_0"; +"15 /Weights2_0" -> "16 /Conv2_0"; +"16 /Conv2_0" -> "17 /Add_0"; +"17 /Add_0" -> "18 /Final_node_0"; +"19 /Weights3_0" -> "21 /Conv3_0"; +"20 /Weights4_0" -> "21 /Conv3_0"; +"21 /Conv3_0" -> "22 /NodeAfterConstantConv_0"; +} diff --git a/tests/common/data/reference_graphs/passes/test_constant_filtering_model_before1.dot b/tests/common/data/reference_graphs/passes/test_constant_filtering_model_before1.dot new file mode 100644 index 00000000000..8b1942fe1c0 --- /dev/null +++ b/tests/common/data/reference_graphs/passes/test_constant_filtering_model_before1.dot @@ -0,0 +1,64 @@ +strict digraph { +"0 /Const0_0" [id=0, type=Const0]; +"1 /AnyAfterConst0_0" [id=1, type=AnyAfterConst0]; +"2 /Const1_0" [id=2, type=Const1]; +"3 /AnyAfterConst1_0" [id=3, type=AnyAfterConst1]; +"4 /Const2_0" [id=4, type=Const2]; +"5 /AnyAfterConst2_0" [id=5, type=AnyAfterConst2]; +"6 /Const3_0" [id=6, type=Const3]; +"7 /AnyAfterConst3_0" [id=7, type=AnyAfterConst3]; +"8 /Const4_0" [id=8, type=Const4]; +"9 /AnyAfterConst4_0" [id=9, type=AnyAfterConst4]; +"10 /Concat_with_constant_0" [id=10, type=Concat_with_constant]; +"11 /Const5_0" [id=11, type=Const5]; +"12 /AnyAfterConst5_0" [id=12, type=AnyAfterConst5]; +"13 /Const6_0" [id=13, type=Const6]; +"14 /AnyAfterConst6_0" [id=14, type=AnyAfterConst6]; +"15 /Concat_with_missed_input_0" [id=15, type=Concat_with_missed_input]; +"16 /Input_1_0" [id=16, type=Input_1]; +"17 /Input_2_0" [id=17, type=Input_2]; +"18 /Concat_with_input_0" [id=18, type=Concat_with_input]; +"19 /ReadVariable_0" [id=19, type=ReadVariable]; +"20 /Weights_0" [id=20, type=Weights]; +"21 /AnyAfterWeights_0" [id=21, type=AnyAfterWeights]; +"22 /Conv_0" [id=22, type=Conv]; +"23 /Weights2_0" [id=23, type=Weights2]; +"24 /AnyAfterWeights2_0" [id=24, type=AnyAfterWeights2]; +"25 /Conv2_0" [id=25, type=Conv2]; +"26 /Add_0" [id=26, type=Add]; +"27 /Final_node_0" [id=27, type=Final_node]; +"28 /Weights3_0" [id=28, type=Weights3]; +"29 /AnyAfterWeights3_0" [id=29, type=AnyAfterWeights3]; +"30 /Weights4_0" [id=30, type=Weights4]; +"31 /AnyAfterWeights4_0" [id=31, type=AnyAfterWeights4]; +"32 /Conv3_0" [id=32, type=Conv3]; +"33 /NodeAfterConstantConv_0" [id=33, type=NodeAfterConstantConv]; +"0 /Const0_0" -> "1 /AnyAfterConst0_0"; +"1 /AnyAfterConst0_0" -> "18 /Concat_with_input_0"; +"2 /Const1_0" -> "3 /AnyAfterConst1_0"; +"3 /AnyAfterConst1_0" -> "18 /Concat_with_input_0"; +"4 /Const2_0" -> "5 /AnyAfterConst2_0"; +"5 /AnyAfterConst2_0" -> "10 /Concat_with_constant_0"; +"6 /Const3_0" -> "7 /AnyAfterConst3_0"; +"7 /AnyAfterConst3_0" -> "10 /Concat_with_constant_0"; +"8 /Const4_0" -> "9 /AnyAfterConst4_0"; +"9 /AnyAfterConst4_0" -> "10 /Concat_with_constant_0"; +"11 /Const5_0" -> "12 /AnyAfterConst5_0"; +"12 /AnyAfterConst5_0" -> "15 /Concat_with_missed_input_0"; +"13 /Const6_0" -> "14 /AnyAfterConst6_0"; +"14 /AnyAfterConst6_0" -> "15 /Concat_with_missed_input_0"; +"16 /Input_1_0" -> "22 /Conv_0"; +"17 /Input_2_0" -> "18 /Concat_with_input_0"; +"19 /ReadVariable_0" -> "26 /Add_0"; +"20 /Weights_0" -> "21 /AnyAfterWeights_0"; +"21 /AnyAfterWeights_0" -> "22 /Conv_0"; +"23 /Weights2_0" -> "24 /AnyAfterWeights2_0"; +"24 /AnyAfterWeights2_0" -> "25 /Conv2_0"; +"25 /Conv2_0" -> "26 /Add_0"; +"26 /Add_0" -> "27 /Final_node_0"; +"28 /Weights3_0" -> "29 /AnyAfterWeights3_0"; +"29 /AnyAfterWeights3_0" -> "32 /Conv3_0"; +"30 /Weights4_0" -> "31 /AnyAfterWeights4_0"; +"31 /AnyAfterWeights4_0" -> "32 /Conv3_0"; +"32 /Conv3_0" -> "33 /NodeAfterConstantConv_0"; +} diff --git a/tests/common/quantization/test_passes.py b/tests/common/quantization/test_passes.py index 87d38d09805..df41b7f43ea 100644 --- a/tests/common/quantization/test_passes.py +++ b/tests/common/quantization/test_passes.py @@ -14,6 +14,8 @@ import pytest +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes +from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.quantization.passes import filter_constant_nodes from nncf.quantization.passes import remove_nodes_and_reconnect_graph from tests.post_training.test_templates.models import NNCFGraphDropoutRemovingCase @@ -57,17 +59,27 @@ def test_remove_nodes_and_reconnect_graph(mode: TestModes): _check_graphs(dot_reference_path_after, nncf_graph) -@pytest.mark.xfail -def test_filter_constant_nodes(): - dot_reference_path_before = Path("passes") / "test_constant_filtering_model_before.dot" - dot_reference_path_after = Path("passes") / "test_constant_filtering_model_after.dot" +@pytest.mark.parametrize("node_between_const_and_op", [False, True]) +def test_filter_constant_nodes(node_between_const_and_op): + dot_reference_path_before = ( + Path("passes") / f"test_constant_filtering_model_before{int(node_between_const_and_op)}.dot" + ) + dot_reference_path_after = ( + Path("passes") / f"test_constant_filtering_model_after{int(node_between_const_and_op)}.dot" + ) - constant_metatype = "CONSTANT_METATYPE" - read_variable_metatype = "READ_VARIABLE_METATYPE" + class ConstantMetatype(OperatorMetatype): + pass - nncf_graph = NNCFGraphToTestConstantFiltering(constant_metatype, read_variable_metatype).nncf_graph + class NodeWithWeightMetatype(OperatorMetatype): + input_edges_num_expected = 2 + + nncf_graph = NNCFGraphToTestConstantFiltering( + ConstantMetatype, + NodeWithWeightMetatype, + MultipleInputLayerAttributes(1, 3), + node_between_const_and_op, + ).nncf_graph _check_graphs(dot_reference_path_before, nncf_graph) - filter_constant_nodes( - nncf_graph, read_variable_metatypes=[read_variable_metatype], constant_nodes_metatypes=[constant_metatype] - ) + filter_constant_nodes(nncf_graph, constant_metatypes=[ConstantMetatype]) _check_graphs(dot_reference_path_after, nncf_graph) diff --git a/tests/post_training/test_templates/models.py b/tests/post_training/test_templates/models.py index be4e78a78b7..a41c05d739e 100644 --- a/tests/post_training/test_templates/models.py +++ b/tests/post_training/test_templates/models.py @@ -303,27 +303,73 @@ def __init__( class NNCFGraphToTestConstantFiltering: - def __init__(self, constant_metatype, read_variable_metatype, nncf_graph_cls=NNCFGraph) -> None: + def __init__( + self, + constant_metatype, + node_with_weights_metatype, + concat_layer_attr, + add_node_between_const_and_weight_node, + nncf_graph_cls=NNCFGraph, + ) -> None: nodes = [ NodeWithType("Input_1", InputNoopMetatype), - NodeWithType("Conv", None), + NodeWithType("Conv", node_with_weights_metatype), NodeWithType("Weights", constant_metatype), - NodeWithType("AnyNodeBetweenWeightAndConv", None), NodeWithType("Weights2", constant_metatype), - NodeWithType("Conv2", None), - NodeWithType("ReadVariable", read_variable_metatype), + NodeWithType("Conv2", node_with_weights_metatype), + NodeWithType("ReadVariable", None), NodeWithType("Add", None), + NodeWithType("Weights3", constant_metatype), + NodeWithType("Weights4", constant_metatype), + NodeWithType("Conv3", node_with_weights_metatype), + NodeWithType("NodeAfterConstantConv", None), NodeWithType("Final_node", None), + NodeWithType("Input_2", InputNoopMetatype), + NodeWithType("Const0", constant_metatype), + NodeWithType("Const1", constant_metatype), + NodeWithType("Concat_with_input", None, layer_attributes=concat_layer_attr), + NodeWithType("Const2", constant_metatype), + NodeWithType("Const3", constant_metatype), + NodeWithType("Const4", constant_metatype), + NodeWithType("Concat_with_constant", None, layer_attributes=concat_layer_attr), + NodeWithType("Const5", constant_metatype), + NodeWithType("Const6", constant_metatype), + NodeWithType("Concat_with_missed_input", None, layer_attributes=concat_layer_attr), ] edges = [ ("Input_1", "Conv"), - ("Weights", "AnyNodeBetweenWeightAndConv"), - ("AnyNodeBetweenWeightAndConv", "Conv"), + ("Weights", "Conv"), ("Weights2", "Conv2"), ("Conv2", "Add"), ("ReadVariable", "Add"), ("Add", "Final_node"), + ("Weights3", "Conv3"), + ("Weights4", "Conv3"), + ("Conv3", "NodeAfterConstantConv"), + ("Input_2", "Concat_with_input"), + ("Const0", "Concat_with_input"), + ("Const1", "Concat_with_input"), + ("Const2", "Concat_with_constant"), + ("Const3", "Concat_with_constant"), + ("Const4", "Concat_with_constant"), + ("Const5", "Concat_with_missed_input"), + ("Const6", "Concat_with_missed_input"), ] + if add_node_between_const_and_weight_node: + constant_nodes = [node for node in nodes if node.node_op_metatype is constant_metatype] + const_node_to_edge = {} + for node in constant_nodes: + for i, edge in enumerate(edges): + if node.node_name == edge[0]: + const_node_to_edge[node] = edge + break + del edges[i] + for node, edge in const_node_to_edge.items(): + any_after_node_name = f"AnyAfter{node.node_name}" + nodes.append(NodeWithType(any_after_node_name, None)) + edges.append((edge[0], any_after_node_name)) + edges.append((any_after_node_name, edge[1])) + original_mock_graph = create_mock_graph(nodes, edges) self.nncf_graph = get_nncf_graph_from_mock_nx_graph(original_mock_graph, nncf_graph_cls)