Skip to content

Commit

Permalink
support old nlist api
Browse files Browse the repository at this point in the history
  • Loading branch information
CaRoLZhangxy committed Oct 23, 2024
1 parent 389b914 commit 8e82b9f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
10 changes: 8 additions & 2 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,14 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
torch::Tensor sendnum_tensor =
torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
torch::Tensor communicator_tensor;
communicator_tensor = torch::from_blob(const_cast<void*>(lmp_list.world),
{1}, torch::kInt64);
if(lmp_list.world == 0)
{
communicator_tensor = torch::empty({1}, torch::kInt64);
}
else{
communicator_tensor = torch::from_blob(
const_cast<void*>(lmp_list.world), {1}, torch::kInt64);
}

torch::Tensor nswap_tensor = torch::tensor(nswap, int32_option);
int total_send =
Expand Down
26 changes: 16 additions & 10 deletions source/op/pt/comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,19 @@ class Border : public torch::autograd::Function<Border> {
int mpi_init = 0;
MPI_Initialized(&mpi_init);
int cuda_aware = 1;
int me;
int me = 0;
MPI_Comm world;
int world_size = 0;
unpack_communicator(communicator_tensor, world);
MPI_Comm_rank(world, &me);
MPI_Comm_size(world, &world_size);
if(mpi_init)
{
unpack_communicator(communicator_tensor, world);
MPI_Comm_rank(world, &me);
MPI_Comm_size(world, &world_size);
}
MPI_Datatype mpi_type = get_mpi_type<FPTYPE>();
MPI_Request request;
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
if (world_size != 1) {
if (world_size >= 1) {
int version, subversion;
MPI_Get_version(&version, &subversion);
if (version >= 4) {
Expand Down Expand Up @@ -211,15 +214,18 @@ class Border : public torch::autograd::Function<Border> {
MPI_Initialized(&mpi_init);
int world_size = 0;
int cuda_aware = 1;
int me = 0;
MPI_Comm world;
unpack_communicator(communicator_tensor, world);
int me;
MPI_Comm_rank(world, &me);
MPI_Comm_size(world, &world_size);
if(mpi_init)
{
unpack_communicator(communicator_tensor, world);
MPI_Comm_rank(world, &me);
MPI_Comm_size(world, &world_size);
}
MPI_Datatype mpi_type = get_mpi_type<FPTYPE>();
MPI_Request request;
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
if (world_size != 1) {
if (world_size >= 1) {
int version, subversion;
MPI_Get_version(&version, &subversion);
if (version >= 4) {
Expand Down

0 comments on commit 8e82b9f

Please sign in to comment.