-
Notifications
You must be signed in to change notification settings - Fork 374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
EAMxx: add vert_contraction field utility #6889
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I have one doubt regarding GPU testing, plus a few suggestions that are up for grabs.
} | ||
field20.sync_to_dev(); | ||
manual_result.sync_to_dev(); | ||
REQUIRE(views_are_equal(result, manual_result)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My guess is that this won't pass on GPU, since the order of the parallel reduction is different from how you hand-rolled it.
You could use a tolerance based test:
auto diff = field20.clone("diff");
diff.update(manual_result,1,-1);
auto manual_norm = frobenius_norm<Real>(manual_result);
REQUIRE_THAT (frobenius_norm<Real>(diff), Catch::Matchers::WithinRel(manual_norm,tol));
where tol should be chosen depending on precision, so something like std::numeric_limits<Real>::epsilon()*10
.
Also, instead of hand rolling a manual reduction, you could consider using existing Field ops/utils (we can trust them, since they are already tested). So, e.g.,
for (int i=0;i<dim0;++i) {
f.subfield(0,i).scale(w); // rescale ith col by w
auto sum = field_sum<Real>(f_i); // Sum all entries of the col
manual_result.subfield(0,i).deep_copy(sum);
}
It may save some lines of code, keeping the test at a more "math" level, without view syntax polluting the test.
"The weight field has rank " | ||
<< l_w.rank() << ".\n"); | ||
EKAT_REQUIRE_MSG( | ||
l_w.tags().back() == LEV, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you foresee ILEV fields never being contracted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oversight; yes, we should readily support ILEV as well
"The input field layout is " | ||
<< l_in.to_string() << ".\n"); | ||
EKAT_REQUIRE_MSG( | ||
l_in.dim(l_in.rank() - 1) == l_w.dim(l_w.rank() - 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: here (and below) you can prob do l_in.dims().back()
, which may be more readable than l_in.dim(l_in.rank()-1)
.
|
||
// Sanity checks before handing off to the implementation | ||
EKAT_REQUIRE_MSG( | ||
l_w.rank() >= 1 && l_w.rank() <= 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably subjective, but since there are only 2 options, we may do a more explicit check like l_w.rank()==1 or l_w.rank()==2
, which has the same code length anyways. Inequalities take a bit longer to visually parse, and it may take some time for the reader to figure out that 1 and 2 are the only valid values.
<< l_in.dim(0) << ".\n"); | ||
} | ||
EKAT_REQUIRE_MSG( | ||
l_out == l_in.clone().strip_dim(l_in.rank() - 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also call strip_dim(LEV)
, which may be more verbose.
<< l_in.to_string() << ".\n"); | ||
if(l_w.rank() == 2) { | ||
EKAT_REQUIRE_MSG( | ||
l_w.tags()[0] == COL && l_in.tags()[0] == COL, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming the only possible l_in types are <COL,LEV>
, and <COL,CMP,LEV>
(possibly with CMP repeated). If so, you could maybe consolidate this check and the next via something like
EKATE_REQUIRE_MSG (l_w.congruent(l_in.clone().strip_dim(CMP,false)),
"Error! Incompatible layouts\n"
" field in: " + l_in.to_stirng() + "\n"
" weight: " + l_w.to_stirng() + "\n");
The false
arg instruct to not throw if the dim is not found.
auto v_out = f_out.get_view<ST *>(); | ||
const int d0 = l_in.dim(0); | ||
auto p = ESU::get_default_team_policy(d0, nlevs); | ||
if(l_w.rank() == 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One option, feel free to dismiss it. Instead of duplicating the whole code, you could create BOTH 1d and 2d views, and do the if inside the loop (the branch predictor will make the if cost negligible). Of course, only one of the views contains anything. So something like this:
typename Field::get_view_type<const ST*,Device> w1d;
typename Field::get_view_type<const ST**,Device> w2d;
auto one_d_w = l_w.rank()==1;
if (one_d_w)
w1d = weigth.get_view<const ST*>();
else
w2d = weight.get_view<const ST**>();
then, inside the KOKKOS_LAMBDA, simply do
if (one_d_w)
ac += w1d(j) * ...
else
ac += w2d(i,j) * ...
It may lead to a tiny bit less code. In the end, what you did is ok. But if you dislike the code duplication, the above can be a viable alternative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is precisely why I was asking you about the numpy repeat feature :)
adds a vertical contraction utility equivalent to einsum('k,...k->...', weight, field) or einsum('ik,i...k->i...', weight, field)
Following #6776 but for the vertical, except with some added complexity and more caveats. Bare-bones impl for now. Design notes regarding two broad cases supported: