Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielBarton446 committed Apr 21, 2024
1 parent e99d6be commit d2d3463
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 15 deletions.
21 changes: 14 additions & 7 deletions shogi-ai/agents/mcts_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
"""

import math
import os
import random
import time
from functools import partial
from typing import List, Optional, Tuple, Callable
from typing import Callable, List, Optional, Tuple

from agents.agent import Agent
from environments.environment import Environment
Expand Down Expand Up @@ -49,7 +48,9 @@ def get_child_from_move(self, move: Move):
for child in all_children:
if child.move == move:
return child
# logger.warning(f"Couldn't find {move} in current tree:\n {[node.move for node in self.all_subchild_nodes()]}")
logger.warning(
"{move} not in tree:\n %s", {[str(child.move) for child in all_children]}
)
return None

def all_subchild_nodes(self) -> List["Node"]:
Expand Down Expand Up @@ -100,6 +101,10 @@ def __init__(self, env: Environment, player: int, strategy=None):
super().__init__(env=env, player=player, strategy=strategy)

def current_board_sims(self) -> int:
"""
returns the number of simulations that have been run for this current
tree.
"""
return self.tree.visits

def select_action(self, board: Optional[Board] = None) -> Move:
Expand All @@ -118,10 +123,12 @@ def select_action(self, board: Optional[Board] = None) -> Move:
while time_delta < self.time_limit:
time_delta = time.time() - start_time
task_futures = self.multiproc_manager.spawn_tasks(
task_lambda=self._simulation,
selector=partial(self._selection, [self.tree]))
results = self.multiproc_manager.futures_results(task_futures,
self.time_limit)
task_lambda=self._simulation,
selector=partial(self._selection, [self.tree]),
)
results = self.multiproc_manager.futures_results(
task_futures, self.time_limit
)
for res in results:
self._backpropagation(self.tree.get_child_from_move(res[0]), res[1])

Expand Down
1 change: 1 addition & 0 deletions shogi-ai/environments/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Environment:
"""
Environment class for a shogi board.
"""

def __init__(self, board: Board):
self.board = board
self._moves: List = []
Expand Down
17 changes: 15 additions & 2 deletions shogi-ai/evaluation/grimbergen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,21 @@
"""

from shogi import (BISHOP, GOLD, KNIGHT, LANCE, PAWN, PROM_BISHOP, PROM_KNIGHT,
PROM_LANCE, PROM_PAWN, PROM_ROOK, PROM_SILVER, ROOK, SILVER)
from shogi import (
BISHOP,
GOLD,
KNIGHT,
LANCE,
PAWN,
PROM_BISHOP,
PROM_KNIGHT,
PROM_LANCE,
PROM_PAWN,
PROM_ROOK,
PROM_SILVER,
ROOK,
SILVER,
)

GRIMBERGEN = {
PAWN: 1,
Expand Down
36 changes: 30 additions & 6 deletions shogi-ai/util/multiproc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
"""
This module is a simple wrapper around the ProcessPoolExecutor
to spawn tasks and get the results from the tasks.
Example usage:
```
from util.multiproc import MultiProcManager
import random
manager = MultiProcManager(num_workers=4)
futures = manager.spawn_tasks(task_lambda=lambda: x, random.randint(1, 10))
results = manager.futures_results(futures, timeout=5)
print(results)
```
"""

import concurrent.futures
import os
from concurrent.futures import Future, ProcessPoolExecutor
Expand All @@ -16,10 +32,14 @@ class MultiProcManager:
the ProcessPoolExecutor.
"""

def __init__(self, num_workers: int = os.cpu_count() - 2):
def __init__(self, num_workers: int = os.cpu_count() - 2): # type: ignore
self.num_workers = num_workers

def spawn_tasks(self, task_lambda: Callable, **kwargs: Any) -> List[Future]:
"""
Generate a list of futures from the task_lambda function to be
created on num_workers processes.
"""
futures = []

with ProcessPoolExecutor(self.num_workers) as executor:
Expand All @@ -28,11 +48,15 @@ def spawn_tasks(self, task_lambda: Callable, **kwargs: Any) -> List[Future]:
futures.append(executor.submit(command))
return futures

# TODO: it would be nice so have List[Any] actually be
# of the type: task_lambda return type
def futures_results(self,
futures: List[concurrent.futures.Future],
timeout: int) -> List[Any]:
# It would be nice so have List[Any] actually be
# of the type: task_lambda return type
def futures_results(
self, futures: List[concurrent.futures.Future], timeout: int
) -> List[Any]:
"""
Fetch the results from the futures. This will block until all
futures are completed or the timeout is reached.
"""
results = []
for future in concurrent.futures.as_completed(futures, timeout):
results.append(future.result())
Expand Down

0 comments on commit d2d3463

Please sign in to comment.