Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reset state_dict after resume #330

Closed
wants to merge 5 commits into from
Closed

Conversation

vgurev
Copy link
Contributor

@vgurev vgurev commented Aug 12, 2024

Reset of state_dict after resume.
  • [ No] Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • [Yes] Did you read the contributor guideline, Pull Request section?
  • [ No] Did you make sure to update the docs?
  • [No] Did you write any new necessary tests?

What does this PR do?

Currently, resume from restart is triggered by a check that self._state_dict is not None when dataset iterator is created

if self._state_dict:
            self._resume(workers_chunks, workers_intervals)

However, the self._state_dict is never reset after restart. At the next epoch, when a new dataset iterator is created, the resume is triggered again from the same state_dict. To fix this bug, I assign None to the self._state_dict after resume.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged.

Did you have fun?

Yes

Copy link

codecov bot commented Aug 12, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Please upload report for BASE (main@a39f31c). Learn more about missing BASE report.

Additional details and impacted files
@@          Coverage Diff          @@
##             main   #330   +/-   ##
=====================================
  Coverage        ?    79%           
=====================================
  Files           ?     34           
  Lines           ?   4989           
  Branches        ?      0           
=====================================
  Hits            ?   3949           
  Misses          ?   1040           
  Partials        ?      0           

@tchaton
Copy link
Collaborator

tchaton commented Aug 12, 2024

Hey @vgurev. Great catch ! Can you add a test ?

@vgurev
Copy link
Contributor Author

vgurev commented Aug 13, 2024

OK, I will

@bhimrazy
Copy link
Collaborator

bhimrazy commented Aug 15, 2024

Hi @vgurev,

Thanks for reporting the bug and getting started on it!

fyi, this issue was also addressed in PR #318, which got released yesterday.
so, I tested by commenting out the lines in dataset.py that seem to reset the state at the end of epoch: L345 and L352.

Test script

Create Optimized Dataset
from litdata import optimize


def random_data(index):
    return index

if __name__ == "__main__":
    optimize(
        fn=random_data,
        inputs=list(range(100)),
        output_dir="my_optimized_dataset",
        num_workers=4,
        chunk_bytes="64MB",
    )
from litdata import StreamingDataLoader, StreamingDataset

dataset = StreamingDataset("my_optimized_dataset")
dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2)

for batch_idx, batch in enumerate(dataloader):
    if batch_idx == 10:
        break

assert dataset._state_dict is None
dataloader.load_state_dict(dataloader.state_dict())
assert dataset._state_dict is not None

for batch_idx, batch in enumerate(dataloader):
    pass
 
assert dataset._state_dict is None

After testing, I noticed that the state reset issue still occurs when num_workers > 1, which might be due to the multi-worker processing.

However, since this has already been addressed, please feel free to try it out with the latest release and let us know how it goes.

Thanks again for bringing this to our attention and contributing to the improvements!
cc: @tchaton

@tchaton
Copy link
Collaborator

tchaton commented Aug 15, 2024

Hi @vgurev,

Thanks for reporting the bug and getting started on it!

fyi, this issue was also addressed in PR #318, which got released yesterday.
so, I tested by commenting out the lines in dataset.py that seem to reset the state at the end of epoch: L345 and L352.

Test script

Create Optimized Dataset

from litdata import StreamingDataLoader, StreamingDataset

dataset = StreamingDataset("my_optimized_dataset")
dataloader = StreamingDataLoader(dataset, batch_size=4, num_workers=2)

for batch_idx, batch in enumerate(dataloader):
    if batch_idx == 10:
        break

assert dataset._state_dict is None
dataloader.load_state_dict(dataloader.state_dict())
assert dataset._state_dict is not None

for batch_idx, batch in enumerate(dataloader):
    pass
 
assert dataset._state_dict is None

After testing, I noticed that the state reset issue still occurs when num_workers > 1, which might be due to the multi-worker processing.

However, since this has already been addressed, please feel free to try it out with the latest release and let us know how it goes.

Thanks again for bringing this to our attention and contributing to the improvements! cc: @tchaton

Yes, we need to reset it from the DataLoader when the stopIteration occurs with num_workers > 0

@vgurev vgurev closed this Aug 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants