Skip to content

Commit

Permalink
reformat spin model
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 22, 2024
1 parent c7b6c42 commit 2915121
Show file tree
Hide file tree
Showing 10 changed files with 301 additions and 46 deletions.
8 changes: 8 additions & 0 deletions deepmd/dpmodel/utils/spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
self.ntypes_real = len(use_spin)
self.ntypes_spin = use_spin.count(True)
self.use_spin = np.array(use_spin)
self.spin_mask = self.use_spin.astype(np.int64)
self.ntypes_real_and_spin = self.ntypes_real + self.ntypes_spin
self.ntypes_placeholder = self.ntypes_real - self.ntypes_spin
self.ntypes_input = 2 * self.ntypes_real # with placeholder for input types
Expand Down Expand Up @@ -185,6 +186,10 @@ def deserialize(
def get_virtual_scale_mask(self):
pass

@abstractmethod
def get_spin_mask(self):
pass


class Spin(BaseSpin):
def __init__(self, *args, **kwargs):
Expand All @@ -194,6 +199,9 @@ def __init__(self, *args, **kwargs):
def get_virtual_scale_mask(self):
return self.virtual_scale_mask

def get_spin_mask(self):
return self.spin_mask

def serialize(
self,
) -> dict:
Expand Down
252 changes: 229 additions & 23 deletions deepmd/pt/model/model/dp_spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@

import torch

from deepmd.pt.utils.utils import (
dict_to_device,
)
from deepmd.utils.path import (
DPPath,
)


class SpinModel(torch.nn.Module):
"""A spin model wrapper, with spin input preprocess and output split."""
Expand All @@ -21,29 +28,117 @@ def __init__(
self.backbone_model = backbone_model
self.spin = spin

def preprocess_spin_input(self, coord, atype, spin):
nframes, nloc = coord.shape[:-1]
def process_spin_input(self, coord, atype, spin):
"""Generate virtual coordinates and types, concat into the input."""
nframes, natom = coord.shape[:-1]
atype_spin = torch.concat([atype, atype + self.spin.ntypes_real], dim=-1)
virtual_scale_mask = self.spin.get_virtual_scale_mask()
virtual_coord = coord + spin * torch.gather(
virtual_scale_mask, -1, index=atype.view(-1)
).reshape([nframes, nloc, 1])
virtual_coord = coord + spin * virtual_scale_mask[atype].reshape(
[nframes, natom, 1]
)
coord_spin = torch.concat([coord, virtual_coord], dim=-2)
return coord_spin, atype_spin

def preprocess_spin_output(self, atype, force):
nframes, nloc_double = force.shape[:2]
nloc = nloc_double // 2
def process_spin_output(self, atype, force):
"""Split the output gradient of both real and virtual atoms, and scale the latter."""
nframes, natom_double = force.shape[:2]
natom = natom_double // 2
virtual_scale_mask = self.spin.get_virtual_scale_mask()
atmoic_mask = torch.gather(
virtual_scale_mask, -1, index=atype.view(-1)
).reshape([nframes, nloc, 1])
force_real, force_mag = torch.split(force, [nloc, nloc], dim=1)
force_mag = (force_mag.view([nframes, nloc, -1]) * atmoic_mask).view(
atmoic_mask = virtual_scale_mask[atype].reshape([nframes, natom, 1])
force_real, force_mag = torch.split(force, [natom, natom], dim=1)
force_mag = (force_mag.view([nframes, natom, -1]) * atmoic_mask).view(
force_mag.shape
)
return force_real, force_mag, atmoic_mask > 0.0

@staticmethod
def extend_nlist(extended_atype, nlist):
nframes, nloc, nnei = nlist.shape
nall = extended_atype.shape[1]
nlist_mask = nlist != -1
nlist[nlist == -1] = 0
nlist_shift = nlist + nall
nlist[~nlist_mask] = -1
nlist_shift[~nlist_mask] = -1
self_spin = torch.arange(0, nloc, dtype=nlist.dtype, device=nlist.device) + nall
self_spin = self_spin.view(1, -1, 1).expand(nframes, -1, -1)
# self spin + real neighbor + virtual neighbor
# nf x nloc x (1 + nnei + nnei)
extended_nlist = torch.cat([self_spin, nlist, nlist_shift], dim=-1)
# nf x (nloc + nloc) x (1 + nnei + nnei)
extended_nlist = torch.cat(
[extended_nlist, -1 * torch.ones_like(extended_nlist)], dim=-2
)
return extended_nlist

@staticmethod
def extend_mapping(mapping, nloc: int):
return torch.cat([mapping, mapping + nloc], dim=-1)

@staticmethod
def switch_virtual_loc(extended_tensor, nloc: int):
"""
Switch the virtual atoms of nloc ones from [nall: nall+nloc] to [nloc: nloc+nloc],
to assure the atom types of first nloc * 2 atoms in nall * 2 to be right.
"""
nframes, nall_double = extended_tensor.shape[:2]
nall = nall_double // 2
swithed_tensor = torch.zeros_like(extended_tensor)
swithed_tensor[:, :nloc] = extended_tensor[:, :nloc]
swithed_tensor[:, nloc : nloc + nloc] = extended_tensor[:, nall : nall + nloc]
swithed_tensor[:, nloc + nloc : nloc + nall] = extended_tensor[:, nloc:nall]
swithed_tensor[:, nloc + nall :] = extended_tensor[:, nloc + nall :]
return swithed_tensor

@staticmethod
def switch_nlist(nlist_updated, nall: int):
nframes, nloc_double = nlist_updated.shape[:2]
nloc = nloc_double // 2
first_part_index = (nloc <= nlist_updated) & (nlist_updated < nall)
second_part_index = (nall <= nlist_updated) & (nlist_updated < (nall + nloc))
nlist_updated[first_part_index] += nloc
nlist_updated[second_part_index] -= nall - nloc
return nlist_updated

def extend_switch_input(
self,
extended_coord,
extended_atype,
extended_spin,
nlist,
mapping: Optional[torch.Tensor] = None,
):
"""
Add `extended_spin` into `extended_coord` to generate virtual atoms, and extend `nlist` and `mapping`.
Note that the final `extended_coord_updated` with shape [nframes, nall + nall, 3] has the following order:
- [:, :nloc]: original nloc real atoms.
- [:, nloc: nloc + nloc]: virtual atoms corresponding to nloc real atoms.
- [:, nloc + nloc: nloc + nall]: ghost real atoms.
- [:, nloc + nall: nall + nall]: virtual atoms corresponding to ghost real atoms.
"""
nframes, nall = extended_coord.shape[:2]
nloc = nlist.shape[1]
# add spin but ignore the index switch
extended_coord_updated, extended_atype_updated = self.process_spin_input(
extended_coord, extended_atype, extended_spin
)
# extend the nlist and mapping but ignore the index switch
nlist_updated = self.extend_nlist(extended_atype, nlist)
mapping_updated = None
if mapping is not None:
mapping_updated = self.extend_mapping(mapping, nloc)
# process the index switch
extended_coord_updated = self.switch_virtual_loc(extended_coord_updated, nloc)
extended_atype_updated = self.switch_virtual_loc(extended_atype_updated, nloc)
mapping_updated = self.switch_virtual_loc(mapping_updated, nloc)
nlist_updated = self.switch_nlist(nlist_updated, nall)
return (
extended_coord_updated,
extended_atype_updated,
nlist_updated,
mapping_updated,
)

def __getattr__(self, name):
"""Get attribute from the wrapped model."""
if (
Expand All @@ -55,6 +150,47 @@ def __getattr__(self, name):
else:
return getattr(self.backbone_model, name)

def compute_or_load_stat(
self,
sampled,
stat_file_path: Optional[DPPath] = None,
):
"""
Compute or load the statistics parameters of the model,
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
and saved in the `stat_file_path`(s).
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
and load the calculated statistics parameters.
Parameters
----------
sampled
The sampled data frames from different data systems.
stat_file_path
The dictionary of paths to the statistics files.
"""
spin_sampled = []
for sys in sampled:
dict_to_device(sys)
coord_updated, atype_updated = self.process_spin_input(
sys["coord"], sys["atype"], sys["spin"]
)
tmp_dict = {
"coord": coord_updated,
"atype": atype_updated,
}
if "natoms" in sys:
natoms = sys["natoms"]
tmp_dict["natoms"] = torch.cat(
[2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], dim=-1
)
for item_key in sys.keys():
if item_key not in ["coord", "atype", "spin", "natoms"]:
tmp_dict[item_key] = sys[item_key]
spin_sampled.append(tmp_dict)
self.backbone_model.compute_or_load_stat(spin_sampled, stat_file_path)

def forward_common(
self,
coord,
Expand All @@ -65,7 +201,7 @@ def forward_common(
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
coord_updated, atype_updated = self.preprocess_spin_input(coord, atype, spin)
coord_updated, atype_updated = self.process_spin_input(coord, atype, spin)
model_ret = self.backbone_model.forward_common(
coord_updated,
atype_updated,
Expand All @@ -76,34 +212,67 @@ def forward_common(
)
if self.fitting_net is not None:
var_name = self.fitting_net.var_name
if self.do_grad(var_name):
if self.do_grad_r(var_name):
force_all = model_ret[f"{var_name}_derv_r"]
(
model_ret[f"{var_name}_derv_r_real"],
model_ret[f"{var_name}_derv_r_mag"],
model_ret["atmoic_mask"],
) = self.preprocess_spin_output(atype, force_all)
) = self.process_spin_output(atype, force_all)
else:
force_all = model_ret["dforce"]
(
model_ret["dforce_real"],
model_ret["dforce_mag"],
model_ret["atmoic_mask"],
) = self.preprocess_spin_output(atype, force_all)
) = self.process_spin_output(atype, force_all)
return model_ret

def forward_common_lower(
self,
extended_coord,
extended_atype,
extended_spin,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
## TODO preprocess
raise NotImplementedError("Not implemented forward_common_lower for spin")
(
extended_coord_updated,
extended_atype_updated,
nlist_updated,
mapping_updated,
) = self.extend_switch_input(
extended_coord, extended_atype, extended_spin, nlist, mapping=mapping
)
model_ret = self.backbone_model.forward_common_lower(
extended_coord_updated,
extended_atype_updated,
nlist_updated,
mapping=mapping_updated,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.fitting_net is not None:
var_name = self.fitting_net.var_name
if self.do_grad_r(var_name):
force_all = model_ret[f"{var_name}_derv_r"]
(
model_ret[f"{var_name}_derv_r_real"],
model_ret[f"{var_name}_derv_r_mag"],
model_ret["atmoic_mask"],
) = self.process_spin_output(extended_atype, force_all)
else:
force_all = model_ret["dforce"]
(
model_ret["dforce_real"],
model_ret["dforce_mag"],
model_ret["atmoic_mask"],
) = self.process_spin_output(extended_atype, force_all)
return model_ret


class SpinEnergyModel(SpinModel):
Expand Down Expand Up @@ -141,13 +310,16 @@ def forward(
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
model_predict["atmoic_mask"] = model_ret["atmoic_mask"]
if self.do_grad("energy"):
if self.do_grad_r("energy"):
model_predict["force_real"] = model_ret["energy_derv_r_real"].squeeze(-2)
model_predict["force_mag"] = model_ret["energy_derv_r_mag"].squeeze(-2)
if self.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
else:
assert model_ret["dforce_real"] is not None
assert model_ret["dforce_mag"] is not None
model_predict["force_real"] = model_ret["dforce_real"]
model_predict["force_mag"] = model_ret["dforce_mag"]
return model_predict
Expand All @@ -156,11 +328,45 @@ def forward_lower(
self,
extended_coord,
extended_atype,
extended_spin,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
## TODO preprocess
raise NotImplementedError("Not implemented forward_lower for spin")
model_ret = self.forward_common_lower(
extended_coord,
extended_atype,
extended_spin,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.fitting_net is not None:
model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad_r("energy"):
model_predict["extended_force_real"] = model_ret[
"energy_derv_r_real"
].squeeze(-2)
model_predict["extended_force_mag"] = model_ret[
"energy_derv_r_mag"
].squeeze(-2)
if self.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["extended_virial"] = model_ret[
"energy_derv_c"
].squeeze(-3)
else:
assert model_ret["dforce_real"] is not None
assert model_ret["dforce_mag"] is not None
model_predict["extended_force_real"] = model_ret["dforce_real"]
model_predict["extended_force_mag"] = model_ret["dforce_mag"]
else:
model_predict = model_ret
return model_predict
14 changes: 10 additions & 4 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,18 +137,24 @@ def data_stat_key(self):
return ["bias_atom_e"]

def compute_output_stats(self, merged, stat_file_path: Optional[DPPath] = None):
energy = [item["energy"] for item in merged]
energy = torch.cat([item["energy"] for item in merged])
data_mixed_type = "real_natoms_vec" in merged[0]
if data_mixed_type:
input_natoms = [item["real_natoms_vec"] for item in merged]
input_natoms = torch.cat([item["real_natoms_vec"] for item in merged])
else:
input_natoms = [item["natoms"] for item in merged]
input_natoms = torch.cat([item["natoms"] for item in merged])
if stat_file_path is not None:
stat_file_path = stat_file_path / "bias_atom_e"
if stat_file_path is not None and stat_file_path.is_file():
bias_atom_e = stat_file_path.load_numpy()
else:
bias_atom_e = compute_output_bias(energy, input_natoms, rcond=self.rcond)
if hasattr(self, "emask"):
type_mask = self.emask(torch.arange(0, self.ntypes).unsqueeze(0))
else:
type_mask = None
bias_atom_e = compute_output_bias(
energy, input_natoms, rcond=self.rcond, type_mask=type_mask
)
if stat_file_path is not None:
stat_file_path.save_numpy(bias_atom_e)
assert all(x is not None for x in [bias_atom_e])
Expand Down
Loading

0 comments on commit 2915121

Please sign in to comment.