diff --git a/source/api_cc/src/DeepPotJAX.cc b/source/api_cc/src/DeepPotJAX.cc index 2182676bd1..86ff973b57 100644 --- a/source/api_cc/src/DeepPotJAX.cc +++ b/source/api_cc/src/DeepPotJAX.cc @@ -320,6 +320,17 @@ void deepmd::DeepPotJAX::compute(std::vector& ener, nall_real, nloc_real, dcoord, datype, aparam_, nghost, ntypes, nframes, daparam, nall, false); + if (nloc_real == 0) { + // no real atoms, fill 0 for all outputs + // this can prevent a Xla error + ener.resize(nframes, 0.0); + force_.resize(static_cast(nframes) * nall * 3, 0.0); + virial.resize(static_cast(nframes) * 9, 0.0); + atom_energy_.resize(static_cast(nframes) * nall, 0.0); + atom_virial_.resize(static_cast(nframes) * nall * 9, 0.0); + return; + } + // cast coord, fparam, and aparam to double - I think it's useless to have a // float model interface std::vector coord_double(coord.begin(), coord.end());