From 51ee228df22533f861d55a1cf073a53ed03b17bf Mon Sep 17 00:00:00 2001 From: Satya Ortiz-Gagne Date: Tue, 25 Apr 2023 01:06:35 -0400 Subject: [PATCH] Add checkpointing example --- docs/examples/data/checkpointing/README.rst | 327 ++++++++++++++++++++ docs/examples/data/checkpointing/_index.rst | 34 ++ docs/examples/data/checkpointing/job.sh | 59 ++++ docs/examples/data/checkpointing/main.py | 221 +++++++++++++ docs/examples/data/index.rst | 1 + docs/examples/generate_diffs.sh | 4 + 6 files changed, 646 insertions(+) create mode 100644 docs/examples/data/checkpointing/README.rst create mode 100644 docs/examples/data/checkpointing/_index.rst create mode 100644 docs/examples/data/checkpointing/job.sh create mode 100644 docs/examples/data/checkpointing/main.py diff --git a/docs/examples/data/checkpointing/README.rst b/docs/examples/data/checkpointing/README.rst new file mode 100644 index 00000000..52eed0e0 --- /dev/null +++ b/docs/examples/data/checkpointing/README.rst @@ -0,0 +1,327 @@ +Checkpointing +============= + + +**Prerequisites** + +Make sure to read the following sections of the documentation before using this +example: + +* :ref:`pytorch_setup` +* :ref:`001 - Single GPU Job` + +The full source code for this example is available on `the mila-docs GitHub +repository. +`_ + + +**job.sh** + +.. code:: diff + + # distributed/001_single_gpu/job.sh -> data/checkpointing/job.sh + #!/bin/bash + #SBATCH --gpus-per-task=rtx8000:1 + #SBATCH --cpus-per-task=4 + #SBATCH --ntasks-per-node=1 + #SBATCH --mem=16G + #SBATCH --time=00:15:00 + +#SBATCH --signal=B:TERM@300 # tells the controller to send SIGTERM to the job 5 + + # min before its time ends to give it a chance for + + # better cleanup. If you cancel the job manually, + + # make sure that you specify the signal as TERM like + + # so scancel --signal=TERM . + + # https://dhruveshp.com/blog/2021/signal-propagation-on-slurm/ + + + +# trap the signal to the main BATCH script here. + +sig_handler() + +{ + + echo "BATCH interrupted" + + wait # wait for all children, this is important! + +} + + + +trap 'sig_handler' SIGINT SIGTERM SIGCONT + + + # Echo time and hostname into log + echo "Date: $(date)" + echo "Hostname: $(hostname)" + + + # Ensure only anaconda/3 module loaded. + module --quiet purge + # This example uses Conda to manage package dependencies. + # See https://docs.mila.quebec/Userguide.html#conda for more information. + module load anaconda/3 + module load cuda/11.7 + + + + # Creating the environment for the first time: + # conda create -y -n pytorch python=3.9 pytorch torchvision torchaudio \ + -# pytorch-cuda=11.7 -c pytorch -c nvidia + +# pytorch-cuda=11.7 scipy -c pytorch -c nvidia + # Other conda packages: + # conda install -y -n pytorch -c conda-forge rich tqdm + + # Activate pre-existing environment. + conda activate pytorch + + + # Stage dataset into $SLURM_TMPDIR + mkdir -p $SLURM_TMPDIR/data + cp /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/ + # General-purpose alternatives combining copy and unpack: + # unzip /network/datasets/some/file.zip -d $SLURM_TMPDIR/data/ + # tar -xf /network/datasets/some/file.tar -C $SLURM_TMPDIR/data/ + + + # Fixes issues with MIG-ed GPUs with versions of PyTorch < 2.0 + unset CUDA_VISIBLE_DEVICES + + # Execute Python script + python main.py + + +**main.py** + +.. code:: diff + + # distributed/001_single_gpu/main.py -> data/checkpointing/main.py + """Single-GPU training example.""" + import logging + import os + -from pathlib import Path + +import shutil + + import rich.logging + import torch + from torch import Tensor, nn + from torch.nn import functional as F + from torch.utils.data import DataLoader, random_split + from torchvision import transforms + from torchvision.datasets import CIFAR10 + from torchvision.models import resnet18 + from tqdm import tqdm + + + +try: + + _CHECKPTS_DIR = f"{os.environ['SCRATCH']}/checkpoints" + +except KeyError: + + _CHECKPTS_DIR = "../checkpoints" + + + + + def main(): + training_epochs = 10 + learning_rate = 5e-4 + weight_decay = 1e-4 + batch_size = 128 + + resume_file = f"{_CHECKPTS_DIR}/resnet18_cifar10/checkpoint.pth.tar" + + start_epoch = 0 + + best_acc = 0 + + # Check that the GPU is available + assert torch.cuda.is_available() and torch.cuda.device_count() > 0 + device = torch.device("cuda", 0) + + # Setup logging (optional, but much better than using print statements) + logging.basicConfig( + level=logging.INFO, + handlers=[rich.logging.RichHandler(markup=True)], # Very pretty, uses the `rich` package. + ) + + logger = logging.getLogger(__name__) + + - # Create a model and move it to the GPU. + + # Create a model. + model = resnet18(num_classes=10) + + + + # Resume from a checkpoint + + if os.path.isfile(resume_file): + + logger.debug(f"=> loading checkpoint '{resume_file}'") + + # Map model to be loaded to gpu. + + checkpoint = torch.load(resume_file, map_location="cuda:0") + + start_epoch = checkpoint["epoch"] + + best_acc = checkpoint["best_acc"] + + # best_acc may be from a checkpoint from a different GPU + + best_acc = best_acc.to(device) + + model.load_state_dict(checkpoint["state_dict"]) + + optimizer.load_state_dict(checkpoint["optimizer"]) + + logger.debug(f"=> loaded checkpoint '{resume_file}' (epoch {checkpoint['epoch']})") + + else: + + logger.debug(f"=> no checkpoint found at '{resume_file}'") + + + + # Move the model to the GPU. + model.to(device=device) + + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + + # Setup CIFAR10 + num_workers = get_num_workers() + - dataset_path = Path(os.environ.get("SLURM_TMPDIR", ".")) / "data" + - train_dataset, valid_dataset, test_dataset = make_datasets(str(dataset_path)) + + try: + + dataset_path = f"{os.environ['SLURM_TMPDIR']}/data" + + except KeyError: + + dataset_path = "../dataset" + + train_dataset, valid_dataset, test_dataset = make_datasets(dataset_path) + train_dataloader = DataLoader( + train_dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=True, + ) + valid_dataloader = DataLoader( + valid_dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=False, + ) + test_dataloader = DataLoader( # NOTE: Not used in this example. + test_dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=False, + ) + + # Checkout the "checkpointing and preemption" example for more info! + logger.debug("Starting training from scratch.") + + - for epoch in range(training_epochs): + + for epoch in range(start_epoch, training_epochs): + logger.debug(f"Starting epoch {epoch}/{training_epochs}") + + - # Set the model in training mode (important for e.g. BatchNorm and Dropout layers) + + # Set the model in training mode (this is important for e.g. BatchNorm and Dropout layers) + model.train() + + # NOTE: using a progress bar from tqdm because it's nicer than using `print`. + progress_bar = tqdm( + total=len(train_dataloader), + desc=f"Train epoch {epoch}", + ) + + # Training loop + for batch in train_dataloader: + # Move the batch to the GPU before we pass it to the model + batch = tuple(item.to(device) for item in batch) + x, y = batch + + # Forward pass + logits: Tensor = model(x) + + loss = F.cross_entropy(logits, y) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Calculate some metrics: + n_correct_predictions = logits.detach().argmax(-1).eq(y).sum() + n_samples = y.shape[0] + accuracy = n_correct_predictions / n_samples + + logger.debug(f"Accuracy: {accuracy.item():.2%}") + logger.debug(f"Average Loss: {loss.item()}") + + # Advance the progress bar one step, and update the "postfix" () the progress bar. (nicer than just) + progress_bar.update(1) + progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy.item()) + progress_bar.close() + + val_loss, val_accuracy = validation_loop(model, valid_dataloader, device) + logger.info(f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}") + + + # remember best acc and save checkpoint + + is_best = val_accuracy > best_acc + + best_acc = max(val_accuracy, best_acc) + + + + save_checkpoint({ + + "epoch": epoch + 1, + + "arch": "resnet18", + + "state_dict": model.state_dict(), + + "best_acc": best_acc, + + "optimizer": optimizer.state_dict(), + + }, is_best) + + + print("Done!") + + + @torch.no_grad() + def validation_loop(model: nn.Module, dataloader: DataLoader, device: torch.device): + model.eval() + + total_loss = 0.0 + n_samples = 0 + correct_predictions = 0 + + for batch in dataloader: + batch = tuple(item.to(device) for item in batch) + x, y = batch + + logits: Tensor = model(x) + loss = F.cross_entropy(logits, y) + + batch_n_samples = x.shape[0] + batch_correct_predictions = logits.argmax(-1).eq(y).sum() + + total_loss += loss.item() + n_samples += batch_n_samples + correct_predictions += batch_correct_predictions + + accuracy = correct_predictions / n_samples + return total_loss, accuracy + + + def make_datasets( + dataset_path: str, + val_split: float = 0.1, + val_split_seed: int = 42, + ): + """Returns the training, validation, and test splits for CIFAR10. + + NOTE: We don't use image transforms here for simplicity. + Having different transformations for train and validation would complicate things a bit. + Later examples will show how to do the train/val/test split properly when using transforms. + """ + train_dataset = CIFAR10( + root=dataset_path, transform=transforms.ToTensor(), download=True, train=True + ) + test_dataset = CIFAR10( + root=dataset_path, transform=transforms.ToTensor(), download=True, train=False + ) + # Split the training dataset into a training and validation set. + - n_samples = len(train_dataset) + - n_valid = int(val_split * n_samples) + - n_train = n_samples - n_valid + train_dataset, valid_dataset = random_split( + - train_dataset, (n_train, n_valid), torch.Generator().manual_seed(val_split_seed) + + train_dataset, ((1 - val_split), val_split), torch.Generator().manual_seed(val_split_seed) + ) + return train_dataset, valid_dataset, test_dataset + + + def get_num_workers() -> int: + """Gets the optimal number of DatLoader workers to use in the current job.""" + if "SLURM_CPUS_PER_TASK" in os.environ: + return int(os.environ["SLURM_CPUS_PER_TASK"]) + if hasattr(os, "sched_getaffinity"): + return len(os.sched_getaffinity(0)) + return torch.multiprocessing.cpu_count() + + + +def save_checkpoint(state, is_best, filename=f"{_CHECKPTS_DIR}/checkpoint.pth.tar"): + + torch.save(state, filename) + + if is_best: + + _dir = os.path.dirname(filename) + + shutil.copyfile(filename, f"{_dir}/model_best.pth.tar") + + + + + if __name__ == "__main__": + main() + + +**Running this example** + +.. code-block:: bash + + $ sbatch job.sh diff --git a/docs/examples/data/checkpointing/_index.rst b/docs/examples/data/checkpointing/_index.rst new file mode 100644 index 00000000..afbc2282 --- /dev/null +++ b/docs/examples/data/checkpointing/_index.rst @@ -0,0 +1,34 @@ +Checkpointing +============= + + +**Prerequisites** + +Make sure to read the following sections of the documentation before using this +example: + +* :ref:`pytorch_setup` +* :ref:`001 - Single GPU Job` + +The full source code for this example is available on `the mila-docs GitHub +repository. +`_ + + +**job.sh** + +.. literalinclude:: examples/data/checkpointing/job.sh.diff + :language: diff + + +**main.py** + +.. literalinclude:: examples/data/checkpointing/main.py.diff + :language: diff + + +**Running this example** + +.. code-block:: bash + + $ sbatch job.sh diff --git a/docs/examples/data/checkpointing/job.sh b/docs/examples/data/checkpointing/job.sh new file mode 100644 index 00000000..1b3cf292 --- /dev/null +++ b/docs/examples/data/checkpointing/job.sh @@ -0,0 +1,59 @@ +#!/bin/bash +#SBATCH --gpus-per-task=rtx8000:1 +#SBATCH --cpus-per-task=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --mem=16G +#SBATCH --time=00:15:00 +#SBATCH --signal=B:TERM@300 # tells the controller to send SIGTERM to the job 5 + # min before its time ends to give it a chance for + # better cleanup. If you cancel the job manually, + # make sure that you specify the signal as TERM like + # so scancel --signal=TERM . + # https://dhruveshp.com/blog/2021/signal-propagation-on-slurm/ + +# trap the signal to the main BATCH script here. +sig_handler() +{ + echo "BATCH interrupted" + wait # wait for all children, this is important! +} + +trap 'sig_handler' SIGINT SIGTERM SIGCONT + + +# Echo time and hostname into log +echo "Date: $(date)" +echo "Hostname: $(hostname)" + + +# Ensure only anaconda/3 module loaded. +module --quiet purge +# This example uses Conda to manage package dependencies. +# See https://docs.mila.quebec/Userguide.html#conda for more information. +module load anaconda/3 +module load cuda/11.7 + + +# Creating the environment for the first time: +# conda create -y -n pytorch python=3.9 pytorch torchvision torchaudio \ +# pytorch-cuda=11.7 scipy -c pytorch -c nvidia +# Other conda packages: +# conda install -y -n pytorch -c conda-forge rich tqdm + +# Activate pre-existing environment. +conda activate pytorch + + +# Stage dataset into $SLURM_TMPDIR +mkdir -p $SLURM_TMPDIR/data +cp /network/datasets/cifar10/cifar-10-python.tar.gz $SLURM_TMPDIR/data/ +# General-purpose alternatives combining copy and unpack: +# unzip /network/datasets/some/file.zip -d $SLURM_TMPDIR/data/ +# tar -xf /network/datasets/some/file.tar -C $SLURM_TMPDIR/data/ + + +# Fixes issues with MIG-ed GPUs with versions of PyTorch < 2.0 +unset CUDA_VISIBLE_DEVICES + +# Execute Python script +python main.py diff --git a/docs/examples/data/checkpointing/main.py b/docs/examples/data/checkpointing/main.py new file mode 100644 index 00000000..e274acd3 --- /dev/null +++ b/docs/examples/data/checkpointing/main.py @@ -0,0 +1,221 @@ +"""Single-GPU training example.""" +import logging +import os +import shutil + +import rich.logging +import torch +from torch import Tensor, nn +from torch.nn import functional as F +from torch.utils.data import DataLoader, random_split +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet18 +from tqdm import tqdm + + +try: + _CHECKPTS_DIR = f"{os.environ['SCRATCH']}/checkpoints" +except KeyError: + _CHECKPTS_DIR = "../checkpoints" + + +def main(): + training_epochs = 10 + learning_rate = 5e-4 + weight_decay = 1e-4 + batch_size = 128 + resume_file = f"{_CHECKPTS_DIR}/resnet18_cifar10/checkpoint.pth.tar" + start_epoch = 0 + best_acc = 0 + + # Check that the GPU is available + assert torch.cuda.is_available() and torch.cuda.device_count() > 0 + device = torch.device("cuda", 0) + + # Setup logging (optional, but much better than using print statements) + logging.basicConfig( + level=logging.INFO, + handlers=[rich.logging.RichHandler(markup=True)], # Very pretty, uses the `rich` package. + ) + + logger = logging.getLogger(__name__) + + # Create a model. + model = resnet18(num_classes=10) + + # Resume from a checkpoint + if os.path.isfile(resume_file): + logger.debug(f"=> loading checkpoint '{resume_file}'") + # Map model to be loaded to gpu. + checkpoint = torch.load(resume_file, map_location="cuda:0") + start_epoch = checkpoint["epoch"] + best_acc = checkpoint["best_acc"] + # best_acc may be from a checkpoint from a different GPU + best_acc = best_acc.to(device) + model.load_state_dict(checkpoint["state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + logger.debug(f"=> loaded checkpoint '{resume_file}' (epoch {checkpoint['epoch']})") + else: + logger.debug(f"=> no checkpoint found at '{resume_file}'") + + # Move the model to the GPU. + model.to(device=device) + + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) + + # Setup CIFAR10 + num_workers = get_num_workers() + try: + dataset_path = f"{os.environ['SLURM_TMPDIR']}/data" + except KeyError: + dataset_path = "../dataset" + train_dataset, valid_dataset, test_dataset = make_datasets(dataset_path) + train_dataloader = DataLoader( + train_dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=True, + ) + valid_dataloader = DataLoader( + valid_dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=False, + ) + test_dataloader = DataLoader( # NOTE: Not used in this example. + test_dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=False, + ) + + # Checkout the "checkpointing and preemption" example for more info! + logger.debug("Starting training from scratch.") + + for epoch in range(start_epoch, training_epochs): + logger.debug(f"Starting epoch {epoch}/{training_epochs}") + + # Set the model in training mode (this is important for e.g. BatchNorm and Dropout layers) + model.train() + + # NOTE: using a progress bar from tqdm because it's nicer than using `print`. + progress_bar = tqdm( + total=len(train_dataloader), + desc=f"Train epoch {epoch}", + ) + + # Training loop + for batch in train_dataloader: + # Move the batch to the GPU before we pass it to the model + batch = tuple(item.to(device) for item in batch) + x, y = batch + + # Forward pass + logits: Tensor = model(x) + + loss = F.cross_entropy(logits, y) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Calculate some metrics: + n_correct_predictions = logits.detach().argmax(-1).eq(y).sum() + n_samples = y.shape[0] + accuracy = n_correct_predictions / n_samples + + logger.debug(f"Accuracy: {accuracy.item():.2%}") + logger.debug(f"Average Loss: {loss.item()}") + + # Advance the progress bar one step, and update the "postfix" () the progress bar. (nicer than just) + progress_bar.update(1) + progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy.item()) + progress_bar.close() + + val_loss, val_accuracy = validation_loop(model, valid_dataloader, device) + logger.info(f"Epoch {epoch}: Val loss: {val_loss:.3f} accuracy: {val_accuracy:.2%}") + + # remember best acc and save checkpoint + is_best = val_accuracy > best_acc + best_acc = max(val_accuracy, best_acc) + + save_checkpoint({ + "epoch": epoch + 1, + "arch": "resnet18", + "state_dict": model.state_dict(), + "best_acc": best_acc, + "optimizer": optimizer.state_dict(), + }, is_best) + + print("Done!") + + +@torch.no_grad() +def validation_loop(model: nn.Module, dataloader: DataLoader, device: torch.device): + model.eval() + + total_loss = 0.0 + n_samples = 0 + correct_predictions = 0 + + for batch in dataloader: + batch = tuple(item.to(device) for item in batch) + x, y = batch + + logits: Tensor = model(x) + loss = F.cross_entropy(logits, y) + + batch_n_samples = x.shape[0] + batch_correct_predictions = logits.argmax(-1).eq(y).sum() + + total_loss += loss.item() + n_samples += batch_n_samples + correct_predictions += batch_correct_predictions + + accuracy = correct_predictions / n_samples + return total_loss, accuracy + + +def make_datasets( + dataset_path: str, + val_split: float = 0.1, + val_split_seed: int = 42, +): + """Returns the training, validation, and test splits for CIFAR10. + + NOTE: We don't use image transforms here for simplicity. + Having different transformations for train and validation would complicate things a bit. + Later examples will show how to do the train/val/test split properly when using transforms. + """ + train_dataset = CIFAR10( + root=dataset_path, transform=transforms.ToTensor(), download=True, train=True + ) + test_dataset = CIFAR10( + root=dataset_path, transform=transforms.ToTensor(), download=True, train=False + ) + # Split the training dataset into a training and validation set. + train_dataset, valid_dataset = random_split( + train_dataset, ((1 - val_split), val_split), torch.Generator().manual_seed(val_split_seed) + ) + return train_dataset, valid_dataset, test_dataset + + +def get_num_workers() -> int: + """Gets the optimal number of DatLoader workers to use in the current job.""" + if "SLURM_CPUS_PER_TASK" in os.environ: + return int(os.environ["SLURM_CPUS_PER_TASK"]) + if hasattr(os, "sched_getaffinity"): + return len(os.sched_getaffinity(0)) + return torch.multiprocessing.cpu_count() + + +def save_checkpoint(state, is_best, filename=f"{_CHECKPTS_DIR}/checkpoint.pth.tar"): + torch.save(state, filename) + if is_best: + _dir = os.path.dirname(filename) + shutil.copyfile(filename, f"{_dir}/model_best.pth.tar") + + +if __name__ == "__main__": + main() diff --git a/docs/examples/data/index.rst b/docs/examples/data/index.rst index bd8e2691..30b6a1b2 100644 --- a/docs/examples/data/index.rst +++ b/docs/examples/data/index.rst @@ -5,3 +5,4 @@ Data Handling during Training .. include:: examples/data/torchvision/_index.rst .. include:: examples/data/hf/_index.rst +.. include:: examples/data/checkpointing/_index.rst diff --git a/docs/examples/generate_diffs.sh b/docs/examples/generate_diffs.sh index 5d8f1cea..456cac24 100755 --- a/docs/examples/generate_diffs.sh +++ b/docs/examples/generate_diffs.sh @@ -31,6 +31,10 @@ generate_diff distributed/001_single_gpu/main.py distributed/002_multi_gpu/main. generate_diff distributed/002_multi_gpu/job.sh distributed/003_multi_node/job.sh generate_diff distributed/002_multi_gpu/main.py distributed/003_multi_node/main.py +# single_gpu -> checkpointing +generate_diff distributed/001_single_gpu/job.sh data/checkpointing/job.sh +generate_diff distributed/001_single_gpu/main.py data/checkpointing/main.py + # single_gpu -> huggingface generate_diff distributed/001_single_gpu/job.sh data/hf/job.sh generate_diff distributed/001_single_gpu/main.py data/hf/main.py