Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Adding a PureJaxRL example #237

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions benchmarks/purejaxrl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

# Benchmark

Rewrite this README to explain what the benchmark is!
9 changes: 9 additions & 0 deletions benchmarks/purejaxrl/benchfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from milabench.pack import Package


class PureJaxRLBenchmark(Package):
base_requirements = "requirements.in"
main_script = "main.py"


__pack__ = PureJaxRLBenchmark()
331 changes: 331 additions & 0 deletions benchmarks/purejaxrl/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,331 @@
"""PureJaxRL example based on https://github.com/luchris429/purejaxrl/blob/main/examples/walkthrough.ipynb

TODOs:
- This is a simple example on CartPole, it would be preferable to do something more serious (and for longer).
- Should probably try something like https://github.com/luchris429/purejaxrl/blob/main/purejaxrl/experimental/s5/ppo_s5.py

"""
import dataclasses
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
from typing import NamedTuple

import distrax
import flax.linen as nn
import gymnax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
from gymnax.wrappers.purerl import FlattenObservationWrapper, LogWrapper


class ActorCritic(nn.Module):
action_dim: int
activation: str = "tanh"

@nn.compact
def __call__(self, x):
if self.activation == "relu":
activation = nn.relu
else:
activation = nn.tanh
actor_mean = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(x)
actor_mean = activation(actor_mean)
actor_mean = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(actor_mean)
actor_mean = activation(actor_mean)
actor_mean = nn.Dense(
self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
)(actor_mean)
pi = distrax.Categorical(logits=actor_mean)

critic = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(x)
critic = activation(critic)
critic = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(critic)
critic = activation(critic)
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
critic
)

return pi, jnp.squeeze(critic, axis=-1)


class Transition(NamedTuple):
done: jnp.ndarray
action: jnp.ndarray
value: jnp.ndarray
reward: jnp.ndarray
log_prob: jnp.ndarray
obs: jnp.ndarray
info: jnp.ndarray


@dataclasses.dataclass(frozen=True)
class Config:
LR: float = 2.5e-4
NUM_ENVS: int = 4
NUM_STEPS: int = 128
TOTAL_TIMESTEPS: int = int(5e5)
UPDATE_EPOCHS: int = 4
NUM_MINIBATCHES: int = 4
GAMMA: float = 0.99
GAE_LAMBDA: float = 0.95
CLIP_EPS: float = 0.2
ENT_COEF: float = 0.01
VF_COEF: float = 0.5
MAX_GRAD_NORM: float = 0.5
ACTIVATION: str = "tanh"
ENV_NAME: str = "CartPole-v1"
ANNEAL_LR: bool = True

@property
def NUM_UPDATES(self) -> int:
return self.TOTAL_TIMESTEPS // self.NUM_STEPS // self.NUM_ENVS

@property
def MINIBATCH_SIZE(self) -> int:
return self.NUM_ENVS * self.NUM_STEPS // self.NUM_MINIBATCHES


def make_train(config: Config):
env, env_params = gymnax.make(config.ENV_NAME)
env = FlattenObservationWrapper(env)
env = LogWrapper(env)

def linear_schedule(count):
frac = (
1.0
- (count // (config.NUM_MINIBATCHES * config.UPDATE_EPOCHS))
/ config.NUM_UPDATES
)
return config.LR * frac

def train(rng):
# INIT NETWORK
network = ActorCritic(
env.action_space(env_params).n, activation=config.ACTIVATION
)
rng, _rng = jax.random.split(rng)
init_x = jnp.zeros(env.observation_space(env_params).shape)
network_params = network.init(_rng, init_x)
if config.ANNEAL_LR:
tx = optax.chain(
optax.clip_by_global_norm(config.MAX_GRAD_NORM),
optax.adam(learning_rate=linear_schedule, eps=1e-5),
)
else:
tx = optax.chain(
optax.clip_by_global_norm(config.MAX_GRAD_NORM),
optax.adam(config.LR, eps=1e-5),
)
train_state = TrainState.create(
apply_fn=network.apply,
params=network_params,
tx=tx,
)

# INIT ENV
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, config.NUM_ENVS)
obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)

# TRAIN LOOP
def _update_step(runner_state, unused):
# COLLECT TRAJECTORIES
def _env_step(runner_state, unused):
train_state, env_state, last_obs, rng = runner_state

# SELECT ACTION
rng, _rng = jax.random.split(rng)
pi, value = network.apply(train_state.params, last_obs)
action = pi.sample(seed=_rng)
log_prob = pi.log_prob(action)

# STEP ENV
rng, _rng = jax.random.split(rng)
rng_step = jax.random.split(_rng, config.NUM_ENVS)
obsv, env_state, reward, done, info = jax.vmap(
env.step, in_axes=(0, 0, 0, None)
)(rng_step, env_state, action, env_params)
transition = Transition(
done, action, value, reward, log_prob, last_obs, info
)
runner_state = (train_state, env_state, obsv, rng)
return runner_state, transition

runner_state, traj_batch = jax.lax.scan(
_env_step, runner_state, None, config.NUM_STEPS
)

