Skip to content

Commit

Permalink
code refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
nahso committed Nov 22, 2023
1 parent 4ccc78b commit cc9c470
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 74 deletions.
7 changes: 7 additions & 0 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from deepmd.utils.graph import (
get_pattern_nodes_from_graph_def,
get_tensor_by_name_from_graph,
get_extra_embedding_net_variables_from_graph_def,
)
from deepmd.utils.network import (
embedding_net,
Expand Down Expand Up @@ -1327,6 +1328,12 @@ def init_variables(
self.dstd = new_dstd
if self.original_sel is None:
self.original_sel = sel
if self.stripped_type_embedding:
if self.type_one_side:
extra_suffix = "_one_side_ebd"

Check warning on line 1333 in deepmd/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/se_a.py#L1333

Added line #L1333 was not covered by tests
else:
extra_suffix = "_two_side_ebd"
self.extra_embedding_net_variables = get_extra_embedding_net_variables_from_graph_def(graph_def, suffix, extra_suffix, self.layer_size)

@property
def explicit_ntypes(self) -> bool:
Expand Down
49 changes: 0 additions & 49 deletions deepmd/descriptor/se_a_ebd_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,6 @@
from .descriptor import (
Descriptor,
)
from deepmd.env import (
tf,
)
import numpy as np
from deepmd.utils.compress import (
get_extra_side_embedding_net_variable,
get_two_side_type_embedding,
make_data,
)
from deepmd.utils.graph import (
get_attention_layer_variables_from_graph_def,
get_pattern_nodes_from_graph_def,
get_tensor_by_name_from_graph,
get_tensor_by_type,
)
from .se_a import (
DescrptSeA,
)
Expand Down Expand Up @@ -83,37 +68,3 @@ def __init__(
stripped_type_embedding=True,
**kwargs,
)
def init_variables(
self,
graph: tf.Graph,
graph_def: tf.GraphDef,
suffix: str = "",
) -> None:
super().init_variables(graph=graph, graph_def=graph_def, suffix=suffix)
self.extra_embedding_net_variables = {}
if self.type_one_side:
extra_suffix = "_one_side_ebd"
else:
extra_suffix = "_two_side_ebd"
for i in range(1, self.layer_size + 1):
matrix_pattern = f"filter_type_all{suffix}/matrix_{i}{extra_suffix}"
self.extra_embedding_net_variables[
matrix_pattern
] = self._get_two_embed_variables(graph_def, matrix_pattern)
bias_pattern = f"filter_type_all{suffix}/bias_{i}{extra_suffix}"
self.extra_embedding_net_variables[
bias_pattern
] = self._get_two_embed_variables(graph_def, bias_pattern)

def _get_two_embed_variables(self, graph_def, pattern: str):
node = get_pattern_nodes_from_graph_def(graph_def, pattern)[pattern]
dtype = tf.as_dtype(node.dtype).as_numpy_dtype
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
tensor_value = np.frombuffer(
node.tensor_content,
dtype=tf.as_dtype(node.dtype).as_numpy_dtype,
)
else:
tensor_value = get_tensor_by_type(node, dtype)
return np.reshape(tensor_value, tensor_shape)
28 changes: 3 additions & 25 deletions deepmd/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
get_attention_layer_variables_from_graph_def,
get_pattern_nodes_from_graph_def,
get_tensor_by_name_from_graph,
get_tensor_by_type,
get_extra_embedding_net_variables_from_graph_def,
)
from deepmd.utils.network import (
embedding_net,
Expand Down Expand Up @@ -1292,18 +1292,6 @@ def init_variables(
"""
super().init_variables(graph=graph, graph_def=graph_def, suffix=suffix)

if self.stripped_type_embedding:
self.two_side_embeeding_net_variables = {}
for i in range(1, self.layer_size + 1):
matrix_pattern = f"filter_type_all{suffix}/matrix_{i}_two_side_ebd"
self.two_side_embeeding_net_variables[
matrix_pattern
] = self._get_two_embed_variables(graph_def, matrix_pattern)
bias_pattern = f"filter_type_all{suffix}/bias_{i}_two_side_ebd"
self.two_side_embeeding_net_variables[
bias_pattern
] = self._get_two_embed_variables(graph_def, bias_pattern)

self.attention_layer_variables = get_attention_layer_variables_from_graph_def(
graph_def, suffix=suffix
)
Expand All @@ -1322,18 +1310,8 @@ def init_variables(
f"attention_layer_{i}{suffix}/layer_normalization_{i}/gamma"
]

def _get_two_embed_variables(self, graph_def, pattern: str):
node = get_pattern_nodes_from_graph_def(graph_def, pattern)[pattern]
dtype = tf.as_dtype(node.dtype).as_numpy_dtype
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
tensor_value = np.frombuffer(
node.tensor_content,
dtype=tf.as_dtype(node.dtype).as_numpy_dtype,
)
else:
tensor_value = get_tensor_by_type(node, dtype)
return np.reshape(tensor_value, tensor_shape)
if self.stripped_type_embedding:
self.two_side_embeeding_net_variables = get_extra_embedding_net_variables_from_graph_def(graph_def, suffix, "_two_side_ebd", self.layer_size)

def build_type_exclude_mask(
self,
Expand Down
60 changes: 60 additions & 0 deletions deepmd/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,66 @@ def get_embedding_net_variables_from_graph_def(
embedding_net_variables[item] = np.reshape(tensor_value, tensor_shape)
return embedding_net_variables

def get_variables_from_graph_def_as_numpy_array(graph_def: tf.GraphDef, pattern: str):
"""Get variables from the given tf.GraphDef object, with numpy array returns.
Parameters
----------
graph_def
The input tf.GraphDef object
suffix : str
The name of variable
Returns
-------
np.ndarray
The numpy array of the variable
"""
node = get_pattern_nodes_from_graph_def(graph_def, pattern)[pattern]
dtype = tf.as_dtype(node.dtype).as_numpy_dtype
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
tensor_value = np.frombuffer(
node.tensor_content,
dtype=tf.as_dtype(node.dtype).as_numpy_dtype,
)
else:
tensor_value = get_tensor_by_type(node, dtype)

Check warning on line 263 in deepmd/utils/graph.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/graph.py#L263

Added line #L263 was not covered by tests
return np.reshape(tensor_value, tensor_shape)

def get_extra_embedding_net_variables_from_graph_def(graph_def: tf.GraphDef, suffix: str, extra_suffix: str, layer_size: int):
"""Get extra embedding net variables from the given tf.GraphDef object.
The "extra embedding net" means the embedding net with only type embeddings input,
which occurs in "se_atten_v2" and "se_a_ebd_v2" descriptor.
Parameters
----------
graph_def
The input tf.GraphDef object
suffix : str
The "common" suffix in the descriptor
extra_suffix: str
This value depends on the value of "type_one_side".
It should always be "_one_side_ebd" or "_two_side_ebd"
layer_size: int
The layer size of the embedding net
Returns
-------
Dict
The extra embedding net variables within the given tf.GraphDef object
"""
extra_embedding_net_variables = {}
for i in range(1, layer_size + 1):
matrix_pattern = f"filter_type_all{suffix}/matrix_{i}{extra_suffix}"
extra_embedding_net_variables[
matrix_pattern
] = get_variables_from_graph_def_as_numpy_array(graph_def, matrix_pattern)
bias_pattern = f"filter_type_all{suffix}/bias_{i}{extra_suffix}"
extra_embedding_net_variables[
bias_pattern
] = get_variables_from_graph_def_as_numpy_array(graph_def, bias_pattern)
return extra_embedding_net_variables

def get_embedding_net_variables(model_file: str, suffix: str = "") -> Dict:
"""Get the embedding net variables with the given frozen model(model_file).
Expand Down

0 comments on commit cc9c470

Please sign in to comment.