From f210d699cb82498b89d1f6dd7be2b9cdfe6093d6 Mon Sep 17 00:00:00 2001 From: Jet Date: Mon, 4 Dec 2023 13:59:45 +0000 Subject: [PATCH] change to parallel env --- .../pz_envs/quadx_envs/ma_quadx_base_env.py | 191 +++++++++--------- .../pz_envs/quadx_envs/ma_quadx_hover_env.py | 19 +- pyproject.toml | 2 +- tests/test_pz_envs.py | 64 +----- 4 files changed, 118 insertions(+), 158 deletions(-) diff --git a/PyFlyt/pz_envs/quadx_envs/ma_quadx_base_env.py b/PyFlyt/pz_envs/quadx_envs/ma_quadx_base_env.py index c009ba9b..d69746a2 100644 --- a/PyFlyt/pz_envs/quadx_envs/ma_quadx_base_env.py +++ b/PyFlyt/pz_envs/quadx_envs/ma_quadx_base_env.py @@ -1,41 +1,30 @@ -"""Base Multiagent QuadX Environment for use with the Pettingzoo API.""" -from __future__ import annotations - +"""Base Multiagent QuadX Environment.""" from copy import deepcopy from typing import Any import numpy as np import pybullet as p from gymnasium import Space, spaces -from pettingzoo import AECEnv -from pettingzoo.utils import agent_selector +from pettingzoo import ParallelEnv from PyFlyt.core import Aviary -class MAQuadXBaseEnv(AECEnv): - """Base Multiagent QuadX Environment for use with the Pettingzoo API. - - Args: - start_pos (np.ndarray): start_pos - start_orn (np.ndarray): start_orn - flight_dome_size (float): flight_dome_size - max_duration_seconds (float): max_duration_seconds - angle_representation (str): angle_representation - agent_hz (int): agent_hz - render_mode (None | str): render_mode - """ - - metadata = {"render_modes": ["human"], "name": "ma_quadx_hover"} +class MAQuadXBaseEnv(ParallelEnv): + """MAQuadXBaseEnv.""" def __init__( self, - start_pos: np.ndarray = np.array([[0.0, 0.0, 1.0]]), - start_orn: np.ndarray = np.array([[0.0, 0.0, 0.0]]), - flight_dome_size: float = np.inf, - max_duration_seconds: float = 10.0, - angle_representation: str = "quaternion", - agent_hz: int = 30, + start_pos: np.ndarray = np.array( + [[-1.0, -1.0, 1.0], [1.0, -1.0, 0.0], [-1.0, 1.0, 0.0], [1.0, 1.0, 0.0]] + ), + start_orn: np.ndarray = np.array( + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] + ), + flight_dome_size: float = 10.0, + max_duration_seconds: float = 30.0, + angle_representation: str = "euler", + agent_hz: int = 40, render_mode: None | str = None, ): """__init__. @@ -154,11 +143,14 @@ def __init__( ) ) - def observation_space(self, _): + def observation_space(self, _) -> Space: """observation_space. Args: _: + + Returns: + Space: """ raise NotImplementedError @@ -169,38 +161,24 @@ def action_space(self, _) -> spaces.Box: _: Returns: - Space: + spaces.Box: """ return self._action_space - def observe(self, agent: str) -> np.ndarray: - """observe. - - Args: - agent (str): agent - """ - agent_id = self.agent_name_mapping[agent] - return self.observe_by_id(agent_id) - - def observe_by_id(self, agent_id: int): - """observe_by_id. - - Args: - agent_id (int): agent_id - """ - raise NotImplementedError - def close(self): """close.""" if hasattr(self, "aviary"): self.aviary.disconnect() - def reset(self, seed=None, options=dict()): + def reset(self, seed=None, options=dict()) -> tuple[dict[str, Any], dict[str, Any]]: """reset. Args: - seed: seed - options: options + seed: + options: + + Returns: + tuple[dict[str, Any], dict[str, Any]]: observation and infos """ raise NotImplementedError @@ -211,19 +189,6 @@ def begin_reset(self, seed=None, options=dict()): self.aviary.disconnect() self.step_count = 0 self.agents = self.possible_agents[:] - self.rewards = {agent: 0.0 for agent in self.agents} - self._cumulative_rewards = {agent: 0.0 for agent in self.agents} - self.terminations = {agent: False for agent in self.agents} - self.truncations = {agent: False for agent in self.agents} - self.infos = {agent: dict() for agent in self.agents} - for agent in self.agents: - self.infos[agent]["out_of_bounds"] = False - self.infos[agent]["collision"] = False - self.infos[agent]["env_complete"] = False - - # Our agent_selector utility allows easy cyclic stepping through the agents list. - self._agent_selector = agent_selector(self.agents) - self.agent_selection = self._agent_selector.next() # rebuild the environment self.aviary = Aviary( @@ -275,6 +240,17 @@ def compute_attitude_by_id( return ang_vel, ang_pos, lin_vel, lin_pos, quaternion + def compute_observation_by_id(self, agent_id: int) -> Any: + """compute_observation_by_id. + + Args: + agent_id (int): agent_id + + Returns: + Any: + """ + raise NotImplementedError + def compute_base_term_trunc_reward_info_by_id( self, agent_id: int ) -> tuple[bool, bool, float, dict[str, Any]]: @@ -315,49 +291,64 @@ def compute_term_trunc_reward_info_by_id( """ raise NotImplementedError - def step(self, action: np.ndarray): + def step( + self, actions: dict[str, np.ndarray] + ) -> tuple[ + dict[str, Any], + dict[str, float], + dict[str, bool], + dict[str, bool], + dict[str, dict[str, Any]], + ]: """step. Args: - action (np.ndarray): action - """ - agent = self.agent_selection + actions (dict[str, np.ndarray]): actions - # terminate if agent is dead - if self.terminations[agent] or self.truncations[agent]: - self._was_dead_step(action) - return + Returns: + tuple[dict[str, Any], dict[str, float], dict[str, bool], dict[str, bool], dict[str, dict[str, Any]]]: + """ + # copy over the past actions + self.past_actions = deepcopy(self.current_actions) + + # set the new actions and send to aviary + for k, v in actions.items(): + self.current_actions[self.agent_name_mapping[k]] = v + self.aviary.set_all_setpoints(self.current_actions) + + # observation and rewards dictionary + observations = dict() + terminations = {k: False for k in self.agents} + truncations = {k: False for k in self.agents} + rewards = {k: 0.0 for k in self.agents} + infos = {k: dict() for k in self.agents} + + # step enough times for one RL step + for _ in range(self.env_step_ratio): + self.aviary.step() - # set the actions, clear the agent's cumulative rewards since it's seen it - self.current_actions[self.agent_name_mapping[agent]] = action - self._cumulative_rewards[agent] = 0 + # update reward, term, trunc, for each agent + for ag in self.agents: + # compute term trunc reward + term, trunc, rew, info = self.compute_term_trunc_reward_info_by_id( + self.agent_name_mapping[ag] + ) + terminations[ag] |= term + truncations[ag] |= trunc + rewards[ag] += rew + infos[ag] = {**infos[ag], **info} + + # compute observations at the end + observations = { + ag: self.compute_observation_by_id(self.agent_name_mapping[ag]) + for ag in self.agents + } + + # cull dead agents for the next round + self.agents = [ + agent + for agent in self.agents + if not (terminations[agent] or truncations[agent]) + ] - # environment logic - if not self._agent_selector.is_last(): - # don't do anything if all agents haven't acted - self._clear_rewards() - else: - # collect reward if it is the last agent to act - self.aviary.set_all_setpoints(self.current_actions) - self.past_actions = deepcopy(self.current_actions) - - # step enough times for one RL step - for _ in range(self.env_step_ratio): - self.aviary.step() - - # update reward, term, trunc, for each agent - for ag in self.agents: - ag_id = self.agent_name_mapping[ag] - - # compute term trunc reward - term, trunc, rew, info = self.compute_term_trunc_reward_info_by_id( - ag_id - ) - self.terminations[ag] |= term - self.truncations[ag] |= trunc - self.rewards[ag] += rew - self.infos[ag] = {**self.infos[ag], **info} - - # accumulate rewards and select next agent - self.agent_selection = self._agent_selector.next() - self._accumulate_rewards() + return observations, rewards, terminations, truncations, infos diff --git a/PyFlyt/pz_envs/quadx_envs/ma_quadx_hover_env.py b/PyFlyt/pz_envs/quadx_envs/ma_quadx_hover_env.py index 9b86f261..ee059203 100644 --- a/PyFlyt/pz_envs/quadx_envs/ma_quadx_hover_env.py +++ b/PyFlyt/pz_envs/quadx_envs/ma_quadx_hover_env.py @@ -26,7 +26,11 @@ class MAQuadXHoverEnv(MAQuadXBaseEnv): render_mode (None | str): render_mode """ - metadata = {"render_modes": ["human"], "name": "ma_quadx_hover", "is_parallelizable": True} + metadata = { + "render_modes": ["human"], + "name": "ma_quadx_hover", + "is_parallelizable": True, + } def __init__( self, @@ -82,7 +86,7 @@ def observation_space(self, _): """ return self._observation_space - def reset(self, seed=None, options=dict()): + def reset(self, seed=None, options=dict()) -> tuple[dict[str, Any], dict[str, Any]]: """reset. Args: @@ -92,8 +96,15 @@ def reset(self, seed=None, options=dict()): super().begin_reset(seed, options) super().end_reset(seed, options) - def observe_by_id(self, agent_id: int) -> np.ndarray: - """observe_by_id. + observations = { + ag: self.compute_observation_by_id(self.agent_name_mapping[ag]) + for ag in self.agents + } + infos = {ag: dict() for ag in self.agents} + return observations, infos + + def compute_observation_by_id(self, agent_id: int) -> np.ndarray: + """compute_observation_by_id. Args: agent_id (int): agent_id diff --git a/pyproject.toml b/pyproject.toml index b1677b67..ab4876fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "PyFlyt" -version = "0.14.1" +version = "0.15.0" authors = [ { name="Jet", email="taijunjet@hotmail.com" }, ] diff --git a/tests/test_pz_envs.py b/tests/test_pz_envs.py index 590d517d..c7623684 100644 --- a/tests/test_pz_envs.py +++ b/tests/test_pz_envs.py @@ -1,10 +1,11 @@ """Tests the API compatibility of all PyFlyt Pettingzoo Environments.""" import warnings +from typing import Any import pytest -from gymnasium.utils.env_checker import data_equivalence -from pettingzoo.test import api_test -from pettingzoo.utils import wrappers +from pettingzoo import ParallelEnv +from pettingzoo.test import parallel_api_test +from pettingzoo.test.seed_test import check_environment_deterministic_parallel from PyFlyt.pz_envs import MAQuadXHoverEnv @@ -41,14 +42,13 @@ @pytest.mark.parametrize("env_config", _ALL_ENV_CONFIGS) -def test_check_env(env_config): +def test_check_env(env_config: tuple[ParallelEnv, dict[str, Any]]): """Check that environment pass the pettingzoo api_test.""" - env = env_config[0](**env_config[1]) - env = wrappers.OrderEnforcingWrapper(env) + env = env_config[0](**env_config[1]) # pyright: ignore with warnings.catch_warnings(record=True) as caught_warnings: print(caught_warnings) - api_test(env) + parallel_api_test(env, num_cycles=1000) for warning_message in caught_warnings: assert isinstance(warning_message.message, Warning) @@ -59,51 +59,9 @@ def test_check_env(env_config): @pytest.mark.parametrize("env_config", _ALL_ENV_CONFIGS) -def test_seeding(env_config): +def test_seeding(env_config: tuple[ParallelEnv, dict[str, Any]]): """Check that two AEC environments execute the same way.""" - env1 = env_config[0](**env_config[1]) - env2 = env_config[0](**env_config[1]) - env1 = wrappers.OrderEnforcingWrapper(env1) - env2 = wrappers.OrderEnforcingWrapper(env2) - env1.reset(seed=42) - env2.reset(seed=42) + env1 = env_config[0](**env_config[1]) # pyright: ignore + env2 = env_config[0](**env_config[1]) # pyright: ignore - for i in env1.agents: - env1.action_space(i).seed(seed=42) - env2.action_space(i).seed(seed=42) - env1.observation_space(i).seed(seed=42) - env2.observation_space(i).seed(seed=42) - - iterations = 0 - for agent1, agent2 in zip(env1.agent_iter(), env2.agent_iter()): - assert data_equivalence(agent1, agent2), f"Incorrect agent: {agent1} {agent2}" - - obs1, reward1, termination1, truncation1, info1 = env1.last() - obs2, reward2, termination2, truncation2, info2 = env2.last() - - assert data_equivalence(obs1, obs2), "Incorrect observation" - assert data_equivalence(reward1, reward2), "Incorrect reward." - assert data_equivalence(termination1, termination2), "Incorrect termination." - assert data_equivalence(truncation1, truncation2), "Incorrect truncation." - assert data_equivalence(info1, info2), "Incorrect info." - - if termination1 or truncation1: - break - - action1 = env1.action_space(agent1).sample() - action2 = env2.action_space(agent2).sample() - - assert data_equivalence( - action1, action2 - ), f"Incorrect actions: {action1} {action2}" - - env1.step(action1) - env2.step(action2) - - iterations += 1 - - if iterations >= 100: - break - - env1.close() - env2.close() + check_environment_deterministic_parallel(env1, env2, num_cycles=1000)