Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: unexpected behaviours (bugs) in train_test_split fixed #192

Merged
merged 4 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/litdata/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _add_item(self, index: int, data: Any) -> Optional[str]:
def __getitem__(self, index: Union[int, ChunkedIndex]) -> Dict[str, Any]:
"""Read an item in the reader."""
if isinstance(index, int):
index = ChunkedIndex(index, self._get_chunk_index_from_index(index))
index = ChunkedIndex(*self._get_chunk_index_from_index(index))
return self._reader.read(index)

def done(self) -> Optional[List[str]]:
Expand All @@ -150,5 +150,5 @@ def __len__(self) -> int:
def get_chunk_intervals(self) -> List[Interval]:
return self._reader.get_chunk_intervals()

def _get_chunk_index_from_index(self, index: int) -> int:
def _get_chunk_index_from_index(self, index: int) -> Tuple[int, int]:
return self._reader._get_chunk_index_from_index(index)
22 changes: 19 additions & 3 deletions src/litdata/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
self._compressor = _COMPRESSORS[self._compressor_name]

self._skip_chunk_indexes_deletion: Optional[List[int]] = None
self.zero_based_roi: Optional[List[Tuple[int, int]]] = None

def can_delete(self, chunk_index: int) -> bool:
if self._skip_chunk_indexes_deletion is None:
Expand Down Expand Up @@ -179,10 +180,25 @@ def config(self) -> Dict[str, Any]:
raise RuntimeError("The config should be defined.")
return self._config

def _get_chunk_index_from_index(self, index: int) -> int:
for chunk_index, internal in enumerate(self._intervals):
def _get_chunk_index_from_index(self, index: int) -> Tuple[int, int]:
if self.zero_based_roi is None:
# zero_based_roi is a list of tuples (start, end),
# to efficiently find the chunk index.
# Example:
# self._intervals = [(0, 5, 10, 10), (10, 10, 20, 20)]
# self.zero_based_roi = [(0, 5), (5, 15)]

self.zero_based_roi = []
start = 0
for curr_interval in self._intervals:
diff = curr_interval[2] - curr_interval[1] # roi_start, roi_end
self.zero_based_roi.append((start, start + diff))
start += diff

for chunk_index, internal in enumerate(self.zero_based_roi):
if internal[0] <= index < internal[-1]:
return chunk_index
real_index_to_read_from = self._intervals[chunk_index][1] + (index - internal[0])
return real_index_to_read_from, chunk_index
raise ValueError(
f"The provided index {index} didn't find a match within the chunk intervals {self._intervals}."
)
Expand Down
4 changes: 2 additions & 2 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,11 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:
self.cache = self._create_cache(worker_env=self.worker_env)
self.shuffler = self._create_shuffler(self.cache)
if isinstance(index, int):
index = ChunkedIndex(index, self.cache._get_chunk_index_from_index(index))
index = ChunkedIndex(*self.cache._get_chunk_index_from_index(index))
elif isinstance(index, slice):
start, stop, step = index.indices(len(self))
_my_indices = list(range(start, stop, step))
_my_cache_indices = [ChunkedIndex(idx, self.cache._get_chunk_index_from_index(idx)) for idx in _my_indices]
_my_cache_indices = [ChunkedIndex(*self.cache._get_chunk_index_from_index(idx)) for idx in _my_indices]
return [self.cache[chnk_idx] for chnk_idx in _my_cache_indices]
return self.cache[index]

