Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Support latest Jumanji version #1134

Open
wants to merge 17 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/Quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {
"id": "eWjNSGvZ7ALw"
},
Expand Down Expand Up @@ -573,7 +573,7 @@
" )\n",
"\n",
" # Initialise observation with obs of all agents.\n",
" obs = env.observation_spec().generate_value()\n",
" obs = env.observation_spec.generate_value()\n",
" init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs)\n",
"\n",
" # Initialise actor params and optimiser state.\n",
Expand Down
2 changes: 1 addition & 1 deletion mava/advanced_usage/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dummy_flashbax_transition = {
"observation": jnp.zeros(
(
config.system.num_agents,
env.observation_spec().agents_view.shape[1],
env.observation_spec.agents_view.shape[1],
),
dtype=jnp.float32,
),
Expand Down
6 changes: 3 additions & 3 deletions mava/advanced_usage/ff_ippo_store_experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand All @@ -378,7 +378,7 @@ def learner_setup(
)

# Initialise observation with obs of all agents.
obs = env.observation_spec().generate_value()
obs = env.observation_spec.generate_value()
init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs)

# Initialise actor params and optimiser state.
Expand Down Expand Up @@ -508,7 +508,7 @@ def run_experiment(_config: DictConfig) -> None:
"observation": jnp.zeros(
(
config.system.num_agents,
env.observation_spec().agents_view.shape[1],
env.observation_spec.agents_view.shape[1],
),
dtype=jnp.float32,
),
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/ff_ippo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: anakin
- system: ppo/ff_ippo
- network: mlp # [mlp, cnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
4 changes: 2 additions & 2 deletions mava/configs/default/ff_mappo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ defaults:
- logger: logger
- arch: anakin
- system: ppo/ff_mappo
- network: mlp # [mlp, cnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- network: mlp # [mlp, cnn]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/ff_sable.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: anakin
- system: sable/ff_sable
- network: ff_retention
- env: rware # [cleaner, connector, gigastep, lbf, rware, smax, mpe]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/mat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: anakin
- system: mat/mat
- network: transformer
- env: rware # [gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
4 changes: 2 additions & 2 deletions mava/configs/default/rec_ippo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ defaults:
- logger: logger
- arch: anakin
- system: ppo/rec_ippo
- network: rnn # [rnn, rcnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- network: rnn # [rnn, rcnn]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
4 changes: 2 additions & 2 deletions mava/configs/default/rec_iql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ defaults:
- logger: logger
- arch: anakin
- system: q_learning/rec_iql
- network: rnn # [rnn, rcnn]
- env: smax # [cleaner, connector, gigastep, lbf, matrax, rware, smax]
- network: rnn # [rnn, rcnn]
- env: smax # [cleaner, connector, vector-connector, gigastep, lbf, matrax, rware, smax]

hydra:
searchpath:
Expand Down
4 changes: 2 additions & 2 deletions mava/configs/default/rec_mappo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ defaults:
- logger: logger
- arch: anakin
- system: ppo/rec_mappo
- network: rnn # [rnn, rcnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- network: rnn # [rnn, rcnn]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/rec_sable.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: anakin
- system: sable/rec_sable
- network: rec_retention
- env: rware # [cleaner, connector, gigastep, lbf, rware, smax, mabrax, mpe]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
5 changes: 4 additions & 1 deletion mava/configs/env/connector.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ defaults:
- scenario: con-5x5x3a # [con-5x5x3a, con-7x7x5a, con-10x10x10a, con-15x15x23a]
# Further environment config details in "con-10x10x5a" file.

env_name: MaConnector # Used for logging purposes.
env_name: Connector # Used for logging purposes.
WiemKhlifi marked this conversation as resolved.
Show resolved Hide resolved

# Choose whether to aggregate the list of individual rewards and use the team reward (default setting) OR use_individual_rewards=True.
use_individual_rewards: False # If True, use the list of individual rewards.

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/con-10x10x10a.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# The config of the 10x10x10a scenario
name: MaConnector-v2
name: Connector-v2
task_name: con-10x10x10a

task_config:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/con-15x15x23a.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# The config of the 15x15x23a scenario
name: MaConnector-v2
name: Connector-v2
task_name: con-15x15x23a

task_config:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/con-5x5x3a.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# The config of the 5x5x3a scenario
name: MaConnector-v2
name: Connector-v2
task_name: con-5x5x3a

task_config:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/con-7x7x5a.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# The config of the 7x7x5a scenario
name: MaConnector-v2
name: Connector-v2
task_name: con-7x7x5a

task_config:
Expand Down
5 changes: 4 additions & 1 deletion mava/configs/env/vector-connector.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ defaults:
- scenario: con-5x5x3a # [con-5x5x3a, con-7x7x5a, con-10x10x10a, con-15x15x23a]
# Further environment config details in "con-10x10x5a" file.

env_name: VectorMaConnector # Used for logging purposes.
env_name: VectorConnector # Used for logging purposes.

# Choose whether to aggregate the list of individual rewards and use the team reward (default setting) OR use_individual_rewards=True.
WiemKhlifi marked this conversation as resolved.
Show resolved Hide resolved
use_individual_rewards: True # If True, use the list of individual rewards.

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
Expand Down
4 changes: 2 additions & 2 deletions mava/systems/mat/anakin/mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,10 @@ def learner_setup(
key, actor_net_key = keys

# Initialise observation: Obs for all agents.
init_x = env.observation_spec().generate_value()
init_x = env.observation_spec.generate_value()
init_x = tree.map(lambda x: x[None, ...], init_x)

_, action_space_type = get_action_head(env.action_spec())
_, action_space_type = get_action_head(env.action_spec)

if action_space_type == "discrete":
init_action = jnp.zeros((1, config.system.num_agents), dtype=jnp.int32)
Expand Down
4 changes: 2 additions & 2 deletions mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand All @@ -382,7 +382,7 @@ def learner_setup(
)

# Initialise observation with obs of all agents.
obs = env.observation_spec().generate_value()
obs = env.observation_spec.generate_value()
init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs)

# Initialise actor params and optimiser state.
Expand Down
4 changes: 2 additions & 2 deletions mava/systems/ppo/anakin/ff_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand All @@ -366,7 +366,7 @@ def learner_setup(
)

# Initialise observation with obs of all agents.
obs = env.observation_spec().generate_value()
obs = env.observation_spec.generate_value()
init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs)

# Initialise actor params and optimiser state.
Expand Down
4 changes: 2 additions & 2 deletions mava/systems/ppo/anakin/rec_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def learner_setup(
# Define network and optimisers.
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)
Expand Down Expand Up @@ -487,7 +487,7 @@ def learner_setup(
)

# Initialise observation with obs of all agents.
init_obs = env.observation_spec().generate_value()
init_obs = env.observation_spec.generate_value()
init_obs = tree.map(
lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0),
init_obs,
Expand Down
4 changes: 2 additions & 2 deletions mava/systems/ppo/anakin/rec_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def learner_setup(
# Define network and optimiser.
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)
Expand Down Expand Up @@ -483,7 +483,7 @@ def learner_setup(
)

# Initialise observation with obs of all agents.
init_obs = env.observation_spec().generate_value()
init_obs = env.observation_spec.generate_value()
init_obs = tree.map(
lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0),
init_obs,
Expand Down
4 changes: 2 additions & 2 deletions mava/systems/q_learning/anakin/rec_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def replicate(x: Any) -> Any:
# N: Agent

# Make dummy inputs to init recurrent Q network -> need shape (T, B, N, ...)
init_obs = env.observation_spec().generate_value() # (N, ...)
init_obs = env.observation_spec.generate_value() # (N, ...)
# (B, T, N, ...)
init_obs_batched = tree.map(lambda x: x[jnp.newaxis, jnp.newaxis, ...], init_obs)
init_term_or_trunc = jnp.zeros((1, 1, 1), dtype=bool) # (T, B, 1)
Expand Down Expand Up @@ -130,7 +130,7 @@ def replicate(x: Any) -> Any:
init_hidden_state = replicate(init_hidden_state)

# Create dummy transition
init_acts = env.action_spec().generate_value() # (N,)
init_acts = env.action_spec.generate_value() # (N,)
init_transition = Transition(
obs=init_obs, # (N, ...)
action=init_acts,
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/q_learning/anakin/rec_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def replicate(x: Any) -> Any:
# N: Agent

# Make dummy inputs to init recurrent Q network -> need shape (T, B, N, ...)
init_obs = env.observation_spec().generate_value() # (N, ...)
init_obs = env.observation_spec.generate_value() # (N, ...)
# (B, T, N, ...)
init_obs_batched = tree.map(lambda x: x[jnp.newaxis, jnp.newaxis, ...], init_obs)
init_term_or_trunc = jnp.zeros((1, 1, 1), dtype=bool) # (T, B, 1)
Expand Down Expand Up @@ -126,7 +126,7 @@ def replicate(x: Any) -> Any:
dtype=float,
)
global_env_state_shape = (
env.observation_spec().generate_value().global_state[0, :].shape
env.observation_spec.generate_value().global_state[0, :].shape
) # NOTE: Env wrapper currently duplicates env state for each agent
dummy_global_env_state = jnp.zeros(
(
Expand Down Expand Up @@ -159,7 +159,7 @@ def replicate(x: Any) -> Any:
opt_state = replicate(opt_state)
init_hidden_state = replicate(init_hidden_state)

init_acts = env.action_spec().generate_value()
init_acts = env.action_spec.generate_value()

# NOTE: term_or_trunc refers to the the joint done, ie. when all agents are done or when the
# episode horizon has been reached. We use this exclusively in QMIX.
Expand Down
10 changes: 3 additions & 7 deletions mava/systems/sable/anakin/ff_sable.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,17 +372,13 @@ def learner_setup(
# Get available TPU cores.
n_devices = len(jax.devices())

# Get number of agents.
config.system.num_agents = env.num_agents

# PRNG keys.
key, net_key = keys

# Get number of agents and actions.
action_dim = env.action_dim
n_agents = env.action_spec().shape[0]
n_agents = env.num_agents
config.system.num_agents = n_agents
config.system.num_actions = action_dim
sash-a marked this conversation as resolved.
Show resolved Hide resolved

# Setting the chunksize - many agent problems require chunking agents
# Create a dummy decay factor for FF Sable
Expand All @@ -397,7 +393,7 @@ def learner_setup(
# Set positional encoding to False, since ff-sable does not use temporal dependencies.
config.network.memory_config.timestep_positional_encoding = False

_, action_space_type = get_action_head(env.action_spec())
_, action_space_type = get_action_head(env.action_spec)

# Define network.
sable_network = SableNetwork(
Expand All @@ -417,7 +413,7 @@ def learner_setup(
)

# Get mock inputs to initialise network.
init_obs = env.observation_spec().generate_value()
init_obs = env.observation_spec.generate_value()
init_obs = tree.map(lambda x: x[jnp.newaxis, ...], init_obs) # Add batch dim
init_hs = get_init_hidden_state(config.network.net_config, config.arch.num_envs)
init_hs = tree.map(lambda x: x[0, jnp.newaxis], init_hs)
Expand Down
7 changes: 3 additions & 4 deletions mava/systems/sable/anakin/rec_sable.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,8 @@ def learner_setup(

# Get number of agents and actions.
action_dim = env.action_dim
n_agents = env.action_spec().shape[0]
n_agents = env.num_agents
config.system.num_agents = n_agents
config.system.num_actions = action_dim
WiemKhlifi marked this conversation as resolved.
Show resolved Hide resolved

# Setting the chunksize - smaller chunks save memory at the cost of speed
if config.network.memory_config.timestep_chunk_size:
Expand All @@ -429,7 +428,7 @@ def learner_setup(
else:
config.network.memory_config.chunk_size = config.system.rollout_length * n_agents

_, action_space_type = get_action_head(env.action_spec())
_, action_space_type = get_action_head(env.action_spec)

# Define network.
sable_network = SableNetwork(
Expand All @@ -449,7 +448,7 @@ def learner_setup(
)

# Get mock inputs to initialise network.
init_obs = env.observation_spec().generate_value()
init_obs = env.observation_spec.generate_value()
init_obs = tree.map(lambda x: x[jnp.newaxis, ...], init_obs) # Add batch dim
init_hs = get_init_hidden_state(config.network.net_config, config.arch.num_envs)
init_hs = tree.map(lambda x: x[0, jnp.newaxis], init_hs)
Expand Down
Loading
Loading