From b6ed9609ac1dc858c36d0c196d7a4c9abf54d437 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Wed, 13 Mar 2024 14:29:50 +0000 Subject: [PATCH] expose NDCuVec.strides() --- cuvec/include/cuvec.cuh | 7 +++++++ cuvec/include/cuvec_pybind11.cuh | 6 +----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/cuvec/include/cuvec.cuh b/cuvec/include/cuvec.cuh index 721f82e..9b69804 100644 --- a/cuvec/include/cuvec.cuh +++ b/cuvec/include/cuvec.cuh @@ -112,6 +112,13 @@ template struct NDCuVec { if (size != vec.size()) throw std::length_error("reshape: size mismatch"); this->shape = shape; } + std::vector strides() const { + const size_t ndim = this->shape.size(); + std::vector s(ndim); + s[ndim - 1] = sizeof(T); + for (int i = ndim - 2; i >= 0; i--) s[i] = this->shape[i + 1] * s[i + 1]; + return s; + } }; #endif // _CUVEC_H_ diff --git a/cuvec/include/cuvec_pybind11.cuh b/cuvec/include/cuvec_pybind11.cuh index 9aedf47..999be18 100644 --- a/cuvec/include/cuvec_pybind11.cuh +++ b/cuvec/include/cuvec_pybind11.cuh @@ -32,13 +32,9 @@ PYBIND11_MAKE_OPAQUE(NDCuVec); pybind11::class_>(m, PYBIND11_TOSTRING(NDCuVec_##typechar), \ pybind11::buffer_protocol()) \ .def_buffer([](NDCuVec &v) -> pybind11::buffer_info { \ - size_t ndim = v.shape.size(); \ - std::vector strides(ndim); \ - strides[ndim - 1] = sizeof(T); \ - for (int i = ndim - 2; i >= 0; i--) strides[i] = v.shape[i + 1] * strides[i + 1]; \ return pybind11::buffer_info(v.vec.data(), sizeof(T), \ pybind11::format_descriptor::format(), v.shape.size(), \ - v.shape, strides); \ + v.shape, v.strides()); \ }) \ .def(pybind11::init<>()) \ .def(pybind11::init>()) \