Skip to content

Commit

Permalink
fix: actually implements mcts
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielBarton446 committed Apr 8, 2024
1 parent 3d4992a commit f802bbc
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 51 deletions.
5 changes: 5 additions & 0 deletions shogi-ai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def main() -> None:

agent2: RandomAgent = RandomAgent(env, player=1)

agent1_action: Move = agent1.select_action()
while not board.is_game_over():
agent1_action: Move = agent1.select_action()
board.push(agent1_action)
Expand All @@ -44,7 +45,11 @@ def main() -> None:

print("Final State of board:")
print(board)
print(f"Player {board.turn} Lost!")
print(f"Number of moves {len(board.move_stack)}")
print(f"Simulated games: {agent1.games_simulated}")
print(f"Rollouts: {agent1.rollouts}")
print(f"Positions Checked: {agent1.positions_checked}")
with open("game.txt", "w") as f:
for move in board.move_stack:
f.write(str(move) + "\n")
Expand Down
134 changes: 83 additions & 51 deletions shogi-ai/mcts_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import time
import random
import math

from shogi import Board
from shogi import Move
Expand All @@ -16,20 +17,15 @@ class Node:
def __init__(self, move: Optional[Move], parent=None):
self.move = move
self.parent = parent
self.explored_children = []
self.visits = 1
self.children = []
self.visits = 0
self.value = 0

if parent:
parent.explored_children.append(self)

p = self.parent
while p:
p.visits += 1
p = p.parent
parent.children.append(self)

def get_child_from_move(self, move: Move):
for child in self.explored_children:
for child in self.children:
if child.move == move:
return child
return None
Expand All @@ -55,6 +51,9 @@ def __init__(self, env: Environment, player: int, strategy=None):
self.time_limit = 5
self.tree = Node(move=None, parent=None)
self.games_simulated = 0
self.positions_checked = 0
self.rollouts = 0
self.exploration_coefficient = 1.41
super().__init__(env=env, player=player, strategy=strategy)

def select_action(self):
Expand All @@ -66,42 +65,70 @@ def select_action(self):
start_time = time.time()
time_delta = 0

while time_delta < self.time_limit:
# Seed initial expansion
self._expansion(self.env.board, self.tree)

while time_delta < self.time_limit and self.games_simulated < 1000:
time_delta = time.time() - start_time
self._simulation()
node_to_simulate = self._selection()
self._simulation(node_to_simulate)
self.games_simulated += 1

best_node = max(self.tree.explored_children, key=lambda n: n.value)
# select the immediate child (depth 1) with the most visits
# as we revisit the most promising nodes
best_node = max(self.tree.children, key=lambda n: n.visits)

return best_node.move

def _simulation(self):
base_board = Board(self._env.board.sfen())
node = self._expansion(board=base_board)
if node.visits == 0:
value = self._rollout(board_copy=base_board, move=node.move)
def _selection(self) -> Node:
queue = []
queue.extend(self.tree.children)
max_uct_ucb1 = float("-inf")
node_to_rollout = None

while queue:
current_node: Node = queue.pop(0)

curr_node_visits = max(1, current_node.visits)
tree_visits = max(1, self.tree.visits)

uct_ucb1 = current_node.value / curr_node_visits + self.exploration_coefficient * math.sqrt(
math.log(tree_visits) / curr_node_visits
)
if uct_ucb1 > max_uct_ucb1:
max_uct_ucb1 = uct_ucb1
node_to_rollout = current_node
queue.extend(current_node.children)

return node_to_rollout

def _simulation(self, node_to_rollout: Node):
board_copy = Board(self._env.board.sfen())

# make all the moves to the board that got us to this node.
moves_stack = []
inspect = node_to_rollout
while inspect.parent:
moves_stack.append(inspect.move)
inspect = inspect.parent
while moves_stack:
board_copy.push(moves_stack.pop())

if node_to_rollout.visits == 0:
value = self._rollout(board_copy=board_copy)
else:
board_after_explored_node = Board(base_board.sfen())
board_after_explored_node.push(node.move)

# need to consider when we pick a move that
# loses us the game
if board_after_explored_node.is_game_over():
value = self._utility(board_after_explored_node)
self._backpropagation(node, value)

expanded_move = self._select_random_move(board=board_after_explored_node)
value = self._rollout(board_copy=board_after_explored_node,
move=expanded_move)
self._expansion(board_copy, node_to_rollout)
value = self._rollout(board_copy=board_copy)

self._backpropagation(node, value)
self._backpropagation(node_to_rollout, value)

def _rollout(self, board_copy: Board, move: Move) -> int:
board_copy.push(move)
def _rollout(self, board_copy: Board) -> int:
# we just play moves after we get to the current move position
self.rollouts += 1

while not board_copy.is_game_over():
move = self._select_random_move(board_copy)
board_copy.push(move)
while not board_copy.is_game_over() and board_copy.move_number < 100:
new_random_move = self._random_move(board_copy)
board_copy.push(new_random_move)

return self._utility(board_copy)

Expand All @@ -113,23 +140,28 @@ def _utility(self, board_copy: Board):
return -1
return 0

def _backpropagation(self, node: Node, value: int):
node.value += value

def _select_random_move(self, board: Board) -> Move:
moves = [move for move in board.legal_moves]
return random.choice(moves)

def _expansion(self, board: Board) -> Node:
move = self._select_random_move(board)

existing_node = self.tree.get_child_from_move(move)
if existing_node:
existing_node.visits += 1
return existing_node
def _backpropagation(self, leaf_node: Node, value: int):
leaf_node.value += value
p = leaf_node.parent
while p:
p.visits += 1
p.value += value
p = p.parent

node = Node(move, self.tree)
return node
def _random_move(self, board: Board) -> Move:
self.positions_checked += 1
moves = [move for move in board.pseudo_legal_moves]
move = None
while True:
move = random.choice(moves)
# only need to ensure we dont win with dropping a pawn
if not board.was_check_by_dropping_pawn(move):
break
return move

def _expansion(self, board: Board, parent_node: Node) -> None:
move = self._random_move(board)
node = Node(move=move, parent=parent_node)

@classmethod
def from_board(cls, board: Board):
Expand Down

0 comments on commit f802bbc

Please sign in to comment.