Skip to content

Commit

Permalink
temporarily remove unified functions, fix typing for _backend
Browse files Browse the repository at this point in the history
  • Loading branch information
syurkevi committed Apr 12, 2024
1 parent 93b88b6 commit d6f1261
Showing 1 changed file with 6 additions and 42 deletions.
48 changes: 6 additions & 42 deletions arrayfire_wrapper/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def is_cygwin(cls, name: str) -> bool:
class _BackendPathConfig:
lib_prefix: str
lib_postfix: str
af_path: Path
af_path: Path | None
af_is_user_path: bool
cuda_found: bool

Expand Down Expand Up @@ -175,7 +175,7 @@ def __iter__(self) -> Iterator:


class Backend:
_backend_type: BackendType
_backend_type: BackendType | None
_clibs: dict[BackendType, ctypes.CDLL]

def __init__(self) -> None:
Expand Down Expand Up @@ -297,51 +297,15 @@ def _find_nvrtc_builtins_lib_name(self, search_path: Path) -> str | None:
return f.name
return None

# unified backend functions
def get_active_backend(self) -> str:
if self._backend_type == BackendType.unified:
from arrayfire_wrapper.lib.unified_api_functions import get_active_backend as unified_get_active_backend

return unified_get_active_backend()
raise RuntimeError("Using unified function on non-unified backend")

def get_available_backends(self) -> list[int]:
if self._backend_type == BackendType.unified:
from arrayfire_wrapper.lib.unified_api_functions import (
get_available_backends as unified_get_available_backends,
)

return unified_get_available_backends()
raise RuntimeError("Using unified function on non-unified backend")

def get_backend_count(self) -> int:
if self._backend_type == BackendType.unified:
from arrayfire_wrapper.lib.unified_api_functions import get_backend_count as unified_get_backend_count

return unified_get_backend_count()
raise RuntimeError("Using unified function on non-unified backend")

def get_backend_id(self, arr: AFArray, /) -> int:
if self._backend_type == BackendType.unified:
from arrayfire_wrapper.lib.unified_api_functions import get_backend_id as unified_get_backend_id

return unified_get_backend_id(arr)
raise RuntimeError("Using unified function on non-unified backend")

def get_device_id(self, arr: AFArray, /) -> int:
if self._backend_type == BackendType.unified:
from arrayfire_wrapper.lib.unified_api_functions import get_device_id as unified_get_device_id

return unified_get_device_id(arr)
raise RuntimeError("Using unified function on non-unified backend")

@property
def backend_type(self) -> BackendType:
def backend_type(self) -> BackendType | None:
return self._backend_type

@property
def clib(self) -> ctypes.CDLL:
return self._clibs[self._backend_type]
if self._backend_type:
return self._clibs[self._backend_type]
raise RuntimeError("No valid _backend_type")


# Initialize the backend
Expand Down

0 comments on commit d6f1261

Please sign in to comment.