From 1373e9c45b6025492797de1e86612e3e470d9025 Mon Sep 17 00:00:00 2001 From: mahf708 Date: Sat, 23 Nov 2024 16:12:53 -0800 Subject: [PATCH] EAMxx: add column reduction utility to fields --- .../eamxx/src/share/field/field_utils.hpp | 12 ++ .../share/field/field_utils_impl_colred.hpp | 123 ++++++++++++++++++ .../eamxx/src/share/tests/field_utils.cpp | 88 +++++++++++++ 3 files changed, 223 insertions(+) create mode 100644 components/eamxx/src/share/field/field_utils_impl_colred.hpp diff --git a/components/eamxx/src/share/field/field_utils.hpp b/components/eamxx/src/share/field/field_utils.hpp index 8f977a7caa1b..b75effa93598 100644 --- a/components/eamxx/src/share/field/field_utils.hpp +++ b/components/eamxx/src/share/field/field_utils.hpp @@ -2,6 +2,7 @@ #define SCREAM_FIELD_UTILS_HPP #include "share/field/field_utils_impl.hpp" +#include "share/field/field_utils_impl_colred.hpp" namespace scream { @@ -111,6 +112,17 @@ void perturb (const Field& f, impl::perturb(f, engine, pdf, base_seed, level_mask, dof_gids); } +template +Field column_reduction(const Field &f1, const Field &f2, + const ekat::Comm *comm = nullptr) { + EKAT_REQUIRE_MSG(f1.is_allocated() && f2.is_allocated(), + "Error! Input fields must be allocated."); + EKAT_REQUIRE_MSG(f1.data_type() == f2.data_type(), + "Error! Input fields must have matching data types."); + + return impl::column_reduction(f1, f2, comm); +} + template ST frobenius_norm(const Field& f, const ekat::Comm* comm = nullptr) { diff --git a/components/eamxx/src/share/field/field_utils_impl_colred.hpp b/components/eamxx/src/share/field/field_utils_impl_colred.hpp new file mode 100644 index 000000000000..5a18b42b7a2f --- /dev/null +++ b/components/eamxx/src/share/field/field_utils_impl_colred.hpp @@ -0,0 +1,123 @@ +#ifndef SCREAM_FIELD_UTILS_IMPL_COLRED_HPP +#define SCREAM_FIELD_UTILS_IMPL_COLRED_HPP + +#include "ekat/kokkos/ekat_kokkos_utils.hpp" +#include "ekat/mpi/ekat_comm.hpp" +#include "share/field/field.hpp" + +namespace scream { +namespace impl { + +// Utility to compute the reduction of a field along its column dimension. +// This is equivalent to einsum('i,i...k->...k', f1, f2); i is the column. +// The layouts are such that: +// - The first dimension is for the columns (COL) +// - There can be only up to 3 dimensions + +template +Field column_reduction(const Field &f1, const Field &f2, const ekat::Comm *co) { + using KT = ekat::KokkosTypes; + using RangePolicy = Kokkos::RangePolicy; + using TeamPolicy = Kokkos::TeamPolicy; + using TeamMember = typename TeamPolicy::member_type; + using ESU = ekat::ExeSpaceUtils; + using namespace ShortFieldTagsNames; + + const auto &l1 = f1.get_header().get_identifier().get_layout(); + + EKAT_REQUIRE_MSG(l1.rank() == 1, + "Error! First field f1 must be rank-1.\n" + "The input has rank " + << l1.rank() << ".\n"); + EKAT_REQUIRE_MSG(l1.tags() == std::vector({COL}), + "Error! First field f1 must have a column dimension.\n" + "The input f1 layout is " + << l1.tags() << ".\n"); + + const auto &n2 = f2.get_header().get_identifier().name(); + const auto &l2 = f2.get_header().get_identifier().get_layout(); + const auto &u2 = f2.get_header().get_identifier().get_units(); + const auto &g2 = f2.get_header().get_identifier().get_grid_name(); + + EKAT_REQUIRE_MSG(l2.rank() <= 3, + "Error! Second field f2 must be at most rank-3.\n" + "The input f2 rank is " + << l2.rank() << ".\n"); + EKAT_REQUIRE_MSG(l2.tags()[0] == COL, + "Error! Second field f2 must have a column dimension.\n" + "The input f2 layout is " + << l2.tags() << ".\n"); + EKAT_REQUIRE_MSG( + l1.dim(0) == l2.dim(0), + "Error! The two input fields must have the same dimension along " + "which we are taking the reducing the field.\n" + "The first field f1 has dimension " + << l1.dim(0) + << " while " + "the second field f2 has dimension " + << l2.dim(0) << ".\n"); + + auto v1 = f1.get_view(); + + FieldIdentifier fo_id(n2 + "_colred", l2.clone().strip_dim(0), u2, g2); + Field fo(fo_id); + fo.allocate_view(); + fo.deep_copy(0); + + const int d0 = l2.dim(0); + + switch(l2.rank()) { + case 1: { + auto v2 = f2.get_view(); + auto vo = fo.get_view(); + Kokkos::parallel_reduce( + fo.name(), RangePolicy(0, d0), + KOKKOS_LAMBDA(const int i, ST &ls) { ls += v1(i) * v2(i); }, vo); + } break; + case 2: { + auto v2 = f2.get_view(); + auto vo = fo.get_view(); + const int d1 = l2.dim(1); + auto p = ESU::get_default_team_policy(d1, d0); + Kokkos::parallel_for( + fo.name(), p, KOKKOS_LAMBDA(const TeamMember &tm) { + const int j = tm.league_rank(); + Kokkos::parallel_reduce( + Kokkos::TeamVectorRange(tm, d0), + [&](int i, ST &ac) { ac += v1(i) * v2(i, j); }, vo(j)); + }); + } break; + case 3: { + auto v2 = f2.get_view(); + auto vo = fo.get_view(); + const int d1 = l2.dim(1); + const int d2 = l2.dim(2); + auto p = ESU::get_default_team_policy(d1 * d2, d0); + Kokkos::parallel_for( + fo.name(), p, KOKKOS_LAMBDA(const TeamMember &tm) { + const int idx = tm.league_rank(); + const int j = idx / d2; + const int k = idx % d2; + Kokkos::parallel_reduce( + Kokkos::TeamVectorRange(tm, d0), + [&](int i, ST &ac) { ac += v1(i) * v2(i, j, k); }, vo(j, k)); + }); + } break; + default: + EKAT_ERROR_MSG("Error! Unsupported field rank.\n"); + } + + if(co) { + Kokkos::fence(); + fo.sync_to_host(); + co->all_reduce(fo.template get_internal_view_data(), + l2.size() / l2.dim(0), MPI_SUM); + fo.sync_to_dev(); + } + return fo; +} + +} // namespace impl +} // namespace scream + +#endif // SCREAM_FIELD_UTILS_IMPL_COLRED_HPP diff --git a/components/eamxx/src/share/tests/field_utils.cpp b/components/eamxx/src/share/tests/field_utils.cpp index f444ec75d52c..052eaa07826a 100644 --- a/components/eamxx/src/share/tests/field_utils.cpp +++ b/components/eamxx/src/share/tests/field_utils.cpp @@ -126,6 +126,94 @@ TEST_CASE("utils") { REQUIRE(field_sum(f1,&comm)==gsum); } + SECTION("column_reduction") { + using RPDF = std::uniform_real_distribution; + auto engine = setup_random_test(); + RPDF pdf(0, 1); + + int dim0 = 3; + int dim1 = 9; + int dim2 = 2; + FieldIdentifier f00("f", {{COL}, {dim0}}, m / s, "g"); + Field field00(f00); + field00.allocate_view(); + field00.sync_to_host(); + auto v00 = field00.get_strided_view(); + for(int i = 0; i < dim0; ++i) { + v00(i) = (i + 1) / sp(6); + } + field00.sync_to_dev(); + + FieldIdentifier f10("f", {{COL, CMP}, {dim0, dim1}}, m / s, "g"); + FieldIdentifier f11("f", {{COL, LEV}, {dim0, dim2}}, m / s, "g"); + FieldIdentifier f20("f", {{COL, CMP, LEV}, {dim0, dim1, dim2}}, m / s, "g"); + + Field field10(f10); + Field field11(f11); + Field field20(f20); + field10.allocate_view(); + field11.allocate_view(); + field20.allocate_view(); + + randomize(field10, engine, pdf); + randomize(field11, engine, pdf); + randomize(field20, engine, pdf); + + FieldIdentifier F_x("fx", {{COL}, {dim1}}, m/s, "g"); + FieldIdentifier F_y("fy", {{LEV}, {dim2}}, m/s, "g"); + + Field field_x(F_x); + Field field_y(F_y); + + REQUIRE_THROWS(column_reduction(field00, field_x)); // x not allocated + + field_x.allocate_view(); + field_y.allocate_view(); + + REQUIRE_THROWS(column_reduction(field_x, field_y)); // unmatching layout + REQUIRE_THROWS(column_reduction(field11, field11)); // wrong f1 layout + + Field result; + + result = column_reduction(field00, field00); + result.sync_to_host(); + auto v = result.get_view(); + REQUIRE(v() == (1 / sp(36) + 4 / sp(36) + 9 / sp(36))); + + result = column_reduction(field00, field10); + REQUIRE(result.get_header().get_identifier().get_layout().tags() == + std::vector({CMP})); + REQUIRE(result.get_header().get_identifier().get_layout().dim(0) == dim1); + + result = column_reduction(field00, field11); + REQUIRE(result.get_header().get_identifier().get_layout().tags() == + std::vector({LEV})); + REQUIRE(result.get_header().get_identifier().get_layout().dim(0) == dim2); + + result = column_reduction(field00, field20); + REQUIRE(result.get_header().get_identifier().get_layout().tags() == + std::vector({CMP, LEV})); + REQUIRE(result.get_header().get_identifier().get_layout().dim(0) == dim1); + REQUIRE(result.get_header().get_identifier().get_layout().dim(1) == dim2); + + field20.sync_to_host(); + auto manual_result = result.clone(); + manual_result.deep_copy(0); + manual_result.sync_to_host(); + auto v2 = field20.get_strided_view(); + auto mr = manual_result.get_strided_view(); + for(int i = 0; i < dim0; ++i) { + for(int j = 0; j < dim1; ++j) { + for(int k = 0; k < dim2; ++k) { + mr(j, k) += v00(i) * v2(i, j, k); + } + } + } + field20.sync_to_dev(); + manual_result.sync_to_dev(); + REQUIRE(views_are_equal(result, manual_result)); + } + SECTION ("frobenius") { auto v1 = f1.get_strided_view();