diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index 549ec83ee3..920619593c 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -43,6 +43,7 @@ from deepmd.utils.graph import ( get_pattern_nodes_from_graph_def, get_tensor_by_name_from_graph, + get_extra_embedding_net_variables_from_graph_def, ) from deepmd.utils.network import ( embedding_net, @@ -1327,6 +1328,12 @@ def init_variables( self.dstd = new_dstd if self.original_sel is None: self.original_sel = sel + if self.stripped_type_embedding: + if self.type_one_side: + extra_suffix = "_one_side_ebd" + else: + extra_suffix = "_two_side_ebd" + self.extra_embedding_net_variables = get_extra_embedding_net_variables_from_graph_def(graph_def, suffix, extra_suffix, self.layer_size) @property def explicit_ntypes(self) -> bool: diff --git a/deepmd/descriptor/se_a_ebd_v2.py b/deepmd/descriptor/se_a_ebd_v2.py index 1425c94b37..c6e3cebc71 100644 --- a/deepmd/descriptor/se_a_ebd_v2.py +++ b/deepmd/descriptor/se_a_ebd_v2.py @@ -12,21 +12,6 @@ from .descriptor import ( Descriptor, ) -from deepmd.env import ( - tf, -) -import numpy as np -from deepmd.utils.compress import ( - get_extra_side_embedding_net_variable, - get_two_side_type_embedding, - make_data, -) -from deepmd.utils.graph import ( - get_attention_layer_variables_from_graph_def, - get_pattern_nodes_from_graph_def, - get_tensor_by_name_from_graph, - get_tensor_by_type, -) from .se_a import ( DescrptSeA, ) @@ -83,37 +68,3 @@ def __init__( stripped_type_embedding=True, **kwargs, ) - def init_variables( - self, - graph: tf.Graph, - graph_def: tf.GraphDef, - suffix: str = "", - ) -> None: - super().init_variables(graph=graph, graph_def=graph_def, suffix=suffix) - self.extra_embedding_net_variables = {} - if self.type_one_side: - extra_suffix = "_one_side_ebd" - else: - extra_suffix = "_two_side_ebd" - for i in range(1, self.layer_size + 1): - matrix_pattern = f"filter_type_all{suffix}/matrix_{i}{extra_suffix}" - self.extra_embedding_net_variables[ - matrix_pattern - ] = self._get_two_embed_variables(graph_def, matrix_pattern) - bias_pattern = f"filter_type_all{suffix}/bias_{i}{extra_suffix}" - self.extra_embedding_net_variables[ - bias_pattern - ] = self._get_two_embed_variables(graph_def, bias_pattern) - - def _get_two_embed_variables(self, graph_def, pattern: str): - node = get_pattern_nodes_from_graph_def(graph_def, pattern)[pattern] - dtype = tf.as_dtype(node.dtype).as_numpy_dtype - tensor_shape = tf.TensorShape(node.tensor_shape).as_list() - if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): - tensor_value = np.frombuffer( - node.tensor_content, - dtype=tf.as_dtype(node.dtype).as_numpy_dtype, - ) - else: - tensor_value = get_tensor_by_type(node, dtype) - return np.reshape(tensor_value, tensor_shape) diff --git a/deepmd/descriptor/se_atten.py b/deepmd/descriptor/se_atten.py index 8e4c3c3ef6..4cef308ff1 100644 --- a/deepmd/descriptor/se_atten.py +++ b/deepmd/descriptor/se_atten.py @@ -44,7 +44,7 @@ get_attention_layer_variables_from_graph_def, get_pattern_nodes_from_graph_def, get_tensor_by_name_from_graph, - get_tensor_by_type, + get_extra_embedding_net_variables_from_graph_def, ) from deepmd.utils.network import ( embedding_net, @@ -1292,18 +1292,6 @@ def init_variables( """ super().init_variables(graph=graph, graph_def=graph_def, suffix=suffix) - if self.stripped_type_embedding: - self.two_side_embeeding_net_variables = {} - for i in range(1, self.layer_size + 1): - matrix_pattern = f"filter_type_all{suffix}/matrix_{i}_two_side_ebd" - self.two_side_embeeding_net_variables[ - matrix_pattern - ] = self._get_two_embed_variables(graph_def, matrix_pattern) - bias_pattern = f"filter_type_all{suffix}/bias_{i}_two_side_ebd" - self.two_side_embeeding_net_variables[ - bias_pattern - ] = self._get_two_embed_variables(graph_def, bias_pattern) - self.attention_layer_variables = get_attention_layer_variables_from_graph_def( graph_def, suffix=suffix ) @@ -1322,18 +1310,8 @@ def init_variables( f"attention_layer_{i}{suffix}/layer_normalization_{i}/gamma" ] - def _get_two_embed_variables(self, graph_def, pattern: str): - node = get_pattern_nodes_from_graph_def(graph_def, pattern)[pattern] - dtype = tf.as_dtype(node.dtype).as_numpy_dtype - tensor_shape = tf.TensorShape(node.tensor_shape).as_list() - if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): - tensor_value = np.frombuffer( - node.tensor_content, - dtype=tf.as_dtype(node.dtype).as_numpy_dtype, - ) - else: - tensor_value = get_tensor_by_type(node, dtype) - return np.reshape(tensor_value, tensor_shape) + if self.stripped_type_embedding: + self.two_side_embeeding_net_variables = get_extra_embedding_net_variables_from_graph_def(graph_def, suffix, "_two_side_ebd", self.layer_size) def build_type_exclude_mask( self, diff --git a/deepmd/utils/graph.py b/deepmd/utils/graph.py index 2a795a45a2..20fb7b16c5 100644 --- a/deepmd/utils/graph.py +++ b/deepmd/utils/graph.py @@ -236,6 +236,66 @@ def get_embedding_net_variables_from_graph_def( embedding_net_variables[item] = np.reshape(tensor_value, tensor_shape) return embedding_net_variables +def get_variables_from_graph_def_as_numpy_array(graph_def: tf.GraphDef, pattern: str): + """Get variables from the given tf.GraphDef object, with numpy array returns. + + Parameters + ---------- + graph_def + The input tf.GraphDef object + suffix : str + The name of variable + + Returns + ------- + np.ndarray + The numpy array of the variable + """ + node = get_pattern_nodes_from_graph_def(graph_def, pattern)[pattern] + dtype = tf.as_dtype(node.dtype).as_numpy_dtype + tensor_shape = tf.TensorShape(node.tensor_shape).as_list() + if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): + tensor_value = np.frombuffer( + node.tensor_content, + dtype=tf.as_dtype(node.dtype).as_numpy_dtype, + ) + else: + tensor_value = get_tensor_by_type(node, dtype) + return np.reshape(tensor_value, tensor_shape) + +def get_extra_embedding_net_variables_from_graph_def(graph_def: tf.GraphDef, suffix: str, extra_suffix: str, layer_size: int): + """Get extra embedding net variables from the given tf.GraphDef object. + The "extra embedding net" means the embedding net with only type embeddings input, + which occurs in "se_atten_v2" and "se_a_ebd_v2" descriptor. + + Parameters + ---------- + graph_def + The input tf.GraphDef object + suffix : str + The "common" suffix in the descriptor + extra_suffix: str + This value depends on the value of "type_one_side". + It should always be "_one_side_ebd" or "_two_side_ebd" + layer_size: int + The layer size of the embedding net + + Returns + ------- + Dict + The extra embedding net variables within the given tf.GraphDef object + """ + extra_embedding_net_variables = {} + for i in range(1, layer_size + 1): + matrix_pattern = f"filter_type_all{suffix}/matrix_{i}{extra_suffix}" + extra_embedding_net_variables[ + matrix_pattern + ] = get_variables_from_graph_def_as_numpy_array(graph_def, matrix_pattern) + bias_pattern = f"filter_type_all{suffix}/bias_{i}{extra_suffix}" + extra_embedding_net_variables[ + bias_pattern + ] = get_variables_from_graph_def_as_numpy_array(graph_def, bias_pattern) + return extra_embedding_net_variables def get_embedding_net_variables(model_file: str, suffix: str = "") -> Dict: """Get the embedding net variables with the given frozen model(model_file).