Skip to content

Commit

Permalink
Refactor plotting a tiny bit
Browse files Browse the repository at this point in the history
  • Loading branch information
DubiousCactus committed May 31, 2024
1 parent b1c6c2e commit effa80c
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions src/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tqdm import tqdm

from conf import project as project_conf
from utils import blink_pbar, to_cuda, update_pbar_str
from utils import blink_pbar, colorize, to_cuda, update_pbar_str
from utils.helpers import BestNModelSaver
from utils.training import visualize_model_predictions

Expand Down Expand Up @@ -260,9 +260,12 @@ def train(
"""
if model_ckpt_path is not None:
self._load_checkpoint(model_ckpt_path)
if project_conf.PLOT_ENABLED:
self._setup_plot()
print(f"[*] Training for {epochs} epochs")
print(
colorize(
f"[*] Training {self._run_name} for {epochs} epochs",
project_conf.ANSI_COLORS["green"],
)
)
self._viz_n_samples = visualize_n_samples
train_losses: List[float] = []
val_losses: List[float] = []
Expand Down Expand Up @@ -311,12 +314,16 @@ def train(
)

@staticmethod
def _setup_plot():
def _setup_plot(run_name: str, log_scale: bool = False):
"""Setup the plot for training and validation losses."""
plt.title("Training and validation losses")
plt.title(f"Training curves for {run_name}")
plt.theme("dark")
plt.xlabel("Epoch")
plt.ylabel("Loss")
if log_scale:
plt.ylabel("Loss (log scale)")
plt.yscale("log")
else:
plt.ylabel("Loss")
plt.grid(True, True)

def _plot(self, epoch: int, train_losses: List[float], val_losses: List[float]):
Expand All @@ -329,18 +336,12 @@ def _plot(self, epoch: int, train_losses: List[float], val_losses: List[float]):
None
"""
plt.clf()
plt.theme("dark")
plt.xlabel("Epoch")
if project_conf.LOG_SCALE_PLOT:
if any(loss_val <= 0 for loss_val in train_losses + val_losses):
raise ValueError(
"Cannot plot on a log scale if there are non-positive losses."
)
plt.ylabel("Loss (log scale)")
plt.yscale("log")
else:
plt.ylabel("Loss")
plt.grid(True, True)
self._setup_plot(self._run_name, log_scale=project_conf.LOG_SCALE_PLOT)
plt.plot(
list(range(self._starting_epoch, epoch + 1)),
train_losses,
Expand Down

0 comments on commit effa80c

Please sign in to comment.