Expand Down
2 changes: 1 addition & 1 deletion src/litdata/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def __init__(
self._last_chunk_index: Optional[int] = None
self._max_cache_size = int(os.getenv("MAX_CACHE_SIZE", max_cache_size or 0))

def _get_chunk_index_from_index(self, index: int) -> int:
def _get_chunk_index_from_index(self, index: int) -> Tuple[int, int]:
# Load the config containing the index
if self._config is None and self._try_load_config() is None:
raise Exception("The reader index isn't defined.")
Expand Down
46 changes: 28 additions & 18 deletions src/litdata/utilities/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,31 +39,41 @@ def shuffle_lists_together(


def subsample_filenames_and_roi(
chunks: List[Dict[str, Any]], roi_list: List[Tuple[int, int]], target: int
chunks: List[Dict[str, Any]], roi_list: List[Tuple[int, int]], item_count: int
) -> Tuple[List[str], List[Tuple[int, int]], List[Dict[str, Any]], List[Tuple[int, int]]]:
assert len(chunks) == len(roi_list)

cumsum_sizes = np.cumsum([r[-1] for r in roi_list])
if item_count == 0:
return [], [], chunks, roi_list

match = np.argmax(cumsum_sizes > target)
left_chunk_filenames = [c["filename"] for c in chunks[: match + 1]]
left_chunk_roi = [r[-1] for r in roi_list[: match + 1]]
left_chunk_roi[-1] = target if match == 0 else (target - cumsum_sizes[match - 1])
cumsum_sizes = np.cumsum([r[1] - r[0] for r in roi_list])

assert np.sum(left_chunk_roi) == target
match = np.argmax(cumsum_sizes >= item_count)

right_chunk_filenames = chunks[match:]
right_chunk_roi = [r[-1] for r in roi_list[match:]]
right_chunk_roi[0] -= left_chunk_roi[-1]
exact_item_count_match = cumsum_sizes[match] == item_count

# exact match
if left_chunk_roi[-1] == 0:
left_chunk_filenames = left_chunk_filenames[:-1]
left_chunk_roi = left_chunk_roi[:-1]
subsampled_filenames = [c["filename"] for c in chunks[: match + 1]]
subsampled_chunk_roi = roi_list[: match + 1]
# bcoz tuple doesn't support item assignment
subsampled_chunk_roi[-1] = (
subsampled_chunk_roi[-1][0],
subsampled_chunk_roi[-1][1] - (cumsum_sizes[match] - item_count),
)

assert sum(_chnk[1] - _chnk[0] for _chnk in subsampled_chunk_roi) == item_count

if exact_item_count_match:
match += 1 # start from next chunk
left_over_chunks = chunks[match:]
left_over_chunk_roi = roi_list[match:]

if not exact_item_count_match:
# bcoz tuple doesn't support item assignment
left_over_chunk_roi[0] = subsampled_chunk_roi[-1][1], left_over_chunk_roi[0][1] # start from next chunk

return (
left_chunk_filenames,
[(0, r) for r in left_chunk_roi],
right_chunk_filenames,
[(0, r) for r in right_chunk_roi],
subsampled_filenames,
subsampled_chunk_roi,
left_over_chunks,
left_over_chunk_roi,
)
10 changes: 5 additions & 5 deletions src/litdata/utilities/train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ def train_test_split(
if not all(0 <= _f <= 1 for _f in splits):
raise ValueError("Each Split should be a float with each value in [0,1].")

if any(split == 0 for split in splits):
logging.warning("Warning: some splits are 0, this will lead to empty datasets")

if sum(splits) > 1:
raise ValueError("Splits' sum must be less than 1.")

Expand Down Expand Up @@ -77,9 +74,12 @@ def train_test_split(
subsampled_chunks, dummy_subsampled_roi, np.random.RandomState([seed])
)

for i, split in enumerate(splits):
item_count = int(dataset_length * split)
item_count_list = [int(dataset_length * split) for split in splits]

if any(item_count == 0 for item_count in item_count_list):
logging.warning("Warning: some splits are having item count 0, this will lead to empty datasets")

for i, item_count in enumerate(item_count_list):
curr_chunk_filename, curr_chunk_roi, left_chunks, left_roi = subsample_filenames_and_roi(
subsampled_chunks, dummy_subsampled_roi, item_count
)
Expand Down
4 changes: 2 additions & 2 deletions tests/streaming/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_reader_chunk_removal(tmpdir):
os.makedirs(cache_dir, exist_ok=True)

for i in range(25):
index = ChunkedIndex(i, cache._get_chunk_index_from_index(i), is_last_index=i == 24)
index = ChunkedIndex(*cache._get_chunk_index_from_index(i), is_last_index=i == 24)
assert cache[index] == i

assert len(os.listdir(cache_dir)) == 14
Expand All @@ -41,7 +41,7 @@ def test_reader_chunk_removal(tmpdir):

for i in range(25):
assert len(os.listdir(cache_dir)) <= 3
index = ChunkedIndex(i, cache._get_chunk_index_from_index(i), is_last_index=i == 24)
index = ChunkedIndex(*cache._get_chunk_index_from_index(i), is_last_index=i == 24)
assert cache[index] == i

assert len(os.listdir(cache_dir)) in [2, 3]
Expand Down
2 changes: 1 addition & 1 deletion tests/utilities/test_subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_subsample_filenames_and_roi():
_, roi_list, _, left_roi = subsample_filenames_and_roi(my_chunks, roi_list, target)

assert target == sum([roi[1] - roi[0] for roi in roi_list])
assert (total_chunk_roi_length - target) == np.sum(left_roi)
assert (total_chunk_roi_length - target) == sum(_roi[1] - _roi[0] for _roi in left_roi)


def test_subsample_filenames_and_roi_exact():
Expand Down
41 changes: 40 additions & 1 deletion tests/utilities/test_train_test_split.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from litdata import StreamingDataset, train_test_split
from litdata import StreamingDataLoader, StreamingDataset, train_test_split
from litdata.constants import _ZSTD_AVAILABLE
from litdata.streaming.cache import Cache

Expand Down Expand Up @@ -68,3 +68,42 @@ def test_split_a_subsampled_dataset(tmpdir, compression):
split_datasets = train_test_split(_sub_sampled_streaming_dataset, _split_fraction)

assert all(len(split_datasets[i]) == int(50 * split) for i, split in enumerate(_split_fraction))


@pytest.mark.parametrize(
"compression",
[
pytest.param(None),
pytest.param("zstd", marks=pytest.mark.skipif(condition=not _ZSTD_AVAILABLE, reason="Requires: ['zstd']")),
],
)
def test_train_test_split_with_streaming_dataloader(tmpdir, compression):
cache = Cache(str(tmpdir), chunk_size=10, compression=compression)
for i in range(200):
cache[i] = i
cache.done()
cache.merge()

my_streaming_dataset = StreamingDataset(input_dir=str(tmpdir))

splits = [0.1, 0.2, 0.7, 0.0]

ds = train_test_split(my_streaming_dataset, splits=splits)

assert [len(ds[i]) for i in range(len(splits))] == [int(200 * split) for split in splits]

# check that the indices are unique for each dataset (iterating over the datasets)
visited_indices = set()
for _ds in ds:
for idx in range(len(_ds)):
assert _ds[idx] not in visited_indices
visited_indices.add(_ds[idx])

# check that the indices are unique for each dataloader (iterating over the dataloader)
visited_indices = set()
for _ds in ds:
dl = StreamingDataLoader(_ds, batch_size=10)
for _dl in dl:
for curr_idx in _dl:
assert curr_idx not in visited_indices
visited_indices.add(curr_idx)
Loading