diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 1e35bd4d49..982a4eb834 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -6,6 +6,10 @@ import numpy as np +from deepmd.common import ( + GLOBAL_NP_FLOAT_PRECISION, +) + PRECISION_DICT = { "float16": np.float16, "float32": np.float32, @@ -15,6 +19,7 @@ "double": np.float64, "int32": np.int32, "int64": np.int64, + "default": GLOBAL_NP_FLOAT_PRECISION, } DEFAULT_PRECISION = "float64" diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 1cbaf69c49..78ff83a056 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -15,6 +15,7 @@ from deepmd.dpmodel import ( DEFAULT_PRECISION, + PRECISION_DICT, NativeOP, ) from deepmd.dpmodel.utils import ( @@ -133,6 +134,8 @@ def __init__( activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, spin: Optional[Any] = None, + # consistent with argcheck, not used though + seed: Optional[int] = None, ) -> None: ## seed, uniform_seed, multi_task, not included. if not type_one_side: @@ -163,6 +166,8 @@ def __init__( ndim=(1 if self.type_one_side else 2), network_type="embedding_network", ) + if not self.type_one_side: + raise NotImplementedError("type_one_side == False not implemented") for ii in range(self.ntypes): self.embeddings[(ii,)] = EmbeddingNet( in_dim, @@ -316,7 +321,8 @@ def serialize(self) -> dict: "exclude_types": self.exclude_types, "set_davg_zero": self.set_davg_zero, "activation_function": self.activation_function, - "precision": self.precision, + # make deterministic + "precision": np.dtype(PRECISION_DICT[self.precision]).name, "spin": self.spin, "env_mat": self.env_mat.serialize(), "embeddings": self.embeddings.serialize(), diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 17b3043612..8c826c8771 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -192,7 +192,8 @@ def serialize(self) -> dict: "use_timestep": self.idt is not None, "activation_function": self.activation_function, "resnet": self.resnet, - "precision": self.precision, + # make deterministic + "precision": np.dtype(PRECISION_DICT[self.precision]).name, "@variables": data, } @@ -464,7 +465,8 @@ def serialize(self) -> dict: "neuron": self.neuron.copy(), "activation_function": self.activation_function, "resnet_dt": self.resnet_dt, - "precision": self.precision, + # make deterministic + "precision": np.dtype(PRECISION_DICT[self.precision]).name, "layers": [layer.serialize() for layer in self.layers], } diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 700bf6d59b..c722c2dc02 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -19,6 +19,7 @@ ) from deepmd.pt.utils.env import ( PRECISION_DICT, + RESERVED_PRECISON_DICT, ) try: @@ -207,7 +208,8 @@ def serialize(self) -> dict: "resnet_dt": obj.resnet_dt, "set_davg_zero": obj.set_davg_zero, "activation_function": obj.activation_function, - "precision": obj.precision, + # make deterministic + "precision": RESERVED_PRECISON_DICT[obj.prec], "embeddings": obj.filter_layers.serialize(), "env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(), "@variables": { @@ -223,6 +225,7 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "DescrptSeA": + data = data.copy() variables = data.pop("@variables") embeddings = data.pop("embeddings") env_mat = data.pop("env_mat") diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 81499b5063..7383cf5c49 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -42,6 +42,15 @@ "int64": torch.int64, } GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name] +PRECISION_DICT["default"] = GLOBAL_PT_FLOAT_PRECISION +# cannot automatically generated +RESERVED_PRECISON_DICT = { + torch.float16: "float16", + torch.float32: "float32", + torch.float64: "float64", + torch.int32: "int32", + torch.int64: "int64", +} DEFAULT_PRECISION = "float64" # throw warnings if threads not set @@ -58,6 +67,7 @@ "GLOBAL_PT_FLOAT_PRECISION", "DEFAULT_PRECISION", "PRECISION_DICT", + "RESERVED_PRECISON_DICT", "SAMPLER_RECORD", "NUM_WORKERS", "DEVICE", diff --git a/deepmd/tf/descriptor/descriptor.py b/deepmd/tf/descriptor/descriptor.py index fe49fe11fe..1a73d3c273 100644 --- a/deepmd/tf/descriptor/descriptor.py +++ b/deepmd/tf/descriptor/descriptor.py @@ -509,3 +509,41 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict): # call subprocess cls = cls.get_class_by_input(local_jdata) return cls.update_sel(global_jdata, local_jdata) + + @classmethod + def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": + """Deserialize the model. + + There is no suffix in a native DP model, but it is important + for the TF backend. + + Parameters + ---------- + data : dict + The serialized data + suffix : str, optional + Name suffix to identify this descriptor + + Returns + ------- + Descriptor + The deserialized descriptor + """ + if cls is Descriptor: + return Descriptor.get_class_by_input(data).deserialize(data) + raise NotImplementedError("Not implemented in class %s" % cls.__name__) + + def serialize(self, suffix: str = "") -> dict: + """Serialize the model. + + There is no suffix in a native DP model, but it is important + for the TF backend. + + Returns + ------- + dict + The serialized data + suffix : str, optional + Name suffix to identify this descriptor + """ + raise NotImplementedError("Not implemented in class %s" % self.__name__) diff --git a/deepmd/tf/descriptor/se.py b/deepmd/tf/descriptor/se.py index 4f49a8800f..98d98cd467 100644 --- a/deepmd/tf/descriptor/se.py +++ b/deepmd/tf/descriptor/se.py @@ -1,9 +1,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import re from typing import ( + List, + Set, Tuple, ) +from deepmd.dpmodel.utils.network import ( + EmbeddingNet, + NetworkCollection, +) from deepmd.tf.env import ( + EMBEDDING_NET_PATTERN, tf, ) from deepmd.tf.utils.graph import ( @@ -160,3 +168,166 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict): # default behavior is to update sel which is a list local_jdata_cpy = local_jdata.copy() return update_one_sel(global_jdata, local_jdata_cpy, False) + + def serialize_network( + self, + ntypes: int, + ndim: int, + in_dim: int, + neuron: List[int], + activation_function: str, + resnet_dt: bool, + variables: dict, + excluded_types: Set[Tuple[int, int]] = set(), + suffix: str = "", + ) -> dict: + """Serialize network. + + Parameters + ---------- + ntypes : int + The number of types + ndim : int + The dimension of elements + in_dim : int + The input dimension + neuron : List[int] + The neuron list + activation_function : str + The activation function + resnet_dt : bool + Whether to use resnet + variables : dict + The input variables + excluded_types : Set[Tuple[int, int]], optional + The excluded types + suffix : str, optional + The suffix of the scope + + Returns + ------- + dict + The converted network data + """ + embeddings = NetworkCollection( + ntypes=ntypes, + ndim=ndim, + network_type="embedding_network", + ) + if ndim == 2: + for type_i, type_j in excluded_types: + # initialize an empty network for the excluded types + embeddings[(type_i, type_j)] = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + activation_function=activation_function, + resnet_dt=resnet_dt, + precision=self.precision.name, + ) + embeddings[(type_j, type_i)] = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + activation_function=activation_function, + resnet_dt=resnet_dt, + precision=self.precision.name, + ) + for layer in range(len(neuron)): + embeddings[(type_i, type_j)][layer]["w"][:] = 0.0 + embeddings[(type_i, type_j)][layer]["b"][:] = 0.0 + if embeddings[(type_i, type_j)][layer]["idt"] is not None: + embeddings[(type_i, type_j)][layer]["idt"][:] = 0.0 + embeddings[(type_j, type_i)][layer]["w"][:] = 0.0 + embeddings[(type_j, type_i)][layer]["b"][:] = 0.0 + if embeddings[(type_j, type_i)][layer]["idt"] is not None: + embeddings[(type_j, type_i)][layer]["idt"][:] = 0.0 + + if suffix != "": + embedding_net_pattern = ( + EMBEDDING_NET_PATTERN.replace("/(idt)", suffix + "/(idt)") + .replace("/(bias)", suffix + "/(bias)") + .replace("/(matrix)", suffix + "/(matrix)") + ) + else: + embedding_net_pattern = EMBEDDING_NET_PATTERN + for key, value in variables.items(): + m = re.search(embedding_net_pattern, key) + m = [mm for mm in m.groups() if mm is not None] + typei = m[0] + typej = "_".join(m[3:]) if len(m[3:]) else "all" + layer_idx = int(m[2]) - 1 + weight_name = m[1] + if ndim == 0: + network_idx = () + elif ndim == 1: + network_idx = (int(typej),) + elif ndim == 2: + network_idx = (int(typei), int(typej)) + else: + raise ValueError(f"Invalid ndim: {ndim}") + if embeddings[network_idx] is None: + # initialize the network if it is not initialized + embeddings[network_idx] = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + activation_function=activation_function, + resnet_dt=resnet_dt, + precision=self.precision.name, + ) + assert embeddings[network_idx] is not None + if weight_name == "idt": + value = value.ravel() + embeddings[network_idx][layer_idx][weight_name] = value + return embeddings.serialize() + + @classmethod + def deserialize_network(cls, data: dict, suffix: str = "") -> dict: + """Deserialize network. + + Parameters + ---------- + data : dict + The input network data + suffix : str, optional + The suffix of the scope + + Returns + ------- + variables : dict + The input variables + """ + embedding_net_variables = {} + embeddings = NetworkCollection.deserialize(data) + for ii in range(embeddings.ntypes**embeddings.ndim): + net_idx = [] + rest_ii = ii + for _ in range(embeddings.ndim): + net_idx.append(rest_ii % embeddings.ntypes) + rest_ii //= embeddings.ntypes + net_idx = tuple(net_idx) + if embeddings.ndim in (0, 1): + key0 = "all" + key1 = f"_{ii}" + elif embeddings.ndim == 2: + key0 = f"{net_idx[0]}" + key1 = f"_{net_idx[1]}" + else: + raise ValueError(f"Invalid ndim: {embeddings.ndim}") + network = embeddings[net_idx] + assert network is not None + for layer_idx, layer in enumerate(network.layers): + embedding_net_variables[ + f"filter_type_{key0}{suffix}/matrix_{layer_idx + 1}{key1}" + ] = layer.w + embedding_net_variables[ + f"filter_type_{key0}{suffix}/bias_{layer_idx + 1}{key1}" + ] = layer.b + if layer.idt is not None: + embedding_net_variables[ + f"filter_type_{key0}{suffix}/idt_{layer_idx + 1}{key1}" + ] = layer.idt.reshape(1, -1) + else: + # prevent keyError + embedding_net_variables[ + f"filter_type_{key0}{suffix}/idt_{layer_idx + 1}{key1}" + ] = 0.0 + return embedding_net_variables diff --git a/deepmd/tf/descriptor/se_a.py b/deepmd/tf/descriptor/se_a.py index 01c4ee8844..986328479b 100644 --- a/deepmd/tf/descriptor/se_a.py +++ b/deepmd/tf/descriptor/se_a.py @@ -7,6 +7,9 @@ import numpy as np +from deepmd.dpmodel.utils.env_mat import ( + EnvMat, +) from deepmd.tf.common import ( cast_precision, get_activation_func, @@ -195,6 +198,7 @@ def __init__( self.trainable = trainable self.compress_activation_fn = get_activation_func(activation_function) self.filter_activation_fn = get_activation_func(activation_function) + self.activation_function_name = activation_function self.filter_precision = get_precision(precision) self.filter_np_precision = get_np_precision(precision) self.exclude_types = set() @@ -1345,3 +1349,101 @@ def explicit_ntypes(self) -> bool: if self.stripped_type_embedding: return True return False + + @classmethod + def deserialize(cls, data: dict, suffix: str = ""): + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Model + The deserialized model + """ + if cls is not DescrptSeA: + raise NotImplementedError("Not implemented in class %s" % cls.__name__) + data = data.copy() + embedding_net_variables = cls.deserialize_network( + data.pop("embeddings"), suffix=suffix + ) + data.pop("env_mat") + variables = data.pop("@variables") + descriptor = cls(**data) + descriptor.embedding_net_variables = embedding_net_variables + descriptor.davg = variables["davg"].reshape( + descriptor.ntypes, descriptor.ndescrpt + ) + descriptor.dstd = variables["dstd"].reshape( + descriptor.ntypes, descriptor.ndescrpt + ) + return descriptor + + def serialize(self, suffix: str = "") -> dict: + """Serialize the model. + + Parameters + ---------- + suffix : str, optional + The suffix of the scope + + Returns + ------- + dict + The serialized data + """ + if type(self) is not DescrptSeA: + raise NotImplementedError( + "Not implemented in class %s" % self.__class__.__name__ + ) + if self.stripped_type_embedding: + raise NotImplementedError( + "stripped_type_embedding is unsupported by the native model" + ) + if (self.original_sel != self.sel_a).any(): + raise NotImplementedError( + "Adjusting sel is unsupported by the native model" + ) + if self.embedding_net_variables is None: + raise RuntimeError("init_variables must be called before serialize") + if self.spin is not None: + raise NotImplementedError("spin is unsupported") + assert self.davg is not None + assert self.dstd is not None + # TODO: not sure how to handle type embedding - type embedding is not a model parameter, + # but instead a part of the input data. Maybe the interface should be refactored... + + return { + "rcut": self.rcut_r, + "rcut_smth": self.rcut_r_smth, + "sel": self.sel_a, + "neuron": self.filter_neuron, + "axis_neuron": self.n_axis_neuron, + "resnet_dt": self.filter_resnet_dt, + "trainable": self.trainable, + "type_one_side": self.type_one_side, + "exclude_types": list(self.exclude_types), + "set_davg_zero": self.set_davg_zero, + "activation_function": self.activation_function_name, + "precision": self.filter_precision.name, + "embeddings": self.serialize_network( + ntypes=self.ntypes, + ndim=(1 if self.type_one_side else 2), + in_dim=1, + neuron=self.filter_neuron, + activation_function=self.activation_function_name, + resnet_dt=self.filter_resnet_dt, + variables=self.embedding_net_variables, + excluded_types=self.exclude_types, + suffix=suffix, + ), + "env_mat": EnvMat(self.rcut_r, self.rcut_r_smth).serialize(), + "@variables": { + "davg": self.davg.reshape(self.ntypes, self.nnei_a, 4), + "dstd": self.dstd.reshape(self.ntypes, self.nnei_a, 4), + }, + "spin": self.spin, + } diff --git a/deepmd/tf/env.py b/deepmd/tf/env.py index e94c052f55..fe5bb81bae 100644 --- a/deepmd/tf/env.py +++ b/deepmd/tf/env.py @@ -120,19 +120,25 @@ def dlopen_library(module: str, filename: str): except AttributeError: tf_py_version = tf.__version__ +# subpatterns: +# \1: type of centeral atom +# \2: weight name +# \3: layer index +# The rest: types of neighbor atoms +# IMPORTANT: the order is critical to match the pattern EMBEDDING_NET_PATTERN = str( - r"filter_type_\d+/matrix_\d+_\d+|" - r"filter_type_\d+/bias_\d+_\d+|" - r"filter_type_\d+/idt_\d+_\d+|" - r"filter_type_all/matrix_\d+|" - r"filter_type_all/matrix_\d+_\d+|" - r"filter_type_all/matrix_\d+_\d+_\d+|" - r"filter_type_all/bias_\d+|" - r"filter_type_all/bias_\d+_\d+|" - r"filter_type_all/bias_\d+_\d+_\d+|" - r"filter_type_all/idt_\d+|" - r"filter_type_all/idt_\d+_\d+|" -) + r"filter_type_(\d+)/(matrix)_(\d+)_(\d+)|" + r"filter_type_(\d+)/(bias)_(\d+)_(\d+)|" + r"filter_type_(\d+)/(idt)_(\d+)_(\d+)|" + r"filter_type_(all)/(matrix)_(\d+)_(\d+)_(\d+)|" + r"filter_type_(all)/(matrix)_(\d+)_(\d+)|" + r"filter_type_(all)/(matrix)_(\d+)|" + r"filter_type_(all)/(bias)_(\d+)_(\d+)_(\d+)|" + r"filter_type_(all)/(bias)_(\d+)_(\d+)|" + r"filter_type_(all)/(bias)_(\d+)|" + r"filter_type_(all)/(idt)_(\d+)_(\d+)|" + r"filter_type_(all)/(idt)_(\d+)|" +)[:-1] FITTING_NET_PATTERN = str( r"layer_\d+/matrix|" diff --git a/deepmd/tf/utils/graph.py b/deepmd/tf/utils/graph.py index 9d2608e34a..7e67cf27a6 100644 --- a/deepmd/tf/utils/graph.py +++ b/deepmd/tf/utils/graph.py @@ -166,9 +166,9 @@ def get_embedding_net_nodes_from_graph_def( # embedding_net_pattern = f"filter_type_\d+{suffix}/matrix_\d+_\d+|filter_type_\d+{suffix}/bias_\d+_\d+|filter_type_\d+{suffix}/idt_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+_\d+|filter_type_all{suffix}/idt_\d+_\d+" if suffix != "": embedding_net_pattern = ( - EMBEDDING_NET_PATTERN.replace("/idt", suffix + "/idt") - .replace("/bias", suffix + "/bias") - .replace("/matrix", suffix + "/matrix") + EMBEDDING_NET_PATTERN.replace("/(idt)", suffix + "/(idt)") + .replace("/(bias)", suffix + "/(bias)") + .replace("/(matrix)", suffix + "/(matrix)") ) else: embedding_net_pattern = EMBEDDING_NET_PATTERN @@ -176,10 +176,6 @@ def get_embedding_net_nodes_from_graph_def( embedding_net_nodes = get_pattern_nodes_from_graph_def( graph_def, embedding_net_pattern ) - for key in embedding_net_nodes.keys(): - assert ( - key.find("bias") > 0 or key.find("matrix") > 0 - ), "currently, only support weight matrix and bias matrix at the tabulation op!" return embedding_net_nodes diff --git a/deepmd/tf/utils/tabulate.py b/deepmd/tf/utils/tabulate.py index ff5e2b9e09..958e08dd86 100644 --- a/deepmd/tf/utils/tabulate.py +++ b/deepmd/tf/utils/tabulate.py @@ -133,6 +133,10 @@ def __init__( self.embedding_net_nodes = get_embedding_net_nodes_from_graph_def( self.graph_def, suffix=self.suffix ) + for key in self.embedding_net_nodes.keys(): + assert ( + key.find("bias") > 0 or key.find("matrix") > 0 + ), "currently, only support weight matrix and bias matrix at the tabulation op!" # move it to the descriptor class # for tt in self.exclude_types: diff --git a/source/tests/consistent/__init__.py b/source/tests/consistent/__init__.py new file mode 100644 index 0000000000..50b8b8bdc5 --- /dev/null +++ b/source/tests/consistent/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Test whether DP native, TF, and PT models are consistent.""" diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py new file mode 100644 index 0000000000..e5633726ef --- /dev/null +++ b/source/tests/consistent/common.py @@ -0,0 +1,375 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import os +import sys +from abc import ( + ABC, + abstractmethod, +) +from enum import ( + Enum, +) +from importlib.util import ( + find_spec, +) +from typing import ( + Any, + Callable, + ClassVar, + List, + Optional, + Tuple, +) +from uuid import ( + uuid4, +) + +import numpy as np +from dargs import ( + Argument, +) + +INSTALLED_TF = find_spec("tensorflow") is not None +INSTALLED_PT = find_spec("torch") is not None + +if os.environ.get("CI") and not (INSTALLED_TF and INSTALLED_PT): + raise ImportError("TensorFlow or PyTorch should be tested in the CI") + + +if INSTALLED_TF: + from deepmd.tf.common import ( + clear_session, + ) + from deepmd.tf.env import ( + default_tf_session_config, + tf, + ) + from deepmd.tf.utils.sess import ( + run_sess, + ) + + +__all__ = [ + "CommonTest", + "INSTALLED_TF", + "INSTALLED_PT", +] + + +class CommonTest(ABC): + data: ClassVar[dict] + """Arguments data.""" + tf_class: ClassVar[Optional[type]] + """TensorFlow model class.""" + dp_class: ClassVar[Optional[type]] + """Native DP model class.""" + pt_class: ClassVar[Optional[type]] + """PyTorch model class.""" + args: ClassVar[Optional[List[Argument]]] + """Arguments that maps to the `data`.""" + skip_dp: ClassVar[bool] = False + """Whether to skip the native DP model.""" + skip_tf: ClassVar[bool] = not INSTALLED_TF + """Whether to skip the TensorFlow model.""" + skip_pt: ClassVar[bool] = not INSTALLED_PT + """Whether to skip the PyTorch model.""" + + def setUp(self): + self.unique_id = uuid4().hex + + def reset_unique_id(self): + self.unique_id = uuid4().hex + + def init_backend_cls(self, cls) -> Any: + """Initialize a backend model.""" + assert self.data is not None + if self.args is None: + data = self.data + else: + base = Argument("arg", dict, sub_fields=self.args) + data = base.normalize_value(self.data, trim_pattern="_*") + base.check_value(data, strict=True) + return cls(**data) + + @abstractmethod + def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]: + """Build the TF graph. + + Parameters + ---------- + obj : Any + The object of TF + suffix : str + The suffix of the scope + + Returns + ------- + list of tf.Tensor + The list of tensors + dict + The feed_dict + """ + + @abstractmethod + def eval_dp(self, dp_obj: Any) -> Any: + """Evaluate the return value of DP. + + Parameters + ---------- + dp_obj : Any + The object of DP + """ + + @abstractmethod + def eval_pt(self, pt_obj: Any) -> Any: + """Evaluate the return value of PT. + + Parameters + ---------- + pt_obj : Any + The object of PT + """ + + class RefBackend(Enum): + """Reference backend.""" + + TF = 1 + DP = 2 + PT = 3 + + @abstractmethod + def extract_ret(self, ret: Any, backend: RefBackend) -> Tuple[np.ndarray, ...]: + """Extract the return value when comparing with other backends. + + Parameters + ---------- + ret : Any + The return value + backend : RefBackend + The backend + + Returns + ------- + tuple[np.ndarray, ...] + The extracted return value + """ + + def build_eval_tf( + self, sess: "tf.Session", obj: Any, suffix: str + ) -> List[np.ndarray]: + """Build and evaluate the TF graph.""" + t_out, feed_dict = self.build_tf(obj, suffix) + + t_out_indentity = [ + tf.identity(tt, name=f"o_{ii}_{suffix}") for ii, tt in enumerate(t_out) + ] + run_sess(sess, tf.global_variables_initializer()) + return run_sess( + sess, + t_out_indentity, + feed_dict=feed_dict, + ) + + def get_tf_ret_serialization_from_cls(self, obj): + with tf.Session(config=default_tf_session_config) as sess: + graph = tf.get_default_graph() + ret = self.build_eval_tf(sess, obj, suffix=self.unique_id) + output_graph_def = tf.graph_util.convert_variables_to_constants( + sess, + graph.as_graph_def(), + [f"o_{ii}_{self.unique_id}" for ii, _ in enumerate(ret)], + ) + with tf.Graph().as_default() as new_graph: + tf.import_graph_def(output_graph_def, name="") + obj.init_variables(new_graph, output_graph_def, suffix=self.unique_id) + + data = obj.serialize(suffix=self.unique_id) + return ret, data + + def get_pt_ret_serialization_from_cls(self, obj): + ret = self.eval_pt(obj) + data = obj.serialize() + return ret, data + + def get_dp_ret_serialization_from_cls(self, obj): + ret = self.eval_dp(obj) + data = obj.serialize() + return ret, data + + def get_reference_backend(self): + """Get the reference backend. + + Order of checking for ref: DP, TF, PT. + """ + if not self.skip_dp: + return self.RefBackend.DP + if not self.skip_tf: + return self.RefBackend.TF + if not self.skip_pt: + return self.RefBackend.PT + raise ValueError("No available reference") + + def get_reference_ret_serialization(self, ref: RefBackend): + if ref == self.RefBackend.DP: + obj = self.init_backend_cls(self.dp_class) + return self.get_dp_ret_serialization_from_cls(obj) + if ref == self.RefBackend.TF: + obj = self.init_backend_cls(self.tf_class) + self.reset_unique_id() + return self.get_tf_ret_serialization_from_cls(obj) + if ref == self.RefBackend.PT: + obj = self.init_backend_cls(self.pt_class) + return self.get_pt_ret_serialization_from_cls(obj) + raise ValueError("No available reference") + + def test_tf_consistent_with_ref(self): + """Test whether TF and reference are consistent.""" + if self.skip_tf: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.TF: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + self.reset_unique_id() + tf_obj = self.tf_class.deserialize(data1, suffix=self.unique_id) + ret2, data2 = self.get_tf_ret_serialization_from_cls(tf_obj) + ret2 = self.extract_ret(ret2, self.RefBackend.TF) + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2): + np.testing.assert_allclose(rr1, rr2) + + def test_tf_self_consistent(self): + """Test whether TF is self consistent.""" + if self.skip_tf: + self.skipTest("Unsupported backend") + obj1 = self.init_backend_cls(self.tf_class) + self.reset_unique_id() + ret1, data1 = self.get_tf_ret_serialization_from_cls(obj1) + self.reset_unique_id() + obj2 = self.tf_class.deserialize(data1, suffix=self.unique_id) + ret2, data2 = self.get_tf_ret_serialization_from_cls(obj2) + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2): + np.testing.assert_allclose(rr1, rr2) + + def test_dp_consistent_with_ref(self): + """Test whether DP and reference are consistent.""" + if self.skip_dp: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.DP: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + dp_obj = self.dp_class.deserialize(data1) + ret2 = self.eval_dp(dp_obj) + ret2 = self.extract_ret(ret2, self.RefBackend.DP) + data2 = dp_obj.serialize() + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2): + np.testing.assert_allclose(rr1, rr2) + + def test_dp_self_consistent(self): + """Test whether DP is self consistent.""" + if self.skip_dp: + self.skipTest("Unsupported backend") + obj1 = self.init_backend_cls(self.dp_class) + ret1, data1 = self.get_dp_ret_serialization_from_cls(obj1) + obj1 = self.dp_class.deserialize(data1) + ret2, data2 = self.get_dp_ret_serialization_from_cls(obj1) + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2): + if isinstance(rr1, np.ndarray) and isinstance(rr2, np.ndarray): + np.testing.assert_allclose(rr1, rr2) + else: + self.assertEqual(rr1, rr2) + + def test_pt_consistent_with_ref(self): + """Test whether PT and reference are consistent.""" + if self.skip_pt: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.PT: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + obj = self.pt_class.deserialize(data1) + ret2 = self.eval_pt(obj) + ret2 = self.extract_ret(ret2, self.RefBackend.PT) + data2 = obj.serialize() + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2): + np.testing.assert_allclose(rr1, rr2) + + def test_pt_self_consistent(self): + """Test whether PT is self consistent.""" + if self.skip_pt: + self.skipTest("Unsupported backend") + obj1 = self.init_backend_cls(self.pt_class) + ret1, data1 = self.get_pt_ret_serialization_from_cls(obj1) + obj2 = self.pt_class.deserialize(data1) + ret2, data2 = self.get_pt_ret_serialization_from_cls(obj2) + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2): + if isinstance(rr1, np.ndarray) and isinstance(rr2, np.ndarray): + np.testing.assert_allclose(rr1, rr2) + else: + self.assertEqual(rr1, rr2) + + def tearDown(self) -> None: + """Clear the TF session.""" + if not self.skip_tf: + clear_session() + + +def parameterized(*attrs: tuple) -> Callable: + """Parameterized test. + + Orginal class will not be actually generated. Avoid inherbiting from it. + New classes are generated with the name of the original class and the + parameters. + + Parameters + ---------- + *attrs : tuple + The attributes to be parameterized. + + Returns + ------- + object + The decorator. + + Examples + -------- + >>> @parameterized( + ... (True, False), + ... (True, False), + ... ) + ... class TestSeA(CommonTest, unittest.TestCase): + ... @property + ... def data(self) -> dict: + ... ( + ... param1, + ... param2, + ... ) = self.param + ... return { + ... "param1": param1, + ... "param2": param2, + ... } + """ + + def decorator(base_class: type): + class_module = sys.modules[base_class.__module__].__dict__ + for pp in itertools.product(*attrs): + + class TestClass(base_class): + param: ClassVar = pp + + name = f"{base_class.__name__}_{'_'.join(str(x) for x in pp)}" + + class_module[name] = TestClass + # make unittest module happy by ignoring the original one + return object + + return decorator diff --git a/source/tests/consistent/descriptor/__init__.py b/source/tests/consistent/descriptor/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/consistent/descriptor/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py new file mode 100644 index 0000000000..ef7b39b52e --- /dev/null +++ b/source/tests/consistent/descriptor/common.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.common import ( + make_default_mesh, +) +from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) + +from ..common import ( + INSTALLED_PT, + INSTALLED_TF, +) + +if INSTALLED_PT: + import torch + + from deepmd.pt.utils.env import DEVICE as PT_DEVICE + from deepmd.pt.utils.nlist import build_neighbor_list as build_neighbor_list_pt + from deepmd.pt.utils.nlist import ( + extend_coord_with_ghosts as extend_coord_with_ghosts_pt, + ) +if INSTALLED_TF: + from deepmd.tf.env import ( + GLOBAL_TF_FLOAT_PRECISION, + tf, + ) + + +class DescriptorTest: + """Useful utilities for descriptor tests.""" + + def build_tf_descriptor(self, obj, natoms, coords, atype, box, suffix): + t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord") + t_type = tf.placeholder(tf.int32, [None], name="i_type") + t_natoms = tf.placeholder(tf.int32, natoms.shape, name="i_natoms") + t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [9], name="i_box") + t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh") + t_des = obj.build( + t_coord, + t_type, + t_natoms, + t_box, + t_mesh, + {}, + suffix=suffix, + ) + return [t_des], { + t_coord: coords, + t_type: atype, + t_natoms: natoms, + t_box: box, + t_mesh: make_default_mesh(True, False), + } + + def eval_dp_descriptor(self, dp_obj: Any, natoms, coords, atype, box) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts( + coords.reshape(1, -1, 3), + atype.reshape(1, -1), + box.reshape(1, 3, 3), + dp_obj.get_rcut(), + ) + nlist = build_neighbor_list( + ext_coords, + ext_atype, + natoms[0], + dp_obj.get_rcut(), + dp_obj.get_sel(), + distinguish_types=True, + ) + return dp_obj(ext_coords, ext_atype, nlist=nlist) + + def eval_pt_descriptor(self, pt_obj: Any, natoms, coords, atype, box) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( + torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), + torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), + torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), + pt_obj.get_rcut(), + ) + nlist = build_neighbor_list_pt( + ext_coords, + ext_atype, + natoms[0], + pt_obj.get_rcut(), + pt_obj.get_sel(), + distinguish_types=True, + ) + return [ + x.detach().cpu().numpy() if torch.is_tensor(x) else x + for x in pt_obj(ext_coords, ext_atype, nlist=nlist) + ] diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py new file mode 100644 index 0000000000..a694a2a20c --- /dev/null +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, + Tuple, +) + +import numpy as np + +from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeA as DescrptSeADP +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + INSTALLED_TF, + CommonTest, + parameterized, +) +from .common import ( + DescriptorTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.descriptor.se_a import DescrptSeA as DescrptSeAPT +else: + DescrptSeAPT = None +if INSTALLED_TF: + from deepmd.tf.descriptor.se_a import DescrptSeA as DescrptSeATF +else: + DescrptSeATF = None +from deepmd.utils.argcheck import ( + descrpt_se_a_args, +) + + +@parameterized( + (True, False), # resnet_dt + (True, False), # type_one_side + ([], [[0, 1]]), # excluded_types +) +class TestSeA(CommonTest, DescriptorTest, unittest.TestCase): + @property + def data(self) -> dict: + ( + resnet_dt, + type_one_side, + excluded_types, + ) = self.param + return { + "sel": [10, 10], + "rcut_smth": 5.80, + "rcut": 6.00, + "neuron": [6, 12, 24], + "axis_neuron": 3, + "resnet_dt": resnet_dt, + "type_one_side": type_one_side, + "exclude_types": excluded_types, + "seed": 1145141919810, + } + + @property + def skip_pt(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + ) = self.param + return not type_one_side or excluded_types != [] or CommonTest.skip_pt + + @property + def skip_dp(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + ) = self.param + return not type_one_side or excluded_types != [] or CommonTest.skip_dp + + tf_class = DescrptSeATF + dp_class = DescrptSeADP + pt_class = DescrptSeAPT + args = descrpt_se_a_args() + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + def build_tf(self, obj: Any, suffix: str) -> Tuple[list, dict]: + return self.build_tf_descriptor( + obj, + self.natoms, + self.coords, + self.atype, + self.box, + suffix, + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_descriptor( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_descriptor( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: + return (ret[0],)