Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Builder #11

Merged
merged 18 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ FIGURES/
.mypy_cache/
.ruff_cache/
utils/data/MNIST/raw
node_modules/.yarn-integrity
ltex.dictionary.en-GB.txt
yarn.lock
31 changes: 31 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"configurations": [
{
"name": "Launch Train [exp_a]",
"type": "python",
"request": "launch",
"python": "/Users/cactus/miniforge3/envs/bellsnw/bin/python",
"autoReload": { "enable": true },
"program": "${workspaceFolder}/train.py",
"args": ["+experiment=exp_a", "dataset.tiny=1"]
},
{
"name": "Launch Build [exp_a]",
"type": "python",
"request": "launch",
"python": "/Users/cactus/miniforge3/envs/bellsnw/bin/python",
"autoReload": { "enable": true },
"program": "${workspaceFolder}/build.py",
"args": ["+experiment=exp_a", "dataset.tiny=1"]
},
{
"name": "Attach Build [exp_a]",
"type": "python",
"request": "attach",
"connect": {
"host": "localhost",
"port": 5555
}
}
]
}
38 changes: 38 additions & 0 deletions bootstrap/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Optional

from hydra_zen.typing import Partial


class MatchboxModule:
PREV = "MatchboxModule.PREV" # TODO: This is used as an enum value. Should figure it out

def __init__(self, name: str, fn: Callable | Partial, *args, **kwargs):
# TODO: Figure out this entire class. It's a hack, I'm still figuring things
# out as I go.
self._str_rep = name
self.underlying_fn = fn.func if isinstance(fn, partial) else fn
self.partial = partial(fn, *args, **kwargs)

def __call__(self, prev_result: Any) -> Any:
# TODO: Replace .PREV in any of the function's args/kwargs with prev_result
for i, arg in enumerate(self.partial.args):
if arg == self.PREV:
assert prev_result is not None
self.partial.args[i] = prev_result
for key, value in self.partial.keywords.items():
if value == self.PREV:
assert prev_result is not None
self.partial.keywords[key] = prev_result
return self.partial()

def __str__(self) -> str:
return self._str_rep


@dataclass
class MatchboxModuleState:
first_run: bool
result: Any
is_frozen: bool
142 changes: 142 additions & 0 deletions bootstrap/launch_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from rich.syntax import Syntax
from torch.utils.data import DataLoader, Dataset

