-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* first pass at benchmarking the recursion GFN repo. currently non-functional as MilaBench attempts to access the len() of the DataSource class * isort, black * milabench default * downloads weights to specified location * dependencies handled * Avoid implicit call to len * Tweaks metric gathering * now count batch_size as the number of nodes in the batch, added argparse * black --------- Co-authored-by: pierre.delaunay <[email protected]>
- Loading branch information
1 parent
eed157a
commit 3ca4d4b
Showing
9 changed files
with
321 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Use global base if possible | ||
ifndef MILABENCH_BASE | ||
MILABENCH_BASE="base" | ||
endif | ||
|
||
export MILABENCH_BASE | ||
|
||
BENCH_NAME=recursiongfn | ||
MILABENCH_CONFIG=dev.yaml | ||
MILABENCH_ARGS=--config $(MILABENCH_CONFIG) --base $(MILABENCH_BASE) | ||
|
||
all: | ||
install prepare single gpus nodes | ||
|
||
install: | ||
milabench install $(MILABENCH_ARGS) --force | ||
|
||
prepare: | ||
milabench prepare $(MILABENCH_ARGS) | ||
|
||
tests: # install prepare | ||
milabench run $(MILABENCH_ARGS) | ||
|
||
single: | ||
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-single | ||
|
||
gpus: | ||
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-gpus | ||
|
||
nodes: | ||
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-nodes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
|
||
# Recursiongfn | ||
|
||
Rewrite this README to explain what the benchmark is! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from milabench.pack import Package | ||
|
||
|
||
class Recursiongfn(Package): | ||
# Requirements file installed by install(). It can be empty or absent. | ||
base_requirements = "requirements.in" | ||
|
||
# The preparation script called by prepare(). It must be executable, | ||
# but it can be any type of script. It can be empty or absent. | ||
prepare_script = "prepare.py" | ||
|
||
# The main script called by run(). It must be a Python file. It has to | ||
# be present. | ||
main_script = "main.py" | ||
|
||
# You can remove the functions below if you don't need to modify them. | ||
|
||
def make_env(self): | ||
# Return a dict of environment variables for prepare_script and | ||
# main_script. | ||
return super().make_env() | ||
|
||
async def install(self): | ||
await super().install() # super() call installs the requirements | ||
|
||
async def prepare(self): | ||
await super().prepare() # super() call executes prepare_script | ||
|
||
|
||
|
||
__pack__ = Recursiongfn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
|
||
recursiongfn: | ||
inherits: _defaults | ||
definition: . | ||
install-variant: unpinned | ||
install_group: torch | ||
plan: | ||
method: per_gpu | ||
|
||
argv: | ||
--batch_size: 128 | ||
--num_workers: 8 | ||
--num_steps: 100 | ||
--layer_width: 128 | ||
--num_layers: 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# This is the script run by milabench run (by default) | ||
# It is possible to use a script from a GitHub repo if it is cloned using | ||
# clone_subtree in the benchfile.py, in which case this file can simply | ||
# be deleted. | ||
|
||
import datetime | ||
import os | ||
import random | ||
import time | ||
from pathlib import Path | ||
from typing import Callable | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torchcompat.core as accelerator | ||
from gflownet.config import Config, init_empty | ||
from gflownet.models import bengio2021flow | ||
from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask | ||
from gflownet.utils.conditioning import TemperatureConditional | ||
from gflownet.utils.misc import get_worker_device | ||
from torch import Tensor | ||
from torch.utils.data import DataLoader, Dataset | ||
|
||
from benchmate.observer import BenchObserver | ||
|
||
|
||
class SEHFragTrainerMonkeyPatch(SEHFragTrainer): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
self.batch_size_in_nodes = [] | ||
|
||
def batch_size(x): | ||
"""Measures the batch size as the sum of all nodes in the batch.""" | ||
return self.batch_size_in_nodes.pop() | ||
|
||
self.observer = BenchObserver( | ||
accelerator.Event, | ||
earlystop=65, | ||
batch_size_fn=batch_size, | ||
raise_stop_program=False, | ||
stdout=False, | ||
) | ||
|
||
def _maybe_resolve_shared_buffer(self, *args, **kwargs): | ||
batch = super()._maybe_resolve_shared_buffer(*args, **kwargs) | ||
|
||
# Accumulate the size of all graphs in the batch measured in nodes. | ||
acc = 0 | ||
n = len(batch) | ||
for i in range(n): | ||
elem = batch[i] | ||
acc += elem.x.shape[0] | ||
|
||
self.batch_size_in_nodes.append(acc) | ||
return batch | ||
|
||
def step(self, loss: Tensor): | ||
original_output = super().step(loss) | ||
self.observer.record_loss(loss) | ||
return original_output | ||
|
||
def build_training_data_loader(self) -> DataLoader: | ||
original_output = super().build_training_data_loader() | ||
return self.observer.loader(original_output) | ||
|
||
def setup_task(self): | ||
self.task = SEHTaskMonkeyPatch( | ||
dataset=self.training_data, | ||
cfg=self.cfg, | ||
rng=self.rng, | ||
wrap_model=self._wrap_for_mp, | ||
) | ||
|
||
|
||
class SEHTaskMonkeyPatch(SEHTask): | ||
"""Allows us to specify the location of the original model download.""" | ||
|
||
def __init__( | ||
self, | ||
dataset: Dataset, | ||
cfg: Config, | ||
rng: np.random.Generator = None, | ||
wrap_model: Callable[[nn.Module], nn.Module] = None, | ||
): | ||
self._wrap_model = wrap_model | ||
self.rng = rng | ||
self.models = self._load_task_models() | ||
self.dataset = dataset | ||
self.temperature_conditional = TemperatureConditional(cfg, rng) | ||
self.num_cond_dim = self.temperature_conditional.encoding_size() | ||
|
||
def _load_task_models(self): | ||
xdg_cache = os.environ["XDG_CACHE_HOME"] | ||
model = bengio2021flow.load_original_model( | ||
cache=True, | ||
location=Path(os.path.join(xdg_cache, "bengio2021flow_proxy.pkl.gz")), | ||
) | ||
model.to(get_worker_device()) | ||
model = self._wrap_model(model) | ||
return {"seh": model} | ||
|
||
|
||
def main( | ||
batch_size: int, num_workers: int, num_steps: int, layer_width: int, num_layers: int | ||
): | ||
# This script runs on an A100 with 8 cpus and 32Gb memory, but the A100 is probably | ||
# overkill here. VRAM peaks at 6Gb and GPU usage peaks at 25%. | ||
|
||
config = init_empty(Config()) | ||
config.print_every = 1 | ||
config.log_dir = f"./logs/debug_run_seh_frag_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" | ||
config.device = accelerator.fetch_device(0) # This is your CUDA device. | ||
config.overwrite_existing_exp = True | ||
|
||
config.num_training_steps = num_steps # Change this to train for longer. | ||
config.checkpoint_every = 5 # 500 | ||
config.validate_every = 0 | ||
config.num_final_gen_steps = 0 | ||
config.opt.lr_decay = 20_000 | ||
config.opt.clip_grad_type = "total_norm" | ||
config.algo.sampling_tau = 0.9 | ||
config.cond.temperature.sample_dist = "constant" | ||
config.cond.temperature.dist_params = [64.0] | ||
config.replay.use = False | ||
|
||
# Things it may be fun to play with. | ||
config.num_workers = num_workers | ||
config.model.num_emb = layer_width | ||
config.model.num_layers = num_layers | ||
batch_size = batch_size | ||
|
||
if config.replay.use: | ||
config.algo.num_from_policy = 0 | ||
config.replay.num_new_samples = batch_size | ||
config.replay.num_from_replay = batch_size | ||
else: | ||
config.algo.num_from_policy = batch_size | ||
|
||
# This may need to be adjusted if the batch_size is made bigger | ||
config.mp_buffer_size = 32 * 1024**2 # 32Mb | ||
trial = SEHFragTrainerMonkeyPatch(config, print_config=False) | ||
trial.run() | ||
trial.terminate() | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser(description="Description of your program") | ||
parser.add_argument("-b", "--batch_size", help="Batch Size", default=128) | ||
parser.add_argument("-n", "--num_workers", help="Number of Workers", default=8) | ||
parser.add_argument( | ||
"-s", "--num_steps", help="Number of Training Steps", default=100 | ||
) | ||
parser.add_argument( | ||
"-w", "--layer_width", help="Width of each policy hidden layer", default=128 | ||
) | ||
parser.add_argument("-l", "--num_layers", help="Number of hidden layers", default=4) | ||
args = parser.parse_args() | ||
|
||
main( | ||
args.batch_size, | ||
args.num_workers, | ||
args.num_steps, | ||
args.layer_width, | ||
args.num_layers, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#!/usr/bin/env python | ||
|
||
import os | ||
from gflownet.models.bengio2021flow import load_original_model | ||
from pathlib import Path | ||
|
||
|
||
if __name__ == "__main__": | ||
# If you need the whole configuration: | ||
# config = json.loads(os.environ["MILABENCH_CONFIG"]) | ||
print("+ Full environment:\n{}\n***".format(os.environ)) | ||
|
||
#milabench_cfg = os.environ["MILABENCH_CONFIG"] | ||
#print(milabench_cfg) | ||
|
||
xdg_cache = os.environ["XDG_CACHE_HOME"] | ||
|
||
print("+ Loading proxy model weights to MILABENCH_DIR_DATA={}".format(xdg_cache)) | ||
_ = load_original_model( | ||
cache=True, | ||
location=Path(os.path.join(xdg_cache, "bengio2021flow_proxy.pkl.gz")), | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
voir>=0.2.17,<0.3 | ||
torch | ||
gflownet @ git+https://github.com/recursionpharma/gflownet@bengioe-mila-demo | ||
--find-links https://data.pyg.org/whl/torch-2.1.2+cu121.html | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from dataclasses import dataclass | ||
|
||
from voir import configurable | ||
from voir.instruments import dash, early_stop, log, rate | ||
from benchmate.monitor import monitor_monogpu | ||
|
||
@dataclass | ||
class Config: | ||
"""voir configuration""" | ||
|
||
# Whether to display the dash or not | ||
dash: bool = False | ||
|
||
# How often to log the rates | ||
interval: str = "1s" | ||
|
||
# Number of rates to skip before logging | ||
skip: int = 5 | ||
|
||
# Number of rates to log before stopping | ||
stop: int = 20 | ||
|
||
# Number of seconds between each gpu poll | ||
gpu_poll: int = 3 | ||
|
||
|
||
@configurable | ||
def instrument_main(ov, options: Config): | ||
yield ov.phases.init | ||
|
||
if options.dash: | ||
ov.require(dash) | ||
|
||
ov.require( | ||
log("value", "progress", "rate", "units", "loss", "gpudata", context="task"), | ||
early_stop(n=options.stop, key="rate", task="train"), | ||
monitor_monogpu(poll_interval=options.gpu_poll), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters