Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: unified gae #1129

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
52 changes: 16 additions & 36 deletions mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
unreplicate_n_dims,
)
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.multistep import calculate_gae
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
Expand Down Expand Up @@ -81,7 +82,7 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup

def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]:
"""Step the environment."""
params, opt_states, key, env_state, last_timestep = learner_state
params, opt_states, key, env_state, last_timestep, last_done = learner_state

# SELECT ACTION
key, policy_key = jax.random.split(key)
Expand All @@ -102,15 +103,15 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
info = timestep.extras["episode_metrics"]

transition = PPOTransition(
done,
last_done,
action,
value,
timestep.reward,
log_prob,
last_timestep.observation,
info,
)
learner_state = LearnerState(params, opt_states, key, env_state, timestep)
learner_state = LearnerState(params, opt_states, key, env_state, timestep, done)
return learner_state, transition

# STEP ENVIRONMENT FOR ROLLOUT LENGTH
Expand All @@ -119,37 +120,12 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
)

# CALCULATE ADVANTAGE
params, opt_states, key, env_state, last_timestep = learner_state
params, opt_states, key, env_state, last_timestep, last_done = learner_state
last_val = critic_apply_fn(params.critic_params, last_timestep.observation)

def _calculate_gae(
traj_batch: PPOTransition, last_val: chex.Array
) -> Tuple[chex.Array, chex.Array]:
"""Calculate the GAE."""

def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple:
"""Calculate the GAE for a single transition."""
gae, next_value = gae_and_next_value
done, value, reward = (
transition.done,
transition.value,
transition.reward,
)
gamma = config.system.gamma
delta = reward + gamma * next_value * (1 - done) - value
gae = delta + gamma * config.system.gae_lambda * (1 - done) * gae
return (gae, value), gae

_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val)
advantages, targets = calculate_gae(
traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda
)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
"""Update the network for a single epoch."""
Expand Down Expand Up @@ -312,7 +288,7 @@ def _critic_loss_fn(
)

params, opt_states, traj_batch, advantages, targets, key = update_state
learner_state = LearnerState(params, opt_states, key, env_state, last_timestep)
learner_state = LearnerState(params, opt_states, key, env_state, last_timestep, last_done)
metric = traj_batch.info
return learner_state, (metric, loss_info)

Expand Down Expand Up @@ -430,9 +406,13 @@ def learner_setup(
params = restored_params

# Define params to be replicated across devices and batches.
dones = jnp.zeros(
(config.arch.num_envs, config.system.num_agents),
dtype=bool,
)
key, step_keys = jax.random.split(key)
opt_states = OptStates(actor_opt_state, critic_opt_state)
replicate_learner = (params, opt_states, step_keys)
replicate_learner = (params, opt_states, step_keys, dones)

# Duplicate learner for update_batch_size.
broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size, *x.shape))
Expand All @@ -442,8 +422,8 @@ def learner_setup(
replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices())

# Initialise learner state.
params, opt_states, step_keys = replicate_learner
init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps)
params, opt_states, step_keys, dones = replicate_learner
init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps, dones)

return learn, actor_network, init_learner_state

Expand Down
52 changes: 16 additions & 36 deletions mava/systems/ppo/anakin/ff_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.multistep import calculate_gae
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
Expand Down Expand Up @@ -76,7 +77,7 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup

def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]:
"""Step the environment."""
params, opt_states, key, env_state, last_timestep = learner_state
params, opt_states, key, env_state, last_timestep, last_done = learner_state

# SELECT ACTION
key, policy_key = jax.random.split(key)
Expand All @@ -96,9 +97,9 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
info = timestep.extras["episode_metrics"]

transition = PPOTransition(
done, action, value, timestep.reward, log_prob, last_timestep.observation, info
last_done, action, value, timestep.reward, log_prob, last_timestep.observation, info
)
learner_state = LearnerState(params, opt_states, key, env_state, timestep)
learner_state = LearnerState(params, opt_states, key, env_state, timestep, done)
return learner_state, transition

# STEP ENVIRONMENT FOR ROLLOUT LENGTH
Expand All @@ -107,37 +108,12 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
)

# CALCULATE ADVANTAGE
params, opt_states, key, env_state, last_timestep = learner_state
params, opt_states, key, env_state, last_timestep, last_done = learner_state
last_val = critic_apply_fn(params.critic_params, last_timestep.observation)

def _calculate_gae(
traj_batch: PPOTransition, last_val: chex.Array
) -> Tuple[chex.Array, chex.Array]:
"""Calculate the GAE."""

def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple:
"""Calculate the GAE for a single transition."""
gae, next_value = gae_and_next_value
done, value, reward = (
transition.done,
transition.value,
transition.reward,
)
gamma = config.system.gamma
delta = reward + gamma * next_value * (1 - done) - value
gae = delta + gamma * config.system.gae_lambda * (1 - done) * gae
return (gae, value), gae

_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val)
advantages, targets = calculate_gae(
traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda
)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
"""Update the network for a single epoch."""
Expand Down Expand Up @@ -296,7 +272,7 @@ def _critic_loss_fn(
)

params, opt_states, traj_batch, advantages, targets, key = update_state
learner_state = LearnerState(params, opt_states, key, env_state, last_timestep)
learner_state = LearnerState(params, opt_states, key, env_state, last_timestep, last_done)
metric = traj_batch.info
return learner_state, (metric, loss_info)

