Skip to content

Commit

Permalink
refactor: merge the implementation of calc_fit_gradients_impl and cal…
Browse files Browse the repository at this point in the history
…c_fit_forces_impl
  • Loading branch information
HanatoK committed Sep 17, 2024
1 parent be05982 commit 3519440
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 84 deletions.
100 changes: 31 additions & 69 deletions src/colvaratoms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1217,75 +1217,33 @@ void cvm::atom_group::calc_fit_gradients()
if (cvm::debug())
cvm::log("Calculating fit gradients.\n");

cvm::atom_group *group_for_fit = fitting_group ? fitting_group : this;
auto accessor = [](const std::vector<cvm::atom>& v, size_t i){return v[i].grad;};
if (is_enabled(f_ag_center) && is_enabled(f_ag_rotate))
calc_fit_gradients_impl<true, true>();
calc_fit_forces_impl<true, true>(atoms, group_for_fit->fit_gradients, accessor);
if (is_enabled(f_ag_center) && !is_enabled(f_ag_rotate))
calc_fit_gradients_impl<true, false>();
calc_fit_forces_impl<true, false>(atoms, group_for_fit->fit_gradients, accessor);
if (!is_enabled(f_ag_center) && is_enabled(f_ag_rotate))
calc_fit_gradients_impl<false, true>();
calc_fit_forces_impl<false, true>(atoms, group_for_fit->fit_gradients, accessor);
if (!is_enabled(f_ag_center) && !is_enabled(f_ag_rotate))
calc_fit_gradients_impl<false, false>();
calc_fit_forces_impl<false, false>(atoms, group_for_fit->fit_gradients, accessor);

if (cvm::debug())
cvm::log("Done calculating fit gradients.\n");
}


template <bool B_ag_center, bool B_ag_rotate>
void cvm::atom_group::calc_fit_gradients_impl() {
cvm::atom_group *group_for_fit = fitting_group ? fitting_group : this;
// the center of geometry contribution to the gradients
cvm::rvector atom_grad;
// the rotation matrix contribution to the gradients
const auto rot_inv = rot.inverse().matrix();
// temporary variables for computing and summing derivatives
cvm::real sum_dxdq[4] = {0, 0, 0, 0};
cvm::vector1d<cvm::rvector> dq0_1(4);
// loop 1: iterate over the current atom group
for (size_t i = 0; i < size(); i++) {
cvm::atom_pos pos_orig;
if (B_ag_center) {
atom_grad += atoms[i].grad;
if (B_ag_rotate) pos_orig = rot_inv * (atoms[i].pos - ref_pos_cog);
} else {
if (B_ag_rotate) pos_orig = atoms[i].pos;
}
if (B_ag_rotate) {
// calculate \partial(R(q) \vec{x}_i)/\partial q) \cdot \partial\xi/\partial\vec{x}_i
cvm::quaternion const dxdq =
rot.q.position_derivative_inner(pos_orig, atoms[i].grad);
sum_dxdq[0] += dxdq[0];
sum_dxdq[1] += dxdq[1];
sum_dxdq[2] += dxdq[2];
sum_dxdq[3] += dxdq[3];
}
}
if (B_ag_center) {
if (B_ag_rotate) atom_grad = rot.inverse().matrix() * atom_grad;
atom_grad *= (-1.0)/(cvm::real(group_for_fit->size()));
}
// loop 2: iterate over the fitting group
if (B_ag_rotate) rot_deriv->prepare_derivative(rotation_derivative_dldq::use_dq);
for (size_t j = 0; j < group_for_fit->size(); j++) {
if (B_ag_center) {
group_for_fit->fit_gradients[j] = atom_grad;
}
if (B_ag_rotate) {
rot_deriv->calc_derivative_wrt_group1(j, nullptr, &dq0_1);
// multiply by {\partial q}/\partial\vec{x}_j and add it to the fit gradients
group_for_fit->fit_gradients[j] += sum_dxdq[0] * dq0_1[0] +
sum_dxdq[1] * dq0_1[1] +
sum_dxdq[2] * dq0_1[2] +
sum_dxdq[3] * dq0_1[3];
}
}
}


