diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 4c7aac19b8..eb99abc41a 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -80,6 +80,7 @@ void DeepPotPT::init(const std::string& model, device = torch::Device(torch::kCPU); std::cout << "load model from: " << model << " to cpu " << std::endl; } else { + c10::cuda::set_device(gpu_id); std::cout << "load model from: " << model << " to gpu " << gpu_id << std::endl; }