From effa80c1a49de3f6126bfd36875c513e2ac4df80 Mon Sep 17 00:00:00 2001 From: Theo Date: Fri, 31 May 2024 17:56:36 +0100 Subject: [PATCH] Refactor plotting a tiny bit --- src/base_trainer.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/src/base_trainer.py b/src/base_trainer.py index bffae53..301a0d1 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -25,7 +25,7 @@ from tqdm import tqdm from conf import project as project_conf -from utils import blink_pbar, to_cuda, update_pbar_str +from utils import blink_pbar, colorize, to_cuda, update_pbar_str from utils.helpers import BestNModelSaver from utils.training import visualize_model_predictions @@ -260,9 +260,12 @@ def train( """ if model_ckpt_path is not None: self._load_checkpoint(model_ckpt_path) - if project_conf.PLOT_ENABLED: - self._setup_plot() - print(f"[*] Training for {epochs} epochs") + print( + colorize( + f"[*] Training {self._run_name} for {epochs} epochs", + project_conf.ANSI_COLORS["green"], + ) + ) self._viz_n_samples = visualize_n_samples train_losses: List[float] = [] val_losses: List[float] = [] @@ -311,12 +314,16 @@ def train( ) @staticmethod - def _setup_plot(): + def _setup_plot(run_name: str, log_scale: bool = False): """Setup the plot for training and validation losses.""" - plt.title("Training and validation losses") + plt.title(f"Training curves for {run_name}") plt.theme("dark") plt.xlabel("Epoch") - plt.ylabel("Loss") + if log_scale: + plt.ylabel("Loss (log scale)") + plt.yscale("log") + else: + plt.ylabel("Loss") plt.grid(True, True) def _plot(self, epoch: int, train_losses: List[float], val_losses: List[float]): @@ -329,18 +336,12 @@ def _plot(self, epoch: int, train_losses: List[float], val_losses: List[float]): None """ plt.clf() - plt.theme("dark") - plt.xlabel("Epoch") if project_conf.LOG_SCALE_PLOT: if any(loss_val <= 0 for loss_val in train_losses + val_losses): raise ValueError( "Cannot plot on a log scale if there are non-positive losses." ) - plt.ylabel("Loss (log scale)") - plt.yscale("log") - else: - plt.ylabel("Loss") - plt.grid(True, True) + self._setup_plot(self._run_name, log_scale=project_conf.LOG_SCALE_PLOT) plt.plot( list(range(self._starting_epoch, epoch + 1)), train_losses,