Skip to content

Commit

Permalink
Add uts and reformat SpinModel
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 25, 2024
1 parent 65d6162 commit 07e3fca
Show file tree
Hide file tree
Showing 20 changed files with 1,167 additions and 387 deletions.
6 changes: 0 additions & 6 deletions deepmd/dpmodel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@
phys2inter,
to_face_distance,
)
from .spin import (
BaseSpin,
Spin,
)

__all__ = [
"EnvMat",
Expand Down Expand Up @@ -63,6 +59,4 @@
"to_face_distance",
"AtomExcludeMask",
"PairExcludeMask",
"BaseSpin",
"Spin",
]
79 changes: 69 additions & 10 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def eval_model(
coords: Union[np.ndarray, torch.Tensor],
cells: Optional[Union[np.ndarray, torch.Tensor]],
atom_types: Union[np.ndarray, torch.Tensor, List[int]],
spins: Optional[Union[np.ndarray, torch.Tensor]] = None,
atomic: bool = False,
infer_batch_size: int = 2,
denoise: bool = False,
Expand All @@ -414,6 +415,8 @@ def eval_model(
energy_out = []
atomic_energy_out = []
force_out = []
force_real_out = []
force_mag_out = []
virial_out = []
atomic_virial_out = []
updated_coord_out = []
Expand All @@ -426,11 +429,15 @@ def eval_model(
if isinstance(coords, torch.Tensor):
if cells is not None:
assert isinstance(cells, torch.Tensor), err_msg
if spins is not None:
assert isinstance(spins, torch.Tensor), err_msg
assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list)
atom_types = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
elif isinstance(coords, np.ndarray):
if cells is not None:
assert isinstance(cells, np.ndarray), err_msg
if spins is not None:
assert isinstance(spins, np.ndarray), err_msg
assert isinstance(atom_types, np.ndarray) or isinstance(atom_types, list)
atom_types = np.array(atom_types, dtype=np.int32)
return_tensor = False
Expand All @@ -450,6 +457,13 @@ def eval_model(
coord_input = torch.tensor(
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
spin_input = None
if spins is not None:
spin_input = torch.tensor(
spins.reshape([-1, natoms, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
box_input = None
if cells is None:
Expand All @@ -465,9 +479,19 @@ def eval_model(
batch_coord = coord_input[ii * infer_batch_size : (ii + 1) * infer_batch_size]
batch_atype = type_input[ii * infer_batch_size : (ii + 1) * infer_batch_size]
batch_box = None
batch_spin = None
if spin_input is not None:
batch_spin = spin_input[ii * infer_batch_size : (ii + 1) * infer_batch_size]
if pbc:
batch_box = box_input[ii * infer_batch_size : (ii + 1) * infer_batch_size]
batch_output = model(batch_coord, batch_atype, box=batch_box)
input_dict = {
"coord": batch_coord,
"atype": batch_atype,
"box": batch_box,
}
if getattr(model, "__USE_SPIN_INPUT__", False):
input_dict["spin"] = batch_spin
batch_output = model(**input_dict)
if isinstance(batch_output, tuple):
batch_output = batch_output[0]
if not return_tensor:
Expand All @@ -479,6 +503,10 @@ def eval_model(
)
if "force" in batch_output:
force_out.append(batch_output["force"].detach().cpu().numpy())
if "force_real" in batch_output:
force_real_out.append(batch_output["force_real"].detach().cpu().numpy())
if "force_mag" in batch_output:
force_mag_out.append(batch_output["force_mag"].detach().cpu().numpy())
if "virial" in batch_output:
virial_out.append(batch_output["virial"].detach().cpu().numpy())
if "atom_virial" in batch_output:
Expand All @@ -498,6 +526,10 @@ def eval_model(
atomic_energy_out.append(batch_output["atom_energy"])
if "force" in batch_output:
force_out.append(batch_output["force"])
if "force_real" in batch_output:
force_real_out.append(batch_output["force_real"])
if "force_mag" in batch_output:
force_mag_out.append(batch_output["force_mag"])
if "virial" in batch_output:
virial_out.append(batch_output["virial"])
if "atom_virial" in batch_output:
Expand All @@ -518,6 +550,16 @@ def eval_model(
force_out = (
np.concatenate(force_out) if force_out else np.zeros([nframes, natoms, 3])
)
force_real_out = (
np.concatenate(force_real_out)
if force_real_out
else np.zeros([nframes, natoms, 3])
)
force_mag_out = (
np.concatenate(force_mag_out)
if force_mag_out
else np.zeros([nframes, natoms, 3])
)
virial_out = (
np.concatenate(virial_out) if virial_out else np.zeros([nframes, 3, 3])
)
Expand Down Expand Up @@ -552,6 +594,20 @@ def eval_model(
[nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
force_real_out = (
torch.cat(force_real_out)
if force_real_out
else torch.zeros(
[nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
force_mag_out = (
torch.cat(force_mag_out)
if force_mag_out
else torch.zeros(
[nframes, natoms, 3], dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
)
virial_out = (
torch.cat(virial_out)
if virial_out
Expand All @@ -571,13 +627,16 @@ def eval_model(
if denoise:
return updated_coord_out, logits_out
else:
if not atomic:
return energy_out, force_out, virial_out
results_dict = {
"energy": energy_out,
"virial": virial_out,
}
if not getattr(model, "__USE_SPIN_INPUT__", False):
results_dict["force"] = force_out
else:
return (
energy_out,
force_out,
virial_out,
atomic_energy_out,
atomic_virial_out,
)
results_dict["force_real"] = force_real_out
results_dict["force_mag"] = force_mag_out
if atomic:
results_dict["atom_energy"] = atomic_energy_out
results_dict["atom_virial"] = atomic_virial_out
return results_dict
6 changes: 3 additions & 3 deletions deepmd/pt/loss/ener_spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):

if self.has_fm and "force_mag" in model_pred and "force_mag" in label:
nframes = model_pred["force_mag"].shape[0]
atmoic_mask = model_pred["atmoic_mask"].expand([-1, -1, 3])
label_force_mag = label["force_mag"][atmoic_mask].view(nframes, -1, 3)
model_pred_force_mag = model_pred["force_mag"][atmoic_mask].view(
atomic_mask = model_pred["mask_mag"].expand([-1, -1, 3])
label_force_mag = label["force_mag"][atomic_mask].view(nframes, -1, 3)
model_pred_force_mag = model_pred["force_mag"][atomic_mask].view(
nframes, -1, 3
)
if not self.use_l1_all:
Expand Down
14 changes: 14 additions & 0 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,20 @@ def get_stats(self) -> Dict[str, StatItem]:
"""Get the statistics of the descriptor."""
raise NotImplementedError

def get_emask(self, nlist: torch.Tensor, atype: torch.Tensor) -> torch.Tensor:
"""
Compute the pair-wise type mask for given nlist and atype,
with shape same as nlist.
1 for include and 0 for exclude.
"""
if hasattr(self, "emask"):
exclude_mask = self.emask(nlist, atype)
else:
exclude_mask = torch.ones_like(
nlist, dtype=torch.int32, device=nlist.device
)
return exclude_mask

def share_params(self, base_class, shared_level, resume=False):
assert (
self.__class__ == base_class.__class__
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def serialize(self) -> dict:
"embeddings": obj.filter_layers.serialize(),
"env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(),
"exclude_types": obj.exclude_types,
"env_protection": obj.env_protection,
"@variables": {
"davg": obj["davg"].detach().cpu().numpy(),
"dstd": obj["dstd"].detach().cpu().numpy(),
Expand Down
10 changes: 5 additions & 5 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,13 @@
from deepmd.pt.model.task import (
Fitting,
)
from deepmd.pt.utils.spin import (
from deepmd.utils.spin import (
Spin,
)

from .dp_model import (
DPModel,
)
from .dp_spin_model import (
SpinEnergyModel,
SpinModel,
)
from .dp_zbl_model import (
DPZBLModel,
)
Expand All @@ -47,6 +43,10 @@
from .model import (
BaseModel,
)
from .spin_model import (
SpinEnergyModel,
SpinModel,
)


def get_zbl_model(model_params):
Expand Down
7 changes: 4 additions & 3 deletions deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,14 @@ def forward_lower(
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad_r("energy"):
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2)
if self.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(
-3
)
else:
assert model_ret["dforce"] is not None
model_predict["dforce"] = model_ret["dforce"]
model_predict = model_ret
return model_predict
Loading

0 comments on commit 07e3fca

Please sign in to comment.