diff --git a/source/api_cc/src/DataModifierTF.cc b/source/api_cc/src/DataModifierTF.cc index 324cb14098..a416b51280 100644 --- a/source/api_cc/src/DataModifierTF.cc +++ b/source/api_cc/src/DataModifierTF.cc @@ -49,8 +49,10 @@ void DipoleChargeModifierTF::init(const std::string& model, 0.9); options.config.mutable_gpu_options()->set_allow_growth(true); DPErrcheck(DPSetDevice(gpu_rank % gpu_num)); - std::string str = "/gpu:"; - str += std::to_string(gpu_rank % gpu_num); + std::string str = "/gpu:0"; + // See + // https://github.com/tensorflow/tensorflow/blame/8fac27b486939f40bc8e362b94a16a4a8bb51869/tensorflow/core/protobuf/config.proto#L80 + options.config.visible_device_list = std::to_string(gpu_rank % gpu_num); graph::SetDefaultDevice(str, graph_def); } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/source/api_cc/src/DeepPotTF.cc b/source/api_cc/src/DeepPotTF.cc index 2c09c17a69..5e7a1f24a5 100644 --- a/source/api_cc/src/DeepPotTF.cc +++ b/source/api_cc/src/DeepPotTF.cc @@ -447,8 +447,10 @@ void DeepPotTF::init(const std::string& model, 0.9); options.config.mutable_gpu_options()->set_allow_growth(true); DPErrcheck(DPSetDevice(gpu_rank % gpu_num)); - std::string str = "/gpu:"; - str += std::to_string(gpu_rank % gpu_num); + std::string str = "/gpu:0"; + // See + // https://github.com/tensorflow/tensorflow/blame/8fac27b486939f40bc8e362b94a16a4a8bb51869/tensorflow/core/protobuf/config.proto#L80 + options.config.visible_device_list = std::to_string(gpu_rank % gpu_num); graph::SetDefaultDevice(str, graph_def); } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/source/api_cc/src/DeepTensorTF.cc b/source/api_cc/src/DeepTensorTF.cc index 34a47bc6f3..9dddde037a 100644 --- a/source/api_cc/src/DeepTensorTF.cc +++ b/source/api_cc/src/DeepTensorTF.cc @@ -46,8 +46,10 @@ void DeepTensorTF::init(const std::string &model, 0.9); options.config.mutable_gpu_options()->set_allow_growth(true); DPErrcheck(DPSetDevice(gpu_rank % gpu_num)); - std::string str = "/gpu:"; - str += std::to_string(gpu_rank % gpu_num); + std::string str = "/gpu:0"; + // See + // https://github.com/tensorflow/tensorflow/blame/8fac27b486939f40bc8e362b94a16a4a8bb51869/tensorflow/core/protobuf/config.proto#L80 + options.config.visible_device_list = std::to_string(gpu_rank % gpu_num); graph::SetDefaultDevice(str, graph_def); } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM