Skip to content

Commit

Permalink
optimization: evaluate has been optimized by creating buffers which s…
Browse files Browse the repository at this point in the history
…ave the results of forward-passes. Buffer use gives a 20% reduction in execution time with small amount of games, and execution time gets better with larger amount of games, since you are more likely to get a hit in the buffer.
  • Loading branch information
ChristianFredrikJohnsen committed Apr 27, 2024
1 parent e35d626 commit 0037e4d
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 117 deletions.
72 changes: 39 additions & 33 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,15 @@
# This will make the overhead of creating a new multiprocessing process less significant.


def test_overfit():
mp.set_start_method('spawn')

overfit_context = GameContext(
game_name="tic_tac_toe",
nn=NeuralNetwork(),
save_path="./models/overfit_nn"
)
def test_overfit(context: GameContext):

mp.set_start_method('spawn')
train_alphazero_model(
context=overfit_context,
num_games=1,
num_simulations=1000,
epochs=1000,
batch_size=16
context=context,
num_games=3,
num_simulations=100,
epochs=1,
batch_size=64
)

def train_tic_tac_toe(context: GameContext):
Expand All @@ -41,7 +35,7 @@ def train_tic_tac_toe(context: GameContext):
for i in range(int(1e6)):
train_alphazero_model(
context=context,
num_games=96,
num_games=48,
num_simulations=100,
epochs=3,
batch_size=32
Expand All @@ -56,7 +50,7 @@ def train_connect_four(context: GameContext):
for i in range(int(1e6)):
train_alphazero_model(
context=context,
num_games=48,
num_games=384,
num_simulations=100,
epochs=3,
batch_size=256,
Expand All @@ -75,27 +69,39 @@ def play(context: GameContext, first: bool):
first=first
)

if __name__ == '__main__': # Needed for multiprocessing to work
overfit_path = "./models/connect_four/overfit_nn"
overfit_context = GameContext(
game_name="connect_four",
nn=NeuralNetworkConnectFour().load(overfit_path),
save_path="./models/overfit_waste"
)

tic_tac_toe_path = "./models/test_nn"
tic_tac_toe_context = GameContext(
game_name="tic_tac_toe",
nn=NeuralNetwork().load(tic_tac_toe_path),
save_path=tic_tac_toe_path
)
tic_tac_toe_path = "./models/test_nn"
tic_tac_toe_context = GameContext(
game_name="tic_tac_toe",
nn=NeuralNetwork().load(tic_tac_toe_path),
save_path=tic_tac_toe_path
)

connect4_path = "./models/connect_four/initial_test"
connect4_context = GameContext(
game_name="connect_four",
nn=NeuralNetworkConnectFour().load(connect4_path),
save_path=connect4_path
)
connect4_path = "./models/connect_four/initial_test"
connect4_context = GameContext(
game_name="connect_four",
nn=NeuralNetworkConnectFour().load(connect4_path),
save_path=connect4_path
)


if __name__ == '__main__': # Needed for multiprocessing to work

# test_overfit()


# test_overfit(overfit_context)
# train_tic_tac_toe(tic_tac_toe_context)
# train_connect_four(connect4_context)
# self_play(context)
play(tic_tac_toe_context, first=False)

# self_play(tic_tac_toe_context)
# self_play(connect4_context)
# play(tic_tac_toe_context, first=False)
play(connect4_context, first=False)

# create_tic_tac_toe_model("initial_test")
# create_connect_four_model("initial_test")
# create_connect_four_model("overfit_nn")
7 changes: 4 additions & 3 deletions src/alphazero/agents/alphazero_play_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,24 @@ class AlphaZero:

def __init__(self, context: GameContext):
self.context = context
self.shape = [1] + context.game.observation_tensor_shape()
self.c = 4.0 # Exploration constant

def run_simulation(self, state, num_simulations=800): # Num-simulations 800 is good for tic-tac-toe
"""
Selection, expansion & evaluation, backpropagation.
"""
root_node = Node(parent=None, state=state, action=None, policy_value=None) # Initialize root node.
policy, value = evaluate(root_node, self.context) # Evaluate the root node
root_node = Node(parent=None, state=state, action=None, policy_value=None)
policy, value = evaluate(root_node.state.observation_tensor(), self.shape, self.context.nn, self.context.device)
print("Root node value: ", value)

for _ in range(num_simulations): # Do selection, expansion & evaluation, backpropagation

node = vectorized_select(root_node, self.c)

if not node.state.is_terminal():
policy, value = evaluate(node, self.context) # Evaluate the node, using the neural network
policy, value = evaluate(node.state.observation_tensor(), self.shape, self.context.nn, self.context.device)
expand(node, policy)

else:
Expand Down
10 changes: 7 additions & 3 deletions src/alphazero/agents/alphazero_training_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def __init__(self, context: GameContext, c: float = 4.0, alpha: float = 0.3, eps
Contains useful information like the game, neural network and device.
"""

self.shape: list[int] = [1] + context.game.observation_tensor_shape()
"""
The shape which the state tensor must have in order to be compatible with the neural network.
"""

self.c = c
"""
An exploration constant, used when calculating PUCT-values.
Expand All @@ -59,7 +64,6 @@ def __init__(self, context: GameContext, c: float = 4.0, alpha: float = 0.3, eps
After temperature_moves, the move played is deterministically the one visited the most.
"""

# @profile
def run_simulation(
self, state: pyspiel.State, move_number: int, num_simulations: int = 800
) -> tuple[int, torch.Tensor]:
Expand All @@ -69,7 +73,7 @@ def run_simulation(
"""
try:
root_node = Node(parent=None, state=state, action=None, policy_value=None)
policy, value = evaluate(root_node, self.context) # Evaluate the root node
policy, value = evaluate(root_node.state.observation_tensor(), self.shape, self.context.nn, self.context.device) # Evaluate the root node
dirichlet_expand(root_node, policy, self.a, self.e)
backpropagate(root_node, value)

Expand All @@ -78,7 +82,7 @@ def run_simulation(
node = vectorized_select(root_node, self.c)

if not node.state.is_terminal():
policy, value = evaluate(node, self.context)
policy, value = evaluate(node.state.observation_tensor(), self.shape, self.context.nn, self.context.device)
expand(node, policy)

else:
Expand Down
37 changes: 20 additions & 17 deletions src/alphazero/alphazero_generate_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def play_alphazero_game(
while not state.is_terminal():
action, probability_target = alphazero.run_simulation(state, move_number, num_simulations=num_simulations)
game_data.append((
reshape_pyspiel_state(state, alphazero.context),
probability_target
))
reshape_pyspiel_state(state, alphazero.context),
probability_target
))
state.apply_action(action)
move_number += 1

Expand Down Expand Up @@ -57,7 +57,6 @@ def play_alphazero_games(
training_data.extend(play_alphazero_game(alphazero, num_simulations))
return training_data


def generate_training_data(alphazero: AlphaZero, num_games: int, num_simulations: int = 100) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Takes in a neural network, and generates training data by making the neural network play games against itself.
Expand All @@ -79,19 +78,23 @@ def generate_training_data(alphazero: AlphaZero, num_games: int, num_simulations

training_data = []

# result_list = [play_alphazero_games(alphazero, num_games, num_simulations)] # Single-threaded
multicore_args, thread_count = get_play_alphazero_games_arguments(alphazero, num_games, num_simulations)
try:
print(f"Generating training data with {thread_count} threads...")
start_time = time.time()
with mp.Pool(thread_count) as pool:
result_list = list(tqdm(pool.starmap(play_alphazero_games, multicore_args)))
end_time = time.time()
print(f"Generated training data with {thread_count} threads in {end_time - start_time:.2f} seconds.")

except KeyboardInterrupt:
print("KeyboardInterrupt: Terminating training data generation...")
raise
start_time = time.time()
result_list = [play_alphazero_games(alphazero, num_games, num_simulations)] # Single-threaded
end_time = time.time()
print(f"Generated training data in {end_time - start_time:.2f} seconds.")

# multicore_args, thread_count = get_play_alphazero_games_arguments(alphazero, num_games, num_simulations)
# try:
# print(f"Generating training data with {thread_count} threads...")
# start_time = time.time()
# with mp.Pool(thread_count) as pool:
# result_list = list(tqdm(pool.starmap(play_alphazero_games, multicore_args)))
# end_time = time.time()
# print(f"Generated training data with {thread_count} threads in {end_time - start_time:.2f} seconds.")

# except KeyboardInterrupt:
# print("KeyboardInterrupt: Terminating training data generation...")
# raise

for i in range(len(result_list)):
training_data.extend(result_list[i])
Expand Down
53 changes: 47 additions & 6 deletions src/alphazero/tree_search_methods/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,54 @@
import torch
from src.alphazero.node import Node
from src.utils.game_context import GameContext
from src.utils.nn_utils import forward_state
from src.neuralnet.neural_network import NeuralNetwork

state_tensor_buffer = {}
policy_value_buffer = {}

def evaluate(node: Node, context: GameContext) -> tuple[torch.Tensor, float]:
def get_state_tensor(observation_tensor: list[int], shape: list[int], device: torch.device) -> torch.Tensor:
"""
Get the state tensor of the input node.
If the state tensor is already calculated, return it from the buffer.
Otherwise, calculate the state tensor and store it in the buffer.
Parameters:
- state: Node - The node to get the state tensor from
- context: GameContext - Information about the shape of the state tensor and device.
Returns:
- torch.Tensor - The state tensor of the input node
"""
observation_key = tuple(observation_tensor)
if observation_key in state_tensor_buffer:
return state_tensor_buffer[observation_key]
else:
state_tensor = torch.tensor(observation_key, device=device).reshape(shape)
state_tensor_buffer[observation_key] = state_tensor
return state_tensor

def evaluate(observation_tensor: list[int], shape: list[int], nn: NeuralNetwork, device: torch.device) -> tuple[torch.Tensor, float]:
"""
Neural network evaluation of the state of the input node.
Will not be run on a leaf node (terminal state)
Forward propagates the state tensor through the neural network.
Does some reshaping behind the scenes to make the state tensor compatible with the neural network.
Parameters:
- state: torch.Tensor - The state tensor to forward propagate
- context: GameContext - Information about the shape of the state tensor, neural network and device.
Returns:
- torch.Tensor - The output of the neural network after forward propagating the state tensor
"""
policy, value = forward_state(node.state, context)
return policy, value.item()
observation_key = tuple(observation_tensor)
if observation_key in policy_value_buffer:
return policy_value_buffer[observation_key]
else:
state_tensor = get_state_tensor(observation_tensor, shape, device)
with torch.no_grad(): ## Disable gradient calculation
policy, value = nn.forward_for_alphazero(state_tensor)
policy_value_buffer[observation_key] = (policy, value)
return policy, value

9 changes: 9 additions & 0 deletions src/neuralnet/neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
self.hidden_dimension = hidden_dimension
self.input_dimension = input_dimension
self.res_blocks = res_blocks
self.legal_moves = legal_moves

self.initial = nn.Sequential(
nn.Conv2d(
Expand Down Expand Up @@ -79,6 +80,14 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
value = self.value(x)
return policy, value

def forward_for_alphazero(self, x: torch.Tensor) -> tuple[torch.Tensor, float]:
x = self.initial(x)
for residual_block in self.residual_blocks:
x = residual_block(x)
policy = self.policy(x).reshape(self.legal_moves)
value = self.value(x).item()
return policy, value

def save(self, path: str) -> None:
directory = os.path.dirname(path)

Expand Down
9 changes: 9 additions & 0 deletions src/neuralnet/neural_network_connect_four.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
self.hidden_dimension = hidden_dimension
self.input_dimension = input_dimension
self.res_blocks = res_blocks
self.legal_moves = legal_moves

self.initial = nn.Sequential(
nn.Conv2d(
Expand Down Expand Up @@ -80,6 +81,14 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
value = self.value(x)
return policy, value

def forward_for_alphazero(self, x: torch.Tensor) -> tuple[torch.Tensor, float]:
x = self.initial(x)
for residual_block in self.residual_blocks:
x = residual_block(x)
policy = self.policy(x).reshape(self.legal_moves)
value = self.value(x).item()
return policy, value

def save(self, path: str) -> None:
directory = os.path.dirname(path)

Expand Down
4 changes: 2 additions & 2 deletions src/utils/multi_core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def get_play_alphazero_games_arguments(
- number_of_threads: int - The number of threads to use for multiprocessing
"""

max_num_threads = mp.cpu_count()
number_of_threads = max(1, min(max_num_threads, num_games // 4)) # We estimate that we should have at least 4 games per process to get the best time efficiency.
max_num_threads = mp.cpu_count() - 1
number_of_threads = max(1, min(max_num_threads, num_games // 20)) # We estimate that we should have at least 4 games per process to get the best time efficiency.

num_games_per_thread = num_games // number_of_threads
remainder = num_games % number_of_threads
Expand Down
21 changes: 0 additions & 21 deletions src/utils/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,6 @@
import pyspiel
from src.utils.game_context import GameContext

def forward_state(state: torch.Tensor, context: GameContext) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward propagates the state tensor through the neural network.
Does some reshaping behind the scenes to make the state tensor compatible with the neural network.
Parameters:
- state: torch.Tensor - The state tensor to forward propagate
- context: GameContext - Information about the shape of the state tensor, neural network and device.
Returns:
- torch.Tensor - The output of the neural network after forward propagating the state tensor
"""
shape = context.game.observation_tensor_shape() ## Get the shape of the state tensor
state_tensor = torch.reshape(torch.tensor(state.observation_tensor(), device=context.device), shape).unsqueeze(0) ## Reshape the state tensor to the correct shape and add a batch dimension

with torch.no_grad(): ## Disable gradient calculation
policy, value = context.nn.forward(state_tensor) ## Forward propagate the state tensor through the neural network
del state_tensor ## Delete the state tensor to free up memory

return policy.squeeze(0), value.squeeze(0) ## Remove the batch dimension from the output tensors and return them

def reshape_pyspiel_state(state: pyspiel.State, context: GameContext) -> torch.Tensor:
"""
Reshapes the pyspiel state tensor to the correct shape for the neural network.
Expand Down
Loading

0 comments on commit 0037e4d

Please sign in to comment.