Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AlphaZero subtree persistence #86

Closed
wants to merge 20 commits into from
Closed

Conversation

lowrollr
Copy link

@lowrollr lowrollr commented Jan 6, 2024

Requested by #51, this PR introduces the capability to pass a Tree to muzero_policy and gumbel_muzero_policy, allowing for MCTS to continue from a pre-initialized tree.

The main use-case is for users implementing AlphaZero, where environment dynamics are known, not modeled and therefore saving work from a previous MCTS call becomes useful.

I introduce a new public API function get_subtree, which extracts a subtree rooted at a given root child index, which can be utilized by AlphaZero-esque implementations to extract the subtree corresponding to a taken action.

I also include a utility function reset_search_tree, which can be used to reset/zero out the search tree, useful in the case of a terminated episode where the search tree can be discarded.

Including this feature within an AlphaZero implementation might look something like this (pseudo-code)

output = mctx.muzero_policy(..., tree=tree)
tree = mctx.get_subtree(output.search_tree, output.action)
terminated = env.step(output.action)
tree = mctx.reset_search_tree(tree, terminated)

In the case where no trees have been initialized mctx.muzero_policy(..., tree=None) still works and will instantiate a new search tree (as before).

I've also decoupled num_simulations from the capacity of the search tree, which is now specified as an argument to muzero_policy or gumbel_muzero_policy called max_nodes. If max_nodes is not specified, the tree capacity defaults to num_simulations (just as it worked before). This is useful in the case of AlphaZero, where the number of occupied nodes in the search tree may grow/shrink from call to call so it's useful to include extra capacity.

I also included tests for get_subtree that run on each of the existing test pytrees. The tests run get_subtree on each of the root children and compare against the source tree. I'd be happy to only run on a subset of the child nodes if test runtime is too long (~60s total on my machine).

Calls the public API work as they did before, I did not introduce any new mandatory arguments. Happy to re-organize & re-tool any of these changes if the maintainers have suggestions.

@lowrollr lowrollr changed the title Utility functions for AlphaZero subtree persistence functions for AlphaZero subtree persistence Jan 7, 2024
@lowrollr lowrollr changed the title functions for AlphaZero subtree persistence AlphaZero subtree persistence Jan 7, 2024
@lowrollr
Copy link
Author

lowrollr commented Jan 7, 2024

I thought of one concern regarding the Tree property num_simulations. The number of simulations that a particular Tree object supported used to be equivalent to its capacity, but in this PR this is no longer the case, which could make the name of this property deceiving (as it now just tied to capacity, or maximum number of simulations).

@fidlej
Copy link
Collaborator

fidlej commented Jan 14, 2024

Thanks for trying the get_subtree() and sending the PR.
Sorry for my slow response.

I worry that the subtree reuse is not compatible with the current gumbel_muzero_policy implementation.
That policy assumes that the tree starts empty. To implement the sequential halving, the action selection uses a simulation_index.
https://github.com/google-deepmind/mctx/blob/d40d32e1a18fb73030762bac33819f95fff9787c/mctx/_src/action_selection.py#L145C3-L145C19

@lowrollr
Copy link
Author

lowrollr commented Jan 15, 2024

I see -- I'm not aware of a good way to incorporate any existing visit counts into the sequential halving algorithm, especially given that they were generated by the interior action selection algorithm -- perhaps devising a way to do this would be a good research problem but is probably out of scope for this PR.

I will remove the option for subtree reuse from gumbel_muzero_policy and just allow it for muzero_policy.
If you'd prefer, I could instead create a new policy alphazero_policy that allows for subtree reuse and is otherwise identical to muzero_policy and restore muzero_policy to the way it was before. I wanted to minimize changes to the public API but this could help disambiguate.

@fidlej
Copy link
Collaborator

fidlej commented Jan 16, 2024

Thanks for the comment.
Are you sure that the implementation works correctly?
I left some comments on the code, but I have not checked everything.

@lowrollr
Copy link
Author

lowrollr commented Jan 16, 2024

Thanks for the comment.
Are you sure that the implementation works correctly?

As far as I can tell -- all subtrees of the provided test trees are reproduced accurately in the tests I wrote. I also tested the feature in the Connect 4 example notebook linked in the readme and had no issues.

I'd be happy to write some more granular test cases if you'd like.

I left some comments on the code, but I have not checked everything.

I'm not able to see your comments yet

Copy link
Collaborator

