Skip to content

Commit

Permalink
pass cpu test
Browse files Browse the repository at this point in the history
  • Loading branch information
CaRoLZhangxy committed Feb 26, 2024
1 parent eda885a commit 8b49cf1
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
.toGenericDict();
c10::IValue energy_ = outputs.at("energy");
c10::IValue force_ = outputs.at("extended_force");
c10::IValue virial_ = outputs.at("reduced_virial");
c10::IValue virial_ = outputs.at("virial");
c10::IValue atom_virial_ = outputs.at("extended_virial");
c10::IValue atom_energy_ = outputs.at("atom_energy");
torch::Tensor flat_energy_ = energy_.toTensor().view({-1});
Expand Down Expand Up @@ -238,7 +238,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
c10::IValue energy_ = outputs.at("energy");
c10::IValue force_ = outputs.at("force");
c10::IValue virial_ = outputs.at("virial");
c10::IValue atom_virial_ = outputs.at("atomic_virial");
c10::IValue atom_virial_ = outputs.at("atom_virial");
c10::IValue atom_energy_ = outputs.at("atom_energy");
torch::Tensor flat_energy_ = energy_.toTensor().view({-1});
torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU);
Expand Down

0 comments on commit 8b49cf1

Please sign in to comment.