Skip to content

Commit

Permalink
Refactored (un)pack_vector_fields.
Browse files Browse the repository at this point in the history
  • Loading branch information
odlomax committed Oct 14, 2024
1 parent 96edef9 commit b3f83ad
Showing 1 changed file with 16 additions and 43 deletions.
59 changes: 16 additions & 43 deletions src/atlas/util/PackVectorFields.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ namespace util {
namespace {

using eckit::LocalConfiguration;

using array::DataType;
using array::helpers::arrayForEachDim;

void addOrReplaceField(FieldSet& fieldSet, const Field& field) {
Expand Down Expand Up @@ -84,51 +82,29 @@ void checkFieldCompatibility(const Field& componentField,

template <typename ComponentField, typename VectorField, typename Functor>
void copyFieldData(ComponentField& componentField, VectorField& vectorField,
const Functor& copier) {
const Functor& copier) {
checkFieldCompatibility(componentField, vectorField);

const auto copyArrayData = [&](auto value, auto rank) {
// Resolve value-type and rank from arguments.
using Value = decltype(value);
constexpr auto Rank = decltype(rank)::value;

// Iterate over fields.
auto vectorView = array::make_view<Value, Rank>(vectorField);
auto componentView = array::make_view<Value, Rank - 1>(componentField);
constexpr auto Dims = std::make_integer_sequence<int, Rank - 1>{};
arrayForEachDim(Dims, execution::par, std::tie(componentView, vectorView),
copier);
};
auto componentViewVariant = array::make_view_variant(componentField);

const auto selectRank = [&](auto value) {
switch (vectorField.rank()) {
case 2:
return copyArrayData(value, std::integral_constant<int, 2>{});
case 3:
return copyArrayData(value, std::integral_constant<int, 3>{});
default:
ATLAS_THROW_EXCEPTION("Unsupported vector field rank: " +
std::to_string(vectorField.rank()));
}
};
const auto componentVisitor = [&](auto componentView) {
if constexpr (array::is_rank<1, 2>(componentView)) {
using ComponentView = std::decay_t<decltype(componentView)>;
constexpr auto ComponentRank = ComponentView::rank();
using Value = typename ComponentView::non_const_value_type;

auto vectorView = array::make_view<Value, ComponentRank + 1>(vectorField);
constexpr auto Dims = std::make_integer_sequence<int, ComponentRank>{};
arrayForEachDim(Dims, execution::par, std::tie(componentView, vectorView),
copier);

const auto selectType = [&]() {
switch (vectorField.datatype().kind()) {
case DataType::kind<double>():
return selectRank(double{});
case DataType::kind<float>():
return selectRank(float{});
case DataType::kind<long>():
return selectRank(long{});
case DataType::kind<int>():
return selectRank(int{});
default:
ATLAS_THROW_EXCEPTION("Unknown datatype: " +
std::to_string(vectorField.datatype().kind()));
} else {
ATLAS_THROW_EXCEPTION("Unsupported component field rank: " +
std::to_string(componentView.rank()));
}
};

selectType();
std::visit(componentVisitor, componentViewVariant);
}

} // namespace
Expand Down Expand Up @@ -187,8 +163,6 @@ FieldSet pack_vector_fields(const FieldSet& fields, FieldSet packedFields) {
} else {
vectorField.set_dirty(vectorField.dirty() || componentField.dirty());
}


}
return packedFields;
}
Expand All @@ -208,7 +182,6 @@ FieldSet unpack_vector_fields(const FieldSet& fields, FieldSet unpackedFields) {

auto vectorIndex = 0;
for (const auto& componentFieldMetadata : componentFieldMetadataVector) {

// Get or create field.
auto componentFieldName = std::string{};
componentFieldMetadata.get("name", componentFieldName);
Expand Down

0 comments on commit b3f83ad

Please sign in to comment.