Skip to content

Commit

Permalink
Remove virial support
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Mar 5, 2024
1 parent b7554b9 commit 8a72e01
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 53 deletions.
12 changes: 9 additions & 3 deletions deepmd/dpmodel/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def is_aparam_nall(self) -> bool:
"""
return self.backbone_model.is_aparam_nall()

Check warning on line 219 in deepmd/dpmodel/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/spin_model.py#L219

Added line #L219 was not covered by tests

def model_output_type(self) -> str:
def model_output_type(self) -> List[str]:
"""Get the output type for the model."""
return self.backbone_model.model_output_type()

Check warning on line 223 in deepmd/dpmodel/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/spin_model.py#L223

Added line #L223 was not covered by tests

Expand Down Expand Up @@ -310,7 +310,10 @@ def call(
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
var_name = self.backbone_model.fitting.var_name
model_output_type = self.backbone_model.model_output_type()
if "mask" in model_output_type:
model_output_type.pop(model_output_type.index("mask"))
var_name = model_output_type[0]
model_predict[f"{var_name}"] = np.split(

Check warning on line 317 in deepmd/dpmodel/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/spin_model.py#L313-L317

Added lines #L313 - L317 were not covered by tests
model_predict[f"{var_name}"], [nloc], axis=1
)[0]
Expand Down Expand Up @@ -376,7 +379,10 @@ def call_lower(
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
var_name = self.backbone_model.fitting.var_name
model_output_type = self.backbone_model.model_output_type()
if "mask" in model_output_type:
model_output_type.pop(model_output_type.index("mask"))
var_name = model_output_type[0]
model_predict[f"{var_name}"] = np.split(

Check warning on line 386 in deepmd/dpmodel/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/spin_model.py#L382-L386

Added lines #L382 - L386 were not covered by tests
model_predict[f"{var_name}"], [nloc], axis=1
)[0]
Expand Down
43 changes: 11 additions & 32 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,15 +491,7 @@ def forward(
if self.backbone_model.do_grad_r("energy"):
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
model_predict["force_mag"] = model_ret["energy_derv_r_mag"].squeeze(-2)

Check warning on line 493 in deepmd/pt/model/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/spin_model.py#L487-L493

Added lines #L487 - L493 were not covered by tests
if self.backbone_model.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
else:
assert model_ret["dforce_real"] is not None
assert model_ret["dforce_mag"] is not None
model_predict["force"] = model_ret["dforce_real"]
model_predict["force_mag"] = model_ret["dforce_mag"]
# not support virial by far
return model_predict

Check warning on line 495 in deepmd/pt/model/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/spin_model.py#L495

Added line #L495 was not covered by tests

@torch.jit.export
Expand All @@ -524,27 +516,14 @@ def forward_lower(
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.backbone_model.fitting_net is not None:
model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
model_predict["mask_mag"] = model_ret["mask_mag"]
if self.backbone_model.do_grad_r("energy"):
model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2)
model_predict["extended_force_mag"] = model_ret[
"energy_derv_r_mag"
].squeeze(-2)
if self.backbone_model.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["extended_virial"] = model_ret[
"energy_derv_c"
].squeeze(-3)
else:
assert model_ret["dforce_real"] is not None
assert model_ret["dforce_mag"] is not None
model_predict["extended_force"] = model_ret["dforce_real"]
model_predict["extended_force_mag"] = model_ret["dforce_mag"]
else:
model_predict = model_ret
model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
model_predict["mask_mag"] = model_ret["mask_mag"]
if self.backbone_model.do_grad_r("energy"):
model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2)
model_predict["extended_force_mag"] = model_ret[

Check warning on line 525 in deepmd/pt/model/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/spin_model.py#L519-L525

Added lines #L519 - L525 were not covered by tests
"energy_derv_r_mag"
].squeeze(-2)
# not support virial by far
return model_predict

Check warning on line 529 in deepmd/pt/model/model/spin_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/spin_model.py#L529

Added line #L529 was not covered by tests
19 changes: 1 addition & 18 deletions source/tests/pt/model/test_autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,11 @@ def test(
cell = (cell) + 5.0 * torch.eye(3, device="cpu")
coord = torch.rand([natoms, 3], dtype=dtype, device="cpu")
coord = torch.matmul(coord, cell)
spin = torch.rand([natoms, 3], dtype=dtype, device="cpu")
atype = torch.IntTensor([0, 0, 0, 1, 1])
# assumes input to be numpy tensor
coord = coord.numpy()
cell = cell.numpy()
spin = spin.numpy()
test_spin = getattr(self, "test_spin", False)
if not test_spin:
test_keys = ["energy", "force", "virial"]
else:
test_keys = ["energy", "force", "force_mag", "virial"]
test_keys = ["energy", "force", "virial"]

def np_infer(
new_cell,
Expand All @@ -156,9 +150,6 @@ def np_infer(
).unsqueeze(0),
torch.tensor(new_cell, device="cpu").unsqueeze(0),
atype,
spins=torch.tensor(
stretch_box(spin, cell, new_cell), device="cpu"
).unsqueeze(0),
)
# detach
ret = {
Expand Down Expand Up @@ -269,11 +260,3 @@ def setUp(self):
self.type_split = False
self.test_spin = True
self.model = get_model(model_params).to(env.DEVICE)


class TestEnergyModelSpinSeAVirial(unittest.TestCase, VirialTest):
def setUp(self):
model_params = copy.deepcopy(model_spin)
self.type_split = False
self.test_spin = True
self.model = get_model(model_params).to(env.DEVICE)

0 comments on commit 8a72e01

Please sign in to comment.