diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 30641348c7..4fcebfb5d9 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -38,10 +38,11 @@ void DeepPotPT::init(const std::string& model, return; } int gpu_num = torch::cuda::device_count(); - if(gpu_num > 0) + if (gpu_num > 0) { gpu_id = gpu_rank % gpu_num; - else + } else { gpu_id = 0; + } torch::Device device(torch::kCUDA, gpu_id); gpu_enabled = torch::cuda::is_available(); if (!gpu_enabled) {