Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train_test_split fails when asked for splits=[0.1, 0.2, 0.7] #186

Closed
deependujha opened this issue Jun 27, 2024 · 1 comment · Fixed by #187
Closed

train_test_split fails when asked for splits=[0.1, 0.2, 0.7] #186

deependujha opened this issue Jun 27, 2024 · 1 comment · Fixed by #187
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@deependujha
Copy link
Collaborator

deependujha commented Jun 27, 2024

🐛 Bug

train_test_split works perfectly when asked to split dataset in splits=[0.1, 0.7, 0.2], but it fails when asked for splits=[0.1, 0.2, 0.7].

To Reproduce

Try this script:

import os
from litdata import optimize, train_test_split, StreamingDataset, StreamingDataLoader

x, y, z = train_test_split(streaming_dataset=StreamingDataset("output_dir"), splits=[0.1, 0.2, 0.7])

print(f"{len(x)=}")
print(f"{len(y)=}")
print(f"{len(z)=}")

print(f"{x[:]=}")
print(f"{y[:]=}")
print(f"{z[:]=}") # this will raise error

x = StreamingDataLoader(x, batch_size=5)
y = StreamingDataLoader(y, batch_size=5)
z = StreamingDataLoader(z, batch_size=5)


print("-"*80)
print("iterate X")
for _x in x:
    print(_x)

print("-"*80)
print("iterate Y")
for _y in y:
    print(_y)

print("-"*80)
print("iterate Z")
for _z in z: # this will raise error
    print(_z)

print("-"*80)
print("All done!")

Code sample

Code for output_dir:

import os
from litdata import optimize, train_test_split, StreamingDataset

def compress(index):
    return (index, index ** 2)

optimize(
    fn=compress,
    inputs=list(range(100)),
    num_workers=4,
    output_dir="output_dir",
    chunk_bytes="64MB",
    mode="overwrite",
)

Expected behavior

It should work irrespective of their order.

Environment

  • PyTorch Version (e.g., 1.0): torch==2.2.1+cu121
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): Already installed on Lightning Studio
  • Build command you used (if compiling from source): pip install -e .
  • Python version: Python 3.10.10

Additional context

It's happening bcoz of some logic issue in def subsample_filenames_and_roi().

@deependujha deependujha added bug Something isn't working help wanted Extra attention is needed labels Jun 27, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
1 participant