From 7dab89634d87444717326272a97e3a91c09152f4 Mon Sep 17 00:00:00 2001 From: Levon Ghukasyan Date: Thu, 31 Aug 2023 18:22:40 +0400 Subject: [PATCH 1/3] added torch batch_sampler --- deeplake/enterprise/dataloader.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deeplake/enterprise/dataloader.py b/deeplake/enterprise/dataloader.py index 0d87419da1..8692cb74c6 100644 --- a/deeplake/enterprise/dataloader.py +++ b/deeplake/enterprise/dataloader.py @@ -25,10 +25,13 @@ try: from torch.utils.data.dataloader import DataLoader, _InfiniteConstantSampler from torch.utils.data.distributed import DistributedSampler + from torch.utils.data import BatchSampler + except ImportError: DataLoader = object # type: ignore _InfiniteConstantSampler = None # type: ignore DistributedSampler = None # type: ignore + BatchSampler = None # type: ignore import numpy as np @@ -196,7 +199,7 @@ def sampler(self): @property def batch_sampler(self): - return DistributedSampler(self.dataset) if self._distributed else None + return BatchSampler(self.sampler, self.batch_size, self.drop_last) if BatchSampler else None @property def generator(self): From f2b6fc792ce184f2b3c98833ba24867ca950ab7f Mon Sep 17 00:00:00 2001 From: Levon Ghukasyan Date: Thu, 31 Aug 2023 23:40:50 +0400 Subject: [PATCH 2/3] check batch sampler --- deeplake/enterprise/test_pytorch.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/deeplake/enterprise/test_pytorch.py b/deeplake/enterprise/test_pytorch.py index 774bdd72fa..c21608c0f9 100644 --- a/deeplake/enterprise/test_pytorch.py +++ b/deeplake/enterprise/test_pytorch.py @@ -690,10 +690,22 @@ def test_pytorch_error_handling(hub_cloud_ds): pass +@requires_libdeeplake +@requires_torch +def test_batch_sampler_attribute(local_auth_ds): + ld = local_auth_ds.dataloader().pytorch() + + from torch.utils.data import BatchSampler + + assert isinstance(ld.batch_sampler, BatchSampler) + assert ld.batch_sampler.sampler is not None + + @requires_libdeeplake @requires_torch def test_pil_decode_method(hub_cloud_ds): from indra.pytorch.exceptions import CollateExceptionWrapper + with hub_cloud_ds as ds: ds.create_tensor("x", htype="image", sample_compression="jpeg") ds.x.extend(np.random.randint(0, 255, (10, 10, 10, 3), np.uint8)) From 666a4dc0a90a07f0f75ded910388fbbc5492682e Mon Sep 17 00:00:00 2001 From: Levon Ghukasyan Date: Fri, 1 Sep 2023 23:17:07 +0400 Subject: [PATCH 3/3] mark --- deeplake/enterprise/test_pytorch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deeplake/enterprise/test_pytorch.py b/deeplake/enterprise/test_pytorch.py index ebb36d4b1a..63598e4634 100644 --- a/deeplake/enterprise/test_pytorch.py +++ b/deeplake/enterprise/test_pytorch.py @@ -733,6 +733,8 @@ def test_batch_sampler_attribute(local_auth_ds): assert ld.batch_sampler.sampler is not None +@requires_libdeeplake +@requires_torch @pytest.mark.slow @pytest.mark.flaky def test_pil_decode_method(local_auth_ds):