@fidlej fidlej left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, my comments were pending.

mctx/_src/search.py Show resolved Hide resolved
mctx/_src/search.py Show resolved Hide resolved
tree = expand(
params, expand_key, tree, recurrent_fn, parent_index,
action, next_node_index)
# if next_node_index goes out of bounds (i.e. no room left for new nodes)
# backward its (in-bounds) parent
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we bound the next_node_index before calling expand()?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm exploiting that out-of-bounds updates are no-ops in JAX here, otherwise I'd have to change some of the logic in expand(). If this is a bad pattern I can try something else.

I assumed that if the tree is full we do not want to overwrite any already-expanded node, but still backpropagate the value normally as if we did do an expansion. (this is why I put the out of bounds logic after expand() )

invalid_actions: a mask with invalid actions. Invalid actions
have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`.
max_depth: maximum search tree depth allowed during simulation.
max_nodes: maximum number of nodes allowed in the search tree. If `None`,
max_nodes == num_simulations + 1. This only applies when `tree` is `None`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to specify max_nodes? Cannot we always deduce the max_nodes from the num_simulations?

Copy link
Author

@lowrollr lowrollr Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea behind max_nodes is that when initializing a Tree we need to choose a capacity. When the tree is discarded after each call to search(), we can just initialize a tree with capacity = num_simulations.

However, in the case where we want to re-use a tree, I thought it might be useful to decouple num_simulations from tree capacity, s.t. tree capacity >= num_simulations, allowing for extra room for the next search call's expanded nodes.

In the case where a selected subtree contains on average S nodes, one might want to set the tree capacity to some value >= S + num_simulations, so that there is room for most node expansions. How high/low to set this capacity relative to num_simulations becomes a problem-dependent memory/(accuracy?) trade-off.

I do think that having max_nodes alongside num_simulations could be confusing, perhaps a cleaner way to organize this would be:

  • change tree to be a mandatory argument to search()
    • muzero_policy etc. each call instantiate_tree_from_root and pass the initialized tree to search()
    • lets us remove max_nodes argument to search() (and coincidentally root, extra_data, and root_invalid_actions)
  • create a new policy alphazero_policy:
    • [option 1]: accepts an optional tree and max_nodes, initializes a tree with capacity max_nodes if a tree is not passed as an argument
    • [option 2]: requires an initialized tree to be passed (would need another new public API function to initialize one)
  • remove max_nodes and tree as arguments to muzero_policy and only support tree re-use in alphazero_policy

This isolates passing a subtree to alphazero_policy alone, which I like given that you'd never actually want to pass a subtree in MuZero. Maybe it could even return the subtree as part of the output??

I implemented these proposed changes (w/ option 1) in a branch on my mctx fork: https://github.com/lowrollr/mctx/tree/alphazero_policy

I also pushed a modified version of the Connect 4 example to that branch that uses alphazero_policy and get_subtree. If you feel this is a good approach I could write some unit tests for alphazero_policy as well.

@fidlej
Copy link
Collaborator

fidlej commented Jan 17, 2024

Thanks for the clarifications.
You understand the code well.

Would it be OK to keep the functionality unmerged?
If people want this alphazero-specific functionality, they can look at your repository.

@lowrollr
Copy link
Author

lowrollr commented Jan 17, 2024

You mention AlphaZero in the readme, so in my opinion supporting subtree re-use should be included functionality.

I understand wanting to keep the codebase as lightweight and simple as possible. If you have specific concerns, constraints, or parts of the code you'd prefer be left unchanged I'd be happy to work around them to get this ok to merge.

@fidlej
Copy link
Collaborator

fidlej commented Jan 17, 2024

I want to ensure that mctx will work correctly
and I currently do not have time to carefully review the proposed changes.
Mctx will probably remain mostly frozen.

@lowrollr
Copy link
Author

That is understandable.
In that case I can document the new functionality in my repository and provide a few examples.
Would appreciate you adding a link to my repo in 'Example Projects' in the mctx readme when I am done.

Thank you for taking a look at my code, I admire this repo a lot.

@fidlej
Copy link
Collaborator

fidlej commented Jan 18, 2024

Thank you.
When you are ready with your repo, please ping me and I will add the link.

@lowrollr lowrollr closed this Jan 18, 2024
@lowrollr
Copy link
Author

Here's the link: https://github.com/lowrollr/mctx-az

@fidlej

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants