Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EAMxx: add vert_contraction field utility #6889

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

mahf708
Copy link
Contributor

@mahf708 mahf708 commented Jan 10, 2025

adds a vertical contraction utility equivalent to einsum('k,...k->...', weight, field) or einsum('ik,i...k->i...', weight, field)


Following #6776 but for the vertical, except with some added complexity and more caveats. Bare-bones impl for now. Design notes regarding two broad cases supported:

  • if the weight provided is 1D, its lone dim must be LEV; this will be used to calculate vertical sums or means whose weight is column-invariant, say for aodvis (see diagnostic impl).
  • if the weight provided is 2D, its dims must be COL,LEV; this will be used to calculate vertical sums or means whose weight is column-variant, say waterpath (see diagnostic waterpath)

@mahf708 mahf708 added the EAMxx PRs focused on capabilities for EAMxx label Jan 10, 2025
Copy link
Contributor

@bartgol bartgol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I have one doubt regarding GPU testing, plus a few suggestions that are up for grabs.

}
field20.sync_to_dev();
manual_result.sync_to_dev();
REQUIRE(views_are_equal(result, manual_result));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My guess is that this won't pass on GPU, since the order of the parallel reduction is different from how you hand-rolled it.

You could use a tolerance based test:

auto diff = field20.clone("diff");
diff.update(manual_result,1,-1);
auto manual_norm = frobenius_norm<Real>(manual_result);
REQUIRE_THAT (frobenius_norm<Real>(diff), Catch::Matchers::WithinRel(manual_norm,tol));

where tol should be chosen depending on precision, so something like std::numeric_limits<Real>::epsilon()*10.

Also, instead of hand rolling a manual reduction, you could consider using existing Field ops/utils (we can trust them, since they are already tested). So, e.g.,

for (int i=0;i<dim0;++i) {
  f.subfield(0,i).scale(w); // rescale ith col by w
  auto sum = field_sum<Real>(f_i); // Sum all entries of the col
  manual_result.subfield(0,i).deep_copy(sum);
}

It may save some lines of code, keeping the test at a more "math" level, without view syntax polluting the test.

"The weight field has rank "
<< l_w.rank() << ".\n");
EKAT_REQUIRE_MSG(
l_w.tags().back() == LEV,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you foresee ILEV fields never being contracted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oversight; yes, we should readily support ILEV as well

"The input field layout is "
<< l_in.to_string() << ".\n");
EKAT_REQUIRE_MSG(
l_in.dim(l_in.rank() - 1) == l_w.dim(l_w.rank() - 1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: here (and below) you can prob do l_in.dims().back(), which may be more readable than l_in.dim(l_in.rank()-1).


// Sanity checks before handing off to the implementation
EKAT_REQUIRE_MSG(
l_w.rank() >= 1 && l_w.rank() <= 2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably subjective, but since there are only 2 options, we may do a more explicit check like l_w.rank()==1 or l_w.rank()==2, which has the same code length anyways. Inequalities take a bit longer to visually parse, and it may take some time for the reader to figure out that 1 and 2 are the only valid values.

<< l_in.dim(0) << ".\n");
}
EKAT_REQUIRE_MSG(
l_out == l_in.clone().strip_dim(l_in.rank() - 1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also call strip_dim(LEV), which may be more verbose.

<< l_in.to_string() << ".\n");
if(l_w.rank() == 2) {
EKAT_REQUIRE_MSG(
l_w.tags()[0] == COL && l_in.tags()[0] == COL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming the only possible l_in types are <COL,LEV>, and <COL,CMP,LEV> (possibly with CMP repeated). If so, you could maybe consolidate this check and the next via something like

EKATE_REQUIRE_MSG (l_w.congruent(l_in.clone().strip_dim(CMP,false)),
  "Error! Incompatible layouts\n"
  "  field in: " + l_in.to_stirng() + "\n"
  "  weight: " + l_w.to_stirng() + "\n");

The false arg instruct to not throw if the dim is not found.

auto v_out = f_out.get_view<ST *>();
const int d0 = l_in.dim(0);
auto p = ESU::get_default_team_policy(d0, nlevs);
if(l_w.rank() == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One option, feel free to dismiss it. Instead of duplicating the whole code, you could create BOTH 1d and 2d views, and do the if inside the loop (the branch predictor will make the if cost negligible). Of course, only one of the views contains anything. So something like this:

typename Field::get_view_type<const ST*,Device> w1d;
typename Field::get_view_type<const ST**,Device> w2d;
auto one_d_w =  l_w.rank()==1;
if (one_d_w)
  w1d = weigth.get_view<const ST*>();
else
  w2d = weight.get_view<const ST**>();

then, inside the KOKKOS_LAMBDA, simply do

if (one_d_w)
  ac += w1d(j) * ...
else
  ac += w2d(i,j) * ...

It may lead to a tiny bit less code. In the end, what you did is ok. But if you dislike the code duplication, the above can be a viable alternative.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is precisely why I was asking you about the numpy repeat feature :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
EAMxx PRs focused on capabilities for EAMxx
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants