Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Support latest Jumanji version #1134

Open
wants to merge 18 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/Quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {
"id": "eWjNSGvZ7ALw"
},
Expand Down Expand Up @@ -573,7 +573,7 @@
" )\n",
"\n",
" # Initialise observation with obs of all agents.\n",
" obs = env.observation_spec().generate_value()\n",
" obs = env.observation_spec.generate_value()\n",
" init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs)\n",
"\n",
" # Initialise actor params and optimiser state.\n",
Expand Down
2 changes: 1 addition & 1 deletion mava/advanced_usage/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dummy_flashbax_transition = {
"observation": jnp.zeros(
(
config.system.num_agents,
env.observation_spec().agents_view.shape[1],
env.observation_spec.agents_view.shape[1],
),
dtype=jnp.float32,
),
Expand Down
4 changes: 2 additions & 2 deletions mava/advanced_usage/ff_ippo_store_experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def learner_setup(
)

# Initialise observation with obs of all agents.
obs = env.observation_spec().generate_value()
obs = env.observation_spec.generate_value()
init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs)

# Initialise actor params and optimiser state.
Expand Down Expand Up @@ -507,7 +507,7 @@ def run_experiment(_config: DictConfig) -> None:
"observation": jnp.zeros(
(
config.system.num_agents,
env.observation_spec().agents_view.shape[1],
env.observation_spec.agents_view.shape[1],
),
dtype=jnp.float32,
),
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/ff_ippo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: anakin
- system: ppo/ff_ippo
- network: mlp # [mlp, cnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
4 changes: 2 additions & 2 deletions mava/configs/default/ff_mappo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ defaults:
- logger: logger
- arch: anakin
- system: ppo/ff_mappo
- network: mlp # [mlp, cnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- network: mlp # [mlp, cnn]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/ff_sable.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: anakin
- system: sable/ff_sable
- network: ff_retention
- env: rware # [cleaner, connector, gigastep, lbf, rware, smax, mpe]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/mat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: anakin
- system: mat/mat
- network: transformer
- env: rware # [gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
4 changes: 2 additions & 2 deletions mava/configs/default/rec_ippo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ defaults:
- logger: logger
- arch: anakin
- system: ppo/rec_ippo
- network: rnn # [rnn, rcnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- network: rnn # [rnn, rcnn]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
4 changes: 2 additions & 2 deletions mava/configs/default/rec_iql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ defaults:
- logger: logger
- arch: anakin
- system: q_learning/rec_iql
- network: rnn # [rnn, rcnn]
- env: smax # [cleaner, connector, gigastep, lbf, matrax, rware, smax]
- network: rnn # [rnn, rcnn]
- env: smax # [cleaner, connector, vector-connector, gigastep, lbf, matrax, rware, smax, mpe]

hydra:
searchpath:
Expand Down
4 changes: 2 additions & 2 deletions mava/configs/default/rec_mappo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ defaults:
- logger: logger
- arch: anakin
- system: ppo/rec_mappo
- network: rnn # [rnn, rcnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- network: rnn # [rnn, rcnn]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/rec_sable.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: anakin
- system: sable/rec_sable
- network: rec_retention
- env: rware # [cleaner, connector, gigastep, lbf, rware, smax, mabrax, mpe]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
5 changes: 4 additions & 1 deletion mava/configs/env/connector.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ defaults:
- scenario: con-5x5x3a # [con-5x5x3a, con-7x7x5a, con-10x10x10a, con-15x15x23a]
# Further environment config details in "con-10x10x5a" file.

env_name: MaConnector # Used for logging purposes.
env_name: Connector # Used for logging purposes.
WiemKhlifi marked this conversation as resolved.
Show resolved Hide resolved

# Choose whether to aggregate the list of individual rewards and use the team reward (default setting) OR use_individual_rewards=True.
use_individual_rewards: False # If True, use the list of individual rewards.

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/con-10x10x10a.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# The config of the 10x10x10a scenario
name: MaConnector-v2
name: Connector-v2
task_name: con-10x10x10a

task_config:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/con-15x15x23a.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# The config of the 15x15x23a scenario
name: MaConnector-v2
name: Connector-v2
task_name: con-15x15x23a

task_config:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/con-5x5x3a.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# The config of the 5x5x3a scenario
name: MaConnector-v2
name: Connector-v2
task_name: con-5x5x3a

task_config:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/con-7x7x5a.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# The config of the 7x7x5a scenario
name: MaConnector-v2
name: Connector-v2
task_name: con-7x7x5a

task_config:
Expand Down
5 changes: 4 additions & 1 deletion mava/configs/env/vector-connector.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ defaults:
- scenario: con-5x5x3a # [con-5x5x3a, con-7x7x5a, con-10x10x10a, con-15x15x23a]
# Further environment config details in "con-10x10x5a" file.

env_name: VectorMaConnector # Used for logging purposes.
env_name: VectorConnector # Used for logging purposes.

