-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(pt): support spin virial #4545
base: devel
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -54,11 +54,14 @@ def process_spin_input(self, coord, atype, spin): | |||||||||||||||||||||
coord = coord.reshape(nframes, nloc, 3) | ||||||||||||||||||||||
spin = spin.reshape(nframes, nloc, 3) | ||||||||||||||||||||||
atype_spin = torch.concat([atype, atype + self.ntypes_real], dim=-1) | ||||||||||||||||||||||
virtual_coord = coord + spin * (self.virtual_scale_mask.to(atype.device))[ | ||||||||||||||||||||||
atype | ||||||||||||||||||||||
].reshape([nframes, nloc, 1]) | ||||||||||||||||||||||
spin_dist = spin * (self.virtual_scale_mask.to(atype.device))[atype].reshape( | ||||||||||||||||||||||
[nframes, nloc, 1] | ||||||||||||||||||||||
) | ||||||||||||||||||||||
virtual_coord = coord + spin_dist | ||||||||||||||||||||||
coord_spin = torch.concat([coord, virtual_coord], dim=-2) | ||||||||||||||||||||||
return coord_spin, atype_spin | ||||||||||||||||||||||
# for spin virial corr | ||||||||||||||||||||||
coord_corr = torch.concat([torch.zeros_like(coord), -spin_dist], dim=-2) | ||||||||||||||||||||||
return coord_spin, atype_spin, coord_corr | ||||||||||||||||||||||
|
||||||||||||||||||||||
def process_spin_input_lower( | ||||||||||||||||||||||
self, | ||||||||||||||||||||||
|
@@ -78,13 +81,18 @@ def process_spin_input_lower( | |||||||||||||||||||||
""" | ||||||||||||||||||||||
nframes, nall = extended_coord.shape[:2] | ||||||||||||||||||||||
nloc = nlist.shape[1] | ||||||||||||||||||||||
virtual_extended_coord = extended_coord + extended_spin * ( | ||||||||||||||||||||||
extended_spin_dist = extended_spin * ( | ||||||||||||||||||||||
self.virtual_scale_mask.to(extended_atype.device) | ||||||||||||||||||||||
)[extended_atype].reshape([nframes, nall, 1]) | ||||||||||||||||||||||
virtual_extended_coord = extended_coord + extended_spin_dist | ||||||||||||||||||||||
virtual_extended_atype = extended_atype + self.ntypes_real | ||||||||||||||||||||||
extended_coord_updated = concat_switch_virtual( | ||||||||||||||||||||||
extended_coord, virtual_extended_coord, nloc | ||||||||||||||||||||||
) | ||||||||||||||||||||||
# for spin virial corr | ||||||||||||||||||||||
extended_coord_corr = concat_switch_virtual( | ||||||||||||||||||||||
torch.zeros_like(extended_coord), -extended_spin_dist, nloc | ||||||||||||||||||||||
) | ||||||||||||||||||||||
extended_atype_updated = concat_switch_virtual( | ||||||||||||||||||||||
extended_atype, virtual_extended_atype, nloc | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
@@ -100,6 +108,7 @@ def process_spin_input_lower( | |||||||||||||||||||||
extended_atype_updated, | ||||||||||||||||||||||
nlist_updated, | ||||||||||||||||||||||
mapping_updated, | ||||||||||||||||||||||
extended_coord_corr, | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
def process_spin_output( | ||||||||||||||||||||||
|
@@ -367,7 +376,7 @@ def spin_sampled_func(): | |||||||||||||||||||||
sampled = sampled_func() | ||||||||||||||||||||||
spin_sampled = [] | ||||||||||||||||||||||
for sys in sampled: | ||||||||||||||||||||||
coord_updated, atype_updated = self.process_spin_input( | ||||||||||||||||||||||
coord_updated, atype_updated, _ = self.process_spin_input( | ||||||||||||||||||||||
sys["coord"], sys["atype"], sys["spin"] | ||||||||||||||||||||||
) | ||||||||||||||||||||||
tmp_dict = { | ||||||||||||||||||||||
|
@@ -398,7 +407,9 @@ def forward_common( | |||||||||||||||||||||
do_atomic_virial: bool = False, | ||||||||||||||||||||||
) -> dict[str, torch.Tensor]: | ||||||||||||||||||||||
nframes, nloc = atype.shape | ||||||||||||||||||||||
coord_updated, atype_updated = self.process_spin_input(coord, atype, spin) | ||||||||||||||||||||||
coord_updated, atype_updated, coord_corr_for_virial = self.process_spin_input( | ||||||||||||||||||||||
coord, atype, spin | ||||||||||||||||||||||
) | ||||||||||||||||||||||
if aparam is not None: | ||||||||||||||||||||||
aparam = self.expand_aparam(aparam, nloc * 2) | ||||||||||||||||||||||
model_ret = self.backbone_model.forward_common( | ||||||||||||||||||||||
|
@@ -408,6 +419,7 @@ def forward_common( | |||||||||||||||||||||
fparam=fparam, | ||||||||||||||||||||||
aparam=aparam, | ||||||||||||||||||||||
do_atomic_virial=do_atomic_virial, | ||||||||||||||||||||||
coord_corr_for_virial=coord_corr_for_virial, | ||||||||||||||||||||||
) | ||||||||||||||||||||||
model_output_type = self.backbone_model.model_output_type() | ||||||||||||||||||||||
if "mask" in model_output_type: | ||||||||||||||||||||||
|
@@ -454,6 +466,7 @@ def forward_common_lower( | |||||||||||||||||||||
extended_atype_updated, | ||||||||||||||||||||||
nlist_updated, | ||||||||||||||||||||||
mapping_updated, | ||||||||||||||||||||||
extended_coord_corr_for_virial, | ||||||||||||||||||||||
) = self.process_spin_input_lower( | ||||||||||||||||||||||
extended_coord, extended_atype, extended_spin, nlist, mapping=mapping | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
@@ -469,6 +482,7 @@ def forward_common_lower( | |||||||||||||||||||||
do_atomic_virial=do_atomic_virial, | ||||||||||||||||||||||
comm_dict=comm_dict, | ||||||||||||||||||||||
extra_nlist_sort=extra_nlist_sort, | ||||||||||||||||||||||
extended_coord_corr=extended_coord_corr_for_virial, | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ensure Similar to the previous comment, verify that |
||||||||||||||||||||||
) | ||||||||||||||||||||||
model_output_type = self.backbone_model.model_output_type() | ||||||||||||||||||||||
if "mask" in model_output_type: | ||||||||||||||||||||||
|
@@ -541,6 +555,11 @@ def translated_output_def(self): | |||||||||||||||||||||
output_def["force"].squeeze(-2) | ||||||||||||||||||||||
output_def["force_mag"] = deepcopy(out_def_data["energy_derv_r_mag"]) | ||||||||||||||||||||||
output_def["force_mag"].squeeze(-2) | ||||||||||||||||||||||
if self.do_grad_c("energy"): | ||||||||||||||||||||||
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) | ||||||||||||||||||||||
output_def["virial"].squeeze(-2) | ||||||||||||||||||||||
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) | ||||||||||||||||||||||
output_def["atom_virial"].squeeze(-3) | ||||||||||||||||||||||
Comment on lines
+558
to
+562
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Assign the result of The Apply this diff to fix the issue: - output_def["virial"].squeeze(-2)
+ output_def["virial"] = output_def["virial"].squeeze(-2)
- output_def["atom_virial"].squeeze(-3)
+ output_def["atom_virial"] = output_def["atom_virial"].squeeze(-3) 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||
return output_def | ||||||||||||||||||||||
|
||||||||||||||||||||||
def forward( | ||||||||||||||||||||||
|
@@ -569,7 +588,10 @@ def forward( | |||||||||||||||||||||
if self.backbone_model.do_grad_r("energy"): | ||||||||||||||||||||||
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) | ||||||||||||||||||||||
model_predict["force_mag"] = model_ret["energy_derv_r_mag"].squeeze(-2) | ||||||||||||||||||||||
# not support virial by far | ||||||||||||||||||||||
if self.backbone_model.do_grad_c("energy"): | ||||||||||||||||||||||
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) | ||||||||||||||||||||||
if do_atomic_virial: | ||||||||||||||||||||||
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) | ||||||||||||||||||||||
return model_predict | ||||||||||||||||||||||
|
||||||||||||||||||||||
@torch.jit.export | ||||||||||||||||||||||
|
@@ -606,5 +628,10 @@ def forward_lower( | |||||||||||||||||||||
model_predict["extended_force_mag"] = model_ret[ | ||||||||||||||||||||||
"energy_derv_r_mag" | ||||||||||||||||||||||
].squeeze(-2) | ||||||||||||||||||||||
# not support virial by far | ||||||||||||||||||||||
if self.backbone_model.do_grad_c("energy"): | ||||||||||||||||||||||
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) | ||||||||||||||||||||||
if do_atomic_virial: | ||||||||||||||||||||||
model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( | ||||||||||||||||||||||
-3 | ||||||||||||||||||||||
) | ||||||||||||||||||||||
return model_predict |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -141,11 +141,17 @@ def test( | |
cell = (cell) + 5.0 * torch.eye(3, device="cpu") | ||
coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) | ||
coord = torch.matmul(coord, cell) | ||
spin = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) | ||
atype = torch.IntTensor([0, 0, 0, 1, 1]) | ||
# assumes input to be numpy tensor | ||
coord = coord.numpy() | ||
spin = spin.numpy() | ||
cell = cell.numpy() | ||
test_keys = ["energy", "force", "virial"] | ||
test_spin = getattr(self, "test_spin", False) | ||
if not test_spin: | ||
test_keys = ["energy", "force", "virial"] | ||
else: | ||
test_keys = ["energy", "force", "force_mag", "virial"] | ||
|
||
def np_infer( | ||
new_cell, | ||
|
@@ -157,6 +163,7 @@ def np_infer( | |
).unsqueeze(0), | ||
torch.tensor(new_cell, device="cpu").unsqueeze(0), | ||
atype, | ||
spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0), | ||
) | ||
# detach | ||
ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys} | ||
Comment on lines
+166
to
169
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ensure compatibility of tensor devices When creating tensors within the Apply this diff to correct the device assignment: - spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0),
+ spins=torch.tensor(spin, device=new_cell.device).unsqueeze(0),
|
||
|
@@ -251,3 +258,11 @@ def setUp(self) -> None: | |
self.type_split = False | ||
self.test_spin = True | ||
self.model = get_model(model_params).to(env.DEVICE) | ||
|
||
|
||
class TestEnergyModelSpinSeAVirial(unittest.TestCase, VirialTest): | ||
def setUp(self) -> None: | ||
model_params = copy.deepcopy(model_spin) | ||
self.type_split = False | ||
self.test_spin = True | ||
self.model = get_model(model_params).to(env.DEVICE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confirm the compatibility of the new argument
coord_corr_for_virial
Check if the backbone model's
forward_common
method is designed to acceptcoord_corr_for_virial
. If not, update the backbone model accordingly or modify the call to prevent runtime errors.