diff --git a/src/atlas/CMakeLists.txt b/src/atlas/CMakeLists.txt index ebecad3bb..60b93b9a1 100644 --- a/src/atlas/CMakeLists.txt +++ b/src/atlas/CMakeLists.txt @@ -759,6 +759,8 @@ array/Range.h array/Vector.h array/Vector.cc array/SVector.h +array/ArrayViewVariant.h +array/ArrayViewVariant.cc array/helpers/ArrayInitializer.h array/helpers/ArrayAssigner.h array/helpers/ArrayWriter.h diff --git a/src/atlas/array.h b/src/atlas/array.h index c2cf7f720..a7ac48d07 100644 --- a/src/atlas/array.h +++ b/src/atlas/array.h @@ -23,6 +23,7 @@ #include "atlas/array/ArraySpec.h" #include "atlas/array/ArrayStrides.h" #include "atlas/array/ArrayView.h" +#include "atlas/array/ArrayViewVariant.h" #include "atlas/array/DataType.h" #include "atlas/array/LocalView.h" #include "atlas/array/MakeView.h" diff --git a/src/atlas/array/ArrayViewVariant.cc b/src/atlas/array/ArrayViewVariant.cc new file mode 100644 index 000000000..f62efd2d8 --- /dev/null +++ b/src/atlas/array/ArrayViewVariant.cc @@ -0,0 +1,110 @@ +/* + * (C) Crown Copyright 2024 Met Office + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + */ + +#include "atlas/array/ArrayViewVariant.h" + +#include +#include + +#include "atlas/runtime/Exception.h" + +namespace atlas { +namespace array { + +using namespace detail; + +namespace { + +template +struct VariantTypeHelper { + using type = ArrayViewVariant; +}; + +template <> +struct VariantTypeHelper { + using type = ConstArrayViewVariant; +}; + +template +using VariantType = + typename VariantTypeHelper>::type; + +// Match array.rank() and array.datatype() to variant types. Return result of +// makeView on a successful pattern match. +template +VariantType executeMakeView(ArrayType& array, + const MakeView& makeView) { + using View = std::variant_alternative_t>; + using Value = typename View::non_const_value_type; + constexpr auto Rank = View::rank(); + + if (array.datatype() == DataType::kind() && array.rank() == Rank) { + return makeView(array, Value{}, std::integral_constant{}); + } + + if constexpr (TypeIndex < std::variant_size_v> - 1) { + return executeMakeView(array, makeView); + } else { + throw_Exception("ArrayView<" + array.datatype().str() + ", " + + std::to_string(array.rank()) + + "> is not an alternative in ArrayViewVariant.", + Here()); + } +} + +template +VariantType makeViewVariantImpl(ArrayType& array) { + const auto makeView = [](auto& array, auto value, auto rank) { + return make_view(array); + }; + return executeMakeView<>(array, makeView); +} + +template +VariantType makeHostViewVariantImpl(ArrayType& array) { + const auto makeView = [](auto& array, auto value, auto rank) { + return make_host_view(array); + }; + return executeMakeView<>(array, makeView); +} + +template +VariantType makeDeviceViewVariantImpl(ArrayType& array) { + const auto makeView = [](auto& array, auto value, auto rank) { + return make_device_view(array); + }; + return executeMakeView<>(array, makeView); +} + +} // namespace + +ArrayViewVariant make_view_variant(Array& array) { + return makeViewVariantImpl(array); +} + +ConstArrayViewVariant make_view_variant(const Array& array) { + return makeViewVariantImpl(array); +} + +ArrayViewVariant make_host_view_variant(Array& array) { + return makeHostViewVariantImpl(array); +} + +ConstArrayViewVariant make_host_view_variant(const Array& array) { + return makeHostViewVariantImpl(array); +} + +ArrayViewVariant make_device_view_variant(Array& array) { + return makeDeviceViewVariantImpl(array); +} + +ConstArrayViewVariant make_device_view_variant(const Array& array) { + return makeDeviceViewVariantImpl(array); +} + +} // namespace array +} // namespace atlas diff --git a/src/atlas/array/ArrayViewVariant.h b/src/atlas/array/ArrayViewVariant.h new file mode 100644 index 000000000..94dff8115 --- /dev/null +++ b/src/atlas/array/ArrayViewVariant.h @@ -0,0 +1,103 @@ +/* + * (C) Crown Copyright 2024 Met Office + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + */ + +#pragma once + +#include + +#include "atlas/array.h" + +namespace atlas { +namespace array { + +namespace detail { + +using namespace array; + +// Container struct for a list of types. +template +struct Types { + using add_const = Types...>; +}; + +// Container struct for a list of integers. +template +struct Ints {}; + +template +struct VariantHelper; + +// Recursively construct ArrayView std::variant from types Ts and ranks Is. +template +struct VariantHelper, Ints, ArrayViews...> { + using type = typename VariantHelper, Ints, ArrayViews..., + ArrayView...>::type; +}; + +// End recursion. +template +struct VariantHelper, Ints, ArrayViews...> { + using type = std::variant; +}; + +template +using Variant = typename VariantHelper::type; + +using VariantValueTypes = + detail::Types; + +using VariantRanks = detail::Ints<1, 2, 3, 4, 5, 6, 7, 8, 9>; + +} // namespace detail + +/// @brief Variant containing all supported non-const ArrayView alternatives. +using ArrayViewVariant = + detail::Variant; + +/// @brief Variant containing all supported const ArrayView alternatives. +using ConstArrayViewVariant = + detail::Variant; + +/// @brief Create an ArrayView and assign to an ArrayViewVariant. +ArrayViewVariant make_view_variant(Array& array); + +/// @brief Create a const ArrayView and assign to an ArrayViewVariant. +ConstArrayViewVariant make_view_variant(const Array& array); + +/// @brief Create a host ArrayView and assign to an ArrayViewVariant. +ArrayViewVariant make_host_view_variant(Array& array); + +/// @brief Create a const host ArrayView and assign to an ArrayViewVariant. +ConstArrayViewVariant make_host_view_variant(const Array& array); + +/// @brief Create a device ArrayView and assign to an ArrayViewVariant. +ArrayViewVariant make_device_view_variant(Array& array); + +/// @brief Create a const device ArrayView and assign to an ArrayViewVariant. +ConstArrayViewVariant make_device_view_variant(const Array& array); + +/// @brief Return true if View::rank() is any of Ranks... +template +constexpr bool is_rank(const View&) { + return ((std::decay_t::rank() == Ranks) || ...); +} +/// @brief Return true if View::value_type is any of ValuesTypes... +template +constexpr bool is_value_type(const View&) { + using ValueType = typename std::decay_t::value_type; + return ((std::is_same_v) || ...); +} + +/// @brief Return true if View::non_const_value_type is any of ValuesTypes... +template +constexpr bool is_non_const_value_type(const View&) { + using ValueType = typename std::decay_t::non_const_value_type; + return ((std::is_same_v) || ...); +} + +} // namespace array +} // namespace atlas diff --git a/src/tests/array/CMakeLists.txt b/src/tests/array/CMakeLists.txt index 3915caca7..880e46ba6 100644 --- a/src/tests/array/CMakeLists.txt +++ b/src/tests/array/CMakeLists.txt @@ -81,3 +81,8 @@ atlas_add_hic_test( ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT} ) +ecbuild_add_test( TARGET atlas_test_array_view_variant + SOURCES test_array_view_variant.cc + LIBS atlas + ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT} +) diff --git a/src/tests/array/test_array_view_variant.cc b/src/tests/array/test_array_view_variant.cc new file mode 100644 index 000000000..f6eaeeacf --- /dev/null +++ b/src/tests/array/test_array_view_variant.cc @@ -0,0 +1,131 @@ +/* + * (C) Crown Copyright 2024 Met Office + * + * This software is licensed under the terms of the Apache Licence Version 2.0 + * which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. + */ + +#include +#include + +#include "atlas/array.h" +#include "atlas/array/ArrayViewVariant.h" +#include "eckit/utils/Overloaded.h" +#include "tests/AtlasTestEnvironment.h" + +namespace atlas { +namespace test { + +using namespace array; + +CASE("test variant assignment") { + auto array1 = array::ArrayT(2); + auto array2 = array::ArrayT(2, 3); + auto array3 = array::ArrayT(2, 3, 4); + const auto& arrayRef = array1; + + array1.allocateDevice(); + array2.allocateDevice(); + array3.allocateDevice(); + + auto view1 = make_view_variant(array1); + auto view2 = make_view_variant(array2); + auto view3 = make_view_variant(array3); + auto view4 = make_view_variant(arrayRef); + + const auto hostView1 = make_host_view_variant(array1); + const auto hostView2 = make_host_view_variant(array2); + const auto hostView3 = make_host_view_variant(array3); + const auto hostView4 = make_host_view_variant(arrayRef); + + auto deviceView1 = make_device_view_variant(array1); + auto deviceView2 = make_device_view_variant(array2); + auto deviceView3 = make_device_view_variant(array3); + auto deviceView4 = make_device_view_variant(arrayRef); + + const auto visitVariants = [](auto& var1, auto& var2, auto var3, auto var4) { + std::visit( + [](auto view) { + EXPECT((is_rank<1>(view))); + EXPECT((is_value_type(view))); + EXPECT((is_non_const_value_type(view))); + }, + var1); + + std::visit( + [](auto view) { + EXPECT((is_rank<2>(view))); + EXPECT((is_value_type(view))); + EXPECT((is_non_const_value_type(view))); + }, + var2); + + std::visit( + [](auto view) { + EXPECT((is_rank<3>(view))); + EXPECT((is_value_type(view))); + EXPECT((is_non_const_value_type(view))); + }, + var3); + + std::visit( + [](auto view) { + EXPECT((is_rank<1>(view))); + EXPECT((is_value_type(view))); + EXPECT((is_non_const_value_type(view))); + }, + var4); + }; + + visitVariants(view1, view2, view3, view4); + visitVariants(hostView1, hostView2, hostView3, hostView4); + visitVariants(deviceView1, deviceView2, deviceView3, deviceView4); +} + +CASE("test std::visit") { + auto array1 = ArrayT(10); + make_view(array1).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + + auto array2 = ArrayT(5, 2); + make_view(array2).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + + const auto var1 = make_view_variant(array1); + const auto var2 = make_view_variant(array2); + auto rank1Tested = false; + auto rank2Tested = false; + + const auto visitor = [&](auto view) { + if constexpr (is_rank<1>(view)) { + EXPECT((is_value_type(view))); + auto testValue = int{0}; + for (auto i = size_t{0}; i < view.size(); ++i) { + const auto value = view(i); + EXPECT_EQ(value, static_cast(testValue++)); + } + rank1Tested = true; + } else if constexpr (is_rank<2>(view)) { + EXPECT((is_value_type(view))); + auto testValue = int{0}; + for (auto i = idx_t{0}; i < view.shape(0); ++i) { + for (auto j = idx_t{0}; j < view.shape(1); ++j) { + const auto value = view(i, j); + EXPECT_EQ(value, static_cast(testValue++)); + } + } + rank2Tested = true; + } else { + // Test should not reach here. + EXPECT(false); + } + }; + + std::visit(visitor, var1); + EXPECT(rank1Tested); + std::visit(visitor, var2); + EXPECT(rank2Tested); +} + +} // namespace test +} // namespace atlas + +int main(int argc, char** argv) { return atlas::test::run(argc, argv); }