-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
StreamingDataset incompatibility with PyTorch Lightning #133
Comments
Hey @enrico-stauss, can you share a reproducible script of the problem. Not sure I fully follow it. The size of the StreamingDataset should be exact. If not, there is a bug. |
That is exactly the problem. As you can read in the second issue I linked, even if the size IS exact, when specifying drop_last=True, pytorch_lightning seems to skip the validation. I'll try to provide a MWE for it when I got some time to spare. |
@tchaton |
Maybe changing to the standard Dataset type could also help with this one #135 (comment). |
Hey @enrico-stauss, changing the base type is a very large task and not something I am planing to do. |
I understand. Do you have any idea how to proceed though, as it does severely break compatibility? I might have a look at it but can't promise anything. |
@enrico-stauss I think I have a fix. Could you try this branch: #139. This will work only with the StreamingDataLoader Example of the issue: There is 300 samples, 2 workers, batch size of 4. This is 300 / (4 * 2) = 37.5 batches. Because there is a non completed batch, the StopIteration is triggered while fetching the last batch and the validation is skipped. My PR extends the StreamingDataLoader to pass the number of workers and batch size to the dataset, so the shuffler can drop the extra 0.5 batches causing the issue. |
Hey @enrico-stauss, can you confirm it works for you with the PR ? |
Sorry @tchaton I did not find time to test it earlier. My MWE however still shows that no validation is performed even with the updates that are not merged into main. I don't think it's possible to resolve this from the side of LitData without either removing the |
Hey @enrico-stauss Trust me, we are going to figure this out. And I am one of the core dev of PyTorch Lightning, so we will find a way. But It think this is a litdata problem. Would you be available to pair debug this with me sometimes next week ? Also, would you be interested to join the core team of litdata ? |
Hi @tchaton But with you being a core dev of PyTorchLightning, too, I'm confident that we can figure it out. I think we can schedule a meeting for next week, let's get in touch on discord. Then we can also talk about what you proposed. :) |
🐛 Bug
It is a known issue with PyTorch's IterableDataset that issues can occur when the dataset defines
len()
. PyTorch Lightning even raises a warning to make the user aware. Unfortunatly theStreamingDataset
is based on theIterableDataset
and defineslen()
. See the following issues for more context:I am aware that the root of the issue lies not within LitData and that it is tricky to resolve from the side of PyTorch Lightning. Still, defining
len()
in theStreamingDataset
seems unfortunate as I now can't avoid the issue.To Reproduce
Expected behavior
The validation epoch should not be skipped.
Is it maybe possible to use the standard (map style) torch
Dataset
instead of theIterableDataset
as base class to theStreamingDataset
?Environment
Additional context
The text was updated successfully, but these errors were encountered: