Skip to content

Commit

Permalink
fix compile issues
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Sep 13, 2023
1 parent 2882072 commit d319bb3
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 30 deletions.
20 changes: 20 additions & 0 deletions source/api_c/include/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,16 @@ int DP_DeepPotGetDimFParam(DP_DeepPot* dp);
*/
int DP_DeepPotGetDimAParam(DP_DeepPot* dp);

/**
* @brief Check whether the atomic dimension of atomic parameters is nall
* instead of nloc.
*
* @param[in] dp The DP to use.
* @return true the atomic dimension of atomic parameters is nall
* @return false the atomic dimension of atomic parameters is nloc
*/
bool DP_DeepPotIsAParamNAll(DP_DeepPot* dp);

/**
* @brief Get the type map of a DP.
* @param[in] dp The DP to use.
Expand All @@ -737,6 +747,16 @@ int DP_DeepPotModelDeviGetDimFParam(DP_DeepPotModelDevi* dp);
*/
int DP_DeepPotModelDeviGetDimAParam(DP_DeepPotModelDevi* dp);

/**
* @brief Check whether the atomic dimension of atomic parameters is nall
* instead of nloc.
*
* @param[in] dp The DP Model Deviation to use.
* @return true the atomic dimension of atomic parameters is nall
* @return false the atomic dimension of atomic parameters is nloc
*/
bool DP_DeepPotModelDeviIsAParamNAll(DP_DeepPotModelDevi* dp);

/**
* @brief The deep tensor.
**/
Expand Down
40 changes: 26 additions & 14 deletions source/api_c/include/deepmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ class DeepPot {
DP_CHECK_OK(DP_DeepPotCheckOK, dp);
dfparam = DP_DeepPotGetDimFParam(dp);
daparam = DP_DeepPotGetDimAParam(dp);
aparam_nall = DP_DeepPotIsAParamNAll(dp);
};