# Choose whether to aggregate the list of individual rewards and use the team reward (default setting) OR use_individual_rewards=True.
WiemKhlifi marked this conversation as resolved.
Show resolved Hide resolved
use_individual_rewards: True # If True, use the list of individual rewards.

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
Expand Down
2 changes: 1 addition & 1 deletion mava/systems/mat/anakin/mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def learner_setup(
key, actor_net_key = keys

# Initialise observation: Obs for all agents.
init_x = env.observation_spec().generate_value()
init_x = env.observation_spec.generate_value()
init_x = tree.map(lambda x: x[None, ...], init_x)

_, action_space_type = get_action_head(env.action_spec())
Expand Down
2 changes: 1 addition & 1 deletion mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def learner_setup(
)

# Initialise observation with obs of all agents.
obs = env.observation_spec().generate_value()
obs = env.observation_spec.generate_value()
init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs)

# Initialise actor params and optimiser state.
Expand Down
2 changes: 1 addition & 1 deletion mava/systems/ppo/anakin/ff_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def learner_setup(
)

# Initialise observation with obs of all agents.
obs = env.observation_spec().generate_value()
obs = env.observation_spec.generate_value()
init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs)

# Initialise actor params and optimiser state.
Expand Down
2 changes: 1 addition & 1 deletion mava/systems/ppo/anakin/rec_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def learner_setup(
)

# Initialise observation with obs of all agents.
init_obs = env.observation_spec().generate_value()
init_obs = env.observation_spec.generate_value()
init_obs = tree.map(
lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0),
init_obs,
Expand Down
2 changes: 1 addition & 1 deletion mava/systems/ppo/anakin/rec_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def learner_setup(
)

# Initialise observation with obs of all agents.
init_obs = env.observation_spec().generate_value()
init_obs = env.observation_spec.generate_value()
init_obs = tree.map(
lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0),
init_obs,
Expand Down
4 changes: 2 additions & 2 deletions mava/systems/q_learning/anakin/rec_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def replicate(x: Any) -> Any:
# N: Agent

# Make dummy inputs to init recurrent Q network -> need shape (T, B, N, ...)
init_obs = env.observation_spec().generate_value() # (N, ...)
init_obs = env.observation_spec.generate_value() # (N, ...)
# (B, T, N, ...)
init_obs_batched = tree.map(lambda x: x[jnp.newaxis, jnp.newaxis, ...], init_obs)
init_term_or_trunc = jnp.zeros((1, 1, 1), dtype=bool) # (T, B, 1)
Expand Down Expand Up @@ -130,7 +130,7 @@ def replicate(x: Any) -> Any:
init_hidden_state = replicate(init_hidden_state)

# Create dummy transition
init_acts = env.action_spec().generate_value() # (N,)
init_acts = env.action_spec.generate_value() # (N,)
init_transition = Transition(
obs=init_obs, # (N, ...)
action=init_acts,
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/q_learning/anakin/rec_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def replicate(x: Any) -> Any:
# N: Agent

# Make dummy inputs to init recurrent Q network -> need shape (T, B, N, ...)
init_obs = env.observation_spec().generate_value() # (N, ...)
init_obs = env.observation_spec.generate_value() # (N, ...)
# (B, T, N, ...)
init_obs_batched = tree.map(lambda x: x[jnp.newaxis, jnp.newaxis, ...], init_obs)
init_term_or_trunc = jnp.zeros((1, 1, 1), dtype=bool) # (T, B, 1)
Expand Down Expand Up @@ -126,7 +126,7 @@ def replicate(x: Any) -> Any:
dtype=float,
)
global_env_state_shape = (
env.observation_spec().generate_value().global_state[0, :].shape
env.observation_spec.generate_value().global_state[0, :].shape
) # NOTE: Env wrapper currently duplicates env state for each agent
dummy_global_env_state = jnp.zeros(
(
Expand Down Expand Up @@ -159,7 +159,7 @@ def replicate(x: Any) -> Any:
opt_state = replicate(opt_state)
init_hidden_state = replicate(init_hidden_state)

init_acts = env.action_spec().generate_value()
init_acts = env.action_spec.generate_value()

# NOTE: term_or_trunc refers to the the joint done, ie. when all agents are done or when the
# episode horizon has been reached. We use this exclusively in QMIX.
Expand Down
2 changes: 1 addition & 1 deletion mava/systems/sable/anakin/ff_sable.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def learner_setup(
)

# Get mock inputs to initialise network.
init_obs = env.observation_spec().generate_value()
init_obs = env.observation_spec.generate_value()
init_obs = tree.map(lambda x: x[jnp.newaxis, ...], init_obs) # Add batch dim
init_hs = get_init_hidden_state(config.network.net_config, config.arch.num_envs)
init_hs = tree.map(lambda x: x[0, jnp.newaxis], init_hs)
Expand Down
2 changes: 1 addition & 1 deletion mava/systems/sable/anakin/rec_sable.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def learner_setup(
)

# Get mock inputs to initialise network.
init_obs = env.observation_spec().generate_value()
init_obs = env.observation_spec.generate_value()
init_obs = tree.map(lambda x: x[jnp.newaxis, ...], init_obs) # Add batch dim
init_hs = get_init_hidden_state(config.network.net_config, config.arch.num_envs)
init_hs = tree.map(lambda x: x[0, jnp.newaxis], init_hs)
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/sac/anakin/ff_hasac.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ def replicate(x: Any) -> Any:
key, actor_key, q1_key, q2_key, q1_target_key, q2_target_key = jax.random.split(key, 6)
actor_keys = jax.random.split(actor_key, n_agents)

acts = env.action_spec().generate_value() # all agents actions
acts = env.action_spec.generate_value() # all agents actions
act_single = acts[0] # single agents action
concat_acts = jnp.concatenate([act_single for _ in range(n_agents)], axis=0)
concat_acts_batched = concat_acts[jnp.newaxis, ...] # batch + concat of all agents actions
obs = env.observation_spec().generate_value()
obs = env.observation_spec.generate_value()
obs_single_batched = tree.map(lambda x: x[0][jnp.newaxis, ...], obs)

# Making actor network
Expand Down Expand Up @@ -285,7 +285,7 @@ def make_update_fns(
actor_net, q_net = networks
actor_opt, q_opt, alpha_opt = optims

full_action_shape = (cfg.arch.num_envs, *env.action_spec().shape)
full_action_shape = (cfg.arch.num_envs, *env.action_spec.shape)

# losses:
def q_loss_fn(
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/sac/anakin/ff_isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def replicate(x: Any) -> Any:

key, actor_key, q1_key, q2_key, q1_target_key, q2_target_key = jax.random.split(key, 6)

acts = env.action_spec().generate_value() # all agents actions
acts = env.action_spec.generate_value() # all agents actions
act_single_batched = acts[0][jnp.newaxis, ...] # batch single agent action
obs = env.observation_spec().generate_value()
obs = env.observation_spec.generate_value()
obs_single_batched = tree.map(lambda x: x[0][jnp.newaxis, ...], obs)

# Making actor network
Expand Down Expand Up @@ -242,7 +242,7 @@ def make_update_fns(
actor_net, q_net = networks
actor_opt, q_opt, alpha_opt = optims

full_action_shape = (cfg.arch.num_envs, *env.action_spec().shape)
full_action_shape = (cfg.arch.num_envs, *env.action_spec.shape)

# losses:
def q_loss_fn(
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/sac/anakin/ff_masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ def replicate(x: Any) -> Any:

key, actor_key, q1_key, q2_key, q1_target_key, q2_target_key = jax.random.split(key, 6)

acts = env.action_spec().generate_value() # all agents actions
acts = env.action_spec.generate_value() # all agents actions
act_single = acts[0] # single agents action
joint_acts = jnp.concatenate([act_single for _ in range(n_agents)], axis=0)
joint_acts_batched = joint_acts[jnp.newaxis, ...] # joint actions with a batch dim
obs = env.observation_spec().generate_value()
obs = env.observation_spec.generate_value()
obs_single_batched = tree.map(lambda x: x[0][jnp.newaxis, ...], obs)

# Making actor network
Expand Down Expand Up @@ -245,7 +245,7 @@ def make_update_fns(
actor_net, q_net = networks
actor_opt, q_opt, alpha_opt = optims

full_action_shape = (cfg.arch.num_envs, *env.action_spec().shape)
full_action_shape = (cfg.arch.num_envs, *env.action_spec.shape)

# losses:
def q_loss_fn(
Expand Down
5 changes: 5 additions & 0 deletions mava/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import cached_property
from typing import Any, Callable, Dict, Generic, Optional, Protocol, Tuple, TypeVar, Union

import chex
Expand Down Expand Up @@ -67,6 +68,7 @@ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep]:
"""
...

@cached_property
def observation_spec(self) -> specs.Spec:
"""Returns the observation spec.

Expand All @@ -75,6 +77,7 @@ def observation_spec(self) -> specs.Spec:
"""
...

@cached_property
def action_spec(self) -> specs.Spec:
"""Returns the action spec.

Expand All @@ -83,6 +86,7 @@ def action_spec(self) -> specs.Spec:
"""
...

@cached_property
def reward_spec(self) -> specs.Array:
"""Describes the reward returned by the environment. By default, this is assumed to be a
single float.
Expand All @@ -92,6 +96,7 @@ def reward_spec(self) -> specs.Array:
"""
...

@cached_property
def discount_spec(self) -> specs.BoundedArray:
"""Describes the discount returned by the environment. By default, this is assumed to be a
single float between 0 and 1.
Expand Down
Loading
Loading