Skip to content

Commit

Permalink
[LSC] Ignore incorrect type annotations related to jax.numpy APIs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568283111
  • Loading branch information
Jake VanderPlas authored and MctxDev committed Sep 26, 2023
1 parent c13a660 commit 49cbb4a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mctx/_src/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,15 @@ def body_fun(state):

node_index = jnp.array(Tree.ROOT_INDEX, dtype=jnp.int32)
depth = jnp.zeros((), dtype=tree.children_prior_logits.dtype)
# pytype: disable=wrong-arg-types # jnp-type
initial_state = _SimulationState(
rng_key=rng_key,
node_index=tree.NO_PARENT,
action=tree.NO_PARENT,
next_node_index=node_index,
depth=depth,
is_continuing=jnp.array(True))
# pytype: enable=wrong-arg-types
end_state = jax.lax.while_loop(cond_fun, body_fun, initial_state)

# Returning a node with a selected action.
Expand Down
2 changes: 2 additions & 0 deletions mctx/_src/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ def num_simulations(self):

def qvalues(self, indices):
"""Compute q-values for any node indices in the tree."""
# pytype: disable=wrong-arg-types # jnp-type
if jnp.asarray(indices).shape:
return jax.vmap(_unbatched_qvalues)(self, indices)
else:
return _unbatched_qvalues(self, indices)
# pytype: enable=wrong-arg-types

def summary(self) -> SearchSummary:
"""Extract summary statistics for the root node."""
Expand Down

0 comments on commit 49cbb4a

Please sign in to comment.