From d16211202be736f466e83e0707f8e710b4968a9b Mon Sep 17 00:00:00 2001 From: Theo Date: Sun, 28 Jul 2024 01:12:23 +0100 Subject: [PATCH] Capture prints with Textual --- bootstrap/launch_experiment.py | 9 ++-- bootstrap/tui/logger.py | 84 ++++++++++++++++++++++++++++++++++ bootstrap/tui/training_ui.py | 84 ++++++---------------------------- src/base_tester.py | 7 --- src/base_trainer.py | 8 +--- 5 files changed, 105 insertions(+), 87 deletions(-) create mode 100644 bootstrap/tui/logger.py diff --git a/bootstrap/launch_experiment.py b/bootstrap/launch_experiment.py index 7cfe718..820813f 100644 --- a/bootstrap/launch_experiment.py +++ b/bootstrap/launch_experiment.py @@ -163,6 +163,7 @@ def launch_experiment( async def launch_with_async_gui(): tui = TrainingUI(run_name, project_conf.LOG_SCALE_PLOT) task = asyncio.create_task(tui.run_async()) + await asyncio.sleep(0.5) # Wait for the app to start up while not tui.is_running: await asyncio.sleep(0.01) # Wait for the app to start up model_ckpt_path = load_model_ckpt(run.load_from, run.training_mode) @@ -174,7 +175,7 @@ async def launch_with_async_gui(): tui=tui, ) if run.training_mode: - tui.print("Training started!") + print("Training started!") if training_loss_inst is None: raise ValueError("training_loss must be defined in training mode!") if val_loader_inst is None or train_loader_inst is None: @@ -195,9 +196,9 @@ async def launch_with_async_gui(): visualize_train_every=run.viz_train_every, visualize_n_samples=run.viz_num_samples, ) - tui.print("Training finished!") + print("Training finished!") else: - tui.print("Testing started!") + print("Testing started!") if test_loader_inst is None: raise ValueError("test_loader must be defined in testing mode!") await tester( @@ -207,7 +208,7 @@ async def launch_with_async_gui(): visualize_every=run.viz_every, **asdict(run), ) - tui.print("Testing finished!") + print("Testing finished!") _ = await task asyncio.run(launch_with_async_gui()) diff --git a/bootstrap/tui/logger.py b/bootstrap/tui/logger.py new file mode 100644 index 0000000..7bfc548 --- /dev/null +++ b/bootstrap/tui/logger.py @@ -0,0 +1,84 @@ +from datetime import datetime +from typing import ( + Any, +) + +from rich.console import Group, RenderableType +from rich.pretty import Pretty +from rich.text import Text +from textual.app import ComposeResult +from textual.events import Print +from textual.widgets import ( + RichLog, + Static, +) + + +class Logger(Static): + def compose(self) -> ComposeResult: + yield RichLog(highlight=True, markup=True, wrap=True) + + def on_mount(self): + self.begin_capture_print() + + def on_print(self, event: Print) -> None: + self.wite(event.text, event.stderr) + + def wite(self, message: Any, is_stderr: bool): + if isinstance(message, str) and message.strip() == "": + # FIXME: Why do we need this hack?! + return + logger: RichLog = self.query_one(RichLog) + if isinstance(message, (RenderableType, str)): + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan" if not is_stderr else "bold red", + end="", + ), + message, + ), + ) + else: + ppable, pp_msg = True, None + try: + pp_msg = Pretty(message) + except Exception: + ppable = False + if ppable and pp_msg is not None: + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + Text(str(type(message)) + " ", style="italic blue", end=""), + pp_msg, + ) + ) + else: + try: + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + message, + ), + ) + except Exception as e: + logger.write( + Group( + Text( + datetime.now().strftime("[%H:%M] "), + style="dim cyan", + end="", + ), + Text("Logging error: ", style="bold red"), + Text(str(e), style="bold red"), + ) + ) diff --git a/bootstrap/tui/training_ui.py b/bootstrap/tui/training_ui.py index 3dfc204..6070b67 100644 --- a/bootstrap/tui/training_ui.py +++ b/bootstrap/tui/training_ui.py @@ -1,5 +1,4 @@ import asyncio -from datetime import datetime from itertools import cycle from random import random from typing import ( @@ -13,8 +12,6 @@ import numpy as np import torch import torch.multiprocessing as mp -from rich.console import Group, RenderableType -from rich.pretty import Pretty from rich.text import Text from textual.app import App, ComposeResult from textual.reactive import var @@ -22,13 +19,13 @@ Footer, Header, Placeholder, - RichLog, ) from torch.utils.data.dataloader import DataLoader from torchvision.datasets import MNIST from torchvision.transforms.functional import to_tensor from bootstrap.tui import Plot_BestModel, Task +from bootstrap.tui.logger import Logger from bootstrap.tui.widgets.plotting import PlotterWidget from bootstrap.tui.widgets.progress import DatasetProgressBar @@ -72,9 +69,7 @@ def compose(self) -> ComposeResult: use_log_scale=self._log_scale, classes="box", ) - yield RichLog( - highlight=True, markup=True, wrap=True, id="logger", classes="box" - ) + yield Logger(id="logger", classes="box") yield DatasetProgressBar() yield Placeholder(classes="box") yield Footer() @@ -94,57 +89,8 @@ def action_marker(self) -> None: """Cycle to the next marker type.""" self.marker = next(self._markers) # skipcq: PTC-W0063 - def print(self, message: Any): - logger: RichLog = self.query_one(RichLog) - if isinstance(message, (RenderableType, str)): - logger.write( - Group( - Text(datetime.now().strftime("[%H:%M] "), style="dim cyan", end=""), - message, - ), - ) - else: - ppable, pp_msg = True, None - try: - pp_msg = Pretty(message) - except Exception: - ppable = False - if ppable and pp_msg is not None: - logger.write( - Group( - Text( - datetime.now().strftime("[%H:%M] "), - style="dim cyan", - end="", - ), - Text(str(type(message)) + " ", style="italic blue", end=""), - pp_msg, - ) - ) - else: - try: - logger.write( - Group( - Text( - datetime.now().strftime("[%H:%M] "), - style="dim cyan", - end="", - ), - message, - ), - ) - except Exception as e: - logger.write( - Group( - Text( - datetime.now().strftime("[%H:%M] "), - style="dim cyan", - end="", - ), - Text("Logging error: ", style="bold red"), - Text(str(e), style="bold red"), - ) - ) + def print_rich(self, message: Any): + self.query_one(Logger).wite(message, is_stderr=False) def track_training(self, iterable, total: int) -> Tuple[Iterable, Callable]: """Return an iterable that tracks the progress of the training process, and a @@ -188,19 +134,19 @@ async def run_my_app(): task = asyncio.create_task(gui.run_async()) while not gui.is_running: await asyncio.sleep(0.01) # Wait for the app to start up - gui.print("Hello, World!") + gui.print_rich("Hello, World!") await asyncio.sleep(2) - gui.print(Text("Let's log some tensors :)", style="bold magenta")) + gui.print_rich(Text("Let's log some tensors :)", style="bold magenta")) await asyncio.sleep(0.5) - gui.print(torch.rand(2, 4)) + gui.print_rich(torch.rand(2, 4)) await asyncio.sleep(2) - gui.print(Text("How about some numpy arrays?!", style="italic green")) + gui.print_rich(Text("How about some numpy arrays?!", style="italic green")) await asyncio.sleep(1) - gui.print(np.random.rand(3, 3)) + gui.print_rich(np.random.rand(3, 3)) pbar, update_progress_loss = gui.track_training(range(10), 10) for i, e in enumerate(pbar): - gui.print(f"[{i+1}/10]: We can iterate over iterables") - gui.print(e) + gui.print_rich(f"[{i+1}/10]: We can iterate over iterables") + gui.print_rich(e) await asyncio.sleep(0.1) await asyncio.sleep(2) mnist = MNIST(root="data", train=False, download=True, transform=to_tensor) @@ -212,16 +158,16 @@ async def run_my_app(): for i, batch in enumerate(pbar): await asyncio.sleep(0.01) if i % 10 == 0: - gui.print(batch) + gui.print_rich(batch) update_progress_loss(random()) gui.plot(epoch=i, train_loss=random(), val_loss=random()) - gui.print( + gui.print_rich( f"[{i+1}/{len(dataloader)}]: " + "We can also iterate over PyTorch dataloaders!" ) if i == 0: - gui.print(batch) - gui.print("Goodbye, world!") + gui.print_rich(batch) + gui.print_rich("Goodbye, world!") _ = await task diff --git a/src/base_tester.py b/src/base_tester.py index 7fc8572..500a2e5 100644 --- a/src/base_tester.py +++ b/src/base_tester.py @@ -15,7 +15,6 @@ from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union import torch -from rich.console import Console from rich.text import Text from torch import Tensor from torch.nn import Module @@ -29,10 +28,6 @@ T = TypeVar("T") -console = Console() -print = console.print # skipcq: PYL-W0603, PYL-W0622 - - class BaseTester(BaseTrainer): def __init__( self, @@ -54,8 +49,6 @@ def __init__( _args = kwargs _loss = training_loss self._tui = tui - global print # skipcq: PYL-W0603 - print = self._tui.print # skipcq: PYL-W0603, PYL-W0622 self._run_name = run_name self._model = model if model_ckpt_path is None: diff --git a/src/base_trainer.py b/src/base_trainer.py index e54b9e8..9d8266c 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -19,7 +19,6 @@ import torch import wandb from hydra.core.hydra_config import HydraConfig -from rich.console import Console from rich.text import Text from torch import Tensor from torch.nn import Module @@ -34,9 +33,6 @@ from utils.helpers import BestNModelSaver from utils.training import visualize_model_predictions -console = Console() -print = console.print # skipcq: PYL-W0603, PYL-W0622 - class BaseTrainer: def __init__( @@ -77,8 +73,6 @@ def __init__( self._viz_n_samples = 1 self._n_ctrl_c = 0 self._tui = tui - global print # skipcq: PYL-W0603 - print = self._tui.print # skipcq: PYL-W0603, PYL-W0622 if model_ckpt_path is not None: self._load_checkpoint(model_ckpt_path) signal.signal(signal.SIGINT, self._terminator) # FIXME: Textual broke this @@ -263,7 +257,7 @@ async def train( Returns: None """ - print( + self._tui.print_rich( Text( f"[*] Training {self._run_name} for {epochs} epochs", style="bold green" )