diff --git a/deepmd/dpmodel/model/linear_atomic_model.py b/deepmd/dpmodel/model/linear_atomic_model.py index d1c726bc6b..dc7e9996c8 100644 --- a/deepmd/dpmodel/model/linear_atomic_model.py +++ b/deepmd/dpmodel/model/linear_atomic_model.py @@ -186,7 +186,12 @@ def deserialize(data) -> List[BaseAtomicModel]: return models @abstractmethod - def _compute_weight(self, extended_coord: np.ndarray, extended_atype: np.ndarray, nlists_: List[np.ndarray]) -> np.ndarray: + def _compute_weight( + self, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + nlists_: List[np.ndarray], + ) -> np.ndarray: """This should be a list of user defined weights that matches the number of models to be combined.""" raise NotImplementedError @@ -242,7 +247,12 @@ def deserialize(cls, data) -> "DPZBLLinearAtomicModel": smin_alpha=smin_alpha, ) - def _compute_weight(self, extended_coord: np.ndarray, extended_atype: np.ndarray, nlists_: List[np.ndarray]) -> List[np.ndarray]: + def _compute_weight( + self, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + nlists_: List[np.ndarray], + ) -> List[np.ndarray]: """ZBL weight. Returns diff --git a/deepmd/dpmodel/model/pairtab_atomic_model.py b/deepmd/dpmodel/model/pairtab_atomic_model.py index 2b679166c2..d4feb970fb 100644 --- a/deepmd/dpmodel/model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/model/pairtab_atomic_model.py @@ -122,7 +122,7 @@ def forward_atomic( mask = nlist >= 0 masked_nlist = nlist * mask - atype = extended_atype[:, : nloc] # (nframes, nloc) + atype = extended_atype[:, :nloc] # (nframes, nloc) pairwise_rr = self._get_pairwise_dist( extended_coord, masked_nlist ) # (nframes, nloc, nnei) diff --git a/deepmd/pt/model/model/linear_atomic_model.py b/deepmd/pt/model/model/linear_atomic_model.py index 0a2b2ef569..8b50f5e4f5 100644 --- a/deepmd/pt/model/model/linear_atomic_model.py +++ b/deepmd/pt/model/model/linear_atomic_model.py @@ -160,7 +160,7 @@ def forward_atomic( ) weights = self._compute_weight(extended_coord, extended_atype, nlists_) - + if self.atomic_bias is not None: raise NotImplementedError("Need to add bias in a future PR.") else: @@ -201,7 +201,9 @@ def deserialize(data) -> List[BaseAtomicModel]: return models @abstractmethod - def _compute_weight(self, extended_coord, extended_atype, nlists_) -> List[torch.Tensor]: + def _compute_weight( + self, extended_coord, extended_atype, nlists_ + ) -> List[torch.Tensor]: """This should be a list of user defined weights that matches the number of models to be combined.""" raise NotImplementedError @@ -234,7 +236,7 @@ def __init__( self.smin_alpha = smin_alpha # this is a placeholder being updated in _compute_weight, to handle Jit attribute init error. - self.zbl_weight = torch.empty(0, dtype=torch.float64) + self.zbl_weight = torch.empty(0, dtype=torch.float64) def serialize(self) -> dict: return { @@ -260,7 +262,12 @@ def deserialize(cls, data) -> "DPZBLLinearAtomicModel": smin_alpha=smin_alpha, ) - def _compute_weight(self, extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlists_: List[torch.Tensor]) -> List[torch.Tensor]: + def _compute_weight( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlists_: List[torch.Tensor], + ) -> List[torch.Tensor]: """ZBL weight. Returns diff --git a/deepmd/pt/model/model/pairtab_atomic_model.py b/deepmd/pt/model/model/pairtab_atomic_model.py index 373cf3ab51..98215191c1 100644 --- a/deepmd/pt/model/model/pairtab_atomic_model.py +++ b/deepmd/pt/model/model/pairtab_atomic_model.py @@ -77,7 +77,7 @@ def __init__( self.sel = sum(sel) else: raise TypeError("sel must be int or list[int]") - + @torch.jit.ignore def _set_pairtab(self, tab_file: str, rcut: float) -> PairTab: return PairTab(tab_file, rcut) @@ -141,7 +141,7 @@ def forward_atomic( mask = nlist >= 0 masked_nlist = nlist * mask - atype = extended_atype[:, : nloc] # (nframes, nloc) + atype = extended_atype[:, :nloc] # (nframes, nloc) pairwise_rr = self._get_pairwise_dist( extended_coord, masked_nlist ) # (nframes, nloc, nnei) diff --git a/source/tests/pt/model/test_linear_atomic_model.py b/source/tests/pt/model/test_linear_atomic_model.py index 656f54412c..211b1f8215 100644 --- a/source/tests/pt/model/test_linear_atomic_model.py +++ b/source/tests/pt/model/test_linear_atomic_model.py @@ -10,13 +10,15 @@ from deepmd.dpmodel.model.linear_atomic_model import ( DPZBLLinearAtomicModel as DPDPZBLLinearAtomicModel, ) -from deepmd.pt.model.model.ener import ZBLModel from deepmd.pt.model.descriptor.se_a import ( DescrptSeA, ) from deepmd.pt.model.model.dp_atomic_model import ( DPAtomicModel, ) +from deepmd.pt.model.model.ener import ( + ZBLModel, +) from deepmd.pt.model.model.linear_atomic_model import ( DPZBLLinearAtomicModel, )