From 239dc9958ebc1de33dec796f957196231bb4eed6 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Mon, 1 Feb 2021 12:31:22 +0000 Subject: [PATCH] add `__cuda_array_interface__` and tests - fixes #4 --- .github/workflows/test.yml | 7 +++++++ cuvec/helpers.py | 6 ++++++ tests/test_helpers.py | 14 +++++++++++++- 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 27155e5..5f7d953 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -52,6 +52,13 @@ jobs: - name: Run setup-python run: setup-python -p${{ matrix.python }} - run: pip install -U -e .[dev] + - name: pip install cupy + run: | + CUVER=$(ls /usr/local/cuda-* -d | sed -r 's/\/usr\/local\/cuda-([0-9]+)\.([0-9]+)/\1\2/' | sort -nr | head -n1) + echo CUDA Tookit: $CUVER + [[ $CUVER -gt 111 ]] && CUVER=111 + echo Installing: cupy-cuda$CUVER + pip install cupy-cuda$CUVER - run: pytest - run: codecov - name: Post Run setup-python diff --git a/cuvec/helpers.py b/cuvec/helpers.py index 0c8a526..96e7f86 100644 --- a/cuvec/helpers.py +++ b/cuvec/helpers.py @@ -49,6 +49,12 @@ def __new__(cls, arr): (do not do `cuvec.CuVec((42, 1337))`; instead use `cuvec.zeros((42, 137))`""")) + @property + def __cuda_array_interface__(self): + res = self.__array_interface__ + return { + 'shape': res['shape'], 'typestr': res['typestr'], 'data': res['data'], 'version': 3} + def zeros(shape, dtype="float32"): """ diff --git a/tests/test_helpers.py b/tests/test_helpers.py index ae6f399..b2dea58 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,7 +1,7 @@ import logging import numpy as np -from pytest import mark, raises +from pytest import importorskip, mark, raises import cuvec as cu @@ -81,3 +81,15 @@ def test_asarray(): assert s.cuvec != v.cuvec assert (s == v[1:]).all() assert np.asarray(s.cuvec).data != np.asarray(v.cuvec).data + + +def test_cuda_array_interface(): + cupy = importorskip("cupy") + v = cu.asarray(np.random.random(shape)) + c = cupy.asarray(v) + + assert (c == v).all() + c[0, 0, 0] = 1 + assert c[0, 0, 0] == v[0, 0, 0] + c[0, 0, 0] = 0 + assert c[0, 0, 0] == v[0, 0, 0]