Skip to content

Commit

Permalink
Refactored pack_vector_fields to use ArrayViewVariant.
Browse files Browse the repository at this point in the history
  • Loading branch information
odlomax committed Sep 25, 2024
1 parent 55cfbe4 commit 1f4b191
Showing 1 changed file with 17 additions and 41 deletions.
58 changes: 17 additions & 41 deletions src/atlas/util/PackVectorFields.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,51 +84,30 @@ void checkFieldCompatibility(const Field& componentField,

template <typename ComponentField, typename VectorField, typename Functor>
void copyFieldData(ComponentField& componentField, VectorField& vectorField,
const Functor& copier) {
const Functor& copier) {
checkFieldCompatibility(componentField, vectorField);

const auto copyArrayData = [&](auto value, auto rank) {
// Resolve value-type and rank from arguments.
using Value = decltype(value);
constexpr auto Rank = decltype(rank)::value;

// Iterate over fields.
auto vectorView = array::make_view<Value, Rank>(vectorField);
auto componentView = array::make_view<Value, Rank - 1>(componentField);
constexpr auto Dims = std::make_integer_sequence<int, Rank - 1>{};
arrayForEachDim(Dims, execution::par, std::tie(componentView, vectorView),
copier);
};
auto componentViewVariant = array::make_view_variant(componentField);

const auto selectRank = [&](auto value) {
switch (vectorField.rank()) {
case 2:
return copyArrayData(value, std::integral_constant<int, 2>{});
case 3:
return copyArrayData(value, std::integral_constant<int, 3>{});
default:
ATLAS_THROW_EXCEPTION("Unsupported vector field rank: " +
std::to_string(vectorField.rank()));
}
};
const auto componentVisitor = [&](auto&& componentView) {
using ComponentView = std::decay_t<decltype(componentView)>;

if constexpr (array::RankIs<ComponentView, 1, 2>()) {
constexpr auto ComponentRank = ComponentView::rank();
using Value = typename ComponentView::non_const_value_type;

auto vectorView = array::make_view<Value, ComponentRank + 1>(vectorField);
constexpr auto Dims = std::make_integer_sequence<int, ComponentRank>{};
arrayForEachDim(Dims, execution::par, std::tie(componentView, vectorView),
copier);

const auto selectType = [&]() {
switch (vectorField.datatype().kind()) {
case DataType::kind<double>():
return selectRank(double{});
case DataType::kind<float>():
return selectRank(float{});
case DataType::kind<long>():
return selectRank(long{});
case DataType::kind<int>():
return selectRank(int{});
default:
ATLAS_THROW_EXCEPTION("Unknown datatype: " +
std::to_string(vectorField.datatype().kind()));
} else {
ATLAS_THROW_EXCEPTION("Unsupported vector field rank: " +
std::to_string(componentView.rank()));
}
};

selectType();
std::visit(componentVisitor, componentViewVariant);
}

} // namespace
Expand Down Expand Up @@ -187,8 +166,6 @@ FieldSet pack_vector_fields(const FieldSet& fields, FieldSet packedFields) {
} else {
vectorField.set_dirty(vectorField.dirty() || componentField.dirty());
}


}
return packedFields;
}
Expand All @@ -208,7 +185,6 @@ FieldSet unpack_vector_fields(const FieldSet& fields, FieldSet unpackedFields) {

auto vectorIndex = 0;
for (const auto& componentFieldMetadata : componentFieldMetadataVector) {

// Get or create field.
auto componentFieldName = std::string{};
componentFieldMetadata.get("name", componentFieldName);
Expand Down

0 comments on commit 1f4b191

Please sign in to comment.