Skip to content

Commit

Permalink
Update Method::adjoint_interpolate_* function to use multiply-add.
Browse files Browse the repository at this point in the history
  • Loading branch information
l90lpa authored and wdeconinck committed Nov 22, 2024
1 parent a435f75 commit e639a06
Showing 1 changed file with 4 additions and 39 deletions.
43 changes: 4 additions & 39 deletions src/atlas/interpolation/method/Method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,65 +168,30 @@ void Method::interpolate_field_rank3(const Field& src, Field& tgt, const Matrix&

template <typename Value>
void Method::adjoint_interpolate_field_rank1(Field& src, const Field& tgt, const Matrix& W) const {
array::ArrayT<Value> tmp(src.shape());
auto backend = std::is_same<Value, float>::value ? sparse::backend::openmp() : sparse::Backend{linalg_backend_};

auto tmp_v = array::make_view<Value, 1>(tmp);
auto src_v = array::make_view<Value, 1>(src);
auto tgt_v = array::make_view<Value, 1>(tgt);

tmp_v.assign(0.);

if (std::is_same<Value, float>::value) {
sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::backend::openmp());
}
else {
sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::Backend{linalg_backend_});
}


for (idx_t t = 0; t < tmp.shape(0); ++t) {
src_v(t) += tmp_v(t);
}
sparse_matrix_multiply_add(W, tgt_v, src_v, backend);
}

template <typename Value>
void Method::adjoint_interpolate_field_rank2(Field& src, const Field& tgt, const Matrix& W) const {
array::ArrayT<Value> tmp(src.shape());

auto tmp_v = array::make_view<Value, 2>(tmp);
auto src_v = array::make_view<Value, 2>(src);
auto tgt_v = array::make_view<Value, 2>(tgt);

tmp_v.assign(0.);

sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::backend::openmp());

for (idx_t t = 0; t < tmp.shape(0); ++t) {
for (idx_t k = 0; k < tmp.shape(1); ++k) {
src_v(t, k) += tmp_v(t, k);
}
}
sparse_matrix_multiply_add(W, tgt_v, src_v, sparse::backend::openmp());
}

template <typename Value>
void Method::adjoint_interpolate_field_rank3(Field& src, const Field& tgt, const Matrix& W) const {
array::ArrayT<Value> tmp(src.shape());

auto tmp_v = array::make_view<Value, 3>(tmp);
auto src_v = array::make_view<Value, 3>(src);
auto tgt_v = array::make_view<Value, 3>(tgt);

tmp_v.assign(0.);

sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::backend::openmp());

for (idx_t t = 0; t < tmp.shape(0); ++t) {
for (idx_t j = 0; j < tmp.shape(1); ++j) {
for (idx_t k = 0; k < tmp.shape(2); ++k) {
src_v(t, j, k) += tmp_v(t, j, k);
}
}
}
sparse_matrix_multiply_add(W, tgt_v, src_v, sparse::backend::openmp());
}

void Method::check_compatibility(const Field& src, const Field& tgt, const Matrix& W) const {
Expand Down

0 comments on commit e639a06

Please sign in to comment.