-
Notifications
You must be signed in to change notification settings - Fork 192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AlphaZero subtree persistence #86
Conversation
I thought of one concern regarding the |
Thanks for trying the get_subtree() and sending the PR. I worry that the subtree reuse is not compatible with the current gumbel_muzero_policy implementation. |
I see -- I'm not aware of a good way to incorporate any existing visit counts into the sequential halving algorithm, especially given that they were generated by the interior action selection algorithm -- perhaps devising a way to do this would be a good research problem but is probably out of scope for this PR. I will remove the option for subtree reuse from |
Thanks for the comment. |
As far as I can tell -- all subtrees of the provided test trees are reproduced accurately in the tests I wrote. I also tested the feature in the Connect 4 example notebook linked in the readme and had no issues. I'd be happy to write some more granular test cases if you'd like.
I'm not able to see your comments yet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, my comments were pending.
tree = expand( | ||
params, expand_key, tree, recurrent_fn, parent_index, | ||
action, next_node_index) | ||
# if next_node_index goes out of bounds (i.e. no room left for new nodes) | ||
# backward its (in-bounds) parent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we bound the next_node_index before calling expand()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm exploiting that out-of-bounds updates are no-ops in JAX here, otherwise I'd have to change some of the logic in expand()
. If this is a bad pattern I can try something else.
I assumed that if the tree is full we do not want to overwrite any already-expanded node, but still backpropagate the value normally as if we did do an expansion. (this is why I put the out of bounds logic after expand() )
invalid_actions: a mask with invalid actions. Invalid actions | ||
have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`. | ||
max_depth: maximum search tree depth allowed during simulation. | ||
max_nodes: maximum number of nodes allowed in the search tree. If `None`, | ||
max_nodes == num_simulations + 1. This only applies when `tree` is `None` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to specify max_nodes? Cannot we always deduce the max_nodes from the num_simulations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea behind max_nodes is that when initializing a Tree we need to choose a capacity. When the tree is discarded after each call to search(), we can just initialize a tree with capacity = num_simulations
.
However, in the case where we want to re-use a tree, I thought it might be useful to decouple num_simulations from tree capacity, s.t. tree capacity >= num_simulations, allowing for extra room for the next search call's expanded nodes.
In the case where a selected subtree contains on average S nodes, one might want to set the tree capacity to some value >= S + num_simulations
, so that there is room for most node expansions. How high/low to set this capacity relative to num_simulations becomes a problem-dependent memory/(accuracy?) trade-off.
I do think that having max_nodes
alongside num_simulations
could be confusing, perhaps a cleaner way to organize this would be:
- change
tree
to be a mandatory argument tosearch()
muzero_policy
etc. each callinstantiate_tree_from_root
and pass the initialized tree tosearch()
- lets us remove
max_nodes
argument tosearch()
(and coincidentallyroot
,extra_data
, androot_invalid_actions
)
- create a new policy
alphazero_policy
:- [option 1]: accepts an optional
tree
andmax_nodes
, initializes a tree with capacitymax_nodes
if a tree is not passed as an argument - [option 2]: requires an initialized tree to be passed (would need another new public API function to initialize one)
- [option 1]: accepts an optional
- remove
max_nodes
andtree
as arguments tomuzero_policy
and only support tree re-use inalphazero_policy
This isolates passing a subtree to alphazero_policy
alone, which I like given that you'd never actually want to pass a subtree in MuZero. Maybe it could even return the subtree as part of the output??
I implemented these proposed changes (w/ option 1) in a branch on my mctx fork: https://github.com/lowrollr/mctx/tree/alphazero_policy
I also pushed a modified version of the Connect 4 example to that branch that uses alphazero_policy
and get_subtree
. If you feel this is a good approach I could write some unit tests for alphazero_policy
as well.
Thanks for the clarifications. Would it be OK to keep the functionality unmerged? |
You mention AlphaZero in the readme, so in my opinion supporting subtree re-use should be included functionality. I understand wanting to keep the codebase as lightweight and simple as possible. If you have specific concerns, constraints, or parts of the code you'd prefer be left unchanged I'd be happy to work around them to get this ok to merge. |
I want to ensure that mctx will work correctly |
That is understandable. Thank you for taking a look at my code, I admire this repo a lot. |
Thank you. |
Here's the link: https://github.com/lowrollr/mctx-az |
Requested by #51, this PR introduces the capability to pass a Tree to
muzero_policy
andgumbel_muzero_policy
, allowing for MCTS to continue from a pre-initialized tree.The main use-case is for users implementing AlphaZero, where environment dynamics are known, not modeled and therefore saving work from a previous MCTS call becomes useful.
I introduce a new public API function
get_subtree
, which extracts a subtree rooted at a given root child index, which can be utilized by AlphaZero-esque implementations to extract the subtree corresponding to a taken action.I also include a utility function
reset_search_tree
, which can be used to reset/zero out the search tree, useful in the case of a terminated episode where the search tree can be discarded.Including this feature within an AlphaZero implementation might look something like this (pseudo-code)
In the case where no trees have been initialized
mctx.muzero_policy(..., tree=None)
still works and will instantiate a new search tree (as before).I've also decoupled
num_simulations
from the capacity of the search tree, which is now specified as an argument tomuzero_policy
orgumbel_muzero_policy
calledmax_nodes
. Ifmax_nodes
is not specified, the tree capacity defaults tonum_simulations
(just as it worked before). This is useful in the case of AlphaZero, where the number of occupied nodes in the search tree may grow/shrink from call to call so it's useful to include extra capacity.I also included tests for get_subtree that run on each of the existing test pytrees. The tests run
get_subtree
on each of the root children and compare against the source tree. I'd be happy to only run on a subset of the child nodes if test runtime is too long (~60s total on my machine).Calls the public API work as they did before, I did not introduce any new mandatory arguments. Happy to re-organize & re-tool any of these changes if the maintainers have suggestions.