Skip to content

Commit

Permalink
Capture prints with Textual
Browse files Browse the repository at this point in the history
  • Loading branch information
DubiousCactus committed Jul 28, 2024
1 parent 67a09cb commit d162112
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 87 deletions.
9 changes: 5 additions & 4 deletions bootstrap/launch_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def launch_experiment(
async def launch_with_async_gui():
tui = TrainingUI(run_name, project_conf.LOG_SCALE_PLOT)
task = asyncio.create_task(tui.run_async())
await asyncio.sleep(0.5) # Wait for the app to start up
while not tui.is_running:
await asyncio.sleep(0.01) # Wait for the app to start up
model_ckpt_path = load_model_ckpt(run.load_from, run.training_mode)
Expand All @@ -174,7 +175,7 @@ async def launch_with_async_gui():
tui=tui,
)
if run.training_mode:
tui.print("Training started!")
print("Training started!")
if training_loss_inst is None:
raise ValueError("training_loss must be defined in training mode!")
if val_loader_inst is None or train_loader_inst is None:
Expand All @@ -195,9 +196,9 @@ async def launch_with_async_gui():
visualize_train_every=run.viz_train_every,
visualize_n_samples=run.viz_num_samples,
)
tui.print("Training finished!")
print("Training finished!")
else:
tui.print("Testing started!")
print("Testing started!")
if test_loader_inst is None:
raise ValueError("test_loader must be defined in testing mode!")
await tester(
Expand All @@ -207,7 +208,7 @@ async def launch_with_async_gui():
visualize_every=run.viz_every,
**asdict(run),
)
tui.print("Testing finished!")
print("Testing finished!")
_ = await task

asyncio.run(launch_with_async_gui())
84 changes: 84 additions & 0 deletions bootstrap/tui/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from datetime import datetime
from typing import (
Any,
)

from rich.console import Group, RenderableType
from rich.pretty import Pretty
from rich.text import Text
from textual.app import ComposeResult
from textual.events import Print
from textual.widgets import (
RichLog,
Static,
)


class Logger(Static):
def compose(self) -> ComposeResult:
yield RichLog(highlight=True, markup=True, wrap=True)

def on_mount(self):
self.begin_capture_print()

def on_print(self, event: Print) -> None:
self.wite(event.text, event.stderr)

def wite(self, message: Any, is_stderr: bool):
if isinstance(message, str) and message.strip() == "":
# FIXME: Why do we need this hack?!
return
logger: RichLog = self.query_one(RichLog)
if isinstance(message, (RenderableType, str)):
logger.write(
Group(
Text(
datetime.now().strftime("[%H:%M] "),
style="dim cyan" if not is_stderr else "bold red",
end="",
),
message,
),
)
else:
ppable, pp_msg = True, None
try:
pp_msg = Pretty(message)
except Exception:
ppable = False
if ppable and pp_msg is not None:
logger.write(
Group(
Text(
datetime.now().strftime("[%H:%M] "),
style="dim cyan",
end="",
),
Text(str(type(message)) + " ", style="italic blue", end=""),
pp_msg,
)
)
else:
try:
logger.write(
Group(
Text(
datetime.now().strftime("[%H:%M] "),
style="dim cyan",
end="",
),
message,
),
)
except Exception as e:
logger.write(
Group(
Text(
datetime.now().strftime("[%H:%M] "),
style="dim cyan",
end="",
),
Text("Logging error: ", style="bold red"),
Text(str(e), style="bold red"),
)
)
84 changes: 15 additions & 69 deletions bootstrap/tui/training_ui.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
from datetime import datetime
from itertools import cycle
from random import random
from typing import (
Expand All @@ -13,22 +12,20 @@
import numpy as np
import torch
import torch.multiprocessing as mp
from rich.console import Group, RenderableType
from rich.pretty import Pretty
from rich.text import Text
from textual.app import App, ComposeResult
from textual.reactive import var
from textual.widgets import (
Footer,
Header,
Placeholder,
RichLog,
)
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms.functional import to_tensor

from bootstrap.tui import Plot_BestModel, Task
from bootstrap.tui.logger import Logger
from bootstrap.tui.widgets.plotting import PlotterWidget
from bootstrap.tui.widgets.progress import DatasetProgressBar

Expand Down Expand Up @@ -72,9 +69,7 @@ def compose(self) -> ComposeResult:
use_log_scale=self._log_scale,
classes="box",
)
yield RichLog(
highlight=True, markup=True, wrap=True, id="logger", classes="box"
)
yield Logger(id="logger", classes="box")
yield DatasetProgressBar()
yield Placeholder(classes="box")
yield Footer()
Expand All @@ -94,57 +89,8 @@ def action_marker(self) -> None:
"""Cycle to the next marker type."""
self.marker = next(self._markers) # skipcq: PTC-W0063

