diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index 355d2c98..cf37810e 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -1233,7 +1233,7 @@ def zipf(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optiona if size is None: size = jnp.shape(a) dtype = jax.dtypes.canonicalize_dtype(jnp.int_) - dtype = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype), + r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype), jax.ShapeDtypeStruct(size, dtype), a) return _return(r) @@ -1280,7 +1280,7 @@ def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequenc size = _size2shape(size) dtype = jax.dtypes.canonicalize_dtype(jnp.int_) d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample} - r = jax.pure_callback()(lambda d: np.random.hypergeometric(ngood=d['ngood'], + r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'], nbad=d['nbad'], nsample=d['nsample'], size=size).astype(dtype),