Skip to content

Commit

Permalink
feat: modules for significant sections
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielBarton446 committed Apr 12, 2024
1 parent 2f1b970 commit ffc5af5
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 88 deletions.
Empty file added shogi-ai/agents/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion shogi-ai/agent.py → shogi-ai/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Agent class is the base class for all agents.
"""

from environment import Environment
from env.environment import Environment
from shogi import Board


Expand Down
45 changes: 30 additions & 15 deletions shogi-ai/mcts_agent.py → shogi-ai/agents/mcts_agent.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,41 @@
"""
A monte-carlo tree search for playing shogi
"""

import time
import random
import math
import random
import time
from typing import List, Optional

from shogi import Board
from shogi import Move
from environment import Environment
from agent import Agent
from typing import Optional
from agents.agent import Agent
from env.environment import Environment
from shogi import Board, Move


class Node:
class Node: # pylint: disable=too-few-public-methods
"""
Node class for moves in the MCTS tree.
"""

# we are at first going to have a max depth of recorded
# nodes of 1. This might be a bad idea, but we will see

def __init__(self, move: Optional[Move], parent=None):
self.move = move
self.parent = parent
self.children = []
self.children: List[Node] = []
self.visits = 0
self.value = 0

if parent:
parent.children.append(self)

def get_child_from_move(self, move: Move):
"""
Fetch a child move if it exists for the given move.
"""
for child in self.children:
if child.move == move:
return child
Expand Down Expand Up @@ -92,14 +103,19 @@ def _selection(self) -> Node:
curr_node_visits = max(1, current_node.visits)
tree_visits = max(1, self.tree.visits)

uct_ucb1 = current_node.value / curr_node_visits + self.exploration_coefficient * math.sqrt(
math.log(tree_visits) / curr_node_visits
uct_ucb1 = (
current_node.value / curr_node_visits
+ self.exploration_coefficient
* math.sqrt(math.log(tree_visits) / curr_node_visits)
)
if uct_ucb1 > max_uct_ucb1:
max_uct_ucb1 = uct_ucb1
node_to_rollout = current_node
queue.extend(current_node.children)

if node_to_rollout is None:
raise ValueError("No node to rollout")

return node_to_rollout

def _simulation(self, node_to_rollout: Node):
Expand Down Expand Up @@ -136,8 +152,7 @@ def _utility(self, board_copy: Board):
if board_copy.is_checkmate():
if board_copy.turn != self.player:
return 1
else:
return -1
return -1
return 0

def _backpropagation(self, leaf_node: Node, value: int):
Expand All @@ -150,7 +165,7 @@ def _backpropagation(self, leaf_node: Node, value: int):

def _random_move(self, board: Board) -> Move:
self.positions_checked += 1
moves = [move for move in board.pseudo_legal_moves]
moves = list(board.pseudo_legal_moves)
move = None
while True:
move = random.choice(moves)
Expand All @@ -161,7 +176,7 @@ def _random_move(self, board: Board) -> Move:

def _expansion(self, board: Board, parent_node: Node) -> None:
move = self._random_move(board)
node = Node(move=move, parent=parent_node)
Node(move=move, parent=parent_node)

@classmethod
def from_board(cls, board: Board):
Expand Down
7 changes: 3 additions & 4 deletions shogi-ai/random_agent.py → shogi-ai/agents/random_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@

import random

from agent import Agent
from environment import Environment
from agents.agent import Agent
from env.environment import Environment
from shogi import Board
from shogi.Move import Move
from typing import Optional


class RandomAgent(Agent):
Expand All @@ -33,7 +32,7 @@ def __init__(self, env: Environment, player: int, strategy=None):
strategy = "random"
super().__init__(env=env, player=player, strategy=strategy)

def select_action(self, board: Optional[Board] = None) -> Move:
def select_action(self) -> Move:
if self.player != self.env.board.turn:
raise ValueError("Not the player's turn")
legal_moves = self._env.action_space
Expand Down
60 changes: 0 additions & 60 deletions shogi-ai/environment.py

This file was deleted.

Empty file added shogi-ai/evaluation/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions shogi-ai/evaluation/grimbergen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
This contains the mapping used for
the Grimbergen evaluation of pieces.
"""

from shogi import (
BISHOP,
GOLD,
KNIGHT,
LANCE,
PAWN,
PROM_BISHOP,
PROM_KNIGHT,
PROM_LANCE,
PROM_PAWN,
PROM_ROOK,
PROM_SILVER,
ROOK,
SILVER,
)

GRIMBERGEN = {
PAWN: 1,
LANCE: 3,
KNIGHT: 3,
SILVER: 5,
GOLD: 5,
BISHOP: 8,
ROOK: 9,
PROM_PAWN: 5,
PROM_LANCE: 5,
PROM_KNIGHT: 5,
PROM_SILVER: 5,
PROM_BISHOP: 12,
PROM_ROOK: 13,
}
12 changes: 4 additions & 8 deletions shogi-ai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@
"""

import shogi
import sys
from environment import Environment
from random_agent import RandomAgent
from mcts_agent import MctsAgent
from agents.mcts_agent import MctsAgent
from agents.random_agent import RandomAgent
from env.environment import Environment
from shogi import Board, Move

sys.stdout = open(sys.stdout.fileno(), mode="w", encoding="utf-8", buffering=1)


def main() -> None:
"""
Expand All @@ -29,7 +26,6 @@ def main() -> None:

agent2: RandomAgent = RandomAgent(env, player=1)

agent1_action: Move = agent1.select_action()
while not board.is_game_over():
agent1_action: Move = agent1.select_action()
board.push(agent1_action)
Expand All @@ -50,7 +46,7 @@ def main() -> None:
print(f"Simulated games: {agent1.games_simulated}")
print(f"Rollouts: {agent1.rollouts}")
print(f"Positions Checked: {agent1.positions_checked}")
with open("game.txt", "w") as f:
with open("game.txt", "w", encoding="utf-8") as f:
for move in board.move_stack:
f.write(str(move) + "\n")

Expand Down

0 comments on commit ffc5af5

Please sign in to comment.