Skip to content

Commit

Permalink
change to parallel env
Browse files Browse the repository at this point in the history
  • Loading branch information
jjshoots committed Dec 4, 2023
1 parent 2acf712 commit f210d69
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 158 deletions.
191 changes: 91 additions & 100 deletions PyFlyt/pz_envs/quadx_envs/ma_quadx_base_env.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,30 @@
"""Base Multiagent QuadX Environment for use with the Pettingzoo API."""
from __future__ import annotations

"""Base Multiagent QuadX Environment."""
from copy import deepcopy
from typing import Any

import numpy as np
import pybullet as p
from gymnasium import Space, spaces
from pettingzoo import AECEnv
from pettingzoo.utils import agent_selector
from pettingzoo import ParallelEnv

from PyFlyt.core import Aviary


class MAQuadXBaseEnv(AECEnv):
"""Base Multiagent QuadX Environment for use with the Pettingzoo API.
Args:
start_pos (np.ndarray): start_pos
start_orn (np.ndarray): start_orn
flight_dome_size (float): flight_dome_size
max_duration_seconds (float): max_duration_seconds
angle_representation (str): angle_representation
agent_hz (int): agent_hz
render_mode (None | str): render_mode
"""

metadata = {"render_modes": ["human"], "name": "ma_quadx_hover"}
class MAQuadXBaseEnv(ParallelEnv):
"""MAQuadXBaseEnv."""

def __init__(
self,
start_pos: np.ndarray = np.array([[0.0, 0.0, 1.0]]),
start_orn: np.ndarray = np.array([[0.0, 0.0, 0.0]]),
flight_dome_size: float = np.inf,
max_duration_seconds: float = 10.0,
angle_representation: str = "quaternion",
agent_hz: int = 30,
start_pos: np.ndarray = np.array(
[[-1.0, -1.0, 1.0], [1.0, -1.0, 0.0], [-1.0, 1.0, 0.0], [1.0, 1.0, 0.0]]
),
start_orn: np.ndarray = np.array(
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
),
flight_dome_size: float = 10.0,
max_duration_seconds: float = 30.0,
angle_representation: str = "euler",
agent_hz: int = 40,
render_mode: None | str = None,
):
"""__init__.
Expand Down Expand Up @@ -154,11 +143,14 @@ def __init__(
)
)

def observation_space(self, _):
def observation_space(self, _) -> Space:
"""observation_space.
Args:
_:
Returns:
Space:
"""
raise NotImplementedError

Expand All @@ -169,38 +161,24 @@ def action_space(self, _) -> spaces.Box:
_:
Returns:
Space:
spaces.Box:
"""
return self._action_space

def observe(self, agent: str) -> np.ndarray:
"""observe.
Args:
agent (str): agent
"""
agent_id = self.agent_name_mapping[agent]
return self.observe_by_id(agent_id)

def observe_by_id(self, agent_id: int):
"""observe_by_id.
Args:
agent_id (int): agent_id
"""
raise NotImplementedError

def close(self):
"""close."""
if hasattr(self, "aviary"):
self.aviary.disconnect()

def reset(self, seed=None, options=dict()):
def reset(self, seed=None, options=dict()) -> tuple[dict[str, Any], dict[str, Any]]:
"""reset.
Args:
seed: seed
options: options
seed:
options:
Returns:
tuple[dict[str, Any], dict[str, Any]]: observation and infos
"""
raise NotImplementedError

Expand All @@ -211,19 +189,6 @@ def begin_reset(self, seed=None, options=dict()):
self.aviary.disconnect()
self.step_count = 0
self.agents = self.possible_agents[:]
self.rewards = {agent: 0.0 for agent in self.agents}
self._cumulative_rewards = {agent: 0.0 for agent in self.agents}
self.terminations = {agent: False for agent in self.agents}
self.truncations = {agent: False for agent in self.agents}
self.infos = {agent: dict() for agent in self.agents}
for agent in self.agents:
self.infos[agent]["out_of_bounds"] = False
self.infos[agent]["collision"] = False
self.infos[agent]["env_complete"] = False

# Our agent_selector utility allows easy cyclic stepping through the agents list.
self._agent_selector = agent_selector(self.agents)
self.agent_selection = self._agent_selector.next()

# rebuild the environment
self.aviary = Aviary(
Expand Down Expand Up @@ -275,6 +240,17 @@ def compute_attitude_by_id(

return ang_vel, ang_pos, lin_vel, lin_pos, quaternion

def compute_observation_by_id(self, agent_id: int) -> Any:
"""compute_observation_by_id.
Args:
agent_id (int): agent_id
Returns:
Any:
"""
raise NotImplementedError

def compute_base_term_trunc_reward_info_by_id(
self, agent_id: int
) -> tuple[bool, bool, float, dict[str, Any]]:
Expand Down Expand Up @@ -315,49 +291,64 @@ def compute_term_trunc_reward_info_by_id(
"""
raise NotImplementedError

def step(self, action: np.ndarray):
def step(
self, actions: dict[str, np.ndarray]
) -> tuple[
dict[str, Any],
dict[str, float],
dict[str, bool],
dict[str, bool],
dict[str, dict[str, Any]],
]:
"""step.
Args:
action (np.ndarray): action
"""
agent = self.agent_selection
actions (dict[str, np.ndarray]): actions
# terminate if agent is dead
if self.terminations[agent] or self.truncations[agent]:
self._was_dead_step(action)
return
Returns:
tuple[dict[str, Any], dict[str, float], dict[str, bool], dict[str, bool], dict[str, dict[str, Any]]]:
"""
# copy over the past actions
self.past_actions = deepcopy(self.current_actions)

# set the new actions and send to aviary
for k, v in actions.items():
self.current_actions[self.agent_name_mapping[k]] = v
self.aviary.set_all_setpoints(self.current_actions)

# observation and rewards dictionary
observations = dict()
terminations = {k: False for k in self.agents}
truncations = {k: False for k in self.agents}
rewards = {k: 0.0 for k in self.agents}
infos = {k: dict() for k in self.agents}

# step enough times for one RL step
for _ in range(self.env_step_ratio):
self.aviary.step()

# set the actions, clear the agent's cumulative rewards since it's seen it
self.current_actions[self.agent_name_mapping[agent]] = action
self._cumulative_rewards[agent] = 0
# update reward, term, trunc, for each agent
for ag in self.agents:
# compute term trunc reward
term, trunc, rew, info = self.compute_term_trunc_reward_info_by_id(
self.agent_name_mapping[ag]
)
terminations[ag] |= term
truncations[ag] |= trunc
rewards[ag] += rew
infos[ag] = {**infos[ag], **info}

# compute observations at the end
observations = {
ag: self.compute_observation_by_id(self.agent_name_mapping[ag])
for ag in self.agents
}

# cull dead agents for the next round
self.agents = [
agent
for agent in self.agents
if not (terminations[agent] or truncations[agent])
]

# environment logic
if not self._agent_selector.is_last():
# don't do anything if all agents haven't acted
self._clear_rewards()
else:
# collect reward if it is the last agent to act
self.aviary.set_all_setpoints(self.current_actions)
self.past_actions = deepcopy(self.current_actions)

# step enough times for one RL step
for _ in range(self.env_step_ratio):
self.aviary.step()

# update reward, term, trunc, for each agent
for ag in self.agents:
ag_id = self.agent_name_mapping[ag]

# compute term trunc reward
term, trunc, rew, info = self.compute_term_trunc_reward_info_by_id(
ag_id
)
self.terminations[ag] |= term
self.truncations[ag] |= trunc
self.rewards[ag] += rew
self.infos[ag] = {**self.infos[ag], **info}

# accumulate rewards and select next agent
self.agent_selection = self._agent_selector.next()
self._accumulate_rewards()
return observations, rewards, terminations, truncations, infos
19 changes: 15 additions & 4 deletions PyFlyt/pz_envs/quadx_envs/ma_quadx_hover_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ class MAQuadXHoverEnv(MAQuadXBaseEnv):
render_mode (None | str): render_mode
"""

metadata = {"render_modes": ["human"], "name": "ma_quadx_hover", "is_parallelizable": True}
metadata = {
"render_modes": ["human"],
"name": "ma_quadx_hover",
"is_parallelizable": True,
}

def __init__(
self,
Expand Down Expand Up @@ -82,7 +86,7 @@ def observation_space(self, _):
"""
return self._observation_space

def reset(self, seed=None, options=dict()):
def reset(self, seed=None, options=dict()) -> tuple[dict[str, Any], dict[str, Any]]:
"""reset.
Args:
Expand All @@ -92,8 +96,15 @@ def reset(self, seed=None, options=dict()):
super().begin_reset(seed, options)
super().end_reset(seed, options)

def observe_by_id(self, agent_id: int) -> np.ndarray:
"""observe_by_id.
observations = {
ag: self.compute_observation_by_id(self.agent_name_mapping[ag])
for ag in self.agents
}
infos = {ag: dict() for ag in self.agents}
return observations, infos

def compute_observation_by_id(self, agent_id: int) -> np.ndarray:
"""compute_observation_by_id.
Args:
agent_id (int): agent_id
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "PyFlyt"
version = "0.14.1"
version = "0.15.0"
authors = [
{ name="Jet", email="[email protected]" },
]
Expand Down
64 changes: 11 additions & 53 deletions tests/test_pz_envs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Tests the API compatibility of all PyFlyt Pettingzoo Environments."""
import warnings
from typing import Any

import pytest
from gymnasium.utils.env_checker import data_equivalence
from pettingzoo.test import api_test
from pettingzoo.utils import wrappers
from pettingzoo import ParallelEnv
from pettingzoo.test import parallel_api_test
from pettingzoo.test.seed_test import check_environment_deterministic_parallel

from PyFlyt.pz_envs import MAQuadXHoverEnv

Expand Down Expand Up @@ -41,14 +42,13 @@


@pytest.mark.parametrize("env_config", _ALL_ENV_CONFIGS)
def test_check_env(env_config):
def test_check_env(env_config: tuple[ParallelEnv, dict[str, Any]]):
"""Check that environment pass the pettingzoo api_test."""
env = env_config[0](**env_config[1])
env = wrappers.OrderEnforcingWrapper(env)
env = env_config[0](**env_config[1]) # pyright: ignore

with warnings.catch_warnings(record=True) as caught_warnings:
print(caught_warnings)
api_test(env)
parallel_api_test(env, num_cycles=1000)

for warning_message in caught_warnings:
assert isinstance(warning_message.message, Warning)
Expand All @@ -59,51 +59,9 @@ def test_check_env(env_config):


@pytest.mark.parametrize("env_config", _ALL_ENV_CONFIGS)
def test_seeding(env_config):
def test_seeding(env_config: tuple[ParallelEnv, dict[str, Any]]):
"""Check that two AEC environments execute the same way."""
env1 = env_config[0](**env_config[1])
env2 = env_config[0](**env_config[1])
env1 = wrappers.OrderEnforcingWrapper(env1)
env2 = wrappers.OrderEnforcingWrapper(env2)
env1.reset(seed=42)
env2.reset(seed=42)
env1 = env_config[0](**env_config[1]) # pyright: ignore
env2 = env_config[0](**env_config[1]) # pyright: ignore

for i in env1.agents:
env1.action_space(i).seed(seed=42)
env2.action_space(i).seed(seed=42)
env1.observation_space(i).seed(seed=42)
env2.observation_space(i).seed(seed=42)

iterations = 0
for agent1, agent2 in zip(env1.agent_iter(), env2.agent_iter()):
assert data_equivalence(agent1, agent2), f"Incorrect agent: {agent1} {agent2}"

obs1, reward1, termination1, truncation1, info1 = env1.last()
obs2, reward2, termination2, truncation2, info2 = env2.last()

assert data_equivalence(obs1, obs2), "Incorrect observation"
assert data_equivalence(reward1, reward2), "Incorrect reward."
assert data_equivalence(termination1, termination2), "Incorrect termination."
assert data_equivalence(truncation1, truncation2), "Incorrect truncation."
assert data_equivalence(info1, info2), "Incorrect info."

if termination1 or truncation1:
break

action1 = env1.action_space(agent1).sample()
action2 = env2.action_space(agent2).sample()

assert data_equivalence(
action1, action2
), f"Incorrect actions: {action1} {action2}"

env1.step(action1)
env2.step(action2)

iterations += 1

if iterations >= 100:
break

env1.close()
env2.close()
check_environment_deterministic_parallel(env1, env2, num_cycles=1000)

0 comments on commit f210d69

Please sign in to comment.