From eed157a509567149d16d23bd4d721bc4cd547df3 Mon Sep 17 00:00:00 2001 From: Roger Creus <31919499+roger-creus@users.noreply.github.com> Date: Tue, 6 Aug 2024 11:19:11 -0400 Subject: [PATCH] Initial commit Torch_PPO_Cleanrl_Atari_Envpool (#243) * Initial commit Torch_PPO_Cleanrl_Atari_Envpool * Simplify requirements.in * Code now runs * instrumentation concept * Fix CPU scaling --------- Co-authored-by: Pierre Delaunay --- benchmarks/torch_ppo_atari_envpool/Makefile | 31 ++ benchmarks/torch_ppo_atari_envpool/README.md | 4 + .../torch_ppo_atari_envpool/benchfile.py | 31 ++ benchmarks/torch_ppo_atari_envpool/dev.yaml | 17 + benchmarks/torch_ppo_atari_envpool/main.py | 349 ++++++++++++++++++ benchmarks/torch_ppo_atari_envpool/prepare.py | 16 + .../torch_ppo_atari_envpool/requirements.in | 9 + .../torch_ppo_atari_envpool/voirfile.py | 87 +++++ .../mark_torch_ppo_atari_envpool | 0 milabench/_version.py | 8 +- milabench/sizer.py | 14 +- 11 files changed, 556 insertions(+), 10 deletions(-) create mode 100644 benchmarks/torch_ppo_atari_envpool/Makefile create mode 100644 benchmarks/torch_ppo_atari_envpool/README.md create mode 100644 benchmarks/torch_ppo_atari_envpool/benchfile.py create mode 100644 benchmarks/torch_ppo_atari_envpool/dev.yaml create mode 100644 benchmarks/torch_ppo_atari_envpool/main.py create mode 100755 benchmarks/torch_ppo_atari_envpool/prepare.py create mode 100644 benchmarks/torch_ppo_atari_envpool/requirements.in create mode 100644 benchmarks/torch_ppo_atari_envpool/voirfile.py create mode 100644 extra/torch_ppo_atari_envpool/mark_torch_ppo_atari_envpool diff --git a/benchmarks/torch_ppo_atari_envpool/Makefile b/benchmarks/torch_ppo_atari_envpool/Makefile new file mode 100644 index 000000000..81443ce2b --- /dev/null +++ b/benchmarks/torch_ppo_atari_envpool/Makefile @@ -0,0 +1,31 @@ +# Use global base if possible +ifndef MILABENCH_BASE + MILABENCH_BASE="base" +endif + +export MILABENCH_BASE + +BENCH_NAME=torch_ppo_atari_envpool +MILABENCH_CONFIG=dev.yaml +MILABENCH_ARGS=--config $(MILABENCH_CONFIG) --base $(MILABENCH_BASE) + +all: + install prepare single gpus nodes + +install: + milabench install $(MILABENCH_ARGS) --force + +prepare: + milabench prepare $(MILABENCH_ARGS) + +tests: + MILABENCH_CPU_AUTO=1 CUDA_VISIBLE_DEVICES=0,1 milabench run $(MILABENCH_ARGS) + +single: + MILABENCH_CPU_AUTO=1 milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME) + +gpus: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-gpus + +nodes: + milabench run $(MILABENCH_ARGS) --select $(BENCH_NAME)-nodes diff --git a/benchmarks/torch_ppo_atari_envpool/README.md b/benchmarks/torch_ppo_atari_envpool/README.md new file mode 100644 index 000000000..44de20162 --- /dev/null +++ b/benchmarks/torch_ppo_atari_envpool/README.md @@ -0,0 +1,4 @@ + +# Torch_ppo_atari_envpool + +Rewrite this README to explain what the benchmark is! diff --git a/benchmarks/torch_ppo_atari_envpool/benchfile.py b/benchmarks/torch_ppo_atari_envpool/benchfile.py new file mode 100644 index 000000000..5625f7ed9 --- /dev/null +++ b/benchmarks/torch_ppo_atari_envpool/benchfile.py @@ -0,0 +1,31 @@ +from milabench.pack import Package + + +class Torch_ppo_atari_envpool(Package): + # Requirements file installed by install(). It can be empty or absent. + base_requirements = "requirements.in" + + # The preparation script called by prepare(). It must be executable, + # but it can be any type of script. It can be empty or absent. + prepare_script = "prepare.py" + + # The main script called by run(). It must be a Python file. It has to + # be present. + main_script = "main.py" + + # You can remove the functions below if you don't need to modify them. + + def make_env(self): + # Return a dict of environment variables for prepare_script and + # main_script. + return super().make_env() + + async def install(self): + await super().install() # super() call installs the requirements + + async def prepare(self): + await super().prepare() # super() call executes prepare_script + + + +__pack__ = Torch_ppo_atari_envpool diff --git a/benchmarks/torch_ppo_atari_envpool/dev.yaml b/benchmarks/torch_ppo_atari_envpool/dev.yaml new file mode 100644 index 000000000..338bed075 --- /dev/null +++ b/benchmarks/torch_ppo_atari_envpool/dev.yaml @@ -0,0 +1,17 @@ + +torch_ppo_atari_envpool: + max_duration: 600 + inherits: _defaults + definition: . + install-variant: unpinned + install_group: torch + plan: + method: per_gpu + + argv: + --num-minibatches: 16 + --update-epochs: 4 + --num-steps: 128 + --num-envs: auto({cpu_per_gpu}, 128) + --total-timesteps: 1000000 + --env-id: Breakout-v5 \ No newline at end of file diff --git a/benchmarks/torch_ppo_atari_envpool/main.py b/benchmarks/torch_ppo_atari_envpool/main.py new file mode 100644 index 000000000..62c9b3a07 --- /dev/null +++ b/benchmarks/torch_ppo_atari_envpool/main.py @@ -0,0 +1,349 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpoolpy +import os +import random +import time +from collections import deque +from dataclasses import dataclass + +import envpool +import gym +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import tyro +from torch.distributions.categorical import Categorical +from torch.utils.tensorboard import SummaryWriter +import torchcompat.core as acc + +@dataclass +class Args: + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + torch_deterministic: bool = True + """if toggled, `torch.backends.cudnn.deterministic=False`""" + cuda: bool = True + """if toggled, cuda will be enabled by default""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "cleanRL" + """the wandb's project name""" + wandb_entity: str = None + """the entity (team) of wandb's project""" + capture_video: bool = False + """whether to capture videos of the agent performances (check out `videos` folder)""" + + # Algorithm specific arguments + env_id: str = "Breakout-v5" + """the id of the environment""" + total_timesteps: int = 10000000 + """total timesteps of the experiments""" + learning_rate: float = 2.5e-4 + """the learning rate of the optimizer""" + num_envs: int = 128 + """the number of parallel game environments""" + num_steps: int = 128 + """the number of steps to run in each environment per policy rollout""" + anneal_lr: bool = True + """Toggle learning rate annealing for policy and value networks""" + gamma: float = 0.99 + """the discount factor gamma""" + gae_lambda: float = 0.95 + """the lambda for the general advantage estimation""" + num_minibatches: int = 16 + """the number of mini-batches""" + update_epochs: int = 4 + """the K epochs to update the policy""" + norm_adv: bool = True + """Toggles advantages normalization""" + clip_coef: float = 0.1 + """the surrogate clipping coefficient""" + clip_vloss: bool = True + """Toggles whether or not to use a clipped loss for the value function, as per the paper.""" + ent_coef: float = 0.01 + """coefficient of the entropy""" + vf_coef: float = 0.5 + """coefficient of the value function""" + max_grad_norm: float = 0.5 + """the maximum norm for the gradient clipping""" + target_kl: float = None + """the target KL divergence threshold""" + + # to be filled in runtime + batch_size: int = 0 + """the batch size (computed in runtime)""" + minibatch_size: int = 0 + """the mini-batch size (computed in runtime)""" + num_iterations: int = 0 + """the number of iterations (computed in runtime)""" + + +class RecordEpisodeStatistics(gym.Wrapper): + def __init__(self, env, deque_size=100): + super().__init__(env) + self.num_envs = getattr(env, "num_envs", 1) + self.episode_returns = None + self.episode_lengths = None + + def reset(self, **kwargs): + observations = super().reset(**kwargs) + self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + self.lives = np.zeros(self.num_envs, dtype=np.int32) + self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + return observations + + def step(self, action): + observations, rewards, dones, infos = super().step(action) + self.episode_returns += infos["reward"] + self.episode_lengths += 1 + self.returned_episode_returns[:] = self.episode_returns + self.returned_episode_lengths[:] = self.episode_lengths + self.episode_returns *= 1 - infos["terminated"] + self.episode_lengths *= 1 - infos["terminated"] + infos["r"] = self.returned_episode_returns + infos["l"] = self.returned_episode_lengths + return ( + observations, + rewards, + dones, + infos, + ) + + +def layer_init(layer, std=np.sqrt(2), bias_const=0.0): + torch.nn.init.orthogonal_(layer.weight, std) + torch.nn.init.constant_(layer.bias, bias_const) + return layer + + +class Agent(nn.Module): + def __init__(self, envs): + super().__init__() + self.network = nn.Sequential( + layer_init(nn.Conv2d(4, 32, 8, stride=4)), + nn.ReLU(), + layer_init(nn.Conv2d(32, 64, 4, stride=2)), + nn.ReLU(), + layer_init(nn.Conv2d(64, 64, 3, stride=1)), + nn.ReLU(), + nn.Flatten(), + layer_init(nn.Linear(64 * 7 * 7, 512)), + nn.ReLU(), + ) + self.actor = layer_init(nn.Linear(512, envs.single_action_space.n), std=0.01) + self.critic = layer_init(nn.Linear(512, 1), std=1) + + def get_value(self, x): + return self.critic(self.network(x / 255.0)) + + def get_action_and_value(self, x, action=None): + hidden = self.network(x / 255.0) + logits = self.actor(hidden) + probs = Categorical(logits=logits) + if action is None: + action = probs.sample() + return action, probs.log_prob(action), probs.entropy(), self.critic(hidden) + + +def main(): + args = tyro.cli(Args) + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + args.num_iterations = args.total_timesteps // args.batch_size + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = acc.fetch_device(0) + + # env setup + envs = envpool.make( + args.env_id, + env_type="gym", + num_envs=args.num_envs, + episodic_life=True, + reward_clip=True, + seed=args.seed, + ) + envs.num_envs = args.num_envs + envs.single_action_space = envs.action_space + envs.single_observation_space = envs.observation_space + envs = RecordEpisodeStatistics(envs) + assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported" + + agent = Agent(envs).to(device) + optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) + + # ALGO Logic: Storage setup + obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) + actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) + logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) + rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) + dones = torch.zeros((args.num_steps, args.num_envs)).to(device) + values = torch.zeros((args.num_steps, args.num_envs)).to(device) + avg_returns = deque(maxlen=20) + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs = torch.Tensor(envs.reset()).to(device) + next_done = torch.zeros(args.num_envs).to(device) + iterations = range(1, args.num_iterations + 1) + + for iteration in iterations: + # Annealing the rate if instructed to do so. + if args.anneal_lr: + frac = 1.0 - (iteration - 1.0) / args.num_iterations + lrnow = frac * args.learning_rate + optimizer.param_groups[0]["lr"] = lrnow + + for step in range(0, args.num_steps): + global_step += args.num_envs + obs[step] = next_obs + dones[step] = next_done + + # ALGO LOGIC: action logic + with torch.no_grad(): + action, logprob, _, value = agent.get_action_and_value(next_obs) + values[step] = value.flatten() + actions[step] = action + logprobs[step] = logprob + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, reward, next_done, info = envs.step(action.cpu().numpy()) + rewards[step] = torch.tensor(reward).to(device).view(-1) + next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device) + + for idx, d in enumerate(next_done): + if d and info["lives"][idx] == 0: + # print(f"global_step={global_step}, episodic_return={info['r'][idx]}") + avg_returns.append(info["r"][idx]) + writer.add_scalar("charts/avg_episodic_return", np.average(avg_returns), global_step) + writer.add_scalar("charts/episodic_return", info["r"][idx], global_step) + writer.add_scalar("charts/episodic_length", info["l"][idx], global_step) + + # bootstrap value if not done + with torch.no_grad(): + next_value = agent.get_value(next_obs).reshape(1, -1) + advantages = torch.zeros_like(rewards).to(device) + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - dones[t + 1] + nextvalues = values[t + 1] + delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t] + advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + returns = advantages + values + + # flatten the batch + b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = logprobs.reshape(-1) + b_actions = actions.reshape((-1,) + envs.single_action_space.shape) + b_advantages = advantages.reshape(-1) + b_returns = returns.reshape(-1) + b_values = values.reshape(-1) + + # Optimizing the policy and value network + b_inds = np.arange(args.batch_size) + clipfracs = [] + for epoch in range(args.update_epochs): + np.random.shuffle(b_inds) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + + _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds]) + logratio = newlogprob - b_logprobs[mb_inds] + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] + + mb_advantages = b_advantages[mb_inds] + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if args.clip_vloss: + v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 + v_clipped = b_values[mb_inds] + torch.clamp( + newvalue - b_values[mb_inds], + -args.clip_coef, + args.clip_coef, + ) + v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) + optimizer.step() + + if args.target_kl is not None and approx_kl > args.target_kl: + break + + y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) + writer.add_scalar("losses/explained_variance", explained_var, global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + envs.close() + writer.close() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/benchmarks/torch_ppo_atari_envpool/prepare.py b/benchmarks/torch_ppo_atari_envpool/prepare.py new file mode 100755 index 000000000..32bd5901d --- /dev/null +++ b/benchmarks/torch_ppo_atari_envpool/prepare.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python + +import os + +if __name__ == "__main__": + # If you need the whole configuration: + # config = json.loads(os.environ["MILABENCH_CONFIG"]) + + data_directory = os.environ["MILABENCH_DIR_DATA"] + + # Download (or generate) the needed dataset(s). You are responsible + # to check if it has already been properly downloaded or not, and to + # do nothing if it has been. + print("Hello I am doing some data stuff!") + + # If there is nothing to download or generate, just delete this file. diff --git a/benchmarks/torch_ppo_atari_envpool/requirements.in b/benchmarks/torch_ppo_atari_envpool/requirements.in new file mode 100644 index 000000000..c264f5563 --- /dev/null +++ b/benchmarks/torch_ppo_atari_envpool/requirements.in @@ -0,0 +1,9 @@ +envpool +gym==0.23.1 +numpy +torch +tyro +voir +tensorboard +torchcompat +cantilever diff --git a/benchmarks/torch_ppo_atari_envpool/voirfile.py b/benchmarks/torch_ppo_atari_envpool/voirfile.py new file mode 100644 index 000000000..7b8873852 --- /dev/null +++ b/benchmarks/torch_ppo_atari_envpool/voirfile.py @@ -0,0 +1,87 @@ +from dataclasses import dataclass + +from voir import configurable +from voir.phase import StopProgram +from benchmate.observer import BenchObserver +from benchmate.monitor import voirfile_monitor + + +@dataclass +class Config: + """voir configuration""" + + # Whether to display the dash or not + dash: bool = False + + # How often to log the rates + interval: str = "1s" + + # Number of rates to skip before logging + skip: int = 5 + + # Number of rates to log before stopping + stop: int = 20 + + # Number of seconds between each gpu poll + gpu_poll: int = 3 + + +@configurable +def instrument_main(ov, options: Config): + yield ov.phases.init + + # GPU monitor, rate, loss etc... + voirfile_monitor(ov, options) + + yield ov.phases.load_script + + step_per_iteration = 0 + + def fetch_args(args): + nonlocal step_per_iteration + step_per_iteration = args.num_envs * args.num_steps + return args + + def batch_size(x): + return step_per_iteration + + observer = BenchObserver( + earlystop=options.stop + options.skip, + batch_size_fn=batch_size, + ) + + probe = ov.probe("//main > args", overridable=True) + probe['args'].override(fetch_args) + + # measure the time it took to execute the body + probe = ov.probe("//main > iterations", overridable=True) + probe['iterations'].override(observer.loader) + + # Too many losses + # probe = ov.probe("//main > loss", overridable=True) + # probe["loss"].override(observer.record_loss) + + def record_starts(writer): + old_add_scalar = writer.add_scalar + + def add_scalar(name, *values): + if name == "losses/value_loss": + observer.record_loss(values[0]) + old_add_scalar(name, *values) + + writer.add_scalar = add_scalar + return writer + + probe = ov.probe("//main > writer", overridable=True) + probe["writer"].override(record_starts) + + probe = ov.probe("//main > optimizer", overridable=True) + probe['optimizer'].override(observer.optimizer) + + # + # Run the benchmark + # + try: + yield ov.phases.run_script + except StopProgram: + print("early stopped") \ No newline at end of file diff --git a/extra/torch_ppo_atari_envpool/mark_torch_ppo_atari_envpool b/extra/torch_ppo_atari_envpool/mark_torch_ppo_atari_envpool new file mode 100644 index 000000000..e69de29bb diff --git a/milabench/_version.py b/milabench/_version.py index 23cf810bc..d07e39f06 100644 --- a/milabench/_version.py +++ b/milabench/_version.py @@ -1,5 +1,7 @@ """This file is generated, do not modify""" -__tag__ = "v0.1.0-38-gfb01d691" -__commit__ = "fb01d691aa0d88717dcb3fea8852f61e111cc75f" -__date__ = "2024-08-01 18:59:13 -0400" + +__tag__ = "c7ae304" +__commit__ = "c7ae3043a12faef4da3eb0ddd6dc33e355b265fc" +__date__ = "2024-08-01 17:03:10 -0400" + diff --git a/milabench/sizer.py b/milabench/sizer.py index 2ae877213..2ae8ecd6b 100644 --- a/milabench/sizer.py +++ b/milabench/sizer.py @@ -339,16 +339,16 @@ def new_argument_resolver(pack): context = deepcopy(system_config) arch = context.get("arch", "cpu") + device_count_used = 1 + device_count_system = len(get_gpu_info()["gpus"]) if hasattr(pack, "config"): - device_count = len(pack.config.get("devices", [0])) - else: - device_count = len(get_gpu_info()["gpus"]) + device_count_used = len(pack.config.get("devices", [0])) - ccl = {"hpu": "hccl", "cuda": "nccl", "rocm": "rccl", "xpu": "ccl", "cpu": "gloo"} + if device_count_used <= 0: + device_count_used = 1 - if device_count <= 0: - device_count = 1 + ccl = {"hpu": "hccl", "cuda": "nccl", "rocm": "rccl", "xpu": "ccl", "cpu": "gloo"} cpu_opt = CPUOptions() def auto(value, default): @@ -363,7 +363,7 @@ def clamp(x, mn=cpu_opt.cpu_min, mx=cpu_opt.cpu_max): total_available = total_cpu - cpu_opt.reserved_cores context["cpu_count"] = total_available - context["cpu_per_gpu"] = total_available // device_count + context["cpu_per_gpu"] = total_available // device_count_system context["n_worker"] = clamp(context["cpu_per_gpu"]) if cpu_opt.n_workers is not None: