Skip to content

Commit

Permalink
feat: better mcts and some logging
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielBarton446 committed Apr 16, 2024
1 parent ffc5af5 commit dc07634
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 29 deletions.
6 changes: 5 additions & 1 deletion shogi-ai/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Agent class is the base class for all agents.
"""

from typing import Optional

from env.environment import Environment
from shogi import Board

Expand All @@ -17,9 +19,11 @@ def __init__(self, env: Environment, player: int, strategy=None):
self.player = player
self.strategy = strategy

def select_action(self):
def select_action(self, board: Optional[Board] = None):
"""
Select an action based on the state of the environment.
Optionally provide a board to select an action from and update
the environment accordingly.
NotImplementedError: This method must be implemented by the subclass
"""
Expand Down
72 changes: 60 additions & 12 deletions shogi-ai/agents/mcts_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from agents.agent import Agent
from env.environment import Environment
from shogi import Board, Move
from util.common import get_logger

logger = get_logger(__name__)
logger.setLevel("DEBUG")


class Node: # pylint: disable=too-few-public-methods
Expand All @@ -28,6 +32,8 @@ def __init__(self, move: Optional[Move], parent=None):
self.children: List[Node] = []
self.visits = 0
self.value = 0
self.ucb1 = float("inf")
self.expanded = False

if parent:
parent.children.append(self)
Expand All @@ -41,6 +47,27 @@ def get_child_from_move(self, move: Move):
return child
return None

def all_subchild_nodes(self) -> List["Node"]:
"""
returns a list of all subchild nodes. EXCLUDING the root.
"""
bfs_queue = [self]
nodes = []
while bfs_queue:
current_node = bfs_queue.pop()
nodes.append(current_node)
bfs_queue.extend(current_node.children)
nodes.remove(self) # remove the root node

return nodes


def __repr__(self):
msg = f"Node from -- Move: {self.move} - Visits: {self.visits}\n"
for child in self.children:
msg += f"{[str(child.move) for child in self.children]}\n"
return msg


class MctsAgent(Agent):
"""
Expand All @@ -67,7 +94,9 @@ def __init__(self, env: Environment, player: int, strategy=None):
self.exploration_coefficient = 1.41
super().__init__(env=env, player=player, strategy=strategy)

def select_action(self):
def select_action(self, board: Optional[Board] = None) -> Move:
self._env.board = board or self._env.board

if self.player != self.env.board.turn:
raise ValueError("Not the MCTS_AGENT's turn")

Expand All @@ -84,34 +113,39 @@ def select_action(self):
node_to_simulate = self._selection()
self._simulation(node_to_simulate)
self.games_simulated += 1
logger.debug(f"{[child.visits for child in self.tree.children]}")

# select the immediate child (depth 1) with the most visits
# as we revisit the most promising nodes
best_node = max(self.tree.children, key=lambda n: n.visits)

logger.info(f"Games simulated: {self.games_simulated}")
logger.info(f"Selected move: {best_node.move}")
return best_node.move

def _selection(self) -> Node:
queue = []
queue.extend(self.tree.children)
selection_queue: List[Node] = self.tree.all_subchild_nodes()

max_uct_ucb1 = float("-inf")
node_to_rollout = None

while queue:
current_node: Node = queue.pop(0)
while selection_queue:
current_node: Node = selection_queue.pop()

curr_node_visits = max(1, current_node.visits)
if current_node.visits == 0:
return current_node
tree_visits = max(1, self.tree.visits)

uct_ucb1 = (
current_node.value / curr_node_visits
current_node.value / current_node.visits
+ self.exploration_coefficient
* math.sqrt(math.log(tree_visits) / curr_node_visits)
* math.sqrt(math.log(tree_visits) / current_node.visits)
)
current_node.ucb1 = uct_ucb1
if uct_ucb1 > max_uct_ucb1:
max_uct_ucb1 = uct_ucb1
node_to_rollout = current_node
queue.extend(current_node.children)
selection_queue.extend(current_node.children)

if node_to_rollout is None:
raise ValueError("No node to rollout")
Expand All @@ -133,7 +167,7 @@ def _simulation(self, node_to_rollout: Node):
if node_to_rollout.visits == 0:
value = self._rollout(board_copy=board_copy)
else:
self._expansion(board_copy, node_to_rollout)
self._expansion(board_copy, node_to_rollout.parent)
value = self._rollout(board_copy=board_copy)

self._backpropagation(node_to_rollout, value)
Expand All @@ -157,6 +191,7 @@ def _utility(self, board_copy: Board):

def _backpropagation(self, leaf_node: Node, value: int):
leaf_node.value += value
leaf_node.visits += 1 # turn out its important to say we visited this node
p = leaf_node.parent
while p:
p.visits += 1
Expand All @@ -175,8 +210,21 @@ def _random_move(self, board: Board) -> Move:
return move

def _expansion(self, board: Board, parent_node: Node) -> None:
move = self._random_move(board)
Node(move=move, parent=parent_node)
# Seems really gross that we have to check if this layer
# of the tree has already been expanded. We might
# be able to just check if the move is fetchable already in
# the tree and by that assertion we can know that we dont
# need to create new nodes.
if parent_node.expanded:
return

for legal_move in board.legal_moves:
if self.tree.get_child_from_move(legal_move) is not None:
logger.warning(f"Already expanded this legal move somehow: {legal_move}")
continue
else:
Node(move=legal_move, parent=parent_node)
parent_node.expanded = True

@classmethod
def from_board(cls, board: Board):
Expand Down
17 changes: 2 additions & 15 deletions shogi-ai/evaluation/grimbergen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,8 @@
"""

from shogi import (
BISHOP,
GOLD,
KNIGHT,
LANCE,
PAWN,
PROM_BISHOP,
PROM_KNIGHT,
PROM_LANCE,
PROM_PAWN,
PROM_ROOK,
PROM_SILVER,
ROOK,
SILVER,
)
from shogi import (BISHOP, GOLD, KNIGHT, LANCE, PAWN, PROM_BISHOP, PROM_KNIGHT,
PROM_LANCE, PROM_PAWN, PROM_ROOK, PROM_SILVER, ROOK, SILVER)

GRIMBERGEN = {
PAWN: 1,
Expand Down
2 changes: 1 addition & 1 deletion shogi-ai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main() -> None:
agent2: RandomAgent = RandomAgent(env, player=1)

while not board.is_game_over():
agent1_action: Move = agent1.select_action()
agent1_action: Move = agent1.select_action(board)
board.push(agent1_action)
print(f"Agent 1 move: {agent1_action}")
print(f"Games simulated: {agent1.games_simulated}")
Expand Down
Empty file added shogi-ai/util/__init__.py
Empty file.
19 changes: 19 additions & 0 deletions shogi-ai/util/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import logging
import os
from datetime import datetime
from logging import getLogger
from pathlib import Path


def get_logger(name: str) -> logging.Logger:
logger = getLogger(name)

logdir = Path(os.environ.get("VIRTUAL_ENV")) / "logs"
if not logdir.exists():
logdir.mkdir()

form = '%(asctime)s.%(msecs)03d - %(name)s - %(levelname)s - %(message)s'
time = datetime.now().strftime("%Y%m%d.%H%M%S.%f")
logging.basicConfig(filename=f"{logdir}/{name}.{time}.log",
level=logging.INFO, format=form, datefmt='%H:%M:%S')
return logger

0 comments on commit dc07634

Please sign in to comment.