Skip to content

Commit

Permalink
[Feature] Add Stack transform
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler committed Nov 27, 2024
1 parent d537dcb commit a74d3a4
Show file tree
Hide file tree
Showing 7 changed files with 757 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ conda deactivate && conda activate ./env
python -c "import mlagents_envs"

python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestUnityMLAgents --runslow
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_transforms.py --instafail -v --durations 200 --capture no -k test_transform_env[unity]

coverage combine
coverage xml -i
157 changes: 155 additions & 2 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
from typing import Dict, List, Optional

import torch
import torch.nn as nn
Expand All @@ -24,7 +24,12 @@
from torchrl.data.utils import consolidate_spec
from torchrl.envs.common import EnvBase
from torchrl.envs.model_based.common import ModelBasedEnvBase
from torchrl.envs.utils import _terminated_or_truncated
from torchrl.envs.utils import (
_terminated_or_truncated,
check_marl_grouping,
MarlGroupMapType,
)


spec_dict = {
"bounded": Bounded,
Expand Down Expand Up @@ -1059,6 +1064,154 @@ def _step(
return tensordict


class MultiAgentCountingEnv(EnvBase):
"""A multi-agent env that is done after a given number of steps.
All agents have identical specs.
The count is incremented by 1 on each step.
"""

def __init__(
self,
n_agents: int,
group_map: MarlGroupMapType
| Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP,
max_steps: int = 5,
start_val: int = 0,
**kwargs,
):
super().__init__(**kwargs)
self.max_steps = max_steps
self.start_val = start_val
self.n_agents = n_agents
self.agent_names = [f"agent_{idx}" for idx in range(n_agents)]

if isinstance(group_map, MarlGroupMapType):
group_map = group_map.get_group_map(self.agent_names)
check_marl_grouping(group_map, self.agent_names)

self.group_map = group_map

observation_specs = {}
reward_specs = {}
done_specs = {}
action_specs = {}

for group_name, agents in group_map.items():
observation_specs[group_name] = {}
reward_specs[group_name] = {}
done_specs[group_name] = {}
action_specs[group_name] = {}

for agent_name in agents:
observation_specs[group_name][agent_name] = Composite(
observation=Unbounded(
(
*self.batch_size,
3,
4,
),
dtype=torch.float32,
device=self.device,
),
shape=self.batch_size,
device=self.device,
)
reward_specs[group_name][agent_name] = Composite(
reward=Unbounded(
(
*self.batch_size,
1,
),
device=self.device,
),
shape=self.batch_size,
device=self.device,
)
done_specs[group_name][agent_name] = Composite(
done=Categorical(
2,
dtype=torch.bool,
shape=(
*self.batch_size,
1,
),
device=self.device,
),
shape=self.batch_size,
device=self.device,
)
action_specs[group_name][agent_name] = Composite(
action=Binary(n=1, shape=[*self.batch_size, 1], device=self.device),
shape=self.batch_size,
device=self.device,
)

self.observation_spec = Composite(observation_specs)
self.reward_spec = Composite(reward_specs)
self.done_spec = Composite(done_specs)
self.action_spec = Composite(action_specs)
self.register_buffer(
"count",
torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int),
)

def _set_seed(self, seed: Optional[int]):
torch.manual_seed(seed)

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
if tensordict is not None and "_reset" in tensordict.keys():
_reset = tensordict.get("_reset")
self.count[_reset] = self.start_val
else:
self.count[:] = self.start_val

source = {}
for group_name, agents in self.group_map.items():
source[group_name] = {}
for agent_name in agents:
source[group_name][agent_name] = TensorDict(
source={
"observation": torch.rand(
(*self.batch_size, 3, 4), device=self.device
),
"done": self.count > self.max_steps,
"terminated": self.count > self.max_steps,
},
batch_size=self.batch_size,
device=self.device,
)

tensordict = TensorDict(source, batch_size=self.batch_size, device=self.device)
return tensordict

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
self.count += 1
source = {}
for group_name, agents in self.group_map.items():
source[group_name] = {}
for agent_name in agents:
source[group_name][agent_name] = TensorDict(
source={
"observation": torch.rand(
(*self.batch_size, 3, 4), device=self.device
),
"done": self.count > self.max_steps,
"terminated": self.count > self.max_steps,
"reward": torch.zeros_like(self.count, dtype=torch.float),
},
batch_size=self.batch_size,
device=self.device,
)
tensordict = TensorDict(source, batch_size=self.batch_size, device=self.device)
return tensordict


class IncrementingEnv(CountingEnv):
# Same as CountingEnv but always increments the count by 1 regardless of the action.
def _step(
Expand Down
Loading

0 comments on commit a74d3a4

Please sign in to comment.