Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz and coderabbitai[bot] authored Nov 4, 2024
1 parent 6c10e8e commit 2aa6deb
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion source/api_cc/include/DeepPotJAX.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class DeepPotJAX : public DeepPotBase {
* @brief DP constructor without initialization.
**/
DeepPotJAX();
~DeepPotJAX();
virtual ~DeepPotJAX();
/**
* @brief DP constructor with initialization.
* @param[in] model The name of the frozen model file.
Expand Down
10 changes: 8 additions & 2 deletions source/api_cc/src/DeepPotJAX.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ inline T get_scalar(TFE_Context* ctx,
T result = *data;
// deallocate
TFE_DeleteOp(op);
TF_DeleteTensor(tensor);
TFE_DeleteTensorHandle(retval);
return result;
}

Expand Down Expand Up @@ -152,6 +154,8 @@ inline std::vector<std::string> get_vector_string(

// deallocate
TFE_DeleteOp(op);
TF_DeleteTensor(tensor);
TFE_DeleteTensorHandle(retval);
return result;
}

Expand Down Expand Up @@ -191,6 +195,8 @@ inline void tensor_to_vector(std::vector<T>& result,
for (int i = 0; i < TF_TensorElementCount(tensor); i++) {
result[i] = data[i];
}
// Delete the tensor to free memory
TF_DeleteTensor(tensor);
}

deepmd::DeepPotJAX::DeepPotJAX() : inited(false) {}
Expand Down Expand Up @@ -262,6 +268,7 @@ void deepmd::DeepPotJAX::init(const std::string& model,
ntypes = type_map_.size();
sel = get_vector<int64_t>(ctx, "get_sel", func_vector, device, status);
nnei = std::accumulate(sel.begin(), sel.end(), decltype(sel)::value_type(0));
inited = true;
}

deepmd::DeepPotJAX::~DeepPotJAX() {
Expand Down Expand Up @@ -302,7 +309,6 @@ void deepmd::DeepPotJAX::compute(std::vector<ENERGYTYPE>& ener,
select_real_atoms_coord(coord, atype, aparam, nghost_real, fwd_map, bkw_map,
nall_real, nloc_real, dcoord, datype, aparam_, nghost,
ntypes, nframes, daparam, nall, false);
int nloc = nall_real - nghost_real;

// cast coord, fparam, and aparam to double - I think it's useless to have a
// float model interface
Expand Down Expand Up @@ -331,7 +337,7 @@ void deepmd::DeepPotJAX::compute(std::vector<ENERGYTYPE>& ener,
nlist_data.shuffle_exclude_empty(fwd_map);
}
std::vector<int64_t> nlist_shape = {nframes, nloc_real, nnei};
std::vector<int64_t> nlist(nframes * nloc_real * nnei);
std::vector<int64_t> nlist(static_cast<size_t>(nframes) * nloc_real * nnei);
// pass nlist_data.jlist to nlist
for (int ii = 0; ii < nloc_real; ii++) {
for (int jj = 0; jj < nnei; jj++) {
Expand Down
4 changes: 2 additions & 2 deletions source/api_cc/tests/test_deeppot_jax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ class TestInferDeepPotAJAX : public ::testing::Test {
expected_tot_v[dd] += expected_v[ii * 9 + dd];
}
}
};
}

void TearDown() override {};
void TearDown() override {}
};

TYPED_TEST_SUITE(TestInferDeepPotAJAX, ValueTypes);
Expand Down

0 comments on commit 2aa6deb

Please sign in to comment.