Skip to content

Commit

Permalink
refactor: eliminate the buffer of fitting group forces
Browse files Browse the repository at this point in the history
  • Loading branch information
HanatoK committed Sep 18, 2024
1 parent cdb7f54 commit 57fa6a6
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 63 deletions.
93 changes: 41 additions & 52 deletions src/colvaratoms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,15 +921,6 @@ int cvm::atom_group::parse_fitting_options(std::string const &group_conf)

void cvm::atom_group::do_feature_side_effects(int id)
{
// If enabled features are changed upstream, the features below should be refreshed
switch (id) {
case f_ag_fit_gradients:
if (is_enabled(f_ag_center) || is_enabled(f_ag_rotate)) {
atom_group *group_for_fit = fitting_group ? fitting_group : this;
group_for_fit->fit_gradients.assign(group_for_fit->size(), cvm::atom_pos(0.0, 0.0, 0.0));
}
break;
}
}


Expand Down Expand Up @@ -1218,30 +1209,34 @@ void cvm::atom_group::calc_fit_gradients()
cvm::log("Calculating fit gradients.\n");

cvm::atom_group *group_for_fit = fitting_group ? fitting_group : this;
auto accessor = [this](size_t i){return atoms[i].grad;};
if (group_for_fit->fit_gradients.size() != group_for_fit->size()) {
group_for_fit->fit_gradients.assign(group_for_fit->size(), 0);
} else {
std::fill(group_for_fit->fit_gradients.begin(),
group_for_fit->fit_gradients.end(), 0);
}
auto accessor_main = [this](size_t i){return atoms[i].grad;};
auto accessor_fitting = [&group_for_fit](size_t j, const cvm::rvector& grad){group_for_fit->fit_gradients[j] += grad;};
if (is_enabled(f_ag_center) && is_enabled(f_ag_rotate))
calc_fit_forces_impl<true, true>(group_for_fit->fit_gradients, accessor);
calc_fit_forces_impl<true, true>(accessor_main, accessor_fitting);
if (is_enabled(f_ag_center) && !is_enabled(f_ag_rotate))
calc_fit_forces_impl<true, false>(group_for_fit->fit_gradients, accessor);
calc_fit_forces_impl<true, false>(accessor_main, accessor_fitting);
if (!is_enabled(f_ag_center) && is_enabled(f_ag_rotate))
calc_fit_forces_impl<false, true>(group_for_fit->fit_gradients, accessor);
calc_fit_forces_impl<false, true>(accessor_main, accessor_fitting);
if (!is_enabled(f_ag_center) && !is_enabled(f_ag_rotate))
calc_fit_forces_impl<false, false>(group_for_fit->fit_gradients, accessor);
calc_fit_forces_impl<false, false>(accessor_main, accessor_fitting);

if (cvm::debug())
cvm::log("Done calculating fit gradients.\n");
}


