-
Notifications
You must be signed in to change notification settings - Fork 374
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
EAMxx: add column reduction utility to fields
- Loading branch information
Showing
3 changed files
with
199 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
118 changes: 118 additions & 0 deletions
118
components/eamxx/src/share/field/field_utils_impl_colred.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
#ifndef SCREAM_FIELD_UTILS_IMPL_COLRED_HPP | ||
#define SCREAM_FIELD_UTILS_IMPL_COLRED_HPP | ||
|
||
#include "ekat/kokkos/ekat_kokkos_utils.hpp" | ||
#include "ekat/mpi/ekat_comm.hpp" | ||
#include "share/field/field.hpp" | ||
|
||
namespace scream { | ||
namespace impl { | ||
|
||
// utility to compute the column reduction of a field | ||
// this is equivalent to einsum('i...k,i->...k', f1, f2) | ||
// where we only support layouts such that: | ||
// - the first dimension is for the columns (col, i) | ||
// - the last dimension is for the levels (lev, k) | ||
// - at most, one dimension in between (cmp, ...) | ||
|
||
template <typename ST> | ||
Field column_reduction(const Field &f1, const Field &f2, const ekat::Comm *co) { | ||
using KT = ekat::KokkosTypes<DefaultDevice>; | ||
using RangePolicy = Kokkos::RangePolicy<Field::device_t::execution_space>; | ||
using TeamPolicy = Kokkos::TeamPolicy<Field::device_t::execution_space>; | ||
using TeamMember = typename TeamPolicy::member_type; | ||
using ESU = ekat::ExeSpaceUtils<typename KT::ExeSpace>; | ||
using namespace ShortFieldTagsNames; | ||
|
||
const auto &l1 = f1.get_header().get_identifier().get_layout(); | ||
|
||
const auto &n2 = f2.get_header().get_identifier().name(); | ||
const auto &l2 = f2.get_header().get_identifier().get_layout(); | ||
const auto &u2 = f2.get_header().get_identifier().get_units(); | ||
const auto &g2 = f2.get_header().get_identifier().get_grid_name(); | ||
|
||
EKAT_REQUIRE_MSG(l1.rank() == 1, | ||
"Error! First field f1 must be rank-1.\n" | ||
"The input f1 rank is " | ||
<< l1.rank() << ", which is not accepted.\n"); | ||
EKAT_REQUIRE_MSG(l2.rank() <= 3, | ||
"Error! Third argument f2 must be at most rank-3.\n" | ||
"The input f2 rank is " | ||
<< l2.rank() << ", which is not accepted.\n"); | ||
EKAT_REQUIRE_MSG(l1.tags() == std::vector<FieldTag>({COL}), | ||
"Error! The first field f1 must have a column dimension.\n" | ||
"The input f1 layout is " | ||
<< l1.tags() << ", which is not accepted.\n"); | ||
EKAT_REQUIRE_MSG( | ||
l1.dim(0) == l2.dim(0), | ||
"Error! The two input fields must have the same dimension along " | ||
"which we are taking the dot product.\n" | ||
"The first field f1 has dimension " | ||
<< l1.dim(0) | ||
<< " while " | ||
"the second field f2 has dimension " | ||
<< l2.dim(0) << ".\n"); | ||
|
||
auto v1 = f1.get_view<const ST *>(); | ||
|
||
FieldIdentifier fo_id(n2, l2.clone().strip_dim(0), u2, g2); | ||
Field fo(fo_id); | ||
fo.allocate_view(); | ||
fo.deep_copy(0); | ||
|
||
const int d0 = l2.dim(0); | ||
|
||
switch(l2.rank()) { | ||
case 1: { | ||
auto v2 = f2.get_view<ST *>(); | ||
auto vo = fo.get_view<ST>(); | ||
Kokkos::parallel_reduce( | ||
fo.name(), Kokkos::RangePolicy<>(0, d0), | ||
KOKKOS_LAMBDA(const int i, Real &ls) { ls += v1(i) * v2(i); }, vo); | ||
} break; | ||
case 2: { | ||
auto v2 = f2.get_view<const ST **>(); | ||
auto vo = fo.get_view<ST *>(); | ||
const int d1 = l2.dim(1); | ||
auto p = ESU::get_default_team_policy(d1, d0); | ||
Kokkos::parallel_for( | ||
fo.name(), p, KOKKOS_LAMBDA(const TeamMember &tm) { | ||
const int j = tm.league_rank(); | ||
Kokkos::parallel_reduce( | ||
Kokkos::TeamVectorRange(tm, d0), | ||
[&](int i, ST &ac) { ac += v1(i) * v2(i, j); }, vo(j)); | ||
}); | ||
} break; | ||
case 3: { | ||
auto v2 = f2.get_view<const ST ***>(); | ||
auto vo = fo.get_view<ST **>(); | ||
const int d1 = l2.dim(1); | ||
const int d2 = l2.dim(2); | ||
auto p = ESU::get_default_team_policy(d1 * d2, d0); | ||
Kokkos::parallel_for( | ||
fo.name(), p, KOKKOS_LAMBDA(const TeamMember &tm) { | ||
const int idx = tm.league_rank(); | ||
const int j = idx / d2; | ||
const int k = idx % d2; | ||
Kokkos::parallel_reduce( | ||
Kokkos::TeamVectorRange(tm, d0), | ||
[&](int i, ST &ac) { ac += v1(i) * v2(i, j, k); }, vo(j, k)); | ||
}); | ||
} break; | ||
default: | ||
EKAT_ERROR_MSG("Error! Unsupported field rank.\n"); | ||
} | ||
Kokkos::fence(); | ||
if(co) { | ||
fo.sync_to_host(); | ||
co->all_reduce(fo.template get_internal_view_data<ST, Host>(), | ||
l2.size() / l2.dim(0), MPI_SUM); | ||
fo.sync_to_dev(); | ||
} | ||
return fo; | ||
} | ||
|
||
} // namespace impl | ||
} // namespace scream | ||
|
||
#endif // SCREAM_FIELD_UTILS_IMPL_COLRED_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters