diff --git a/mctx/_src/search.py b/mctx/_src/search.py index 42707be..b40a924 100644 --- a/mctx/_src/search.py +++ b/mctx/_src/search.py @@ -171,13 +171,15 @@ def body_fun(state): node_index = jnp.array(Tree.ROOT_INDEX, dtype=jnp.int32) depth = jnp.zeros((), dtype=tree.children_prior_logits.dtype) - initial_state = _SimulationState( + # pytype: disable=wrong-arg-types # jnp-type + initial_state = _SimulationState( rng_key=rng_key, node_index=tree.NO_PARENT, action=tree.NO_PARENT, next_node_index=node_index, depth=depth, is_continuing=jnp.array(True)) + # pytype: enable=wrong-arg-types end_state = jax.lax.while_loop(cond_fun, body_fun, initial_state) # Returning a node with a selected action. diff --git a/mctx/_src/tree.py b/mctx/_src/tree.py index 0cfc9fb..e9a5354 100644 --- a/mctx/_src/tree.py +++ b/mctx/_src/tree.py @@ -87,10 +87,12 @@ def num_simulations(self): def qvalues(self, indices): """Compute q-values for any node indices in the tree.""" + # pytype: disable=wrong-arg-types # jnp-type if jnp.asarray(indices).shape: return jax.vmap(_unbatched_qvalues)(self, indices) else: return _unbatched_qvalues(self, indices) + # pytype: enable=wrong-arg-types def summary(self) -> SearchSummary: """Extract summary statistics for the root node."""