From d3944bbfaefff1ffa7132b4d7b6d566a2e6e68d4 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 30 Sep 2024 18:00:42 -0400 Subject: [PATCH] fix(tf): set visible_device_list for TF C++ Fix #4171. Signed-off-by: Jinzhe Zeng --- source/api_cc/src/DataModifierTF.cc | 6 ++++-- source/api_cc/src/DeepPotTF.cc | 6 ++++-- source/api_cc/src/DeepTensorTF.cc | 6 ++++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/source/api_cc/src/DataModifierTF.cc b/source/api_cc/src/DataModifierTF.cc index 324cb14098..a416b51280 100644 --- a/source/api_cc/src/DataModifierTF.cc +++ b/source/api_cc/src/DataModifierTF.cc @@ -49,8 +49,10 @@ void DipoleChargeModifierTF::init(const std::string& model, 0.9); options.config.mutable_gpu_options()->set_allow_growth(true); DPErrcheck(DPSetDevice(gpu_rank % gpu_num)); - std::string str = "/gpu:"; - str += std::to_string(gpu_rank % gpu_num); + std::string str = "/gpu:0"; + // See + // https://github.com/tensorflow/tensorflow/blame/8fac27b486939f40bc8e362b94a16a4a8bb51869/tensorflow/core/protobuf/config.proto#L80 + options.config.visible_device_list = std::to_string(gpu_rank % gpu_num); graph::SetDefaultDevice(str, graph_def); } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/source/api_cc/src/DeepPotTF.cc b/source/api_cc/src/DeepPotTF.cc index 2c09c17a69..5e7a1f24a5 100644 --- a/source/api_cc/src/DeepPotTF.cc +++ b/source/api_cc/src/DeepPotTF.cc @@ -447,8 +447,10 @@ void DeepPotTF::init(const std::string& model, 0.9); options.config.mutable_gpu_options()->set_allow_growth(true); DPErrcheck(DPSetDevice(gpu_rank % gpu_num)); - std::string str = "/gpu:"; - str += std::to_string(gpu_rank % gpu_num); + std::string str = "/gpu:0"; + // See + // https://github.com/tensorflow/tensorflow/blame/8fac27b486939f40bc8e362b94a16a4a8bb51869/tensorflow/core/protobuf/config.proto#L80 + options.config.visible_device_list = std::to_string(gpu_rank % gpu_num); graph::SetDefaultDevice(str, graph_def); } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/source/api_cc/src/DeepTensorTF.cc b/source/api_cc/src/DeepTensorTF.cc index 34a47bc6f3..9dddde037a 100644 --- a/source/api_cc/src/DeepTensorTF.cc +++ b/source/api_cc/src/DeepTensorTF.cc @@ -46,8 +46,10 @@ void DeepTensorTF::init(const std::string &model, 0.9); options.config.mutable_gpu_options()->set_allow_growth(true); DPErrcheck(DPSetDevice(gpu_rank % gpu_num)); - std::string str = "/gpu:"; - str += std::to_string(gpu_rank % gpu_num); + std::string str = "/gpu:0"; + // See + // https://github.com/tensorflow/tensorflow/blame/8fac27b486939f40bc8e362b94a16a4a8bb51869/tensorflow/core/protobuf/config.proto#L80 + options.config.visible_device_list = std::to_string(gpu_rank % gpu_num); graph::SetDefaultDevice(str, graph_def); } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM