diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 69d73c18e2..e353af9051 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -124,8 +124,9 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU); ener.assign(cpu_energy_.data_ptr(), cpu_energy_.data_ptr() + cpu_energy_.numel()); + torch::Tensor resize_atom_energy = atom_energy_.toTensor().detach().resize_({1,natoms}); torch::Tensor flat_atom_energy_ = - atom_energy_.toTensor().view({-1}).to(floatType); + resize_atom_energy.view({-1}).to(floatType); torch::Tensor cpu_atom_energy_ = flat_atom_energy_.to(torch::kCPU); atom_energy.assign( cpu_atom_energy_.data_ptr(),