Skip to content

Commit

Permalink
Merge pull request SSAGESLabs#336 from SSAGESLabs/pz/compat
Browse files Browse the repository at this point in the history
[compat] Handle deprecation of jax.Array.device_buffer
  • Loading branch information
pabloferz authored Sep 23, 2024
2 parents f6d53b8 + dc8a737 commit 96483d3
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
4 changes: 2 additions & 2 deletions pysages/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from numpy.ctypeslib import as_ctypes_type

from pysages.typing import JaxArray
from pysages.utils import dispatch
from pysages.utils import dispatch, unsafe_buffer_pointer


def cupy_helpers():
Expand Down Expand Up @@ -38,7 +38,7 @@ def view(array: JaxArray):
# NOTE: We need a more general strategy to handle
# `SharedDeviceArray`s and `GlobalDeviceArray`s.
ptype = ctypes.POINTER(as_ctypes_type(array.dtype))
addr = array.device_buffer.unsafe_buffer_pointer()
addr = unsafe_buffer_pointer(array)
ptr = ctypes.cast(ctypes.c_void_p(addr), ptype)
return numba.carray(ptr, array.shape)

Expand Down
1 change: 1 addition & 0 deletions pysages/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
prod,
solve_pos_def,
try_import,
unsafe_buffer_pointer,
)
from .core import (
ToCPU,
Expand Down
19 changes: 16 additions & 3 deletions pysages/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,24 @@ def prod(iterable, start=1):
return result


# Compatibility for jax >=0.4.27
# Compatibility for jax >=0.4.22

# https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0427-may-7-2024
# https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0422-dec-13-2023
if _jax_version_tuple < (0, 4, 22):

if _jax_version_tuple < (0, 4, 27):
def unsafe_buffer_pointer(array):
return array.device_buffer.unsafe_buffer_pointer()

else:

def unsafe_buffer_pointer(array):
return array.unsafe_buffer_pointer()


# Compatibility for jax >=0.4.21

# https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0421-dec-4-2023
if _jax_version_tuple < (0, 4, 21):

def device_platform(array):
return array.device().platform
Expand Down

0 comments on commit 96483d3

Please sign in to comment.