Skip to content

Commit

Permalink
Merge pull request #2571 from activeloopai/batch_sampler
Browse files Browse the repository at this point in the history
added torch batch_sampler
  • Loading branch information
levongh authored Sep 1, 2023
2 parents b289139 + 666a4dc commit ca8a4d6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
5 changes: 4 additions & 1 deletion deeplake/enterprise/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@
try:
from torch.utils.data.dataloader import DataLoader, _InfiniteConstantSampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import BatchSampler

except ImportError:
DataLoader = object # type: ignore
_InfiniteConstantSampler = None # type: ignore
DistributedSampler = None # type: ignore
BatchSampler = None # type: ignore

import numpy as np

Expand Down Expand Up @@ -196,7 +199,7 @@ def sampler(self):

@property
def batch_sampler(self):
return DistributedSampler(self.dataset) if self._distributed else None
return BatchSampler(self.sampler, self.batch_size, self.drop_last) if BatchSampler else None

@property
def generator(self):
Expand Down
11 changes: 11 additions & 0 deletions deeplake/enterprise/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,17 @@ def test_pytorch_error_handling(local_auth_ds):
pass


@requires_libdeeplake
@requires_torch
def test_batch_sampler_attribute(local_auth_ds):
ld = local_auth_ds.dataloader().pytorch()

from torch.utils.data import BatchSampler

assert isinstance(ld.batch_sampler, BatchSampler)
assert ld.batch_sampler.sampler is not None


@requires_libdeeplake
@requires_torch
@pytest.mark.slow
Expand Down

0 comments on commit ca8a4d6

Please sign in to comment.