def print(self, message: Any):
logger: RichLog = self.query_one(RichLog)
if isinstance(message, (RenderableType, str)):
logger.write(
Group(
Text(datetime.now().strftime("[%H:%M] "), style="dim cyan", end=""),
message,
),
)
else:
ppable, pp_msg = True, None
try:
pp_msg = Pretty(message)
except Exception:
ppable = False
if ppable and pp_msg is not None:
logger.write(
Group(
Text(
datetime.now().strftime("[%H:%M] "),
style="dim cyan",
end="",
),
Text(str(type(message)) + " ", style="italic blue", end=""),
pp_msg,
)
)
else:
try:
logger.write(
Group(
Text(
datetime.now().strftime("[%H:%M] "),
style="dim cyan",
end="",
),
message,
),
)
except Exception as e:
logger.write(
Group(
Text(
datetime.now().strftime("[%H:%M] "),
style="dim cyan",
end="",
),
Text("Logging error: ", style="bold red"),
Text(str(e), style="bold red"),
)
)
def print_rich(self, message: Any):
self.query_one(Logger).wite(message, is_stderr=False)

def track_training(self, iterable, total: int) -> Tuple[Iterable, Callable]:
"""Return an iterable that tracks the progress of the training process, and a
Expand Down Expand Up @@ -188,19 +134,19 @@ async def run_my_app():
task = asyncio.create_task(gui.run_async())
while not gui.is_running:
await asyncio.sleep(0.01) # Wait for the app to start up
gui.print("Hello, World!")
gui.print_rich("Hello, World!")
await asyncio.sleep(2)
gui.print(Text("Let's log some tensors :)", style="bold magenta"))
gui.print_rich(Text("Let's log some tensors :)", style="bold magenta"))
await asyncio.sleep(0.5)
gui.print(torch.rand(2, 4))
gui.print_rich(torch.rand(2, 4))
await asyncio.sleep(2)
gui.print(Text("How about some numpy arrays?!", style="italic green"))
gui.print_rich(Text("How about some numpy arrays?!", style="italic green"))
await asyncio.sleep(1)
gui.print(np.random.rand(3, 3))
gui.print_rich(np.random.rand(3, 3))
pbar, update_progress_loss = gui.track_training(range(10), 10)
for i, e in enumerate(pbar):
gui.print(f"[{i+1}/10]: We can iterate over iterables")
gui.print(e)
gui.print_rich(f"[{i+1}/10]: We can iterate over iterables")
gui.print_rich(e)
await asyncio.sleep(0.1)
await asyncio.sleep(2)
mnist = MNIST(root="data", train=False, download=True, transform=to_tensor)
Expand All @@ -212,16 +158,16 @@ async def run_my_app():
for i, batch in enumerate(pbar):
await asyncio.sleep(0.01)
if i % 10 == 0:
gui.print(batch)
gui.print_rich(batch)
update_progress_loss(random())
gui.plot(epoch=i, train_loss=random(), val_loss=random())
gui.print(
gui.print_rich(
f"[{i+1}/{len(dataloader)}]: "
+ "We can also iterate over PyTorch dataloaders!"
)
if i == 0:
gui.print(batch)
gui.print("Goodbye, world!")
gui.print_rich(batch)
gui.print_rich("Goodbye, world!")
_ = await task


Expand Down
7 changes: 0 additions & 7 deletions src/base_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union

import torch
from rich.console import Console
from rich.text import Text
from torch import Tensor
from torch.nn import Module
Expand All @@ -29,10 +28,6 @@
T = TypeVar("T")


console = Console()
print = console.print # skipcq: PYL-W0603, PYL-W0622


class BaseTester(BaseTrainer):
def __init__(
self,
Expand All @@ -54,8 +49,6 @@ def __init__(
_args = kwargs
_loss = training_loss
self._tui = tui
global print # skipcq: PYL-W0603
print = self._tui.print # skipcq: PYL-W0603, PYL-W0622
self._run_name = run_name
self._model = model
if model_ckpt_path is None:
Expand Down
8 changes: 1 addition & 7 deletions src/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch
import wandb
from hydra.core.hydra_config import HydraConfig
from rich.console import Console
from rich.text import Text
from torch import Tensor
from torch.nn import Module
Expand All @@ -34,9 +33,6 @@
from utils.helpers import BestNModelSaver
from utils.training import visualize_model_predictions

console = Console()
print = console.print # skipcq: PYL-W0603, PYL-W0622


class BaseTrainer:
def __init__(
Expand Down Expand Up @@ -77,8 +73,6 @@ def __init__(
self._viz_n_samples = 1
self._n_ctrl_c = 0
self._tui = tui
global print # skipcq: PYL-W0603
print = self._tui.print # skipcq: PYL-W0603, PYL-W0622
if model_ckpt_path is not None:
self._load_checkpoint(model_ckpt_path)
signal.signal(signal.SIGINT, self._terminator) # FIXME: Textual broke this
Expand Down Expand Up @@ -263,7 +257,7 @@ async def train(
Returns:
None
"""
print(
self._tui.print_rich(
Text(
f"[*] Training {self._run_name} for {epochs} epochs", style="bold green"
)
Expand Down

0 comments on commit d162112

Please sign in to comment.