Skip to content

Commit

Permalink
Update core header files to support C++17 as well (NVIDIA#27)
Browse files Browse the repository at this point in the history
Signed-off-by: Ben Howe <[email protected]>
  • Loading branch information
bmhowe23 authored Dec 12, 2024
1 parent 6ea31be commit 0c7d96d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
1 change: 1 addition & 0 deletions libs/core/include/cuda-qx/core/extension_point.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <functional>
#include <memory>
#include <stdexcept>
#include <unordered_map>

namespace cudaqx {
Expand Down
17 changes: 11 additions & 6 deletions libs/core/include/cuda-qx/core/heterogeneous_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,16 @@ class heterogeneous_map {
// we have a value of type int, but request here is std::size_t.
// Handle that case, by getting T's map of related types, and checking
// if any of them are valid.
using RelatedTypes =
typename RelatedTypesMap<std::remove_cvref_t<T>>::types;
using RelatedTypes = typename RelatedTypesMap<
std::remove_cv_t<std::remove_reference_t<T>>>::types;
std::optional<T> opt;
cudaqx::tuple_for_each(RelatedTypes(), [&](auto &&el) {
if (!opt.has_value() &&
isCastable<std::remove_cvref_t<decltype(el)>>(iter->second))
opt = std::any_cast<std::remove_cvref_t<decltype(el)>>(iter->second);
isCastable<std::remove_cv_t<std::remove_reference_t<decltype(el)>>>(
iter->second))
opt = std::any_cast<
std::remove_cv_t<std::remove_reference_t<decltype(el)>>>(
iter->second);
});

if (opt.has_value())
Expand Down Expand Up @@ -185,10 +188,12 @@ class heterogeneous_map {
/// @brief Check if the map contains a key
/// @param key The key to check
/// @return true if the key exists, false otherwise
bool contains(const std::string &key) const { return items.contains(key); }
bool contains(const std::string &key) const {
return items.find(key) != items.end();
}
bool contains(const std::vector<std::string> &keys) const {
for (auto &key : keys)
if (items.contains(key))
if (items.find(key) != items.end())
return true;

return false;
Expand Down
8 changes: 5 additions & 3 deletions libs/core/include/cuda-qx/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

namespace cudaqx {

/// @brief A tensor class implementing the PIMPL idiom.
/// @brief A tensor class implementing the PIMPL idiom. The flattened data is
/// stored where the strides grow from right to left (similar to a
/// multi-dimensional C array).
template <typename Scalar = std::complex<double>>
class tensor {
private:
Expand All @@ -35,7 +37,7 @@ class tensor {

public:
/// @brief Type alias for the scalar type used in the tensor
using scalar_type = details::tensor_impl<Scalar>::scalar_type;
using scalar_type = typename details::tensor_impl<Scalar>::scalar_type;
static constexpr auto ScalarAsString = type_to_string<Scalar>();

/// @brief Construct an empty tensor
Expand All @@ -54,7 +56,7 @@ class tensor {
.release())) {}

/// @brief Construct a tensor with the given data and shape
/// @param data Pointer to the tensor data
/// @param data Pointer to the tensor data. This takes ownership of the data.
/// @param shape The shape of the tensor
tensor(const scalar_type *data, const std::vector<std::size_t> &shape)
: pimpl(std::shared_ptr<details::tensor_impl<Scalar>>(
Expand Down

0 comments on commit 0c7d96d

Please sign in to comment.