diff --git a/include/cppflow/span.h b/include/cppflow/span.h new file mode 100644 index 0000000..4248966 --- /dev/null +++ b/include/cppflow/span.h @@ -0,0 +1,570 @@ +// + +/* + * This is an implementation of C++20's std::span + * http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/n4820.pdf + */ + +// Copyright Tristan Brindle 2018. +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file ../../LICENSE_1_0.txt or copy at +// https://www.boost.org/LICENSE_1_0.txt) + +#pragma once + +#include +#include +#include +#include + +#ifndef TCB_SPAN_NO_EXCEPTIONS +// Attempt to discover whether we're being compiled with exception support +#if !(defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)) +#define TCB_SPAN_NO_EXCEPTIONS +#endif +#endif + +#ifndef TCB_SPAN_NO_EXCEPTIONS +#include +#include +#endif + +// Various feature test macros + +#ifndef TCB_SPAN_NAMESPACE_NAME +#define TCB_SPAN_NAMESPACE_NAME cppflow +#endif + +#if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) +#define TCB_SPAN_HAVE_CPP17 +#endif + +#if __cplusplus >= 201402L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201402L) +#define TCB_SPAN_HAVE_CPP14 +#endif + +namespace TCB_SPAN_NAMESPACE_NAME { + +// Establish default contract checking behavior +#if !defined(TCB_SPAN_THROW_ON_CONTRACT_VIOLATION) && \ + !defined(TCB_SPAN_TERMINATE_ON_CONTRACT_VIOLATION) && \ + !defined(TCB_SPAN_NO_CONTRACT_CHECKING) +#if defined(NDEBUG) || !defined(TCB_SPAN_HAVE_CPP14) +#define TCB_SPAN_NO_CONTRACT_CHECKING +#else +#define TCB_SPAN_TERMINATE_ON_CONTRACT_VIOLATION +#endif +#endif + +#if defined(TCB_SPAN_THROW_ON_CONTRACT_VIOLATION) + struct contract_violation_error : std::logic_error { + explicit contract_violation_error(const char* msg) : std::logic_error(msg) + {} + }; + + inline void contract_violation(const char* msg) + { + throw contract_violation_error(msg); + } + +#elif defined(TCB_SPAN_TERMINATE_ON_CONTRACT_VIOLATION) + [[noreturn]] inline void contract_violation(const char* /*unused*/) + { + std::terminate(); + } +#endif + +#if !defined(TCB_SPAN_NO_CONTRACT_CHECKING) +#define TCB_SPAN_STRINGIFY(cond) #cond +#define TCB_SPAN_EXPECT(cond) \ + cond ? (void) 0 : contract_violation("Expected " TCB_SPAN_STRINGIFY(cond)) +#else +#define TCB_SPAN_EXPECT(cond) +#endif + +#if defined(TCB_SPAN_HAVE_CPP17) || defined(__cpp_inline_variables) +#define TCB_SPAN_INLINE_VAR inline +#else +#define TCB_SPAN_INLINE_VAR +#endif + +#if defined(TCB_SPAN_HAVE_CPP14) || \ + (defined(__cpp_constexpr) && __cpp_constexpr >= 201304) +#define TCB_SPAN_HAVE_CPP14_CONSTEXPR +#endif + +#if defined(TCB_SPAN_HAVE_CPP14_CONSTEXPR) +#define TCB_SPAN_CONSTEXPR14 constexpr +#else +#define TCB_SPAN_CONSTEXPR14 +#endif + +#if defined(TCB_SPAN_HAVE_CPP14_CONSTEXPR) && \ + (!defined(_MSC_VER) || _MSC_VER > 1900) +#define TCB_SPAN_CONSTEXPR_ASSIGN constexpr +#else +#define TCB_SPAN_CONSTEXPR_ASSIGN +#endif + +#if defined(TCB_SPAN_NO_CONTRACT_CHECKING) +#define TCB_SPAN_CONSTEXPR11 constexpr +#else +#define TCB_SPAN_CONSTEXPR11 TCB_SPAN_CONSTEXPR14 +#endif + +#if defined(TCB_SPAN_HAVE_CPP17) || defined(__cpp_deduction_guides) +#define TCB_SPAN_HAVE_DEDUCTION_GUIDES +#endif + +#if defined(TCB_SPAN_HAVE_CPP17) || defined(__cpp_lib_byte) +#define TCB_SPAN_HAVE_STD_BYTE +#endif + +#if defined(TCB_SPAN_HAVE_CPP17) || defined(__cpp_lib_array_constexpr) +#define TCB_SPAN_HAVE_CONSTEXPR_STD_ARRAY_ETC +#endif + +#if defined(TCB_SPAN_HAVE_CONSTEXPR_STD_ARRAY_ETC) +#define TCB_SPAN_ARRAY_CONSTEXPR constexpr +#else +#define TCB_SPAN_ARRAY_CONSTEXPR +#endif + +#ifdef TCB_SPAN_HAVE_STD_BYTE + using byte = std::byte; +#else + using byte = unsigned char; +#endif + +#if defined(TCB_SPAN_HAVE_CPP17) +#define TCB_SPAN_NODISCARD [[nodiscard]] +#else +#define TCB_SPAN_NODISCARD +#endif + +TCB_SPAN_INLINE_VAR constexpr std::size_t dynamic_extent = SIZE_MAX; + +template +class span; + +namespace detail { + +template +struct span_storage { + constexpr span_storage() noexcept = default; + + constexpr span_storage(E* p_ptr, std::size_t /*unused*/) noexcept + : ptr(p_ptr) + {} + + E* ptr = nullptr; + static constexpr std::size_t size = S; +}; + +template +struct span_storage { + constexpr span_storage() noexcept = default; + + constexpr span_storage(E* p_ptr, std::size_t p_size) noexcept + : ptr(p_ptr), size(p_size) + {} + + E* ptr = nullptr; + std::size_t size = 0; +}; + +// Reimplementation of C++17 std::size() and std::data() +#if defined(TCB_SPAN_HAVE_CPP17) || \ +defined(__cpp_lib_nonmember_container_access) +using std::data; +using std::size; +#else +template +constexpr auto size(const C& c) -> decltype(c.size()) { + return c.size(); +} + +template +constexpr std::size_t size(const T (&)[N]) noexcept { + return N; +} + +template +constexpr auto data(C& c) -> decltype(c.data()) { + return c.data(); +} + +template +constexpr auto data(const C& c) -> decltype(c.data()) { + return c.data(); +} + +template +constexpr T* data(T (&array)[N]) noexcept { + return array; +} + +template +constexpr const E* data(std::initializer_list il) noexcept { + return il.begin(); +} +#endif // TCB_SPAN_HAVE_CPP17 + +#if defined(TCB_SPAN_HAVE_CPP17) || defined(__cpp_lib_void_t) +using std::void_t; +#else +template +using void_t = void; +#endif + +template +using uncvref_t = +typename std::remove_cv::type>::type; + +template +struct is_span : std::false_type {}; + +template +struct is_span> : std::true_type {}; + +template +struct is_std_array : std::false_type {}; + +template +struct is_std_array> : std::true_type {}; + +template +struct has_size_and_data : std::false_type {}; + +template +struct has_size_and_data())), +decltype(detail::data(std::declval()))>> + : std::true_type {}; + +template > +struct is_container { + static constexpr bool value = + !is_span::value && !is_std_array::value && + !std::is_array::value && has_size_and_data::value; +}; + +template +using remove_pointer_t = typename std::remove_pointer::type; + +template +struct is_container_element_type_compatible : std::false_type {}; + +template +struct is_container_element_type_compatible< +T, E, +typename std::enable_if< + !std::is_same()))>::type, +void>::value>::type> + : std::is_convertible< + remove_pointer_t()))> (*)[], +E (*)[]> {}; + +template +struct is_complete : std::false_type {}; + +template +struct is_complete : std::true_type {}; + +} // namespace detail + +template +class span { + static_assert(std::is_object::value, + "A span's ElementType must be an object type (not a " + "reference type or void)"); + static_assert(detail::is_complete::value, + "A span's ElementType must be a complete type (not a forward " + "declaration)"); + static_assert(!std::is_abstract::value, + "A span's ElementType cannot be an abstract class type"); + + using storage_type = detail::span_storage; + +public: + // constants and types + using element_type = ElementType; + using value_type = typename std::remove_cv::type; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using pointer = element_type*; + using const_pointer = const element_type*; + using reference = element_type&; + using const_reference = const element_type&; + using iterator = pointer; + using reverse_iterator = std::reverse_iterator; + + static constexpr size_type extent = Extent; + + // [span.cons], span constructors, copy, assignment, and destructor + template < + std::size_t E = Extent, + typename std::enable_if<(E == dynamic_extent || E <= 0), int>::type = 0> + constexpr span() noexcept {} + + TCB_SPAN_CONSTEXPR11 span(pointer ptr, size_type count) + : storage_(ptr, count) { + TCB_SPAN_EXPECT(extent == dynamic_extent || count == extent); + } + + TCB_SPAN_CONSTEXPR11 span(pointer first_elem, pointer last_elem) + : storage_(first_elem, last_elem - first_elem) { + TCB_SPAN_EXPECT(extent == dynamic_extent || + last_elem - first_elem == + static_cast(extent)); + } + + template ::value, + int>::type = 0> + constexpr span(element_type (&arr)[N]) noexcept : storage_(arr, N) {} + + template &, ElementType>::value, + int>::type = 0> + TCB_SPAN_ARRAY_CONSTEXPR span(std::array& arr) noexcept + : storage_(arr.data(), N) {} + + template &, ElementType>::value, + int>::type = 0> + TCB_SPAN_ARRAY_CONSTEXPR span(const std::array& arr) noexcept + : storage_(arr.data(), N) {} + + template ::value && + detail::is_container_element_type_compatible< + Container&, ElementType>::value, + int>::type = 0> + constexpr span(Container& cont) + : storage_(detail::data(cont), detail::size(cont)) {} + + template ::value && + detail::is_container_element_type_compatible< + const Container&, ElementType>::value, + int>::type = 0> + constexpr span(const Container& cont) + : storage_(detail::data(cont), detail::size(cont)) {} + + constexpr span(const span& other) noexcept = default; + + template ::value, + int>::type = 0> + constexpr span(const span& other) noexcept + : storage_(other.data(), other.size()) {} + + ~span() noexcept = default; + + TCB_SPAN_CONSTEXPR_ASSIGN span& + operator=(const span& other) noexcept = default; + + // [span.sub], span subviews + template + TCB_SPAN_CONSTEXPR11 span first() const { + TCB_SPAN_EXPECT(Count <= size()); + return {data(), Count}; + } + + template + TCB_SPAN_CONSTEXPR11 span last() const { + TCB_SPAN_EXPECT(Count <= size()); + return {data() + (size() - Count), Count}; + } + + template + using subspan_return_t = + span; + + template + TCB_SPAN_CONSTEXPR11 subspan_return_t subspan() const { + TCB_SPAN_EXPECT(Offset <= size() && + (Count == dynamic_extent || Offset + Count <= size())); + return {data() + Offset, + Count != dynamic_extent ? Count : size() - Offset}; + } + + TCB_SPAN_CONSTEXPR11 span + first(size_type count) const { + TCB_SPAN_EXPECT(count <= size()); + return {data(), count}; + } + + TCB_SPAN_CONSTEXPR11 span + last(size_type count) const { + TCB_SPAN_EXPECT(count <= size()); + return {data() + (size() - count), count}; + } + + TCB_SPAN_CONSTEXPR11 span + subspan(size_type offset, size_type count = dynamic_extent) const { + TCB_SPAN_EXPECT(offset <= size() && + (count == dynamic_extent || offset + count <= size())); + return {data() + offset, + count == dynamic_extent ? size() - offset : count}; + } + + // [span.obs], span observers + constexpr size_type size() const noexcept { return storage_.size; } + + constexpr size_type size_bytes() const noexcept { + return size() * sizeof(element_type); + } + + TCB_SPAN_NODISCARD constexpr bool empty() const noexcept { + return size() == 0; + } + + // [span.elem], span element access + TCB_SPAN_CONSTEXPR11 reference operator[](size_type idx) const { + TCB_SPAN_EXPECT(idx < size()); + return *(data() + idx); + } + + TCB_SPAN_CONSTEXPR11 reference front() const { + TCB_SPAN_EXPECT(!empty()); + return *data(); + } + + TCB_SPAN_CONSTEXPR11 reference back() const { + TCB_SPAN_EXPECT(!empty()); + return *(data() + (size() - 1)); + } + + constexpr pointer data() const noexcept { return storage_.ptr; } + + // [span.iterators], span iterator support + constexpr iterator begin() const noexcept { return data(); } + + constexpr iterator end() const noexcept { return data() + size(); } + + TCB_SPAN_ARRAY_CONSTEXPR reverse_iterator rbegin() const noexcept { + return reverse_iterator(end()); + } + + TCB_SPAN_ARRAY_CONSTEXPR reverse_iterator rend() const noexcept { + return reverse_iterator(begin()); + } + +private: + storage_type storage_{}; +}; + +#ifdef TCB_SPAN_HAVE_DEDUCTION_GUIDES + +/* Deduction Guides */ +template + span(T (&)[N])->span; + +template + span(std::array&)->span; + +template + span(const std::array&)->span; + +template + span(Container&)->span; + +template + span(const Container&)->span; + +#endif // TCB_HAVE_DEDUCTION_GUIDES + +template + constexpr span +make_span(span s) noexcept { + return s; +} + +template +constexpr span make_span(T (&arr)[N]) noexcept { + return {arr}; +} + +template +TCB_SPAN_ARRAY_CONSTEXPR span make_span(std::array& arr) noexcept { + return {arr}; +} + +template + TCB_SPAN_ARRAY_CONSTEXPR span +make_span(const std::array& arr) noexcept { + return {arr}; +} + +template +constexpr span make_span(Container& cont) { + return {cont}; +} + +template + constexpr span +make_span(const Container& cont) { + return {cont}; +} + +template + span +as_bytes(span s) noexcept { + return {reinterpret_cast(s.data()), s.size_bytes()}; +} + +template ::value, int>::type = 0> + span +as_writable_bytes(span s) noexcept { + return {reinterpret_cast(s.data()), s.size_bytes()}; +} + +template +constexpr auto get(span s) -> decltype(s[N]) { + return s[N]; +} + +} // namespace TCB_SPAN_NAMESPACE_NAME + +namespace std { + +template +class tuple_size> +: public integral_constant {}; + +template +class tuple_size>; // not defined + +template +class tuple_element> { +public: + static_assert(Extent != TCB_SPAN_NAMESPACE_NAME::dynamic_extent && + I < Extent, + ""); + using type = ElementType; +}; + +} // end namespace std \ No newline at end of file diff --git a/include/cppflow/tensor.h b/include/cppflow/tensor.h index 62a0bbd..0bb48c4 100644 --- a/include/cppflow/tensor.h +++ b/include/cppflow/tensor.h @@ -14,6 +14,7 @@ #include "context.h" #include "datatype.h" +#include "span.h" namespace cppflow { @@ -70,12 +71,12 @@ namespace cppflow { datatype dtype() const; /** - * Converts the tensor into a C++ vector + * Converts the tensor into a C++ span. There is no data copy on span for perfomance. * @tparam T The c++ type (must be equivalent to the tensor type) - * @return A vector representing the flat tensor + * @return A span representing the flat tensor */ template - std::vector get_data() const; + cppflow::span get_data() const; ~tensor() = default; @@ -193,7 +194,7 @@ namespace cppflow { } template - std::vector tensor::get_data() const { + cppflow::span tensor::get_data() const { auto res_tensor = TFE_TensorHandleResolve(this->tfe_handle.get(), context::get_status()); status_check(context::get_status()); @@ -205,7 +206,7 @@ namespace cppflow { // Convert to correct type const auto T_data = static_cast(raw_data); - return std::vector(T_data, T_data + size); + return cppflow::span(T_data, size); } datatype tensor::dtype() const {