diff --git a/shogi-ai/main.py b/shogi-ai/main.py index 0a45921..80ab46c 100644 --- a/shogi-ai/main.py +++ b/shogi-ai/main.py @@ -29,6 +29,7 @@ def main() -> None: agent2: RandomAgent = RandomAgent(env, player=1) + agent1_action: Move = agent1.select_action() while not board.is_game_over(): agent1_action: Move = agent1.select_action() board.push(agent1_action) @@ -44,7 +45,11 @@ def main() -> None: print("Final State of board:") print(board) + print(f"Player {board.turn} Lost!") print(f"Number of moves {len(board.move_stack)}") + print(f"Simulated games: {agent1.games_simulated}") + print(f"Rollouts: {agent1.rollouts}") + print(f"Positions Checked: {agent1.positions_checked}") with open("game.txt", "w") as f: for move in board.move_stack: f.write(str(move) + "\n") diff --git a/shogi-ai/mcts_agent.py b/shogi-ai/mcts_agent.py index 5706917..3c0af7b 100644 --- a/shogi-ai/mcts_agent.py +++ b/shogi-ai/mcts_agent.py @@ -1,6 +1,7 @@ import time import random +import math from shogi import Board from shogi import Move @@ -16,20 +17,15 @@ class Node: def __init__(self, move: Optional[Move], parent=None): self.move = move self.parent = parent - self.explored_children = [] - self.visits = 1 + self.children = [] + self.visits = 0 self.value = 0 if parent: - parent.explored_children.append(self) - - p = self.parent - while p: - p.visits += 1 - p = p.parent + parent.children.append(self) def get_child_from_move(self, move: Move): - for child in self.explored_children: + for child in self.children: if child.move == move: return child return None @@ -55,6 +51,9 @@ def __init__(self, env: Environment, player: int, strategy=None): self.time_limit = 5 self.tree = Node(move=None, parent=None) self.games_simulated = 0 + self.positions_checked = 0 + self.rollouts = 0 + self.exploration_coefficient = 1.41 super().__init__(env=env, player=player, strategy=strategy) def select_action(self): @@ -66,42 +65,70 @@ def select_action(self): start_time = time.time() time_delta = 0 - while time_delta < self.time_limit: + # Seed initial expansion + self._expansion(self.env.board, self.tree) + + while time_delta < self.time_limit and self.games_simulated < 1000: time_delta = time.time() - start_time - self._simulation() + node_to_simulate = self._selection() + self._simulation(node_to_simulate) self.games_simulated += 1 - best_node = max(self.tree.explored_children, key=lambda n: n.value) + # select the immediate child (depth 1) with the most visits + # as we revisit the most promising nodes + best_node = max(self.tree.children, key=lambda n: n.visits) return best_node.move - def _simulation(self): - base_board = Board(self._env.board.sfen()) - node = self._expansion(board=base_board) - if node.visits == 0: - value = self._rollout(board_copy=base_board, move=node.move) + def _selection(self) -> Node: + queue = [] + queue.extend(self.tree.children) + max_uct_ucb1 = float("-inf") + node_to_rollout = None + + while queue: + current_node: Node = queue.pop(0) + + curr_node_visits = max(1, current_node.visits) + tree_visits = max(1, self.tree.visits) + + uct_ucb1 = current_node.value / curr_node_visits + self.exploration_coefficient * math.sqrt( + math.log(tree_visits) / curr_node_visits + ) + if uct_ucb1 > max_uct_ucb1: + max_uct_ucb1 = uct_ucb1 + node_to_rollout = current_node + queue.extend(current_node.children) + + return node_to_rollout + + def _simulation(self, node_to_rollout: Node): + board_copy = Board(self._env.board.sfen()) + + # make all the moves to the board that got us to this node. + moves_stack = [] + inspect = node_to_rollout + while inspect.parent: + moves_stack.append(inspect.move) + inspect = inspect.parent + while moves_stack: + board_copy.push(moves_stack.pop()) + + if node_to_rollout.visits == 0: + value = self._rollout(board_copy=board_copy) else: - board_after_explored_node = Board(base_board.sfen()) - board_after_explored_node.push(node.move) - - # need to consider when we pick a move that - # loses us the game - if board_after_explored_node.is_game_over(): - value = self._utility(board_after_explored_node) - self._backpropagation(node, value) - - expanded_move = self._select_random_move(board=board_after_explored_node) - value = self._rollout(board_copy=board_after_explored_node, - move=expanded_move) + self._expansion(board_copy, node_to_rollout) + value = self._rollout(board_copy=board_copy) - self._backpropagation(node, value) + self._backpropagation(node_to_rollout, value) - def _rollout(self, board_copy: Board, move: Move) -> int: - board_copy.push(move) + def _rollout(self, board_copy: Board) -> int: + # we just play moves after we get to the current move position + self.rollouts += 1 - while not board_copy.is_game_over(): - move = self._select_random_move(board_copy) - board_copy.push(move) + while not board_copy.is_game_over() and board_copy.move_number < 100: + new_random_move = self._random_move(board_copy) + board_copy.push(new_random_move) return self._utility(board_copy) @@ -113,23 +140,28 @@ def _utility(self, board_copy: Board): return -1 return 0 - def _backpropagation(self, node: Node, value: int): - node.value += value - - def _select_random_move(self, board: Board) -> Move: - moves = [move for move in board.legal_moves] - return random.choice(moves) - - def _expansion(self, board: Board) -> Node: - move = self._select_random_move(board) - - existing_node = self.tree.get_child_from_move(move) - if existing_node: - existing_node.visits += 1 - return existing_node + def _backpropagation(self, leaf_node: Node, value: int): + leaf_node.value += value + p = leaf_node.parent + while p: + p.visits += 1 + p.value += value + p = p.parent - node = Node(move, self.tree) - return node + def _random_move(self, board: Board) -> Move: + self.positions_checked += 1 + moves = [move for move in board.pseudo_legal_moves] + move = None + while True: + move = random.choice(moves) + # only need to ensure we dont win with dropping a pawn + if not board.was_check_by_dropping_pawn(move): + break + return move + + def _expansion(self, board: Board, parent_node: Node) -> None: + move = self._random_move(board) + node = Node(move=move, parent=parent_node) @classmethod def from_board(cls, board: Board):