diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 624c6299f8..de2611d3e9 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -72,8 +72,9 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, const InputNlist& lmp_list, const int& ago) { torch::Device device(torch::kCUDA, gpu_id); - if(cpu_enabled) + if (cpu_enabled) { device = torch::Device(torch::kCPU); + } std::vector coord_wrapped = coord; int natoms = atype.size(); auto options = torch::TensorOptions().dtype(torch::kFloat64); @@ -194,8 +195,9 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, const std::vector& atype, const std::vector& box) { torch::Device device(torch::kCUDA, gpu_id); - if(cpu_enabled) + if (cpu_enabled) { device = torch::Device(torch::kCPU); + } std::vector coord_wrapped = coord; int natoms = atype.size(); auto options = torch::TensorOptions().dtype(torch::kFloat64);