Skip to content

Commit

Permalink
Hook running state to the trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
DubiousCactus committed Jul 30, 2024
1 parent 5b7d47c commit 2ac3093
Showing 1 changed file with 31 additions and 29 deletions.
60 changes: 31 additions & 29 deletions src/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import asyncio
import os
import random
import signal
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -64,7 +63,6 @@ def __init__(
self._val_loader = val_loader
self._epoch = 0
self._starting_epoch = 0
self._running = True
self._model_saver = BestNModelSaver(
project_conf.BEST_N_MODELS_TO_KEEP, self._save_checkpoint
)
Expand All @@ -75,7 +73,11 @@ def __init__(
self._tui = tui
if model_ckpt_path is not None:
self._load_checkpoint(model_ckpt_path)
signal.signal(signal.SIGINT, self._terminator) # FIXME: Textual broke this
# signal.signal(signal.SIGINT, self._terminator) # FIXME: Textual broke this

@property
def is_running(self) -> bool:
return self._tui.is_running

@to_cuda
def _visualize(
Expand Down Expand Up @@ -139,7 +141,7 @@ def _train_epoch(
)
for i, batch in enumerate(pbar):
if (
not self._running
not self.is_running
and project_conf.SIGINT_BEHAVIOR
== project_conf.TerminationBehavior.ABORT_EPOCH
):
Expand Down Expand Up @@ -194,7 +196,7 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float:
)
for i, batch in enumerate(pbar):
if (
not self._running
not self.is_running
and project_conf.SIGINT_BEHAVIOR
== project_conf.TerminationBehavior.ABORT_EPOCH
):
Expand Down Expand Up @@ -269,7 +271,7 @@ async def train(
for epoch in range(self._epoch, epochs):
print(f"Epoch: {epoch}")
self._epoch = epoch # Update for the model saver
if not self._running:
if not self.is_running:
break
self._model.train()
train_loss: float = await asyncio.to_thread(
Expand Down Expand Up @@ -373,26 +375,26 @@ def _load_checkpoint(self, ckpt_path: str, model_only: bool = False) -> None:
if self._scheduler is not None:
self._scheduler.load_state_dict(ckpt["scheduler_ckpt"])

def _terminator(self, sig, frame): # FIXME: Textual broke this
"""
Handles the SIGINT signal (Ctrl+C) and stops the training loop.
"""
_ = sig
_ = frame
if (
project_conf.SIGINT_BEHAVIOR
== project_conf.TerminationBehavior.WAIT_FOR_EPOCH_END
and self._n_ctrl_c == 0
):
print(
f"[!] SIGINT received. Waiting for epoch to end for {self._run_name}."
+ " Press Ctrl+C again to abort."
)
self._n_ctrl_c += 1
elif (
project_conf.SIGINT_BEHAVIOR == project_conf.TerminationBehavior.ABORT_EPOCH
or self._n_ctrl_c > 0
):
print(f"[!] SIGINT received. Aborting epoch for {self._run_name}!")
raise KeyboardInterrupt
self._running = False
# def _terminator(self, sig, frame): # FIXME: Textual broke this
# """
# Handles the SIGINT signal (Ctrl+C) and stops the training loop.
# """
# _ = sig
# _ = frame
# if (
# project_conf.SIGINT_BEHAVIOR
# == project_conf.TerminationBehavior.WAIT_FOR_EPOCH_END
# and self._n_ctrl_c == 0
# ):
# print(
# f"[!] SIGINT received. Waiting for epoch to end for {self._run_name}."
# + " Press Ctrl+C again to abort."
# )
# self._n_ctrl_c += 1
# elif (
# project_conf.SIGINT_BEHAVIOR == project_conf.TerminationBehavior.ABORT_EPOCH
# or self._n_ctrl_c > 0
# ):
# print(f"[!] SIGINT received. Aborting epoch for {self._run_name}!")
# raise KeyboardInterrupt
# self._running = False

0 comments on commit 2ac3093

Please sign in to comment.