Skip to content

Commit

Permalink
refactor PPO, bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Aug 24, 2023
1 parent 9219f84 commit 352aa6f
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 61 deletions.
8 changes: 4 additions & 4 deletions cleanba/cleanba_impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,10 +719,10 @@ def update_minibatch(agent_state, minibatch):
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/loss", loss.item(), global_step)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/loss", loss[-1].item(), global_step)
if learner_policy_version >= args.num_updates:
break

Expand Down
135 changes: 79 additions & 56 deletions cleanba/cleanba_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import List, NamedTuple, Optional, Sequence, Tuple
from functools import partial

import envpool
import flax
Expand Down Expand Up @@ -61,7 +62,7 @@ class Args:
"total timesteps of the experiments"
learning_rate: float = 2.5e-4
"the learning rate of the optimizer"
local_num_envs: int = 64
local_num_envs: int = 60
"the number of parallel game environments"
num_actor_threads: int = 2
"the number of actor threads to use"
Expand Down Expand Up @@ -214,7 +215,8 @@ class Transition(NamedTuple):
obs: list
dones: list
actions: list
logitss: list
logprobs: list
values: list
env_ids: list
rewards: list
truncations: list
Expand Down Expand Up @@ -242,7 +244,7 @@ def rollout(
start_time = time.time()

@jax.jit
def get_action(
def get_action_and_value(
params: flax.core.FrozenDict,
next_obs: np.ndarray,
key: jax.random.PRNGKey,
Expand All @@ -255,20 +257,22 @@ def get_action(
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey, shape=logits.shape)
action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1)
return next_obs, action, logits, key
logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action]
value = Critic().apply(params.critic_params, hidden)
return next_obs, action, logprob, value.squeeze(), key

# put data in the last index
episode_returns = np.zeros((args.local_num_envs,), dtype=np.float32)
returned_episode_returns = np.zeros((args.local_num_envs,), dtype=np.float32)
episode_lengths = np.zeros((args.local_num_envs,), dtype=np.float32)
returned_episode_lengths = np.zeros((args.local_num_envs,), dtype=np.float32)
envs.async_reset()

params_queue_get_time = deque(maxlen=10)
rollout_time = deque(maxlen=10)
rollout_queue_put_time = deque(maxlen=10)
actor_policy_version = 0
storage = []
next_obs = envs.reset()
next_done = jnp.zeros(args.local_num_envs, dtype=jax.numpy.bool_)

@jax.jit
def prepare_data(storage: List[Transition]) -> Transition:
Expand All @@ -281,9 +285,6 @@ def prepare_data(storage: List[Transition]) -> Transition:
storage_time = 0
d2h_time = 0
env_send_time = 0
num_steps_with_bootstrap = (
args.num_steps + 1 + int(len(storage) == 0)
) # num_steps + 1 to get the states for value bootstrapping.
# NOTE: `update != 2` is actually IMPORTANT — it allows us to start running policy collection
# concurrently with the learning process. It also ensures the actor's policy version is only 1 step
# behind the learner's policy version
Expand All @@ -292,8 +293,8 @@ def prepare_data(storage: List[Transition]) -> Transition:
if update != 2:
params = params_queue.get()
# NOTE: block here is important because otherwise this thread will call
# the jitted `get_action` function that hangs until the params are ready.
# This blocks the `get_action` function in other actor threads.
# the jitted `get_action_and_value` function that hangs until the params are ready.
# This blocks the `get_action_and_value` function in other actor threads.
# See https://excalidraw.com/#json=hSooeQL707gE5SWY8wOSS,GeaN1eb2r24PPi75a3n14Q for a visual explanation.
params.network_params["params"]["Dense_0"][
"kernel"
Expand All @@ -304,22 +305,22 @@ def prepare_data(storage: List[Transition]) -> Transition:
actor_policy_version += 1
params_queue_get_time.append(time.time() - params_queue_get_time_start)
rollout_time_start = time.time()
for _ in range(1, num_steps_with_bootstrap):
env_recv_time_start = time.time()
next_obs, next_reward, next_done, info = envs.recv()
env_recv_time += time.time() - env_recv_time_start
storage = []
for _ in range(0, args.num_steps):
cached_next_obs = next_obs
cached_next_done = next_done
global_step += len(next_done) * args.num_actor_threads * len_actor_device_ids * args.world_size
env_id = info["env_id"]

inference_time_start = time.time()
next_obs, action, logits, key = get_action(params, next_obs, key)
cached_next_obs, action, logprob, value, key = get_action_and_value(params, cached_next_obs, key)
inference_time += time.time() - inference_time_start

d2h_time_start = time.time()
cpu_action = np.array(action)
d2h_time += time.time() - d2h_time_start

env_send_time_start = time.time()
envs.send(cpu_action, env_id)
next_obs, next_reward, next_done, info = envs.step(cpu_action)
env_id = info["env_id"]
env_send_time += time.time() - env_send_time_start
storage_time_start = time.time()

Expand All @@ -328,10 +329,11 @@ def prepare_data(storage: List[Transition]) -> Transition:
truncated = info["elapsed_step"] >= envs.spec.config.max_episode_steps
storage.append(
Transition(
obs=next_obs,
dones=next_done,
obs=cached_next_obs,
dones=cached_next_done,
actions=action,
logitss=logits,
logprobs=logprob,
values=value,
env_ids=env_id,
rewards=next_reward,
truncations=truncated,
Expand All @@ -357,21 +359,23 @@ def prepare_data(storage: List[Transition]) -> Transition:
sharded_storage = Transition(
*list(map(lambda x: jax.device_put_sharded(x, devices=learner_devices), partitioned_storage))
)
# next_obs, next_done are still in the host
sharded_next_obs = jax.device_put_sharded(np.split(next_obs, len(learner_devices)), devices=learner_devices)
sharded_next_done = jax.device_put_sharded(np.split(next_done, len(learner_devices)), devices=learner_devices)
payload = (
global_step,
actor_policy_version,
update,
sharded_storage,
sharded_next_obs,
sharded_next_done,
np.mean(params_queue_get_time),
device_thread_id,
)
rollout_queue_put_time_start = time.time()
rollout_queue.put(payload)
rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start)

# move bootstrapping step to the beginning of the next update
storage = storage[-1:]

if update % args.log_frequency == 0:
if device_thread_id == 0:
print(
Expand Down Expand Up @@ -526,16 +530,38 @@ def get_logprob_entropy_value(
value = Critic().apply(params.critic_params, hidden).squeeze(-1)
return logprob, entropy, value

def compute_gae_once(carry, inp, gamma, gae_lambda):
advantages = carry
nextdone, nextvalues, curvalues, reward = inp
nextnonterminal = 1.0 - nextdone

delta = reward + gamma * nextvalues * nextnonterminal - curvalues
advantages = delta + gamma * gae_lambda * nextnonterminal * advantages
return advantages, advantages

compute_gae_once = partial(compute_gae_once, gamma=args.gamma, gae_lambda=args.gae_lambda)

@jax.jit
def compute_gae(
agent_state: TrainState,
next_obs: np.ndarray,
next_done: np.ndarray,
storage: Transition,
):
next_value = critic.apply(
agent_state.params.critic_params, network.apply(agent_state.params.network_params, next_obs)
).squeeze()

advantages = jnp.zeros((args.local_num_envs,))
dones = jnp.concatenate([storage.dones, next_done[None, :]], axis=0)
values = jnp.concatenate([storage.values, next_value[None, :]], axis=0)
_, advantages = jax.lax.scan(
compute_gae_once, advantages, (dones[1:], values[1:], values[:-1], storage.rewards), reverse=True
)
return advantages, advantages + storage.values

def ppo_loss(params, obs, actions, behavior_logprobs, firststeps, advantages, target_values):
# TODO: figure out when to use `mask`
# mask = 1.0 - firststeps
newlogprob, entropy, newvalue = jax.vmap(get_logprob_entropy_value, in_axes=(None, 0, 0))(params, obs, actions)
behavior_logprobs = behavior_logprobs[:-1]
newlogprob = newlogprob[:-1]
entropy = entropy[:-1]
actions = actions[:-1]
# mask = mask[:-1]

logratio = newlogprob - behavior_logprobs
ratio = jnp.exp(logratio)
approx_kl = ((ratio - 1) - logratio).mean()
Expand All @@ -546,7 +572,7 @@ def ppo_loss(params, obs, actions, behavior_logprobs, firststeps, advantages, ta
pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean()

# Value loss
v_loss = 0.5 * ((newvalue[:-1] - target_values) ** 2).mean()
v_loss = 0.5 * ((newvalue - target_values) ** 2).mean()
entropy_loss = entropy.mean()
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl))
Expand All @@ -555,26 +581,15 @@ def ppo_loss(params, obs, actions, behavior_logprobs, firststeps, advantages, ta
def single_device_update(
agent_state: TrainState,
sharded_storages: List,
sharded_next_obs: List,
sharded_next_done: List,
key: jax.random.PRNGKey,
):
storage = jax.tree_map(lambda *x: jnp.hstack(x), *sharded_storages)
next_obs = jnp.concatenate(sharded_next_obs)
next_done = jnp.concatenate(sharded_next_done)
ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True)
behavior_logprobs = jax.vmap(lambda logits, action: jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action])(
storage.logitss, storage.actions
)
values = jax.vmap(get_value, in_axes=(None, 0))(agent_state.params, storage.obs)
discounts = (1.0 - storage.dones) * args.gamma

def gae_advantages(rewards: jnp.array, discounts: jnp.array, values: jnp.array) -> Tuple[jnp.ndarray, jnp.array]:
advantages = rlax.truncated_generalized_advantage_estimation(rewards, discounts, args.gae_lambda, values)
advantages = jax.lax.stop_gradient(advantages)
target_values = values[:-1] + advantages
target_values = jax.lax.stop_gradient(target_values)
return advantages, target_values

advantages, target_values = jax.vmap(gae_advantages, in_axes=1, out_axes=1)(
storage.rewards[:-1], discounts[:-1], values
)
advantages, target_values = compute_gae(agent_state, next_obs, next_done, storage)
# NOTE: notable implementation difference: we normalize advantage at the batch level
if args.norm_adv:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
Expand Down Expand Up @@ -604,7 +619,7 @@ def update_minibatch(agent_state, minibatch):
(
jnp.array(jnp.split(storage.obs, args.num_minibatches * args.gradient_accumulation_steps, axis=1)),
jnp.array(jnp.split(storage.actions, args.num_minibatches * args.gradient_accumulation_steps, axis=1)),
jnp.array(jnp.split(behavior_logprobs, args.num_minibatches * args.gradient_accumulation_steps, axis=1)),
jnp.array(jnp.split(storage.logprobs, args.num_minibatches * args.gradient_accumulation_steps, axis=1)),
jnp.array(jnp.split(storage.firststeps, args.num_minibatches * args.gradient_accumulation_steps, axis=1)),
jnp.array(jnp.split(advantages, args.num_minibatches * args.gradient_accumulation_steps, axis=1)),
jnp.array(jnp.split(target_values, args.num_minibatches * args.gradient_accumulation_steps, axis=1)),
Expand Down Expand Up @@ -661,22 +676,30 @@ def update_minibatch(agent_state, minibatch):
learner_policy_version += 1
rollout_queue_get_time_start = time.time()
sharded_storages = []
sharded_next_obss = []
sharded_next_dones = []
for d_idx, d_id in enumerate(args.actor_device_ids):
for thread_id in range(args.num_actor_threads):
(
global_step,
actor_policy_version,
update,
sharded_storage,
sharded_next_obs,
sharded_next_done,
avg_params_queue_get_time,
device_thread_id,
) = rollout_queues[d_idx * args.num_actor_threads + thread_id].get()
sharded_storages.append(sharded_storage)
sharded_next_obss.append(sharded_next_obs)
sharded_next_dones.append(sharded_next_done)
rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
training_time_start = time.time()
(agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, learner_keys) = multi_device_update(
agent_state,
sharded_storages,
sharded_next_obss,
sharded_next_dones,
learner_keys,
)
unreplicated_params = flax.jax_utils.unreplicate(agent_state.params)
Expand All @@ -703,11 +726,11 @@ def update_minibatch(agent_state, minibatch):
writer.add_scalar(
"charts/learning_rate", agent_state.opt_state[2][1].hyperparams["learning_rate"][-1].item(), global_step
)
writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
writer.add_scalar("losses/loss", loss.item(), global_step)
writer.add_scalar("losses/value_loss", v_loss[-1].item(), global_step)
writer.add_scalar("losses/policy_loss", pg_loss[-1].item(), global_step)
writer.add_scalar("losses/entropy", entropy_loss[-1].item(), global_step)
writer.add_scalar("losses/approx_kl", approx_kl[-1].item(), global_step)
writer.add_scalar("losses/loss", loss[-1].item(), global_step)
if learner_policy_version >= args.num_updates:
break

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cleanba"
version = "1.0.0b2"
version = "1.0.0b3"
description = ""
authors = ["Costa Huang <[email protected]>"]
readme = "README.md"
Expand Down

0 comments on commit 352aa6f

Please sign in to comment.