Skip to content

Commit

Permalink
optimization: dirichlet_expand is now probably 4 times faster, even t…
Browse files Browse the repository at this point in the history
…hough not much work has been done on it. The method is called so rarely that it doesn't matter much, but it's still nice to see that it's faster.
  • Loading branch information
ChristianFredrikJohnsen committed Apr 27, 2024
1 parent 1470892 commit bf2dfd3
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 188 deletions.
24 changes: 11 additions & 13 deletions src/alphazero/tree_search_methods/expand.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import torch
from torch.distributions.dirichlet import Dirichlet
from src.alphazero.node import Node
from src.utils.game_context import GameContext
from src.utils.tensor_utils import normalize_policy_values_with_noise
from src.utils.random_utils import generate_dirichlet_noise


# @profile
Expand Down Expand Up @@ -54,16 +53,15 @@ def dirichlet_expand(context: GameContext, node: Node, nn_policy_values: torch.T
-> [0.3931, 0.2297, 0.3772]
"""
state = node.state
legal_actions = state.legal_actions()
noise = Dirichlet(torch.tensor([alpha] * len(legal_actions), dtype=torch.float)).sample()
nn_policy_values = nn_policy_values.cpu()
policy_values_with_noise = torch.softmax(nn_policy_values[legal_actions], dim=0).mul_(epsilon).add_(noise.mul_(1 - epsilon))
node.set_children_policy_values(policy_values_with_noise)

legal_actions = node.state.legal_actions()
noise = generate_dirichlet_noise(context, len(legal_actions), alpha)
normalize_policy_values_with_noise(nn_policy_values, legal_actions, noise, epsilon)
policy_values = nn_policy_values.to("cpu")
node.set_children_policy_values(policy_values[legal_actions])

for action in legal_actions: # Add the children with correct policy values
new_state = node.state.clone()
children = node.children
for action, policy_value in zip(legal_actions, policy_values_with_noise):
new_state = state.clone()
new_state.apply_action(action)
node.children.append(
Node(node, new_state, action, policy_values[action])
)
children.append(Node(node, new_state, action, policy_value))
15 changes: 0 additions & 15 deletions src/utils/random_utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,7 @@
import torch
from torch.distributions.dirichlet import Dirichlet
from src.alphazero.node import Node
from src.utils.game_context import GameContext

def generate_dirichlet_noise(context: GameContext, num_legal_actions: int, alpha: float) -> torch.Tensor:
"""
Generates a Dirichlet noise tensor, which is used to encourage exploration in the policy values.
The Dirichlet distribution is a multivariate generalization of the Beta distribution.
Parameters:
- num_actions: int - The number of actions in the current state
- alpha: float - The concentration parameter of the Dirichlet distribution
Returns:
- torch.Tensor - The Dirichlet noise tensor
"""
return Dirichlet(torch.tensor([alpha] * num_legal_actions, dtype=torch.float, device=context.device)).sample()


def generate_probabilty_target(root_node: Node, context: GameContext) -> torch.Tensor:
"""
Expand Down
61 changes: 0 additions & 61 deletions src/utils/tensor_utils.py

This file was deleted.

Empty file removed test/mcts/__init__.py
Empty file.
23 changes: 0 additions & 23 deletions test/utils/test_random_utils.py

This file was deleted.

76 changes: 0 additions & 76 deletions test/utils/test_tensor_utils.py

This file was deleted.

0 comments on commit bf2dfd3

Please sign in to comment.