Skip to content

Commit

Permalink
Merge branch 'devel' into train_rf
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd authored Feb 29, 2024
2 parents cce52da + 665d716 commit 18cbf9e
Show file tree
Hide file tree
Showing 4 changed files with 448 additions and 22 deletions.
12 changes: 6 additions & 6 deletions source/api_cc/include/DeepPotPT.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ class DeepPotPT : public DeepPotBase {
std::vector<VALUETYPE>& atom_virial,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box);
// const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
// const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
const std::vector<VALUETYPE>& box,
const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
/**
* @brief Evaluate the energy, force, virial, atomic energy, and atomic virial
*by using this DP.
Expand Down Expand Up @@ -108,9 +108,9 @@ class DeepPotPT : public DeepPotBase {
const std::vector<VALUETYPE>& box,
// const int nghost,
const InputNlist& lmp_list,
const int& ago);
// const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
// const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
const int& ago,
const std::vector<VALUETYPE>& fparam = std::vector<VALUETYPE>(),
const std::vector<VALUETYPE>& aparam = std::vector<VALUETYPE>());
/**
* @brief Evaluate the energy, force, and virial with the mixed type
*by using this DP.
Expand Down
74 changes: 58 additions & 16 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ void DeepPotPT::init(const std::string& model,
rcut = static_cast<double>(rcut_);
ntypes = 0;
ntypes_spin = 0;
dfparam = 0;
daparam = 0;
aparam_nall = false;
dfparam = module.run_method("get_dim_fparam").toInt();
daparam = module.run_method("get_dim_aparam").toInt();
aparam_nall = module.run_method("is_aparam_nall").toBool();
inited = true;
}
DeepPotPT::~DeepPotPT() {}
Expand All @@ -79,7 +79,9 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
const InputNlist& lmp_list,
const int& ago) {
const int& ago,
const std::vector<VALUETYPE>& fparam,
const std::vector<VALUETYPE>& aparam) {
torch::Device device(torch::kCUDA, gpu_id);
if (!gpu_enabled) {
device = torch::Device(torch::kCPU);
Expand Down Expand Up @@ -109,11 +111,27 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
firstneigh_tensor = firstneigh.to(torch::kInt64).to(device);
bool do_atom_virial_tensor = true;
c10::optional<torch::Tensor> optional_tensor;
c10::optional<torch::Tensor> fparam_tensor;
if (!fparam.empty()) {
fparam_tensor =
torch::from_blob(const_cast<VALUETYPE*>(fparam.data()),
{1, static_cast<long int>(fparam.size())}, options)
.to(device);
}
c10::optional<torch::Tensor> aparam_tensor;
if (!aparam.empty()) {
aparam_tensor =
torch::from_blob(const_cast<VALUETYPE*>(aparam.data()),
{1, lmp_list.inum,
static_cast<long int>(aparam.size()) / lmp_list.inum},
options)
.to(device);
}
c10::Dict<c10::IValue, c10::IValue> outputs =
module
.run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor,
firstneigh_tensor, optional_tensor, optional_tensor,
optional_tensor, do_atom_virial_tensor)
firstneigh_tensor, optional_tensor, fparam_tensor,
aparam_tensor, do_atom_virial_tensor)
.toGenericDict();
c10::IValue energy_ = outputs.at("energy");
c10::IValue force_ = outputs.at("extended_force");
Expand Down Expand Up @@ -156,7 +174,9 @@ template void DeepPotPT::compute<double, std::vector<ENERGYTYPE>>(
const std::vector<int>& atype,
const std::vector<double>& box,
const InputNlist& lmp_list,
const int& ago);
const int& ago,
const std::vector<double>& fparam,
const std::vector<double>& aparam);
template void DeepPotPT::compute<float, std::vector<ENERGYTYPE>>(
std::vector<ENERGYTYPE>& ener,
std::vector<float>& force,
Expand All @@ -167,7 +187,9 @@ template void DeepPotPT::compute<float, std::vector<ENERGYTYPE>>(
const std::vector<int>& atype,
const std::vector<float>& box,
const InputNlist& lmp_list,
const int& ago);
const int& ago,
const std::vector<float>& fparam,
const std::vector<float>& aparam);
template <typename VALUETYPE, typename ENERGYVTYPE>
void DeepPotPT::compute(ENERGYVTYPE& ener,
std::vector<VALUETYPE>& force,
Expand All @@ -176,7 +198,9 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
std::vector<VALUETYPE>& atom_virial,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box) {
const std::vector<VALUETYPE>& box,
const std::vector<VALUETYPE>& fparam,
const std::vector<VALUETYPE>& aparam) {
torch::Device device(torch::kCUDA, gpu_id);
if (!gpu_enabled) {
device = torch::Device(torch::kCPU);
Expand Down Expand Up @@ -207,8 +231,21 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
}
inputs.push_back(box_Tensor);
c10::optional<torch::Tensor> fparam_tensor;
if (!fparam.empty()) {
fparam_tensor =
torch::from_blob(const_cast<VALUETYPE*>(fparam.data()),
{1, static_cast<long int>(fparam.size())}, options)
.to(device);
}
inputs.push_back(fparam_tensor);
c10::optional<torch::Tensor> aparam_tensor;
if (!aparam.empty()) {
aparam_tensor =
torch::from_blob(
const_cast<VALUETYPE*>(aparam.data()),
{1, natoms, static_cast<long int>(aparam.size()) / natoms}, options)
.to(device);
}
inputs.push_back(aparam_tensor);
bool do_atom_virial_tensor = true;
inputs.push_back(do_atom_virial_tensor);
Expand Down Expand Up @@ -253,7 +290,9 @@ template void DeepPotPT::compute<double, std::vector<ENERGYTYPE>>(
std::vector<double>& atom_virial,
const std::vector<double>& coord,
const std::vector<int>& atype,
const std::vector<double>& box);
const std::vector<double>& box,
const std::vector<double>& fparam,
const std::vector<double>& aparam);
template void DeepPotPT::compute<float, std::vector<ENERGYTYPE>>(
std::vector<ENERGYTYPE>& ener,
std::vector<float>& force,
Expand All @@ -262,7 +301,9 @@ template void DeepPotPT::compute<float, std::vector<ENERGYTYPE>>(
std::vector<float>& atom_virial,
const std::vector<float>& coord,
const std::vector<int>& atype,
const std::vector<float>& box);
const std::vector<float>& box,
const std::vector<float>& fparam,
const std::vector<float>& aparam);
void DeepPotPT::get_type_map(std::string& type_map) {
auto ret = module.run_method("get_type_map").toList();
for (const torch::IValue& element : ret) {
Expand All @@ -282,7 +323,8 @@ void DeepPotPT::computew(std::vector<double>& ener,
const std::vector<double>& box,
const std::vector<double>& fparam,
const std::vector<double>& aparam) {
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box);
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
fparam, aparam);
}
void DeepPotPT::computew(std::vector<double>& ener,
std::vector<float>& force,
Expand All @@ -294,7 +336,8 @@ void DeepPotPT::computew(std::vector<double>& ener,
const std::vector<float>& box,
const std::vector<float>& fparam,
const std::vector<float>& aparam) {
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box);
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
fparam, aparam);
}
void DeepPotPT::computew(std::vector<double>& ener,
std::vector<double>& force,
Expand All @@ -309,9 +352,8 @@ void DeepPotPT::computew(std::vector<double>& ener,
const int& ago,
const std::vector<double>& fparam,
const std::vector<double>& aparam) {
// TODO: atomic compute unsupported
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
inlist, ago);
inlist, ago, fparam, aparam);
}
void DeepPotPT::computew(std::vector<double>& ener,
std::vector<float>& force,
Expand All @@ -327,7 +369,7 @@ void DeepPotPT::computew(std::vector<double>& ener,
const std::vector<float>& fparam,
const std::vector<float>& aparam) {
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
inlist, ago);
inlist, ago, fparam, aparam);
}
void DeepPotPT::computew_mixed_type(std::vector<double>& ener,
std::vector<double>& force,
Expand Down
Loading

0 comments on commit 18cbf9e

Please sign in to comment.