Skip to content

Commit

Permalink
WIP: Adding a PureJaxRL example
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jul 18, 2024
1 parent 1c96794 commit ee2a414
Show file tree
Hide file tree
Showing 6 changed files with 416 additions and 0 deletions.
4 changes: 4 additions & 0 deletions benchmarks/purejaxrl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

# Benchmark

Rewrite this README to explain what the benchmark is!
9 changes: 9 additions & 0 deletions benchmarks/purejaxrl/benchfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from milabench.pack import Package


class BraxBenchmark(Package):
base_requirements = "requirements.in"
main_script = "main.py"


__pack__ = BraxBenchmark
325 changes: 325 additions & 0 deletions benchmarks/purejaxrl/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
"""PureJaxRL"""
import dataclasses
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
from typing import NamedTuple

import distrax
import flax.linen as nn
import gymnax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
from gymnax.wrappers.purerl import FlattenObservationWrapper, LogWrapper


class ActorCritic(nn.Module):
action_dim: int
activation: str = "tanh"

@nn.compact
def __call__(self, x):
if self.activation == "relu":
activation = nn.relu
else:
activation = nn.tanh
actor_mean = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(x)
actor_mean = activation(actor_mean)
actor_mean = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(actor_mean)
actor_mean = activation(actor_mean)
actor_mean = nn.Dense(
self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
)(actor_mean)
pi = distrax.Categorical(logits=actor_mean)

critic = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(x)
critic = activation(critic)
critic = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(critic)
critic = activation(critic)
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
critic
)

return pi, jnp.squeeze(critic, axis=-1)


class Transition(NamedTuple):
done: jnp.ndarray
action: jnp.ndarray
value: jnp.ndarray
reward: jnp.ndarray
log_prob: jnp.ndarray
obs: jnp.ndarray
info: jnp.ndarray


@dataclasses.dataclass(frozen=True)
class Config:
LR: float = 2.5e-4
NUM_ENVS: int = 4
NUM_STEPS: int = 128
TOTAL_TIMESTEPS: int = int(5e5)
UPDATE_EPOCHS: int = 4
NUM_MINIBATCHES: int = 4
GAMMA: float = 0.99
GAE_LAMBDA: float = 0.95
CLIP_EPS: float = 0.2
ENT_COEF: float = 0.01
VF_COEF: float = 0.5
MAX_GRAD_NORM: float = 0.5
ACTIVATION: str = "tanh"
ENV_NAME: str = "CartPole-v1"
ANNEAL_LR: bool = True

@property
def NUM_UPDATES(self) -> int:
return self.TOTAL_TIMESTEPS // self.NUM_STEPS // self.NUM_ENVS

@property
def MINIBATCH_SIZE(self) -> int:
return self.NUM_ENVS * self.NUM_STEPS // self.NUM_MINIBATCHES


def make_train(config: Config):
env, env_params = gymnax.make(config.ENV_NAME)
env = FlattenObservationWrapper(env)
env = LogWrapper(env)

def linear_schedule(count):
frac = (
1.0
- (count // (config.NUM_MINIBATCHES * config.UPDATE_EPOCHS))
/ config.NUM_UPDATES
)
return config.LR * frac

def train(rng):
# INIT NETWORK
network = ActorCritic(
env.action_space(env_params).n, activation=config.ACTIVATION
)
rng, _rng = jax.random.split(rng)
init_x = jnp.zeros(env.observation_space(env_params).shape)
network_params = network.init(_rng, init_x)
if config.ANNEAL_LR:
tx = optax.chain(
optax.clip_by_global_norm(config.MAX_GRAD_NORM),
optax.adam(learning_rate=linear_schedule, eps=1e-5),
)
else:
tx = optax.chain(
optax.clip_by_global_norm(config.MAX_GRAD_NORM),
optax.adam(config.LR, eps=1e-5),
)
train_state = TrainState.create(
apply_fn=network.apply,
params=network_params,
tx=tx,
)

# INIT ENV
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, config.NUM_ENVS)
obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)

# TRAIN LOOP
def _update_step(runner_state, unused):
# COLLECT TRAJECTORIES
def _env_step(runner_state, unused):
train_state, env_state, last_obs, rng = runner_state

# SELECT ACTION
rng, _rng = jax.random.split(rng)
pi, value = network.apply(train_state.params, last_obs)
action = pi.sample(seed=_rng)
log_prob = pi.log_prob(action)

# STEP ENV
rng, _rng = jax.random.split(rng)
rng_step = jax.random.split(_rng, config.NUM_ENVS)
obsv, env_state, reward, done, info = jax.vmap(
env.step, in_axes=(0, 0, 0, None)
)(rng_step, env_state, action, env_params)
transition = Transition(
done, action, value, reward, log_prob, last_obs, info
)
runner_state = (train_state, env_state, obsv, rng)
return runner_state, transition

runner_state, traj_batch = jax.lax.scan(
_env_step, runner_state, None, config.NUM_STEPS
)

# CALCULATE ADVANTAGE
train_state, env_state, last_obs, rng = runner_state
_, last_val = network.apply(train_state.params, last_obs)

