Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: unexpected behaviours (bugs) in train_test_split fixed #192

Merged
merged 4 commits into from
Jun 27, 2024

Conversation

deependujha
Copy link
Collaborator

@deependujha deependujha commented Jun 27, 2024

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

train_test_split works perfectly when asked to split dataset in splits=[0.1, 0.7, 0.2], but it fails when asked for splits=[0.1, 0.2, 0.7].

In the original code, this code will fail:

import os
from litdata import optimize, train_test_split, StreamingDataset, StreamingDataLoader

x, y, z = train_test_split(streaming_dataset=StreamingDataset("output_dir"), splits=[0.1, 0.2, 0.7])

print(f"{len(x)=}")
print(f"{len(y)=}")
print(f"{len(z)=}")

print(f"{x[:]=}")
print(f"{y[:]=}")
print(f"{z[:]=}") # this will raise error

x = StreamingDataLoader(x, batch_size=5)
y = StreamingDataLoader(y, batch_size=5)
z = StreamingDataLoader(z, batch_size=5)


print("-"*80)
print("iterate X")
for _x in x:
    print(_x)

print("-"*80)
print("iterate Y")
for _y in y:
    print(_y)

print("-"*80)
print("iterate Z")
for _z in z: # this will raise error
    print(_z)

print("-"*80)
print("All done!")

Except the failure, if you look at the values printed by y[:], it overlaps with x[:]. This was bcoz of the way reader was reading from chunks.


Code for output_dir

import os
from litdata import optimize, train_test_split, StreamingDataset

def compress(index):
    return (index, index ** 2)

optimize(
    fn=compress,
    inputs=list(range(100)),
    num_workers=4,
    output_dir="output_dir",
    chunk_bytes="64MB",
    mode="overwrite",
)

These bugs have been fixed in this PR. This PR originally aimed at closing a issue #186 , but it has been closed already, bcoz of some confusion.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

Copy link

codecov bot commented Jun 27, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Please upload report for BASE (main@f2c5a7b). Learn more about missing BASE report.

Additional details and impacted files
@@          Coverage Diff          @@
##             main   #192   +/-   ##
=====================================
  Coverage        ?    78%           
=====================================
  Files           ?     33           
  Lines           ?   4488           
  Branches        ?      0           
=====================================
  Hits            ?   3492           
  Misses          ?    996           
  Partials        ?      0           

@tchaton
Copy link
Collaborator

tchaton commented Jun 27, 2024

Hey @deependujha. Can you describe which bugs this is fixing ?

@deependujha
Copy link
Collaborator Author

Hey @deependujha. Can you describe which bugs this is fixing ?

Sorry for the delay in response. I've updated the description. Plz have a look at it. It's an extension of a PR that has been merged already (#187 )

@tchaton tchaton merged commit 744265f into Lightning-AI:main Jun 27, 2024
28 checks passed
@deependujha deependujha deleted the fix/train-test-split-bug branch June 28, 2024 08:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants