Skip to content

Commit

Permalink
fix bug: device init
Browse files Browse the repository at this point in the history
  • Loading branch information
CaRoLZhangxy committed Feb 3, 2024
1 parent 6750312 commit 3de36b0
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
const std::vector<VALUETYPE>& box,
const InputNlist& lmp_list,
const int& ago) {
torch::Device device;
if(cpu_enabled)
torch::Device device(torch::kCPU);
device = torch::Device(torch::kCPU);
else
torch::Device device(torch::kCUDA, gpu_id);
device = torch::Device(torch::kCUDA, gpu_id);
std::vector<VALUETYPE> coord_wrapped = coord;
int natoms = atype.size();
auto options = torch::TensorOptions().dtype(torch::kFloat64);
Expand Down Expand Up @@ -190,10 +191,11 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box) {
torch::Device device;
if(cpu_enabled)
torch::Device device(torch::kCPU);
device = torch::Device(torch::kCPU);
else
torch::Device device(torch::kCUDA, gpu_id);
device = torch::Device(torch::kCUDA, gpu_id);
std::vector<VALUETYPE> coord_wrapped = coord;
int natoms = atype.size();
auto options = torch::TensorOptions().dtype(torch::kFloat64);
Expand Down

0 comments on commit 3de36b0

Please sign in to comment.