diff --git a/src/atlas/interpolation/method/Method.cc b/src/atlas/interpolation/method/Method.cc index 3a2349a1d..1195a7d53 100644 --- a/src/atlas/interpolation/method/Method.cc +++ b/src/atlas/interpolation/method/Method.cc @@ -168,65 +168,30 @@ void Method::interpolate_field_rank3(const Field& src, Field& tgt, const Matrix& template void Method::adjoint_interpolate_field_rank1(Field& src, const Field& tgt, const Matrix& W) const { - array::ArrayT tmp(src.shape()); + auto backend = std::is_same::value ? sparse::backend::openmp() : sparse::Backend{linalg_backend_}; - auto tmp_v = array::make_view(tmp); auto src_v = array::make_view(src); auto tgt_v = array::make_view(tgt); - tmp_v.assign(0.); - - if (std::is_same::value) { - sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::backend::openmp()); - } - else { - sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::Backend{linalg_backend_}); - } - - - for (idx_t t = 0; t < tmp.shape(0); ++t) { - src_v(t) += tmp_v(t); - } + sparse_matrix_multiply_add(W, tgt_v, src_v, backend); } template void Method::adjoint_interpolate_field_rank2(Field& src, const Field& tgt, const Matrix& W) const { - array::ArrayT tmp(src.shape()); - auto tmp_v = array::make_view(tmp); auto src_v = array::make_view(src); auto tgt_v = array::make_view(tgt); - tmp_v.assign(0.); - - sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::backend::openmp()); - - for (idx_t t = 0; t < tmp.shape(0); ++t) { - for (idx_t k = 0; k < tmp.shape(1); ++k) { - src_v(t, k) += tmp_v(t, k); - } - } + sparse_matrix_multiply_add(W, tgt_v, src_v, sparse::backend::openmp()); } template void Method::adjoint_interpolate_field_rank3(Field& src, const Field& tgt, const Matrix& W) const { - array::ArrayT tmp(src.shape()); - auto tmp_v = array::make_view(tmp); auto src_v = array::make_view(src); auto tgt_v = array::make_view(tgt); - tmp_v.assign(0.); - - sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::backend::openmp()); - - for (idx_t t = 0; t < tmp.shape(0); ++t) { - for (idx_t j = 0; j < tmp.shape(1); ++j) { - for (idx_t k = 0; k < tmp.shape(2); ++k) { - src_v(t, j, k) += tmp_v(t, j, k); - } - } - } + sparse_matrix_multiply_add(W, tgt_v, src_v, sparse::backend::openmp()); } void Method::check_compatibility(const Field& src, const Field& tgt, const Matrix& W) const {