diff --git a/.gitignore b/.gitignore index 4b5f7a4..b29a67f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ tensor_bridge/tensor_bridge.cpp *.so *.o .python-version +tensor_bridge.egg-info diff --git a/dev.requirements.txt b/dev.requirements.txt index 64ccdda..ff50437 100644 --- a/dev.requirements.txt +++ b/dev.requirements.txt @@ -4,3 +4,4 @@ isort==5.13.2 docformatter==1.7.5 torch==2.5.0 jax==0.4.25 +numpy==1.26.4 diff --git a/mypy.ini b/mypy.ini index aa6ab49..4a381d4 100644 --- a/mypy.ini +++ b/mypy.ini @@ -10,6 +10,7 @@ warn_redundant_casts = True warn_unused_ignores = True warn_return_any = True warn_unused_configs = True +plugins = numpy.typing.mypy_plugin [mypy-torch.*] ignore_missing_imports = True diff --git a/tensor_bridge/__init__.py b/tensor_bridge/__init__.py index 23eb941..fca471c 100644 --- a/tensor_bridge/__init__.py +++ b/tensor_bridge/__init__.py @@ -1,3 +1,14 @@ +import numpy as np + from .tensor_bridge import copy_tensor +from .types import Array +from .utils import get_numpy_data + +__all__ = ["copy_tensor", "copy_tensor_with_assertion"] + -__all__ = ["copy_tensor"] +def copy_tensor_with_assertion(src: Array, dst: Array) -> None: + copy_tensor(src, dst) + assert np.all( + get_numpy_data(src) == get_numpy_data(dst) + ), "Copied tensor doesn't match the source tensor. Layout of tensors can be different." diff --git a/tensor_bridge/tensor_bridge.pyi b/tensor_bridge/tensor_bridge.pyi index d0f82dc..b445abc 100644 --- a/tensor_bridge/tensor_bridge.pyi +++ b/tensor_bridge/tensor_bridge.pyi @@ -1,8 +1,3 @@ -from typing import Union - -import jax -import torch - -Array = Union[torch.Tensor, jax.Array] +from .types import Array def copy_tensor(src: Array, dst: Array) -> None: ... diff --git a/tensor_bridge/types.py b/tensor_bridge/types.py new file mode 100644 index 0000000..08fa370 --- /dev/null +++ b/tensor_bridge/types.py @@ -0,0 +1,11 @@ +from typing import Any, Union + +import jax +import numpy as np +import torch + +__all__ = ["NumpyArray", "Array"] + + +NumpyArray = np.ndarray[Any, Any] +Array = Union[torch.Tensor, jax.Array] diff --git a/tensor_bridge/utils.py b/tensor_bridge/utils.py new file mode 100644 index 0000000..cda2cdc --- /dev/null +++ b/tensor_bridge/utils.py @@ -0,0 +1,16 @@ +import jax +import numpy as np +import torch + +from .types import Array, NumpyArray + +__all__ = ["get_numpy_data"] + + +def get_numpy_data(tensor: Array) -> NumpyArray: + if isinstance(tensor, torch.Tensor): + return tensor.cpu().detach().numpy() # type: ignore + elif isinstance(tensor, jax.Array): + return np.array(tensor) + else: + raise ValueError(f"Unsupported tensor type: {type(tensor)}") diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..7459819 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,15 @@ +import jax +import numpy as np +import torch + +from tensor_bridge.utils import get_numpy_data + + +def test_get_numpy_data_with_torch() -> None: + tensor = torch.rand(2, 3, 4) + assert np.all(tensor.numpy() == get_numpy_data(tensor)) + + +def test_get_numpy_data_with_jax() -> None: + tensor = jax.random.uniform(jax.random.key(123), shape=(2, 3, 4)) + assert np.all(np.array(tensor) == get_numpy_data(tensor))