diff --git a/deepmd/pt/model/descriptor/__init__.py b/deepmd/pt/model/descriptor/__init__.py index 1c2e943369..445537e2c5 100644 --- a/deepmd/pt/model/descriptor/__init__.py +++ b/deepmd/pt/model/descriptor/__init__.py @@ -13,6 +13,7 @@ ) from .env_mat import ( prod_env_mat_se_a, + prod_env_mat_se_r, ) from .gaussian_lcc import ( DescrptGaussianLcc, @@ -27,6 +28,9 @@ DescrptBlockSeA, DescrptSeA, ) +from .se_r import( + DescrptSeR +) __all__ = [ "Descriptor", @@ -35,9 +39,11 @@ "DescrptBlockSeA", "DescrptBlockSeAtten", "DescrptSeA", + "DescrptSeR", "DescrptDPA1", "DescrptDPA2", "prod_env_mat_se_a", + "prod_env_mat_se_r", "DescrptGaussianLcc", "DescrptBlockHybrid", "DescrptBlockRepformers", diff --git a/deepmd/pt/model/descriptor/env_mat.py b/deepmd/pt/model/descriptor/env_mat.py index b3235de175..be6c2437d8 100644 --- a/deepmd/pt/model/descriptor/env_mat.py +++ b/deepmd/pt/model/descriptor/env_mat.py @@ -29,6 +29,27 @@ def _make_env_mat_se_a(nlist, coord, rcut: float, ruct_smth: float): env_mat_se_a = torch.cat([t0, t1], dim=-1) * weight return env_mat_se_a, diff * mask.unsqueeze(-1), weight +def _make_env_mat_se_r(nlist, coord, rcut: float, ruct_smth: float): + """Make smooth environment matrix.""" + bsz, natoms, nnei = nlist.shape + coord = coord.view(bsz, -1, 3) + nall = coord.shape[1] + mask = nlist >= 0 + # nlist = nlist * mask ## this impl will contribute nans in Hessian calculation. + nlist = torch.where(mask, nlist, nall - 1) + coord_l = coord[:, :natoms].view(bsz, -1, 1, 3) + index = nlist.view(bsz, -1).unsqueeze(-1).expand(-1, -1, 3) + coord_r = torch.gather(coord, 1, index) + coord_r = coord_r.view(bsz, natoms, nnei, 3) + diff = coord_r - coord_l + length = torch.linalg.norm(diff, dim=-1, keepdim=True) + # for index 0 nloc atom + length = length + ~mask.unsqueeze(-1) + t0 = 1 / length + weight = compute_smooth_weight(length, ruct_smth, rcut) + weight = weight * mask.unsqueeze(-1) + env_mat_se_r = t0 * weight + return env_mat_se_r, diff * mask.unsqueeze(-1), weight def prod_env_mat_se_a( extended_coord, nlist, atype, mean, stddev, rcut: float, rcut_smth: float @@ -58,3 +79,32 @@ def prod_env_mat_se_a( t_std = stddev[atype] # [n_atom, dim, 4] env_mat_se_a = (_env_mat_se_a - t_avg) / t_std return env_mat_se_a, diff, switch + +def prod_env_mat_se_r( + extended_coord, nlist, atype, mean, stddev, rcut: float, rcut_smth: float +): + """Generate smooth environment matrix from atom coordinates and other context. + + Args: + - extended_coord: Copied atom coordinates with shape [nframes, nall*3]. + - atype: Atom types with shape [nframes, nloc]. + - natoms: Batched atom statisics with shape [len(sec)+2]. + - box: Batched simulation box with shape [nframes, 9]. + - mean: Average value of descriptor per element type with shape [len(sec), nnei, 1]. + - stddev: Standard deviation of descriptor per element type with shape [len(sec), nnei, 1]. + - deriv_stddev: StdDev of descriptor derivative per element type with shape [len(sec), nnei, 1, 3]. + - rcut: Cut-off radius. + - rcut_smth: Smooth hyper-parameter for pair force & energy. + + Returns + ------- + - env_mat_se_r: Shape is [nframes, natoms[1]*nnei*1]. + """ + nframes = extended_coord.shape[0] + _env_mat_se_r, diff, switch = _make_env_mat_se_r( + nlist, extended_coord, rcut, rcut_smth + ) # shape [n_atom, dim, 1] + t_avg = mean[atype] # [n_atom, dim, 1] + t_std = stddev[atype] # [n_atom, dim, 1] + env_mat_se_r = (_env_mat_se_r - t_avg) / t_std + return env_mat_se_r, diff, switch diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py new file mode 100644 index 0000000000..461667364b --- /dev/null +++ b/deepmd/pt/model/descriptor/se_r.py @@ -0,0 +1,399 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Optional, + Tuple, +) + +import numpy as np +import torch + +from deepmd.pt.model.descriptor import ( + Descriptor, + prod_env_mat_se_r, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISON_DICT, +) +from deepmd.pt.utils.env_mat_stat import ( + EnvMatStatSeR, +) +from deepmd.utils.env_mat_stat import ( + StatItem, +) +from deepmd.utils.path import ( + DPPath, +) + +try: + from typing import ( + Final, + ) +except ImportError: + from torch.jit import Final + +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.pt.model.network.mlp import ( + EmbeddingNet, + NetworkCollection, +) +from deepmd.pt.model.network.network import ( + TypeFilter, +) +from deepmd.pt.utils.exclude_mask import ( + PairExcludeMask, +) + + +@Descriptor.register("se_e2_r") +class DescrptSeR(Descriptor): + def __init__( + self, + rcut, + rcut_smth, + sel, + neuron=[25, 50, 100], + set_davg_zero: bool = False, + activation_function: str = "tanh", + precision: str = "float64", + resnet_dt: bool = False, + exclude_types: List[Tuple[int, int]] = [], + old_impl: bool = False, + **kwargs, + ): + super().__init__() + self.rcut = rcut + self.rcut_smth = rcut_smth + self.neuron = neuron + self.filter_neuron = self.neuron + self.set_davg_zero = set_davg_zero + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.resnet_dt = resnet_dt + self.old_impl = old_impl + self.exclude_types = exclude_types + self.ntypes = len(sel) + self.emask = PairExcludeMask(len(sel), exclude_types=exclude_types) + + self.sel = sel + self.sec = torch.tensor( + np.append([0], np.cumsum(self.sel)), dtype=int, device=env.DEVICE + ) + self.split_sel = self.sel + self.nnei = sum(sel) + self.ndescrpt = self.nnei * 1 + + wanted_shape = (self.ntypes, self.nnei, 1) + mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) + stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE) + self.register_buffer("mean", mean) + self.register_buffer("stddev", stddev) + self.filter_layers_old = None + self.filter_layers = None + + if self.old_impl: + filter_layers = [] + # TODO: remove + start_index = 0 + for type_i in range(self.ntypes): + one = TypeFilter(start_index, sel[type_i], self.filter_neuron) + filter_layers.append(one) + start_index += sel[type_i] + self.filter_layers_old = torch.nn.ModuleList(filter_layers) + else: + filter_layers = NetworkCollection( + ndim=1, ntypes=len(sel), network_type="embedding_network" + ) + # TODO: ndim=2 if type_one_side=False + for ii in range(self.ntypes): + filter_layers[(ii,)] = EmbeddingNet( + 1, + self.filter_neuron, + activation_function=self.activation_function, + precision=self.precision, + resnet_dt=self.resnet_dt, + ) + self.filter_layers = filter_layers + self.stats = None + + def get_rcut(self) -> float: + """Returns the cut-off radius.""" + return self.rcut + + def get_nsel(self) -> int: + """Returns the number of selected atoms in the cut-off radius.""" + return sum(self.sel) + + def get_sel(self) -> List[int]: + """Returns the number of selected atoms for each type.""" + return self.sel + + def get_ntypes(self) -> int: + """Returns the number of element types.""" + return self.ntypes + + def get_dim_out(self) -> int: + """Returns the output dimension.""" + return self.neuron[-1] + + def get_dim_emb(self) -> int: + """Returns the output dimension.""" + raise NotImplementedError + + def get_dim_in(self) -> int: + """Returns the input dimension.""" + return 0 + + def mixed_types(self) -> bool: + """If true, the discriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the discriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + """ + return False + + def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None): + """Update mean and stddev for descriptor elements.""" + env_mat_stat = EnvMatStatSeR(self) + if path is not None: + path = path / env_mat_stat.get_hash() + env_mat_stat.load_or_compute_stats(merged, path) + self.stats = env_mat_stat.stats + mean, stddev = env_mat_stat() + if not self.set_davg_zero: + self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) + self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) + if not self.set_davg_zero: + self.mean.copy_(torch.tensor(mean, device=env.DEVICE)) + self.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) + + def get_stats(self) -> Dict[str, StatItem]: + """Get the statistics of the descriptor.""" + if self.stats is None: + raise RuntimeError( + "The statistics of the descriptor has not been computed." + ) + return self.stats + + def __setitem__(self, key, value): + if key in ("avg", "data_avg", "davg"): + self.mean = value + elif key in ("std", "data_std", "dstd"): + self.stddev = value + else: + raise KeyError(key) + + def __getitem__(self, key): + if key in ("avg", "data_avg", "davg"): + return self.mean + elif key in ("std", "data_std", "dstd"): + return self.stddev + else: + raise KeyError(key) + + @classmethod + def get_data_process_key(cls, config): + """ + Get the keys for the data preprocess. + Usually need the information of rcut and sel. + TODO Need to be deprecated when the dataloader has been cleaned up. + """ + descrpt_type = config["type"] + assert descrpt_type in ["se_e2_r"] + return {"sel": config["sel"], "rcut": config["rcut"]} + + @property + def data_stat_key(self): + """ + Get the keys for the data statistic of the descriptor. + Return a list of statistic names needed, such as "sumr", "sumr2" or "sumn". + """ + return ["sumr", "sumn", "sumr2"] + + def forward( + self, + coord_ext: torch.Tensor, + atype_ext: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + ): + """Compute the descriptor. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nallx3) + atype_ext + The extended aotm types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping, not required by this descriptor. + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant and permutationally invariant single particle + representation. shape: nf x nloc x ng x 3 + g2 + The rotationally invariant pair-partical representation. + this descriptor returns None + h2 + The rotationally equivariant pair-partical representation. + this descriptor returns None + sw + The smooth switch function. + + """ + + del mapping + nloc = nlist.shape[1] + atype = atype_ext[:, :nloc] + dmatrix, diff, sw = prod_env_mat_se_r( + coord_ext, + nlist, + atype, + self.mean, + self.stddev, + self.rcut, + self.rcut_smth, + ) + assert dmatrix.shape == (2,3,7,1) + + if self.old_impl: + assert self.filter_layers_old is not None + dmatrix = dmatrix.view( + -1, self.ndescrpt + ) # shape is [nframes*nall, self.ndescrpt] + xyz_scatter = torch.empty( + 1, + device=env.DEVICE, + ) + ret = self.filter_layers_old[0](dmatrix) + xyz_scatter = ret + for ii, transform in enumerate(self.filter_layers_old[1:]): + # shape is [nframes*nall, 1, self.filter_neuron[-1]] + ret = transform.forward(dmatrix) + xyz_scatter = xyz_scatter + ret + else: + assert self.filter_layers is not None + dmatrix = dmatrix.view(-1, self.nnei, 1) + dmatrix = dmatrix.to(dtype=self.prec) + nfnl = dmatrix.shape[0] + # pre-allocate a shape to pass jit + xyz_scatter = torch.zeros( + [nfnl, 1, self.filter_neuron[-1]], dtype=self.prec, device=env.DEVICE + ) + + # nfnl x nnei + exclude_mask = self.emask(nlist, atype_ext).view(nfnl, -1) + for ii, ll in enumerate(self.filter_layers.networks): + # nfnl x nt + mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]] + # nfnl x nt x 1 + rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :] + rr = rr * mm[:, :, None] + ss = rr[:, :, :1] + # nfnl x nt x ng + gg = ll.forward(ss) + # nfnl x 1 x ng + gr = torch.matmul(rr.permute(0, 2, 1), gg) + xyz_scatter += gr + + xyz_scatter /= self.nnei + xyz_scatter_1 = xyz_scatter.permute(0, 2, 1) + + result = torch.matmul( + xyz_scatter_1, xyz_scatter + ) # shape is [nframes*nall, self.filter_neuron[-1], 1] + result = result.view(-1, nloc, self.filter_neuron[-1] * 1) + return ( + result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + None, + None, + None, + sw, + ) + + def set_stat_mean_and_stddev( + self, + mean: torch.Tensor, + stddev: torch.Tensor, + ) -> None: + self.mean = mean + self.stddev = stddev + + def serialize(self) -> dict: + return { + "rcut": self.rcut, + "rcut_smth": self.rcut_smth, + "sel": self.sel, + "neuron": self.neuron, + "resnet_dt": self.resnet_dt, + "set_davg_zero": self.set_davg_zero, + "activation_function": self.activation_function, + # make deterministic + "precision": RESERVED_PRECISON_DICT[self.prec], + "embeddings": self.filter_layers.serialize(), + "env_mat": DPEnvMat(self.rcut, self.rcut_smth).serialize(), + "exclude_types": self.exclude_types, + "@variables": { + "davg": self["davg"].detach().cpu().numpy(), + "dstd": self["dstd"].detach().cpu().numpy(), + }, + ## to be updated when the options are supported. + "trainable": True, + "type_one_side": True, + "spin": None, + } + + @classmethod + def deserialize(cls, data: dict) -> "DescrptSeR": + data = data.copy() + variables = data.pop("@variables") + embeddings = data.pop("embeddings") + env_mat = data.pop("env_mat") + obj = cls(**data) + + def t_cvt(xx): + return torch.tensor(xx, dtype=obj.prec, device=env.DEVICE) + + obj["davg"] = t_cvt(variables["davg"]) + obj["dstd"] = t_cvt(variables["dstd"]) + obj.filter_layers = NetworkCollection.deserialize(embeddings) + return obj + +def analyze_descrpt(matrix, ndescrpt, natoms): + """Collect avg, square avg and count of descriptors in a batch.""" + ntypes = natoms.shape[1] - 2 + start_index = 0 + sysr = [] + sysn = [] + sysr2 = [] + for type_i in range(ntypes): + end_index = start_index + natoms[0, 2 + type_i] + dd = matrix[:, start_index:end_index] # all descriptors for this element + start_index = end_index + dd = np.reshape( + dd, [-1, 1] + ) # Shape is [nframes*natoms[2+type_id]*self.nnei, 1] + ddr = dd[:, :1] + sumr = np.sum(ddr) + sumn = dd.shape[0] # Value is nframes*natoms[2+type_id]*self.nnei + sumr2 = np.sum(np.multiply(ddr, ddr)) + sysr.append(sumr) + sysn.append(sumn) + sysr2.append(sumr2) + return sysr, sysr2, sysn diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index 2f3c728c99..61f1dccb38 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -14,6 +14,7 @@ ) from deepmd.pt.model.descriptor.env_mat import ( prod_env_mat_se_a, + prod_env_mat_se_r, ) from deepmd.pt.utils import ( env, @@ -193,3 +194,138 @@ def __call__(self): mean = np.stack(all_davg) stddev = np.stack(all_dstd) return mean, stddev + + +class EnvMatStatSeR(EnvMatStat): + """Environmental matrix statistics for the se_r environemntal matrix. + + Parameters + ---------- + descriptor : DescriptorBlock + The descriptor of the model. + """ + + def __init__(self, descriptor: "DescriptorBlock"): + super().__init__() + self.descriptor = descriptor + + def iter( + self, data: List[Dict[str, torch.Tensor]] + ) -> Iterator[Dict[str, StatItem]]: + """Get the iterator of the environment matrix. + + Parameters + ---------- + data : List[Dict[str, torch.Tensor]] + The environment matrix. + + Yields + ------ + Dict[str, StatItem] + The statistics of the environment matrix. + """ + zero_mean = torch.zeros( + self.descriptor.get_ntypes(), + self.descriptor.get_nsel(), + 1, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, + ) + one_stddev = torch.ones( + self.descriptor.get_ntypes(), + self.descriptor.get_nsel(), + 1, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, + ) + for system in data: + coord, atype, box, natoms = ( + system["coord"], + system["atype"], + system["box"], + system["natoms"], + ) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + self.descriptor.get_rcut(), + self.descriptor.get_sel(), + mixed_types=self.descriptor.mixed_types(), + box=box, + ) + env_mat, _, _ = prod_env_mat_se_r( + extended_coord, + nlist, + atype, + zero_mean, + one_stddev, + self.descriptor.get_rcut(), + # TODO: export rcut_smth from DescriptorBlock + self.descriptor.rcut_smth, + ) + # reshape to nframes * nloc at the atom level, + # so nframes/mixed_type do not matter + env_mat = env_mat.view( + coord.shape[0] * coord.shape[1], self.descriptor.get_nsel(), 1 + ) + atype = atype.view(coord.shape[0] * coord.shape[1]) + # (1, nloc) eq (ntypes, 1), so broadcast is possible + # shape: (ntypes, nloc) + type_idx = torch.eq( + atype.view(1, -1), + torch.arange( + self.descriptor.get_ntypes(), device=env.DEVICE, dtype=torch.int32 + ).view(-1, 1), + ) + for type_i in range(self.descriptor.get_ntypes()): + dd = env_mat[type_idx[type_i]] + dd = dd.reshape([-1, 1]) # typen_atoms * nnei, 4 + env_mats = {} + env_mats[f"r_{type_i}"] = dd[:, :1] + yield self.compute_stat(env_mats) + + def get_hash(self) -> str: + """Get the hash of the environment matrix. + + Returns + ------- + str + The hash of the environment matrix. + """ + return get_hash( + { + "type": "se_r", + "ntypes": self.descriptor.get_ntypes(), + "rcut": round(self.descriptor.get_rcut(), 2), + "rcut_smth": round(self.descriptor.rcut_smth, 2), + "nsel": self.descriptor.get_nsel(), + "sel": self.descriptor.get_sel(), + "mixed_types": self.descriptor.mixed_types(), + } + ) + + def __call__(self): + avgs = self.get_avg() + stds = self.get_std() + + all_davg = [] + all_dstd = [] + for type_i in range(self.descriptor.get_ntypes()): + davgunit = [[avgs[f"r_{type_i}"]]] + dstdunit = [ + [ + stds[f"r_{type_i}"] + ] + ] + davg = np.tile(davgunit, [self.descriptor.get_nsel(), 1]) + dstd = np.tile(dstdunit, [self.descriptor.get_nsel(), 1]) + all_davg.append(davg) + all_dstd.append(dstd) + mean = np.stack(all_davg) + stddev = np.stack(all_dstd) + return mean, stddev \ No newline at end of file diff --git a/source/tests/pt/model/test_descriptor_se_r.py b/source/tests/pt/model/test_descriptor_se_r.py new file mode 100644 index 0000000000..abf3a050d3 --- /dev/null +++ b/source/tests/pt/model/test_descriptor_se_r.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +# from deepmd.dpmodel.descriptor import DescrptSeR as DPDescrptSeR +from deepmd.pt.model.descriptor.se_r import ( + DescrptSeR, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from .test_mlp import ( + get_tols, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +# to be merged with the tf test case +class TestDescrptSeR(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_consistency( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 1)) + dstd = rng.normal(size=(self.nt, nnei, 1)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec, em in itertools.product( + [False, True], + ["float64", "float32"], + [[], [[0, 1]], [[1, 1]]], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + # sea new impl + dd0 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + old_impl=False, + exclude_mask=em, + ).to(env.DEVICE) + dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + # serialization + dd1 = DescrptSeR.deserialize(dd0.serialize()) + rd1, gr1, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + # dp impl + # dd2 = DPDescrptSeR.deserialize(dd0.serialize()) + # rd2, gr2, _, _, sw2 = dd2.call( + # self.coord_ext, + # self.atype_ext, + # self.nlist, + # ) + # for aa, bb in zip([rd1, gr1, sw1], [rd2, gr2, sw2]): + # np.testing.assert_allclose( + # aa.detach().cpu().numpy(), + # bb, + # rtol=rtol, + # atol=atol, + # err_msg=err_msg, + # ) + # old impl + if idt is False and prec == "float64": + dd3 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + old_impl=True, + ).to(env.DEVICE) + dd0_state_dict = dd0.state_dict() + dd3_state_dict = dd3.state_dict() + for i in dd3_state_dict: + dd3_state_dict[i] = ( + dd0_state_dict[ + i.replace(".deep_layers.", ".layers.").replace( + "filter_layers_old.", "filter_layers.networks." + ) + ] + .detach() + .clone() + ) + if ".bias" in i: + dd3_state_dict[i] = dd3_state_dict[i].unsqueeze(0) + dd3.load_state_dict(dd3_state_dict) + + rd3, gr3, _, _, sw3 = dd3( + torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), + torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), + torch.tensor(self.nlist, dtype=int, device=env.DEVICE), + ) + for aa, bb in zip([rd1, gr1, sw1], [rd3, gr3, sw3]): + np.testing.assert_allclose( + aa.detach().cpu().numpy(), + bb.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + def test_jit( + self, + ): + rng = np.random.default_rng() + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + # sea new impl + dd0 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + old_impl=False, + ) + dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE) + dd1 = DescrptSeR.deserialize(dd0.serialize()) + model = torch.jit.script(dd0) + model = torch.jit.script(dd1)