Skip to content

Commit

Permalink
Fix windows test (#2001)
Browse files Browse the repository at this point in the history
* lambda -> fn to fix windows test

* update libdeeplake version

* fix

* change version

* fix again
  • Loading branch information
AbhinavTuli authored Nov 11, 2022
1 parent fc10deb commit d736afb
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 54 deletions.
54 changes: 0 additions & 54 deletions deeplake/experimental/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,57 +650,3 @@ def test_pytorch_error_handling(local_ds):
ptds = dataloader(ds).pytorch(tensors=["x"])
for _ in ptds:
pass


def json_collate_fn(batch):
import torch

batch = [it["a"][0]["x"] for it in batch]
return torch.utils.data._utils.collate.default_collate(batch)


def json_transform_fn(sample):
return sample[0]["x"]


def list_collate_fn(batch):
import torch

batch = [np.array([it["a"][0], it["a"][1]]) for it in batch]
return torch.utils.data._utils.collate.default_collate(batch)


def list_transform_fn(sample):
return np.array([sample["a"][0], sample["a"][1]])


def test_pytorch_json(local_ds):
ds = local_ds
with ds:
ds.create_tensor("a", htype="json")
ds.a.append({"x": 1})
ds.a.append({"x": 2})

ptds = ds.pytorch(transform={"a": json_transform_fn}, batch_size=2)
batch = next(iter(ptds))
np.testing.assert_equal(batch["a"], np.array([1, 2]))

ptds = ds.pytorch(collate_fn=json_collate_fn, batch_size=2)
batch = next(iter(ptds))
np.testing.assert_equal(batch, np.array([1, 2]))


def test_pytorch_list(local_ds):
ds = local_ds
with ds:
ds.create_tensor("a", htype="list")
ds.a.append([1, 2])
ds.a.append([3, 4])

ptds = ds.pytorch(transform={"a": lambda x: np.array([x[0], x[1]])}, batch_size=2)
batch = next(iter(ptds))
np.testing.assert_equal(batch["a"], np.array([[1, 2], [3, 4]]))

ptds = ds.pytorch(collate_fn=list_collate_fn, batch_size=2)
batch = next(iter(ptds))
np.testing.assert_equal(batch, np.array([[1, 2], [3, 4]]))
54 changes: 54 additions & 0 deletions deeplake/integrations/tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,3 +810,57 @@ def test_uneven_iteration(local_ds):
np.testing.assert_equal(x, i)
target_y = i if i < 5 else []
np.testing.assert_equal(y, target_y)


def json_collate_fn(batch):
import torch

batch = [it["a"][0]["x"] for it in batch]
return torch.utils.data._utils.collate.default_collate(batch)


def json_transform_fn(sample):
return sample[0]["x"]


def list_collate_fn(batch):
import torch

batch = [np.array([it["a"][0], it["a"][1]]) for it in batch]
return torch.utils.data._utils.collate.default_collate(batch)


def list_transform_fn(sample):
return np.array([sample[0], sample[1]])


def test_pytorch_json(local_ds):
ds = local_ds
with ds:
ds.create_tensor("a", htype="json")
ds.a.append({"x": 1})
ds.a.append({"x": 2})

ptds = ds.pytorch(transform={"a": json_transform_fn}, batch_size=2)
batch = next(iter(ptds))
np.testing.assert_equal(batch["a"], np.array([1, 2]))

ptds = ds.pytorch(collate_fn=json_collate_fn, batch_size=2)
batch = next(iter(ptds))
np.testing.assert_equal(batch, np.array([1, 2]))


def test_pytorch_list(local_ds):
ds = local_ds
with ds:
ds.create_tensor("a", htype="list")
ds.a.append([1, 2])
ds.a.append([3, 4])

ptds = ds.pytorch(transform={"a": list_transform_fn}, batch_size=2)
batch = next(iter(ptds))
np.testing.assert_equal(batch["a"], np.array([[1, 2], [3, 4]]))

ptds = ds.pytorch(collate_fn=list_collate_fn, batch_size=2)
batch = next(iter(ptds))
np.testing.assert_equal(batch, np.array([[1, 2], [3, 4]]))

0 comments on commit d736afb

Please sign in to comment.