diff --git a/python/rmm/_lib/device_buffer.pxd b/python/rmm/_lib/device_buffer.pxd index 3d5f29f9a..b48df21e7 100644 --- a/python/rmm/_lib/device_buffer.pxd +++ b/python/rmm/_lib/device_buffer.pxd @@ -17,17 +17,31 @@ from libcpp.memory cimport unique_ptr from rmm._cuda.stream cimport Stream from rmm._lib.cuda_stream_view cimport cuda_stream_view -from rmm._lib.memory_resource cimport DeviceMemoryResource +from rmm._lib.memory_resource cimport ( + DeviceMemoryResource, + device_memory_resource, +) cdef extern from "rmm/device_buffer.hpp" namespace "rmm" nogil: cdef cppclass device_buffer: device_buffer() - device_buffer(size_t size, cuda_stream_view stream) except + - device_buffer(const void* source_data, - size_t size, cuda_stream_view stream) except + - device_buffer(const device_buffer buf, - cuda_stream_view stream) except + + device_buffer( + size_t size, + cuda_stream_view stream, + device_memory_resource * + ) except + + device_buffer( + const void* source_data, + size_t size, + cuda_stream_view stream, + device_memory_resource * + ) except + + device_buffer( + const device_buffer buf, + cuda_stream_view stream, + device_memory_resource * + ) except + void reserve(size_t new_capacity, cuda_stream_view stream) except + void resize(size_t new_size, cuda_stream_view stream) except + void shrink_to_fit(cuda_stream_view stream) except + diff --git a/python/rmm/_lib/device_buffer.pyx b/python/rmm/_lib/device_buffer.pyx index d248d01ab..3ce10c5f6 100644 --- a/python/rmm/_lib/device_buffer.pyx +++ b/python/rmm/_lib/device_buffer.pyx @@ -32,7 +32,10 @@ from cuda.ccudart cimport ( cudaStream_t, ) -from rmm._lib.memory_resource cimport get_current_device_resource +from rmm._lib.memory_resource cimport ( + device_memory_resource, + get_current_device_resource, +) # The DeviceMemoryResource attribute could be released prematurely @@ -75,24 +78,23 @@ cdef class DeviceBuffer: >>> db = rmm.DeviceBuffer(size=5) """ cdef const void* c_ptr + cdef device_memory_resource * mr_ptr + # Save a reference to the MR and stream used for allocation + self.mr = get_current_device_resource() + self.stream = stream + mr_ptr = self.mr.get_mr() with nogil: c_ptr = ptr - if size == 0: - self.c_obj.reset(new device_buffer()) - elif c_ptr == NULL: - self.c_obj.reset(new device_buffer(size, stream.view())) + if c_ptr == NULL or size == 0: + self.c_obj.reset(new device_buffer(size, stream.view(), mr_ptr)) else: - self.c_obj.reset(new device_buffer(c_ptr, size, stream.view())) + self.c_obj.reset(new device_buffer(c_ptr, size, stream.view(), mr_ptr)) if stream.c_is_default(): stream.c_synchronize() - # Save a reference to the MR and stream used for allocation - self.mr = get_current_device_resource() - self.stream = stream - def __len__(self): return self.size diff --git a/python/rmm/_lib/memory_resource.pxd b/python/rmm/_lib/memory_resource.pxd index 0770fb8ed..f9c2e91de 100644 --- a/python/rmm/_lib/memory_resource.pxd +++ b/python/rmm/_lib/memory_resource.pxd @@ -34,7 +34,7 @@ cdef extern from "rmm/mr/device/device_memory_resource.hpp" \ cdef class DeviceMemoryResource: cdef shared_ptr[device_memory_resource] c_obj - cdef device_memory_resource* get_mr(self) + cdef device_memory_resource* get_mr(self) noexcept nogil cdef class UpstreamResourceAdaptor(DeviceMemoryResource): cdef readonly DeviceMemoryResource upstream_mr diff --git a/python/rmm/_lib/memory_resource.pyx b/python/rmm/_lib/memory_resource.pyx index 7458ca025..100d18b56 100644 --- a/python/rmm/_lib/memory_resource.pyx +++ b/python/rmm/_lib/memory_resource.pyx @@ -218,7 +218,7 @@ cdef extern from "rmm/mr/device/failure_callback_resource_adaptor.hpp" \ cdef class DeviceMemoryResource: - cdef device_memory_resource* get_mr(self): + cdef device_memory_resource* get_mr(self) noexcept nogil: """Get the underlying C++ memory resource object.""" return self.c_obj.get()