Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): pass mapping from LAMMPS to PT #4351

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions source/api_c/include/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ extern "C" {
/** C API version. Bumped whenever the API is changed.
* @since API version 22
*/
#define DP_C_API_VERSION 24
#define DP_C_API_VERSION 25

/**
* @brief Neighbor list.
Expand All @@ -31,7 +31,7 @@ extern DP_Nlist* DP_NewNlist(int inum_,
int* ilist_,
int* numneigh_,
int** firstneigh_);
/*
/**
* @brief Create a new neighbor list with communication capabilities.
* @details This function extends DP_NewNlist by adding support for parallel
* communication, allowing the neighbor list to be used in distributed
Expand Down Expand Up @@ -68,7 +68,7 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
int* recvproc,
void* world);

/*
/**
* @brief Set mask for a neighbor list.
*
* @param nl Neighbor list.
Expand All @@ -78,6 +78,16 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
**/
extern void DP_NlistSetMask(DP_Nlist* nl, int mask);

/**
* @brief Set mapping for a neighbor list.
*
* @param nl Neighbor list.
* @param mapping mapping from all atoms to real atoms, in size nall.
* @since API version 25
*
**/
extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping);

/**
* @brief Delete a neighbor list.
*
Expand Down
5 changes: 5 additions & 0 deletions source/api_c/include/deepmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,11 @@ struct InputNlist {
* @brief Set mask for this neighbor list.
*/
void set_mask(int mask) { DP_NlistSetMask(nl, mask); };
/**
* @brief Set mapping for this neighbor list.
* @param mapping mapping from all atoms to real atoms, in size nall.
*/
void set_mapping(int *mapping) { DP_NlistSetMapping(nl, mapping); };
};

/**
Expand Down
3 changes: 3 additions & 0 deletions source/api_c/src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ DP_Nlist* DP_NewNlist_comm(int inum_,
return new_nl;
}
void DP_NlistSetMask(DP_Nlist* nl, int mask) { nl->nl.set_mask(mask); }
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
nl->nl.set_mapping(mapping);
}
njzjz marked this conversation as resolved.
Show resolved Hide resolved
void DP_DeleteNlist(DP_Nlist* nl) { delete nl; }

// DP Base Model
Expand Down
9 changes: 9 additions & 0 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,15 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
comm_dict.insert("recv_num", recvnum_tensor);
comm_dict.insert("communicator", communicator_tensor);
}
if (lmp_list.mapping) {
std::vector<std::int64_t> mapping(nall_real);
for (size_t ii = 0; ii < nall_real; ii++) {
mapping[ii] = lmp_list.mapping[fwd_map[ii]];
}
mapping_tensor =
torch::from_blob(mapping.data(), {1, nall_real}, int_option)
.to(device);
}
}
at::Tensor firstneigh = createNlistTensor(nlist_data.jlist);
firstneigh_tensor = firstneigh.to(torch::kInt64).to(device);
Expand Down
6 changes: 6 additions & 0 deletions source/lib/include/neighbor_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ struct InputNlist {
void* world;
/// mask to the neighbor index
int mask = 0xFFFFFFFF;
/// mapping from all atoms to real atoms, in the size of nall
int* mapping = nullptr;
InputNlist()
: inum(0),
ilist(NULL),
Expand Down Expand Up @@ -99,6 +101,10 @@ struct InputNlist {
* @brief Set mask for this neighbor list.
*/
void set_mask(int mask_) { mask = mask_; };
/**
* @brief Set mapping for this neighbor list.
*/
void set_mapping(int* mapping_) { mapping = mapping_; };
njzjz marked this conversation as resolved.
Show resolved Hide resolved
};

/**
Expand Down
11 changes: 11 additions & 0 deletions source/lmp/fix_dplr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,14 @@ void FixDPLR::pre_force(int vflag) {
int nghost = atom->nghost;
int nall = nlocal + nghost;

// mapping (for DPA-2 JAX)
std::vector<int> mapping_vec(nall, -1);
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
for (size_t ii = 0; ii < nall; ++ii) {
mapping_vec[ii] = atom->map(atom->tag[ii]);
}
}
njzjz marked this conversation as resolved.
Show resolved Hide resolved

// if (eflag_atom) {
// error->all(FLERR,"atomic energy calculation is not supported by this
// fix\n");
Expand Down Expand Up @@ -499,6 +507,9 @@ void FixDPLR::pre_force(int vflag) {
deepmd_compat::InputNlist lmp_list(list->inum, list->ilist, list->numneigh,
list->firstneigh);
lmp_list.set_mask(NEIGHMASK);
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
lmp_list.set_mapping(mapping_vec.data());
}
njzjz marked this conversation as resolved.
Show resolved Hide resolved
// declear output
vector<FLOAT_PREC> tensor;
// compute
Expand Down
11 changes: 11 additions & 0 deletions source/lmp/pair_deepmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ void PairDeepMD::compute(int eflag, int vflag) {
}
}

// mapping (for DPA-2 JAX)
std::vector<int> mapping_vec(nall, -1);
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
for (size_t ii = 0; ii < nall; ++ii) {
mapping_vec[ii] = atom->map(atom->tag[ii]);
}
}

if (do_compute_aparam) {
make_aparam_from_compute(daparam);
} else if (aparam.size() > 0) {
Expand Down Expand Up @@ -198,6 +206,9 @@ void PairDeepMD::compute(int eflag, int vflag) {
commdata_->firstrecv, commdata_->sendlist, commdata_->sendproc,
commdata_->recvproc, &world);
lmp_list.set_mask(NEIGHMASK);
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
lmp_list.set_mapping(mapping_vec.data());
}
deepmd_compat::InputNlist extend_lmp_list;
if (single_model || multi_models_no_mod_devi) {
// cvflag_atom is the right flag for the cvatom matrix
Expand Down
Loading