From 5f7707ac9b5226bcbcc1b351e8a479cd0dc3d419 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 26 Oct 2024 13:57:02 -0400 Subject: [PATCH] fix(pt): set device for PT C++ Fix #4171. Signed-off-by: Jinzhe Zeng --- source/api_cc/src/DeepPotPT.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 4c7aac19b8..eb99abc41a 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -80,6 +80,7 @@ void DeepPotPT::init(const std::string& model, device = torch::Device(torch::kCPU); std::cout << "load model from: " << model << " to cpu " << std::endl; } else { + c10::cuda::set_device(gpu_id); std::cout << "load model from: " << model << " to gpu " << gpu_id << std::endl; }