diff --git a/integration_tests/test_baselines.py b/integration_tests/test_baselines.py index e2f820b08..33cca3a1d 100644 --- a/integration_tests/test_baselines.py +++ b/integration_tests/test_baselines.py @@ -23,7 +23,7 @@ async def test_random_players(): players = [RandomPlayer(), RandomPlayer()] await asyncio.wait_for( simple_cross_evaluation(5, players=players), - timeout=5, + timeout=10, ) @@ -33,19 +33,19 @@ async def test_random_players_in_doubles(): RandomPlayer(battle_format="gen9randomdoublesbattle"), RandomPlayer(battle_format="gen9randomdoublesbattle"), ] - await asyncio.wait_for(simple_cross_evaluation(5, players=players), timeout=5) + await asyncio.wait_for(simple_cross_evaluation(5, players=players), timeout=10) @pytest.mark.asyncio async def test_shp(): players = [RandomPlayer(), SimpleHeuristicsPlayer()] - await asyncio.wait_for(simple_cross_evaluation(5, players=players), timeout=5) + await asyncio.wait_for(simple_cross_evaluation(5, players=players), timeout=10) @pytest.mark.asyncio async def test_max_base_power(): players = [RandomPlayer(), MaxBasePowerPlayer()] - await asyncio.wait_for(simple_cross_evaluation(5, players=players), timeout=5) + await asyncio.wait_for(simple_cross_evaluation(5, players=players), timeout=10) @pytest.mark.asyncio @@ -54,4 +54,4 @@ async def test_max_base_power_in_doubles(): RandomPlayer(battle_format="gen9randomdoublesbattle"), MaxBasePowerPlayer(battle_format="gen9randomdoublesbattle"), ] - await asyncio.wait_for(simple_cross_evaluation(5, players=players), timeout=5) + await asyncio.wait_for(simple_cross_evaluation(5, players=players), timeout=10) diff --git a/src/poke_env/environment/battle.py b/src/poke_env/environment/battle.py index 655baf728..5f77f3fc1 100644 --- a/src/poke_env/environment/battle.py +++ b/src/poke_env/environment/battle.py @@ -94,6 +94,16 @@ def parse_request(self, request: Dict[str, Any]) -> None: self._teampreview = False self._update_team_from_request(request["side"]) + if self.active_pokemon is not None: + active_mon = self.get_pokemon( + request["side"]["pokemon"][0]["ident"], + force_self_team=True, + details=request["side"]["pokemon"][0]["details"], + ) + if active_mon != self.active_pokemon: + self.active_pokemon.switch_out() + active_mon.switch_in() + if "active" in request: active_request = request["active"][0] diff --git a/src/poke_env/environment/double_battle.py b/src/poke_env/environment/double_battle.py index 07f1e2f4e..76490e7ce 100644 --- a/src/poke_env/environment/double_battle.py +++ b/src/poke_env/environment/double_battle.py @@ -130,6 +130,13 @@ def parse_request(self, request: Dict[str, Any]) -> None: if side["pokemon"]: self._player_role = side["pokemon"][0]["ident"][:2] self._update_team_from_request(side) + if self.player_role is not None: + self._active_pokemon[f"{self.player_role}a"] = self.team[ + request["side"]["pokemon"][0]["ident"] + ] + self._active_pokemon[f"{self.player_role}b"] = self.team[ + request["side"]["pokemon"][1]["ident"] + ] if "active" in request: for active_pokemon_number, active_request in enumerate(request["active"]): @@ -139,41 +146,6 @@ def parse_request(self, request: Dict[str, Any]) -> None: force_self_team=True, details=pokemon_dict["details"], ) - if self.player_role is not None: - if ( - active_pokemon_number == 0 - and f"{self.player_role}a" not in self._active_pokemon - ): - self._active_pokemon[f"{self.player_role}a"] = active_pokemon - elif f"{self.player_role}b" not in self._active_pokemon: - self._active_pokemon[f"{self.player_role}b"] = active_pokemon - elif ( - active_pokemon_number == 0 - and self._active_pokemon[f"{self.player_role}a"].fainted - and self._active_pokemon[f"{self.player_role}b"] - == active_pokemon - ): - ( - self._active_pokemon[f"{self.player_role}a"], - self._active_pokemon[f"{self.player_role}b"], - ) = ( - self._active_pokemon[f"{self.player_role}b"], - self._active_pokemon[f"{self.player_role}a"], - ) - elif ( - active_pokemon_number == 1 - and self._active_pokemon[f"{self.player_role}b"].fainted - and not active_pokemon.fainted - and self._active_pokemon[f"{self.player_role}a"] - == active_pokemon - ): - ( - self._active_pokemon[f"{self.player_role}a"], - self._active_pokemon[f"{self.player_role}b"], - ) = ( - self._active_pokemon[f"{self.player_role}b"], - self._active_pokemon[f"{self.player_role}a"], - ) if active_pokemon.fainted: continue diff --git a/src/poke_env/environment/pokemon.py b/src/poke_env/environment/pokemon.py index 7d9b260d5..47ca9a5f0 100644 --- a/src/poke_env/environment/pokemon.py +++ b/src/poke_env/environment/pokemon.py @@ -378,6 +378,7 @@ def set_hp_status(self, hp_status: str, store=False): self.end_effect("yawn") else: hp = hp_status + self._status = None current_hp, max_hp = "".join([c for c in hp if c in "0123456789/"]).split("/") self._current_hp = int(current_hp) diff --git a/src/poke_env/player/baselines.py b/src/poke_env/player/baselines.py index 60a9e02f2..e5d077c28 100644 --- a/src/poke_env/player/baselines.py +++ b/src/poke_env/player/baselines.py @@ -1,5 +1,5 @@ import random -from typing import List +from typing import List, Optional from poke_env.environment.abstract_battle import AbstractBattle from poke_env.environment.double_battle import DoubleBattle @@ -29,7 +29,7 @@ def choose_singles_move(self, battle: AbstractBattle): return self.choose_random_move(battle) def choose_doubles_move(self, battle: DoubleBattle): - orders: List[BattleOrder] = [] + orders: List[Optional[BattleOrder]] = [] switched_in = None if any(battle.force_switch): @@ -51,7 +51,7 @@ def choose_doubles_move(self, battle: DoubleBattle): switches = [s for s in switches if s != switched_in] if not mon or mon.fainted: - orders.append(DefaultBattleOrder()) + orders.append(None) continue elif not moves and switches: mon_to_switch_in = random.choice(switches) diff --git a/src/poke_env/player/battle_order.py b/src/poke_env/player/battle_order.py index 249002489..da84806a5 100644 --- a/src/poke_env/player/battle_order.py +++ b/src/poke_env/player/battle_order.py @@ -73,9 +73,9 @@ def message(self) -> str: + self.second_order.message.replace("/choose ", "") ) elif self.first_order: - return self.first_order.message + ", default" + return self.first_order.message + ", pass" elif self.second_order: - return self.second_order.message + ", default" + return "/choose pass, " + self.second_order.message.replace("/choose ", "") else: return self.DEFAULT_ORDER @@ -95,10 +95,11 @@ def join_orders(first_orders: List[BattleOrder], second_orders: List[BattleOrder if orders: return orders elif first_orders: - return [DoubleBattleOrder(first_order=order) for order in first_orders] + return [DoubleBattleOrder(order, None) for order in first_orders] elif second_orders: - return [DoubleBattleOrder(first_order=order) for order in second_orders] - return [DefaultBattleOrder()] + return [DoubleBattleOrder(None, order) for order in second_orders] + else: + return [DoubleBattleOrder(None, None)] class ForfeitBattleOrder(BattleOrder): diff --git a/src/poke_env/player/player.py b/src/poke_env/player/player.py index fd26f9eb7..7546e88cc 100644 --- a/src/poke_env/player/player.py +++ b/src/poke_env/player/player.py @@ -264,6 +264,9 @@ async def _handle_battle_message(self, split_messages: List[List[str]]): :type split_message: str """ # Battle messages can be multiline + should_process_request = False + is_from_teampreview = False + should_maybe_default = False if ( len(split_messages) > 1 and len(split_messages[1]) > 1 @@ -286,7 +289,7 @@ async def _handle_battle_message(self, split_messages: List[List[str]]): request = orjson.loads(split_message[2]) battle.parse_request(request) if battle.move_on_next_request: - await self._handle_battle_request(battle) + should_process_request = True battle.move_on_next_request = False elif split_message[1] == "win" or split_message[1] == "tie": if split_message[1] == "win": @@ -306,7 +309,7 @@ async def _handle_battle_message(self, split_messages: List[List[str]]): "[Invalid choice] Sorry, too late to make a different move" ): if battle.trapped: - await self._handle_battle_request(battle) + should_process_request = True elif split_message[2].startswith( "[Unavailable choice] Can't switch: The active Pokémon is " "trapped" @@ -314,38 +317,45 @@ async def _handle_battle_message(self, split_messages: List[List[str]]): "[Invalid choice] Can't switch: The active Pokémon is trapped" ): battle.trapped = True - await self._handle_battle_request(battle) + should_process_request = True elif split_message[2].startswith( "[Invalid choice] Can't switch: You can't switch to an active " "Pokémon" ): - await self._handle_battle_request(battle, maybe_default_order=True) + should_process_request = True + should_maybe_default = True elif split_message[2].startswith( "[Invalid choice] Can't switch: You can't switch to a fainted " "Pokémon" ): - await self._handle_battle_request(battle, maybe_default_order=True) + should_process_request = True + should_maybe_default = True elif split_message[2].startswith( "[Invalid choice] Can't move: Invalid target for" ): - await self._handle_battle_request(battle, maybe_default_order=True) + should_process_request = True + should_maybe_default = True elif split_message[2].startswith( "[Invalid choice] Can't move: You can't choose a target for" ): - await self._handle_battle_request(battle, maybe_default_order=True) + should_process_request = True + should_maybe_default = True elif split_message[2].startswith( "[Invalid choice] Can't move: " ) and split_message[2].endswith("needs a target"): - await self._handle_battle_request(battle, maybe_default_order=True) + should_process_request = True + should_maybe_default = True elif ( split_message[2].startswith("[Invalid choice] Can't move: Your") and " doesn't have a move matching " in split_message[2] ): - await self._handle_battle_request(battle, maybe_default_order=True) + should_process_request = True + should_maybe_default = True elif split_message[2].startswith( "[Invalid choice] Incomplete choice: " ): - await self._handle_battle_request(battle, maybe_default_order=True) + should_process_request = True + should_maybe_default = True elif split_message[2].startswith( "[Unavailable choice]" ) and split_message[2].endswith("is disabled"): @@ -358,25 +368,34 @@ async def _handle_battle_message(self, split_messages: List[List[str]]): "[Invalid choice] Can't move: You sent more choices than unfainted" " Pokémon." ): - await self._handle_battle_request(battle, maybe_default_order=True) + should_process_request = True + should_maybe_default = True elif split_message[2].startswith( "[Invalid choice] Can't move: You can only Terastallize once per battle." ): - await self._handle_battle_request(battle, maybe_default_order=True) + should_process_request = True + should_maybe_default = True else: self.logger.critical("Unexpected error message: %s", split_message) elif split_message[1] == "turn": battle.parse_message(split_message) - await self._handle_battle_request(battle) + should_process_request = True elif split_message[1] == "teampreview": battle.parse_message(split_message) - await self._handle_battle_request(battle, from_teampreview_request=True) + should_process_request = True + is_from_teampreview = True elif split_message[1] == "bigerror": self.logger.warning("Received 'bigerror' message: %s", split_message) elif split_message[1] == "uhtml" and split_message[2] == "otsrequest": await self._handle_ots_request(battle.battle_tag) else: battle.parse_message(split_message) + if should_process_request: + await self._handle_battle_request( + battle, + from_teampreview_request=is_from_teampreview, + maybe_default_order=should_maybe_default, + ) async def _handle_battle_request( self, @@ -532,9 +551,7 @@ def choose_random_doubles_move(battle: DoubleBattle) -> BattleOrder: second_switch_in = random.choice(available_switches) second_order = BattleOrder(second_switch_in) - if first_order and second_order: - return DoubleBattleOrder(first_order, second_order) - return DoubleBattleOrder(first_order or second_order, None) + return DoubleBattleOrder(first_order, second_order) for ( orders, diff --git a/src/poke_env/ps_client/ps_client.py b/src/poke_env/ps_client/ps_client.py index 8199b7da6..233d7a164 100644 --- a/src/poke_env/ps_client/ps_client.py +++ b/src/poke_env/ps_client/ps_client.py @@ -7,7 +7,7 @@ from asyncio import CancelledError, Event, Lock, create_task, sleep from logging import Logger from time import perf_counter -from typing import Any, List, Optional, Set +from typing import Any, Dict, List, Optional, Set import requests import websockets.client as ws @@ -85,6 +85,7 @@ def __init__( self._sending_lock = create_in_poke_loop(Lock) self.websocket: WebSocketClientProtocol + self.reqs: Dict[str, List[List[str]]] = {} self._logger: Logger = self._create_logger(log_level) if start_listening: @@ -139,8 +140,20 @@ async def _handle_message(self, message: str): # For battles, this is the zero-th entry # Otherwise it is the one-th entry if split_messages[0][0].startswith(">battle"): + # Determine protocol and request + battle_tag = split_messages[0][0][1:] + request = self.reqs.pop(battle_tag, None) + if "|request|" in message: + protocol = None + self.reqs[battle_tag] = split_messages + else: + protocol = split_messages # Battle update - await self._handle_battle_message(split_messages) # type: ignore + if protocol is not None or request is not None: + split_messages = protocol or [[f">{battle_tag}"]] + if request is not None: + split_messages += [request[1]] + await self._handle_battle_message(split_messages) # type: ignore elif split_messages[0][1] == "challstr": # Confirms connection to the server: we can login await self.log_in(split_messages[0]) diff --git a/unit_tests/player/test_battle_orders.py b/unit_tests/player/test_battle_orders.py index 9cac1a927..e5d1a741c 100644 --- a/unit_tests/player/test_battle_orders.py +++ b/unit_tests/player/test_battle_orders.py @@ -35,10 +35,8 @@ def test_double_orders(): DoubleBattleOrder(mon, move).message == "/choose switch lugia, move selfdestruct 2" ) - assert DoubleBattleOrder(mon).message == "/choose switch lugia, default" - assert ( - DoubleBattleOrder(None, move).message == "/choose move selfdestruct 2, default" - ) + assert DoubleBattleOrder(mon).message == "/choose switch lugia, pass" + assert DoubleBattleOrder(None, move).message == "/choose pass, move selfdestruct 2" assert DoubleBattleOrder().message == "/choose default" orders = [move, mon] @@ -53,12 +51,12 @@ def test_double_orders(): "/choose switch lugia, move selfdestruct 2", } assert first == { - "/choose move selfdestruct 2, default", - "/choose switch lugia, default", + "/choose move selfdestruct 2, pass", + "/choose switch lugia, pass", } assert second == { - "/choose move selfdestruct 2, default", - "/choose switch lugia, default", + "/choose pass, move selfdestruct 2", + "/choose pass, switch lugia", } assert none == {"/choose default"} diff --git a/unit_tests/player/test_doubles_baselines.py b/unit_tests/player/test_doubles_baselines.py index 81910a902..b7ec13430 100644 --- a/unit_tests/player/test_doubles_baselines.py +++ b/unit_tests/player/test_doubles_baselines.py @@ -17,11 +17,11 @@ def test_doubles_max_damage_player(): battle._active_pokemon["p1a"] = active_pikachu # calls player.choose_random_doubles_move(battle) - assert player.choose_move(battle).message == "/choose default, default" + assert player.choose_move(battle).message == "/choose default, pass" # calls player.choose_random_doubles_move(battle) battle._available_switches[0].append(Pokemon(species="ponyta", gen=8)) - assert player.choose_move(battle).message == "/choose switch ponyta, default" + assert player.choose_move(battle).message == "/choose switch ponyta, pass" active_raichu = Pokemon(species="raichu", gen=8) active_raichu.switch_in() @@ -65,12 +65,12 @@ def test_doubles_max_damage_player(): # forced switch battle._force_switch = [True, False] assert player.choose_move(battle).message in [ - "/choose switch ponyta, default", + "/choose switch ponyta, pass", ] battle._force_switch = [False, True] assert player.choose_move(battle).message in [ - "/choose switch rapidash, default", + "/choose pass, switch rapidash", ] battle._force_switch = [True, True] diff --git a/unit_tests/ps_client/test_ps_client.py b/unit_tests/ps_client/test_ps_client.py index 1bca5bfb4..2ddf354ba 100644 --- a/unit_tests/ps_client/test_ps_client.py +++ b/unit_tests/ps_client/test_ps_client.py @@ -107,8 +107,11 @@ async def test_handle_message(): client._update_challenges.assert_called_once_with(["", "updatechallenges"]) client._handle_battle_message = AsyncMock() - await client._handle_message(">battle|thing") - client._handle_battle_message.assert_called_once_with([[">battle", "thing"]]) + await client._handle_message(">battle\n|request|request-thing") + await client._handle_message(">battle\n|turn|15") + client._handle_battle_message.assert_called_once_with( + [[">battle"], ["", "turn", "15"], ["", "request", "request-thing"]] + ) await client._handle_message("|updatesearch")