From 5c145a3b964451ea57715d391e224cc1d8a24c00 Mon Sep 17 00:00:00 2001 From: Olivier Dulcy <106678676+odulcy-mindee@users.noreply.github.com> Date: Tue, 17 Dec 2024 16:54:56 +0100 Subject: [PATCH] feat: :sparkles: specify `output_dir` in reference scripts (#1820) --- references/classification/train_pytorch_character.py | 4 +++- references/classification/train_pytorch_orientation.py | 4 +++- references/classification/train_tensorflow_character.py | 4 +++- references/classification/train_tensorflow_orientation.py | 4 +++- references/detection/train_pytorch.py | 6 ++++-- references/detection/train_pytorch_ddp.py | 6 ++++-- references/detection/train_tensorflow.py | 6 ++++-- references/recognition/train_pytorch.py | 3 ++- references/recognition/train_pytorch_ddp.py | 3 ++- references/recognition/train_tensorflow.py | 3 ++- 10 files changed, 30 insertions(+), 13 deletions(-) diff --git a/references/classification/train_pytorch_character.py b/references/classification/train_pytorch_character.py index 14bb749066..999ae08c22 100644 --- a/references/classification/train_pytorch_character.py +++ b/references/classification/train_pytorch_character.py @@ -11,6 +11,7 @@ import logging import multiprocessing as mp import time +from pathlib import Path import numpy as np import torch @@ -335,7 +336,7 @@ def main(args): val_loss, acc = evaluate(model, val_loader, batch_transforms) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - torch.save(model.state_dict(), f"./{exp_name}.pt") + torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}.pt") min_loss = val_loss print(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})") # W&B @@ -370,6 +371,7 @@ def parse_args(): ) parser.add_argument("arch", type=str, help="text-recognition model to train") + parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") parser.add_argument("-b", "--batch_size", type=int, default=64, help="batch size for training") diff --git a/references/classification/train_pytorch_orientation.py b/references/classification/train_pytorch_orientation.py index d3a8a57fa2..207ededd56 100644 --- a/references/classification/train_pytorch_orientation.py +++ b/references/classification/train_pytorch_orientation.py @@ -11,6 +11,7 @@ import logging import multiprocessing as mp import time +from pathlib import Path import numpy as np import torch @@ -341,7 +342,7 @@ def main(args): val_loss, acc = evaluate(model, val_loader, batch_transforms) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - torch.save(model.state_dict(), f"./{exp_name}.pt") + torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}.pt") min_loss = val_loss print(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})") # W&B @@ -376,6 +377,7 @@ def parse_args(): ) parser.add_argument("arch", type=str, help="classification model to train") + parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") parser.add_argument("--type", type=str, required=True, choices=["page", "crop"], help="type of data to train on") parser.add_argument("--train_path", type=str, required=True, help="path to training data folder") parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder") diff --git a/references/classification/train_tensorflow_character.py b/references/classification/train_tensorflow_character.py index 3049e60ecd..33f77fe941 100644 --- a/references/classification/train_tensorflow_character.py +++ b/references/classification/train_tensorflow_character.py @@ -14,6 +14,7 @@ import datetime import time +from pathlib import Path import numpy as np import tensorflow as tf @@ -298,7 +299,7 @@ def main(args): val_loss, acc = evaluate(model, val_loader, batch_transforms) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - model.save_weights(f"./{exp_name}.weights.h5") + model.save_weights(Path(args.output_dir) / f"{exp_name}.weights.h5") min_loss = val_loss print(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})") # W&B @@ -345,6 +346,7 @@ def parse_args(): ) parser.add_argument("arch", type=str, help="text-recognition model to train") + parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") parser.add_argument("--epochs", type=int, default=10, help="number of epochs to train the model on") parser.add_argument("-b", "--batch_size", type=int, default=64, help="batch size for training") diff --git a/references/classification/train_tensorflow_orientation.py b/references/classification/train_tensorflow_orientation.py index 05ae7fce96..d582f747b6 100644 --- a/references/classification/train_tensorflow_orientation.py +++ b/references/classification/train_tensorflow_orientation.py @@ -14,6 +14,7 @@ import datetime import time +from pathlib import Path import numpy as np import tensorflow as tf @@ -308,7 +309,7 @@ def main(args): val_loss, acc = evaluate(model, val_loader, batch_transforms) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - model.save_weights(f"./{exp_name}.weights.h5") + model.save_weights(Path(args.output_dir) / f"{exp_name}.weights.h5") min_loss = val_loss print(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})") # W&B @@ -355,6 +356,7 @@ def parse_args(): ) parser.add_argument("arch", type=str, help="classification model to train") + parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") parser.add_argument("--type", type=str, required=True, choices=["page", "crop"], help="type of data to train on") parser.add_argument("--train_path", type=str, help="path to training data folder") parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder") diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index 8d1bfa4499..48b5ca022a 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -12,6 +12,7 @@ import logging import multiprocessing as mp import time +from pathlib import Path import numpy as np import torch @@ -390,11 +391,11 @@ def main(args): val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - torch.save(model.state_dict(), f"./{exp_name}.pt") + torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}.pt") min_loss = val_loss if args.save_interval_epoch: print(f"Saving state at epoch: {epoch + 1}") - torch.save(model.state_dict(), f"./{exp_name}_epoch{epoch + 1}.pt") + torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}_epoch{epoch + 1}.pt") log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " if any(val is None for val in (recall, precision, mean_iou)): log_msg += "(Undefined metric value, caused by empty GTs or predictions)" @@ -428,6 +429,7 @@ def parse_args(): ) parser.add_argument("arch", type=str, help="text-detection model to train") + parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") parser.add_argument("--train_path", type=str, required=True, help="path to training data folder") parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder") parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") diff --git a/references/detection/train_pytorch_ddp.py b/references/detection/train_pytorch_ddp.py index ba0bf5d8f3..937bc208ac 100644 --- a/references/detection/train_pytorch_ddp.py +++ b/references/detection/train_pytorch_ddp.py @@ -11,6 +11,7 @@ import hashlib import multiprocessing import time +from pathlib import Path import numpy as np import torch @@ -410,11 +411,11 @@ def main(rank: int, world_size: int, args): ) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - torch.save(model.module.state_dict(), f"./{exp_name}.pt") + torch.save(model.module.state_dict(), Path(args.output_dir) / f"{exp_name}.pt") min_loss = val_loss if args.save_interval_epoch: print(f"Saving state at epoch: {epoch + 1}") - torch.save(model.state_dict(), f"./{exp_name}_epoch{epoch + 1}.pt") + torch.save(model.module.state_dict(), Path(args.output_dir) / f"{exp_name}_epoch{epoch + 1}.pt") log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " if any(val is None for val in (recall, precision, mean_iou)): log_msg += "(Undefined metric value, caused by empty GTs or predictions)" @@ -453,6 +454,7 @@ def parse_args(): parser.add_argument("--backend", default="nccl", type=str, help="backend to use for torch DDP") parser.add_argument("arch", type=str, help="text-detection model to train") + parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") parser.add_argument("--train_path", type=str, required=True, help="path to training data folder") parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder") parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index 561447c5f7..bde115e7b7 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -15,6 +15,7 @@ import datetime import hashlib import time +from pathlib import Path import numpy as np import tensorflow as tf @@ -353,11 +354,11 @@ def main(args): val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - model.save_weights(f"./{exp_name}.weights.h5") + model.save_weights(Path(args.output_dir) / f"{exp_name}.weights.h5") min_loss = val_loss if args.save_interval_epoch: print(f"Saving state at epoch: {epoch + 1}") - model.save_weights(f"./{exp_name}_{epoch + 1}.weights.h5") + model.save_weights(Path(args.output_dir) / f"{exp_name}_{epoch + 1}.weights.h5") log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " if any(val is None for val in (recall, precision, mean_iou)): log_msg += "(Undefined metric value, caused by empty GTs or predictions)" @@ -401,6 +402,7 @@ def parse_args(): ) parser.add_argument("arch", type=str, help="text-detection model to train") + parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") parser.add_argument("--train_path", type=str, required=True, help="path to training data folder") parser.add_argument("--val_path", type=str, required=True, help="path to validation data folder") parser.add_argument("--name", type=str, default=None, help="Name of your training experiment") diff --git a/references/recognition/train_pytorch.py b/references/recognition/train_pytorch.py index 608c5d4145..a5632e79bf 100644 --- a/references/recognition/train_pytorch.py +++ b/references/recognition/train_pytorch.py @@ -395,7 +395,7 @@ def main(args): val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - torch.save(model.state_dict(), f"./{exp_name}.pt") + torch.save(model.state_dict(), Path(args.output_dir) / f"{exp_name}.pt") min_loss = val_loss print( f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " @@ -427,6 +427,7 @@ def parse_args(): ) parser.add_argument("arch", type=str, help="text-recognition model to train") + parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") parser.add_argument("--train_path", type=str, default=None, help="path to train data folder(s)") parser.add_argument("--val_path", type=str, default=None, help="path to val data folder") parser.add_argument( diff --git a/references/recognition/train_pytorch_ddp.py b/references/recognition/train_pytorch_ddp.py index 2f40ed766e..d1c29eec27 100644 --- a/references/recognition/train_pytorch_ddp.py +++ b/references/recognition/train_pytorch_ddp.py @@ -329,7 +329,7 @@ def main(rank: int, world_size: int, args): # random parameters and gradients are synchronized in backward passes. # Therefore, saving it in one process is sufficient. print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - torch.save(model.module.state_dict(), f"./{exp_name}.pt") + torch.save(model.module.state_dict(), Path(args.output_dir) / f"{exp_name}.pt") min_loss = val_loss print( f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " @@ -365,6 +365,7 @@ def parse_args(): parser.add_argument("--backend", default="nccl", type=str, help="Backend to use for Torch DDP") parser.add_argument("arch", type=str, help="text-recognition model to train") + parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") parser.add_argument("--train_path", type=str, default=None, help="path to train data folder(s)") parser.add_argument("--val_path", type=str, default=None, help="path to val data folder") parser.add_argument( diff --git a/references/recognition/train_tensorflow.py b/references/recognition/train_tensorflow.py index 8a6a9e1e01..bbb8d77475 100644 --- a/references/recognition/train_tensorflow.py +++ b/references/recognition/train_tensorflow.py @@ -350,7 +350,7 @@ def main(args): val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric) if val_loss < min_loss: print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") - model.save_weights(f"./{exp_name}.weights.h5") + model.save_weights(Path(args.output_dir) / f"{exp_name}.weights.h5") min_loss = val_loss print( f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " @@ -391,6 +391,7 @@ def parse_args(): ) parser.add_argument("arch", type=str, help="text-recognition model to train") + parser.add_argument("--output_dir", type=str, default=".", help="path to save checkpoints and final model") parser.add_argument("--train_path", type=str, default=None, help="path to train data folder(s)") parser.add_argument("--val_path", type=str, default=None, help="path to val data folder") parser.add_argument(