-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Fabrice Normandin <[email protected]>
- Loading branch information
Showing
6 changed files
with
416 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
|
||
# Benchmark | ||
|
||
Rewrite this README to explain what the benchmark is! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from milabench.pack import Package | ||
|
||
|
||
class BraxBenchmark(Package): | ||
base_requirements = "requirements.in" | ||
main_script = "main.py" | ||
|
||
|
||
__pack__ = BraxBenchmark |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,325 @@ | ||
"""PureJaxRL""" | ||
import dataclasses | ||
import os | ||
|
||
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" | ||
from typing import NamedTuple | ||
|
||
import distrax | ||
import flax.linen as nn | ||
import gymnax | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import optax | ||
from flax.linen.initializers import constant, orthogonal | ||
from flax.training.train_state import TrainState | ||
from gymnax.wrappers.purerl import FlattenObservationWrapper, LogWrapper | ||
|
||
|
||
class ActorCritic(nn.Module): | ||
action_dim: int | ||
activation: str = "tanh" | ||
|
||
@nn.compact | ||
def __call__(self, x): | ||
if self.activation == "relu": | ||
activation = nn.relu | ||
else: | ||
activation = nn.tanh | ||
actor_mean = nn.Dense( | ||
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) | ||
)(x) | ||
actor_mean = activation(actor_mean) | ||
actor_mean = nn.Dense( | ||
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) | ||
)(actor_mean) | ||
actor_mean = activation(actor_mean) | ||
actor_mean = nn.Dense( | ||
self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) | ||
)(actor_mean) | ||
pi = distrax.Categorical(logits=actor_mean) | ||
|
||
critic = nn.Dense( | ||
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) | ||
)(x) | ||
critic = activation(critic) | ||
critic = nn.Dense( | ||
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) | ||
)(critic) | ||
critic = activation(critic) | ||
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))( | ||
critic | ||
) | ||
|
||
return pi, jnp.squeeze(critic, axis=-1) | ||
|
||
|
||
class Transition(NamedTuple): | ||
done: jnp.ndarray | ||
action: jnp.ndarray | ||
value: jnp.ndarray | ||
reward: jnp.ndarray | ||
log_prob: jnp.ndarray | ||
obs: jnp.ndarray | ||
info: jnp.ndarray | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class Config: | ||
LR: float = 2.5e-4 | ||
NUM_ENVS: int = 4 | ||
NUM_STEPS: int = 128 | ||
TOTAL_TIMESTEPS: int = int(5e5) | ||
UPDATE_EPOCHS: int = 4 | ||
NUM_MINIBATCHES: int = 4 | ||
GAMMA: float = 0.99 | ||
GAE_LAMBDA: float = 0.95 | ||
CLIP_EPS: float = 0.2 | ||
ENT_COEF: float = 0.01 | ||
VF_COEF: float = 0.5 | ||
MAX_GRAD_NORM: float = 0.5 | ||
ACTIVATION: str = "tanh" | ||
ENV_NAME: str = "CartPole-v1" | ||
ANNEAL_LR: bool = True | ||
|
||
@property | ||
def NUM_UPDATES(self) -> int: | ||
return self.TOTAL_TIMESTEPS // self.NUM_STEPS // self.NUM_ENVS | ||
|
||
@property | ||
def MINIBATCH_SIZE(self) -> int: | ||
return self.NUM_ENVS * self.NUM_STEPS // self.NUM_MINIBATCHES | ||
|
||
|
||
def make_train(config: Config): | ||
env, env_params = gymnax.make(config.ENV_NAME) | ||
env = FlattenObservationWrapper(env) | ||
env = LogWrapper(env) | ||
|
||
def linear_schedule(count): | ||
frac = ( | ||
1.0 | ||
- (count // (config.NUM_MINIBATCHES * config.UPDATE_EPOCHS)) | ||
/ config.NUM_UPDATES | ||
) | ||
return config.LR * frac | ||
|
||
def train(rng): | ||
# INIT NETWORK | ||
network = ActorCritic( | ||
env.action_space(env_params).n, activation=config.ACTIVATION | ||
) | ||
rng, _rng = jax.random.split(rng) | ||
init_x = jnp.zeros(env.observation_space(env_params).shape) | ||
network_params = network.init(_rng, init_x) | ||
if config.ANNEAL_LR: | ||
tx = optax.chain( | ||
optax.clip_by_global_norm(config.MAX_GRAD_NORM), | ||
optax.adam(learning_rate=linear_schedule, eps=1e-5), | ||
) | ||
else: | ||
tx = optax.chain( | ||
optax.clip_by_global_norm(config.MAX_GRAD_NORM), | ||
optax.adam(config.LR, eps=1e-5), | ||
) | ||
train_state = TrainState.create( | ||
apply_fn=network.apply, | ||
params=network_params, | ||
tx=tx, | ||
) | ||
|
||
# INIT ENV | ||
rng, _rng = jax.random.split(rng) | ||
reset_rng = jax.random.split(_rng, config.NUM_ENVS) | ||
obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params) | ||
|
||
# TRAIN LOOP | ||
def _update_step(runner_state, unused): | ||
# COLLECT TRAJECTORIES | ||
def _env_step(runner_state, unused): | ||
train_state, env_state, last_obs, rng = runner_state | ||
|
||
# SELECT ACTION | ||
rng, _rng = jax.random.split(rng) | ||
pi, value = network.apply(train_state.params, last_obs) | ||
action = pi.sample(seed=_rng) | ||
log_prob = pi.log_prob(action) | ||
|
||
# STEP ENV | ||
rng, _rng = jax.random.split(rng) | ||
rng_step = jax.random.split(_rng, config.NUM_ENVS) | ||
obsv, env_state, reward, done, info = jax.vmap( | ||
env.step, in_axes=(0, 0, 0, None) | ||
)(rng_step, env_state, action, env_params) | ||
transition = Transition( | ||
done, action, value, reward, log_prob, last_obs, info | ||
) | ||
runner_state = (train_state, env_state, obsv, rng) | ||
return runner_state, transition | ||
|
||
runner_state, traj_batch = jax.lax.scan( | ||
_env_step, runner_state, None, config.NUM_STEPS | ||
) | ||
|
||
# CALCULATE ADVANTAGE | ||
train_state, env_state, last_obs, rng = runner_state | ||
_, last_val = network.apply(train_state.params, last_obs) | ||
|
||
def _calculate_gae(traj_batch, last_val): | ||
def _get_advantages(gae_and_next_value, transition): | ||
gae, next_value = gae_and_next_value | ||
done, value, reward = ( | ||
transition.done, | ||
transition.value, | ||
transition.reward, | ||
) | ||
delta = reward + config.GAMMA * next_value * (1 - done) - value | ||
gae = delta + config.GAMMA * config.GAE_LAMBDA * (1 - done) * gae | ||
return (gae, value), gae | ||
|
||
_, advantages = jax.lax.scan( | ||
_get_advantages, | ||
(jnp.zeros_like(last_val), last_val), | ||
traj_batch, | ||
reverse=True, | ||
unroll=16, | ||
) | ||
return advantages, advantages + traj_batch.value | ||
|
||
advantages, targets = _calculate_gae(traj_batch, last_val) | ||
|
||
# UPDATE NETWORK | ||
def _update_epoch(update_state, unused): | ||
def _update_minbatch(train_state, batch_info): | ||
traj_batch, advantages, targets = batch_info | ||
|
||
def _loss_fn(params, traj_batch, gae, targets): | ||
# RERUN NETWORK | ||
pi, value = network.apply(params, traj_batch.obs) | ||
log_prob = pi.log_prob(traj_batch.action) | ||
|
||
# CALCULATE VALUE LOSS | ||
value_pred_clipped = traj_batch.value + ( | ||
value - traj_batch.value | ||
).clip(-config.CLIP_EPS, config.CLIP_EPS) | ||
value_losses = jnp.square(value - targets) | ||
value_losses_clipped = jnp.square(value_pred_clipped - targets) | ||
value_loss = ( | ||
0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() | ||
) | ||
|
||
# CALCULATE ACTOR LOSS | ||
ratio = jnp.exp(log_prob - traj_batch.log_prob) | ||
gae = (gae - gae.mean()) / (gae.std() + 1e-8) | ||
loss_actor1 = ratio * gae | ||
loss_actor2 = ( | ||
jnp.clip( | ||
ratio, | ||
1.0 - config.CLIP_EPS, | ||
1.0 + config.CLIP_EPS, | ||
) | ||
* gae | ||
) | ||
loss_actor = -jnp.minimum(loss_actor1, loss_actor2) | ||
loss_actor = loss_actor.mean() | ||
entropy = pi.entropy().mean() | ||
|
||
total_loss = ( | ||
loss_actor | ||
+ config.VF_COEF * value_loss | ||
- config.ENT_COEF * entropy | ||
) | ||
return total_loss, (value_loss, loss_actor, entropy) | ||
|
||
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) | ||
total_loss, grads = grad_fn( | ||
train_state.params, traj_batch, advantages, targets | ||
) | ||
train_state = train_state.apply_gradients(grads=grads) | ||
return train_state, total_loss | ||
|
||
train_state, traj_batch, advantages, targets, rng = update_state | ||
rng, _rng = jax.random.split(rng) | ||
batch_size = config.MINIBATCH_SIZE * config.NUM_MINIBATCHES | ||
assert ( | ||
batch_size == config.NUM_STEPS * config.NUM_ENVS | ||
), "batch size must be equal to number of steps * number of envs" | ||
permutation = jax.random.permutation(_rng, batch_size) | ||
batch = (traj_batch, advantages, targets) | ||
batch = jax.tree_util.tree_map( | ||
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch | ||
) | ||
shuffled_batch = jax.tree_util.tree_map( | ||
lambda x: jnp.take(x, permutation, axis=0), batch | ||
) | ||
minibatches = jax.tree_util.tree_map( | ||
lambda x: jnp.reshape( | ||
x, [config.NUM_MINIBATCHES, -1] + list(x.shape[1:]) | ||
), | ||
shuffled_batch, | ||
) | ||
train_state, total_loss = jax.lax.scan( | ||
_update_minbatch, train_state, minibatches | ||
) | ||
update_state = (train_state, traj_batch, advantages, targets, rng) | ||
return update_state, total_loss | ||
|
||
update_state = (train_state, traj_batch, advantages, targets, rng) | ||
update_state, loss_info = jax.lax.scan( | ||
_update_epoch, update_state, None, config.UPDATE_EPOCHS | ||
) | ||
train_state = update_state[0] | ||
metric = traj_batch.info | ||
rng = update_state[-1] | ||
|
||
runner_state = (train_state, env_state, last_obs, rng) | ||
return runner_state, metric | ||
|
||
rng, _rng = jax.random.split(rng) | ||
runner_state = (train_state, env_state, obsv, _rng) | ||
runner_state, metric = jax.lax.scan( | ||
_update_step, runner_state, None, config.NUM_UPDATES | ||
) | ||
return {"runner_state": runner_state, "metrics": metric} | ||
|
||
return train | ||
|
||
|
||
def main(): | ||
rng = jax.random.PRNGKey(42) | ||
|
||
config = Config() | ||
|
||
train_jit = jax.jit(make_train(config)) | ||
out = train_jit(rng) | ||
|
||
import time | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
rng = jax.random.PRNGKey(42) | ||
t0 = time.time() | ||
out = jax.block_until_ready(train_jit(rng)) | ||
print(f"time: {time.time() - t0:.2f} s") | ||
plt.plot(out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1)) | ||
plt.xlabel("Update Step") | ||
plt.ylabel("Return") | ||
plt.show() | ||
|
||
rng = jax.random.PRNGKey(42) | ||
rngs = jax.random.split(rng, 256) | ||
train_vjit = jax.jit(jax.vmap(make_train(config))) | ||
t0 = time.time() | ||
outs = jax.block_until_ready(train_vjit(rngs)) | ||
print(f"time: {time.time() - t0:.2f} s") | ||
|
||
for i in range(256): | ||
plt.plot(outs["metrics"]["returned_episode_returns"][i].mean(-1).reshape(-1)) | ||
plt.xlabel("Update Step") | ||
plt.ylabel("Return") | ||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#!/usr/bin/env python | ||
|
||
import os | ||
|
||
if __name__ == "__main__": | ||
# If you need the whole configuration: | ||
# config = json.loads(os.environ["MILABENCH_CONFIG"]) | ||
|
||
data_directory = os.environ["MILABENCH_DIR_DATA"] | ||
|
||
# Download (or generate) the needed dataset(s). You are responsible | ||
# to check if it has already been properly downloaded or not, and to | ||
# do nothing if it has been. | ||
print("Hello I am doing some data stuff!") | ||
|
||
# If there is nothing to download or generate, just delete this file. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
jax>=0.2.26 | ||
jaxlib>=0.1.74 | ||
gymnax | ||
evosax | ||
distrax | ||
optax | ||
flax | ||
numpy | ||
brax | ||
wandb | ||
flashbax | ||
navix | ||
voir>=0.2.10,<0.3 |
Oops, something went wrong.