Skip to content

Commit

Permalink
add __cuda_array_interface__ and tests
Browse files Browse the repository at this point in the history
- fixes #4
  • Loading branch information
casperdcl committed Feb 1, 2021
1 parent 2e0e86f commit 239dc99
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
7 changes: 7 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ jobs:
- name: Run setup-python
run: setup-python -p${{ matrix.python }}
- run: pip install -U -e .[dev]
- name: pip install cupy
run: |
CUVER=$(ls /usr/local/cuda-* -d | sed -r 's/\/usr\/local\/cuda-([0-9]+)\.([0-9]+)/\1\2/' | sort -nr | head -n1)
echo CUDA Tookit: $CUVER
[[ $CUVER -gt 111 ]] && CUVER=111
echo Installing: cupy-cuda$CUVER
pip install cupy-cuda$CUVER
- run: pytest
- run: codecov
- name: Post Run setup-python
Expand Down
6 changes: 6 additions & 0 deletions cuvec/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def __new__(cls, arr):
(do not do `cuvec.CuVec((42, 1337))`;
instead use `cuvec.zeros((42, 137))`"""))

@property
def __cuda_array_interface__(self):
res = self.__array_interface__
return {
'shape': res['shape'], 'typestr': res['typestr'], 'data': res['data'], 'version': 3}


def zeros(shape, dtype="float32"):
"""
Expand Down
14 changes: 13 additions & 1 deletion tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

import numpy as np
from pytest import mark, raises
from pytest import importorskip, mark, raises

import cuvec as cu

Expand Down Expand Up @@ -81,3 +81,15 @@ def test_asarray():
assert s.cuvec != v.cuvec
assert (s == v[1:]).all()
assert np.asarray(s.cuvec).data != np.asarray(v.cuvec).data


def test_cuda_array_interface():
cupy = importorskip("cupy")
v = cu.asarray(np.random.random(shape))
c = cupy.asarray(v)

assert (c == v).all()
c[0, 0, 0] = 1
assert c[0, 0, 0] == v[0, 0, 0]
c[0, 0, 0] = 0
assert c[0, 0, 0] == v[0, 0, 0]

0 comments on commit 239dc99

Please sign in to comment.