Skip to content

Commit

Permalink
add offset
Browse files Browse the repository at this point in the history
  • Loading branch information
levongh committed Sep 5, 2023
1 parent 5402e20 commit 4500621
Show file tree
Hide file tree
Showing 12 changed files with 186 additions and 107 deletions.
1 change: 0 additions & 1 deletion deeplake/api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,6 @@ def test_dataset_deepcopy(path, hub_token, num_workers, progressbar):
dest_path = "_".join((path, "dest1"))

src_ds = deeplake.empty(src_path, overwrite=True, token=hub_token)
# dest_ds = deeplake.empty(dest_path, overwrite=True, token=hub_token)

with src_ds:
src_ds.info.update(key=0)
Expand Down
14 changes: 7 additions & 7 deletions deeplake/api/tests/test_reset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def test_load_corrupt_dataset(path):
save_head = ds.pending_commit_id

with pytest.raises(DatasetCorruptError):
ds = deeplake.load(path, access_method=access_method)
deeplake.load(path, access_method=access_method)

with pytest.raises(ReadOnlyModeError):
ds = deeplake.load(
deeplake.load(
path, read_only=True, access_method=access_method, reset=True
)

Expand Down Expand Up @@ -116,7 +116,7 @@ def test_load_corrupted_branch(local_path):
save_head = ds.pending_commit_id

with pytest.raises(DatasetCorruptError):
ds = deeplake.load(f"{local_path}@alt")
deeplake.load(f"{local_path}@alt")

ds = deeplake.load(f"{local_path}@alt", reset=True)
verify_reset_on_checkout(ds, "alt", main_2, save_head, {"abc": [[1], [2]]})
Expand All @@ -131,10 +131,10 @@ def test_load_corrupted_branch(local_path):
save_head = ds.pending_commit_id

with pytest.raises(DatasetCorruptError):
ds = deeplake.load(f"{local_path}@alt")
deeplake.load(f"{local_path}@alt")

with pytest.raises(DatasetCorruptError):
ds = deeplake.load(f"{local_path}@{save_head}")
deeplake.load(f"{local_path}@{save_head}")

ds = deeplake.load(f"{local_path}@alt", reset=True)
verify_reset_on_checkout(ds, "alt", alt_2, save_head, {"abc": [[1], [2], [3], [4]]})
Expand Down Expand Up @@ -200,10 +200,10 @@ def test_load_corrupt_dataset_with_no_commits(local_path):
corrupt_ds(ds, "abc", 1)

with pytest.raises(DatasetCorruptError):
ds = deeplake.load(local_path)
deeplake.load(local_path)

with pytest.raises(ReadOnlyModeError):
ds = deeplake.load(local_path, read_only=True, reset=True)
deeplake.load(local_path, read_only=True, reset=True)

ds = deeplake.load(local_path, reset=True)

Expand Down
5 changes: 2 additions & 3 deletions deeplake/api/tests/test_update_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ def _make_update_assert_equal(
# this is necessary because `expected` uses `aslist=True` to handle dynamic cases.
# with `aslist=False`, this wouldn't be necessary.
expected_value = value
if hasattr(value, "__len__"):
if len(value) == 1:
expected_value = value[0]
if hasattr(value, "__len__") and len(value) == 1:
expected_value = value[0]

# make updates
tensor[index] = value
Expand Down
4 changes: 2 additions & 2 deletions deeplake/api/tests/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_video_timestamps(vstream_path, hub_token):
ds = deeplake.load(vstream_path, read_only=True, token=hub_token)

with pytest.raises(ValueError):
stamps = ds.mp4_videos[:2].timestamps
ds.mp4_videos[:2].timestamps

stamps = ds.large_video[0, 12000:1199:-100].timestamps

Expand All @@ -131,7 +131,7 @@ def test_video_exception(local_ds):
with local_ds as ds:
ds.create_tensor("abc")
with pytest.raises(Exception):
stamps = ds.abc.timestamps
ds.abc.timestamps


@pytest.mark.skipif(
Expand Down
2 changes: 1 addition & 1 deletion deeplake/auto/structured/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _get_most_frequent_image_extension(self, fn_iterator):

if len(fn_iterator) == 0:
raise IngestionError(
f"Cannot determine the most frequent image compression because no valid image files were provided."
"Cannot determine the most frequent image compression because no valid image files were provided."
)

supported_image_extensions = tuple(
Expand Down
2 changes: 1 addition & 1 deletion deeplake/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ def test_cli_auth(hub_cloud_dev_credentials, hub_cloud_dev_token, method):
def test_bad_token():
runner = CliRunner()

result = runner.invoke(login, f"-t abcd")
result = runner.invoke(login, "-t abcd")
assert isinstance(result.exception, LoginException)
3 changes: 2 additions & 1 deletion deeplake/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@
"gcp://",
"gs://",
"az://",
"azure://" "gdrive://",
"azure://",
"gdrive://",
)

_ENABLE_HUB_SUB_DATASETS = False
Expand Down
5 changes: 2 additions & 3 deletions deeplake/enterprise/convert_to_libdeeplake.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ def dataset_to_libdeeplake(hub2_dataset):
commit_id = hub2_dataset.pending_commit_id
libdeeplake_dataset.checkout(commit_id)
slice_ = hub2_dataset.index.values[0].value
if slice_ != slice(None):
if isinstance(slice_, tuple):
slice_ = list(slice_)
if slice_ != slice(None)and isinstance(slice_, tuple):
slice_ = list(slice_)
libdeeplake_dataset = libdeeplake_dataset[slice_]
return libdeeplake_dataset
143 changes: 103 additions & 40 deletions deeplake/enterprise/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from deeplake.enterprise.convert_to_libdeeplake import dataset_to_libdeeplake
from deeplake.enterprise.dummy_dataloader import DummyDataloader # type: ignore
from deeplake.util.scheduling import create_fetching_schedule, find_primary_tensor
from deeplake.core.seed import DeeplakeRandom
from deeplake.enterprise.util import (
handle_mode,
raise_indra_installation_error,
Expand Down Expand Up @@ -34,6 +35,7 @@
BatchSampler = None # type: ignore

import numpy as np
import warnings

import math

Expand Down Expand Up @@ -108,6 +110,7 @@ def __init__(
_dataloader=None,
_world_size=1,
_ignore_errors=False,
_offset=None,
**kwargs,
):
import_indra_loader()
Expand All @@ -132,6 +135,7 @@ def __init__(
self._dataloader = _dataloader
self._world_size = _world_size
self._ignore_errors = _ignore_errors
self._offset = _offset
for k, v in kwargs.items():
setattr(self, k, v)

Expand Down Expand Up @@ -258,6 +262,25 @@ def batch(self, batch_size: int, drop_last: bool = False):
all_vars["_drop_last"] = drop_last
return self.__class__(**all_vars)

def offset(self, off: int = 0):
"""Returns a shifted :class:`DeepLakeDataLoader` object.
Args:
off (int): index that the dataloadee will start to iterate.
Returns:
DeepLakeDataLoader: A :class:`DeepLakeDataLoader` object.
Raises:
ValueError: If .offset() has already been called.
"""
if self._offset is not None:
raise ValueError("offset is already set")

all_vars = self.__dict__.copy()
all_vars["_offset"] = off
return self.__class__(**all_vars)

def shuffle(self, shuffle: bool = True, buffer_size: int = 2048):
"""Returns a shuffled :class:`DeepLakeDataLoader` object.
Expand Down Expand Up @@ -620,15 +643,85 @@ def _get_suboptimal_thread_count(self) -> Optional[int]:
return num_suboptimal_threads
return self._num_threads

def __create_dummy_dataloader(
self,
dataset,
tensors: Optional[List[str]] = None,
raw_tensors: Optional[List[str]] = None,
pil_compressed_tensors: Optional[List[str]] = None,
) -> DummyDataloader:
return DummyDataloader(
deeplake_dataset=dataset,
batch_size=self._batch_size,
shuffle=self._shuffle,
num_workers=self._num_workers,
collate_fn=self.collate_fn,
transform_fn=self._transform,
distributed=self._distributed,
prefetch_factor=self._prefetch_factor,
tensors=tensors,
drop_last=self._drop_last,
upcast=self._mode == "pytorch", # upcast to handle unsupported dtypes,
return_index=self._return_index,
raw_tensors=raw_tensors,
pil_compressed_tensors=pil_compressed_tensors,
persistent_workers=self._persistent_workers,
)

def __get_indra_dataloader(
self,
indra_dataset,
tensors: Optional[List[str]] = None,
raw_tensors: Optional[List[str]] = None,
pil_compressed_tensors: Optional[List[str]] = None,
json_tensors: Optional[List[str]] = None,
list_tensors: Optional[List[str]] = None,
htype_dict: Optional[dict] = None,
ndim_dict: Optional[dict] = None,
tensor_info_dict: Optional[dict] = None,
):
num_threads = (
self._get_suboptimal_thread_count()
if self._distributed
else self._num_threads
)
seed = DeeplakeRandom().get_seed()
if self._offset is not None and self._shuffle and seed is None:
warnings.warn(
"To keep dataloader consistent during setting offset and shuffling params please confider seeting deeplake.random.seed"
)

return INDRA_LOADER(
indra_dataset,
batch_size=self._batch_size,
num_threads=num_threads,
shuffle=self._shuffle,
num_workers=self._num_workers,
collate_fn=self.collate_fn,
transform_fn=self._transform,
distributed=self._distributed,
prefetch_factor=self._prefetch_factor,
tensors=tensors,
drop_last=self._drop_last,
ignore_errors=self._ignore_errors,
upcast=self._mode == "pytorch", # upcast to handle unsupported dtypes,
return_index=self._return_index,
primary_tensor=self._primary_tensor_name,
buffer_size=self._buffer_size,
raw_tensors=raw_tensors,
pil_compressed_tensors=pil_compressed_tensors,
json_tensors=json_tensors,
list_tensors=list_tensors,
persistent_workers=self._persistent_workers,
htype_dict=htype_dict,
ndim_dict=ndim_dict,
tensor_info_dict=tensor_info_dict,
offset=self._offset,
)

def __iter__(self):
if self._dataloader is None:
dataset = self._orig_dataset
collate_fn = self.collate_fn
upcast = self._mode == "pytorch" # upcast to handle unsupported dtypes

primary_tensor_name = self._primary_tensor_name
buffer_size = self._buffer_size

tensors = self._tensors or map_tensor_keys(dataset, None)

jpeg_png_compressed_tensors, json_tensors, list_tensors = check_tensors(
Expand All @@ -655,61 +748,31 @@ def __iter__(self):
dataset, data_tensors, tensor_info_tensors
)
if deeplake.constants.RETURN_DUMMY_DATA_FOR_DATALOADER:
self._dataloader = DummyDataloader(
deeplake_dataset=dataset,
batch_size=self._batch_size,
shuffle=self._shuffle,
num_workers=self._num_workers,
collate_fn=collate_fn,
transform_fn=self._transform,
distributed=self._distributed,
prefetch_factor=self._prefetch_factor,
self._dataloader = self.__create_dummy_dataloader(
dataset,
tensors=tensors,
drop_last=self._drop_last,
upcast=upcast,
return_index=self._return_index,
raw_tensors=raw_tensors,
pil_compressed_tensors=pil_compressed_tensors,
persistent_workers=self._persistent_workers,
)
else:
if not hasattr(self, "_indra_dataset"):
indra_dataset = dataset_to_libdeeplake(dataset)
else:
indra_dataset = self._indra_dataset

num_threads = (
self._get_suboptimal_thread_count()
if self._distributed
else self._num_threads
)
self._dataloader = INDRA_LOADER(
self._dataloader = self.__get_indra_dataloader(
indra_dataset,
batch_size=self._batch_size,
num_threads=num_threads,
shuffle=self._shuffle,
num_workers=self._num_workers,
collate_fn=collate_fn,
transform_fn=self._transform,
distributed=self._distributed,
prefetch_factor=self._prefetch_factor,
tensors=tensors,
drop_last=self._drop_last,
ignore_errors=self._ignore_errors,
upcast=upcast,
return_index=self._return_index,
primary_tensor=primary_tensor_name,
buffer_size=buffer_size,
raw_tensors=raw_tensors,
pil_compressed_tensors=pil_compressed_tensors,
json_tensors=json_tensors,
list_tensors=list_tensors,
persistent_workers=self._persistent_workers,
htype_dict=htype_dict,
ndim_dict=ndim_dict,
tensor_info_dict=tensor_info_dict,
worker_init_fn=self.worker_init_fn,
)

dataset_read(self._orig_dataset)

if self._internal_iterator is not None:
Expand Down
3 changes: 1 addition & 2 deletions deeplake/enterprise/libdeeplake_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def query(dataset, query_string: str):
elif dataset.libdeeplake_dataset is not None:
ds = dataset.libdeeplake_dataset
slice_ = dataset.index.values[0].value
if slice_ != slice(None):
if isinstance(slice_, tuple):
if slice_ != slice(None) and isinstance(slice_, tuple):
slice_ = list(slice_)
ds = ds[slice_]
else:
Expand Down
Loading

0 comments on commit 4500621

Please sign in to comment.