template <bool B_ag_center, bool B_ag_rotate,
typename main_force_accessor_T>
typename main_force_accessor_T, typename fitting_force_accessor_T>
void cvm::atom_group::calc_fit_forces_impl(
std::vector<cvm::rvector>& forces_on_fitting_group,
main_force_accessor_T accessor) const {
main_force_accessor_T accessor_main,
fitting_force_accessor_T accessor_fitting) const {
const cvm::atom_group *group_for_fit = fitting_group ? fitting_group : this;
if (forces_on_fitting_group.size() != group_for_fit->size()) {
forces_on_fitting_group.assign(group_for_fit->size(), 0);
}
// the center of geometry contribution to the gradients
cvm::rvector atom_grad;
// the rotation matrix contribution to the gradients
Expand All @@ -1253,15 +1248,15 @@ void cvm::atom_group::calc_fit_forces_impl(
for (size_t i = 0; i < size(); i++) {
cvm::atom_pos pos_orig;
if (B_ag_center) {
atom_grad += accessor(i);
atom_grad += accessor_main(i);
if (B_ag_rotate) pos_orig = rot_inv * (atoms[i].pos - ref_pos_cog);
} else {
if (B_ag_rotate) pos_orig = atoms[i].pos;
}
if (B_ag_rotate) {
// calculate \partial(R(q) \vec{x}_i)/\partial q) \cdot \partial\xi/\partial\vec{x}_i
cvm::quaternion const dxdq =
rot.q.position_derivative_inner(pos_orig, accessor(i));
rot.q.position_derivative_inner(pos_orig, accessor_main(i));
sum_dxdq[0] += dxdq[0];
sum_dxdq[1] += dxdq[1];
sum_dxdq[2] += dxdq[2];
Expand All @@ -1275,38 +1270,37 @@ void cvm::atom_group::calc_fit_forces_impl(
// loop 2: iterate over the fitting group
if (B_ag_rotate) rot_deriv->prepare_derivative(rotation_derivative_dldq::use_dq);
for (size_t j = 0; j < group_for_fit->size(); j++) {
cvm::rvector fitting_force_grad{0, 0, 0};
if (B_ag_center) {
forces_on_fitting_group[j] = atom_grad;
fitting_force_grad += atom_grad;
}
if (B_ag_rotate) {
rot_deriv->calc_derivative_wrt_group1(j, nullptr, &dq0_1);
// multiply by {\partial q}/\partial\vec{x}_j and add it to the fit gradients
forces_on_fitting_group[j] += sum_dxdq[0] * dq0_1[0] +
sum_dxdq[1] * dq0_1[1] +
sum_dxdq[2] * dq0_1[2] +
sum_dxdq[3] * dq0_1[3];
fitting_force_grad += sum_dxdq[0] * dq0_1[0] +
sum_dxdq[1] * dq0_1[1] +
sum_dxdq[2] * dq0_1[2] +
sum_dxdq[3] * dq0_1[3];
}
if (cvm::debug()) {
cvm::log(cvm::to_str(fitting_force_grad));
}
accessor_fitting(j, fitting_force_grad);
}
}


template <typename main_force_accessor_T, typename fitting_force_accessor_T>
void cvm::atom_group::calc_fit_forces(
const std::vector<cvm::rvector>& forces_on_main_group,
std::vector<cvm::rvector>& forces_on_fitting_group) const {
if (cvm::debug())
cvm::log("Calculating fit forces.\n");
auto accessor = [&forces_on_main_group](size_t i){return forces_on_main_group[i];};
main_force_accessor_T accessor_main,
fitting_force_accessor_T accessor_fitting) const {
if (is_enabled(f_ag_center) && is_enabled(f_ag_rotate))
calc_fit_forces_impl<true, true>(forces_on_fitting_group, accessor);
calc_fit_forces_impl<true, true, main_force_accessor_T, fitting_force_accessor_T>(accessor_main, accessor_fitting);
if (is_enabled(f_ag_center) && !is_enabled(f_ag_rotate))
calc_fit_forces_impl<true, false>(forces_on_fitting_group, accessor);
calc_fit_forces_impl<true, false, main_force_accessor_T, fitting_force_accessor_T>(accessor_main, accessor_fitting);
if (!is_enabled(f_ag_center) && is_enabled(f_ag_rotate))
calc_fit_forces_impl<false, true>(forces_on_fitting_group, accessor);
calc_fit_forces_impl<false, true, main_force_accessor_T, fitting_force_accessor_T>(accessor_main, accessor_fitting);
if (!is_enabled(f_ag_center) && !is_enabled(f_ag_rotate))
calc_fit_forces_impl<false, false>(forces_on_fitting_group, accessor);

if (cvm::debug())
cvm::log("Done calculating fit forces.\n");
calc_fit_forces_impl<false, false, main_force_accessor_T, fitting_force_accessor_T>(accessor_main, accessor_fitting);
}


Expand Down Expand Up @@ -1500,12 +1494,6 @@ m_has_fitting_force(m_ag->is_enabled(f_ag_center) || m_ag->is_enabled(f_ag_rotat
std::fill(m_ag->group_forces.begin(),
m_ag->group_forces.end(), 0);
}
if (m_ag->fitting_group_forces.size() != m_group_for_fit->size()) {
m_ag->fitting_group_forces.assign(m_group_for_fit->size(), 0);
} else {
std::fill(m_ag->fitting_group_forces.begin(),
m_ag->fitting_group_forces.end(), 0);
}
}
}

Expand Down Expand Up @@ -1542,15 +1530,16 @@ void cvm::atom_group::group_force_object::apply_force_with_fitting_group() {
// group, but checking this flag can mimic results that the users expect (if
// "enableFitGradients no" then there is no force on the fitting group).
if (m_ag->is_enabled(f_ag_fit_gradients)) {
m_ag->calc_fit_forces(m_ag->group_forces, m_ag->fitting_group_forces);
auto accessor_main = [this](size_t i){return m_ag->group_forces[i];};
auto accessor_fitting = [this](size_t j, const cvm::rvector& fitting_force){
(*(m_group_for_fit))[j].apply_force(fitting_force);
};
if (cvm::debug()) {
cvm::log("Applying force on the fitting group of main group" + m_ag->name + ":\n");
}
for (size_t ia = 0; ia < m_group_for_fit->size(); ia++) {
(*(m_group_for_fit))[ia].apply_force(m_ag->fitting_group_forces[ia]);
if (cvm::debug()) {
cvm::log(cvm::to_str(m_ag->fitting_group_forces[ia]));
}
m_ag->calc_fit_forces(accessor_main, accessor_fitting);
if (cvm::debug()) {
cvm::log("Done applying force on the fitting group of main group" + m_ag->name + ":\n");
}
}
}
Expand Down
44 changes: 33 additions & 11 deletions src/colvaratoms.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ class colvarmodule::atom_group
/// \brief Index in the colvarproxy arrays (if the group is scalable)
int index;

/// \brief The forces acting on the group atoms (stored mainly used for calculating the fitting group forces)
std::vector<cvm::rvector> group_forces;
std::vector<cvm::rvector> fitting_group_forces;

public:

Expand Down Expand Up @@ -514,25 +514,47 @@ class colvarmodule::atom_group
/// \brief Calculate the derivatives of the fitting transformation
void calc_fit_gradients();

void calc_fit_forces(
const std::vector<cvm::rvector>& forces_on_main_group,
std::vector<cvm::rvector>& forces_on_fitting_group) const;

/*! @brief Actual implementation of `calc_fit_gradients`. The template is
/*! @brief Actual implementation of `calc_fit_gradients` and
* `calc_fit_forces`. The template is
* used to avoid branching inside the loops in case that the CPU
* branch prediction is broken (or further migration to GPU code).
* @tparam B_ag_center Centered the reference to origin? This should follow
* the value of `is_enabled(f_ag_center)`.
* @tparam B_ag_rotate Calculate the optimal rotation? This should follow
* the value of `is_enabled(f_ag_rotate)`.
* @tparam main_force_accessor_T The type of accessor of the main
* group forces or gradients.
* @tparam fitting_force_accessor_T The type of accessor of the fitting group
* forces or gradients.
* @param accessor_main The accessor of the main group forces or gradients.
* accessor_main(i) should return the i-th force or gradient of the
* main group.
* @param accessor_fitting The accessor of the fitting group forces or gradients.
* accessor_fitting(j, v) should store/apply the j-th atom gradient or
* force in the fitting group.
*/
// template <bool B_ag_center, bool B_ag_rotate> void calc_fit_gradients_impl();

template <bool B_ag_center, bool B_ag_rotate,
typename main_force_accessor_T>
typename main_force_accessor_T, typename fitting_force_accessor_T>
void calc_fit_forces_impl(
std::vector<cvm::rvector>& forces_on_fitting_group,
main_force_accessor_T accessor) const;
main_force_accessor_T accessor_main,
fitting_force_accessor_T accessor_fitting) const;

/*! @brief Calculate or apply the fitting group forces from the main group forces.
* @tparam main_force_accessor_T The type of accessor of the main
* group forces or gradients.
* @tparam fitting_force_accessor_T The type of accessor of the fitting group
* forces or gradients.
* @param accessor_main The accessor of the main group forces or gradients.
* accessor_main(i) should return the i-th force or gradient of the
* main group.
* @param accessor_fitting The accessor of the fitting group forces or gradients.
* accessor_fitting(j, v) should store/apply the j-th atom gradient or
* force in the fitting group.
*/
template <typename main_force_accessor_T, typename fitting_force_accessor_T>
void calc_fit_forces(
main_force_accessor_T accessor_main,
fitting_force_accessor_T accessor_fitting) const;

/// \brief Derivatives of the fitting transformation
std::vector<cvm::atom_pos> fit_gradients;
Expand Down

0 comments on commit 57fa6a6

Please sign in to comment.