From ee2a4146691ac395db86b72ddca3fe8f057f2051 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 18 Jul 2024 14:34:57 -0400 Subject: [PATCH] WIP: Adding a PureJaxRL example Signed-off-by: Fabrice Normandin --- benchmarks/purejaxrl/README.md | 4 + benchmarks/purejaxrl/benchfile.py | 9 + benchmarks/purejaxrl/main.py | 325 +++++++++++++++++++++++++++ benchmarks/purejaxrl/prepare.py | 16 ++ benchmarks/purejaxrl/requirements.in | 13 ++ benchmarks/purejaxrl/voirfile.py | 49 ++++ 6 files changed, 416 insertions(+) create mode 100644 benchmarks/purejaxrl/README.md create mode 100644 benchmarks/purejaxrl/benchfile.py create mode 100644 benchmarks/purejaxrl/main.py create mode 100755 benchmarks/purejaxrl/prepare.py create mode 100644 benchmarks/purejaxrl/requirements.in create mode 100644 benchmarks/purejaxrl/voirfile.py diff --git a/benchmarks/purejaxrl/README.md b/benchmarks/purejaxrl/README.md new file mode 100644 index 000000000..ebce72bf5 --- /dev/null +++ b/benchmarks/purejaxrl/README.md @@ -0,0 +1,4 @@ + +# Benchmark + +Rewrite this README to explain what the benchmark is! diff --git a/benchmarks/purejaxrl/benchfile.py b/benchmarks/purejaxrl/benchfile.py new file mode 100644 index 000000000..0388956d6 --- /dev/null +++ b/benchmarks/purejaxrl/benchfile.py @@ -0,0 +1,9 @@ +from milabench.pack import Package + + +class BraxBenchmark(Package): + base_requirements = "requirements.in" + main_script = "main.py" + + +__pack__ = BraxBenchmark diff --git a/benchmarks/purejaxrl/main.py b/benchmarks/purejaxrl/main.py new file mode 100644 index 000000000..3e432c761 --- /dev/null +++ b/benchmarks/purejaxrl/main.py @@ -0,0 +1,325 @@ +"""PureJaxRL""" +import dataclasses +import os + +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" +from typing import NamedTuple + +import distrax +import flax.linen as nn +import gymnax +import jax +import jax.numpy as jnp +import numpy as np +import optax +from flax.linen.initializers import constant, orthogonal +from flax.training.train_state import TrainState +from gymnax.wrappers.purerl import FlattenObservationWrapper, LogWrapper + + +class ActorCritic(nn.Module): + action_dim: int + activation: str = "tanh" + + @nn.compact + def __call__(self, x): + if self.activation == "relu": + activation = nn.relu + else: + activation = nn.tanh + actor_mean = nn.Dense( + 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(x) + actor_mean = activation(actor_mean) + actor_mean = nn.Dense( + 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(actor_mean) + actor_mean = activation(actor_mean) + actor_mean = nn.Dense( + self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) + )(actor_mean) + pi = distrax.Categorical(logits=actor_mean) + + critic = nn.Dense( + 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(x) + critic = activation(critic) + critic = nn.Dense( + 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(critic) + critic = activation(critic) + critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( + critic + ) + + return pi, jnp.squeeze(critic, axis=-1) + + +class Transition(NamedTuple): + done: jnp.ndarray + action: jnp.ndarray + value: jnp.ndarray + reward: jnp.ndarray + log_prob: jnp.ndarray + obs: jnp.ndarray + info: jnp.ndarray + + +@dataclasses.dataclass(frozen=True) +class Config: + LR: float = 2.5e-4 + NUM_ENVS: int = 4 + NUM_STEPS: int = 128 + TOTAL_TIMESTEPS: int = int(5e5) + UPDATE_EPOCHS: int = 4 + NUM_MINIBATCHES: int = 4 + GAMMA: float = 0.99 + GAE_LAMBDA: float = 0.95 + CLIP_EPS: float = 0.2 + ENT_COEF: float = 0.01 + VF_COEF: float = 0.5 + MAX_GRAD_NORM: float = 0.5 + ACTIVATION: str = "tanh" + ENV_NAME: str = "CartPole-v1" + ANNEAL_LR: bool = True + + @property + def NUM_UPDATES(self) -> int: + return self.TOTAL_TIMESTEPS // self.NUM_STEPS // self.NUM_ENVS + + @property + def MINIBATCH_SIZE(self) -> int: + return self.NUM_ENVS * self.NUM_STEPS // self.NUM_MINIBATCHES + + +def make_train(config: Config): + env, env_params = gymnax.make(config.ENV_NAME) + env = FlattenObservationWrapper(env) + env = LogWrapper(env) + + def linear_schedule(count): + frac = ( + 1.0 + - (count // (config.NUM_MINIBATCHES * config.UPDATE_EPOCHS)) + / config.NUM_UPDATES + ) + return config.LR * frac + + def train(rng): + # INIT NETWORK + network = ActorCritic( + env.action_space(env_params).n, activation=config.ACTIVATION + ) + rng, _rng = jax.random.split(rng) + init_x = jnp.zeros(env.observation_space(env_params).shape) + network_params = network.init(_rng, init_x) + if config.ANNEAL_LR: + tx = optax.chain( + optax.clip_by_global_norm(config.MAX_GRAD_NORM), + optax.adam(learning_rate=linear_schedule, eps=1e-5), + ) + else: + tx = optax.chain( + optax.clip_by_global_norm(config.MAX_GRAD_NORM), + optax.adam(config.LR, eps=1e-5), + ) + train_state = TrainState.create( + apply_fn=network.apply, + params=network_params, + tx=tx, + ) + + # INIT ENV + rng, _rng = jax.random.split(rng) + reset_rng = jax.random.split(_rng, config.NUM_ENVS) + obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params) + + # TRAIN LOOP + def _update_step(runner_state, unused): + # COLLECT TRAJECTORIES + def _env_step(runner_state, unused): + train_state, env_state, last_obs, rng = runner_state + + # SELECT ACTION + rng, _rng = jax.random.split(rng) + pi, value = network.apply(train_state.params, last_obs) + action = pi.sample(seed=_rng) + log_prob = pi.log_prob(action) + + # STEP ENV + rng, _rng = jax.random.split(rng) + rng_step = jax.random.split(_rng, config.NUM_ENVS) + obsv, env_state, reward, done, info = jax.vmap( + env.step, in_axes=(0, 0, 0, None) + )(rng_step, env_state, action, env_params) + transition = Transition( + done, action, value, reward, log_prob, last_obs, info + ) + runner_state = (train_state, env_state, obsv, rng) + return runner_state, transition + + runner_state, traj_batch = jax.lax.scan( + _env_step, runner_state, None, config.NUM_STEPS + ) + + # CALCULATE ADVANTAGE + train_state, env_state, last_obs, rng = runner_state + _, last_val = network.apply(train_state.params, last_obs) + + def _calculate_gae(traj_batch, last_val): + def _get_advantages(gae_and_next_value, transition): + gae, next_value = gae_and_next_value + done, value, reward = ( + transition.done, + transition.value, + transition.reward, + ) + delta = reward + config.GAMMA * next_value * (1 - done) - value + gae = delta + config.GAMMA * config.GAE_LAMBDA * (1 - done) * gae + return (gae, value), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value + + advantages, targets = _calculate_gae(traj_batch, last_val) + + # UPDATE NETWORK + def _update_epoch(update_state, unused): + def _update_minbatch(train_state, batch_info): + traj_batch, advantages, targets = batch_info + + def _loss_fn(params, traj_batch, gae, targets): + # RERUN NETWORK + pi, value = network.apply(params, traj_batch.obs) + log_prob = pi.log_prob(traj_batch.action) + + # CALCULATE VALUE LOSS + value_pred_clipped = traj_batch.value + ( + value - traj_batch.value + ).clip(-config.CLIP_EPS, config.CLIP_EPS) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square(value_pred_clipped - targets) + value_loss = ( + 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() + ) + + # CALCULATE ACTOR LOSS + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae = (gae - gae.mean()) / (gae.std() + 1e-8) + loss_actor1 = ratio * gae + loss_actor2 = ( + jnp.clip( + ratio, + 1.0 - config.CLIP_EPS, + 1.0 + config.CLIP_EPS, + ) + * gae + ) + loss_actor = -jnp.minimum(loss_actor1, loss_actor2) + loss_actor = loss_actor.mean() + entropy = pi.entropy().mean() + + total_loss = ( + loss_actor + + config.VF_COEF * value_loss + - config.ENT_COEF * entropy + ) + return total_loss, (value_loss, loss_actor, entropy) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + total_loss, grads = grad_fn( + train_state.params, traj_batch, advantages, targets + ) + train_state = train_state.apply_gradients(grads=grads) + return train_state, total_loss + + train_state, traj_batch, advantages, targets, rng = update_state + rng, _rng = jax.random.split(rng) + batch_size = config.MINIBATCH_SIZE * config.NUM_MINIBATCHES + assert ( + batch_size == config.NUM_STEPS * config.NUM_ENVS + ), "batch size must be equal to number of steps * number of envs" + permutation = jax.random.permutation(_rng, batch_size) + batch = (traj_batch, advantages, targets) + batch = jax.tree_util.tree_map( + lambda x: x.reshape((batch_size,) + x.shape[2:]), batch + ) + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree_util.tree_map( + lambda x: jnp.reshape( + x, [config.NUM_MINIBATCHES, -1] + list(x.shape[1:]) + ), + shuffled_batch, + ) + train_state, total_loss = jax.lax.scan( + _update_minbatch, train_state, minibatches + ) + update_state = (train_state, traj_batch, advantages, targets, rng) + return update_state, total_loss + + update_state = (train_state, traj_batch, advantages, targets, rng) + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.UPDATE_EPOCHS + ) + train_state = update_state[0] + metric = traj_batch.info + rng = update_state[-1] + + runner_state = (train_state, env_state, last_obs, rng) + return runner_state, metric + + rng, _rng = jax.random.split(rng) + runner_state = (train_state, env_state, obsv, _rng) + runner_state, metric = jax.lax.scan( + _update_step, runner_state, None, config.NUM_UPDATES + ) + return {"runner_state": runner_state, "metrics": metric} + + return train + + +def main(): + rng = jax.random.PRNGKey(42) + + config = Config() + + train_jit = jax.jit(make_train(config)) + out = train_jit(rng) + + import time + + import matplotlib.pyplot as plt + + rng = jax.random.PRNGKey(42) + t0 = time.time() + out = jax.block_until_ready(train_jit(rng)) + print(f"time: {time.time() - t0:.2f} s") + plt.plot(out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1)) + plt.xlabel("Update Step") + plt.ylabel("Return") + plt.show() + + rng = jax.random.PRNGKey(42) + rngs = jax.random.split(rng, 256) + train_vjit = jax.jit(jax.vmap(make_train(config))) + t0 = time.time() + outs = jax.block_until_ready(train_vjit(rngs)) + print(f"time: {time.time() - t0:.2f} s") + + for i in range(256): + plt.plot(outs["metrics"]["returned_episode_returns"][i].mean(-1).reshape(-1)) + plt.xlabel("Update Step") + plt.ylabel("Return") + plt.show() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/purejaxrl/prepare.py b/benchmarks/purejaxrl/prepare.py new file mode 100755 index 000000000..32bd5901d --- /dev/null +++ b/benchmarks/purejaxrl/prepare.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python + +import os + +if __name__ == "__main__": + # If you need the whole configuration: + # config = json.loads(os.environ["MILABENCH_CONFIG"]) + + data_directory = os.environ["MILABENCH_DIR_DATA"] + + # Download (or generate) the needed dataset(s). You are responsible + # to check if it has already been properly downloaded or not, and to + # do nothing if it has been. + print("Hello I am doing some data stuff!") + + # If there is nothing to download or generate, just delete this file. diff --git a/benchmarks/purejaxrl/requirements.in b/benchmarks/purejaxrl/requirements.in new file mode 100644 index 000000000..79e4c8e9b --- /dev/null +++ b/benchmarks/purejaxrl/requirements.in @@ -0,0 +1,13 @@ +jax>=0.2.26 +jaxlib>=0.1.74 +gymnax +evosax +distrax +optax +flax +numpy +brax +wandb +flashbax +navix +voir>=0.2.10,<0.3 diff --git a/benchmarks/purejaxrl/voirfile.py b/benchmarks/purejaxrl/voirfile.py new file mode 100644 index 000000000..fce6f66d0 --- /dev/null +++ b/benchmarks/purejaxrl/voirfile.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass + +from voir import configurable +from voir.phase import StopProgram +from voir.instruments import dash, early_stop, gpu_monitor, log, rate +from benchmate.monitor import monitor_monogpu + + +@dataclass +class Config: + """voir configuration""" + + # Whether to display the dash or not + dash: bool = False + + # How often to log the rates + interval: str = "1s" + + # Number of rates to skip before logging + skip: int = 5 + + # Number of rates to log before stopping + stop: int = 20 + + # Number of seconds between each gpu poll + gpu_poll: int = 3 + + +@configurable +def instrument_main(ov, options: Config): + yield ov.phases.init + + if options.dash: + ov.require(dash) + + ov.require( + log("value", "progress", "rate", "units", "loss", "gpudata", context="task"), + rate( + interval=options.interval, + sync=None, + ), + early_stop(n=options.stop, key="rate", task="train"), + monitor_monogpu(poll_interval=options.gpu_poll), + ) + + try: + yield ov.phases.run_script + except StopProgram: + print("early stopped") \ No newline at end of file