Skip to content

Commit

Permalink
EAMxx: add column reduction utility to fields
Browse files Browse the repository at this point in the history
  • Loading branch information
mahf708 committed Nov 24, 2024
1 parent 49fdbe3 commit 1373e9c
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 0 deletions.
12 changes: 12 additions & 0 deletions components/eamxx/src/share/field/field_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define SCREAM_FIELD_UTILS_HPP

#include "share/field/field_utils_impl.hpp"
#include "share/field/field_utils_impl_colred.hpp"

namespace scream {

Expand Down Expand Up @@ -111,6 +112,17 @@ void perturb (const Field& f,
impl::perturb<ST>(f, engine, pdf, base_seed, level_mask, dof_gids);
}

template <typename ST>
Field column_reduction(const Field &f1, const Field &f2,
const ekat::Comm *comm = nullptr) {
EKAT_REQUIRE_MSG(f1.is_allocated() && f2.is_allocated(),
"Error! Input fields must be allocated.");
EKAT_REQUIRE_MSG(f1.data_type() == f2.data_type(),
"Error! Input fields must have matching data types.");

return impl::column_reduction<ST>(f1, f2, comm);
}

template<typename ST>
ST frobenius_norm(const Field& f, const ekat::Comm* comm = nullptr)
{
Expand Down
123 changes: 123 additions & 0 deletions components/eamxx/src/share/field/field_utils_impl_colred.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#ifndef SCREAM_FIELD_UTILS_IMPL_COLRED_HPP
#define SCREAM_FIELD_UTILS_IMPL_COLRED_HPP

#include "ekat/kokkos/ekat_kokkos_utils.hpp"
#include "ekat/mpi/ekat_comm.hpp"
#include "share/field/field.hpp"

namespace scream {
namespace impl {

// Utility to compute the reduction of a field along its column dimension.
// This is equivalent to einsum('i,i...k->...k', f1, f2); i is the column.
// The layouts are such that:
// - The first dimension is for the columns (COL)
// - There can be only up to 3 dimensions

template <typename ST>
Field column_reduction(const Field &f1, const Field &f2, const ekat::Comm *co) {
using KT = ekat::KokkosTypes<DefaultDevice>;
using RangePolicy = Kokkos::RangePolicy<Field::device_t::execution_space>;
using TeamPolicy = Kokkos::TeamPolicy<Field::device_t::execution_space>;
using TeamMember = typename TeamPolicy::member_type;
using ESU = ekat::ExeSpaceUtils<typename KT::ExeSpace>;
using namespace ShortFieldTagsNames;

const auto &l1 = f1.get_header().get_identifier().get_layout();

EKAT_REQUIRE_MSG(l1.rank() == 1,
"Error! First field f1 must be rank-1.\n"
"The input has rank "
<< l1.rank() << ".\n");
EKAT_REQUIRE_MSG(l1.tags() == std::vector<FieldTag>({COL}),
"Error! First field f1 must have a column dimension.\n"
"The input f1 layout is "
<< l1.tags() << ".\n");

const auto &n2 = f2.get_header().get_identifier().name();
const auto &l2 = f2.get_header().get_identifier().get_layout();
const auto &u2 = f2.get_header().get_identifier().get_units();
const auto &g2 = f2.get_header().get_identifier().get_grid_name();

EKAT_REQUIRE_MSG(l2.rank() <= 3,
"Error! Second field f2 must be at most rank-3.\n"
"The input f2 rank is "
<< l2.rank() << ".\n");
EKAT_REQUIRE_MSG(l2.tags()[0] == COL,
"Error! Second field f2 must have a column dimension.\n"
"The input f2 layout is "
<< l2.tags() << ".\n");
EKAT_REQUIRE_MSG(
l1.dim(0) == l2.dim(0),
"Error! The two input fields must have the same dimension along "
"which we are taking the reducing the field.\n"
"The first field f1 has dimension "
<< l1.dim(0)
<< " while "
"the second field f2 has dimension "
<< l2.dim(0) << ".\n");

auto v1 = f1.get_view<const ST *>();

FieldIdentifier fo_id(n2 + "_colred", l2.clone().strip_dim(0), u2, g2);
Field fo(fo_id);
fo.allocate_view();
fo.deep_copy(0);

const int d0 = l2.dim(0);

switch(l2.rank()) {
case 1: {
auto v2 = f2.get_view<const ST *>();
auto vo = fo.get_view<ST>();
Kokkos::parallel_reduce(
fo.name(), RangePolicy(0, d0),
KOKKOS_LAMBDA(const int i, ST &ls) { ls += v1(i) * v2(i); }, vo);
} break;
case 2: {
auto v2 = f2.get_view<const ST **>();
auto vo = fo.get_view<ST *>();
const int d1 = l2.dim(1);
auto p = ESU::get_default_team_policy(d1, d0);
Kokkos::parallel_for(
fo.name(), p, KOKKOS_LAMBDA(const TeamMember &tm) {
const int j = tm.league_rank();
Kokkos::parallel_reduce(
Kokkos::TeamVectorRange(tm, d0),
[&](int i, ST &ac) { ac += v1(i) * v2(i, j); }, vo(j));
});
} break;
case 3: {
auto v2 = f2.get_view<const ST ***>();
auto vo = fo.get_view<ST **>();
const int d1 = l2.dim(1);
const int d2 = l2.dim(2);
auto p = ESU::get_default_team_policy(d1 * d2, d0);
Kokkos::parallel_for(
fo.name(), p, KOKKOS_LAMBDA(const TeamMember &tm) {
const int idx = tm.league_rank();
const int j = idx / d2;
const int k = idx % d2;
Kokkos::parallel_reduce(
Kokkos::TeamVectorRange(tm, d0),
[&](int i, ST &ac) { ac += v1(i) * v2(i, j, k); }, vo(j, k));
});
} break;
default:
EKAT_ERROR_MSG("Error! Unsupported field rank.\n");
}

if(co) {
Kokkos::fence();
fo.sync_to_host();
co->all_reduce(fo.template get_internal_view_data<ST, Host>(),
l2.size() / l2.dim(0), MPI_SUM);
fo.sync_to_dev();
}
return fo;
}

} // namespace impl
} // namespace scream

#endif // SCREAM_FIELD_UTILS_IMPL_COLRED_HPP
88 changes: 88 additions & 0 deletions components/eamxx/src/share/tests/field_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,94 @@ TEST_CASE("utils") {
REQUIRE(field_sum<Real>(f1,&comm)==gsum);
}

SECTION("column_reduction") {
using RPDF = std::uniform_real_distribution<Real>;
auto engine = setup_random_test();
RPDF pdf(0, 1);

int dim0 = 3;
int dim1 = 9;
int dim2 = 2;
FieldIdentifier f00("f", {{COL}, {dim0}}, m / s, "g");
Field field00(f00);
field00.allocate_view();
field00.sync_to_host();
auto v00 = field00.get_strided_view<Real *, Host>();
for(int i = 0; i < dim0; ++i) {
v00(i) = (i + 1) / sp(6);
}
field00.sync_to_dev();

FieldIdentifier f10("f", {{COL, CMP}, {dim0, dim1}}, m / s, "g");
FieldIdentifier f11("f", {{COL, LEV}, {dim0, dim2}}, m / s, "g");
FieldIdentifier f20("f", {{COL, CMP, LEV}, {dim0, dim1, dim2}}, m / s, "g");

Field field10(f10);
Field field11(f11);
Field field20(f20);
field10.allocate_view();
field11.allocate_view();
field20.allocate_view();

randomize(field10, engine, pdf);
randomize(field11, engine, pdf);
randomize(field20, engine, pdf);

FieldIdentifier F_x("fx", {{COL}, {dim1}}, m/s, "g");
FieldIdentifier F_y("fy", {{LEV}, {dim2}}, m/s, "g");

Field field_x(F_x);
Field field_y(F_y);

REQUIRE_THROWS(column_reduction<Real>(field00, field_x)); // x not allocated

field_x.allocate_view();
field_y.allocate_view();

REQUIRE_THROWS(column_reduction<Real>(field_x, field_y)); // unmatching layout
REQUIRE_THROWS(column_reduction<Real>(field11, field11)); // wrong f1 layout

Field result;

result = column_reduction<Real>(field00, field00);
result.sync_to_host();
auto v = result.get_view<Real, Host>();
REQUIRE(v() == (1 / sp(36) + 4 / sp(36) + 9 / sp(36)));

result = column_reduction<Real>(field00, field10);
REQUIRE(result.get_header().get_identifier().get_layout().tags() ==
std::vector<FieldTag>({CMP}));
REQUIRE(result.get_header().get_identifier().get_layout().dim(0) == dim1);

result = column_reduction<Real>(field00, field11);
REQUIRE(result.get_header().get_identifier().get_layout().tags() ==
std::vector<FieldTag>({LEV}));
REQUIRE(result.get_header().get_identifier().get_layout().dim(0) == dim2);

result = column_reduction<Real>(field00, field20);
REQUIRE(result.get_header().get_identifier().get_layout().tags() ==
std::vector<FieldTag>({CMP, LEV}));
REQUIRE(result.get_header().get_identifier().get_layout().dim(0) == dim1);
REQUIRE(result.get_header().get_identifier().get_layout().dim(1) == dim2);

field20.sync_to_host();
auto manual_result = result.clone();
manual_result.deep_copy(0);
manual_result.sync_to_host();
auto v2 = field20.get_strided_view<Real ***, Host>();
auto mr = manual_result.get_strided_view<Real **, Host>();
for(int i = 0; i < dim0; ++i) {
for(int j = 0; j < dim1; ++j) {
for(int k = 0; k < dim2; ++k) {
mr(j, k) += v00(i) * v2(i, j, k);
}
}
}
field20.sync_to_dev();
manual_result.sync_to_dev();
REQUIRE(views_are_equal(result, manual_result));
}

SECTION ("frobenius") {

auto v1 = f1.get_strided_view<Real**>();
Expand Down

0 comments on commit 1373e9c

Please sign in to comment.