diff --git a/deeplake/experimental/test_pytorch.py b/deeplake/experimental/test_pytorch.py index 0c81f937bc..b1e148f8f8 100644 --- a/deeplake/experimental/test_pytorch.py +++ b/deeplake/experimental/test_pytorch.py @@ -650,57 +650,3 @@ def test_pytorch_error_handling(local_ds): ptds = dataloader(ds).pytorch(tensors=["x"]) for _ in ptds: pass - - -def json_collate_fn(batch): - import torch - - batch = [it["a"][0]["x"] for it in batch] - return torch.utils.data._utils.collate.default_collate(batch) - - -def json_transform_fn(sample): - return sample[0]["x"] - - -def list_collate_fn(batch): - import torch - - batch = [np.array([it["a"][0], it["a"][1]]) for it in batch] - return torch.utils.data._utils.collate.default_collate(batch) - - -def list_transform_fn(sample): - return np.array([sample["a"][0], sample["a"][1]]) - - -def test_pytorch_json(local_ds): - ds = local_ds - with ds: - ds.create_tensor("a", htype="json") - ds.a.append({"x": 1}) - ds.a.append({"x": 2}) - - ptds = ds.pytorch(transform={"a": json_transform_fn}, batch_size=2) - batch = next(iter(ptds)) - np.testing.assert_equal(batch["a"], np.array([1, 2])) - - ptds = ds.pytorch(collate_fn=json_collate_fn, batch_size=2) - batch = next(iter(ptds)) - np.testing.assert_equal(batch, np.array([1, 2])) - - -def test_pytorch_list(local_ds): - ds = local_ds - with ds: - ds.create_tensor("a", htype="list") - ds.a.append([1, 2]) - ds.a.append([3, 4]) - - ptds = ds.pytorch(transform={"a": lambda x: np.array([x[0], x[1]])}, batch_size=2) - batch = next(iter(ptds)) - np.testing.assert_equal(batch["a"], np.array([[1, 2], [3, 4]])) - - ptds = ds.pytorch(collate_fn=list_collate_fn, batch_size=2) - batch = next(iter(ptds)) - np.testing.assert_equal(batch, np.array([[1, 2], [3, 4]])) diff --git a/deeplake/integrations/tests/test_pytorch.py b/deeplake/integrations/tests/test_pytorch.py index 71feb76fe9..83a735982f 100644 --- a/deeplake/integrations/tests/test_pytorch.py +++ b/deeplake/integrations/tests/test_pytorch.py @@ -810,3 +810,57 @@ def test_uneven_iteration(local_ds): np.testing.assert_equal(x, i) target_y = i if i < 5 else [] np.testing.assert_equal(y, target_y) + + +def json_collate_fn(batch): + import torch + + batch = [it["a"][0]["x"] for it in batch] + return torch.utils.data._utils.collate.default_collate(batch) + + +def json_transform_fn(sample): + return sample[0]["x"] + + +def list_collate_fn(batch): + import torch + + batch = [np.array([it["a"][0], it["a"][1]]) for it in batch] + return torch.utils.data._utils.collate.default_collate(batch) + + +def list_transform_fn(sample): + return np.array([sample[0], sample[1]]) + + +def test_pytorch_json(local_ds): + ds = local_ds + with ds: + ds.create_tensor("a", htype="json") + ds.a.append({"x": 1}) + ds.a.append({"x": 2}) + + ptds = ds.pytorch(transform={"a": json_transform_fn}, batch_size=2) + batch = next(iter(ptds)) + np.testing.assert_equal(batch["a"], np.array([1, 2])) + + ptds = ds.pytorch(collate_fn=json_collate_fn, batch_size=2) + batch = next(iter(ptds)) + np.testing.assert_equal(batch, np.array([1, 2])) + + +def test_pytorch_list(local_ds): + ds = local_ds + with ds: + ds.create_tensor("a", htype="list") + ds.a.append([1, 2]) + ds.a.append([3, 4]) + + ptds = ds.pytorch(transform={"a": list_transform_fn}, batch_size=2) + batch = next(iter(ptds)) + np.testing.assert_equal(batch["a"], np.array([[1, 2], [3, 4]])) + + ptds = ds.pytorch(collate_fn=list_collate_fn, batch_size=2) + batch = next(iter(ptds)) + np.testing.assert_equal(batch, np.array([[1, 2], [3, 4]]))