diff --git a/integration_tests/test_env.py b/integration_tests/test_env.py index 833687021..5d28b6dda 100644 --- a/integration_tests/test_env.py +++ b/integration_tests/test_env.py @@ -21,12 +21,19 @@ def embed_battle(self, battle): return np.array([0]) -def play_function(env, n_battles): +def play_function(env: CIEnv, n_battles: int): for _ in range(n_battles): done = False env.reset() while not done: + assert env.battle1 is not None + assert env.battle2 is not None actions = {name: env.action_space(name).sample() for name in env.agents} + [a1, a2] = list(actions.values()) + o1 = CIEnv.action_to_order(a1, env.battle1) + o2 = CIEnv.action_to_order(a2, env.battle2) + assert a1 == CIEnv.order_to_action(o1, env.battle1) + assert a2 == CIEnv.order_to_action(o2, env.battle2) _, _, terminated, truncated, _ = env.step(actions) done = any(terminated.values()) or any(truncated.values()) diff --git a/src/poke_env/player/env.py b/src/poke_env/player/env.py index f20e6b21b..f72fa083d 100644 --- a/src/poke_env/player/env.py +++ b/src/poke_env/player/env.py @@ -360,10 +360,10 @@ def close(self, purge: bool = True): ) closing_task.result() - def observation_space(self, agent: str) -> Space: + def observation_space(self, agent: str) -> Space[ObsType]: return self.observation_spaces[agent] - def action_space(self, agent: str) -> Space: + def action_space(self, agent: str) -> Space[ActionType]: return self.action_spaces[agent] ###################################################################################