template <bool B_ag_center, bool B_ag_rotate>
std::vector<cvm::rvector> cvm::atom_group::calc_fit_forces_impl(const std::vector<cvm::rvector>& forces_on_main_group) const {
template <bool B_ag_center, bool B_ag_rotate,
typename main_force_container_T,
typename main_force_accessor_T>
void cvm::atom_group::calc_fit_forces_impl(
const main_force_container_T& forces_on_main_group,
std::vector<cvm::rvector>& forces_on_fitting_group,
main_force_accessor_T accessor) const {
const cvm::atom_group *group_for_fit = fitting_group ? fitting_group : this;
std::vector<cvm::rvector> forces_on_fitting_group(group_for_fit->size());
if (forces_on_fitting_group.size() != group_for_fit->size()) {
forces_on_fitting_group.resize(group_for_fit->size());
}
// the center of geometry contribution to the gradients
cvm::rvector atom_grad;
// the rotation matrix contribution to the gradients
Expand All @@ -1297,23 +1255,23 @@ std::vector<cvm::rvector> cvm::atom_group::calc_fit_forces_impl(const std::vecto
for (size_t i = 0; i < size(); i++) {
cvm::atom_pos pos_orig;
if (B_ag_center) {
atom_grad += forces_on_main_group[i];
atom_grad += accessor(forces_on_main_group, i);
if (B_ag_rotate) pos_orig = rot_inv * (atoms[i].pos - ref_pos_cog);
} else {
if (B_ag_rotate) pos_orig = atoms[i].pos;
}
if (B_ag_rotate) {
// calculate \partial(R(q) \vec{x}_i)/\partial q) \cdot \partial\xi/\partial\vec{x}_i
cvm::quaternion const dxdq =
rot.q.position_derivative_inner(pos_orig, forces_on_main_group[i]);
rot.q.position_derivative_inner(pos_orig, accessor(forces_on_main_group, i));
sum_dxdq[0] += dxdq[0];
sum_dxdq[1] += dxdq[1];
sum_dxdq[2] += dxdq[2];
sum_dxdq[3] += dxdq[3];
}
}
if (B_ag_center) {
if (B_ag_rotate) atom_grad = rot.inverse().matrix() * atom_grad;
if (B_ag_rotate) atom_grad = rot_inv * atom_grad;
atom_grad *= (-1.0)/(cvm::real(group_for_fit->size()));
}
// loop 2: iterate over the fitting group
Expand All @@ -1331,23 +1289,27 @@ std::vector<cvm::rvector> cvm::atom_group::calc_fit_forces_impl(const std::vecto
sum_dxdq[3] * dq0_1[3];
}
}
return forces_on_fitting_group;
// TODO
}


std::vector<cvm::rvector> cvm::atom_group::calc_fit_forces(const std::vector<cvm::rvector>& forces_on_main_group) const {
std::vector<cvm::rvector> cvm::atom_group::calc_fit_forces(
const std::vector<cvm::rvector>& forces_on_main_group,
std::vector<cvm::rvector>& forces_on_fitting_group) const {
if (cvm::debug())
cvm::log("Calculating fit forces.\n");

auto accessor = [](const std::vector<cvm::rvector>& v, size_t i){return v[i];};
if (is_enabled(f_ag_center) && is_enabled(f_ag_rotate))
return calc_fit_forces_impl<true, true>(forces_on_main_group);
calc_fit_forces_impl<true, true>(
forces_on_main_group, forces_on_fitting_group, accessor);
if (is_enabled(f_ag_center) && !is_enabled(f_ag_rotate))
return calc_fit_forces_impl<true, false>(forces_on_main_group);
calc_fit_forces_impl<true, false>(
forces_on_main_group, forces_on_fitting_group, accessor);
if (!is_enabled(f_ag_center) && is_enabled(f_ag_rotate))
return calc_fit_forces_impl<false, true>(forces_on_main_group);
calc_fit_forces_impl<false, true>(
forces_on_main_group, forces_on_fitting_group, accessor);
if (!is_enabled(f_ag_center) && !is_enabled(f_ag_rotate))
return calc_fit_forces_impl<false, false>(forces_on_main_group);
calc_fit_forces_impl<false, false>(
forces_on_main_group, forces_on_fitting_group, accessor);

if (cvm::debug())
cvm::log("Done calculating fit forces.\n");
Expand Down
16 changes: 12 additions & 4 deletions src/colvaratoms.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,9 @@ class colvarmodule::atom_group
/// \brief Calculate the derivatives of the fitting transformation
void calc_fit_gradients();

std::vector<cvm::rvector> calc_fit_forces(const std::vector<cvm::rvector>& forces_on_main_group) const;
std::vector<cvm::rvector> calc_fit_forces(
const std::vector<cvm::rvector>& forces_on_main_group,
std::vector<cvm::rvector>& forces_on_fitting_group) const;

/*! @brief Actual implementation of `calc_fit_gradients`. The template is
* used to avoid branching inside the loops in case that the CPU
Expand All @@ -507,9 +509,15 @@ class colvarmodule::atom_group
* @tparam B_ag_rotate Calculate the optimal rotation? This should follow
* the value of `is_enabled(f_ag_rotate)`.
*/
template <bool B_ag_center, bool B_ag_rotate> void calc_fit_gradients_impl();

template <bool B_ag_center, bool B_ag_rotate> std::vector<cvm::rvector> calc_fit_forces_impl(const std::vector<cvm::rvector>& forces_on_main_group) const;
// template <bool B_ag_center, bool B_ag_rotate> void calc_fit_gradients_impl();

template <bool B_ag_center, bool B_ag_rotate,
typename main_force_container_T,
typename main_force_accessor_T>
void calc_fit_forces_impl(
const main_force_container_T& forces_on_main_group,
std::vector<cvm::rvector>& forces_on_fitting_group,
main_force_accessor_T accessor) const;

/// \brief Derivatives of the fitting transformation
std::vector<cvm::atom_pos> fit_gradients;
Expand Down
5 changes: 5 additions & 0 deletions src/colvarcomp.h
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,11 @@ class colvar::orientation
struct rotation_derivative_impl_;
std::unique_ptr<rotation_derivative_impl_> rot_deriv_impl;

bool atom_rotated;
cvm::atom_group *group_for_fit;
std::vector<cvm::rvector> main_group_forces;
std::vector<cvm::rvector> fitting_group_forces;

public:

orientation();
Expand Down
43 changes: 32 additions & 11 deletions src/colvarcomp_rotations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ struct colvar::orientation::rotation_derivative_impl_: public rotation_derivativ
};


colvar::orientation::orientation()
colvar::orientation::orientation():
atom_rotated(false), group_for_fit(nullptr)
{
set_function_type("orientation");
rot_deriv_impl = std::unique_ptr<rotation_derivative_impl_>(new rotation_derivative_impl_(this));
Expand All @@ -43,6 +44,12 @@ int colvar::orientation::init(std::string const &conf)
return error_code | COLVARS_INPUT_ERROR;
}
ref_pos.reserve(atoms->size());
main_group_forces.resize(atoms->size());
atom_rotated = atoms->is_enabled(f_ag_rotate);
if (atom_rotated) {
group_for_fit = atoms->fitting_group ? atoms->fitting_group : atoms;
fitting_group_forces.resize(group_for_fit->size());
}

if (get_keyval(conf, "refPositions", ref_pos, ref_pos)) {
cvm::log("Using reference positions from input file.\n");
Expand Down Expand Up @@ -137,30 +144,44 @@ void colvar::orientation::apply_force(colvarvalue const &force)
if (!atoms->noforce) {
rot_deriv_impl->prepare_derivative(rotation_derivative_dldq::use_dq);
cvm::vector1d<cvm::rvector> dq0_2;
std::vector<cvm::rvector> main_group_forces;
const bool force_on_fitting_group = atoms->fitting_group == nullptr ? false : true;
cvm::rmatrix ag_rot;
if (force_on_fitting_group) {
if (atom_rotated) {
ag_rot = atoms->rot.inverse().matrix();
}
if (cvm::debug()) {
cvm::log("Force on main group:\n");
}
for (size_t ia = 0; ia < atoms->size(); ia++) {
rot_deriv_impl->calc_derivative_wrt_group2(ia, nullptr, &dq0_2);
const auto f_ia = FQ[0] * dq0_2[0] +
FQ[1] * dq0_2[1] +
FQ[2] * dq0_2[2] +
FQ[3] * dq0_2[3];
if (force_on_fitting_group) {
main_group_forces.push_back(f_ia);
if (atom_rotated) {
main_group_forces[ia] = f_ia;
(*atoms)[ia].apply_force(ag_rot * f_ia);
if (cvm::debug()) {
cvm::log(cvm::to_str(ag_rot * f_ia));
}
} else {
(*atoms)[ia].apply_force(f_ia);
if (cvm::debug()) {
cvm::log(cvm::to_str(f_ia));
}
}
}
if (force_on_fitting_group) {
const std::vector<cvm::rvector> fitting_group_forces = atoms->calc_fit_forces(main_group_forces);
if (fitting_group_forces.empty()) return;
for (size_t ia = 0; ia < atoms->fitting_group->size(); ia++) {
(*(atoms->fitting_group))[ia].apply_force(fitting_group_forces[ia]);
if (atom_rotated) {
atoms->calc_fit_forces(main_group_forces, fitting_group_forces);
if (cvm::debug()) {
cvm::log("Force on fitting group:\n");
}
for (size_t ia = 0; ia < group_for_fit->size(); ia++) {
(*(group_for_fit))[ia].apply_force(fitting_group_forces[ia]);
if (cvm::debug()) {
cvm::log(cvm::to_str(fitting_group_forces[ia]));
}
// Clear the fitting group force
fitting_group_forces[ia] = 0;
}
}
}
Expand Down

0 comments on commit 3519440

Please sign in to comment.