From 2c8350ef7738998e5155e965eef48a5a4744f53c Mon Sep 17 00:00:00 2001 From: HanatoK Date: Tue, 17 Sep 2024 22:00:47 -0500 Subject: [PATCH] refactor: eliminate the buffer of fitting group forces --- src/colvaratoms.cpp | 93 ++++++++++++++++++++------------------------- src/colvaratoms.h | 44 +++++++++++++++------ 2 files changed, 74 insertions(+), 63 deletions(-) diff --git a/src/colvaratoms.cpp b/src/colvaratoms.cpp index e4a507ef5..0b939cf34 100644 --- a/src/colvaratoms.cpp +++ b/src/colvaratoms.cpp @@ -921,15 +921,6 @@ int cvm::atom_group::parse_fitting_options(std::string const &group_conf) void cvm::atom_group::do_feature_side_effects(int id) { - // If enabled features are changed upstream, the features below should be refreshed - switch (id) { - case f_ag_fit_gradients: - if (is_enabled(f_ag_center) || is_enabled(f_ag_rotate)) { - atom_group *group_for_fit = fitting_group ? fitting_group : this; - group_for_fit->fit_gradients.assign(group_for_fit->size(), cvm::atom_pos(0.0, 0.0, 0.0)); - } - break; - } } @@ -1218,15 +1209,22 @@ void cvm::atom_group::calc_fit_gradients() cvm::log("Calculating fit gradients.\n"); cvm::atom_group *group_for_fit = fitting_group ? fitting_group : this; - auto accessor = [this](size_t i){return atoms[i].grad;}; + if (group_for_fit->fit_gradients.size() != group_for_fit->size()) { + group_for_fit->fit_gradients.assign(group_for_fit->size(), 0); + } else { + std::fill(group_for_fit->fit_gradients.begin(), + group_for_fit->fit_gradients.end(), 0); + } + auto accessor_main = [this](size_t i){return atoms[i].grad;}; + auto accessor_fitting = [&group_for_fit](size_t j, const cvm::rvector& grad){group_for_fit->fit_gradients[j] += grad;}; if (is_enabled(f_ag_center) && is_enabled(f_ag_rotate)) - calc_fit_forces_impl(group_for_fit->fit_gradients, accessor); + calc_fit_forces_impl(accessor_main, accessor_fitting); if (is_enabled(f_ag_center) && !is_enabled(f_ag_rotate)) - calc_fit_forces_impl(group_for_fit->fit_gradients, accessor); + calc_fit_forces_impl(accessor_main, accessor_fitting); if (!is_enabled(f_ag_center) && is_enabled(f_ag_rotate)) - calc_fit_forces_impl(group_for_fit->fit_gradients, accessor); + calc_fit_forces_impl(accessor_main, accessor_fitting); if (!is_enabled(f_ag_center) && !is_enabled(f_ag_rotate)) - calc_fit_forces_impl(group_for_fit->fit_gradients, accessor); + calc_fit_forces_impl(accessor_main, accessor_fitting); if (cvm::debug()) cvm::log("Done calculating fit gradients.\n"); @@ -1234,14 +1232,11 @@ void cvm::atom_group::calc_fit_gradients() template + typename main_force_accessor_T, typename fitting_force_accessor_T> void cvm::atom_group::calc_fit_forces_impl( - std::vector& forces_on_fitting_group, - main_force_accessor_T accessor) const { + main_force_accessor_T accessor_main, + fitting_force_accessor_T accessor_fitting) const { const cvm::atom_group *group_for_fit = fitting_group ? fitting_group : this; - if (forces_on_fitting_group.size() != group_for_fit->size()) { - forces_on_fitting_group.assign(group_for_fit->size(), 0); - } // the center of geometry contribution to the gradients cvm::rvector atom_grad; // the rotation matrix contribution to the gradients @@ -1253,7 +1248,7 @@ void cvm::atom_group::calc_fit_forces_impl( for (size_t i = 0; i < size(); i++) { cvm::atom_pos pos_orig; if (B_ag_center) { - atom_grad += accessor(i); + atom_grad += accessor_main(i); if (B_ag_rotate) pos_orig = rot_inv * (atoms[i].pos - ref_pos_cog); } else { if (B_ag_rotate) pos_orig = atoms[i].pos; @@ -1261,7 +1256,7 @@ void cvm::atom_group::calc_fit_forces_impl( if (B_ag_rotate) { // calculate \partial(R(q) \vec{x}_i)/\partial q) \cdot \partial\xi/\partial\vec{x}_i cvm::quaternion const dxdq = - rot.q.position_derivative_inner(pos_orig, accessor(i)); + rot.q.position_derivative_inner(pos_orig, accessor_main(i)); sum_dxdq[0] += dxdq[0]; sum_dxdq[1] += dxdq[1]; sum_dxdq[2] += dxdq[2]; @@ -1275,38 +1270,37 @@ void cvm::atom_group::calc_fit_forces_impl( // loop 2: iterate over the fitting group if (B_ag_rotate) rot_deriv->prepare_derivative(rotation_derivative_dldq::use_dq); for (size_t j = 0; j < group_for_fit->size(); j++) { + cvm::rvector fitting_force_grad{0, 0, 0}; if (B_ag_center) { - forces_on_fitting_group[j] = atom_grad; + fitting_force_grad += atom_grad; } if (B_ag_rotate) { rot_deriv->calc_derivative_wrt_group1(j, nullptr, &dq0_1); // multiply by {\partial q}/\partial\vec{x}_j and add it to the fit gradients - forces_on_fitting_group[j] += sum_dxdq[0] * dq0_1[0] + - sum_dxdq[1] * dq0_1[1] + - sum_dxdq[2] * dq0_1[2] + - sum_dxdq[3] * dq0_1[3]; + fitting_force_grad += sum_dxdq[0] * dq0_1[0] + + sum_dxdq[1] * dq0_1[1] + + sum_dxdq[2] * dq0_1[2] + + sum_dxdq[3] * dq0_1[3]; + } + if (cvm::debug()) { + cvm::log(cvm::to_str(fitting_force_grad)); } + accessor_fitting(j, fitting_force_grad); } } - +template void cvm::atom_group::calc_fit_forces( - const std::vector& forces_on_main_group, - std::vector& forces_on_fitting_group) const { - if (cvm::debug()) - cvm::log("Calculating fit forces.\n"); - auto accessor = [&forces_on_main_group](size_t i){return forces_on_main_group[i];}; + main_force_accessor_T accessor_main, + fitting_force_accessor_T accessor_fitting) const { if (is_enabled(f_ag_center) && is_enabled(f_ag_rotate)) - calc_fit_forces_impl(forces_on_fitting_group, accessor); + calc_fit_forces_impl(accessor_main, accessor_fitting); if (is_enabled(f_ag_center) && !is_enabled(f_ag_rotate)) - calc_fit_forces_impl(forces_on_fitting_group, accessor); + calc_fit_forces_impl(accessor_main, accessor_fitting); if (!is_enabled(f_ag_center) && is_enabled(f_ag_rotate)) - calc_fit_forces_impl(forces_on_fitting_group, accessor); + calc_fit_forces_impl(accessor_main, accessor_fitting); if (!is_enabled(f_ag_center) && !is_enabled(f_ag_rotate)) - calc_fit_forces_impl(forces_on_fitting_group, accessor); - - if (cvm::debug()) - cvm::log("Done calculating fit forces.\n"); + calc_fit_forces_impl(accessor_main, accessor_fitting); } @@ -1500,12 +1494,6 @@ m_has_fitting_force(m_ag->is_enabled(f_ag_center) || m_ag->is_enabled(f_ag_rotat std::fill(m_ag->group_forces.begin(), m_ag->group_forces.end(), 0); } - if (m_ag->fitting_group_forces.size() != m_group_for_fit->size()) { - m_ag->fitting_group_forces.assign(m_group_for_fit->size(), 0); - } else { - std::fill(m_ag->fitting_group_forces.begin(), - m_ag->fitting_group_forces.end(), 0); - } } } @@ -1542,15 +1530,16 @@ void cvm::atom_group::group_force_object::apply_force_with_fitting_group() { // group, but checking this flag can mimic results that the users expect (if // "enableFitGradients no" then there is no force on the fitting group). if (m_ag->is_enabled(f_ag_fit_gradients)) { - m_ag->calc_fit_forces(m_ag->group_forces, m_ag->fitting_group_forces); + auto accessor_main = [this](size_t i){return m_ag->group_forces[i];}; + auto accessor_fitting = [this](size_t j, const cvm::rvector& fitting_force){ + (*(m_group_for_fit))[j].apply_force(fitting_force); + }; if (cvm::debug()) { cvm::log("Applying force on the fitting group of main group" + m_ag->name + ":\n"); } - for (size_t ia = 0; ia < m_group_for_fit->size(); ia++) { - (*(m_group_for_fit))[ia].apply_force(m_ag->fitting_group_forces[ia]); - if (cvm::debug()) { - cvm::log(cvm::to_str(m_ag->fitting_group_forces[ia])); - } + m_ag->calc_fit_forces(accessor_main, accessor_fitting); + if (cvm::debug()) { + cvm::log("Done applying force on the fitting group of main group" + m_ag->name + ":\n"); } } } diff --git a/src/colvaratoms.h b/src/colvaratoms.h index 5c96b307e..bb1c10ed5 100644 --- a/src/colvaratoms.h +++ b/src/colvaratoms.h @@ -257,8 +257,8 @@ class colvarmodule::atom_group /// \brief Index in the colvarproxy arrays (if the group is scalable) int index; + /// \brief The forces acting on the group atoms (stored mainly used for calculating the fitting group forces) std::vector group_forces; - std::vector fitting_group_forces; public: @@ -514,25 +514,47 @@ class colvarmodule::atom_group /// \brief Calculate the derivatives of the fitting transformation void calc_fit_gradients(); - void calc_fit_forces( - const std::vector& forces_on_main_group, - std::vector& forces_on_fitting_group) const; - -/*! @brief Actual implementation of `calc_fit_gradients`. The template is +/*! @brief Actual implementation of `calc_fit_gradients` and + * `calc_fit_forces`. The template is * used to avoid branching inside the loops in case that the CPU * branch prediction is broken (or further migration to GPU code). * @tparam B_ag_center Centered the reference to origin? This should follow * the value of `is_enabled(f_ag_center)`. * @tparam B_ag_rotate Calculate the optimal rotation? This should follow * the value of `is_enabled(f_ag_rotate)`. + * @tparam main_force_accessor_T The type of accessor of the main + * group forces or gradients. + * @tparam fitting_force_accessor_T The type of accessor of the fitting group + * forces or gradients. + * @param accessor_main The accessor of the main group forces or gradients. + * accessor_main(i) should return the i-th force or gradient of the + * main group. + * @param accessor_fitting The accessor of the fitting group forces or gradients. + * accessor_fitting(j, v) should store/apply the j-th atom gradient or + * force in the fitting group. */ - // template void calc_fit_gradients_impl(); - template + typename main_force_accessor_T, typename fitting_force_accessor_T> void calc_fit_forces_impl( - std::vector& forces_on_fitting_group, - main_force_accessor_T accessor) const; + main_force_accessor_T accessor_main, + fitting_force_accessor_T accessor_fitting) const; + +/*! @brief Calculate or apply the fitting group forces from the main group forces. + * @tparam main_force_accessor_T The type of accessor of the main + * group forces or gradients. + * @tparam fitting_force_accessor_T The type of accessor of the fitting group + * forces or gradients. + * @param accessor_main The accessor of the main group forces or gradients. + * accessor_main(i) should return the i-th force or gradient of the + * main group. + * @param accessor_fitting The accessor of the fitting group forces or gradients. + * accessor_fitting(j, v) should store/apply the j-th atom gradient or + * force in the fitting group. + */ + template + void calc_fit_forces( + main_force_accessor_T accessor_main, + fitting_force_accessor_T accessor_fitting) const; /// \brief Derivatives of the fitting transformation std::vector fit_gradients;