from bootstrap import MatchboxModule
from bootstrap.factories import (
make_dataloaders,
make_datasets,
Expand All @@ -34,6 +35,7 @@
make_training_loss,
parallelize_model,
)
from bootstrap.tui.builder_ui import BuilderUI
from bootstrap.tui.training_ui import TrainingUI
from conf import project as project_conf
from src.base_tester import BaseTester
Expand Down Expand Up @@ -105,6 +107,145 @@ def init_wandb(
wandb.watch(model, log=log, log_graph=log_graph) # type: ignore


def launch_builder(
run, # type: ignore
data_loader: Partial[DataLoader[Any]],
optimizer: Partial[torch.optim.Optimizer], # pyright: ignore
scheduler: Partial[torch.optim.lr_scheduler.LRScheduler],
trainer: Partial[BaseTrainer],
tester: Partial[BaseTester],
dataset: Partial[Dataset[Any]],
model: Partial[torch.nn.Module],
training_loss: Partial[torch.nn.Module],
):
exp_conf = hydra_zen.to_yaml(
dict(
run_conf=run,
dataset=dataset,
model=model,
optimizer=optimizer,
scheduler=scheduler,
training_loss=training_loss,
)
)
# TODO: Overwrite data_loader.num_workers=0
# data_loader.num_workers = 0

async def launch_with_async_gui():
tui = BuilderUI()
task = asyncio.create_task(tui.run_async())
await asyncio.sleep(0.5) # Wait for the app to start up
while not tui.is_running:
await asyncio.sleep(0.01) # Wait for the app to start up
# trace_catcher = TraceCatcher(tui)

# ============ Partials instantiation ============
# NOTE: We're gonna need a lot of thinking and right now I'm just too tired. We
# basically need to have a complex mechanism that does conditional hot code
# reloading in the following places. Of course, we'll never re-run the entire
# program while in the builder. We'll just reload pieces of code and restart the
# execution at some specific places.

# train_dataset = await trace_catcher.catch_and_hang(
# dataset, split="train", seed=run.seed, progress=None, job_id=None
# )
# model_inst = await trace_catcher.catch_and_hang(
# make_model, model, train_dataset
# )
# opt_inst = await trace_catcher.catch_and_hang(
# make_optimizer, optimizer, model_inst
# )
# scheduler_inst = await trace_catcher.catch_and_hang(
# make_scheduler, scheduler, opt_inst, run.epochs
# )
# training_loss_inst = await trace_catcher.catch_and_hang(
# make_training_loss, run.training_mode, training_loss
# )
# if model_inst is not None:
# model_inst = to_cuda_(parallelize_model(model_inst))
# if training_loss_inst is not None:
# training_loss_inst = to_cuda_(training_loss_inst)
tui.chain_up(
[
MatchboxModule(
"Dataset",
dataset, # TODO: Fix the code reloading, then revert to using the dataset factory
split="train",
seed=run.seed,
progress=None,
job_id=None,
),
MatchboxModule(
"Model",
make_model,
model,
dataset=dataset,
),
MatchboxModule(
"Optimizer", make_optimizer, optimizer, model=MatchboxModule.PREV
),
MatchboxModule(
"Scheduler",
make_scheduler,
scheduler,
optimizer=MatchboxModule.PREV,
epochs=run.epochs,
),
MatchboxModule(
"Loss", make_training_loss, run.training_mode, training_loss
),
]
)
tui.run_chain()
# all_success = False # TODO:
# if all_success:
# # TODO: idk how to handle this YET
# # Somehow, the dataloader will crash if it's not forked when using multiprocessing
# # along with Textual.
# mp.set_start_method("fork")
# train_loader_inst, val_loader_inst, test_loader_inst = make_dataloaders(
# data_loader,
# train_dataset,
# val_dataset,
# test_dataset,
# run.training_mode,
# run.seed,
# )
# init_wandb("test-run", model_inst, exp_conf)
#
# model_ckpt_path = load_model_ckpt(run.load_from, run.training_mode)
# common_args = dict(
# run_name="build-run",
# model=model_inst,
# model_ckpt_path=model_ckpt_path,
# training_loss=training_loss_inst,
# tui=tui,
# )
# if training_loss_inst is None:
# raise ValueError("training_loss must be defined in training mode!")
# if val_loader_inst is None or train_loader_inst is None:
# raise ValueError(
# "val_loader and train_loader must be defined in training mode!"
# )
# await trainer(
# train_loader=train_loader_inst,
# val_loader=val_loader_inst,
# opt=opt_inst,
# scheduler=scheduler_inst,
# **common_args,
# **asdict(run),
# ).train(
# epochs=run.epochs,
# val_every=run.val_every,
# visualize_every=run.viz_every,
# visualize_train_every=run.viz_train_every,
# visualize_n_samples=run.viz_num_samples,
# )
_ = await task

asyncio.run(launch_with_async_gui())


def launch_experiment(
run, # type: ignore
data_loader: Partial[DataLoader[Any]],
Expand Down Expand Up @@ -163,6 +304,7 @@ def launch_experiment(
async def launch_with_async_gui():
tui = TrainingUI(run_name, project_conf.LOG_SCALE_PLOT)
task = asyncio.create_task(tui.run_async())
await asyncio.sleep(0.5) # Wait for the app to start up
while not tui.is_running:
await asyncio.sleep(0.01) # Wait for the app to start up
model_ckpt_path = load_model_ckpt(run.load_from, run.training_mode)
Expand Down
Loading
Loading