Skip to content

Commit

Permalink
Merge branch 'devel' into rf_finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd authored May 23, 2024
2 parents 21b77d6 + 591b94b commit 5850a2f
Show file tree
Hide file tree
Showing 5 changed files with 1,014 additions and 1,006 deletions.
12 changes: 10 additions & 2 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,13 @@ def _compute_weight(
extended_coord, masked_nlist
)
numerator = torch.sum(
pairwise_rr * torch.exp(-pairwise_rr / self.smin_alpha), dim=-1
) # masked nnei will be zero, no need to handle
torch.where(
nlist_larger != -1,
pairwise_rr * torch.exp(-pairwise_rr / self.smin_alpha),
torch.zeros_like(nlist_larger),
),
dim=-1,
)
denominator = torch.sum(
torch.where(
nlist_larger != -1,
Expand All @@ -556,5 +561,8 @@ def _compute_weight(
smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1
coef[mid_mask] = smooth[mid_mask]
coef[right_mask] = 0

# to handle masked atoms
coef = torch.where(sigma != 0, coef, torch.zeros_like(coef))
self.zbl_weight = coef # nframes, nloc
return [1 - coef.unsqueeze(-1), coef.unsqueeze(-1)] # to match the model order.
4 changes: 2 additions & 2 deletions source/tests/pt/model/test_autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ForceTest:
def test(
self,
):
places = 8
places = 5
delta = 1e-5
natoms = 5
cell = torch.rand([3, 3], dtype=dtype, device="cpu")
Expand Down Expand Up @@ -126,7 +126,7 @@ class VirialTest:
def test(
self,
):
places = 8
places = 5
delta = 1e-4
natoms = 5
cell = torch.rand([3, 3], dtype=dtype, device="cpu")
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt/model/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"use_srtab": f"{CUR_DIR}/water/data/zbl_tab_potential/H2O_tab_potential.txt",
"smin_alpha": 0.1,
"sw_rmin": 0.2,
"sw_rmax": 1.0,
"sw_rmax": 4.0,
"descriptor": {
"type": "se_atten",
"sel": 40,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt/model/test_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def setUp(self):
model_params = copy.deepcopy(model_zbl)
self.type_split = False
self.model = get_model(model_params).to(env.DEVICE)
self.epsilon, self.aprec = 1e-10, None
self.epsilon, self.aprec = None, 5e-2


class TestEnergyModelSpinSeA(unittest.TestCase, SmoothTest):
Expand Down
Loading

0 comments on commit 5850a2f

Please sign in to comment.