def _calculate_gae(traj_batch, last_val):
def _get_advantages(gae_and_next_value, transition):
gae, next_value = gae_and_next_value
done, value, reward = (
transition.done,
transition.value,
transition.reward,
)
delta = reward + config.GAMMA * next_value * (1 - done) - value
gae = delta + config.GAMMA * config.GAE_LAMBDA * (1 - done) * gae
return (gae, value), gae

_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val)

# UPDATE NETWORK
def _update_epoch(update_state, unused):
def _update_minbatch(train_state, batch_info):
traj_batch, advantages, targets = batch_info

def _loss_fn(params, traj_batch, gae, targets):
# RERUN NETWORK
pi, value = network.apply(params, traj_batch.obs)
log_prob = pi.log_prob(traj_batch.action)

# CALCULATE VALUE LOSS
value_pred_clipped = traj_batch.value + (
value - traj_batch.value
).clip(-config.CLIP_EPS, config.CLIP_EPS)
value_losses = jnp.square(value - targets)
value_losses_clipped = jnp.square(value_pred_clipped - targets)
value_loss = (
0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
)

# CALCULATE ACTOR LOSS
ratio = jnp.exp(log_prob - traj_batch.log_prob)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
loss_actor1 = ratio * gae
loss_actor2 = (
jnp.clip(
ratio,
1.0 - config.CLIP_EPS,
1.0 + config.CLIP_EPS,
)
* gae
)
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
loss_actor = loss_actor.mean()
entropy = pi.entropy().mean()

total_loss = (
loss_actor
+ config.VF_COEF * value_loss
- config.ENT_COEF * entropy
)
return total_loss, (value_loss, loss_actor, entropy)

grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn(
train_state.params, traj_batch, advantages, targets
)
train_state = train_state.apply_gradients(grads=grads)
return train_state, total_loss

train_state, traj_batch, advantages, targets, rng = update_state
rng, _rng = jax.random.split(rng)
batch_size = config.MINIBATCH_SIZE * config.NUM_MINIBATCHES
assert (
batch_size == config.NUM_STEPS * config.NUM_ENVS
), "batch size must be equal to number of steps * number of envs"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree_map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree_map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree_map(
lambda x: jnp.reshape(
x, [config.NUM_MINIBATCHES, -1] + list(x.shape[1:])
),
shuffled_batch,
)
train_state, total_loss = jax.lax.scan(
_update_minbatch, train_state, minibatches
)
update_state = (train_state, traj_batch, advantages, targets, rng)
return update_state, total_loss

update_state = (train_state, traj_batch, advantages, targets, rng)
update_state, loss_info = jax.lax.scan(
_update_epoch, update_state, None, config.UPDATE_EPOCHS
)
train_state = update_state[0]
metric = traj_batch.info
rng = update_state[-1]

runner_state = (train_state, env_state, last_obs, rng)
return runner_state, metric

rng, _rng = jax.random.split(rng)
runner_state = (train_state, env_state, obsv, _rng)
runner_state, metric = jax.lax.scan(
_update_step, runner_state, None, config.NUM_UPDATES
)
return {"runner_state": runner_state, "metrics": metric}

return train


def main():
rng = jax.random.PRNGKey(42)

config = Config()

train_jit = jax.jit(make_train(config))
out = train_jit(rng)

import time

import matplotlib.pyplot as plt

rng = jax.random.PRNGKey(42)
t0 = time.time()
out = jax.block_until_ready(train_jit(rng))
print(f"time: {time.time() - t0:.2f} s")
plt.plot(out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1))
plt.xlabel("Update Step")
plt.ylabel("Return")
plt.show()

rng = jax.random.PRNGKey(42)
rngs = jax.random.split(rng, 256)
train_vjit = jax.jit(jax.vmap(make_train(config)))
t0 = time.time()
outs = jax.block_until_ready(train_vjit(rngs))
print(f"time: {time.time() - t0:.2f} s")

for i in range(256):
plt.plot(outs["metrics"]["returned_episode_returns"][i].mean(-1).reshape(-1))
plt.xlabel("Update Step")
plt.ylabel("Return")
plt.show()


if __name__ == "__main__":
main()
16 changes: 16 additions & 0 deletions benchmarks/purejaxrl/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/usr/bin/env python

import os

if __name__ == "__main__":
# If you need the whole configuration:
# config = json.loads(os.environ["MILABENCH_CONFIG"])

data_directory = os.environ["MILABENCH_DIR_DATA"]

# Download (or generate) the needed dataset(s). You are responsible
# to check if it has already been properly downloaded or not, and to
# do nothing if it has been.
print("Hello I am doing some data stuff!")

# If there is nothing to download or generate, just delete this file.
13 changes: 13 additions & 0 deletions benchmarks/purejaxrl/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
jax>=0.2.26
jaxlib>=0.1.74
gymnax
evosax
distrax
optax
flax
numpy
brax
wandb
flashbax
navix
voir>=0.2.10,<0.3
Loading

0 comments on commit ee2a414

Please sign in to comment.