From a6bb275e184d99ddf5c35feaaeba82b77a0b2cea Mon Sep 17 00:00:00 2001 From: roger-creus Date: Thu, 1 Aug 2024 11:27:52 -0400 Subject: [PATCH 01/13] Initial commit Torch_PPO_Cleanrl_Atari_Envpool --- benchmarks/torch_ppo_atari_envpool/Makefile | 31 ++ benchmarks/torch_ppo_atari_envpool/README.md | 4 + .../torch_ppo_atari_envpool/benchfile.py | 31 ++ benchmarks/torch_ppo_atari_envpool/dev.yaml | 7 + .../mark_torch_ppo_atari_envpool | 0 benchmarks/torch_ppo_atari_envpool/main.py | 344 ++++++++++++++++++ benchmarks/torch_ppo_atari_envpool/prepare.py | 16 + .../torch_ppo_atari_envpool/requirements.in | 88 +++++ .../torch_ppo_atari_envpool/voirfile.py | 38 ++ .../mark_torch_ppo_atari_envpool | 0 milabench/_version.py | 6 +- 11 files changed, 562 insertions(+), 3 deletions(-) create mode 100644 benchmarks/torch_ppo_atari_envpool/Makefile create mode 100644 benchmarks/torch_ppo_atari_envpool/README.md create mode 100644 benchmarks/torch_ppo_atari_envpool/benchfile.py create mode 100644 benchmarks/torch_ppo_atari_envpool/dev.yaml create mode 100644 benchmarks/torch_ppo_atari_envpool/extra/torch_ppo_atari_envpool/mark_torch_ppo_atari_envpool create mode 100644 benchmarks/torch_ppo_atari_envpool/main.py create mode 100755 benchmarks/torch_ppo_atari_envpool/prepare.py create mode 100644 benchmarks/torch_ppo_atari_envpool/requirements.in create mode 100644 benchmarks/torch_ppo_atari_envpool/voirfile.py create mode 100644 extra/torch_ppo_atari_envpool/mark_torch_ppo_atari_envpool diff --git a/benchmarks/torch_ppo_atari_envpool/Makefile b/benchmarks/torch_ppo_atari_envpool/Makefile new file mode 100644 index 000000000..80cff09bb --- /dev/null +++ b/benchmarks/torch_ppo_atari_envpool/Makefile @@ -0,0 +1,31 @@ +# Use global base if possible +ifndef MILABENCH_BASE + MILABENCH_BASE="base" +endif + +export MILABENCH_BASE + +BENCH_NAME=torch_ppo_atari_envpool +MILABENCH_CONFIG=dev.yaml +MILABENCH_ARGS=--config $(MILABENCH_CONFIG) --base $(MILABENCH_BASE) + +all: + install prepare single gpus nodes + +install: + milabench install $(MILABENCH_ARGS) --force + +prepare: + milabench prepare $(MILABENCH_ARGS) + +tests: install prepare + milabench run $(MILABENCH_ARGS) + +single: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-single + +gpus: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-gpus + +nodes: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-nodes diff --git a/benchmarks/torch_ppo_atari_envpool/README.md b/benchmarks/torch_ppo_atari_envpool/README.md new file mode 100644 index 000000000..44de20162 --- /dev/null +++ b/benchmarks/torch_ppo_atari_envpool/README.md @@ -0,0 +1,4 @@ + +# Torch_ppo_atari_envpool + +Rewrite this README to explain what the benchmark is! diff --git a/benchmarks/torch_ppo_atari_envpool/benchfile.py b/benchmarks/torch_ppo_atari_envpool/benchfile.py new file mode 100644 index 000000000..5625f7ed9 --- /dev/null +++ b/benchmarks/torch_ppo_atari_envpool/benchfile.py @@ -0,0 +1,31 @@ +from milabench.pack import Package + + +class Torch_ppo_atari_envpool(Package): + # Requirements file installed by install(). It can be empty or absent. + base_requirements = "requirements.in" + + # The preparation script called by prepare(). It must be executable, + # but it can be any type of script. It can be empty or absent. + prepare_script = "prepare.py" + + # The main script called by run(). It must be a Python file. It has to + # be present. + main_script = "main.py" + + # You can remove the functions below if you don't need to modify them. + + def make_env(self): + # Return a dict of environment variables for prepare_script and + # main_script. + return super().make_env() + + async def install(self): + await super().install() # super() call installs the requirements + + async def prepare(self): + await super().prepare() # super() call executes prepare_script + + + +__pack__ = Torch_ppo_atari_envpool diff --git a/benchmarks/torch_ppo_atari_envpool/dev.yaml b/benchmarks/torch_ppo_atari_envpool/dev.yaml new file mode 100644 index 000000000..aae0fff44 --- /dev/null +++ b/benchmarks/torch_ppo_atari_envpool/dev.yaml @@ -0,0 +1,7 @@ + +torch_ppo_atari_envpool: + inherits: _defaults + definition: . + install-variant: unpinned + plan: + method: per_gpu diff --git a/benchmarks/torch_ppo_atari_envpool/extra/torch_ppo_atari_envpool/mark_torch_ppo_atari_envpool b/benchmarks/torch_ppo_atari_envpool/extra/torch_ppo_atari_envpool/mark_torch_ppo_atari_envpool new file mode 100644 index 000000000..e69de29bb diff --git a/benchmarks/torch_ppo_atari_envpool/main.py b/benchmarks/torch_ppo_atari_envpool/main.py new file mode 100644 index 000000000..7af2e7bbf --- /dev/null +++ b/benchmarks/torch_ppo_atari_envpool/main.py @@ -0,0 +1,344 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpoolpy +import os +import random +import time +from collections import deque +from dataclasses import dataclass + +import envpool +import gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import tyro +from torch.distributions.categorical import Categorical +from torch.utils.tensorboard import SummaryWriter + + +@dataclass +class Args: + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + torch_deterministic: bool = True + """if toggled, `torch.backends.cudnn.deterministic=False`""" + cuda: bool = True + """if toggled, cuda will be enabled by default""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanRL" + """the wandb's project name""" + wandb_entity: str = None + """the entity (team) of wandb's project""" + capture_video: bool = False + """whether to capture videos of the agent performances (check out `videos` folder)""" + + # Algorithm specific arguments + env_id: str = "Breakout-v5" + """the id of the environment""" + total_timesteps: int = 10000000 + """total timesteps of the experiments""" + learning_rate: float = 2.5e-4 + """the learning rate of the optimizer""" + num_envs: int = 128 + """the number of parallel game environments""" + num_steps: int = 128 + """the number of steps to run in each environment per policy rollout""" + anneal_lr: bool = True + """Toggle learning rate annealing for policy and value networks""" + gamma: float = 0.99 + """the discount factor gamma""" + gae_lambda: float = 0.95 + """the lambda for the general advantage estimation""" + num_minibatches: int = 16 + """the number of mini-batches""" + update_epochs: int = 4 + """the K epochs to update the policy""" + norm_adv: bool = True + """Toggles advantages normalization""" + clip_coef: float = 0.1 + """the surrogate clipping coefficient""" + clip_vloss: bool = True + """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" + ent_coef: float = 0.01 + """coefficient of the entropy""" + vf_coef: float = 0.5 + """coefficient of the value function""" + max_grad_norm: float = 0.5 + """the maximum norm for the gradient clipping""" + target_kl: float = None + """the target KL divergence threshold""" + + # to be filled in runtime + batch_size: int = 0 + """the batch size (computed in runtime)""" + minibatch_size: int = 0 + """the mini-batch size (computed in runtime)""" + num_iterations: int = 0 + """the number of iterations (computed in runtime)""" + + +class RecordEpisodeStatistics(gym.Wrapper): + def __init__(self, env, deque_size=100): + super().__init__(env) + self.num_envs = getattr(env, "num_envs", 1) + self.episode_returns = None + self.episode_lengths = None + + def reset(self, **kwargs): + observations = super().reset(**kwargs) + self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + self.lives = np.zeros(self.num_envs, dtype=np.int32) + self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + return observations + + def step(self, action): + observations, rewards, dones, infos = super().step(action) + self.episode_returns += infos["reward"] + self.episode_lengths += 1 + self.returned_episode_returns[:] = self.episode_returns + self.returned_episode_lengths[:] = self.episode_lengths + self.episode_returns *= 1 - infos["terminated"] + self.episode_lengths *= 1 - infos["terminated"] + infos["r"] = self.returned_episode_returns + infos["l"] = self.returned_episode_lengths + return ( + observations, + rewards, + dones, + infos, + ) + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + +class Agent(nn.Module): + def __init__(self, envs): + super().__init__() + self.network = nn.Sequential( + layer_init(nn.Conv2d(4, 32, 8, stride=4)), + nn.ReLU(), + layer_init(nn.Conv2d(32, 64, 4, stride=2)), + nn.ReLU(), + layer_init(nn.Conv2d(64, 64, 3, stride=1)), + nn.ReLU(), + nn.Flatten(), + layer_init(nn.Linear(64 * 7 * 7, 512)), + nn.ReLU(), + ) + self.actor = layer_init(nn.Linear(512, envs.single_action_space.n), std=0.01) + self.critic = layer_init(nn.Linear(512, 1), std=1) + + def get_value(self, x): + return self.critic(self.network(x / 255.0)) + + def get_action_and_value(self, x, action=None): + hidden = self.network(x / 255.0) + logits = self.actor(hidden) + probs = Categorical(logits=logits) + if action is None: + action = probs.sample() + return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) + + +if __name__ == "__main__": + args = tyro.cli(Args) + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + args.num_iterations = args.total_timesteps // args.batch_size + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + # env setup + envs = envpool.make( + args.env_id, + env_type="gym", + num_envs=args.num_envs, + episodic_life=True, + reward_clip=True, + seed=args.seed, + ) + envs.num_envs = args.num_envs + envs.single_action_space = envs.action_space + envs.single_observation_space = envs.observation_space + envs = RecordEpisodeStatistics(envs) + assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported" + + agent = Agent(envs).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + # ALGO Logic: Storage setup + obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) + actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + avg_returns = deque(maxlen=20) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs = torch.Tensor(envs.reset()).to(device) + next_done = torch.zeros(args.num_envs).to(device) + + for iteration in range(1, args.num_iterations + 1): + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (iteration - 1.0) / args.num_iterations + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value = agent.get_action_and_value(next_obs) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, next_done, info = envs.step(action.cpu().numpy()) + rewards[step] = torch.tensor(reward).to(device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device) + + for idx, d in enumerate(next_done): + if d and info["lives"][idx] == 0: + print(f"global_step={global_step}, episodic_return={info['r'][idx]}") + avg_returns.append(info["r"][idx]) + writer.add_scalar("charts/avg_episodic_return", np.average(avg_returns), global_step) + writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) + writer.add_scalar("charts/episodic_length", info["l"][idx], global_step) + + # bootstrap value if not done + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values + + # flatten the batch + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None and approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + envs.close() + writer.close() \ No newline at end of file diff --git a/benchmarks/torch_ppo_atari_envpool/prepare.py b/benchmarks/torch_ppo_atari_envpool/prepare.py new file mode 100755 index 000000000..32bd5901d --- /dev/null +++ b/benchmarks/torch_ppo_atari_envpool/prepare.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python + +import os + +if __name__ == "__main__": + # If you need the whole configuration: + # config = json.loads(os.environ["MILABENCH_CONFIG"]) + + data_directory = os.environ["MILABENCH_DIR_DATA"] + + # Download (or generate) the needed dataset(s). You are responsible + # to check if it has already been properly downloaded or not, and to + # do nothing if it has been. + print("Hello I am doing some data stuff!") + + # If there is nothing to download or generate, just delete this file. diff --git a/benchmarks/torch_ppo_atari_envpool/requirements.in b/benchmarks/torch_ppo_atari_envpool/requirements.in new file mode 100644 index 000000000..7a663dabd --- /dev/null +++ b/benchmarks/torch_ppo_atari_envpool/requirements.in @@ -0,0 +1,88 @@ +absl-py==1.4.0 ; python_version >= "3.8" and python_version < "3.11" +appdirs==1.4.4 ; python_version >= "3.8" and python_version < "3.11" +bitmath==1.3.3.1 ; python_version >= "3.8" and python_version < "3.11" +cachetools==5.3.0 ; python_version >= "3.8" and python_version < "3.11" +certifi==2023.5.7 ; python_version >= "3.8" and python_version < "3.11" +chardet==4.0.0 ; python_version >= "3.8" and python_version < "3.11" +charset-normalizer==3.1.0 ; python_version >= "3.8" and python_version < "3.11" +click==8.1.3 ; python_version >= "3.8" and python_version < "3.11" +cloudpickle==2.2.1 ; python_version >= "3.8" and python_version < "3.11" +colorama==0.4.4 ; python_version >= "3.8" and python_version < "3.11" +commonmark==0.9.1 ; python_version >= "3.8" and python_version < "3.11" +cycler==0.11.0 ; python_version >= "3.8" and python_version < "3.11" +decorator==4.4.2 ; python_version >= "3.8" and python_version < "3.11" +dill==0.3.6 ; python_version >= "3.8" and python_version < "3.11" +dm-env==1.6 ; python_version >= "3.8" and python_version < "3.11" +dm-tree==0.1.8 ; python_version >= "3.8" and python_version < "3.11" +docker-pycreds==0.4.0 ; python_version >= "3.8" and python_version < "3.11" +docstring-parser==0.15 ; python_version >= "3.8" and python_version < "3.11" +enum-tools==0.9.0.post1 ; python_version >= "3.8" and python_version < "3.11" +envpool==0.6.6 ; python_version >= "3.8" and python_version < "3.11" +farama-notifications==0.0.4 ; python_version >= "3.8" and python_version < "3.11" +filelock==3.12.0 ; python_version >= "3.8" and python_version < "3.11" +fonttools==4.38.0 ; python_version >= "3.8" and python_version < "3.11" +gitdb==4.0.10 ; python_version >= "3.8" and python_version < "3.11" +gitpython==3.1.31 ; python_version >= "3.8" and python_version < "3.11" +google-auth-oauthlib==0.4.6 ; python_version >= "3.8" and python_version < "3.11" +google-auth==2.18.0 ; python_version >= "3.8" and python_version < "3.11" +graphviz==0.20.1 ; python_version >= "3.8" and python_version < "3.11" +grpcio==1.54.0 ; python_version >= "3.8" and python_version < "3.11" +gym-notices==0.0.8 ; python_version >= "3.8" and python_version < "3.11" +gym==0.23.1 ; python_version >= "3.8" and python_version < "3.11" +gymnasium==0.28.1 ; python_version >= "3.8" and python_version < "3.11" +hbutils==0.8.6 ; python_version >= "3.8" and python_version < "3.11" +huggingface-hub==0.11.1 ; python_version >= "3.8" and python_version < "3.11" +idna==3.4 ; python_version >= "3.8" and python_version < "3.11" +imageio-ffmpeg==0.3.0 ; python_version >= "3.8" and python_version < "3.11" +imageio==2.28.1 ; python_version >= "3.8" and python_version < "3.11" +importlib-metadata==5.2.0 ; python_version >= "3.8" and python_version < "3.10" +jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" +kiwisolver==1.4.4 ; python_version >= "3.8" and python_version < "3.11" +markdown==3.3.7 ; python_version >= "3.8" and python_version < "3.11" +markupsafe==2.1.2 ; python_version >= "3.8" and python_version < "3.11" +matplotlib==3.5.3 ; python_version >= "3.8" and python_version < "3.11" +moviepy==1.0.3 ; python_version >= "3.8" and python_version < "3.11" +numpy==1.24.4 ; python_version >= "3.8" and python_version < "3.11" +oauthlib==3.2.2 ; python_version >= "3.8" and python_version < "3.11" +packaging==23.1 ; python_version >= "3.8" and python_version < "3.11" +pandas==1.3.5 ; python_version >= "3.8" and python_version < "3.11" +pathtools==0.1.2 ; python_version >= "3.8" and python_version < "3.11" +pillow==9.5.0 ; python_version >= "3.8" and python_version < "3.11" +proglog==0.1.10 ; python_version >= "3.8" and python_version < "3.11" +protobuf==3.20.3 ; python_version < "3.11" and python_version >= "3.8" +psutil==5.9.5 ; python_version >= "3.8" and python_version < "3.11" +pyasn1-modules==0.3.0 ; python_version >= "3.8" and python_version < "3.11" +pyasn1==0.5.0 ; python_version >= "3.8" and python_version < "3.11" +pygame==2.1.0 ; python_version >= "3.8" and python_version < "3.11" +pygments==2.15.1 ; python_version >= "3.8" and python_version < "3.11" +pyparsing==3.0.9 ; python_version >= "3.8" and python_version < "3.11" +python-dateutil==2.8.2 ; python_version >= "3.8" and python_version < "3.11" +pytimeparse==1.1.8 ; python_version >= "3.8" and python_version < "3.11" +pytz==2023.3 ; python_version >= "3.8" and python_version < "3.11" +pyyaml==6.0.1 ; python_version >= "3.8" and python_version < "3.11" +requests-oauthlib==1.3.1 ; python_version >= "3.8" and python_version < "3.11" +requests==2.30.0 ; python_version >= "3.8" and python_version < "3.11" +rich==11.2.0 ; python_version >= "3.8" and python_version < "3.11" +rsa==4.7.2 ; python_version >= "3.8" and python_version < "3.11" +sentry-sdk==1.22.2 ; python_version >= "3.8" and python_version < "3.11" +setproctitle==1.3.2 ; python_version >= "3.8" and python_version < "3.11" +setuptools==67.7.2 ; python_version >= "3.8" and python_version < "3.11" +shtab==1.6.4 ; python_version >= "3.8" and python_version < "3.11" +six==1.16.0 ; python_version >= "3.8" and python_version < "3.11" +smmap==5.0.0 ; python_version >= "3.8" and python_version < "3.11" +stable-baselines3==2.0.0 ; python_version >= "3.8" and python_version < "3.11" +tenacity==8.2.3 ; python_version >= "3.8" and python_version < "3.11" +tensorboard-data-server==0.6.1 ; python_version >= "3.8" and python_version < "3.11" +tensorboard-plugin-wit==1.8.1 ; python_version >= "3.8" and python_version < "3.11" +tensorboard==2.11.2 ; python_version >= "3.8" and python_version < "3.11" +torch==1.12.1 ; python_version >= "3.8" and python_version < "3.11" +tqdm==4.65.0 ; python_version >= "3.8" and python_version < "3.11" +treevalue==1.4.10 ; python_version >= "3.8" and python_version < "3.11" +types-protobuf==4.23.0.1 ; python_version >= "3.8" and python_version < "3.11" +typing-extensions==4.5.0 ; python_version >= "3.8" and python_version < "3.11" +tyro==0.5.10 ; python_version >= "3.8" and python_version < "3.11" +urllib3==1.26.15 ; python_version >= "3.8" and python_version < "3.11" +wandb==0.13.11 ; python_version >= "3.8" and python_version < "3.11" +werkzeug==2.2.3 ; python_version >= "3.8" and python_version < "3.11" +wheel==0.40.0 ; python_version >= "3.8" and python_version < "3.11" +zipp==3.15.0 ; python_version >= "3.8" and python_version < "3.10" \ No newline at end of file diff --git a/benchmarks/torch_ppo_atari_envpool/voirfile.py b/benchmarks/torch_ppo_atari_envpool/voirfile.py new file mode 100644 index 000000000..d93f886cd --- /dev/null +++ b/benchmarks/torch_ppo_atari_envpool/voirfile.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass + +from voir import configurable +from voir.instruments import dash, early_stop, log, rate +from benchmate.monitor import monitor_monogpu + +@dataclass +class Config: + """voir configuration""" + + # Whether to display the dash or not + dash: bool = False + + # How often to log the rates + interval: str = "1s" + + # Number of rates to skip before logging + skip: int = 5 + + # Number of rates to log before stopping + stop: int = 20 + + # Number of seconds between each gpu poll + gpu_poll: int = 3 + + +@configurable +def instrument_main(ov, options: Config): + yield ov.phases.init + + if options.dash: + ov.require(dash) + + ov.require( + log("value", "progress", "rate", "units", "loss", "gpudata", context="task"), + early_stop(n=options.stop, key="rate", task="train"), + monitor_monogpu(poll_interval=options.gpu_poll), + ) diff --git a/extra/torch_ppo_atari_envpool/mark_torch_ppo_atari_envpool b/extra/torch_ppo_atari_envpool/mark_torch_ppo_atari_envpool new file mode 100644 index 000000000..e69de29bb diff --git a/milabench/_version.py b/milabench/_version.py index d8ae9287b..59bfbea09 100644 --- a/milabench/_version.py +++ b/milabench/_version.py @@ -1,5 +1,5 @@ """This file is generated, do not modify""" -__tag__ = "v0.1.0-30-g64aa548b" -__commit__ = "64aa548ba07d3c6bb298e435b8ac43c69eb75738" -__date__ = "2024-07-26 13:07:25 -0400" +__tag__ = "3a30894" +__commit__ = "3a30894b44bc570983fb3f19bb316babddeb83f2" +__date__ = "2024-08-01 11:06:49 -0400" From c31e7be9b01fffeb8f1461761cc114b111ffdd08 Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Thu, 1 Aug 2024 12:14:45 -0400 Subject: [PATCH 02/13] Simplify requirements.in --- benchmarks/torch_ppo_atari_envpool/Makefile | 2 +- benchmarks/torch_ppo_atari_envpool/dev.yaml | 1 + .../mark_torch_ppo_atari_envpool | 0 .../torch_ppo_atari_envpool/requirements.in | 95 ++----------------- milabench/_version.py | 6 +- 5 files changed, 12 insertions(+), 92 deletions(-) delete mode 100644 benchmarks/torch_ppo_atari_envpool/extra/torch_ppo_atari_envpool/mark_torch_ppo_atari_envpool diff --git a/benchmarks/torch_ppo_atari_envpool/Makefile b/benchmarks/torch_ppo_atari_envpool/Makefile index 80cff09bb..edd4ad425 100644 --- a/benchmarks/torch_ppo_atari_envpool/Makefile +++ b/benchmarks/torch_ppo_atari_envpool/Makefile @@ -22,7 +22,7 @@ tests: install prepare milabench run $(MILABENCH_ARGS) single: - milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-single + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME) gpus: milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-gpus diff --git a/benchmarks/torch_ppo_atari_envpool/dev.yaml b/benchmarks/torch_ppo_atari_envpool/dev.yaml index aae0fff44..c01211d98 100644 --- a/benchmarks/torch_ppo_atari_envpool/dev.yaml +++ b/benchmarks/torch_ppo_atari_envpool/dev.yaml @@ -3,5 +3,6 @@ torch_ppo_atari_envpool: inherits: _defaults definition: . install-variant: unpinned + install_group: torch plan: method: per_gpu diff --git a/benchmarks/torch_ppo_atari_envpool/extra/torch_ppo_atari_envpool/mark_torch_ppo_atari_envpool b/benchmarks/torch_ppo_atari_envpool/extra/torch_ppo_atari_envpool/mark_torch_ppo_atari_envpool deleted file mode 100644 index e69de29bb..000000000 diff --git a/benchmarks/torch_ppo_atari_envpool/requirements.in b/benchmarks/torch_ppo_atari_envpool/requirements.in index 7a663dabd..60f05fd46 100644 --- a/benchmarks/torch_ppo_atari_envpool/requirements.in +++ b/benchmarks/torch_ppo_atari_envpool/requirements.in @@ -1,88 +1,7 @@ -absl-py==1.4.0 ; python_version >= "3.8" and python_version < "3.11" -appdirs==1.4.4 ; python_version >= "3.8" and python_version < "3.11" -bitmath==1.3.3.1 ; python_version >= "3.8" and python_version < "3.11" -cachetools==5.3.0 ; python_version >= "3.8" and python_version < "3.11" -certifi==2023.5.7 ; python_version >= "3.8" and python_version < "3.11" -chardet==4.0.0 ; python_version >= "3.8" and python_version < "3.11" -charset-normalizer==3.1.0 ; python_version >= "3.8" and python_version < "3.11" -click==8.1.3 ; python_version >= "3.8" and python_version < "3.11" -cloudpickle==2.2.1 ; python_version >= "3.8" and python_version < "3.11" -colorama==0.4.4 ; python_version >= "3.8" and python_version < "3.11" -commonmark==0.9.1 ; python_version >= "3.8" and python_version < "3.11" -cycler==0.11.0 ; python_version >= "3.8" and python_version < "3.11" -decorator==4.4.2 ; python_version >= "3.8" and python_version < "3.11" -dill==0.3.6 ; python_version >= "3.8" and python_version < "3.11" -dm-env==1.6 ; python_version >= "3.8" and python_version < "3.11" -dm-tree==0.1.8 ; python_version >= "3.8" and python_version < "3.11" -docker-pycreds==0.4.0 ; python_version >= "3.8" and python_version < "3.11" -docstring-parser==0.15 ; python_version >= "3.8" and python_version < "3.11" -enum-tools==0.9.0.post1 ; python_version >= "3.8" and python_version < "3.11" -envpool==0.6.6 ; python_version >= "3.8" and python_version < "3.11" -farama-notifications==0.0.4 ; python_version >= "3.8" and python_version < "3.11" -filelock==3.12.0 ; python_version >= "3.8" and python_version < "3.11" -fonttools==4.38.0 ; python_version >= "3.8" and python_version < "3.11" -gitdb==4.0.10 ; python_version >= "3.8" and python_version < "3.11" -gitpython==3.1.31 ; python_version >= "3.8" and python_version < "3.11" -google-auth-oauthlib==0.4.6 ; python_version >= "3.8" and python_version < "3.11" -google-auth==2.18.0 ; python_version >= "3.8" and python_version < "3.11" -graphviz==0.20.1 ; python_version >= "3.8" and python_version < "3.11" -grpcio==1.54.0 ; python_version >= "3.8" and python_version < "3.11" -gym-notices==0.0.8 ; python_version >= "3.8" and python_version < "3.11" -gym==0.23.1 ; python_version >= "3.8" and python_version < "3.11" -gymnasium==0.28.1 ; python_version >= "3.8" and python_version < "3.11" -hbutils==0.8.6 ; python_version >= "3.8" and python_version < "3.11" -huggingface-hub==0.11.1 ; python_version >= "3.8" and python_version < "3.11" -idna==3.4 ; python_version >= "3.8" and python_version < "3.11" -imageio-ffmpeg==0.3.0 ; python_version >= "3.8" and python_version < "3.11" -imageio==2.28.1 ; python_version >= "3.8" and python_version < "3.11" -importlib-metadata==5.2.0 ; python_version >= "3.8" and python_version < "3.10" -jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" -kiwisolver==1.4.4 ; python_version >= "3.8" and python_version < "3.11" -markdown==3.3.7 ; python_version >= "3.8" and python_version < "3.11" -markupsafe==2.1.2 ; python_version >= "3.8" and python_version < "3.11" -matplotlib==3.5.3 ; python_version >= "3.8" and python_version < "3.11" -moviepy==1.0.3 ; python_version >= "3.8" and python_version < "3.11" -numpy==1.24.4 ; python_version >= "3.8" and python_version < "3.11" -oauthlib==3.2.2 ; python_version >= "3.8" and python_version < "3.11" -packaging==23.1 ; python_version >= "3.8" and python_version < "3.11" -pandas==1.3.5 ; python_version >= "3.8" and python_version < "3.11" -pathtools==0.1.2 ; python_version >= "3.8" and python_version < "3.11" -pillow==9.5.0 ; python_version >= "3.8" and python_version < "3.11" -proglog==0.1.10 ; python_version >= "3.8" and python_version < "3.11" -protobuf==3.20.3 ; python_version < "3.11" and python_version >= "3.8" -psutil==5.9.5 ; python_version >= "3.8" and python_version < "3.11" -pyasn1-modules==0.3.0 ; python_version >= "3.8" and python_version < "3.11" -pyasn1==0.5.0 ; python_version >= "3.8" and python_version < "3.11" -pygame==2.1.0 ; python_version >= "3.8" and python_version < "3.11" -pygments==2.15.1 ; python_version >= "3.8" and python_version < "3.11" -pyparsing==3.0.9 ; python_version >= "3.8" and python_version < "3.11" -python-dateutil==2.8.2 ; python_version >= "3.8" and python_version < "3.11" -pytimeparse==1.1.8 ; python_version >= "3.8" and python_version < "3.11" -pytz==2023.3 ; python_version >= "3.8" and python_version < "3.11" -pyyaml==6.0.1 ; python_version >= "3.8" and python_version < "3.11" -requests-oauthlib==1.3.1 ; python_version >= "3.8" and python_version < "3.11" -requests==2.30.0 ; python_version >= "3.8" and python_version < "3.11" -rich==11.2.0 ; python_version >= "3.8" and python_version < "3.11" -rsa==4.7.2 ; python_version >= "3.8" and python_version < "3.11" -sentry-sdk==1.22.2 ; python_version >= "3.8" and python_version < "3.11" -setproctitle==1.3.2 ; python_version >= "3.8" and python_version < "3.11" -setuptools==67.7.2 ; python_version >= "3.8" and python_version < "3.11" -shtab==1.6.4 ; python_version >= "3.8" and python_version < "3.11" -six==1.16.0 ; python_version >= "3.8" and python_version < "3.11" -smmap==5.0.0 ; python_version >= "3.8" and python_version < "3.11" -stable-baselines3==2.0.0 ; python_version >= "3.8" and python_version < "3.11" -tenacity==8.2.3 ; python_version >= "3.8" and python_version < "3.11" -tensorboard-data-server==0.6.1 ; python_version >= "3.8" and python_version < "3.11" -tensorboard-plugin-wit==1.8.1 ; python_version >= "3.8" and python_version < "3.11" -tensorboard==2.11.2 ; python_version >= "3.8" and python_version < "3.11" -torch==1.12.1 ; python_version >= "3.8" and python_version < "3.11" -tqdm==4.65.0 ; python_version >= "3.8" and python_version < "3.11" -treevalue==1.4.10 ; python_version >= "3.8" and python_version < "3.11" -types-protobuf==4.23.0.1 ; python_version >= "3.8" and python_version < "3.11" -typing-extensions==4.5.0 ; python_version >= "3.8" and python_version < "3.11" -tyro==0.5.10 ; python_version >= "3.8" and python_version < "3.11" -urllib3==1.26.15 ; python_version >= "3.8" and python_version < "3.11" -wandb==0.13.11 ; python_version >= "3.8" and python_version < "3.11" -werkzeug==2.2.3 ; python_version >= "3.8" and python_version < "3.11" -wheel==0.40.0 ; python_version >= "3.8" and python_version < "3.11" -zipp==3.15.0 ; python_version >= "3.8" and python_version < "3.10" \ No newline at end of file +envpool +gym +numpy +torch +tyro +voir +tensorboard diff --git a/milabench/_version.py b/milabench/_version.py index 59bfbea09..c87992389 100644 --- a/milabench/_version.py +++ b/milabench/_version.py @@ -1,5 +1,5 @@ """This file is generated, do not modify""" -__tag__ = "3a30894" -__commit__ = "3a30894b44bc570983fb3f19bb316babddeb83f2" -__date__ = "2024-08-01 11:06:49 -0400" +__tag__ = "v0.0.6-80-ga6bb275" +__commit__ = "a6bb275e184d99ddf5c35feaaeba82b77a0b2cea" +__date__ = "2024-08-01 11:27:52 -0400" From 2b10866f53d0328d02cb04af236d0ec2707645ed Mon Sep 17 00:00:00 2001 From: roger-creus Date: Thu, 1 Aug 2024 17:16:32 -0400 Subject: [PATCH 03/13] Code now runs --- benchmarks/torch_ppo_atari_envpool/requirements.in | 2 +- milabench/_version.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/torch_ppo_atari_envpool/requirements.in b/benchmarks/torch_ppo_atari_envpool/requirements.in index 60f05fd46..dbd35ac19 100644 --- a/benchmarks/torch_ppo_atari_envpool/requirements.in +++ b/benchmarks/torch_ppo_atari_envpool/requirements.in @@ -1,5 +1,5 @@ envpool -gym +gym==0.23.1 numpy torch tyro diff --git a/milabench/_version.py b/milabench/_version.py index c87992389..ceafcb294 100644 --- a/milabench/_version.py +++ b/milabench/_version.py @@ -1,5 +1,5 @@ """This file is generated, do not modify""" -__tag__ = "v0.0.6-80-ga6bb275" -__commit__ = "a6bb275e184d99ddf5c35feaaeba82b77a0b2cea" -__date__ = "2024-08-01 11:27:52 -0400" +__tag__ = "c7ae304" +__commit__ = "c7ae3043a12faef4da3eb0ddd6dc33e355b265fc" +__date__ = "2024-08-01 17:03:10 -0400" From deb8771c140f4c588a8878ee96a090e817713823 Mon Sep 17 00:00:00 2001 From: Pierre Delaunay Date: Thu, 1 Aug 2024 19:34:10 -0400 Subject: [PATCH 04/13] instrumentation concept --- benchmarks/torch_ppo_atari_envpool/dev.yaml | 1 + benchmarks/torch_ppo_atari_envpool/main.py | 17 ++++--- .../torch_ppo_atari_envpool/requirements.in | 2 + .../torch_ppo_atari_envpool/voirfile.py | 50 ++++++++++++++++--- 4 files changed, 56 insertions(+), 14 deletions(-) diff --git a/benchmarks/torch_ppo_atari_envpool/dev.yaml b/benchmarks/torch_ppo_atari_envpool/dev.yaml index c01211d98..c8668abb4 100644 --- a/benchmarks/torch_ppo_atari_envpool/dev.yaml +++ b/benchmarks/torch_ppo_atari_envpool/dev.yaml @@ -1,5 +1,6 @@ torch_ppo_atari_envpool: + max_duration: 60000 inherits: _defaults definition: . install-variant: unpinned diff --git a/benchmarks/torch_ppo_atari_envpool/main.py b/benchmarks/torch_ppo_atari_envpool/main.py index 7af2e7bbf..62c9b3a07 100644 --- a/benchmarks/torch_ppo_atari_envpool/main.py +++ b/benchmarks/torch_ppo_atari_envpool/main.py @@ -14,7 +14,7 @@ import tyro from torch.distributions.categorical import Categorical from torch.utils.tensorboard import SummaryWriter - +import torchcompat.core as acc @dataclass class Args: @@ -149,7 +149,7 @@ def get_action_and_value(self, x, action=None): return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) -if __name__ == "__main__": +def main(): args = tyro.cli(Args) args.batch_size = int(args.num_envs * args.num_steps) args.minibatch_size = int(args.batch_size // args.num_minibatches) @@ -179,7 +179,7 @@ def get_action_and_value(self, x, action=None): torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = args.torch_deterministic - device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + device = acc.fetch_device(0) # env setup envs = envpool.make( @@ -213,8 +213,9 @@ def get_action_and_value(self, x, action=None): start_time = time.time() next_obs = torch.Tensor(envs.reset()).to(device) next_done = torch.zeros(args.num_envs).to(device) + iterations = range(1, args.num_iterations + 1) - for iteration in range(1, args.num_iterations + 1): + for iteration in iterations: # Annealing the rate if instructed to do so. if args.anneal_lr: frac = 1.0 - (iteration - 1.0) / args.num_iterations @@ -240,7 +241,7 @@ def get_action_and_value(self, x, action=None): for idx, d in enumerate(next_done): if d and info["lives"][idx] == 0: - print(f"global_step={global_step}, episodic_return={info['r'][idx]}") + # print(f"global_step={global_step}, episodic_return={info['r'][idx]}") avg_returns.append(info["r"][idx]) writer.add_scalar("charts/avg_episodic_return", np.average(avg_returns), global_step) writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) @@ -341,4 +342,8 @@ def get_action_and_value(self, x, action=None): writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) envs.close() - writer.close() \ No newline at end of file + writer.close() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/benchmarks/torch_ppo_atari_envpool/requirements.in b/benchmarks/torch_ppo_atari_envpool/requirements.in index dbd35ac19..c264f5563 100644 --- a/benchmarks/torch_ppo_atari_envpool/requirements.in +++ b/benchmarks/torch_ppo_atari_envpool/requirements.in @@ -5,3 +5,5 @@ torch tyro voir tensorboard +torchcompat +cantilever diff --git a/benchmarks/torch_ppo_atari_envpool/voirfile.py b/benchmarks/torch_ppo_atari_envpool/voirfile.py index d93f886cd..ba11707d0 100644 --- a/benchmarks/torch_ppo_atari_envpool/voirfile.py +++ b/benchmarks/torch_ppo_atari_envpool/voirfile.py @@ -1,8 +1,10 @@ from dataclasses import dataclass from voir import configurable -from voir.instruments import dash, early_stop, log, rate -from benchmate.monitor import monitor_monogpu +from voir.phase import StopProgram +from benchmate.observer import BenchObserver +from benchmate.monitor import voirfile_monitor + @dataclass class Config: @@ -28,11 +30,43 @@ class Config: def instrument_main(ov, options: Config): yield ov.phases.init - if options.dash: - ov.require(dash) + # GPU monitor, rate, loss etc... + voirfile_monitor(ov, options) + + yield ov.phases.load_script + + step_per_iteration = 0 + + def fetch_args(args): + nonlocal step_per_iteration + step_per_iteration = args.num_envs * args.num_steps + return args + + def batch_size(x): + return step_per_iteration - ov.require( - log("value", "progress", "rate", "units", "loss", "gpudata", context="task"), - early_stop(n=options.stop, key="rate", task="train"), - monitor_monogpu(poll_interval=options.gpu_poll), + observer = BenchObserver( + earlystop=options.stop + options.skip, + batch_size_fn=batch_size, ) + + probe = ov.probe("//main > args", overridable=True) + probe['args'].override(fetch_args) + + # measure the time it took to execute the body + probe = ov.probe("//main > iterations", overridable=True) + probe['iterations'].override(observer.loader) + + probe = ov.probe("//main > loss", overridable=True) + probe["loss"].override(observer.record_loss) + + probe = ov.probe("//main > optimizer", overridable=True) + probe['optimizer'].override(observer.optimizer) + + # + # Run the benchmark + # + try: + yield ov.phases.run_script + except StopProgram: + print("early stopped") \ No newline at end of file From ecc80a8299bc2fb73984775c2bc1236fc402f2b6 Mon Sep 17 00:00:00 2001 From: "pierre.delaunay" Date: Fri, 2 Aug 2024 10:22:46 -0400 Subject: [PATCH 05/13] Fix CPU scaling --- benchmarks/torch_ppo_atari_envpool/Makefile | 6 +++--- benchmarks/torch_ppo_atari_envpool/dev.yaml | 10 +++++++++- .../torch_ppo_atari_envpool/voirfile.py | 19 +++++++++++++++++-- milabench/sizer.py | 14 +++++++------- 4 files changed, 36 insertions(+), 13 deletions(-) diff --git a/benchmarks/torch_ppo_atari_envpool/Makefile b/benchmarks/torch_ppo_atari_envpool/Makefile index edd4ad425..81443ce2b 100644 --- a/benchmarks/torch_ppo_atari_envpool/Makefile +++ b/benchmarks/torch_ppo_atari_envpool/Makefile @@ -18,11 +18,11 @@ install: prepare: milabench prepare $(MILABENCH_ARGS) -tests: install prepare - milabench run $(MILABENCH_ARGS) +tests: + MILABENCH_CPU_AUTO=1 CUDA_VISIBLE_DEVICES=0,1 milabench run $(MILABENCH_ARGS) single: - milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME) + MILABENCH_CPU_AUTO=1 milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME) gpus: milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-gpus diff --git a/benchmarks/torch_ppo_atari_envpool/dev.yaml b/benchmarks/torch_ppo_atari_envpool/dev.yaml index c8668abb4..338bed075 100644 --- a/benchmarks/torch_ppo_atari_envpool/dev.yaml +++ b/benchmarks/torch_ppo_atari_envpool/dev.yaml @@ -1,9 +1,17 @@ torch_ppo_atari_envpool: - max_duration: 60000 + max_duration: 600 inherits: _defaults definition: . install-variant: unpinned install_group: torch plan: method: per_gpu + + argv: + --num-minibatches: 16 + --update-epochs: 4 + --num-steps: 128 + --num-envs: auto({cpu_per_gpu}, 128) + --total-timesteps: 1000000 + --env-id: Breakout-v5 \ No newline at end of file diff --git a/benchmarks/torch_ppo_atari_envpool/voirfile.py b/benchmarks/torch_ppo_atari_envpool/voirfile.py index ba11707d0..7b8873852 100644 --- a/benchmarks/torch_ppo_atari_envpool/voirfile.py +++ b/benchmarks/torch_ppo_atari_envpool/voirfile.py @@ -57,8 +57,23 @@ def batch_size(x): probe = ov.probe("//main > iterations", overridable=True) probe['iterations'].override(observer.loader) - probe = ov.probe("//main > loss", overridable=True) - probe["loss"].override(observer.record_loss) + # Too many losses + # probe = ov.probe("//main > loss", overridable=True) + # probe["loss"].override(observer.record_loss) + + def record_starts(writer): + old_add_scalar = writer.add_scalar + + def add_scalar(name, *values): + if name == "losses/value_loss": + observer.record_loss(values[0]) + old_add_scalar(name, *values) + + writer.add_scalar = add_scalar + return writer + + probe = ov.probe("//main > writer", overridable=True) + probe["writer"].override(record_starts) probe = ov.probe("//main > optimizer", overridable=True) probe['optimizer'].override(observer.optimizer) diff --git a/milabench/sizer.py b/milabench/sizer.py index bc88e355a..3eacd3ae3 100644 --- a/milabench/sizer.py +++ b/milabench/sizer.py @@ -339,16 +339,16 @@ def new_argument_resolver(pack): context = deepcopy(system_config) arch = context.get("arch", "cpu") + device_count_used = 1 + device_count_system = len(get_gpu_info()["gpus"]) if hasattr(pack, "config"): - device_count = len(pack.config.get("devices", [0])) - else: - device_count = len(get_gpu_info()["gpus"]) + device_count_used = len(pack.config.get("devices", [0])) - ccl = {"hpu": "hccl", "cuda": "nccl", "rocm": "rccl", "xpu": "ccl", "cpu": "gloo"} + if device_count_used <= 0: + device_count_used = 1 - if device_count <= 0: - device_count = 1 + ccl = {"hpu": "hccl", "cuda": "nccl", "rocm": "rccl", "xpu": "ccl", "cpu": "gloo"} options = CPUOptions() def auto(value, default): @@ -363,7 +363,7 @@ def clamp(x, mn=options.cpu_min, mx=options.cpu_max): total_available = total_cpu - options.reserved_cores context["cpu_count"] = total_available - context["cpu_per_gpu"] = total_available // device_count + context["cpu_per_gpu"] = total_available // device_count_system context["n_worker"] = clamp(context["cpu_per_gpu"]) if options.n_workers is not None: From 40a96b9e66fccbde9de332471764520176167f6c Mon Sep 17 00:00:00 2001 From: roger-creus Date: Tue, 6 Aug 2024 11:35:31 -0400 Subject: [PATCH 06/13] added metric pusher --- benchmarks/torch_ppo_atari_envpool/main.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/benchmarks/torch_ppo_atari_envpool/main.py b/benchmarks/torch_ppo_atari_envpool/main.py index 62c9b3a07..267a4f080 100644 --- a/benchmarks/torch_ppo_atari_envpool/main.py +++ b/benchmarks/torch_ppo_atari_envpool/main.py @@ -15,6 +15,8 @@ from torch.distributions.categorical import Categorical from torch.utils.tensorboard import SummaryWriter import torchcompat.core as acc +from benchmate.metrics import give_push + @dataclass class Args: @@ -155,6 +157,8 @@ def main(): args.minibatch_size = int(args.batch_size // args.num_minibatches) args.num_iterations = args.total_timesteps // args.batch_size run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + metric_pusher = give_push() + if args.track: import wandb @@ -340,7 +344,10 @@ def main(): writer.add_scalar("losses/explained_variance", explained_var, global_step) print("SPS:", int(global_step / (time.time() - start_time))) writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) - + + metric_pusher(progress=int(global_step / (time.time() - start_time)), unit="steps/s") + metric_pusher(loss=loss.item(), avg_returns=np.average(avg_returns)) + envs.close() writer.close() From f88c03667f790c39d214b2e81701e9780bac81ed Mon Sep 17 00:00:00 2001 From: rogercc Date: Tue, 27 Aug 2024 13:27:49 -0400 Subject: [PATCH 07/13] initial commit cleanrl jax --- benchmarks/cleanrl_jax/Makefile | 31 ++ benchmarks/cleanrl_jax/README.md | 4 + benchmarks/cleanrl_jax/benchfile.py | 31 ++ benchmarks/cleanrl_jax/dev.yaml | 8 + benchmarks/cleanrl_jax/main.py | 522 +++++++++++++++++++++++++ benchmarks/cleanrl_jax/prepare.py | 16 + benchmarks/cleanrl_jax/requirements.in | 99 +++++ benchmarks/cleanrl_jax/voirfile.py | 38 ++ 8 files changed, 749 insertions(+) create mode 100644 benchmarks/cleanrl_jax/Makefile create mode 100644 benchmarks/cleanrl_jax/README.md create mode 100644 benchmarks/cleanrl_jax/benchfile.py create mode 100644 benchmarks/cleanrl_jax/dev.yaml create mode 100644 benchmarks/cleanrl_jax/main.py create mode 100755 benchmarks/cleanrl_jax/prepare.py create mode 100644 benchmarks/cleanrl_jax/requirements.in create mode 100644 benchmarks/cleanrl_jax/voirfile.py diff --git a/benchmarks/cleanrl_jax/Makefile b/benchmarks/cleanrl_jax/Makefile new file mode 100644 index 000000000..20c442249 --- /dev/null +++ b/benchmarks/cleanrl_jax/Makefile @@ -0,0 +1,31 @@ +# Use global base if possible +ifndef MILABENCH_BASE + MILABENCH_BASE="base" +endif + +export MILABENCH_BASE + +BENCH_NAME=cleanrl_jax +MILABENCH_CONFIG=dev.yaml +MILABENCH_ARGS=--config $(MILABENCH_CONFIG) --base $(MILABENCH_BASE) + +all: + install prepare single gpus nodes + +install: + milabench install $(MILABENCH_ARGS) --force + +prepare: + milabench prepare $(MILABENCH_ARGS) + +tests: install prepare + milabench run $(MILABENCH_ARGS) + +single: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME) + +gpus: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-gpus + +nodes: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-nodes diff --git a/benchmarks/cleanrl_jax/README.md b/benchmarks/cleanrl_jax/README.md new file mode 100644 index 000000000..255d805b7 --- /dev/null +++ b/benchmarks/cleanrl_jax/README.md @@ -0,0 +1,4 @@ + +# Cleanrl_jax + +Rewrite this README to explain what the benchmark is! diff --git a/benchmarks/cleanrl_jax/benchfile.py b/benchmarks/cleanrl_jax/benchfile.py new file mode 100644 index 000000000..afc5f5a65 --- /dev/null +++ b/benchmarks/cleanrl_jax/benchfile.py @@ -0,0 +1,31 @@ +from milabench.pack import Package + + +class Cleanrl_jax(Package): + # Requirements file installed by install(). It can be empty or absent. + base_requirements = "requirements.in" + + # The preparation script called by prepare(). It must be executable, + # but it can be any type of script. It can be empty or absent. + prepare_script = "prepare.py" + + # The main script called by run(). It must be a Python file. It has to + # be present. + main_script = "main.py" + + # You can remove the functions below if you don't need to modify them. + + def make_env(self): + # Return a dict of environment variables for prepare_script and + # main_script. + return super().make_env() + + async def install(self): + await super().install() # super() call installs the requirements + + async def prepare(self): + await super().prepare() # super() call executes prepare_script + + + +__pack__ = Cleanrl_jax diff --git a/benchmarks/cleanrl_jax/dev.yaml b/benchmarks/cleanrl_jax/dev.yaml new file mode 100644 index 000000000..37ea10198 --- /dev/null +++ b/benchmarks/cleanrl_jax/dev.yaml @@ -0,0 +1,8 @@ + +cleanrl_jax: + inherits: _defaults + definition: . + install-variant: unpinned + install_group: torch + plan: + method: per_gpu diff --git a/benchmarks/cleanrl_jax/main.py b/benchmarks/cleanrl_jax/main.py new file mode 100644 index 000000000..3d6597563 --- /dev/null +++ b/benchmarks/cleanrl_jax/main.py @@ -0,0 +1,522 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_xla_jaxpy +import os +import random +import time +from dataclasses import dataclass +from functools import partial +from typing import Sequence + +import envpool +import flax +import flax.linen as nn +import gym +import jax +import jax.numpy as jnp +import numpy as np +import optax +import tyro +from flax.linen.initializers import constant, orthogonal +from flax.training.train_state import TrainState +from torch.utils.tensorboard import SummaryWriter + +# Fix weird OOM https://github.com/google/jax/discussions/6332#discussioncomment-1279991 +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.6" +# Fix CUDNN non-determinisim; https://github.com/google/jax/issues/4823#issuecomment-952835771 +os.environ["TF_XLA_FLAGS"] = "--xla_gpu_autotune_level=2 --xla_gpu_deterministic_reductions" +os.environ["TF_CUDNN DETERMINISTIC"] = "1" + + +@dataclass +class Args: + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + torch_deterministic: bool = True + """if toggled, `torch.backends.cudnn.deterministic=False`""" + cuda: bool = True + """if toggled, cuda will be enabled by default""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanRL" + """the wandb's project name""" + wandb_entity: str = None + """the entity (team) of wandb's project""" + capture_video: bool = False + """whether to capture videos of the agent performances (check out `videos` folder)""" + save_model: bool = False + """whether to save model into the `runs/{run_name}` folder""" + upload_model: bool = False + """whether to upload the saved model to huggingface""" + hf_entity: str = "" + """the user or org name of the model repository from the Hugging Face Hub""" + + # Algorithm specific arguments + env_id: str = "Breakout-v5" + """the id of the environment""" + total_timesteps: int = 10000000 + """total timesteps of the experiments""" + learning_rate: float = 2.5e-4 + """the learning rate of the optimizer""" + num_envs: int = 8 + """the number of parallel game environments""" + num_steps: int = 128 + """the number of steps to run in each environment per policy rollout""" + anneal_lr: bool = True + """Toggle learning rate annealing for policy and value networks""" + gamma: float = 0.99 + """the discount factor gamma""" + gae_lambda: float = 0.95 + """the lambda for the general advantage estimation""" + num_minibatches: int = 4 + """the number of mini-batches""" + update_epochs: int = 4 + """the K epochs to update the policy""" + norm_adv: bool = True + """Toggles advantages normalization""" + clip_coef: float = 0.1 + """the surrogate clipping coefficient""" + clip_vloss: bool = True + """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" + ent_coef: float = 0.01 + """coefficient of the entropy""" + vf_coef: float = 0.5 + """coefficient of the value function""" + max_grad_norm: float = 0.5 + """the maximum norm for the gradient clipping""" + target_kl: float = None + """the target KL divergence threshold""" + + # to be filled in runtime + batch_size: int = 0 + """the batch size (computed in runtime)""" + minibatch_size: int = 0 + """the mini-batch size (computed in runtime)""" + num_iterations: int = 0 + """the number of iterations (computed in runtime)""" + + +def make_env(env_id, seed, num_envs): + def thunk(): + envs = envpool.make( + env_id, + env_type="gym", + num_envs=num_envs, + episodic_life=True, + reward_clip=True, + seed=seed, + ) + envs.num_envs = num_envs + envs.single_action_space = envs.action_space + envs.single_observation_space = envs.observation_space + envs.is_vector_env = True + return envs + + return thunk + + +class Network(nn.Module): + @nn.compact + def __call__(self, x): + x = jnp.transpose(x, (0, 2, 3, 1)) + x = x / (255.0) + x = nn.Conv( + 32, + kernel_size=(8, 8), + strides=(4, 4), + padding="VALID", + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(x) + x = nn.relu(x) + x = nn.Conv( + 64, + kernel_size=(4, 4), + strides=(2, 2), + padding="VALID", + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(x) + x = nn.relu(x) + x = nn.Conv( + 64, + kernel_size=(3, 3), + strides=(1, 1), + padding="VALID", + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(x) + x = nn.relu(x) + x = x.reshape((x.shape[0], -1)) + x = nn.Dense(512, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x) + x = nn.relu(x) + return x + + +class Critic(nn.Module): + @nn.compact + def __call__(self, x): + return nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(x) + + +class Actor(nn.Module): + action_dim: Sequence[int] + + @nn.compact + def __call__(self, x): + return nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(x) + + +@flax.struct.dataclass +class AgentParams: + network_params: flax.core.FrozenDict + actor_params: flax.core.FrozenDict + critic_params: flax.core.FrozenDict + + +@flax.struct.dataclass +class Storage: + obs: jnp.array + actions: jnp.array + logprobs: jnp.array + dones: jnp.array + values: jnp.array + advantages: jnp.array + returns: jnp.array + rewards: jnp.array + + +@flax.struct.dataclass +class EpisodeStatistics: + episode_returns: jnp.array + episode_lengths: jnp.array + returned_episode_returns: jnp.array + returned_episode_lengths: jnp.array + + +if __name__ == "__main__": + args = tyro.cli(Args) + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + args.num_iterations = args.total_timesteps // args.batch_size + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + key = jax.random.PRNGKey(args.seed) + key, network_key, actor_key, critic_key = jax.random.split(key, 4) + + # env setup + envs = make_env(args.env_id, args.seed, args.num_envs)() + episode_stats = EpisodeStatistics( + episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32), + episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32), + returned_episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32), + returned_episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32), + ) + handle, recv, send, step_env = envs.xla() + + def step_env_wrappeed(episode_stats, handle, action): + handle, (next_obs, reward, next_done, info) = step_env(handle, action) + new_episode_return = episode_stats.episode_returns + info["reward"] + new_episode_length = episode_stats.episode_lengths + 1 + episode_stats = episode_stats.replace( + episode_returns=(new_episode_return) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]), + episode_lengths=(new_episode_length) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]), + # only update the `returned_episode_returns` if the episode is done + returned_episode_returns=jnp.where( + info["terminated"] + info["TimeLimit.truncated"], new_episode_return, episode_stats.returned_episode_returns + ), + returned_episode_lengths=jnp.where( + info["terminated"] + info["TimeLimit.truncated"], new_episode_length, episode_stats.returned_episode_lengths + ), + ) + return episode_stats, handle, (next_obs, reward, next_done, info) + + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" + + def linear_schedule(count): + # anneal learning rate linearly after one training iteration which contains + # (args.num_minibatches * args.update_epochs) gradient updates + frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_iterations + return args.learning_rate * frac + + network = Network() + actor = Actor(action_dim=envs.single_action_space.n) + critic = Critic() + network_params = network.init(network_key, np.array([envs.single_observation_space.sample()])) + agent_state = TrainState.create( + apply_fn=None, + params=AgentParams( + network_params, + actor.init(actor_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))), + critic.init(critic_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))), + ), + tx=optax.chain( + optax.clip_by_global_norm(args.max_grad_norm), + optax.inject_hyperparams(optax.adam)( + learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5 + ), + ), + ) + network.apply = jax.jit(network.apply) + actor.apply = jax.jit(actor.apply) + critic.apply = jax.jit(critic.apply) + + @jax.jit + def get_action_and_value( + agent_state: TrainState, + next_obs: np.ndarray, + key: jax.random.PRNGKey, + ): + """sample action, calculate value, logprob, entropy, and update storage""" + hidden = network.apply(agent_state.params.network_params, next_obs) + logits = actor.apply(agent_state.params.actor_params, hidden) + # sample action: Gumbel-softmax trick + # see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution + key, subkey = jax.random.split(key) + u = jax.random.uniform(subkey, shape=logits.shape) + action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1) + logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] + value = critic.apply(agent_state.params.critic_params, hidden) + return action, logprob, value.squeeze(1), key + + @jax.jit + def get_action_and_value2( + params: flax.core.FrozenDict, + x: np.ndarray, + action: np.ndarray, + ): + """calculate value, logprob of supplied `action`, and entropy""" + hidden = network.apply(params.network_params, x) + logits = actor.apply(params.actor_params, hidden) + logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] + # normalize the logits https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/ + logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) + logits = logits.clip(min=jnp.finfo(logits.dtype).min) + p_log_p = logits * jax.nn.softmax(logits) + entropy = -p_log_p.sum(-1) + value = critic.apply(params.critic_params, hidden).squeeze() + return logprob, entropy, value + + def compute_gae_once(carry, inp, gamma, gae_lambda): + advantages = carry + nextdone, nextvalues, curvalues, reward = inp + nextnonterminal = 1.0 - nextdone + + delta = reward + gamma * nextvalues * nextnonterminal - curvalues + advantages = delta + gamma * gae_lambda * nextnonterminal * advantages + return advantages, advantages + + compute_gae_once = partial(compute_gae_once, gamma=args.gamma, gae_lambda=args.gae_lambda) + + @jax.jit + def compute_gae( + agent_state: TrainState, + next_obs: np.ndarray, + next_done: np.ndarray, + storage: Storage, + ): + next_value = critic.apply( + agent_state.params.critic_params, network.apply(agent_state.params.network_params, next_obs) + ).squeeze() + + advantages = jnp.zeros((args.num_envs,)) + dones = jnp.concatenate([storage.dones, next_done[None, :]], axis=0) + values = jnp.concatenate([storage.values, next_value[None, :]], axis=0) + _, advantages = jax.lax.scan( + compute_gae_once, advantages, (dones[1:], values[1:], values[:-1], storage.rewards), reverse=True + ) + storage = storage.replace( + advantages=advantages, + returns=advantages + storage.values, + ) + return storage + + def ppo_loss(params, x, a, logp, mb_advantages, mb_returns): + newlogprob, entropy, newvalue = get_action_and_value2(params, x, a) + logratio = newlogprob - logp + ratio = jnp.exp(logratio) + approx_kl = ((ratio - 1) - logratio).mean() + + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean() + + # Value loss + v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl)) + + ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) + + @jax.jit + def update_ppo( + agent_state: TrainState, + storage: Storage, + key: jax.random.PRNGKey, + ): + def update_epoch(carry, unused_inp): + agent_state, key = carry + key, subkey = jax.random.split(key) + + def flatten(x): + return x.reshape((-1,) + x.shape[2:]) + + # taken from: https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py + def convert_data(x: jnp.ndarray): + x = jax.random.permutation(subkey, x) + x = jnp.reshape(x, (args.num_minibatches, -1) + x.shape[1:]) + return x + + flatten_storage = jax.tree_map(flatten, storage) + shuffled_storage = jax.tree_map(convert_data, flatten_storage) + + def update_minibatch(agent_state, minibatch): + (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn( + agent_state.params, + minibatch.obs, + minibatch.actions, + minibatch.logprobs, + minibatch.advantages, + minibatch.returns, + ) + agent_state = agent_state.apply_gradients(grads=grads) + return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) + + agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) = jax.lax.scan( + update_minibatch, agent_state, shuffled_storage + ) + return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) + + (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) = jax.lax.scan( + update_epoch, (agent_state, key), (), length=args.update_epochs + ) + return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs = envs.reset() + next_done = jnp.zeros(args.num_envs, dtype=jax.numpy.bool_) + + # based on https://github.dev/google/evojax/blob/0625d875262011d8e1b6aa32566b236f44b4da66/evojax/sim_mgr.py + def step_once(carry, step, env_step_fn): + agent_state, episode_stats, obs, done, key, handle = carry + action, logprob, value, key = get_action_and_value(agent_state, obs, key) + + episode_stats, handle, (next_obs, reward, next_done, _) = env_step_fn(episode_stats, handle, action) + storage = Storage( + obs=obs, + actions=action, + logprobs=logprob, + dones=done, + values=value, + rewards=reward, + returns=jnp.zeros_like(reward), + advantages=jnp.zeros_like(reward), + ) + return ((agent_state, episode_stats, next_obs, next_done, key, handle), storage) + + def rollout(agent_state, episode_stats, next_obs, next_done, key, handle, step_once_fn, max_steps): + (agent_state, episode_stats, next_obs, next_done, key, handle), storage = jax.lax.scan( + step_once_fn, (agent_state, episode_stats, next_obs, next_done, key, handle), (), max_steps + ) + return agent_state, episode_stats, next_obs, next_done, storage, key, handle + + rollout = partial(rollout, step_once_fn=partial(step_once, env_step_fn=step_env_wrappeed), max_steps=args.num_steps) + + for iteration in range(1, args.num_iterations + 1): + iteration_time_start = time.time() + agent_state, episode_stats, next_obs, next_done, storage, key, handle = rollout( + agent_state, episode_stats, next_obs, next_done, key, handle + ) + global_step += args.num_steps * args.num_envs + storage = compute_gae(agent_state, next_obs, next_done, storage) + agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key = update_ppo( + agent_state, + storage, + key, + ) + avg_episodic_return = np.mean(jax.device_get(episode_stats.returned_episode_returns)) + print(f"global_step={global_step}, avg_episodic_return={avg_episodic_return}") + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step) + writer.add_scalar( + "charts/avg_episodic_length", np.mean(jax.device_get(episode_stats.returned_episode_lengths)), global_step + ) + writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"].item(), global_step) + writer.add_scalar("losses/value_loss", v_loss[-1, -1].item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss[-1, -1].item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss[-1, -1].item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl[-1, -1].item(), global_step) + writer.add_scalar("losses/loss", loss[-1, -1].item(), global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + writer.add_scalar( + "charts/SPS_update", int(args.num_envs * args.num_steps / (time.time() - iteration_time_start)), global_step + ) + + if args.save_model: + model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" + with open(model_path, "wb") as f: + f.write( + flax.serialization.to_bytes( + [ + vars(args), + [ + agent_state.params.network_params, + agent_state.params.actor_params, + agent_state.params.critic_params, + ], + ] + ) + ) + print(f"model saved to {model_path}") + from cleanrl_utils.evals.ppo_envpool_jax_eval import evaluate + + episodic_returns = evaluate( + model_path, + make_env, + args.env_id, + eval_episodes=10, + run_name=f"{run_name}-eval", + Model=(Network, Actor, Critic), + ) + for idx, episodic_return in enumerate(episodic_returns): + writer.add_scalar("eval/episodic_return", episodic_return, idx) + + if args.upload_model: + from cleanrl_utils.huggingface import push_to_hub + + repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" + repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name + push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval") + + envs.close() + writer.close() \ No newline at end of file diff --git a/benchmarks/cleanrl_jax/prepare.py b/benchmarks/cleanrl_jax/prepare.py new file mode 100755 index 000000000..32bd5901d --- /dev/null +++ b/benchmarks/cleanrl_jax/prepare.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python + +import os + +if __name__ == "__main__": + # If you need the whole configuration: + # config = json.loads(os.environ["MILABENCH_CONFIG"]) + + data_directory = os.environ["MILABENCH_DIR_DATA"] + + # Download (or generate) the needed dataset(s). You are responsible + # to check if it has already been properly downloaded or not, and to + # do nothing if it has been. + print("Hello I am doing some data stuff!") + + # If there is nothing to download or generate, just delete this file. diff --git a/benchmarks/cleanrl_jax/requirements.in b/benchmarks/cleanrl_jax/requirements.in new file mode 100644 index 000000000..77b50182c --- /dev/null +++ b/benchmarks/cleanrl_jax/requirements.in @@ -0,0 +1,99 @@ +voir>=0.2.17,<0.3 +absl-py==1.4.0 ; python_version >= "3.8" and python_version < "3.11" +appdirs==1.4.4 ; python_version >= "3.8" and python_version < "3.11" +cached-property==1.5.2 ; python_version >= "3.8" and python_version < "3.11" +cachetools==5.3.0 ; python_version >= "3.8" and python_version < "3.11" +certifi==2023.5.7 ; python_version >= "3.8" and python_version < "3.11" +charset-normalizer==3.1.0 ; python_version >= "3.8" and python_version < "3.11" +chex==0.1.5 ; python_version >= "3.8" and python_version < "3.11" +click==8.1.3 ; python_version >= "3.8" and python_version < "3.11" +cloudpickle==2.2.1 ; python_version >= "3.8" and python_version < "3.11" +colorama==0.4.4 ; python_version >= "3.8" and python_version < "3.11" +commonmark==0.9.1 ; python_version >= "3.8" and python_version < "3.11" +cycler==0.11.0 ; python_version >= "3.8" and python_version < "3.11" +decorator==4.4.2 ; python_version >= "3.8" and python_version < "3.11" +dm-tree==0.1.8 ; python_version >= "3.8" and python_version < "3.11" +docker-pycreds==0.4.0 ; python_version >= "3.8" and python_version < "3.11" +docstring-parser==0.15 ; python_version >= "3.8" and python_version < "3.11" +etils==0.9.0 ; python_version >= "3.8" and python_version < "3.11" +exceptiongroup==1.1.1 ; python_version >= "3.8" and python_version < "3.11" +farama-notifications==0.0.4 ; python_version >= "3.8" and python_version < "3.11" +filelock==3.12.0 ; python_version >= "3.8" and python_version < "3.11" +flax==0.6.8 ; python_version >= "3.8" and python_version < "3.11" +fonttools==4.38.0 ; python_version >= "3.8" and python_version < "3.11" +gitdb==4.0.10 ; python_version >= "3.8" and python_version < "3.11" +gitpython==3.1.31 ; python_version >= "3.8" and python_version < "3.11" +google-auth-oauthlib==0.4.6 ; python_version >= "3.8" and python_version < "3.11" +google-auth==2.18.0 ; python_version >= "3.8" and python_version < "3.11" +grpcio==1.54.0 ; python_version >= "3.8" and python_version < "3.11" +gym-notices==0.0.8 ; python_version >= "3.8" and python_version < "3.11" +gym==0.23.1 ; python_version >= "3.8" and python_version < "3.11" +gymnasium==0.28.1 ; python_version >= "3.8" and python_version < "3.11" +huggingface-hub==0.11.1 ; python_version >= "3.8" and python_version < "3.11" +idna==3.4 ; python_version >= "3.8" and python_version < "3.11" +imageio-ffmpeg==0.3.0 ; python_version >= "3.8" and python_version < "3.11" +imageio==2.28.1 ; python_version >= "3.8" and python_version < "3.11" +importlib-metadata==5.2.0 ; python_version >= "3.8" and python_version < "3.10" +importlib-resources==5.12.0 ; python_version >= "3.8" and python_version < "3.11" +iniconfig==2.0.0 ; python_version >= "3.8" and python_version < "3.11" +jax-jumpy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" +jax==0.4.8 ; python_version >= "3.8" and python_version < "3.11" +jaxlib==0.4.7 ; python_version >= "3.8" and python_version < "3.11" +kiwisolver==1.4.4 ; python_version >= "3.8" and python_version < "3.11" +markdown==3.3.7 ; python_version >= "3.8" and python_version < "3.11" +markupsafe==2.1.2 ; python_version >= "3.8" and python_version < "3.11" +matplotlib==3.5.3 ; python_version >= "3.8" and python_version < "3.11" +ml-dtypes==0.2.0 ; python_version >= "3.8" and python_version < "3.11" +moviepy==1.0.3 ; python_version >= "3.8" and python_version < "3.11" +msgpack==1.0.5 ; python_version >= "3.8" and python_version < "3.11" +numpy==1.24.4 ; python_version < "3.11" and python_version >= "3.8" +oauthlib==3.2.2 ; python_version >= "3.8" and python_version < "3.11" +opt-einsum==3.3.0 ; python_version >= "3.8" and python_version < "3.11" +optax==0.1.4 ; python_version >= "3.8" and python_version < "3.11" +orbax==0.1.0 ; python_version >= "3.8" and python_version < "3.11" +packaging==23.1 ; python_version >= "3.8" and python_version < "3.11" +pandas==1.3.5 ; python_version >= "3.8" and python_version < "3.11" +pathtools==0.1.2 ; python_version >= "3.8" and python_version < "3.11" +pillow==9.5.0 ; python_version >= "3.8" and python_version < "3.11" +pluggy==1.0.0 ; python_version >= "3.8" and python_version < "3.11" +proglog==0.1.10 ; python_version >= "3.8" and python_version < "3.11" +protobuf==3.20.3 ; python_version < "3.11" and python_version >= "3.8" +psutil +envpool +pyasn1-modules==0.3.0 ; python_version >= "3.8" and python_version < "3.11" +pyasn1==0.5.0 ; python_version >= "3.8" and python_version < "3.11" +pygame==2.1.0 ; python_version >= "3.8" and python_version < "3.11" +pygments==2.15.1 ; python_version >= "3.8" and python_version < "3.11" +pyparsing==3.0.9 ; python_version >= "3.8" and python_version < "3.11" +pytest==7.3.1 ; python_version >= "3.8" and python_version < "3.11" +python-dateutil==2.8.2 ; python_version >= "3.8" and python_version < "3.11" +pytz==2023.3 ; python_version >= "3.8" and python_version < "3.11" +pyyaml==6.0.1 ; python_version >= "3.8" and python_version < "3.11" +requests-oauthlib==1.3.1 ; python_version >= "3.8" and python_version < "3.11" +requests==2.30.0 ; python_version >= "3.8" and python_version < "3.11" +rich +rsa==4.7.2 ; python_version >= "3.8" and python_version < "3.11" +scipy==1.10.1 ; python_version >= "3.8" and python_version < "3.11" +sentry-sdk==1.22.2 ; python_version >= "3.8" and python_version < "3.11" +setproctitle==1.3.2 ; python_version >= "3.8" and python_version < "3.11" +setuptools==67.7.2 ; python_version >= "3.8" and python_version < "3.11" +shtab==1.6.4 ; python_version >= "3.8" and python_version < "3.11" +six==1.16.0 ; python_version >= "3.8" and python_version < "3.11" +smmap==5.0.0 ; python_version >= "3.8" and python_version < "3.11" +stable-baselines3==2.0.0 ; python_version >= "3.8" and python_version < "3.11" +tenacity==8.2.3 ; python_version >= "3.8" and python_version < "3.11" +tensorboard-data-server==0.6.1 ; python_version >= "3.8" and python_version < "3.11" +tensorboard-plugin-wit==1.8.1 ; python_version >= "3.8" and python_version < "3.11" +tensorboard==2.11.2 ; python_version >= "3.8" and python_version < "3.11" +tensorstore==0.1.28 ; python_version >= "3.8" and python_version < "3.11" +tomli==2.0.1 ; python_version >= "3.8" and python_version < "3.11" +toolz==0.12.0 ; python_version >= "3.8" and python_version < "3.11" +torch==1.12.1 ; python_version >= "3.8" and python_version < "3.11" +tqdm==4.65.0 ; python_version >= "3.8" and python_version < "3.11" +typing-extensions==4.5.0 ; python_version >= "3.8" and python_version < "3.11" +tyro==0.5.10 ; python_version >= "3.8" and python_version < "3.11" +urllib3==1.26.15 ; python_version >= "3.8" and python_version < "3.11" +wandb==0.13.11 ; python_version >= "3.8" and python_version < "3.11" +werkzeug==2.2.3 ; python_version >= "3.8" and python_version < "3.11" +wheel==0.40.0 ; python_version >= "3.8" and python_version < "3.11" +zipp==3.15.0 ; python_version >= "3.8" and python_version < "3.10" \ No newline at end of file diff --git a/benchmarks/cleanrl_jax/voirfile.py b/benchmarks/cleanrl_jax/voirfile.py new file mode 100644 index 000000000..d93f886cd --- /dev/null +++ b/benchmarks/cleanrl_jax/voirfile.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass + +from voir import configurable +from voir.instruments import dash, early_stop, log, rate +from benchmate.monitor import monitor_monogpu + +@dataclass +class Config: + """voir configuration""" + + # Whether to display the dash or not + dash: bool = False + + # How often to log the rates + interval: str = "1s" + + # Number of rates to skip before logging + skip: int = 5 + + # Number of rates to log before stopping + stop: int = 20 + + # Number of seconds between each gpu poll + gpu_poll: int = 3 + + +@configurable +def instrument_main(ov, options: Config): + yield ov.phases.init + + if options.dash: + ov.require(dash) + + ov.require( + log("value", "progress", "rate", "units", "loss", "gpudata", context="task"), + early_stop(n=options.stop, key="rate", task="train"), + monitor_monogpu(poll_interval=options.gpu_poll), + ) From c018fbb535802b733241bb3a512443e06e398751 Mon Sep 17 00:00:00 2001 From: Setepenre Date: Wed, 4 Sep 2024 10:22:39 -0400 Subject: [PATCH 08/13] Rl argparse (#264) * Initial commit with purejaxrl * Added instrumentation * Updated benchmark configs * Add argparse * update benchmark --------- Co-authored-by: Darshan Patil Co-authored-by: pierre.delaunay --- benchmarks/purejaxrl/Makefile | 31 +++ benchmarks/purejaxrl/README.md | 4 + benchmarks/purejaxrl/benchfile.py | 31 +++ benchmarks/purejaxrl/dev.yaml | 28 +++ benchmarks/purejaxrl/dqn.py | 325 ++++++++++++++++++++++++ benchmarks/purejaxrl/main.py | 37 +++ benchmarks/purejaxrl/ppo.py | 359 +++++++++++++++++++++++++++ benchmarks/purejaxrl/requirements.in | 15 ++ benchmarks/purejaxrl/voirfile.py | 38 +++ benchmarks/purejaxrl/wrappers.py | 349 ++++++++++++++++++++++++++ benchmate/benchmate/timings.py | 64 +++++ config/base.yaml | 19 +- milabench/_version.py | 8 +- 13 files changed, 1300 insertions(+), 8 deletions(-) create mode 100644 benchmarks/purejaxrl/Makefile create mode 100644 benchmarks/purejaxrl/README.md create mode 100644 benchmarks/purejaxrl/benchfile.py create mode 100644 benchmarks/purejaxrl/dev.yaml create mode 100644 benchmarks/purejaxrl/dqn.py create mode 100644 benchmarks/purejaxrl/main.py create mode 100644 benchmarks/purejaxrl/ppo.py create mode 100644 benchmarks/purejaxrl/requirements.in create mode 100644 benchmarks/purejaxrl/voirfile.py create mode 100644 benchmarks/purejaxrl/wrappers.py create mode 100644 benchmate/benchmate/timings.py diff --git a/benchmarks/purejaxrl/Makefile b/benchmarks/purejaxrl/Makefile new file mode 100644 index 000000000..48cd14f88 --- /dev/null +++ b/benchmarks/purejaxrl/Makefile @@ -0,0 +1,31 @@ +# Use global base if possible +ifndef MILABENCH_BASE + MILABENCH_BASE="base" +endif + +export MILABENCH_BASE + +BENCH_NAME=template +MILABENCH_CONFIG=dev.yaml +MILABENCH_ARGS=--config $(MILABENCH_CONFIG) --base $(MILABENCH_BASE) + +all: + install prepare single gpus nodes + +install: + milabench install $(MILABENCH_ARGS) --force + +prepare: + milabench prepare $(MILABENCH_ARGS) + +tests:# install prepare + CUDA_VISIBLE_DEVICES=0 milabench run $(MILABENCH_ARGS) + +single: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-single + +gpus: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-gpus + +nodes: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-nodes diff --git a/benchmarks/purejaxrl/README.md b/benchmarks/purejaxrl/README.md new file mode 100644 index 000000000..239a2dfaf --- /dev/null +++ b/benchmarks/purejaxrl/README.md @@ -0,0 +1,4 @@ + +# Template + +Rewrite this README to explain what the benchmark is! diff --git a/benchmarks/purejaxrl/benchfile.py b/benchmarks/purejaxrl/benchfile.py new file mode 100644 index 000000000..08a51cef0 --- /dev/null +++ b/benchmarks/purejaxrl/benchfile.py @@ -0,0 +1,31 @@ +from milabench.pack import Package + + +class Template(Package): + # Requirements file installed by install(). It can be empty or absent. + base_requirements = "requirements.in" + + # The preparation script called by prepare(). It must be executable, + # but it can be any type of script. It can be empty or absent. + prepare_script = "prepare.py" + + # The main script called by run(). It must be a Python file. It has to + # be present. + main_script = "main.py" + + # You can remove the functions below if you don't need to modify them. + + def make_env(self): + # Return a dict of environment variables for prepare_script and + # main_script. + return super().make_env() + + async def install(self): + await super().install() # super() call installs the requirements + + async def prepare(self): + await super().prepare() # super() call executes prepare_script + + + +__pack__ = Template diff --git a/benchmarks/purejaxrl/dev.yaml b/benchmarks/purejaxrl/dev.yaml new file mode 100644 index 000000000..ad1ebf87f --- /dev/null +++ b/benchmarks/purejaxrl/dev.yaml @@ -0,0 +1,28 @@ +_purejaxrl: + inherits: _defaults + definition: . + install-variant: unpinned + install_group: torch + plan: + method: per_gpu + +dqn: + inherits: _purejaxrl + argv: + dqn: true + --num_envs: auto({cpu_per_gpu}, 128) + --buffer_batch_size: 128 + --env_name: CartPole-v1 + --training_interval: 10 + --learning_starts: 10000 + +ppo: + inherits: _purejaxrl + argv: + ppo: true + --num_envs: auto({cpu_per_gpu}, 128) + --num_steps: 10 + --num_minibatches: 32 + --update_epochs: 4 + --env_name: hopper + --total_timesteps: 200000 \ No newline at end of file diff --git a/benchmarks/purejaxrl/dqn.py b/benchmarks/purejaxrl/dqn.py new file mode 100644 index 000000000..16fa55f52 --- /dev/null +++ b/benchmarks/purejaxrl/dqn.py @@ -0,0 +1,325 @@ +""" +PureJaxRL version of CleanRL's DQN: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_jax.py +""" +from dataclasses import dataclass +import time + +import jax +import jax.numpy as jnp +import chex +import flax +import optax +import flax.linen as nn +from flax.training.train_state import TrainState +from gymnax.wrappers.purerl import FlattenObservationWrapper, LogWrapper +import gymnax +import flashbax as fbx + +from benchmate.metrics import give_push + + + +class QNetwork(nn.Module): + action_dim: int + + @nn.compact + def __call__(self, x: jnp.ndarray): + x = nn.Dense(120)(x) + x = nn.relu(x) + x = nn.Dense(84)(x) + x = nn.relu(x) + x = nn.Dense(self.action_dim)(x) + return x + + +@chex.dataclass(frozen=True) +class TimeStep: + obs: chex.Array + action: chex.Array + reward: chex.Array + done: chex.Array + + +class CustomTrainState(TrainState): + target_network_params: flax.core.FrozenDict + timesteps: int + n_updates: int + + +def make_train(config): + config["NUM_UPDATES"] = config["TOTAL_TIMESTEPS"] // config["NUM_ENVS"] + + from benchmate.timings import StepTimer + step_timer = StepTimer(give_push()) + + basic_env, env_params = gymnax.make(config["ENV_NAME"]) + env = FlattenObservationWrapper(basic_env) + env = LogWrapper(env) + + vmap_reset = lambda n_envs: lambda rng: jax.vmap(env.reset, in_axes=(0, None))( + jax.random.split(rng, n_envs), env_params + ) + vmap_step = lambda n_envs: lambda rng, env_state, action: jax.vmap( + env.step, in_axes=(0, 0, 0, None) + )(jax.random.split(rng, n_envs), env_state, action, env_params) + + def train(rng): + + # INIT ENV + rng, _rng = jax.random.split(rng) + init_obs, env_state = vmap_reset(config["NUM_ENVS"])(_rng) + + # INIT BUFFER + buffer = fbx.make_flat_buffer( + max_length=config["BUFFER_SIZE"], + min_length=config["BUFFER_BATCH_SIZE"], + sample_batch_size=config["BUFFER_BATCH_SIZE"], + add_sequences=False, + add_batch_size=config["NUM_ENVS"], + ) + buffer = buffer.replace( + init=jax.jit(buffer.init), + add=jax.jit(buffer.add, donate_argnums=0), + sample=jax.jit(buffer.sample), + can_sample=jax.jit(buffer.can_sample), + ) + rng = jax.random.PRNGKey(0) # use a dummy rng here + _action = basic_env.action_space().sample(rng) + _, _env_state = env.reset(rng, env_params) + _obs, _, _reward, _done, _ = env.step(rng, _env_state, _action, env_params) + _timestep = TimeStep(obs=_obs, action=_action, reward=_reward, done=_done) + buffer_state = buffer.init(_timestep) + + # INIT NETWORK AND OPTIMIZER + network = QNetwork(action_dim=env.action_space(env_params).n) + rng, _rng = jax.random.split(rng) + init_x = jnp.zeros(env.observation_space(env_params).shape) + network_params = network.init(_rng, init_x) + + def linear_schedule(count): + frac = 1.0 - (count / config["NUM_UPDATES"]) + return config["LR"] * frac + + lr = linear_schedule if config.get("LR_LINEAR_DECAY", False) else config["LR"] + tx = optax.adam(learning_rate=lr) + + train_state = CustomTrainState.create( + apply_fn=network.apply, + params=network_params, + target_network_params=jax.tree_map(lambda x: jnp.copy(x), network_params), + tx=tx, + timesteps=0, + n_updates=0, + ) + + # epsilon-greedy exploration + def eps_greedy_exploration(rng, q_vals, t): + rng_a, rng_e = jax.random.split( + rng, 2 + ) # a key for sampling random actions and one for picking + eps = jnp.clip( # get epsilon + ( + (config["EPSILON_FINISH"] - config["EPSILON_START"]) + / config["EPSILON_ANNEAL_TIME"] + ) + * t + + config["EPSILON_START"], + config["EPSILON_FINISH"], + ) + greedy_actions = jnp.argmax(q_vals, axis=-1) # get the greedy actions + chosed_actions = jnp.where( + jax.random.uniform(rng_e, greedy_actions.shape) + < eps, # pick the actions that should be random + jax.random.randint( + rng_a, shape=greedy_actions.shape, minval=0, maxval=q_vals.shape[-1] + ), # sample random actions, + greedy_actions, + ) + return chosed_actions + + # TRAINING LOOP + def _update_step(runner_state, unused): + + train_state, buffer_state, env_state, last_obs, rng = runner_state + + # STEP THE ENV + rng, rng_a, rng_s = jax.random.split(rng, 3) + q_vals = network.apply(train_state.params, last_obs) + action = eps_greedy_exploration( + rng_a, q_vals, train_state.timesteps + ) # explore with epsilon greedy_exploration + obs, env_state, reward, done, info = vmap_step(config["NUM_ENVS"])( + rng_s, env_state, action + ) + train_state = train_state.replace( + timesteps=train_state.timesteps + config["NUM_ENVS"] + ) # update timesteps count + + # BUFFER UPDATE + timestep = TimeStep(obs=last_obs, action=action, reward=reward, done=done) + buffer_state = buffer.add(buffer_state, timestep) + + # NETWORKS UPDATE + def _learn_phase(train_state, rng): + + learn_batch = buffer.sample(buffer_state, rng).experience + + q_next_target = network.apply( + train_state.target_network_params, learn_batch.second.obs + ) # (batch_size, num_actions) + q_next_target = jnp.max(q_next_target, axis=-1) # (batch_size,) + target = ( + learn_batch.first.reward + + (1 - learn_batch.first.done) * config["GAMMA"] * q_next_target + ) + + def _loss_fn(params): + q_vals = network.apply( + params, learn_batch.first.obs + ) # (batch_size, num_actions) + chosen_action_qvals = jnp.take_along_axis( + q_vals, + jnp.expand_dims(learn_batch.first.action, axis=-1), + axis=-1, + ).squeeze(axis=-1) + return jnp.mean((chosen_action_qvals - target) ** 2) + + loss, grads = jax.value_and_grad(_loss_fn)(train_state.params) + train_state = train_state.apply_gradients(grads=grads) + train_state = train_state.replace(n_updates=train_state.n_updates + 1) + return train_state, loss + + rng, _rng = jax.random.split(rng) + is_learn_time = ( + (buffer.can_sample(buffer_state)) + & ( # enough experience in buffer + train_state.timesteps > config["LEARNING_STARTS"] + ) + & ( # pure exploration phase ended + train_state.timesteps % config["TRAINING_INTERVAL"] == 0 + ) # training interval + ) + train_state, loss = jax.lax.cond( + is_learn_time, + lambda train_state, rng: _learn_phase(train_state, rng), + lambda train_state, rng: (train_state, jnp.array(0.0)), # do nothing + train_state, + _rng, + ) + + # update target network + train_state = jax.lax.cond( + train_state.timesteps % config["TARGET_UPDATE_INTERVAL"] == 0, + lambda train_state: train_state.replace( + target_network_params=optax.incremental_update( + train_state.params, + train_state.target_network_params, + config["TAU"], + ) + ), + lambda train_state: train_state, + operand=train_state, + ) + + metrics = { + "timesteps": train_state.timesteps, + "updates": train_state.n_updates, + "loss": loss.mean(), + "returns": info["returned_episode_returns"].mean(), + } + + def callback(metrics): + # .block_until_ready() + if (metrics["timesteps"] + 1) % 1000: + returns = metrics["returns"].item() + loss = metrics["loss"].block_until_ready().item() + delta = metrics["timesteps"] - step_timer.timesteps + step_timer.timestep = metrics["timesteps"] + + step_timer.step(delta.item()) + step_timer.log(returns=returns, loss=loss) + step_timer.end() + + jax.debug.callback(callback, metrics) + + runner_state = (train_state, buffer_state, env_state, obs, rng) + + return runner_state, metrics + + # train + rng, _rng = jax.random.split(rng) + runner_state = (train_state, buffer_state, env_state, init_obs, _rng) + + runner_state, metrics = jax.lax.scan( + _update_step, runner_state, None, config["NUM_UPDATES"] + ) + return {"runner_state": runner_state, "metrics": metrics} + + return train + + +@dataclass +class Arguments: + num_envs: int = 10 + buffer_size: int = 10000 + buffer_batch_size: int = 128 + total_timesteps: int = 100_000 + epsilon_start: float = 1.0 + epsilon_finish: float = 0.05 + epsilon_anneal_time: int = 25e4 + target_update_interval: int = 500 + lr: float = 2.5e-4 + learning_starts: int = 10000 + training_interval: int = 10 + lr_linear_decay: bool = False + gamma: float = 0.99 + tau: float = 1.0 + env_name: str = "CartPole-v1" + seed: int = 0 + num_seeds: int = 1 + project: str = "" + + +def add_dqn_command(subparser): + parser = subparser.add_parser('dqn', help='RL dqn benchmark') + parser.add_arguments(Arguments) + + + +def main(args: Arguments = None): + if args is None: + args = Arguments() + + config = { + "NUM_ENVS": args.num_envs, + "BUFFER_SIZE": args.buffer_size, + "BUFFER_BATCH_SIZE": args.buffer_batch_size, + "TOTAL_TIMESTEPS": args.total_timesteps, + "EPSILON_START": args.epsilon_start, + "EPSILON_FINISH": args.epsilon_finish, + "EPSILON_ANNEAL_TIME": args.epsilon_anneal_time, + "TARGET_UPDATE_INTERVAL": args.target_update_interval, + "LR": args.lr, + "LEARNING_STARTS": args.learning_starts, + "TRAINING_INTERVAL": args.training_interval, + "LR_LINEAR_DECAY": args.lr_linear_decay, + "GAMMA": args.gamma, + "TAU": args.tau, + "ENV_NAME": args.env_name, + "SEED": args.seed, + "NUM_SEEDS": args.num_seeds, + "PROJECT": args.project, + } + + rng = jax.random.PRNGKey(config["SEED"]) + rngs = jax.random.split(rng, config["NUM_SEEDS"]) + train_vjit = jax.jit(jax.vmap(make_train(config), in_axes=(0,))) + compiled_fn = train_vjit.lower(rngs).compile() + + from benchmate.monitor import bench_monitor + with bench_monitor(): + outs = jax.block_until_ready(compiled_fn(rngs)) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/purejaxrl/main.py b/benchmarks/purejaxrl/main.py new file mode 100644 index 000000000..38eaf2792 --- /dev/null +++ b/benchmarks/purejaxrl/main.py @@ -0,0 +1,37 @@ +# This is the script run by milabench run (by default) + +# It is possible to use a script from a GitHub repo if it is cloned using +# clone_subtree in the benchfile.py, in which case this file can simply +# be deleted. + +import argparse +import argklass + + +from dqn import add_dqn_command, main as dqn_main +from ppo import add_ppo_command, main as ppo_main + + +def main(): + parser = argklass.ArgumentParser(description="PureJaxRL") + subparser = parser.add_subparsers(title="Benchmark", dest="benchmark") + + add_dqn_command(subparser) + add_ppo_command(subparser) + + bench = { + "dqn": dqn_main, + "ppo": ppo_main + } + + args = parser.parse_args() + + if benchmark := bench.get(args.benchmark): + benchmark(args) + + else: + raise ValueError(f"Unknown benchmark: {args.benchmark}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/purejaxrl/ppo.py b/benchmarks/purejaxrl/ppo.py new file mode 100644 index 000000000..a053373f3 --- /dev/null +++ b/benchmarks/purejaxrl/ppo.py @@ -0,0 +1,359 @@ +from dataclasses import dataclass +import time + +import jax +import jax.numpy as jnp +import flax.linen as nn +import numpy as np +import optax +from flax.linen.initializers import constant, orthogonal +from typing import Sequence, NamedTuple, Any +from flax.training.train_state import TrainState +import distrax +from benchmate.metrics import give_push + +from wrappers import ( + LogWrapper, + BraxGymnaxWrapper, + VecEnv, + NormalizeVecObservation, + NormalizeVecReward, + ClipAction, +) + + + + +class ActorCritic(nn.Module): + action_dim: Sequence[int] + activation: str = "tanh" + + @nn.compact + def __call__(self, x): + if self.activation == "relu": + activation = nn.relu + else: + activation = nn.tanh + actor_mean = nn.Dense( + 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(x) + actor_mean = activation(actor_mean) + actor_mean = nn.Dense( + 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(actor_mean) + actor_mean = activation(actor_mean) + actor_mean = nn.Dense( + self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) + )(actor_mean) + actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,)) + pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd)) + + critic = nn.Dense( + 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(x) + critic = activation(critic) + critic = nn.Dense( + 256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(critic) + critic = activation(critic) + critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( + critic + ) + + return pi, jnp.squeeze(critic, axis=-1) + + +class Transition(NamedTuple): + done: jnp.ndarray + action: jnp.ndarray + value: jnp.ndarray + reward: jnp.ndarray + log_prob: jnp.ndarray + obs: jnp.ndarray + info: jnp.ndarray + + +def make_train(config): + from benchmate.timings import StepTimer + step_timer = StepTimer(give_push()) + + config["NUM_UPDATES"] = ( + config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"] + ) + config["MINIBATCH_SIZE"] = ( + config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"] + ) + env, env_params = BraxGymnaxWrapper(config["ENV_NAME"]), None + env = LogWrapper(env) + env = ClipAction(env) + env = VecEnv(env) + if config["NORMALIZE_ENV"]: + env = NormalizeVecObservation(env) + env = NormalizeVecReward(env, config["GAMMA"]) + + def linear_schedule(count): + frac = ( + 1.0 + - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) + / config["NUM_UPDATES"] + ) + return config["LR"] * frac + + def train(rng): + # INIT NETWORK + network = ActorCritic( + env.action_space(env_params).shape[0], activation=config["ACTIVATION"] + ) + rng, _rng = jax.random.split(rng) + init_x = jnp.zeros(env.observation_space(env_params).shape) + network_params = network.init(_rng, init_x) + if config["ANNEAL_LR"]: + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(learning_rate=linear_schedule, eps=1e-5), + ) + else: + tx = optax.chain( + optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), + optax.adam(config["LR"], eps=1e-5), + ) + train_state = TrainState.create( + apply_fn=network.apply, + params=network_params, + tx=tx, + ) + + # INIT ENV + rng, _rng = jax.random.split(rng) + reset_rng = jax.random.split(_rng, config["NUM_ENVS"]) + obsv, env_state = env.reset(reset_rng, env_params) + + # TRAIN LOOP + def _update_step(runner_state, unused): + # COLLECT TRAJECTORIES + def _env_step(runner_state, unused): + train_state, env_state, last_obs, rng = runner_state + + # SELECT ACTION + rng, _rng = jax.random.split(rng) + pi, value = network.apply(train_state.params, last_obs) + action = pi.sample(seed=_rng) + log_prob = pi.log_prob(action) + + # STEP ENV + rng, _rng = jax.random.split(rng) + rng_step = jax.random.split(_rng, config["NUM_ENVS"]) + obsv, env_state, reward, done, info = env.step( + rng_step, env_state, action, env_params + ) + transition = Transition( + done, action, value, reward, log_prob, last_obs, info + ) + runner_state = (train_state, env_state, obsv, rng) + return runner_state, transition + + runner_state, traj_batch = jax.lax.scan( + _env_step, runner_state, None, config["NUM_STEPS"] + ) + + # CALCULATE ADVANTAGE + train_state, env_state, last_obs, rng = runner_state + _, last_val = network.apply(train_state.params, last_obs) + + def _calculate_gae(traj_batch, last_val): + def _get_advantages(gae_and_next_value, transition): + gae, next_value = gae_and_next_value + done, value, reward = ( + transition.done, + transition.value, + transition.reward, + ) + delta = reward + config["GAMMA"] * next_value * (1 - done) - value + gae = ( + delta + + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae + ) + return (gae, value), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value + + advantages, targets = _calculate_gae(traj_batch, last_val) + + # UPDATE NETWORK + def _update_epoch(update_state, unused): + def _update_minbatch(train_state, batch_info): + traj_batch, advantages, targets = batch_info + + def _loss_fn(params, traj_batch, gae, targets): + # RERUN NETWORK + pi, value = network.apply(params, traj_batch.obs) + log_prob = pi.log_prob(traj_batch.action) + + # CALCULATE VALUE LOSS + value_pred_clipped = traj_batch.value + ( + value - traj_batch.value + ).clip(-config["CLIP_EPS"], config["CLIP_EPS"]) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square(value_pred_clipped - targets) + value_loss = ( + 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() + ) + + # CALCULATE ACTOR LOSS + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae = (gae - gae.mean()) / (gae.std() + 1e-8) + loss_actor1 = ratio * gae + loss_actor2 = ( + jnp.clip( + ratio, + 1.0 - config["CLIP_EPS"], + 1.0 + config["CLIP_EPS"], + ) + * gae + ) + loss_actor = -jnp.minimum(loss_actor1, loss_actor2) + loss_actor = loss_actor.mean() + entropy = pi.entropy().mean() + + total_loss = ( + loss_actor + + config["VF_COEF"] * value_loss + - config["ENT_COEF"] * entropy + ) + return total_loss, (value_loss, loss_actor, entropy) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + total_loss, grads = grad_fn( + train_state.params, traj_batch, advantages, targets + ) + train_state = train_state.apply_gradients(grads=grads) + return train_state, total_loss + + train_state, traj_batch, advantages, targets, rng = update_state + rng, _rng = jax.random.split(rng) + batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"] + assert ( + batch_size == config["NUM_STEPS"] * config["NUM_ENVS"] + ), "batch size must be equal to number of steps * number of envs" + permutation = jax.random.permutation(_rng, batch_size) + batch = (traj_batch, advantages, targets) + batch = jax.tree_util.tree_map( + lambda x: x.reshape((batch_size,) + x.shape[2:]), batch + ) + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree_util.tree_map( + lambda x: jnp.reshape( + x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) + ), + shuffled_batch, + ) + train_state, total_loss = jax.lax.scan( + _update_minbatch, train_state, minibatches + ) + update_state = (train_state, traj_batch, advantages, targets, rng) + return update_state, total_loss + + update_state = (train_state, traj_batch, advantages, targets, rng) + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config["UPDATE_EPOCHS"] + ) + train_state = update_state[0] + metric = traj_batch.info + rng = update_state[-1] + + metrics = { + "loss": loss_info, + } + + def callback(info): + total_loss, (value_loss, loss_actor, entropy) = info["loss"] + loss = total_loss.mean().item() + + step_timer.step(config["NUM_ENVS"] * config["NUM_STEPS"]) + step_timer.log(loss=loss) + step_timer.end() + + jax.debug.callback(callback, metrics) + + runner_state = (train_state, env_state, last_obs, rng) + return runner_state, metric + + rng, _rng = jax.random.split(rng) + runner_state = (train_state, env_state, obsv, _rng) + runner_state, metric = jax.lax.scan( + _update_step, runner_state, None, config["NUM_UPDATES"] + ) + return {"runner_state": runner_state, "metrics": metric} + + return train + + + +@dataclass +class Arguments: + lr: float = 3e-4 + num_envs: int = 2048 + num_steps: int = 10 + total_timesteps: float = 50_000_000 + update_epochs: int = 4 + num_minibatches: int = 32 + gamma: float = 0.99 + gae_lambda: float = 0.95 + clip_eps: float = 0.2 + ent_coef: float = 0.0 + vf_coef: float = 0.5 + max_grad_norm: float = 0.5 + activation: str = "tanh" + env_name: str = "hopper" + anneal_lr: bool = False + normalize_env: bool = True + + +def add_ppo_command(subparser): + parser = subparser.add_parser('ppo', help='RL dqn benchmark') + parser.add_arguments(Arguments) + + +def main(args: Arguments = None): + if args is None: + args = Arguments() + + config = { + "LR": args.lr, + "NUM_ENVS": args.num_envs, + "NUM_STEPS": args.num_steps, + "TOTAL_TIMESTEPS": args.total_timesteps, + "UPDATE_EPOCHS": args.update_epochs, + "NUM_MINIBATCHES": args.num_minibatches, + "GAMMA": args.gamma, + "GAE_LAMBDA": args.gae_lambda, + "CLIP_EPS": args.clip_eps, + "ENT_COEF": args.ent_coef, + "VF_COEF": args.vf_coef, + "MAX_GRAD_NORM": args.max_grad_norm, + "ACTIVATION": args.activation, + "ENV_NAME": args.env_name, + "ANNEAL_LR": args.anneal_lr, + "NORMALIZE_ENV": args.normalize_env, + } + rng = jax.random.PRNGKey(30) + train_jit = jax.jit(make_train(config)) + compiled_fn = train_jit.lower(rng).compile() + + from benchmate.monitor import bench_monitor + + with bench_monitor(): + out = compiled_fn(rng) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/purejaxrl/requirements.in b/benchmarks/purejaxrl/requirements.in new file mode 100644 index 000000000..2d3f51759 --- /dev/null +++ b/benchmarks/purejaxrl/requirements.in @@ -0,0 +1,15 @@ +voir +torch +jax[cuda12] +--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +gymnax +evosax +distrax +optax +flax +numpy +brax +flashbax +navix +torch +argklass \ No newline at end of file diff --git a/benchmarks/purejaxrl/voirfile.py b/benchmarks/purejaxrl/voirfile.py new file mode 100644 index 000000000..e78cff32b --- /dev/null +++ b/benchmarks/purejaxrl/voirfile.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass + +from voir import configurable +from voir.instruments import dash, early_stop, log, rate +from benchmate.monitor import monitor_monogpu + +@dataclass +class Config: + """voir configuration""" + + # Whether to display the dash or not + dash: bool = False + + # How often to log the rates + interval: str = "1s" + + # Number of rates to skip before logging + skip: int = 5 + + # Number of rates to log before stopping + stop: int = 20 + + # Number of seconds between each gpu poll + gpu_poll: int = 0.5 + + +@configurable +def instrument_main(ov, options: Config): + yield ov.phases.init + + if options.dash: + ov.require(dash) + + ov.require( + log("value", "progress", "rate", "units", "loss", "gpudata", context="task"), + # early_stop(n=options.stop, key="rate", task="train"), + monitor_monogpu(poll_interval=options.gpu_poll), + ) diff --git a/benchmarks/purejaxrl/wrappers.py b/benchmarks/purejaxrl/wrappers.py new file mode 100644 index 000000000..81b397211 --- /dev/null +++ b/benchmarks/purejaxrl/wrappers.py @@ -0,0 +1,349 @@ +import jax +import jax.numpy as jnp +import chex +import numpy as np +from flax import struct +from functools import partial +from typing import Optional, Tuple, Union, Any +from gymnax.environments import environment, spaces +from brax import envs +from brax.envs.wrappers.training import EpisodeWrapper, AutoResetWrapper +import navix as nx + + +class GymnaxWrapper(object): + """Base class for Gymnax wrappers.""" + + def __init__(self, env): + self._env = env + + # provide proxy access to regular attributes of wrapped object + def __getattr__(self, name): + return getattr(self._env, name) + + +class FlattenObservationWrapper(GymnaxWrapper): + """Flatten the observations of the environment.""" + + def __init__(self, env: environment.Environment): + super().__init__(env) + + def observation_space(self, params) -> spaces.Box: + assert isinstance( + self._env.observation_space(params), spaces.Box + ), "Only Box spaces are supported for now." + return spaces.Box( + low=self._env.observation_space(params).low, + high=self._env.observation_space(params).high, + shape=(np.prod(self._env.observation_space(params).shape),), + dtype=self._env.observation_space(params).dtype, + ) + + @partial(jax.jit, static_argnums=(0,)) + def reset( + self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None + ) -> Tuple[chex.Array, environment.EnvState]: + obs, state = self._env.reset(key, params) + obs = jnp.reshape(obs, (-1,)) + return obs, state + + @partial(jax.jit, static_argnums=(0,)) + def step( + self, + key: chex.PRNGKey, + state: environment.EnvState, + action: Union[int, float], + params: Optional[environment.EnvParams] = None, + ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]: + obs, state, reward, done, info = self._env.step(key, state, action, params) + obs = jnp.reshape(obs, (-1,)) + return obs, state, reward, done, info + + +@struct.dataclass +class LogEnvState: + env_state: environment.EnvState + episode_returns: float + episode_lengths: int + returned_episode_returns: float + returned_episode_lengths: int + timestep: int + + +class LogWrapper(GymnaxWrapper): + """Log the episode returns and lengths.""" + + def __init__(self, env: environment.Environment): + super().__init__(env) + + @partial(jax.jit, static_argnums=(0,)) + def reset( + self, key: chex.PRNGKey, params: Optional[environment.EnvParams] = None + ) -> Tuple[chex.Array, environment.EnvState]: + obs, env_state = self._env.reset(key, params) + state = LogEnvState(env_state, 0, 0, 0, 0, 0) + return obs, state + + @partial(jax.jit, static_argnums=(0,)) + def step( + self, + key: chex.PRNGKey, + state: environment.EnvState, + action: Union[int, float], + params: Optional[environment.EnvParams] = None, + ) -> Tuple[chex.Array, environment.EnvState, float, bool, dict]: + obs, env_state, reward, done, info = self._env.step( + key, state.env_state, action, params + ) + new_episode_return = state.episode_returns + reward + new_episode_length = state.episode_lengths + 1 + state = LogEnvState( + env_state=env_state, + episode_returns=new_episode_return * (1 - done), + episode_lengths=new_episode_length * (1 - done), + returned_episode_returns=state.returned_episode_returns * (1 - done) + + new_episode_return * done, + returned_episode_lengths=state.returned_episode_lengths * (1 - done) + + new_episode_length * done, + timestep=state.timestep + 1, + ) + info["returned_episode_returns"] = state.returned_episode_returns + info["returned_episode_lengths"] = state.returned_episode_lengths + info["timestep"] = state.timestep + info["returned_episode"] = done + return obs, state, reward, done, info + + +class BraxGymnaxWrapper: + def __init__(self, env_name, backend="positional"): + env = envs.get_environment(env_name=env_name, backend=backend) + env = EpisodeWrapper(env, episode_length=1000, action_repeat=1) + env = AutoResetWrapper(env) + self._env = env + self.action_size = env.action_size + self.observation_size = (env.observation_size,) + + def reset(self, key, params=None): + state = self._env.reset(key) + return state.obs, state + + def step(self, key, state, action, params=None): + next_state = self._env.step(state, action) + return next_state.obs, next_state, next_state.reward, next_state.done > 0.5, {} + + def observation_space(self, params): + return spaces.Box( + low=-jnp.inf, + high=jnp.inf, + shape=(self._env.observation_size,), + ) + + def action_space(self, params): + return spaces.Box( + low=-1.0, + high=1.0, + shape=(self._env.action_size,), + ) + + +class NavixGymnaxWrapper: + def __init__(self, env_name): + self._env = nx.make(env_name) + + def reset(self, key, params=None): + timestep = self._env.reset(key) + return timestep.observation, timestep + + def step(self, key, state, action, params=None): + timestep = self._env.step(state, action) + return timestep.observation, timestep, timestep.reward, timestep.is_done(), {} + + def observation_space(self, params): + return spaces.Box( + low=self._env.observation_space.minimum, + high=self._env.observation_space.maximum, + shape=(np.prod(self._env.observation_space.shape),), + dtype=self._env.observation_space.dtype, + ) + + def action_space(self, params): + return spaces.Discrete( + num_categories=self._env.action_space.maximum.item() + 1, + ) + + +class ClipAction(GymnaxWrapper): + def __init__(self, env, low=-1.0, high=1.0): + super().__init__(env) + self.low = low + self.high = high + + def step(self, key, state, action, params=None): + """TODO: In theory the below line should be the way to do this.""" + # action = jnp.clip(action, self.env.action_space.low, self.env.action_space.high) + action = jnp.clip(action, self.low, self.high) + return self._env.step(key, state, action, params) + + +class TransformObservation(GymnaxWrapper): + def __init__(self, env, transform_obs): + super().__init__(env) + self.transform_obs = transform_obs + + def reset(self, key, params=None): + obs, state = self._env.reset(key, params) + return self.transform_obs(obs), state + + def step(self, key, state, action, params=None): + obs, state, reward, done, info = self._env.step(key, state, action, params) + return self.transform_obs(obs), state, reward, done, info + + +class TransformReward(GymnaxWrapper): + def __init__(self, env, transform_reward): + super().__init__(env) + self.transform_reward = transform_reward + + def step(self, key, state, action, params=None): + obs, state, reward, done, info = self._env.step(key, state, action, params) + return obs, state, self.transform_reward(reward), done, info + + +class VecEnv(GymnaxWrapper): + def __init__(self, env): + super().__init__(env) + self.reset = jax.vmap(self._env.reset, in_axes=(0, None)) + self.step = jax.vmap(self._env.step, in_axes=(0, 0, 0, None)) + + +@struct.dataclass +class NormalizeVecObsEnvState: + mean: jnp.ndarray + var: jnp.ndarray + count: float + env_state: environment.EnvState + + +class NormalizeVecObservation(GymnaxWrapper): + def __init__(self, env): + super().__init__(env) + + def reset(self, key, params=None): + obs, state = self._env.reset(key, params) + state = NormalizeVecObsEnvState( + mean=jnp.zeros_like(obs), + var=jnp.ones_like(obs), + count=1e-4, + env_state=state, + ) + batch_mean = jnp.mean(obs, axis=0) + batch_var = jnp.var(obs, axis=0) + batch_count = obs.shape[0] + + delta = batch_mean - state.mean + tot_count = state.count + batch_count + + new_mean = state.mean + delta * batch_count / tot_count + m_a = state.var * state.count + m_b = batch_var * batch_count + M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count + new_var = M2 / tot_count + new_count = tot_count + + state = NormalizeVecObsEnvState( + mean=new_mean, + var=new_var, + count=new_count, + env_state=state.env_state, + ) + + return (obs - state.mean) / jnp.sqrt(state.var + 1e-8), state + + def step(self, key, state, action, params=None): + obs, env_state, reward, done, info = self._env.step( + key, state.env_state, action, params + ) + + batch_mean = jnp.mean(obs, axis=0) + batch_var = jnp.var(obs, axis=0) + batch_count = obs.shape[0] + + delta = batch_mean - state.mean + tot_count = state.count + batch_count + + new_mean = state.mean + delta * batch_count / tot_count + m_a = state.var * state.count + m_b = batch_var * batch_count + M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count + new_var = M2 / tot_count + new_count = tot_count + + state = NormalizeVecObsEnvState( + mean=new_mean, + var=new_var, + count=new_count, + env_state=env_state, + ) + return ( + (obs - state.mean) / jnp.sqrt(state.var + 1e-8), + state, + reward, + done, + info, + ) + + +@struct.dataclass +class NormalizeVecRewEnvState: + mean: jnp.ndarray + var: jnp.ndarray + count: float + return_val: float + env_state: environment.EnvState + + +class NormalizeVecReward(GymnaxWrapper): + def __init__(self, env, gamma): + super().__init__(env) + self.gamma = gamma + + def reset(self, key, params=None): + obs, state = self._env.reset(key, params) + batch_count = obs.shape[0] + state = NormalizeVecRewEnvState( + mean=0.0, + var=1.0, + count=1e-4, + return_val=jnp.zeros((batch_count,)), + env_state=state, + ) + return obs, state + + def step(self, key, state, action, params=None): + obs, env_state, reward, done, info = self._env.step( + key, state.env_state, action, params + ) + return_val = state.return_val * self.gamma * (1 - done) + reward + + batch_mean = jnp.mean(return_val, axis=0) + batch_var = jnp.var(return_val, axis=0) + batch_count = obs.shape[0] + + delta = batch_mean - state.mean + tot_count = state.count + batch_count + + new_mean = state.mean + delta * batch_count / tot_count + m_a = state.var * state.count + m_b = batch_var * batch_count + M2 = m_a + m_b + jnp.square(delta) * state.count * batch_count / tot_count + new_var = M2 / tot_count + new_count = tot_count + + state = NormalizeVecRewEnvState( + mean=new_mean, + var=new_var, + count=new_count, + return_val=return_val, + env_state=env_state, + ) + return obs, state, reward / jnp.sqrt(state.var + 1e-8), done, info diff --git a/benchmate/benchmate/timings.py b/benchmate/benchmate/timings.py new file mode 100644 index 000000000..a4b7343e3 --- /dev/null +++ b/benchmate/benchmate/timings.py @@ -0,0 +1,64 @@ +import os +import time + + +def getenv(name, type, default): + try: + return type(os.getenv(name, default)) + except: + return default + + +def total_observations(): + return ( + getenv("VOIR_EARLYSTOP_COUNT", int, 60) + + getenv("VOIR_EARLYSTOP_SKIP", int, 5) + ) + + +class StepTimer: + """ + + Examples + -------- + + .. code-block:: python + + step_timer = StepTimer() + for i in range(epochs): + + for i, batch in enumerate(data): + step_timer.step(batch.shape[0]) + step_timer.log(loss=...) + + if (i + 1) % grad_acc == 0: + optimizer.step() + step_timer.end() + + """ + def __init__(self, pusher, sync = lambda: None): + self.start_time = time.perf_counter() + self.end_time = 0 + self.n_size = 0 + self.n_obs = 0 + self.total_obs = total_observations() + self.pusher = pusher + self.sync = sync + self.timesteps = 0 + + def step(self, step_size): + """Log a batch size or work that was been done""" + self.n_size += step_size + + def end(self): + """Push a new perf observation""" + self.sync() + self.end_time = time.perf_counter() + self.pusher(rate=self.n_size/(self.end_time - self.start_time), units="items/s", task="train") + self.pusher(progress=(self.n_obs, self.total_obs), task="early_stop") + self.size = 0 + self.n_obs += 1 + self.start_time = self.end_time + + def log(self, **kwargs): + self.pusher(**kwargs) diff --git a/config/base.yaml b/config/base.yaml index 3d02f33e6..438485de9 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -647,12 +647,23 @@ llm-full-mp-nodes: requires_capabilities: - "len(nodes) >= ${num_machines}" +_purejaxrl: + inherits: _defaults + definition: ../benchmarks/purejaxrl + plan: + method: per_gpu +dqn: + inherits: _purejaxrl + argv: + --benchmark: dqn + +ppo: + inherits: _purejaxrl + argv: + --benchmark: ppo _geo_gnn: inherits: _defaults - definition: . - # FIXME: torch cluster is laging behind pytorch - # we are forced to use torch==2.3 instead of torch==2.4 install_group: gnn group: geo_gnn definition: ../benchmarks/geo_gnn @@ -666,7 +677,6 @@ dimenet: --num-samples: 10000 --use3d: True - recursiongfn: inherits: _defaults definition: ../benchmarks/recursiongfn @@ -682,7 +692,6 @@ recursiongfn: --layer_width: 128 --num_layers: 4 - torchatari: inherits: _defaults definition: ../benchmarks/torchatari diff --git a/milabench/_version.py b/milabench/_version.py index b6553ed7d..c3f78d9db 100644 --- a/milabench/_version.py +++ b/milabench/_version.py @@ -1,5 +1,7 @@ """This file is generated, do not modify""" -__tag__ = "v0.1.0-55-g4230ec2d" -__commit__ = "4230ec2d7c587501f8fa8496fba68b6985423a05" -__date__ = "2024-08-29 15:34:06 -0400" + +__tag__ = "v0.1.0-28-g8069946" +__commit__ = "8069946d331fb92090057d7eedd598515249521d" +__date__ = "2024-08-01 12:39:13 -0400" + From 2015b663da93ad8f7780a240ce7d4adf908b40a3 Mon Sep 17 00:00:00 2001 From: Setepenre Date: Wed, 4 Sep 2024 14:40:15 +0000 Subject: [PATCH 09/13] Tweaks --- milabench/_version.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/milabench/_version.py b/milabench/_version.py index b6553ed7d..ff9c3d47d 100644 --- a/milabench/_version.py +++ b/milabench/_version.py @@ -1,5 +1,5 @@ """This file is generated, do not modify""" -__tag__ = "v0.1.0-55-g4230ec2d" -__commit__ = "4230ec2d7c587501f8fa8496fba68b6985423a05" -__date__ = "2024-08-29 15:34:06 -0400" +__tag__ = "95f5fc9" +__commit__ = "95f5fc9c43e67751bdd9c3187e4b5b6e8b60ff6f" +__date__ = "2024-08-29 20:30:41 -0400" From ef74bf8b5ba4f122e73187695c28a902e355dab3 Mon Sep 17 00:00:00 2001 From: "pierre.delaunay" Date: Wed, 4 Sep 2024 12:45:25 -0400 Subject: [PATCH 10/13] Update command regression --- .gitignore | 1 + benchmarks/llm/tune | 1 + scripts/article/run_cuda.sh | 1 + .../test_command_reg_one_node.txt | 30 +++++++++++++++++++ .../test_command_reg_two_nodes.txt | 30 +++++++++++++++++++ 5 files changed, 63 insertions(+) create mode 160000 benchmarks/llm/tune diff --git a/.gitignore b/.gitignore index 90a1e78d7..1bc7f879c 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,7 @@ scripts/article/xpu/ dependencies/ benchmarks/recursiongfn/gflownet benchmarks/recursiongfn/logs/ +benchmarks/llm/tune/ scripts/inventory.yaml output/ diff --git a/benchmarks/llm/tune b/benchmarks/llm/tune new file mode 160000 index 000000000..a83eeff00 --- /dev/null +++ b/benchmarks/llm/tune @@ -0,0 +1 @@ +Subproject commit a83eeff0079a73ee04a11e8fc2573ed8f671b231 diff --git a/scripts/article/run_cuda.sh b/scripts/article/run_cuda.sh index 59e61a754..8acf3959c 100644 --- a/scripts/article/run_cuda.sh +++ b/scripts/article/run_cuda.sh @@ -79,6 +79,7 @@ fi if [ "$MILABENCH_PREPARE" -eq 0 ]; then cd $MILABENCH_WORDIR + milabench pin --variant cuda --from-scratch $ARGS # # Run the benchmakrs milabench run --system $MILABENCH_WORDIR/system.yaml "$@" diff --git a/tests/test_command_reg/test_command_reg_one_node.txt b/tests/test_command_reg/test_command_reg_one_node.txt index f3ff218ae..bb58ec9bf 100644 --- a/tests/test_command_reg/test_command_reg_one_node.txt +++ b/tests/test_command_reg/test_command_reg_one_node.txt @@ -474,6 +474,36 @@ time ( wait ) +echo "---" +echo "dqn" +echo "===" +time ( + CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + wait +) + +echo "---" +echo "ppo" +echo "===" +time ( + CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + wait +) + echo "---" echo "dimenet" echo "=======" diff --git a/tests/test_command_reg/test_command_reg_two_nodes.txt b/tests/test_command_reg/test_command_reg_two_nodes.txt index bda22033e..218330a91 100644 --- a/tests/test_command_reg/test_command_reg_two_nodes.txt +++ b/tests/test_command_reg/test_command_reg_two_nodes.txt @@ -477,6 +477,36 @@ time ( wait ) +echo "---" +echo "dqn" +echo "===" +time ( + CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + wait +) + +echo "---" +echo "ppo" +echo "===" +time ( + CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + wait +) + echo "---" echo "dimenet" echo "=======" From 4de3f7bbf0f5e2cd07e14dc21a57e2033ebcdb8a Mon Sep 17 00:00:00 2001 From: Setepenre Date: Wed, 4 Sep 2024 12:48:45 -0400 Subject: [PATCH 11/13] Tweaks 3 (#261) * Update pins for CUDA * Add tags and fix diffusion single --------- Co-authored-by: pierre.delaunay --- benchmarks/diffusion/main.py | 2 ++ config/base.yaml | 29 ++++++++++++++++++++++------- milabench/_version.py | 1 + milabench/scripts/vcs.py | 12 +++++++++--- 4 files changed, 34 insertions(+), 10 deletions(-) mode change 100644 => 100755 benchmarks/diffusion/main.py diff --git a/benchmarks/diffusion/main.py b/benchmarks/diffusion/main.py old mode 100644 new mode 100755 index 2b4fe9bfd..09513f606 --- a/benchmarks/diffusion/main.py +++ b/benchmarks/diffusion/main.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python + from dataclasses import dataclass from accelerate import Accelerator diff --git a/config/base.yaml b/config/base.yaml index 438485de9..11d50c018 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -394,9 +394,8 @@ _diffusion: inherits: _defaults definition: ../benchmarks/diffusion install_group: torch - plan: - method: njobs - n: 1 + tags: + - diffusion argv: --num_epochs: 5 @@ -408,10 +407,13 @@ diffusion-single: inherits: _diffusion num_machines: 1 plan: - method: njobs + method: per_gpu diffusion-gpus: inherits: _diffusion + plan: + method: njobs + n: 1 num_machines: 1 diffusion-nodes: @@ -426,6 +428,8 @@ _lightning: inherits: _defaults definition: ../benchmarks/lightning install_group: torch + tags: + - lightning argv: --epochs: 10 --num-workers: "auto({n_worker}, 8)" @@ -452,6 +456,9 @@ _dinov2: definition: ../benchmarks/dinov2 install_group: torch num_machines: 1 + tags: + - image + - transformer plan: method: njobs n: 1 @@ -505,7 +512,9 @@ _llm: voir: options: stop: 30 - + tags: + - nlp + - llm max_duration: 1200 num_machines: 1 inherits: _defaults @@ -517,7 +526,6 @@ llm-lora-single: inherits: _llm plan: method: per_gpu - argv: "{milabench_code}/recipes/lora_finetune_single_device.py": true --config: "{milabench_code}/configs/llama3_8B_lora_single_device.yaml" @@ -664,6 +672,10 @@ ppo: _geo_gnn: inherits: _defaults + tags: + - graph + # FIXME: torch cluster is laging behind pytorch + # we are forced to use torch==2.3 instead of torch==2.4 install_group: gnn group: geo_gnn definition: ../benchmarks/geo_gnn @@ -682,6 +694,8 @@ recursiongfn: definition: ../benchmarks/recursiongfn install_group: gnn group: recursiongfn_gnn + tags: + - graph plan: method: per_gpu @@ -698,7 +712,8 @@ torchatari: install_group: torch plan: method: per_gpu - + tags: + - rl argv: --num-minibatches: 16 --update-epochs: 4 diff --git a/milabench/_version.py b/milabench/_version.py index 70125045a..0202d13c4 100644 --- a/milabench/_version.py +++ b/milabench/_version.py @@ -3,3 +3,4 @@ __tag__ = "v0.1.0-28-g8069946" __commit__ = "8069946d331fb92090057d7eedd598515249521d" __date__ = "2024-08-01 12:39:13 -0400" + diff --git a/milabench/scripts/vcs.py b/milabench/scripts/vcs.py index 0f895f886..54bc7638d 100644 --- a/milabench/scripts/vcs.py +++ b/milabench/scripts/vcs.py @@ -26,10 +26,16 @@ def retrieve_git_versions(tag="", commit="", date=""): } +def version_file(): + return os.path.join(ROOT, "milabench", "_version.py") + def read_previous(): info = ["", "", ""] - - with open(os.path.join(ROOT, "milabench", "_version.py"), "r") as file: + + if not os.path.exists(version_file()): + return info + + with open(version_file(), "r") as file: for line in file.readlines(): if "tag" in line: _, v = line.split("=") @@ -49,7 +55,7 @@ def read_previous(): def update_version_file(): version_info = retrieve_git_versions(*read_previous()) - with open(os.path.join(ROOT, "milabench", "_version.py"), "w") as file: + with open(version_file(), "w") as file: file.write('"""') file.write("This file is generated, do not modify") file.write('"""\n\n') From 860749cf5ae262dfd18680cf9ed361e013264b52 Mon Sep 17 00:00:00 2001 From: "pierre.delaunay" Date: Wed, 4 Sep 2024 18:47:30 -0400 Subject: [PATCH 12/13] update config to make ppo & dqn work --- .pin/constraints-cuda-torch.txt | 21 +++------------------ benchmarks/purejaxrl/requirements.cuda.txt | 2 +- benchmarks/purejaxrl/requirements.in | 3 +-- config/base.yaml | 15 +++++++++++++-- scripts/article/run_cuda.sh | 1 - 5 files changed, 18 insertions(+), 24 deletions(-) diff --git a/.pin/constraints-cuda-torch.txt b/.pin/constraints-cuda-torch.txt index da1d8d5ab..f5d6ac6d5 100644 --- a/.pin/constraints-cuda-torch.txt +++ b/.pin/constraints-cuda-torch.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # -# pip-compile --output-file=.pin/constraints-cuda-torch.txt .pin/tmp-constraints.txt benchmarks/brax/requirements.in benchmarks/diffusion/requirements.in benchmarks/dinov2/requirements.in benchmarks/flops/requirements.in benchmarks/huggingface/requirements.in benchmarks/lightning/requirements.in benchmarks/llama/requirements.in benchmarks/llm/requirements.in benchmarks/super-slomo/requirements.in benchmarks/timm/requirements.in benchmarks/torchatari/requirements.in benchmarks/torchvision/requirements.in benchmarks/torchvision_ddp/requirements.in constraints/extra/torch.cuda.txt +# pip-compile --output-file=.pin/constraints-cuda-torch.txt .pin/tmp-constraints.txt benchmarks/brax/requirements.in benchmarks/diffusion/requirements.in benchmarks/dinov2/requirements.in benchmarks/flops/requirements.in benchmarks/huggingface/requirements.in benchmarks/lightning/requirements.in benchmarks/llm/requirements.in benchmarks/super-slomo/requirements.in benchmarks/timm/requirements.in benchmarks/torchatari/requirements.in benchmarks/torchvision/requirements.in benchmarks/torchvision_ddp/requirements.in constraints/extra/torch.cuda.txt # --extra-index-url https://pypi.ngc.nvidia.com --extra-index-url https://download.pytorch.org/whl/cu121 @@ -74,7 +74,6 @@ contextlib2==21.6.0 datasets==2.21.0 # via # -r benchmarks/diffusion/requirements.in - # -r benchmarks/llama/requirements.in # torchtune diffusers[torch]==0.30.2 # via -r benchmarks/diffusion/requirements.in @@ -101,8 +100,6 @@ etils[epath,epy]==1.9.4 # orbax-checkpoint executing==1.2.0 # via varname -fairscale==0.4.13 - # via -r benchmarks/llama/requirements.in farama-notifications==0.0.4 # via gymnasium filelock==3.15.4 @@ -114,8 +111,6 @@ filelock==3.15.4 # torch # transformers # triton -fire==0.6.0 - # via -r benchmarks/llama/requirements.in flask==3.0.3 # via # brax @@ -278,7 +273,6 @@ numpy==1.26.4 # diffusers # dm-env # envpool - # fairscale # fvcore # gym # gymnasium @@ -476,15 +470,12 @@ scipy==1.14.1 # jaxopt # mujoco-mjx sentencepiece==0.2.0 - # via - # -r benchmarks/llama/requirements.in - # torchtune + # via torchtune shtab==1.7.1 # via tyro six==1.16.0 # via # asttokens - # fire # ml-collections # python-dateutil # tensorboard @@ -505,9 +496,7 @@ tensorstore==0.1.64 # flax # orbax-checkpoint termcolor==2.4.0 - # via - # fire - # fvcore + # via fvcore tiktoken==0.7.0 # via torchtune tokenizers==0.19.1 @@ -521,7 +510,6 @@ torch==2.4.0+cu121 # -r benchmarks/flops/requirements.in # -r benchmarks/huggingface/requirements.in # -r benchmarks/lightning/requirements.in - # -r benchmarks/llama/requirements.in # -r benchmarks/llm/requirements.in # -r benchmarks/super-slomo/requirements.in # -r benchmarks/timm/requirements.in @@ -530,7 +518,6 @@ torch==2.4.0+cu121 # -r benchmarks/torchvision_ddp/requirements.in # accelerate # diffusers - # fairscale # lightning # pytorch-lightning # torchmetrics @@ -582,7 +569,6 @@ transformers==4.44.2 # via # -r benchmarks/diffusion/requirements.in # -r benchmarks/huggingface/requirements.in - # -r benchmarks/llama/requirements.in trimesh==4.4.8 # via # brax @@ -629,7 +615,6 @@ voir==0.2.19 # -r benchmarks/flops/requirements.in # -r benchmarks/huggingface/requirements.in # -r benchmarks/lightning/requirements.in - # -r benchmarks/llama/requirements.in # -r benchmarks/llm/requirements.in # -r benchmarks/super-slomo/requirements.in # -r benchmarks/timm/requirements.in diff --git a/benchmarks/purejaxrl/requirements.cuda.txt b/benchmarks/purejaxrl/requirements.cuda.txt index 5d65fd88c..9c7294cb8 100644 --- a/benchmarks/purejaxrl/requirements.cuda.txt +++ b/benchmarks/purejaxrl/requirements.cuda.txt @@ -479,7 +479,7 @@ toolz==0.12.1 # via chex torch==2.4.1+cu121 # via -r benchmarks/purejaxrl/requirements.in -trimesh==4.4.8 +trimesh==4.4.9 # via # brax # mujoco-mjx diff --git a/benchmarks/purejaxrl/requirements.in b/benchmarks/purejaxrl/requirements.in index 2d3f51759..dc225d151 100644 --- a/benchmarks/purejaxrl/requirements.in +++ b/benchmarks/purejaxrl/requirements.in @@ -1,7 +1,6 @@ voir torch -jax[cuda12] ---find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +jax gymnax evosax distrax diff --git a/config/base.yaml b/config/base.yaml index 0c729b169..4d93cef35 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -662,15 +662,26 @@ _purejaxrl: definition: ../benchmarks/purejaxrl plan: method: per_gpu + dqn: inherits: _purejaxrl argv: - --benchmark: dqn + dqn: true + --num_envs: auto({cpu_per_gpu}, 128) + --buffer_batch_size: 128 + --env_name: CartPole-v1 + --training_interval: 10 ppo: inherits: _purejaxrl argv: - --benchmark: ppo + ppo: true + --num_envs: auto({cpu_per_gpu}, 128) + --num_steps: 10 + --num_minibatches: 32 + --update_epochs: 4 + --env_name: hopper + --total_timesteps: 200000 _geo_gnn: inherits: _defaults diff --git a/scripts/article/run_cuda.sh b/scripts/article/run_cuda.sh index 8acf3959c..59e61a754 100644 --- a/scripts/article/run_cuda.sh +++ b/scripts/article/run_cuda.sh @@ -79,7 +79,6 @@ fi if [ "$MILABENCH_PREPARE" -eq 0 ]; then cd $MILABENCH_WORDIR - milabench pin --variant cuda --from-scratch $ARGS # # Run the benchmakrs milabench run --system $MILABENCH_WORDIR/system.yaml "$@" From 554f136eaf8cfdc89f3f0b81f4b80a41a494a424 Mon Sep 17 00:00:00 2001 From: "pierre.delaunay" Date: Thu, 5 Sep 2024 10:17:00 -0400 Subject: [PATCH 13/13] Update reg commands --- .../test_command_reg_one_node.txt | 32 +++++++++---------- .../test_command_reg_two_nodes.txt | 32 +++++++++---------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/tests/test_command_reg/test_command_reg_one_node.txt b/tests/test_command_reg/test_command_reg_one_node.txt index 54ef48067..fa898ed7c 100644 --- a/tests/test_command_reg/test_command_reg_one_node.txt +++ b/tests/test_command_reg/test_command_reg_one_node.txt @@ -492,14 +492,14 @@ echo "---" echo "dqn" echo "===" time ( - CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & - CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & - CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & - CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & - CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & - CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & - CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & - CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & wait ) @@ -507,14 +507,14 @@ echo "---" echo "ppo" echo "===" time ( - CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & - CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & - CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & - CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & - CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & - CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & - CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & - CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & wait ) diff --git a/tests/test_command_reg/test_command_reg_two_nodes.txt b/tests/test_command_reg/test_command_reg_two_nodes.txt index b88aa309c..97b3f683c 100644 --- a/tests/test_command_reg/test_command_reg_two_nodes.txt +++ b/tests/test_command_reg/test_command_reg_two_nodes.txt @@ -502,14 +502,14 @@ echo "---" echo "dqn" echo "===" time ( - CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & - CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & - CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & - CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & - CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & - CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & - CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & - CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark dqn & + CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & + CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py dqn --num_envs 128 --buffer_batch_size 128 --env_name CartPole-v1 --training_interval 10 & wait ) @@ -517,14 +517,14 @@ echo "---" echo "ppo" echo "===" time ( - CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & - CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & - CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & - CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & - CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & - CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & - CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & - CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py --benchmark ppo & + CUDA_VISIBLE_DEVICES=0 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=1 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=2 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=3 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=4 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=5 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=6 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & + CUDA_VISIBLE_DEVICES=7 python $SRC/milabench/benchmarks/purejaxrl/main.py ppo --num_envs 128 --num_steps 10 --num_minibatches 32 --update_epochs 4 --env_name hopper --total_timesteps 200000 & wait )