diff --git a/src/ML-MACE/pair_mace.cpp b/src/ML-MACE/pair_mace.cpp index abfcba5bd76..09fa9088b44 100644 --- a/src/ML-MACE/pair_mace.cpp +++ b/src/ML-MACE/pair_mace.cpp @@ -25,7 +25,9 @@ #include "memory.h" #include "neigh_list.h" #include "neighbor.h" +#include "universe.h" +#include #include #include #include @@ -298,10 +300,12 @@ void PairMACE::coeff(int narg, char **arg) } else { std::cout << "CUDA found, setting device type to torch::kCUDA." << std::endl; MPI_Comm local; - MPI_Comm_split_type(world, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &local); + MPI_Comm_split_type(universe->uworld, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &local); int localrank; MPI_Comm_rank(local, &localrank); - device = c10::Device(torch::kCUDA,localrank); + int nDevices; + cudaGetDeviceCount(&nDevices); + device = c10::Device(torch::kCUDA,localrank % nDevices); } std::cout << "Loading MACE model from \"" << arg[2] << "\" ...";