Skip to content

Commit

Permalink
recursiongfn benchmark (#249)
Browse files Browse the repository at this point in the history
* first pass at benchmarking the recursion GFN repo. currently non-functional as MilaBench attempts to access the len() of the DataSource class

* isort, black

* milabench default

* downloads weights to specified location

* dependencies handled

* Avoid implicit call to len

* Tweaks metric gathering

* now count batch_size as the number of nodes in the batch, added argparse

* black

---------

Co-authored-by: pierre.delaunay <[email protected]>
  • Loading branch information
josephdviviano and pierre.delaunay authored Aug 13, 2024
1 parent eed157a commit 3ca4d4b
Show file tree
Hide file tree
Showing 9 changed files with 321 additions and 2 deletions.
31 changes: 31 additions & 0 deletions benchmarks/recursiongfn/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Use global base if possible
ifndef MILABENCH_BASE
MILABENCH_BASE="base"
endif

export MILABENCH_BASE

BENCH_NAME=recursiongfn
MILABENCH_CONFIG=dev.yaml
MILABENCH_ARGS=--config $(MILABENCH_CONFIG) --base $(MILABENCH_BASE)

all:
install prepare single gpus nodes

install:
milabench install $(MILABENCH_ARGS) --force

prepare:
milabench prepare $(MILABENCH_ARGS)

tests: # install prepare
milabench run $(MILABENCH_ARGS)

single:
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-single

gpus:
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-gpus

nodes:
milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-nodes
4 changes: 4 additions & 0 deletions benchmarks/recursiongfn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

# Recursiongfn

Rewrite this README to explain what the benchmark is!
31 changes: 31 additions & 0 deletions benchmarks/recursiongfn/benchfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from milabench.pack import Package


class Recursiongfn(Package):
# Requirements file installed by install(). It can be empty or absent.
base_requirements = "requirements.in"

# The preparation script called by prepare(). It must be executable,
# but it can be any type of script. It can be empty or absent.
prepare_script = "prepare.py"

# The main script called by run(). It must be a Python file. It has to
# be present.
main_script = "main.py"

# You can remove the functions below if you don't need to modify them.

def make_env(self):
# Return a dict of environment variables for prepare_script and
# main_script.
return super().make_env()

async def install(self):
await super().install() # super() call installs the requirements

async def prepare(self):
await super().prepare() # super() call executes prepare_script



__pack__ = Recursiongfn
15 changes: 15 additions & 0 deletions benchmarks/recursiongfn/dev.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

recursiongfn:
inherits: _defaults
definition: .
install-variant: unpinned
install_group: torch
plan:
method: per_gpu

argv:
--batch_size: 128
--num_workers: 8
--num_steps: 100
--layer_width: 128
--num_layers: 4
169 changes: 169 additions & 0 deletions benchmarks/recursiongfn/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# This is the script run by milabench run (by default)
# It is possible to use a script from a GitHub repo if it is cloned using
# clone_subtree in the benchfile.py, in which case this file can simply
# be deleted.

import datetime
import os
import random
import time
from pathlib import Path
from typing import Callable

import numpy as np
import torch
import torch.nn as nn
import torchcompat.core as accelerator
from gflownet.config import Config, init_empty
from gflownet.models import bengio2021flow
from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask
from gflownet.utils.conditioning import TemperatureConditional
from gflownet.utils.misc import get_worker_device
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

from benchmate.observer import BenchObserver


class SEHFragTrainerMonkeyPatch(SEHFragTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.batch_size_in_nodes = []

def batch_size(x):
"""Measures the batch size as the sum of all nodes in the batch."""
return self.batch_size_in_nodes.pop()

self.observer = BenchObserver(
accelerator.Event,
earlystop=65,
batch_size_fn=batch_size,
raise_stop_program=False,
stdout=False,
)

def _maybe_resolve_shared_buffer(self, *args, **kwargs):
batch = super()._maybe_resolve_shared_buffer(*args, **kwargs)

# Accumulate the size of all graphs in the batch measured in nodes.
acc = 0
n = len(batch)
for i in range(n):
elem = batch[i]
acc += elem.x.shape[0]

self.batch_size_in_nodes.append(acc)
return batch

def step(self, loss: Tensor):
original_output = super().step(loss)
self.observer.record_loss(loss)
return original_output

def build_training_data_loader(self) -> DataLoader:
original_output = super().build_training_data_loader()
return self.observer.loader(original_output)

def setup_task(self):
self.task = SEHTaskMonkeyPatch(
dataset=self.training_data,
cfg=self.cfg,
rng=self.rng,
wrap_model=self._wrap_for_mp,
)


class SEHTaskMonkeyPatch(SEHTask):
"""Allows us to specify the location of the original model download."""

def __init__(
self,
dataset: Dataset,
cfg: Config,
rng: np.random.Generator = None,
wrap_model: Callable[[nn.Module], nn.Module] = None,
):
self._wrap_model = wrap_model
self.rng = rng
self.models = self._load_task_models()
self.dataset = dataset
self.temperature_conditional = TemperatureConditional(cfg, rng)
self.num_cond_dim = self.temperature_conditional.encoding_size()

def _load_task_models(self):
xdg_cache = os.environ["XDG_CACHE_HOME"]
model = bengio2021flow.load_original_model(
cache=True,
location=Path(os.path.join(xdg_cache, "bengio2021flow_proxy.pkl.gz")),
)
model.to(get_worker_device())
model = self._wrap_model(model)
return {"seh": model}


def main(
batch_size: int, num_workers: int, num_steps: int, layer_width: int, num_layers: int
):
# This script runs on an A100 with 8 cpus and 32Gb memory, but the A100 is probably
# overkill here. VRAM peaks at 6Gb and GPU usage peaks at 25%.

config = init_empty(Config())
config.print_every = 1
config.log_dir = f"./logs/debug_run_seh_frag_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
config.device = accelerator.fetch_device(0) # This is your CUDA device.
config.overwrite_existing_exp = True

config.num_training_steps = num_steps # Change this to train for longer.
config.checkpoint_every = 5 # 500
config.validate_every = 0
config.num_final_gen_steps = 0
config.opt.lr_decay = 20_000
config.opt.clip_grad_type = "total_norm"
config.algo.sampling_tau = 0.9
config.cond.temperature.sample_dist = "constant"
config.cond.temperature.dist_params = [64.0]
config.replay.use = False

# Things it may be fun to play with.
config.num_workers = num_workers
config.model.num_emb = layer_width
config.model.num_layers = num_layers
batch_size = batch_size

if config.replay.use:
config.algo.num_from_policy = 0
config.replay.num_new_samples = batch_size
config.replay.num_from_replay = batch_size
else:
config.algo.num_from_policy = batch_size

# This may need to be adjusted if the batch_size is made bigger
config.mp_buffer_size = 32 * 1024**2 # 32Mb
trial = SEHFragTrainerMonkeyPatch(config, print_config=False)
trial.run()
trial.terminate()


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Description of your program")
parser.add_argument("-b", "--batch_size", help="Batch Size", default=128)
parser.add_argument("-n", "--num_workers", help="Number of Workers", default=8)
parser.add_argument(
"-s", "--num_steps", help="Number of Training Steps", default=100
)
parser.add_argument(
"-w", "--layer_width", help="Width of each policy hidden layer", default=128
)
parser.add_argument("-l", "--num_layers", help="Number of hidden layers", default=4)
args = parser.parse_args()

main(
args.batch_size,
args.num_workers,
args.num_steps,
args.layer_width,
args.num_layers,
)
23 changes: 23 additions & 0 deletions benchmarks/recursiongfn/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/usr/bin/env python

import os
from gflownet.models.bengio2021flow import load_original_model
from pathlib import Path


if __name__ == "__main__":
# If you need the whole configuration:
# config = json.loads(os.environ["MILABENCH_CONFIG"])
print("+ Full environment:\n{}\n***".format(os.environ))

#milabench_cfg = os.environ["MILABENCH_CONFIG"]
#print(milabench_cfg)

xdg_cache = os.environ["XDG_CACHE_HOME"]

print("+ Loading proxy model weights to MILABENCH_DIR_DATA={}".format(xdg_cache))
_ = load_original_model(
cache=True,
location=Path(os.path.join(xdg_cache, "bengio2021flow_proxy.pkl.gz")),
)

5 changes: 5 additions & 0 deletions benchmarks/recursiongfn/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
voir>=0.2.17,<0.3
torch
gflownet @ git+https://github.com/recursionpharma/gflownet@bengioe-mila-demo
--find-links https://data.pyg.org/whl/torch-2.1.2+cu121.html

38 changes: 38 additions & 0 deletions benchmarks/recursiongfn/voirfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from dataclasses import dataclass

from voir import configurable
from voir.instruments import dash, early_stop, log, rate
from benchmate.monitor import monitor_monogpu

@dataclass
class Config:
"""voir configuration"""

# Whether to display the dash or not
dash: bool = False

# How often to log the rates
interval: str = "1s"

# Number of rates to skip before logging
skip: int = 5

# Number of rates to log before stopping
stop: int = 20

# Number of seconds between each gpu poll
gpu_poll: int = 3


@configurable
def instrument_main(ov, options: Config):
yield ov.phases.init

if options.dash:
ov.require(dash)

ov.require(
log("value", "progress", "rate", "units", "loss", "gpudata", context="task"),
early_stop(n=options.stop, key="rate", task="train"),
monitor_monogpu(poll_interval=options.gpu_poll),
)
7 changes: 5 additions & 2 deletions benchmate/benchmate/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,13 @@ def iterate(self, iterator, custom_step=False):
def step(self):
self.instance.step()

def original_dataloader(self):
return self.instance

def loader(self, loader, custom_step=False):
"""Wrap a dataloader or an iterable which enable accurate measuring of time spent in the loop's body"""
if self.instance:
return self.instance
if self.instance is not None:
return self.instance.loader

cls = TimedIterator
if custom_step:
Expand Down

0 comments on commit 3ca4d4b

Please sign in to comment.