Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielBarton446 committed Apr 21, 2024
1 parent f1c9cf4 commit 8857501
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 47 deletions.
81 changes: 51 additions & 30 deletions shogi-ai/agents/mcts_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
"""

import concurrent.futures
import math
import os
import random
import time
from typing import List, Optional
from typing import List, Optional, Tuple

from agents.agent import Agent
from env.environment import Environment
Expand Down Expand Up @@ -42,9 +44,11 @@ def get_child_from_move(self, move: Move):
"""
Fetch a child move if it exists for the given move.
"""
for child in self.children:
all_children = self.all_subchild_nodes()
for child in all_children:
if child.move == move:
return child
# logger.warning(f"Couldn't find {move} in current tree:\n {[node.move for node in self.all_subchild_nodes()]}")
return None

def all_subchild_nodes(self) -> List["Node"]:
Expand Down Expand Up @@ -87,43 +91,60 @@ def __init__(self, env: Environment, player: int, strategy=None):
strategy = "mcts"
self.time_limit = 10
self.tree = Node(move=None, parent=None)
self.games_simulated = 0
self.total_games_simulated = 0
self.positions_checked = 0
self.rollouts = 0
self.exploration_coefficient = 1.41
super().__init__(env=env, player=player, strategy=strategy)

def current_board_sims(self) -> int:
return self.tree.visits

def select_action(self, board: Optional[Board] = None) -> Move:
self._env.board = board or self._env.board

if self.player != self.env.board.turn:
raise ValueError("Not the MCTS_AGENT's turn")

self.tree = Node(move=None, parent=None)
self.games_simulated = 0
start_time = time.time()
time_delta = 0.0

# Seed initial expansion
self._expansion(self.env.board, self.tree)

# this really shouldn't be done here, and should be a config
num_workers = 1
number_of_cores = os.cpu_count()
if number_of_cores:
num_workers = number_of_cores - 2

while time_delta < self.time_limit:
time_delta = time.time() - start_time
node_to_simulate = self._selection()
self._simulation(node_to_simulate)
self.games_simulated += 1
with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
futures = []
for proc in range(num_workers):
futures.append(executor.submit(self._simulation, self._selection([self.tree])))
results = [future.result() for future in concurrent.futures.as_completed(futures, self.time_limit)]
for res in results:
self._backpropagation(self.tree.get_child_from_move(res[0]), res[1])

logger.debug("%s", [child.visits for child in self.tree.children])

self.total_games_simulated += self.current_board_sims()
logger.info("Games simulated: %i", self.tree.visits)

# select the immediate child (depth 1) with the most visits
# as we revisit the most promising nodes
best_node = max(self.tree.children, key=lambda n: n.visits)

logger.info("Games simulated: %i", self.games_simulated)
logger.info("Selected move: %s", best_node.move)

return best_node.move

def _selection(self) -> Node:
selection_queue: List[Node] = self.tree.all_subchild_nodes()
def _selection(self, nodes: List[Node]) -> Node:
selection_queue: List[Node] = []
for node in nodes:
selection_queue.extend(node.all_subchild_nodes())
random.shuffle(selection_queue)

max_uct_ucb1 = float("-inf")
Expand Down Expand Up @@ -153,7 +174,24 @@ def _selection(self) -> Node:

return node_to_rollout

def _simulation(self, node_to_rollout: Node):
def _utility(self, board_copy: Board):
if board_copy.is_checkmate():
if board_copy.turn != self.player:
return 1
return -1
return 0

def _rollout(self, board_copy: Board) -> int:
# we just play moves after we get to the current move position
self.rollouts += 1

while not board_copy.is_game_over():
new_random_move = self._random_move(board_copy)
board_copy.push(new_random_move)

return self._utility(board_copy)

def _simulation(self, node_to_rollout: Node) -> Tuple[Move, int]:
board_copy = Board(self._env.board.sfen())

# make all the moves to the board that got us to this node.
Expand All @@ -171,24 +209,7 @@ def _simulation(self, node_to_rollout: Node):
self._expansion(board_copy, node_to_rollout.parent)
value = self._rollout(board_copy=board_copy)

self._backpropagation(node_to_rollout, value)

def _rollout(self, board_copy: Board) -> int:
# we just play moves after we get to the current move position
self.rollouts += 1

while not board_copy.is_game_over():
new_random_move = self._random_move(board_copy)
board_copy.push(new_random_move)

return self._utility(board_copy)

def _utility(self, board_copy: Board):
if board_copy.is_checkmate():
if board_copy.turn != self.player:
return 1
return -1
return 0
return (node_to_rollout.move, value)

def _backpropagation(self, leaf_node: Node, value: int):
leaf_node.value += value
Expand Down
17 changes: 2 additions & 15 deletions shogi-ai/evaluation/grimbergen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,8 @@
"""

from shogi import (
BISHOP,
GOLD,
KNIGHT,
LANCE,
PAWN,
PROM_BISHOP,
PROM_KNIGHT,
PROM_LANCE,
PROM_PAWN,
PROM_ROOK,
PROM_SILVER,
ROOK,
SILVER,
)
from shogi import (BISHOP, GOLD, KNIGHT, LANCE, PAWN, PROM_BISHOP, PROM_KNIGHT,
PROM_LANCE, PROM_PAWN, PROM_ROOK, PROM_SILVER, ROOK, SILVER)

GRIMBERGEN = {
PAWN: 1,
Expand Down
4 changes: 2 additions & 2 deletions shogi-ai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def main() -> None:
agent1_action: Move = agent1.select_action(board)
board.push(agent1_action)
print(f"Agent 1 move: {agent1_action}")
print(f"Games simulated: {agent1.games_simulated}")
print(f"Games simulated: {agent1.current_board_sims()}")
print(board)
print(f"Move: {len(board.move_stack)}")
if board.is_game_over():
Expand All @@ -43,7 +43,7 @@ def main() -> None:
print(board)
print(f"Player {board.turn} Lost!")
print(f"Number of moves {len(board.move_stack)}")
print(f"Simulated games: {agent1.games_simulated}")
print(f"Simulated games: {agent1.total_games_simulated}")
print(f"Rollouts: {agent1.rollouts}")
print(f"Positions Checked: {agent1.positions_checked}")
with open("game.txt", "w", encoding="utf-8") as f:
Expand Down

0 comments on commit 8857501

Please sign in to comment.