From 8b49cf1c8839b4f894f329f98fcb83edca9c30dd Mon Sep 17 00:00:00 2001 From: CaRoLZhangxy Date: Mon, 26 Feb 2024 02:35:01 +0000 Subject: [PATCH] pass cpu test --- source/api_cc/src/DeepPotPT.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 34f54fa8be..89226ff4ee 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -124,7 +124,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, .toGenericDict(); c10::IValue energy_ = outputs.at("energy"); c10::IValue force_ = outputs.at("extended_force"); - c10::IValue virial_ = outputs.at("reduced_virial"); + c10::IValue virial_ = outputs.at("virial"); c10::IValue atom_virial_ = outputs.at("extended_virial"); c10::IValue atom_energy_ = outputs.at("atom_energy"); torch::Tensor flat_energy_ = energy_.toTensor().view({-1}); @@ -238,7 +238,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, c10::IValue energy_ = outputs.at("energy"); c10::IValue force_ = outputs.at("force"); c10::IValue virial_ = outputs.at("virial"); - c10::IValue atom_virial_ = outputs.at("atomic_virial"); + c10::IValue atom_virial_ = outputs.at("atom_virial"); c10::IValue atom_energy_ = outputs.at("atom_energy"); torch::Tensor flat_energy_ = energy_.toTensor().view({-1}); torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU);