Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added function to create std::variant for multiple array views. #220

Merged
merged 26 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
436b36b
Added array view variant class and tests.
odlomax Aug 22, 2024
44d7f2b
Merge branch 'develop' into feature/array_view_variant
odlomax Aug 23, 2024
97eadff
Merge branch 'develop' into feature/array_view_variant
odlomax Sep 13, 2024
30a2bc3
Moved make_view_variant into array namespace.
odlomax Sep 6, 2024
95d9977
Refactored ArrayViewVariant methods.
odlomax Sep 11, 2024
c3638d3
More refactoring.
odlomax Sep 12, 2024
cbf7818
Updated test.
odlomax Sep 13, 2024
b6ff4af
Attempting to address gnu 7.3 compiler errors.
odlomax Sep 13, 2024
6db4f83
Typos in comments.
odlomax Sep 13, 2024
73d990f
Merge branch 'develop' into feature/array_view_variant
odlomax Sep 16, 2024
f4d3697
Added missing EXPECTs in test.
odlomax Sep 16, 2024
b608818
Merge branch 'develop' into feature/array_view_variant
odlomax Sep 16, 2024
ed15c84
Refactored detial::VariantHelper template.
odlomax Sep 16, 2024
529d362
Merge branch 'develop' into feature/array_view_variant
wdeconinck Sep 19, 2024
aeb352f
Merged in ArrayViewVariant refactor.
odlomax Sep 25, 2024
3590f0c
Merge branch 'develop' into feature/array_view_variant
odlomax Sep 25, 2024
d7743ee
Refactored introspection helpers.
odlomax Oct 2, 2024
61db7f8
Refactor helper function signatures. Removed SFINAE test.
odlomax Oct 2, 2024
df696e6
Cleaned up some garbage in test.
odlomax Oct 2, 2024
229b05b
Removed reference qualifier on visitor template parameter.
odlomax Oct 3, 2024
1d645ff
Merge branch 'develop' into feature/array_view_variant
odlomax Oct 7, 2024
7b9bdc9
Moved ValuesTypes and Ranks structs into array::detail namespace.
odlomax Oct 7, 2024
152e8cf
Tidied up naming consistency.
odlomax Oct 7, 2024
b901194
Renamed ValueType and Ranks structs.
odlomax Oct 7, 2024
2ea7833
Revert parameter names in ArrayViewVariant.h
odlomax Oct 7, 2024
2e1ac8e
Revert parameter names in ArrayViewVariant.h
odlomax Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/atlas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,8 @@ array/Range.h
array/Vector.h
array/Vector.cc
array/SVector.h
array/ArrayViewVariant.h
array/ArrayViewVariant.cc
array/helpers/ArrayInitializer.h
array/helpers/ArrayAssigner.h
array/helpers/ArrayWriter.h
Expand Down
1 change: 1 addition & 0 deletions src/atlas/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "atlas/array/ArraySpec.h"
#include "atlas/array/ArrayStrides.h"
#include "atlas/array/ArrayView.h"
#include "atlas/array/ArrayViewVariant.h"
#include "atlas/array/DataType.h"
#include "atlas/array/LocalView.h"
#include "atlas/array/MakeView.h"
Expand Down
110 changes: 110 additions & 0 deletions src/atlas/array/ArrayViewVariant.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* (C) Crown Copyright 2024 Met Office
*
* This software is licensed under the terms of the Apache Licence Version 2.0
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
*/

#include "atlas/array/ArrayViewVariant.h"

#include <string>
#include <type_traits>

#include "atlas/runtime/Exception.h"

namespace atlas {
namespace array {

using namespace detail;

namespace {

template <bool IsConst>
struct VariantTypeHelper {
using type = ArrayViewVariant;
};

template <>
struct VariantTypeHelper<true> {
using type = ConstArrayViewVariant;
};

template <typename ArrayType>
using VariantType =
typename VariantTypeHelper<std::is_const_v<ArrayType>>::type;

// Match array.rank() and array.datatype() to variant types. Return result of
// makeView on a successful pattern match.
template <size_t TypeIndex = 0, typename ArrayType, typename MakeView>
VariantType<ArrayType> executeMakeView(ArrayType& array,
const MakeView& makeView) {
using View = std::variant_alternative_t<TypeIndex, VariantType<ArrayType>>;
using Value = typename View::non_const_value_type;
constexpr auto Rank = View::rank();

if (array.datatype() == DataType::kind<Value>() && array.rank() == Rank) {
return makeView(array, Value{}, std::integral_constant<int, Rank>{});
}

if constexpr (TypeIndex < std::variant_size_v<VariantType<ArrayType>> - 1) {
return executeMakeView<TypeIndex + 1>(array, makeView);
} else {
throw_Exception("ArrayView<" + array.datatype().str() + ", " +
std::to_string(array.rank()) +
"> is not an alternative in ArrayViewVariant.",
Here());
}
}

template <typename ArrayType>
VariantType<ArrayType> makeViewVariantImpl(ArrayType& array) {
const auto makeView = [](auto& array, auto value, auto rank) {
return make_view<decltype(value), decltype(rank)::value>(array);
};
return executeMakeView<>(array, makeView);
}

template <typename ArrayType>
VariantType<ArrayType> makeHostViewVariantImpl(ArrayType& array) {
const auto makeView = [](auto& array, auto value, auto rank) {
return make_host_view<decltype(value), decltype(rank)::value>(array);
};
return executeMakeView<>(array, makeView);
}

template <typename ArrayType>
VariantType<ArrayType> makeDeviceViewVariantImpl(ArrayType& array) {
const auto makeView = [](auto& array, auto value, auto rank) {
return make_device_view<decltype(value), decltype(rank)::value>(array);
};
return executeMakeView<>(array, makeView);
}

} // namespace

ArrayViewVariant make_view_variant(Array& array) {
return makeViewVariantImpl(array);
}

ConstArrayViewVariant make_view_variant(const Array& array) {
return makeViewVariantImpl(array);
}

ArrayViewVariant make_host_view_variant(Array& array) {
return makeHostViewVariantImpl(array);
}

ConstArrayViewVariant make_host_view_variant(const Array& array) {
return makeHostViewVariantImpl(array);
}

ArrayViewVariant make_device_view_variant(Array& array) {
return makeDeviceViewVariantImpl(array);
}

ConstArrayViewVariant make_device_view_variant(const Array& array) {
return makeDeviceViewVariantImpl(array);
}

} // namespace array
} // namespace atlas
102 changes: 102 additions & 0 deletions src/atlas/array/ArrayViewVariant.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* (C) Crown Copyright 2024 Met Office
*
* This software is licensed under the terms of the Apache Licence Version 2.0
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
*/

#pragma once

#include <variant>

#include "atlas/array.h"

namespace atlas {
namespace array {

namespace detail {

using namespace array;

// Container struct for a list of types.
template <typename... Ts>
struct Types {
using add_const = Types<std::add_const_t<Ts>...>;
};

// Container struct for a list of integers.
template <int... Is>
struct Ints {};

template <typename Values, typename Ranks, typename... ArrayViews>
struct VariantHelper;

// Recursively construct ArrayView std::variant from types Ts and ranks Is.
template <typename T, typename... Ts, int... Is, typename... ArrayViews>
struct VariantHelper<Types<T, Ts...>, Ints<Is...>, ArrayViews...> {
using type = typename VariantHelper<Types<Ts...>, Ints<Is...>, ArrayViews...,
ArrayView<T, Is>...>::type;
};

// End recursion.
template <int... Is, typename... ArrayViews>
struct VariantHelper<Types<>, Ints<Is...>, ArrayViews...> {
using type = std::variant<ArrayViews...>;
};

template <typename Values, typename Ranks>
using Variant = typename VariantHelper<Values, Ranks>::type;

} // namespace detail

/// @brief Supported ArrayView value types.
using ValueTypes = detail::Types<float, double, int, long, unsigned long>;

/// @brief Supported ArrayView ranks.
using Ranks = detail::Ints<1, 2, 3, 4, 5, 6, 7, 8, 9>;
odlomax marked this conversation as resolved.
Show resolved Hide resolved

/// @brief Variant containing all supported non-const ArrayView alternatives.
using ArrayViewVariant = detail::Variant<ValueTypes, Ranks>;

/// @brief Variant containing all supported const ArrayView alternatives.
using ConstArrayViewVariant = detail::Variant<ValueTypes::add_const, Ranks>;

odlomax marked this conversation as resolved.
Show resolved Hide resolved
/// @brief Create an ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_view_variant(Array& array);

/// @brief Create a const ArrayView and assign to an ArrayViewVariant.
ConstArrayViewVariant make_view_variant(const Array& array);

/// @brief Create a host ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_host_view_variant(Array& array);

/// @brief Create a const host ArrayView and assign to an ArrayViewVariant.
ConstArrayViewVariant make_host_view_variant(const Array& array);

/// @brief Create a device ArrayView and assign to an ArrayViewVariant.
ArrayViewVariant make_device_view_variant(Array& array);

/// @brief Create a const device ArrayView and assign to an ArrayViewVariant.
ConstArrayViewVariant make_device_view_variant(const Array& array);

/// @brief Return true if View::rank() is any of Ranks...
template <int... Ranks, typename View>
constexpr bool is_rank(const View&) {
return ((std::decay_t<View>::rank() == Ranks) || ...);
}
/// @brief Return true if View::value_type is any of ValuesTypes...
template <typename... ValueTypes, typename View>
constexpr bool is_value_type(const View&) {
using ValueType = typename std::decay_t<View>::value_type;
return ((std::is_same_v<ValueType, ValueTypes>) || ...);
}

/// @brief Return true if View::non_const_value_type is any of ValuesTypes...
template <typename... ValueTypes, typename View>
constexpr bool is_non_const_value_type(const View&) {
using ValueType = typename std::decay_t<View>::non_const_value_type;
return ((std::is_same_v<ValueType, ValueTypes>) || ...);
}

} // namespace array
} // namespace atlas
5 changes: 5 additions & 0 deletions src/tests/array/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,8 @@ atlas_add_hic_test(
ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT}
)

ecbuild_add_test( TARGET atlas_test_array_view_variant
SOURCES test_array_view_variant.cc
LIBS atlas
ENVIRONMENT ${ATLAS_TEST_ENVIRONMENT}
)
131 changes: 131 additions & 0 deletions src/tests/array/test_array_view_variant.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* (C) Crown Copyright 2024 Met Office
*
* This software is licensed under the terms of the Apache Licence Version 2.0
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
*/

#include <type_traits>
#include <variant>

#include "atlas/array.h"
#include "atlas/array/ArrayViewVariant.h"
#include "eckit/utils/Overloaded.h"
#include "tests/AtlasTestEnvironment.h"

namespace atlas {
namespace test {

using namespace array;

CASE("test variant assignment") {
auto array1 = array::ArrayT<float>(2);
auto array2 = array::ArrayT<double>(2, 3);
auto array3 = array::ArrayT<int>(2, 3, 4);
const auto& arrayRef = array1;

array1.allocateDevice();
array2.allocateDevice();
array3.allocateDevice();

auto view1 = make_view_variant(array1);
auto view2 = make_view_variant(array2);
auto view3 = make_view_variant(array3);
auto view4 = make_view_variant(arrayRef);

const auto hostView1 = make_host_view_variant(array1);
const auto hostView2 = make_host_view_variant(array2);
const auto hostView3 = make_host_view_variant(array3);
const auto hostView4 = make_host_view_variant(arrayRef);

auto deviceView1 = make_device_view_variant(array1);
auto deviceView2 = make_device_view_variant(array2);
auto deviceView3 = make_device_view_variant(array3);
auto deviceView4 = make_device_view_variant(arrayRef);

const auto visitVariants = [](auto& var1, auto& var2, auto var3, auto var4) {
std::visit(
[](auto view) {
EXPECT((is_rank<1>(view)));
EXPECT((is_value_type<float>(view)));
EXPECT((is_non_const_value_type<float>(view)));
},
var1);

std::visit(
[](auto view) {
EXPECT((is_rank<2>(view)));
EXPECT((is_value_type<double>(view)));
EXPECT((is_non_const_value_type<double>(view)));
},
var2);

std::visit(
[](auto view) {
EXPECT((is_rank<3>(view)));
EXPECT((is_value_type<int>(view)));
EXPECT((is_non_const_value_type<int>(view)));
},
var3);

std::visit(
[](auto view) {
EXPECT((is_rank<1>(view)));
EXPECT((is_value_type<const float>(view)));
EXPECT((is_non_const_value_type<float>(view)));
},
var4);
};

visitVariants(view1, view2, view3, view4);
visitVariants(hostView1, hostView2, hostView3, hostView4);
visitVariants(deviceView1, deviceView2, deviceView3, deviceView4);
}

CASE("test std::visit") {
auto array1 = ArrayT<int>(10);
make_view<int, 1>(array1).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});

auto array2 = ArrayT<int>(5, 2);
make_view<int, 2>(array2).assign({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});

const auto var1 = make_view_variant(array1);
const auto var2 = make_view_variant(array2);
auto rank1Tested = false;
auto rank2Tested = false;

const auto visitor = [&](auto view) {
if constexpr (is_rank<1>(view)) {
EXPECT((is_value_type<int>(view)));
auto testValue = int{0};
for (auto i = size_t{0}; i < view.size(); ++i) {
const auto value = view(i);
EXPECT_EQ(value, static_cast<decltype(value)>(testValue++));
}
rank1Tested = true;
} else if constexpr (is_rank<2>(view)) {
EXPECT((is_value_type<int>(view)));
auto testValue = int{0};
for (auto i = idx_t{0}; i < view.shape(0); ++i) {
for (auto j = idx_t{0}; j < view.shape(1); ++j) {
const auto value = view(i, j);
EXPECT_EQ(value, static_cast<decltype(value)>(testValue++));
}
}
rank2Tested = true;
} else {
// Test should not reach here.
EXPECT(false);
}
};

std::visit(visitor, var1);
EXPECT(rank1Tested);
std::visit(visitor, var2);
EXPECT(rank2Tested);
}

} // namespace test
} // namespace atlas

int main(int argc, char** argv) { return atlas::test::run(argc, argv); }
Loading