Skip to content

Commit

Permalink
[LSC] Ignore incorrect type annotations related to jax.numpy APIs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568283111
  • Loading branch information
Jake VanderPlas authored and MctxDev committed Sep 26, 2023
1 parent c13a660 commit d36c68e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion mctx/_src/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def body_fun(state):

node_index = jnp.array(Tree.ROOT_INDEX, dtype=jnp.int32)
depth = jnp.zeros((), dtype=tree.children_prior_logits.dtype)
initial_state = _SimulationState(
initial_state = _SimulationState( # pytype: disable=wrong-arg-types # jnp-type
rng_key=rng_key,
node_index=tree.NO_PARENT,
action=tree.NO_PARENT,
Expand Down
4 changes: 2 additions & 2 deletions mctx/_src/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def num_simulations(self):
def qvalues(self, indices):
"""Compute q-values for any node indices in the tree."""
if jnp.asarray(indices).shape:
return jax.vmap(_unbatched_qvalues)(self, indices)
return jax.vmap(_unbatched_qvalues)(self, indices) # pytype: disable=wrong-arg-types # jnp-type
else:
return _unbatched_qvalues(self, indices)
return _unbatched_qvalues(self, indices) # pytype: disable=wrong-arg-types # jnp-type

def summary(self) -> SearchSummary:
"""Extract summary statistics for the root node."""
Expand Down

0 comments on commit d36c68e

Please sign in to comment.