From 04e1159b3f9b3bccd82ab91f0204f65c86cda914 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 28 Oct 2024 15:39:25 -0400 Subject: [PATCH] fix(pt): set device for PT C++ (#4261) Fix #4171. ## Summary by CodeRabbit - **New Features** - Improved GPU initialization to ensure the correct device is utilized. - Enhanced error handling for clearer context on exceptions. - **Bug Fixes** - Updated error handling in multiple methods to catch and rethrow specific exceptions. - Added logic to handle communication-related tensors during computation. --------- Signed-off-by: Jinzhe Zeng --- source/api_cc/src/DeepPotPT.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 4c7aac19b8..780a8007f3 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -80,6 +80,9 @@ void DeepPotPT::init(const std::string& model, device = torch::Device(torch::kCPU); std::cout << "load model from: " << model << " to cpu " << std::endl; } else { +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + DPErrcheck(DPSetDevice(gpu_id)); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM std::cout << "load model from: " << model << " to gpu " << gpu_id << std::endl; }