From d26a3d533eb10b4b4633ab2a9d6ee9c0e01c20b8 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Fri, 10 May 2024 11:51:11 +0800 Subject: [PATCH] Update random.py --- brainpy/_src/math/random.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index 355d2c98..cf37810e 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -1233,7 +1233,7 @@ def zipf(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optiona if size is None: size = jnp.shape(a) dtype = jax.dtypes.canonicalize_dtype(jnp.int_) - dtype = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype), + r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype), jax.ShapeDtypeStruct(size, dtype), a) return _return(r) @@ -1280,7 +1280,7 @@ def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequenc size = _size2shape(size) dtype = jax.dtypes.canonicalize_dtype(jnp.int_) d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample} - r = jax.pure_callback()(lambda d: np.random.hypergeometric(ngood=d['ngood'], + r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'], nbad=d['nbad'], nsample=d['nsample'], size=size).astype(dtype),