From 81a778c506d5536132be7a83298d9b6d19e4d835 Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Thu, 1 Aug 2024 19:34:10 -0400 Subject: [PATCH] instrumentation concept --- benchmarks/torch_ppo_atari_envpool/dev.yaml | 1 + benchmarks/torch_ppo_atari_envpool/main.py | 5 +- .../torch_ppo_atari_envpool/requirements.in | 2 + .../torch_ppo_atari_envpool/voirfile.py | 49 ++++++++++++++++--- 4 files changed, 47 insertions(+), 10 deletions(-) diff --git a/benchmarks/torch_ppo_atari_envpool/dev.yaml b/benchmarks/torch_ppo_atari_envpool/dev.yaml index c01211d98..c8668abb4 100644 --- a/benchmarks/torch_ppo_atari_envpool/dev.yaml +++ b/benchmarks/torch_ppo_atari_envpool/dev.yaml @@ -1,5 +1,6 @@ torch_ppo_atari_envpool: + max_duration: 60000 inherits: _defaults definition: . install-variant: unpinned diff --git a/benchmarks/torch_ppo_atari_envpool/main.py b/benchmarks/torch_ppo_atari_envpool/main.py index 7af2e7bbf..c7cbdf742 100644 --- a/benchmarks/torch_ppo_atari_envpool/main.py +++ b/benchmarks/torch_ppo_atari_envpool/main.py @@ -213,8 +213,9 @@ def get_action_and_value(self, x, action=None): start_time = time.time() next_obs = torch.Tensor(envs.reset()).to(device) next_done = torch.zeros(args.num_envs).to(device) + iterations = range(1, args.num_iterations + 1) - for iteration in range(1, args.num_iterations + 1): + for iteration in iterations: # Annealing the rate if instructed to do so. if args.anneal_lr: frac = 1.0 - (iteration - 1.0) / args.num_iterations @@ -240,7 +241,7 @@ def get_action_and_value(self, x, action=None): for idx, d in enumerate(next_done): if d and info["lives"][idx] == 0: - print(f"global_step={global_step}, episodic_return={info['r'][idx]}") + # print(f"global_step={global_step}, episodic_return={info['r'][idx]}") avg_returns.append(info["r"][idx]) writer.add_scalar("charts/avg_episodic_return", np.average(avg_returns), global_step) writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) diff --git a/benchmarks/torch_ppo_atari_envpool/requirements.in b/benchmarks/torch_ppo_atari_envpool/requirements.in index dbd35ac19..c264f5563 100644 --- a/benchmarks/torch_ppo_atari_envpool/requirements.in +++ b/benchmarks/torch_ppo_atari_envpool/requirements.in @@ -5,3 +5,5 @@ torch tyro voir tensorboard +torchcompat +cantilever diff --git a/benchmarks/torch_ppo_atari_envpool/voirfile.py b/benchmarks/torch_ppo_atari_envpool/voirfile.py index d93f886cd..8556a7b4e 100644 --- a/benchmarks/torch_ppo_atari_envpool/voirfile.py +++ b/benchmarks/torch_ppo_atari_envpool/voirfile.py @@ -1,8 +1,10 @@ from dataclasses import dataclass from voir import configurable -from voir.instruments import dash, early_stop, log, rate -from benchmate.monitor import monitor_monogpu +from voir.phase import StopProgram +from benchmate.observer import BenchObserver +from benchmate.monitor import voirfile_monitor + @dataclass class Config: @@ -28,11 +30,42 @@ class Config: def instrument_main(ov, options: Config): yield ov.phases.init - if options.dash: - ov.require(dash) + # GPU monitor, rate, loss etc... + voirfile_monitor(ov, options) + + yield ov.phases.load_script + + env_size = 0 + + def fetch_args(args): + nonlocal env_size + env_size = args.num_envs + return args + + def batch_size(x): + return env_size - ov.require( - log("value", "progress", "rate", "units", "loss", "gpudata", context="task"), - early_stop(n=options.stop, key="rate", task="train"), - monitor_monogpu(poll_interval=options.gpu_poll), + observer = BenchObserver( + earlystop=options.stop + options.skip, + batch_size_fn=batch_size, ) + + probe = ov.probe("//main > args", overridable=True) + probe['args'].override(fetch_args) + + probe = ov.probe("//main > iterations", overridable=True) + probe['iterations'].override(observer.loader) + + probe = ov.probe("//main > loss", overridable=True) + probe["loss"].override(observer.record_loss) + + probe = ov.probe("//main > optimizer", overridable=True) + probe['optimizer'].override(observer.optimizer) + + # + # Run the benchmark + # + try: + yield ov.phases.run_script + except StopProgram: + print("early stopped") \ No newline at end of file