Skip to content

Commit

Permalink
Merge branch 'devel' of github.com:Chengqian-Zhang/deepmd-kit into devel
Browse files Browse the repository at this point in the history
  • Loading branch information
Chengqian-Zhang committed Mar 18, 2024
2 parents 3f95a82 + 0d58b71 commit 0dc11d6
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
28 changes: 18 additions & 10 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,20 +415,28 @@ def _extract_spline_coefficient(
# (nframes, nloc, nnei)
expanded_i_type = i_type.unsqueeze(-1).expand(-1, -1, j_type.shape[-1])

# (nframes, nloc, nnei, nspline, 4)
expanded_tab_data = tab_data[expanded_i_type, j_type]

# (nframes, nloc, nnei, 1, 4)
expanded_idx = idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, -1, 4)

# handle the case where idx is beyond the number of splines
clipped_indices = torch.clamp(expanded_idx, 0, nspline - 1).to(torch.int64)

clipped_indices = torch.clamp(idx, 0, nspline - 1).to(torch.int64)

nframes = i_type.shape[0]
nloc = i_type.shape[1]
nnei = j_type.shape[2]
ntypes = tab_data.shape[0]
# tab_data_idx: (nframes, nloc, nnei)
tab_data_idx = (
expanded_i_type * ntypes * nspline + j_type * nspline + clipped_indices
)
# tab_data: (ntype, ntype, nspline, 4)
tab_data = tab_data.view(ntypes * ntypes * nspline, 4)
# tab_data_idx: (nframes * nloc * nnei, 4)
tab_data_idx = tab_data_idx.view(nframes * nloc * nnei, 1).expand(-1, 4)
# (nframes, nloc, nnei, 4)
final_coef = torch.gather(expanded_tab_data, 3, clipped_indices).squeeze()
final_coef = torch.gather(tab_data, 0, tab_data_idx).view(
nframes, nloc, nnei, 4
)

# when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`.
final_coef[expanded_idx.squeeze() > nspline] = 0
final_coef[idx > nspline] = 0
return final_coef

@staticmethod
Expand Down
6 changes: 1 addition & 5 deletions deepmd/pt/model/task/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,7 @@ def compute_output_stats(
bias_atom_e = compute_output_stats(
merged, self.ntypes, stat_file_path, self.rcond, self.atom_ener
)
self.bias_atom_e.copy_(
torch.tensor(bias_atom_e, device=env.DEVICE).view(
[self.ntypes, self.dim_out]
)
)
self.bias_atom_e.copy_(bias_atom_e.view([self.ntypes, self.dim_out]))

def output_def(self) -> FittingOutputDef:
return FittingOutputDef(
Expand Down
2 changes: 1 addition & 1 deletion source/lmp/plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ if(DEFINED LAMMPS_SOURCE_ROOT OR DEFINED LAMMPS_VERSION)
install(
CODE "execute_process( \
COMMAND ${CMAKE_COMMAND} -E create_symlink \
../${CMAKE_SHARED_LIBRARY_PREFIX}${libname}${CMAKE_SHARED_LIBRARY_SUFFIX} \
../${CMAKE_SHARED_MODULE_PREFIX}${libname}${CMAKE_SHARED_MODULE_SUFFIX} \
${CMAKE_INSTALL_PREFIX}/lib/${libname}/${PLUGINNAME} \
)")
endif()
Expand Down

0 comments on commit 0dc11d6

Please sign in to comment.