diff --git a/deeplake/enterprise/dataloader.py b/deeplake/enterprise/dataloader.py index 0d87419da1..8692cb74c6 100644 --- a/deeplake/enterprise/dataloader.py +++ b/deeplake/enterprise/dataloader.py @@ -25,10 +25,13 @@ try: from torch.utils.data.dataloader import DataLoader, _InfiniteConstantSampler from torch.utils.data.distributed import DistributedSampler + from torch.utils.data import BatchSampler + except ImportError: DataLoader = object # type: ignore _InfiniteConstantSampler = None # type: ignore DistributedSampler = None # type: ignore + BatchSampler = None # type: ignore import numpy as np @@ -196,7 +199,7 @@ def sampler(self): @property def batch_sampler(self): - return DistributedSampler(self.dataset) if self._distributed else None + return BatchSampler(self.sampler, self.batch_size, self.drop_last) if BatchSampler else None @property def generator(self): diff --git a/deeplake/enterprise/test_pytorch.py b/deeplake/enterprise/test_pytorch.py index 135af4a735..63598e4634 100644 --- a/deeplake/enterprise/test_pytorch.py +++ b/deeplake/enterprise/test_pytorch.py @@ -722,6 +722,17 @@ def test_pytorch_error_handling(local_auth_ds): pass +@requires_libdeeplake +@requires_torch +def test_batch_sampler_attribute(local_auth_ds): + ld = local_auth_ds.dataloader().pytorch() + + from torch.utils.data import BatchSampler + + assert isinstance(ld.batch_sampler, BatchSampler) + assert ld.batch_sampler.sampler is not None + + @requires_libdeeplake @requires_torch @pytest.mark.slow