Skip to content

Commit

Permalink
resolve conversation
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Sep 5, 2024
1 parent c96d77a commit 315a174
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 64 deletions.
48 changes: 16 additions & 32 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def init_subclass_params(sub_data, sub_class):
seed=child_seed(seed, 0),
)
self.use_three_body = self.repinit_args.use_three_body
if self.repinit_args.use_three_body:
if self.use_three_body:
self.repinit_three_body = DescrptBlockSeTTebd(
self.repinit_args.three_body_rcut,
self.repinit_args.three_body_rcut_smth,
Expand Down Expand Up @@ -521,37 +521,21 @@ def init_subclass_params(sub_data, sub_class):
ln_eps=self.repformer_args.ln_eps,
seed=child_seed(seed, 1),
)
if not self.use_three_body:
self.rcut_list = [self.repformers.get_rcut(), self.repinit.get_rcut()]
self.nsel_list = [self.repformers.get_nsel(), self.repinit.get_nsel()]
else:
if (
self.repinit_three_body.get_rcut() >= self.repformers.get_rcut()
and self.repinit_three_body.get_nsel() >= self.repformers.get_nsel()
):
self.rcut_list = [
self.repformers.get_rcut(),
self.repinit_three_body.get_rcut(),
self.repinit.get_rcut(),
]
self.nsel_list = [
self.repformers.get_nsel(),
self.repinit_three_body.get_nsel(),
self.repinit.get_nsel(),
]
else:
self.rcut_list = [
self.repinit_three_body.get_rcut(),
self.repformers.get_rcut(),
self.repinit.get_rcut(),
]
self.nsel_list = [
self.repinit_three_body.get_nsel(),
self.repformers.get_nsel(),
self.repinit.get_nsel(),
]
self.rcut_list = sorted(self.rcut_list)
self.nsel_list = sorted(self.nsel_list)
self.rcsl_list = [
(self.repformers.get_rcut(), self.repformers.get_nsel()),
(self.repinit.get_rcut(), self.repinit.get_nsel()),
]
if self.use_three_body:
self.rcsl_list.append(
(self.repinit_three_body.get_rcut(), self.repinit_three_body.get_sel())
)
self.rcsl_list.sort()
for ii in range(1, len(self.rcsl_list)):
assert (
self.rcsl_list[ii - 1][1] <= self.rcsl_list[ii][1]
), "rcut and sel are not in the same order"
self.rcut_list = [ii[0] for ii in self.rcsl_list]
self.nsel_list = [ii[1] for ii in self.rcsl_list]
self.use_econf_tebd = use_econf_tebd
self.use_tebd_bias = use_tebd_bias
self.type_map = type_map
Expand Down
46 changes: 16 additions & 30 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def init_subclass_params(sub_data, sub_class):
seed=child_seed(seed, 0),
)
self.use_three_body = self.repinit_args.use_three_body
if self.repinit_args.use_three_body:
if self.use_three_body:
self.repinit_three_body = DescrptBlockSeTTebd(
self.repinit_args.three_body_rcut,
self.repinit_args.three_body_rcut_smth,
Expand Down Expand Up @@ -240,35 +240,21 @@ def init_subclass_params(sub_data, sub_class):
seed=child_seed(seed, 1),
old_impl=old_impl,
)
if not self.use_three_body:
self.rcut_list = [self.repformers.get_rcut(), self.repinit.get_rcut()]
self.nsel_list = [self.repformers.get_nsel(), self.repinit.get_nsel()]
else:
if (
self.repinit_three_body.get_rcut() >= self.repformers.get_rcut()
and self.repinit_three_body.get_nsel() >= self.repformers.get_nsel()
):
self.rcut_list = [
self.repformers.get_rcut(),
self.repinit_three_body.get_rcut(),
self.repinit.get_rcut(),
]
self.nsel_list = [
self.repformers.get_nsel(),
self.repinit_three_body.get_nsel(),
self.repinit.get_nsel(),
]
else:
self.rcut_list = [
self.repinit_three_body.get_rcut(),
self.repformers.get_rcut(),
self.repinit.get_rcut(),
]
self.nsel_list = [
self.repinit_three_body.get_nsel(),
self.repformers.get_nsel(),
self.repinit.get_nsel(),
]
self.rcsl_list = [
(self.repformers.get_rcut(), self.repformers.get_nsel()),
(self.repinit.get_rcut(), self.repinit.get_nsel()),
]
if self.use_three_body:
self.rcsl_list.append(
(self.repinit_three_body.get_rcut(), self.repinit_three_body.get_sel())
)
self.rcsl_list.sort()
for ii in range(1, len(self.rcsl_list)):
assert (
self.rcsl_list[ii - 1][1] <= self.rcsl_list[ii][1]
), "rcut and sel are not in the same order"
self.rcut_list = [ii[0] for ii in self.rcsl_list]
self.nsel_list = [ii[1] for ii in self.rcsl_list]
self.use_econf_tebd = use_econf_tebd
self.use_tebd_bias = use_tebd_bias
self.type_map = type_map
Expand Down
5 changes: 3 additions & 2 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,8 +976,9 @@ def _cal_hg(
(nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device
)
else:
invnnei = (1.0 / (float(nnei) ** 0.5)) * torch.ones(
(nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device
invnnei = torch.rsqrt(
float(nnei)
* torch.ones((nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device)
)
# nb x nloc x 3 x ng2
h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei
Expand Down

0 comments on commit 315a174

Please sign in to comment.