diff --git a/source/api_cc/include/DeepPotJAX.h b/source/api_cc/include/DeepPotJAX.h index 613d994fbc..606836de7e 100644 --- a/source/api_cc/include/DeepPotJAX.h +++ b/source/api_cc/include/DeepPotJAX.h @@ -18,7 +18,7 @@ class DeepPotJAX : public DeepPotBase { * @brief DP constructor without initialization. **/ DeepPotJAX(); - ~DeepPotJAX(); + virtual ~DeepPotJAX(); /** * @brief DP constructor with initialization. * @param[in] model The name of the frozen model file. diff --git a/source/api_cc/src/DeepPotJAX.cc b/source/api_cc/src/DeepPotJAX.cc index 8d4d0e61ed..c3deec2bcd 100644 --- a/source/api_cc/src/DeepPotJAX.cc +++ b/source/api_cc/src/DeepPotJAX.cc @@ -101,6 +101,8 @@ inline T get_scalar(TFE_Context* ctx, T result = *data; // deallocate TFE_DeleteOp(op); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(retval); return result; } @@ -152,6 +154,8 @@ inline std::vector get_vector_string( // deallocate TFE_DeleteOp(op); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(retval); return result; } @@ -191,6 +195,8 @@ inline void tensor_to_vector(std::vector& result, for (int i = 0; i < TF_TensorElementCount(tensor); i++) { result[i] = data[i]; } + // Delete the tensor to free memory + TF_DeleteTensor(tensor); } deepmd::DeepPotJAX::DeepPotJAX() : inited(false) {} @@ -262,6 +268,7 @@ void deepmd::DeepPotJAX::init(const std::string& model, ntypes = type_map_.size(); sel = get_vector(ctx, "get_sel", func_vector, device, status); nnei = std::accumulate(sel.begin(), sel.end(), decltype(sel)::value_type(0)); + inited = true; } deepmd::DeepPotJAX::~DeepPotJAX() { @@ -302,7 +309,6 @@ void deepmd::DeepPotJAX::compute(std::vector& ener, select_real_atoms_coord(coord, atype, aparam, nghost_real, fwd_map, bkw_map, nall_real, nloc_real, dcoord, datype, aparam_, nghost, ntypes, nframes, daparam, nall, false); - int nloc = nall_real - nghost_real; // cast coord, fparam, and aparam to double - I think it's useless to have a // float model interface @@ -331,7 +337,7 @@ void deepmd::DeepPotJAX::compute(std::vector& ener, nlist_data.shuffle_exclude_empty(fwd_map); } std::vector nlist_shape = {nframes, nloc_real, nnei}; - std::vector nlist(nframes * nloc_real * nnei); + std::vector nlist(static_cast(nframes) * nloc_real * nnei); // pass nlist_data.jlist to nlist for (int ii = 0; ii < nloc_real; ii++) { for (int jj = 0; jj < nnei; jj++) { diff --git a/source/api_cc/tests/test_deeppot_jax.cc b/source/api_cc/tests/test_deeppot_jax.cc index 361ba5d759..0514cf3017 100644 --- a/source/api_cc/tests/test_deeppot_jax.cc +++ b/source/api_cc/tests/test_deeppot_jax.cc @@ -89,9 +89,9 @@ class TestInferDeepPotAJAX : public ::testing::Test { expected_tot_v[dd] += expected_v[ii * 9 + dd]; } } - }; + } - void TearDown() override {}; + void TearDown() override {} }; TYPED_TEST_SUITE(TestInferDeepPotAJAX, ValueTypes);