Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StreamingDataset incompatibility with PyTorch Lightning #133

Closed
enrico-stauss opened this issue May 20, 2024 · 11 comments · Fixed by #147
Closed

StreamingDataset incompatibility with PyTorch Lightning #133

enrico-stauss opened this issue May 20, 2024 · 11 comments · Fixed by #147
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@enrico-stauss
Copy link
Contributor

enrico-stauss commented May 20, 2024

🐛 Bug

It is a known issue with PyTorch's IterableDataset that issues can occur when the dataset defines len(). PyTorch Lightning even raises a warning to make the user aware. Unfortunatly the StreamingDataset is based on the IterableDataset and defines len(). See the following issues for more context:

I am aware that the root of the issue lies not within LitData and that it is tricky to resolve from the side of PyTorch Lightning. Still, defining len() in the StreamingDataset seems unfortunate as I now can't avoid the issue.

To Reproduce

import os

import torch
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS

from litdata import StreamingDataset, StreamingDataLoader, optimize


DATA_BASE_DIR = "./data_issue133"
DROP_LAST_TRAIN_BATCH = True
BATCH_SIZE = 4
N_WORKERS_DATALOADER = 4


def generate_sample(i):
    return {"index": i, "data": torch.rand((3, 20, 20))}


class MyDataModule(LightningDataModule):
    def __init__(self, batch_size: int = BATCH_SIZE, num_workers: int = N_WORKERS_DATALOADER, **dataloader_kwargs):
        super().__init__()
        self.dataset_root = DATA_BASE_DIR
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.dataloader_kwargs = dataloader_kwargs

    def prepare_data(self) -> None:
        data_dir_train = os.path.join(self.dataset_root, "train")
        if not os.path.exists(data_dir_train):
            os.makedirs(data_dir_train)
            optimize(
                fn=generate_sample,
                inputs=list(range(300)),
                output_dir=data_dir_train,
                chunk_size=1,
                num_workers=max(1, os.cpu_count() - 1),
            )

        data_dir_val = os.path.join(self.dataset_root, "val")
        if not os.path.exists(data_dir_val):
            os.makedirs(data_dir_val)
            optimize(
                fn=generate_sample,
                inputs=list(range(100)),
                output_dir=data_dir_val,
                chunk_size=25,
                num_workers=max(1, os.cpu_count() - 1),
            )

    def train_dataloader(self) -> TRAIN_DATALOADERS:
        dataset = StreamingDataset(os.path.join(self.dataset_root, "train"))
        dataloader = StreamingDataLoader(
            dataset=dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            persistent_workers=True,
            shuffle=True,
            drop_last=DROP_LAST_TRAIN_BATCH,
            **self.dataloader_kwargs,
        )
        return dataloader

    def val_dataloader(self) -> EVAL_DATALOADERS:
        dataset = StreamingDataset(os.path.join(self.dataset_root, "val"))
        dataloader = StreamingDataLoader(
            dataset=dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            persistent_workers=True,
            shuffle=True,
            drop_last=True,
            **self.dataloader_kwargs,
        )
        return dataloader


class MyModel(LightningModule):
    def __init__(self) -> None:
        super().__init__()
        self.model = torch.nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3)
        self.did_validation = False

    def training_step(self, batch):
        loss = self.model(batch["data"]).mean()
        return loss

    def on_train_epoch_start(self) -> None:
        print("TRAINING START")

    def on_train_epoch_end(self) -> None:
        print("TRAINING END")

    def validation_step(self, batch):
        self.did_validation = True
        loss = self.model(batch["data"]).mean()
        return loss

    def on_validation_epoch_start(self) -> None:
        print("VALIDATION START")

    def on_validation_epoch_end(self) -> None:
        print("VALIDATION END")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())


def main():
    model = MyModel()
    trainer = Trainer(logger=False, max_epochs=2, num_sanity_val_steps=0)
    datamodule = MyDataModule()
    trainer.fit(model, datamodule=datamodule)
    print("Performed validation:", model.did_validation)


if __name__ == "__main__":
    main()

Expected behavior

The validation epoch should not be skipped.
Is it maybe possible to use the standard (map style) torch Dataset instead of the IterableDataset as base class to the StreamingDataset?

Environment

pytorch-lightning==2.2.4
litdata==0.2.6

Additional context

@enrico-stauss enrico-stauss added bug Something isn't working help wanted Extra attention is needed labels May 20, 2024
@tchaton
Copy link
Collaborator

tchaton commented May 20, 2024

Hey @enrico-stauss, can you share a reproducible script of the problem. Not sure I fully follow it. The size of the StreamingDataset should be exact. If not, there is a bug.

@enrico-stauss
Copy link
Contributor Author

That is exactly the problem. As you can read in the second issue I linked, even if the size IS exact, when specifying drop_last=True, pytorch_lightning seems to skip the validation.
Also the warning is still raised by pytorch lightning.

I'll try to provide a MWE for it when I got some time to spare.

@enrico-stauss
Copy link
Contributor Author

@tchaton
Please have a look at the modified original post. You can exchange DROP_LAST_TRAIN_SAMPLE=False to see that it then does run the validation epoch.

@enrico-stauss
Copy link
Contributor Author

Maybe changing to the standard Dataset type could also help with this one #135 (comment).

@tchaton
Copy link
Collaborator

tchaton commented May 22, 2024

Hey @enrico-stauss, changing the base type is a very large task and not something I am planing to do.

@enrico-stauss
Copy link
Contributor Author

I understand. Do you have any idea how to proceed though, as it does severely break compatibility? I might have a look at it but can't promise anything.
In all honesty, I think the change should be made on the side of PyTorchLightning but as mentioned here, it seems as this is just not possible at the moment.

@tchaton
Copy link
Collaborator

tchaton commented May 22, 2024

@enrico-stauss I think I have a fix. Could you try this branch: #139. This will work only with the StreamingDataLoader

Example of the issue: There is 300 samples, 2 workers, batch size of 4. This is 300 / (4 * 2) = 37.5 batches. Because there is a non completed batch, the StopIteration is triggered while fetching the last batch and the validation is skipped.

My PR extends the StreamingDataLoader to pass the number of workers and batch size to the dataset, so the shuffler can drop the extra 0.5 batches causing the issue.

@tchaton
Copy link
Collaborator

tchaton commented May 24, 2024

Hey @enrico-stauss, can you confirm it works for you with the PR ?

@enrico-stauss
Copy link
Contributor Author

Sorry @tchaton I did not find time to test it earlier. My MWE however still shows that no validation is performed even with the updates that are not merged into main. I don't think it's possible to resolve this from the side of LitData without either removing the __len__ method or switching to the standard Dataset base class.

@tchaton
Copy link
Collaborator

tchaton commented May 25, 2024

Hey @enrico-stauss Trust me, we are going to figure this out. And I am one of the core dev of PyTorch Lightning, so we will find a way. But It think this is a litdata problem.

Would you be available to pair debug this with me sometimes next week ?

Also, would you be interested to join the core team of litdata ?

@enrico-stauss
Copy link
Contributor Author

Hi @tchaton
The reason I believe that it's not a LitData problem is, that the second issue I linked in the original post already reported the issue using the 'IterableDataset' as base class.

But with you being a core dev of PyTorchLightning, too, I'm confident that we can figure it out.

I think we can schedule a meeting for next week, let's get in touch on discord. Then we can also talk about what you proposed. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
2 participants