Skip to content

Commit

Permalink
finish adding the rest of doubles gimmick handling, prepping for debu…
Browse files Browse the repository at this point in the history
…gging
  • Loading branch information
cameronangliss committed Jan 5, 2025
1 parent 1f7b506 commit 9db5d70
Showing 1 changed file with 35 additions and 19 deletions.
54 changes: 35 additions & 19 deletions src/poke_env/player/gymnasium_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def action_to_order(action: ActionType, battle: AbstractBattle) -> BattleOrder:
43 <= element <= 46: move with target = 2 and terastallize
"""
try:
print(action)
if isinstance(battle, Battle):
assert isinstance(action, (int, np.integer))
a = action.item() if isinstance(action, np.integer) else action
Expand Down Expand Up @@ -467,6 +468,7 @@ def _singles_action_to_order(action: int, battle: Battle) -> BattleOrder:
and battle.available_moves[0].id in ["struggle", "recharge"]
else list(active_mon.moves.values())
)
print(active_mon.base_species, [m.id for m in mvs])
order = Player.create_order(
mvs[(action - 6) % 4],
mega=battle.can_mega_evolve and 10 <= action < 14,
Expand Down Expand Up @@ -528,11 +530,15 @@ def _doubles_action_to_order_individual(
and battle.available_moves[pos][0].id in ["struggle", "recharge"]
else list(active_mon.moves.values())
)
print(pos, active_mon.base_species, [m.id for m in mvs])
order = Player.create_order(

Check warning on line 534 in src/poke_env/player/gymnasium_api.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/gymnasium_api.py#L533-L534

Added lines #L533 - L534 were not covered by tests
mvs[(action - 7) % 4],
terastallize=battle.can_tera[pos] is not None
and bool((action - 7) // 20),
move_target=(action - 7) % 20 // 4 - 2,
mega=battle.can_mega_evolve[pos] and (action - 7) // 20 == 1,
z_move=battle.can_z_move[pos] and (action - 7) // 20 == 2,
dynamax=battle.can_dynamax[pos] and (action - 7) // 20 == 3,
terastallize=battle.can_tera[pos] is not None
and (action - 7) // 20 == 4,
)
assert isinstance(order.order, Move)
assert order.order.id in [

Check warning on line 544 in src/poke_env/player/gymnasium_api.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/gymnasium_api.py#L543-L544

Added lines #L543 - L544 were not covered by tests
Expand All @@ -542,6 +548,9 @@ def _doubles_action_to_order_individual(
assert order.move_target in battle.get_possible_showdown_targets(

Check warning on line 548 in src/poke_env/player/gymnasium_api.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/gymnasium_api.py#L547-L548

Added lines #L547 - L548 were not covered by tests
move, active_mon
), "invalid pick"
assert not order.mega or battle.can_mega_evolve[pos], "invalid pick"
assert not order.z_move or battle.can_z_move[pos], "invalid pick"
assert not order.dynamax or battle.can_dynamax[pos], "invalid pick"
assert (

Check warning on line 554 in src/poke_env/player/gymnasium_api.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/gymnasium_api.py#L551-L554

Added lines #L551 - L554 were not covered by tests
not order.terastallize or battle.can_tera[pos] is not False
), "invalid pick"
Expand Down Expand Up @@ -580,6 +589,7 @@ def _singles_order_to_action(order: BattleOrder, battle: Battle) -> np.int64:
and battle.available_moves[0].id in ["struggle", "recharge"]
else list(active_mon.moves.values())
)
print(order.message, [m.id for m in mvs])
action = [m.id for m in mvs].index(order.order.id)
if order.mega:
gimmick = 1
Expand Down Expand Up @@ -638,20 +648,30 @@ def _doubles_order_to_action_individual(
and battle.available_moves[pos][0].id in ["struggle", "recharge"]
else list(active_mon.moves.values())
)
action = mvs.index(order.order)
print(order.message, [m.id for m in mvs])
action = [m.id for m in mvs].index(order.order.id)
target = order.move_target + 2
if order.terastallize:
if order.mega:
gimmick = 1
elif order.z_move:
gimmick = 2
elif order.dynamax:
gimmick = 3
elif order.terastallize:
gimmick = 4

Check warning on line 661 in src/poke_env/player/gymnasium_api.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/gymnasium_api.py#L651-L661

Added lines #L651 - L661 were not covered by tests
else:
gimmick = 0
action = 6 + action + 4 * target + 20 * gimmick
action = 1 + 6 + action + 4 * target + 20 * gimmick
assert order.order.id in [

Check warning on line 665 in src/poke_env/player/gymnasium_api.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/gymnasium_api.py#L663-L665

Added lines #L663 - L665 were not covered by tests
m.id for m in battle.available_moves[pos]
], "invalid pick"
move = [m for m in battle.available_moves[pos] if m.id == order.order.id][0]
assert order.move_target in battle.get_possible_showdown_targets(

Check warning on line 669 in src/poke_env/player/gymnasium_api.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/gymnasium_api.py#L668-L669

Added lines #L668 - L669 were not covered by tests
move, active_mon
), "invalid pick"
assert not order.mega or battle.can_mega_evolve[pos], "invalid pick"
assert not order.z_move or battle.can_z_move[pos], "invalid pick"
assert not order.dynamax or battle.can_dynamax[pos], "invalid pick"
assert (

Check warning on line 675 in src/poke_env/player/gymnasium_api.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/gymnasium_api.py#L672-L675

Added lines #L672 - L675 were not covered by tests
not order.terastallize or battle.can_tera[pos] is not False
), "invalid pick"
Expand Down Expand Up @@ -784,22 +804,18 @@ def get_action_space_size(battle_format: str):
or "metronome" in format_lowercase
):
num_targets = 5

Check warning on line 806 in src/poke_env/player/gymnasium_api.py

View check run for this annotation

Codecov / codecov/patch

src/poke_env/player/gymnasium_api.py#L806

Added line #L806 was not covered by tests
if format_lowercase.startswith("gen9"):
num_gimmicks = 1
else:
num_gimmicks = 0
else:
num_targets = 1
if format_lowercase.startswith("gen6"):
num_gimmicks = 1
elif format_lowercase.startswith("gen7"):
num_gimmicks = 2
elif format_lowercase.startswith("gen8"):
num_gimmicks = 3
elif format_lowercase.startswith("gen9"):
num_gimmicks = 4
else:
num_gimmicks = 0
if format_lowercase.startswith("gen6"):
num_gimmicks = 1
elif format_lowercase.startswith("gen7"):
num_gimmicks = 2
elif format_lowercase.startswith("gen8"):
num_gimmicks = 3
elif format_lowercase.startswith("gen9"):
num_gimmicks = 4
else:
num_gimmicks = 0
return num_switches + num_moves * num_targets * (num_gimmicks + 1)

@staticmethod
Expand Down

0 comments on commit 9db5d70

Please sign in to comment.