Skip to content

Commit

Permalink
update CMAKE
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Sep 18, 2024
1 parent 482d588 commit a3c4663
Show file tree
Hide file tree
Showing 5 changed files with 669 additions and 172 deletions.
5 changes: 1 addition & 4 deletions source/api_cc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ if(ENABLE_PYTORCH
target_link_libraries(${libname} PRIVATE "${TORCH_LIBRARIES}")
target_compile_definitions(${libname} PRIVATE BUILD_PYTORCH)
endif()
if(ENABLE_PADDLE
AND "${OP_CXX_ABI_PT}" EQUAL "${OP_CXX_ABI}"
# LAMMPS and i-PI in the Python package are not ready - needs more work
AND NOT BUILD_PY_IF)
if(ENABLE_PADDLE AND NOT BUILD_PY_IF)
target_link_libraries(${libname} PRIVATE "${PADDLE_LIBRARIES}")
target_compile_definitions(${libname} PRIVATE BUILD_PADDLE)
endif()
Expand Down
34 changes: 20 additions & 14 deletions source/api_cc/include/DeepPotPD.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include "paddle/include/paddle_inference_api.h"
// #include "paddle/include/paddle_inference_api.h"
// #include "paddle/extension.h"
// #include "paddle/phi/backends/all_context.h"

#include "DeepPot.h"
#include "common.h"
#include "commonPD.h"
#include "neighbor_list.h"

namespace deepmd {
Expand Down Expand Up @@ -177,19 +178,19 @@ class DeepPotPD : public DeepPotBase {
*same aparam.
* @param[in] atomic Whether to compute the atomic energy and virial.
**/
template <typename VALUETYPE, typename ENERGYVTYPE>
void compute_mixed_type(ENERGYVTYPE& ener,
std::vector<VALUETYPE>& force,
std::vector<VALUETYPE>& virial,
std::vector<VALUETYPE>& atom_energy,
std::vector<VALUETYPE>& atom_virial,
const int& nframes,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
const std::vector<VALUETYPE>& fparam,
const std::vector<VALUETYPE>& aparam,
const bool atomic);
// template <typename VALUETYPE, typename ENERGYVTYPE>
// void compute_mixed_type(ENERGYVTYPE& ener,
// std::vector<VALUETYPE>& force,
// std::vector<VALUETYPE>& virial,
// std::vector<VALUETYPE>& atom_energy,
// std::vector<VALUETYPE>& atom_virial,
// const int& nframes,
// const std::vector<VALUETYPE>& coord,
// const std::vector<int>& atype,
// const std::vector<VALUETYPE>& box,
// const std::vector<VALUETYPE>& fparam,
// const std::vector<VALUETYPE>& aparam,
// const bool atomic);

public:
/**
Expand Down Expand Up @@ -327,6 +328,10 @@ class DeepPotPD : public DeepPotBase {
private:
int num_intra_nthreads, num_inter_nthreads;
bool inited;

template <class VT>
VT get_scalar(const std::string& name) const;

int ntypes;
int ntypes_spin;
int dfparam;
Expand All @@ -336,6 +341,7 @@ class DeepPotPD : public DeepPotBase {
std::shared_ptr<paddle_infer::Predictor> predictor = nullptr;
std::shared_ptr<paddle_infer::Config> config = nullptr;
double rcut;
double cell_size;
NeighborListData nlist_data;
int max_num_neighbors;
InputNlist nlist;
Expand Down
1 change: 1 addition & 0 deletions source/api_cc/include/version.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ const std::string global_git_branch="@GIT_BRANCH@";
const std::string global_tf_include_dir="@TensorFlow_INCLUDE_DIRS@";
const std::string global_tf_lib="@TensorFlow_LIBRARY@";
const std::string global_pt_lib="@TORCH_LIBRARIES@";
const std::string global_pd_lib="@PADDLE_LIBRARIES@";
const std::string global_model_version="@MODEL_VERSION@";
Loading

0 comments on commit a3c4663

Please sign in to comment.