/**
Expand Down Expand Up @@ -771,9 +772,12 @@ class DeepPot {
VALUETYPE *force_ = &force[0];
VALUETYPE *virial_ = &virial[0];
std::vector<VALUETYPE> fparam_, aparam_;
validate_fparam_aparam(nframes, natoms - nghost, fparam, aparam);
validate_fparam_aparam(nframes, (aparam_nall ? natoms : (natoms - nghost)),
fparam, aparam);
tile_fparam_aparam(fparam_, nframes, dfparam, fparam);
tile_fparam_aparam(aparam_, nframes, (natoms - nghost) * daparam, aparam);
tile_fparam_aparam(aparam_, nframes,
(aparam_nall ? natoms : (natoms - nghost)) * daparam,
aparam);
const VALUETYPE *fparam__ = !fparam_.empty() ? &fparam_[0] : nullptr;
const VALUETYPE *aparam__ = !aparam_.empty() ? &aparam_[0] : nullptr;

Expand Down Expand Up @@ -842,9 +846,12 @@ class DeepPot {
VALUETYPE *atomic_ener_ = &atom_energy[0];
VALUETYPE *atomic_virial_ = &atom_virial[0];
std::vector<VALUETYPE> fparam_, aparam_;
validate_fparam_aparam(nframes, natoms - nghost, fparam, aparam);
validate_fparam_aparam(nframes, (aparam_nall ? natoms : (natoms - nghost)),
fparam, aparam);
tile_fparam_aparam(fparam_, nframes, dfparam, fparam);
tile_fparam_aparam(aparam_, nframes, (natoms - nghost) * daparam, aparam);
tile_fparam_aparam(aparam_, nframes,
(aparam_nall ? natoms : (natoms - nghost)) * daparam,
aparam);
const VALUETYPE *fparam__ = !fparam_.empty() ? &fparam_[0] : nullptr;
const VALUETYPE *aparam__ = !aparam_.empty() ? &aparam_[0] : nullptr;

Expand Down Expand Up @@ -1039,6 +1046,7 @@ class DeepPot {
DP_DeepPot *dp;
int dfparam;
int daparam;
bool aparam_nall;
template <typename VALUETYPE>
void validate_fparam_aparam(const int &nframes,
const int &nloc,
Expand All @@ -1051,9 +1059,7 @@ class DeepPot {
}

if (aparam.size() != daparam * nloc &&
aparam.size() != nframes * daparam * nloc &&
aparam.size() != daparam * nall &&
aparam.size() != nframes * daparam * nall) {
aparam.size() != nframes * daparam * nloc) {
throw deepmd::hpp::deepmd_exception(
"the dim of atom parameter provided is not consistent with what the "
"model uses");
Expand Down Expand Up @@ -1130,6 +1136,7 @@ class DeepPotModelDevi {
numb_models = models.size();
dfparam = DP_DeepPotModelDeviGetDimFParam(dp);
daparam = DP_DeepPotModelDeviGetDimAParam(dp);
aparam_nall = DP_DeepPotModelDeviIsAParamNAll(dp);
};

/**
Expand Down Expand Up @@ -1175,9 +1182,12 @@ class DeepPotModelDevi {
VALUETYPE *force_ = &force_flat[0];
VALUETYPE *virial_ = &virial_flat[0];
std::vector<VALUETYPE> fparam_, aparam_;
validate_fparam_aparam(nframes, natoms - nghost, fparam, aparam);
validate_fparam_aparam(nframes, (aparam_nall ? natoms : (natoms - nghost)),
fparam, aparam);
tile_fparam_aparam(fparam_, nframes, dfparam, fparam);
tile_fparam_aparam(aparam_, nframes, (natoms - nghost) * daparam, aparam);
tile_fparam_aparam(aparam_, nframes,
(aparam_nall ? natoms : (natoms - nghost)) * daparam,
aparam);
const VALUETYPE *fparam__ = !fparam_.empty() ? &fparam_[0] : nullptr;
const VALUETYPE *aparam__ = !aparam_.empty() ? &aparam_[0] : nullptr;

Expand Down Expand Up @@ -1252,9 +1262,12 @@ class DeepPotModelDevi {
VALUETYPE *atomic_ener_ = &atom_energy_flat[0];
VALUETYPE *atomic_virial_ = &atom_virial_flat[0];
std::vector<VALUETYPE> fparam_, aparam_;
validate_fparam_aparam(nframes, natoms - nghost, fparam, aparam);
validate_fparam_aparam(nframes, (aparam_nall ? natoms : (natoms - nghost)),
fparam, aparam);
tile_fparam_aparam(fparam_, nframes, dfparam, fparam);
tile_fparam_aparam(aparam_, nframes, (natoms - nghost) * daparam, aparam);
tile_fparam_aparam(aparam_, nframes,
(aparam_nall ? natoms : (natoms - nghost)) * daparam,
aparam);
const VALUETYPE *fparam__ = !fparam_.empty() ? &fparam_[0] : nullptr;
const VALUETYPE *aparam__ = !aparam_.empty() ? &aparam_[0] : nullptr;

Expand Down Expand Up @@ -1450,6 +1463,7 @@ class DeepPotModelDevi {
int numb_models;
int dfparam;
int daparam;
bool aparam_nall;
template <typename VALUETYPE>
void validate_fparam_aparam(const int &nframes,
const int &nloc,
Expand All @@ -1462,9 +1476,7 @@ class DeepPotModelDevi {
}

if (aparam.size() != daparam * nloc &&
aparam.size() != nframes * daparam * nloc &&
aparam.size() != daparam * nall &&
aparam.size() != nframes * daparam * nall) {
aparam.size() != nframes * daparam * nloc) {
throw deepmd::hpp::deepmd_exception(
"the dim of atom parameter provided is not consistent with what the "
"model uses");
Expand Down
10 changes: 8 additions & 2 deletions source/api_c/src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ inline void DP_DeepPotComputeNList_variant(DP_DeepPot* dp,
if (aparam) {
aparam_.assign(aparam,
aparam + nframes *
(dp->aparam_all ? natoms : (natoms - nghost)) *
(dp->aparam_nall ? natoms : (natoms - nghost)) *
dp->daparam);
}
std::vector<double> e;
Expand Down Expand Up @@ -440,7 +440,7 @@ void DP_DeepPotModelDeviComputeNList_variant(DP_DeepPotModelDevi* dp,
if (aparam) {
aparam_.assign(
aparam,
aparam + (dp->aparam_all ? natoms : (natoms - nghost)) * dp->daparam);
aparam + (dp->aparam_nall ? natoms : (natoms - nghost)) * dp->daparam);
}
// different from DeepPot
std::vector<double> e;
Expand Down Expand Up @@ -1038,6 +1038,8 @@ int DP_DeepPotGetDimFParam(DP_DeepPot* dp) { return dp->dfparam; }

int DP_DeepPotGetDimAParam(DP_DeepPot* dp) { return dp->daparam; }

bool DP_DeepPotIsAParamNAll(DP_DeepPot* dp) { return dp->aparam_nall; }

const char* DP_DeepPotCheckOK(DP_DeepPot* dp) {
return string_to_char(dp->exception);
}
Expand Down Expand Up @@ -1140,6 +1142,10 @@ int DP_DeepPotModelDeviGetDimAParam(DP_DeepPotModelDevi* dp) {
return dp->daparam;
}

bool DP_DeepPotModelDeviIsAParamNAll(DP_DeepPotModelDevi* dp) {
return dp->aparam_nall;
}

const char* DP_DeepPotModelDeviCheckOK(DP_DeepPotModelDevi* dp) {
return string_to_char(dp->exception);
}
Expand Down
4 changes: 2 additions & 2 deletions source/api_cc/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ int session_input_tensors(
const std::vector<VALUETYPE>& fparam_,
const std::vector<VALUETYPE>& aparam_,
const deepmd::AtomMap& atommap,
const bool aparam_nall = false,
const std::string scope = "");
const std::string scope = "",
const bool aparam_nall = false);

/**
* @brief Get input tensors.
Expand Down
18 changes: 6 additions & 12 deletions source/api_cc/src/DeepPot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -846,11 +846,9 @@ void DeepPot::compute(ENERGYVTYPE& dener,
int nloc = datype_.size();
std::vector<VALUETYPE> fparam;
std::vector<VALUETYPE> aparam;
validate_fparam_aparam(nframes, (aparam_nall ? nall : nloc), fparam_,
aparam_);
validate_fparam_aparam(nframes, nloc, fparam_, aparam_);
tile_fparam_aparam(fparam, nframes, dfparam, fparam_);
tile_fparam_aparam(aparam, nframes, (aparam_nall ? nall : nloc) * daparam,
aparam_);
tile_fparam_aparam(aparam, nframes, nloc * daparam, aparam_);

std::vector<std::pair<std::string, Tensor>> input_tensors;

Expand Down Expand Up @@ -1064,11 +1062,9 @@ void DeepPot::compute_mixed_type(ENERGYVTYPE& dener,
atommap = deepmd::AtomMap(datype_.begin(), datype_.begin() + nloc);
std::vector<VALUETYPE> fparam;
std::vector<VALUETYPE> aparam;
validate_fparam_aparam(nframes, (aparam_nall ? nall : nloc), fparam_,
aparam_);
validate_fparam_aparam(nframes, nloc, fparam_, aparam_);
tile_fparam_aparam(fparam, nframes, dfparam, fparam_);
tile_fparam_aparam(aparam, nframes, (aparam_nall ? nall : nloc) * daparam,
aparam_);
tile_fparam_aparam(aparam, nframes, nloc * daparam, aparam_);

std::vector<std::pair<std::string, Tensor>> input_tensors;

Expand Down Expand Up @@ -1150,11 +1146,9 @@ void DeepPot::compute_mixed_type(ENERGYVTYPE& dener,
atommap = deepmd::AtomMap(datype_.begin(), datype_.begin() + nloc);
std::vector<VALUETYPE> fparam;
std::vector<VALUETYPE> aparam;
validate_fparam_aparam(nframes, (aparam_nall ? nall : nloc), fparam_,
aparam_);
validate_fparam_aparam(nframes, nloc, fparam_, aparam_);
tile_fparam_aparam(fparam, nframes, dfparam, fparam_);
tile_fparam_aparam(aparam, nframes, (aparam_nall ? nall : nloc) * daparam,
aparam_);
tile_fparam_aparam(aparam, nframes, nloc * daparam, aparam_);

std::vector<std::pair<std::string, Tensor>> input_tensors;

Expand Down
4 changes: 4 additions & 0 deletions source/api_cc/src/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,10 @@ template int deepmd::session_get_scalar<int>(Session*,
const std::string,
const std::string);

template bool deepmd::session_get_scalar<bool>(Session*,
const std::string,
const std::string);

template void deepmd::session_get_vector<int>(std::vector<int>&,
Session*,
const std::string,
Expand Down

0 comments on commit d319bb3

Please sign in to comment.