Skip to content

Commit

Permalink
Add assertion variant
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Oct 25, 2024
1 parent 8045f26 commit f665e9d
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ tensor_bridge/tensor_bridge.cpp
*.so
*.o
.python-version
tensor_bridge.egg-info
1 change: 1 addition & 0 deletions dev.requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ isort==5.13.2
docformatter==1.7.5
torch==2.5.0
jax==0.4.25
numpy==1.26.4
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ warn_redundant_casts = True
warn_unused_ignores = True
warn_return_any = True
warn_unused_configs = True
plugins = numpy.typing.mypy_plugin

[mypy-torch.*]
ignore_missing_imports = True
Expand Down
13 changes: 12 additions & 1 deletion tensor_bridge/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
import numpy as np

from .tensor_bridge import copy_tensor
from .types import Array
from .utils import get_numpy_data

__all__ = ["copy_tensor", "copy_tensor_with_assertion"]


__all__ = ["copy_tensor"]
def copy_tensor_with_assertion(src: Array, dst: Array) -> None:
copy_tensor(src, dst)
assert np.all(
get_numpy_data(src) == get_numpy_data(dst)
), "Copied tensor doesn't match the source tensor. Layout of tensors can be different."
7 changes: 1 addition & 6 deletions tensor_bridge/tensor_bridge.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
from typing import Union

import jax
import torch

Array = Union[torch.Tensor, jax.Array]
from .types import Array

def copy_tensor(src: Array, dst: Array) -> None: ...
11 changes: 11 additions & 0 deletions tensor_bridge/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Any, Union

import jax
import numpy as np
import torch

__all__ = ["NumpyArray", "Array"]


NumpyArray = np.ndarray[Any, Any]
Array = Union[torch.Tensor, jax.Array]
16 changes: 16 additions & 0 deletions tensor_bridge/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import jax
import numpy as np
import torch

from .types import Array, NumpyArray

__all__ = ["get_numpy_data"]


def get_numpy_data(tensor: Array) -> NumpyArray:
if isinstance(tensor, torch.Tensor):
return tensor.cpu().detach().numpy() # type: ignore
elif isinstance(tensor, jax.Array):
return np.array(tensor)
else:
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
15 changes: 15 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import jax
import numpy as np
import torch

from tensor_bridge.utils import get_numpy_data


def test_get_numpy_data_with_torch() -> None:
tensor = torch.rand(2, 3, 4)
assert np.all(tensor.numpy() == get_numpy_data(tensor))


def test_get_numpy_data_with_jax() -> None:
tensor = jax.random.uniform(jax.random.key(123), shape=(2, 3, 4))
assert np.all(np.array(tensor) == get_numpy_data(tensor))

0 comments on commit f665e9d

Please sign in to comment.