Expand Down Expand Up @@ -414,9 +390,13 @@ def learner_setup(
params = restored_params

# Define params to be replicated across devices and batches.
dones = jnp.zeros(
(config.arch.num_envs, config.system.num_agents),
dtype=bool,
)
key, step_keys = jax.random.split(key)
opt_states = OptStates(actor_opt_state, critic_opt_state)
replicate_learner = (params, opt_states, step_keys)
replicate_learner = (params, opt_states, step_keys, dones)

# Duplicate learner for update_batch_size.
broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size, *x.shape))
Expand All @@ -426,8 +406,8 @@ def learner_setup(
replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices())

# Initialise learner state.
params, opt_states, step_keys = replicate_learner
init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps)
params, opt_states, step_keys, dones = replicate_learner
init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps, dones)

return learn, actor_network, init_learner_state

Expand Down
27 changes: 4 additions & 23 deletions mava/systems/ppo/anakin/rec_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.multistep import calculate_gae
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
Expand Down Expand Up @@ -179,29 +180,9 @@ def _env_step(
# Squeeze out the batch dimension and mask out the value of terminal states.
last_val = last_val.squeeze(0)

def _calculate_gae(
traj_batch: RNNPPOTransition, last_val: chex.Array, last_done: chex.Array
) -> Tuple[chex.Array, chex.Array]:
def _get_advantages(
carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition
) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]:
gae, next_value, next_done = carry
done, value, reward = transition.done, transition.value, transition.reward
gamma = config.system.gamma
delta = reward + gamma * next_value * (1 - next_done) - value
gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae
return (gae, value, done), gae

_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val, last_done),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val, last_done)
advantages, targets = calculate_gae(
traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda
)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
"""Update the network for a single epoch."""
Expand Down
27 changes: 4 additions & 23 deletions mava/systems/ppo/anakin/rec_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.multistep import calculate_gae
from mava.utils.network_utils import get_action_head
from mava.utils.total_timestep_checker import check_total_timesteps
from mava.utils.training import make_learning_rate
Expand Down Expand Up @@ -175,29 +176,9 @@ def _env_step(
# Squeeze out the batch dimension and mask out the value of terminal states.
last_val = last_val.squeeze(0)

def _calculate_gae(
traj_batch: RNNPPOTransition, last_val: chex.Array, last_done: chex.Array
) -> Tuple[chex.Array, chex.Array]:
def _get_advantages(
carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition
) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]:
gae, next_value, next_done = carry
done, value, reward = transition.done, transition.value, transition.reward
gamma = config.system.gamma
delta = reward + gamma * next_value * (1 - next_done) - value
gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae
return (gae, value, done), gae

_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val, last_done),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val, last_done)
advantages, targets = calculate_gae(
traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda
)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
"""Update the network for a single epoch."""
Expand Down
1 change: 1 addition & 0 deletions mava/systems/ppo/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class LearnerState(NamedTuple):
key: chex.PRNGKey
env_state: State
timestep: TimeStep
dones: Done
SimonDuToit marked this conversation as resolved.
Show resolved Hide resolved


class RNNLearnerState(NamedTuple):
Expand Down
68 changes: 68 additions & 0 deletions mava/utils/multistep.py
SimonDuToit marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple, Union

import chex
import jax
import jax.numpy as jnp

from mava.systems.ppo.types import PPOTransition, RNNPPOTransition


def calculate_gae(
traj_batch: Union[PPOTransition, RNNPPOTransition],
last_val: chex.Array,
last_done: chex.Array,
gamma: float,
gae_lambda: float,
unroll: int = 16
) -> Tuple[chex.Array, chex.Array]:
"""Computes truncated generalized advantage estimates.

The advantages are computed in a backwards fashion according to the equation:
Âₜ = δₜ + (γλ) * δₜ₊₁ + ... + ... + (γλ)ᵏ⁻ᵗ⁺¹ * δₖ₋₁
where δₜ = rₜ₊₁ + γₜ₊₁ * v(sₜ₊₁) - v(sₜ).
See Proximal Policy Optimization Algorithms, Schulman et al.:
https://arxiv.org/abs/1707.06347

Args:
traj_batch (B, T, A, ...): a batch of trajectories.
last_val (B, A): value of the final timestep.
last_done (B, A): whether the last timestep was a terminated or truncated.
SimonDuToit marked this conversation as resolved.
Show resolved Hide resolved
gamma (float): discount factor.
gae_lambda (float): GAE mixing parameter.
unroll (int): how much XLA should unroll the scan used to calculate GAE.

Returns Tuple[(B, T, A), (B, T, A)]: advantages and target values.
SimonDuToit marked this conversation as resolved.
Show resolved Hide resolved
"""

def _get_advantages(
carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition
) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]:
gae, next_value, next_done = carry
done, value, reward = transition.done, transition.value, transition.reward

delta = reward + gamma * next_value * (1 - next_done) - value
gae = delta + gamma * gae_lambda * (1 - next_done) * gae
return (gae, value, done), gae

_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val, last_done),
traj_batch,
reverse=True,
unroll=unroll,
)
return advantages, advantages + traj_batch.value
Loading