Skip to content

Commit

Permalink
Change deprecated jax.tree_util.tree_map to jax.tree.map. Fix argumen…
Browse files Browse the repository at this point in the history
…t passed to jax.numpy.finfo call.
  • Loading branch information
carlosgmartin committed Jul 22, 2024
1 parent 9fb7339 commit a8ea1be
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion mctx/_src/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def _mask_invalid_actions(logits, invalid_actions):


def _get_logits_from_probs(probs):
tiny = jnp.finfo(probs).tiny
tiny = jnp.finfo(probs.dtype).tiny
return jnp.log(jnp.maximum(probs, tiny))


Expand Down
6 changes: 3 additions & 3 deletions mctx/_src/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def expand(
chex.assert_shape([parent_index, action, next_node_index], (batch_size,))

# Retrieve states for nodes to be evaluated.
embedding = jax.tree_util.tree_map(
embedding = jax.tree.map(
lambda x: x[batch_range, parent_index], tree.embeddings)

# Evaluate and create a new node.
Expand Down Expand Up @@ -335,7 +335,7 @@ def update_tree_node(
tree.node_values, value, node_index),
node_visits=batch_update(
tree.node_visits, new_visit, node_index),
embeddings=jax.tree_util.tree_map(
embeddings=jax.tree.map(
lambda t, s: batch_update(t, s, node_index),
tree.embeddings, embedding))

Expand Down Expand Up @@ -375,7 +375,7 @@ def _zeros(x):
children_visits=jnp.zeros(batch_node_action, dtype=jnp.int32),
children_rewards=jnp.zeros(batch_node_action, dtype=data_dtype),
children_discounts=jnp.zeros(batch_node_action, dtype=data_dtype),
embeddings=jax.tree_util.tree_map(_zeros, root.embedding),
embeddings=jax.tree.map(_zeros, root.embedding),
root_invalid_actions=root_invalid_actions,
extra_data=extra_data)

Expand Down
2 changes: 1 addition & 1 deletion mctx/_src/tests/policies_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_gumbel_muzero_policy(self):

# Testing max_depth.
leaf, max_found_depth = _get_deepest_leaf(
jax.tree_util.tree_map(lambda x: x[0], policy_output.search_tree),
jax.tree.map(lambda x: x[0], policy_output.search_tree),
policy_output.search_tree.ROOT_INDEX)
self.assertEqual(max_depth, max_found_depth)
self.assertEqual(6, policy_output.search_tree.node_visits[0, leaf])
Expand Down

0 comments on commit a8ea1be

Please sign in to comment.