Skip to content

Commit

Permalink
Refactored SphericalVector interpolation method to use ArrayViewVariant.
Browse files Browse the repository at this point in the history
  • Loading branch information
odlomax committed Sep 25, 2024
1 parent e8b38bd commit dbf88e1
Showing 1 changed file with 36 additions and 50 deletions.
86 changes: 36 additions & 50 deletions src/atlas/interpolation/method/sphericalvector/SphericalVector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

#include "atlas/interpolation/method/sphericalvector/SphericalVector.h"

#include <cmath>
#include <variant>

#include "atlas/array/ArrayView.h"
#include "atlas/field/Field.h"
#include "atlas/field/FieldSet.h"
Expand Down Expand Up @@ -203,59 +206,42 @@ template <typename MatMul>
void SphericalVector::interpolate_vector_field(const Field& sourceField,
Field& targetField,
const MatMul& matMul) {
ATLAS_ASSERT_MSG(sourceField.variables() == 2 || sourceField.variables() == 3,
"Vector field can only have 2 or 3 components.");

if (sourceField.datatype().kind() == array::DataType::KIND_REAL64) {
interpolate_vector_field<double>(sourceField, targetField, matMul);
return;
}

if (sourceField.datatype().kind() == array::DataType::KIND_REAL32) {
interpolate_vector_field<float>(sourceField, targetField, matMul);
return;
}
const auto sourceViewVariant = array::make_device_view_variant(sourceField);

const auto sourceViewVisitor = [&](auto&& sourceView) {
using SourceView = std::decay_t<decltype(sourceView)>;
if constexpr (array::RankIs<SourceView, 2, 3>() &&
array::ValueIs<SourceView, float, double>()) {
using Value = typename SourceView::non_const_value_type;
constexpr auto Rank = SourceView::rank();
auto targetView = array::make_view<Value, Rank>(targetField);

switch (sourceField.variables()) {
case 2:
return matMul.apply(sourceView, targetView, twoVector);
case 3:
return matMul.apply(sourceView, targetView, threeVector);
default:
throw_Exception("Error: no support for " +
std::to_string(sourceField.variables()) +
" variable vector fields.\n" +
" Number of variables must be 2 or 3.",
Here());
}

} else {
throw_Exception(
"Error: no support for rank = " + std::to_string(sourceField.rank()) +
" and value type = " + sourceField.datatype().str() + ".\n" +
"Vector field must have rank 2 or 3 with value type "
"float or double",
Here());
}
};

ATLAS_NOTIMPLEMENTED;
std::visit(sourceViewVisitor, sourceViewVariant);
};

template <typename Value, typename MatMul>
void SphericalVector::interpolate_vector_field(const Field& sourceField,
Field& targetField,
const MatMul& matMul) {
if (sourceField.rank() == 2) {
interpolate_vector_field<Value, 2>(sourceField, targetField, matMul);
return;
}

if (sourceField.rank() == 3) {
interpolate_vector_field<Value, 3>(sourceField, targetField, matMul);
return;
}

ATLAS_NOTIMPLEMENTED;
}

template <typename Value, int Rank, typename MatMul>
void SphericalVector::interpolate_vector_field(const Field& sourceField,
Field& targetField,
const MatMul& matMul) {
const auto sourceView = array::make_view<Value, Rank>(sourceField);
auto targetView = array::make_view<Value, Rank>(targetField);

if (sourceField.variables() == 2) {
matMul.apply(sourceView, targetView, twoVector);
return;
}

if (sourceField.variables() == 3) {
matMul.apply(sourceView, targetView, threeVector);
return;
}

ATLAS_NOTIMPLEMENTED;
}

} // namespace method
} // namespace interpolation
} // namespace atlas

0 comments on commit dbf88e1

Please sign in to comment.