Skip to content

Commit

Permalink
feat: mcts agent work
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielBarton446 committed Apr 7, 2024
1 parent 8cbd79b commit 4483594
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 8 deletions.
3 changes: 2 additions & 1 deletion shogi-ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ class Agent:
Agent should not be used for anything other than inheritance.
"""

def __init__(self, env: Environment, strategy=None):
def __init__(self, env: Environment, player: int, strategy=None):
self._env = env
self.player = player
self.strategy = strategy

def select_action(self):
Expand Down
15 changes: 14 additions & 1 deletion shogi-ai/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def __init__(self, board: Board):
self._moves: List = []
self._last_state = board.piece_bb

def get_player(self):
"""
Get the player of the environment.
"""
return self.board.turn

@property
def action_space(self):
"""
Expand All @@ -40,7 +46,14 @@ def action_space(self):
self._moves.append(move)
return self._moves

def from_board(self, board: Board):
def is_terminal_state(self) -> bool:
"""
Check if the game is in a terminal state.
"""
return self.board.is_game_over()

@classmethod
def from_board(cls, board: Board):
"""
currently just is another constructor basically.
"""
Expand Down
16 changes: 14 additions & 2 deletions shogi-ai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,33 @@
"""

import shogi
import sys
from environment import Environment
from random_agent import RandomAgent
from mcts_agent import MctsAgent
from shogi import Board, Move

sys.stdout = open(sys.stdout.fileno(), mode="w", encoding="utf-8", buffering=1)


def main() -> None:
"""
Main function for the example.
"""
board: Board = shogi.Board()

agent1: RandomAgent = RandomAgent.from_board(board)
agent1: MctsAgent = MctsAgent.from_board(board)
env: Environment = agent1.env
agent2: RandomAgent = RandomAgent(env)

agent2: RandomAgent = RandomAgent(env, player=1)

while not board.is_game_over():
agent1_action: Move = agent1.select_action()
board.push(agent1_action)
print(f"Agent 1 move: {agent1_action}")
print(f"Games simulated: {agent1.games_simulated}")
print(board)
print(f"Move: {len(board.move_stack)}")
if board.is_game_over():
break

Expand All @@ -36,6 +45,9 @@ def main() -> None:
print("Final State of board:")
print(board)
print(f"Number of moves {len(board.move_stack)}")
with open("game.txt", "w") as f:
for move in board.move_stack:
f.write(str(move) + "\n")


if __name__ == "__main__":
Expand Down
136 changes: 136 additions & 0 deletions shogi-ai/mcts_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@

import time
import random

from shogi import Board
from shogi import Move
from environment import Environment
from agent import Agent
from typing import Optional


class Node:
# we are at first going to have a max depth of recorded
# nodes of 1. This might be a bad idea, but we will see

def __init__(self, move: Optional[Move], parent=None):
self.move = move
self.parent = parent
self.explored_children = []
self.visits = 1
self.value = 0

if parent:
parent.explored_children.append(self)

p = self.parent
while p:
p.visits += 1
p = p.parent

def get_child_from_move(self, move: Move):
for child in self.explored_children:
if child.move == move:
return child
return None


class MctsAgent(Agent):
"""
Monte Carlo Tree Search Agent class for shogi.
```
from shogi import Board
from mcts_agent import MctsAgent
board = Board()
agent = MctsAgent.from_board(board, player=0)
move: Move = agent.select_action(board)
```
"""

def __init__(self, env: Environment, player: int, strategy=None):
strategy = "mcts"
self.time_limit = 5
self.tree = Node(move=None, parent=None)
self.games_simulated = 0
super().__init__(env=env, player=player, strategy=strategy)

def select_action(self):
if self.player != self.env.board.turn:
raise ValueError("Not the MCTS_AGENT's turn")

self.tree = Node(move=None, parent=None)
self.games_simulated = 0
start_time = time.time()
time_delta = 0

while time_delta < self.time_limit:
time_delta = time.time() - start_time
self._simulation()
self.games_simulated += 1

best_node = max(self.tree.explored_children, key=lambda n: n.value)

return best_node.move

def _simulation(self):
base_board = Board(self._env.board.sfen())
node = self._expansion(board=base_board)
if node.visits == 0:
value = self._rollout(board_copy=base_board, move=node.move)
else:
board_after_explored_node = Board(base_board.sfen())
board_after_explored_node.push(node.move)

# need to consider when we pick a move that
# loses us the game
if board_after_explored_node.is_game_over():
value = self._utility(board_after_explored_node)
self._backpropagation(node, value)

expanded_move = self._select_random_move(board=board_after_explored_node)
value = self._rollout(board_copy=board_after_explored_node,
move=expanded_move)

self._backpropagation(node, value)

def _rollout(self, board_copy: Board, move: Move) -> int:
board_copy.push(move)

while not board_copy.is_game_over():
move = self._select_random_move(board_copy)
board_copy.push(move)

return _utility(board_copy)

def _utility(self, board_copy: Board):
if board_copy.is_checkmate():
if board_copy.turn != self.player:
return 1
else:
return -1
return 0

def _backpropagation(self, node: Node, value: int):
node.value += value

def _select_random_move(self, board: Board) -> Move:
moves = [move for move in board.legal_moves]
return random.choice(moves)

def _expansion(self, board: Board) -> Node:
move = self._select_random_move(board)

existing_node = self.tree.get_child_from_move(move)
if existing_node:
existing_node.visits += 1
return existing_node

node = Node(move, self.tree)
return node

@classmethod
def from_board(cls, board: Board):
return MctsAgent(env=Environment(board), player=board.turn)
11 changes: 7 additions & 4 deletions shogi-ai/random_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from environment import Environment
from shogi import Board
from shogi.Move import Move
from typing import Optional


class RandomAgent(Agent):
Expand All @@ -28,14 +29,16 @@ class RandomAgent(Agent):
```
"""

def __init__(self, env: Environment, strategy=None):
def __init__(self, env: Environment, player: int, strategy=None):
strategy = "random"
super().__init__(env, strategy)
super().__init__(env=env, player=player, strategy=strategy)

def select_action(self) -> Move:
def select_action(self, board: Optional[Board] = None) -> Move:
if self.player != self.env.board.turn:
raise ValueError("Not the player's turn")
legal_moves = self._env.action_space
return random.choice(legal_moves)

@classmethod
def from_board(cls, board: Board):
return RandomAgent(Environment(board))
return RandomAgent(Environment(board), player=board.turn)

0 comments on commit 4483594

Please sign in to comment.