Skip to content

Commit

Permalink
Use cuda.bindings layout in tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
bdice committed Dec 17, 2024
1 parent 43b32eb commit 8afcc13
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions python/rmm/rmm/tests/test_rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import warnings
from itertools import product

import cuda.cudart as cudart
import numpy as np
import pytest
from cuda.bindings import runtime
from numba import cuda

import rmm
Expand All @@ -34,7 +34,7 @@
cuda.set_memory_manager(RMMNumbaManager)

_SYSTEM_MEMORY_SUPPORTED = rmm._cuda.gpu.getDeviceAttribute(
cudart.cudaDeviceAttr.cudaDevAttrPageableMemoryAccess,
runtime.cudaDeviceAttr.cudaDevAttrPageableMemoryAccess,
rmm._cuda.gpu.getDevice(),
)

Expand Down Expand Up @@ -319,13 +319,13 @@ def test_rmm_device_buffer_pickle_roundtrip(hb):


def assert_prefetched(buffer, device_id):
err, dev = cudart.cudaMemRangeGetAttribute(
err, dev = runtime.cudaMemRangeGetAttribute(
4,
cudart.cudaMemRangeAttribute.cudaMemRangeAttributeLastPrefetchLocation,
runtime.cudaMemRangeAttribute.cudaMemRangeAttributeLastPrefetchLocation,
buffer.ptr,
buffer.size,
)
assert err == cudart.cudaError_t.cudaSuccess
assert err == runtime.cudaError_t.cudaSuccess
assert dev == device_id


Expand All @@ -336,11 +336,11 @@ def test_rmm_device_buffer_prefetch(pool, managed):
rmm.reinitialize(pool_allocator=pool, managed_memory=managed)
db = rmm.DeviceBuffer.to_device(np.zeros(256, dtype="u1"))
if managed:
assert_prefetched(db, cudart.cudaInvalidDeviceId)
assert_prefetched(db, runtime.cudaInvalidDeviceId)
db.prefetch() # just test that it doesn't throw
if managed:
err, device = cudart.cudaGetDevice()
assert err == cudart.cudaError_t.cudaSuccess
err, device = runtime.cudaGetDevice()
assert err == runtime.cudaError_t.cudaSuccess
assert_prefetched(db, device)


Expand Down Expand Up @@ -830,15 +830,15 @@ def test_prefetch_resource_adaptor(managed):
# This allocation should be prefetched
db = rmm.DeviceBuffer.to_device(np.zeros(256, dtype="u1"))

err, device = cudart.cudaGetDevice()
assert err == cudart.cudaError_t.cudaSuccess
err, device = runtime.cudaGetDevice()
assert err == runtime.cudaError_t.cudaSuccess

if managed:
assert_prefetched(db, device)
db.prefetch() # just test that it doesn't throw
if managed:
err, device = cudart.cudaGetDevice()
assert err == cudart.cudaError_t.cudaSuccess
err, device = runtime.cudaGetDevice()
assert err == runtime.cudaError_t.cudaSuccess
assert_prefetched(db, device)


Expand Down

0 comments on commit 8afcc13

Please sign in to comment.