# CALCULATE ADVANTAGE
train_state, env_state, last_obs, rng = runner_state
_, last_val = network.apply(train_state.params, last_obs)

def _calculate_gae(traj_batch, last_val):
def _get_advantages(gae_and_next_value, transition):
gae, next_value = gae_and_next_value
done, value, reward = (
transition.done,
transition.value,
transition.reward,
)
delta = reward + config.GAMMA * next_value * (1 - done) - value
gae = delta + config.GAMMA * config.GAE_LAMBDA * (1 - done) * gae
return (gae, value), gae

_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val)

# UPDATE NETWORK
def _update_epoch(update_state, unused):
def _update_minbatch(train_state, batch_info):
traj_batch, advantages, targets = batch_info

def _loss_fn(params, traj_batch, gae, targets):
# RERUN NETWORK
pi, value = network.apply(params, traj_batch.obs)
log_prob = pi.log_prob(traj_batch.action)

# CALCULATE VALUE LOSS
value_pred_clipped = traj_batch.value + (
value - traj_batch.value
).clip(-config.CLIP_EPS, config.CLIP_EPS)
value_losses = jnp.square(value - targets)
value_losses_clipped = jnp.square(value_pred_clipped - targets)
value_loss = (
0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
)

# CALCULATE ACTOR LOSS
ratio = jnp.exp(log_prob - traj_batch.log_prob)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
loss_actor1 = ratio * gae
loss_actor2 = (
jnp.clip(
ratio,
1.0 - config.CLIP_EPS,
1.0 + config.CLIP_EPS,
)
* gae
)
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
loss_actor = loss_actor.mean()
entropy = pi.entropy().mean()

total_loss = (
loss_actor
+ config.VF_COEF * value_loss
- config.ENT_COEF * entropy
)
return total_loss, (value_loss, loss_actor, entropy)

grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn(
train_state.params, traj_batch, advantages, targets
)
train_state = train_state.apply_gradients(grads=grads)
return train_state, total_loss

train_state, traj_batch, advantages, targets, rng = update_state
rng, _rng = jax.random.split(rng)
batch_size = config.MINIBATCH_SIZE * config.NUM_MINIBATCHES
assert (
batch_size == config.NUM_STEPS * config.NUM_ENVS
), "batch size must be equal to number of steps * number of envs"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree_map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree_map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
minibatches = jax.tree_util.tree_map(
lambda x: jnp.reshape(
x, [config.NUM_MINIBATCHES, -1] + list(x.shape[1:])
),
shuffled_batch,
)
train_state, total_loss = jax.lax.scan(
_update_minbatch, train_state, minibatches
)
update_state = (train_state, traj_batch, advantages, targets, rng)
return update_state, total_loss

update_state = (train_state, traj_batch, advantages, targets, rng)
update_state, loss_info = jax.lax.scan(
_update_epoch, update_state, None, config.UPDATE_EPOCHS
)
train_state = update_state[0]
metric = traj_batch.info
rng = update_state[-1]

runner_state = (train_state, env_state, last_obs, rng)
return runner_state, metric

rng, _rng = jax.random.split(rng)
runner_state = (train_state, env_state, obsv, _rng)
runner_state, metric = jax.lax.scan(
_update_step, runner_state, None, config.NUM_UPDATES
)
return {"runner_state": runner_state, "metrics": metric}

return train


def main():
rng = jax.random.PRNGKey(42)

config = Config()

train_jit = jax.jit(make_train(config))
out = train_jit(rng)

import time

import matplotlib.pyplot as plt

rng = jax.random.PRNGKey(42)
t0 = time.time()
out = jax.block_until_ready(train_jit(rng))
print(f"time: {time.time() - t0:.2f} s")
plt.plot(out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1))
plt.xlabel("Update Step")
plt.ylabel("Return")
plt.show()

rng = jax.random.PRNGKey(42)
rngs = jax.random.split(rng, 256)
train_vjit = jax.jit(jax.vmap(make_train(config)))
t0 = time.time()
outs = jax.block_until_ready(train_vjit(rngs))
print(f"time: {time.time() - t0:.2f} s")

for i in range(256):
plt.plot(outs["metrics"]["returned_episode_returns"][i].mean(-1).reshape(-1))
plt.xlabel("Update Step")
plt.ylabel("Return")
plt.show()


if __name__ == "__main__":
main()
16 changes: 16 additions & 0 deletions benchmarks/purejaxrl/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/usr/bin/env python

import os

if __name__ == "__main__":
# If you need the whole configuration:
# config = json.loads(os.environ["MILABENCH_CONFIG"])

data_directory = os.environ["MILABENCH_DIR_DATA"]

# Download (or generate) the needed dataset(s). You are responsible
# to check if it has already been properly downloaded or not, and to
# do nothing if it has been.
print("Hello I am doing some data stuff!")

# If there is nothing to download or generate, just delete this file.
13 changes: 13 additions & 0 deletions benchmarks/purejaxrl/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
jax>=0.2.26
jaxlib>=0.1.74
gymnax
evosax
distrax
optax
flax
numpy
brax
wandb
flashbax
navix
voir>=0.2.10,<0.3
Loading
Loading