Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 1, 2024
1 parent 0ce23f4 commit 5190903
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions deepmd/dpmodel/model/pair_tab.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
List,
Expand All @@ -11,14 +12,14 @@
FittingOutputDef,
OutputVariableDef,
)
from deepmd.utils.pair_tab import (
PairTab,
)

from .base_atomic_model import (
BaseAtomicModel,
)

from deepmd.utils.pair_tab import (
PairTab,
)

class PairTabModel(BaseAtomicModel):
"""Pairwise tabulation energy model.
Expand Down Expand Up @@ -54,8 +55,8 @@ def __init__(
self.tab = PairTab(self.tab_file, rcut=rcut)
self.ntypes = self.tab.ntypes

self.tab_info, self.tab_data = self.tab.get()
self.tab_info, self.tab_data = self.tab.get()

if isinstance(sel, int):
self.sel = sel
elif isinstance(sel, list):
Expand Down Expand Up @@ -126,14 +127,12 @@ def forward_atomic(
raw_atomic_energy = self._pair_tabulated_inter(nlist, atype, j_type, rr)

atomic_energy = 0.5 * np.sum(
np.where(
nlist != -1, raw_atomic_energy, np.zeros_like(raw_atomic_energy)
),
np.where(nlist != -1, raw_atomic_energy, np.zeros_like(raw_atomic_energy)),
dim=-1,
)

return {"energy": atomic_energy}

def _pair_tabulated_inter(
self,
nlist: np.array,
Expand Down Expand Up @@ -184,7 +183,7 @@ def _pair_tabulated_inter(
if np.any(uu < 0):
raise Exception("coord go beyond table lower boundary")

idx = uu.astype(np.int)
idx = uu.astype(int)

uu -= idx

Expand Down Expand Up @@ -248,17 +247,24 @@ def _extract_spline_coefficient(
The spline coefficient. (nframes, nloc, nnei, 4), shape may be squeezed.
"""
# (nframes, nloc, nnei)
expanded_i_type = np.broadcast_to(i_type[:, :, np.newaxis], (i_type.shape[0], i_type.shape[1], j_type.shape[-1]))
expanded_i_type = np.broadcast_to(
i_type[:, :, np.newaxis],
(i_type.shape[0], i_type.shape[1], j_type.shape[-1]),
)

# (nframes, nloc, nnei, nspline, 4)
expanded_tab_data = tab_data[expanded_i_type, j_type]

# (nframes, nloc, nnei, 1, 4)
expanded_idx = np.broadcast_to(idx[..., np.newaxis, np.newaxis], idx.shape + (1, 4))
expanded_idx = np.broadcast_to(
idx[..., np.newaxis, np.newaxis], idx.shape + (1, 4)
)
clipped_indices = np.clip(expanded_idx, 0, nspline - 1).astype(int)

# (nframes, nloc, nnei, 4)
final_coef = np.squeeze(np.take_along_axis(expanded_tab_data, clipped_indices, 3))
final_coef = np.squeeze(
np.take_along_axis(expanded_tab_data, clipped_indices, 3)
)

# when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`.
final_coef[expanded_idx.squeeze() > nspline] = 0
Expand Down

0 comments on commit 5190903

Please sign in to comment.