Skip to content

Commit

Permalink
fix: jit
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Feb 8, 2024
1 parent 4edff13 commit 656d979
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 33 deletions.
6 changes: 3 additions & 3 deletions deepmd/dpmodel/model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def forward_atomic(
)["energy"]
for model, nl in zip(self.models, nlists_)
]
self.weights = self._compute_weight(extended_coord, nlists_)
self.weights = self._compute_weight(extended_coord, extended_atype, nlists_)
self.atomic_bias = None
if self.atomic_bias is not None:
raise NotImplementedError("Need to add bias in a future PR.")

Check warning on line 152 in deepmd/dpmodel/model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/linear_atomic_model.py#L152

Added line #L152 was not covered by tests
Expand Down Expand Up @@ -186,7 +186,7 @@ def deserialize(data) -> List[BaseAtomicModel]:
return models

@abstractmethod
def _compute_weight(self, *args, **kwargs) -> np.ndarray:
def _compute_weight(self, extended_coord: np.ndarray, extended_atype: np.ndarray, nlists_: List[np.ndarray]) -> np.ndarray:
"""This should be a list of user defined weights that matches the number of models to be combined."""
raise NotImplementedError

Check warning on line 191 in deepmd/dpmodel/model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/linear_atomic_model.py#L191

Added line #L191 was not covered by tests

Expand Down Expand Up @@ -242,7 +242,7 @@ def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
smin_alpha=smin_alpha,
)

def _compute_weight(self, extended_coord, nlists_) -> List[np.ndarray]:
def _compute_weight(self, extended_coord: np.ndarray, extended_atype: np.ndarray, nlists_: List[np.ndarray]) -> List[np.ndarray]:
"""ZBL weight.
Returns
Expand Down
19 changes: 10 additions & 9 deletions deepmd/dpmodel/model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,14 @@ def forward_atomic(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> Dict[str, np.ndarray]:
self.nframes, self.nloc, self.nnei = nlist.shape
extended_coord = extended_coord.reshape(self.nframes, -1, 3)
nframes, nloc, nnei = nlist.shape
extended_coord = extended_coord.reshape(nframes, -1, 3)

# this will mask all -1 in the nlist
mask = nlist >= 0
masked_nlist = nlist * mask

atype = extended_atype[:, : self.nloc] # (nframes, nloc)
atype = extended_atype[:, : nloc] # (nframes, nloc)
pairwise_rr = self._get_pairwise_dist(
extended_coord, masked_nlist
) # (nframes, nloc, nnei)
Expand All @@ -141,7 +141,7 @@ def forward_atomic(
atomic_energy = 0.5 * np.sum(
np.where(nlist != -1, raw_atomic_energy, np.zeros_like(raw_atomic_energy)),
axis=-1,
).reshape(self.nframes, self.nloc, 1)
).reshape(nframes, nloc, 1)

return {"energy": atomic_energy}

Expand Down Expand Up @@ -180,17 +180,18 @@ def _pair_tabulated_inter(
This function is used to calculate the pairwise energy between two atoms.
It uses a table containing cubic spline coefficients calculated in PairTab.
"""
nframes, nloc, nnei = nlist.shape
rmin = self.tab_info[0]
hh = self.tab_info[1]
hi = 1.0 / hh

self.nspline = int(self.tab_info[2] + 0.1)
nspline = int(self.tab_info[2] + 0.1)

uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei)

# if nnei of atom 0 has -1 in the nlist, uu would be 0.
# this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms.
uu = np.where(nlist != -1, uu, self.nspline + 1)
uu = np.where(nlist != -1, uu, nspline + 1)

if np.any(uu < 0):
raise Exception("coord go beyond table lower boundary")
Expand All @@ -199,14 +200,14 @@ def _pair_tabulated_inter(

uu -= idx
table_coef = self._extract_spline_coefficient(
i_type, j_type, idx, self.tab_data, self.nspline
i_type, j_type, idx, self.tab_data, nspline
)
table_coef = table_coef.reshape(self.nframes, self.nloc, self.nnei, 4)
table_coef = table_coef.reshape(nframes, nloc, nnei, 4)
ener = self._calculate_ener(table_coef, uu)
# here we need to overwrite energy to zero at rcut and beyond.
mask_beyond_rcut = rr >= self.rcut
# also overwrite values beyond extrapolation to zero
extrapolation_mask = rr >= self.tab.rmin + self.nspline * self.tab.hh
extrapolation_mask = rr >= self.tab.rmin + nspline * self.tab.hh
ener[mask_beyond_rcut] = 0
ener[extrapolation_mask] = 0

Expand Down
20 changes: 12 additions & 8 deletions deepmd/pt/model/model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
):
super().__init__()
self.models = torch.nn.ModuleList(models)
self.atomic_bias = None
self.distinguish_type_list = [
model.distinguish_types() for model in self.models
]
Expand Down Expand Up @@ -94,9 +95,9 @@ def _sort_rcuts_sels(self) -> Tuple[List[float], List[int]]:

def forward_atomic(
self,
extended_coord,
extended_atype,
nlist,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -158,14 +159,14 @@ def forward_atomic(
)["energy"]
)

self.weights = self._compute_weight(extended_coord, nlists_)
self.atomic_bias = None
weights = self._compute_weight(extended_coord, extended_atype, nlists_)

if self.atomic_bias is not None:
raise NotImplementedError("Need to add bias in a future PR.")

Check warning on line 165 in deepmd/pt/model/model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/linear_atomic_model.py#L165

Added line #L165 was not covered by tests
else:
fit_ret = {
"energy": torch.sum(
torch.stack(ener_list) * torch.stack(self.weights), dim=0
torch.stack(ener_list) * torch.stack(weights), dim=0
),
} # (nframes, nloc, 1)
return fit_ret
Expand Down Expand Up @@ -200,7 +201,7 @@ def deserialize(data) -> List[BaseAtomicModel]:
return models

@abstractmethod
def _compute_weight(self, *args, **kwargs) -> List[torch.Tensor]:
def _compute_weight(self, extended_coord, extended_atype, nlists_) -> List[torch.Tensor]:
"""This should be a list of user defined weights that matches the number of models to be combined."""
raise NotImplementedError

Check warning on line 206 in deepmd/pt/model/model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/linear_atomic_model.py#L206

Added line #L206 was not covered by tests

Expand Down Expand Up @@ -232,6 +233,9 @@ def __init__(
self.sw_rmax = sw_rmax
self.smin_alpha = smin_alpha

# this is a placeholder being updated in _compute_weight, to handle Jit attribute init error.
self.zbl_weight = torch.empty(0, dtype=torch.float64)

def serialize(self) -> dict:
return {
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
Expand All @@ -256,7 +260,7 @@ def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
smin_alpha=smin_alpha,
)

def _compute_weight(self, extended_coord, nlists_) -> List[torch.Tensor]:
def _compute_weight(self, extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlists_: List[torch.Tensor]) -> List[torch.Tensor]:
"""ZBL weight.
Returns
Expand Down
31 changes: 18 additions & 13 deletions deepmd/pt/model/model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
super().__init__()
self.tab_file = tab_file
self.rcut = rcut
self.tab = PairTab(self.tab_file, rcut=rcut)
self.tab = self._set_pairtab(tab_file, rcut)

# handle deserialization with no input file
if self.tab_file is not None:
Expand All @@ -77,6 +77,10 @@ def __init__(
self.sel = sum(sel)
else:
raise TypeError("sel must be int or list[int]")

@torch.jit.ignore
def _set_pairtab(self, tab_file: str, rcut: float) -> PairTab:
return PairTab(tab_file, rcut)

def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(
Expand Down Expand Up @@ -120,29 +124,29 @@ def deserialize(cls, data) -> "PairTabModel":

def forward_atomic(
self,
extended_coord,
extended_atype,
nlist,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
self.nframes, self.nloc, self.nnei = nlist.shape
extended_coord = extended_coord.view(self.nframes, -1, 3)
nframes, nloc, nnei = nlist.shape
extended_coord = extended_coord.view(nframes, -1, 3)
if self.do_grad():
extended_coord.requires_grad_(True)

# this will mask all -1 in the nlist
mask = nlist >= 0
masked_nlist = nlist * mask

atype = extended_atype[:, : self.nloc] # (nframes, nloc)
atype = extended_atype[:, : nloc] # (nframes, nloc)
pairwise_rr = self._get_pairwise_dist(
extended_coord, masked_nlist
) # (nframes, nloc, nnei)
self.tab_data = self.tab_data.view(
self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4
int(self.tab_info[-1]), int(self.tab_info[-1]), int(self.tab_info[2]), 4
)

# to calculate the atomic_energy, we need 3 tensors, i_type, j_type, pairwise_rr
Expand Down Expand Up @@ -200,17 +204,18 @@ def _pair_tabulated_inter(
This function is used to calculate the pairwise energy between two atoms.
It uses a table containing cubic spline coefficients calculated in PairTab.
"""
nframes, nloc, nnei = nlist.shape
rmin = self.tab_info[0]
hh = self.tab_info[1]
hi = 1.0 / hh

self.nspline = int(self.tab_info[2] + 0.1)
nspline = int(self.tab_info[2] + 0.1)

uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei)

# if nnei of atom 0 has -1 in the nlist, uu would be 0.
# this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms.
uu = torch.where(nlist != -1, uu, self.nspline + 1)
uu = torch.where(nlist != -1, uu, nspline + 1)

if torch.any(uu < 0):
raise Exception("coord go beyond table lower boundary")
Expand All @@ -220,15 +225,15 @@ def _pair_tabulated_inter(
uu -= idx

table_coef = self._extract_spline_coefficient(
i_type, j_type, idx, self.tab_data, self.nspline
i_type, j_type, idx, self.tab_data, nspline
)
table_coef = table_coef.view(self.nframes, self.nloc, self.nnei, 4)
table_coef = table_coef.view(nframes, nloc, nnei, 4)
ener = self._calculate_ener(table_coef, uu)

# here we need to overwrite energy to zero at rcut and beyond.
mask_beyond_rcut = rr >= self.rcut
# also overwrite values beyond extrapolation to zero
extrapolation_mask = rr >= self.tab.rmin + self.nspline * self.tab.hh
extrapolation_mask = rr >= rmin + nspline * hh
ener[mask_beyond_rcut] = 0
ener[extrapolation_mask] = 0

Expand Down
3 changes: 3 additions & 0 deletions source/tests/pt/model/test_linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from deepmd.dpmodel.model.linear_atomic_model import (
DPZBLLinearAtomicModel as DPDPZBLLinearAtomicModel,
)
from deepmd.pt.model.model.ener import ZBLModel
from deepmd.pt.model.descriptor.se_a import (
DescrptSeA,
)
Expand Down Expand Up @@ -153,6 +154,7 @@ def setUp(self, mock_loadtxt):
env.DEVICE
)
self.md2 = DPDPZBLLinearAtomicModel.deserialize(self.md0.serialize())
self.md3 = ZBLModel(dp_model, zbl_model, sw_rmin=0.1, sw_rmax=0.25)

def test_self_consistency(self):
args = [
Expand All @@ -172,6 +174,7 @@ def test_self_consistency(self):

def test_jit(self):
torch.jit.script(self.md1)
torch.jit.script(self.md3)


if __name__ == "__main__":
Expand Down

0 comments on commit 656d979

Please sign in to comment.