Skip to content

Commit

Permalink
Update random.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 10, 2024
1 parent d30a552 commit d26a3d5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions brainpy/_src/math/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,7 @@ def zipf(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optiona
if size is None:
size = jnp.shape(a)
dtype = jax.dtypes.canonicalize_dtype(jnp.int_)
dtype = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
a)
return _return(r)
Expand Down Expand Up @@ -1280,7 +1280,7 @@ def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequenc
size = _size2shape(size)
dtype = jax.dtypes.canonicalize_dtype(jnp.int_)
d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
r = jax.pure_callback()(lambda d: np.random.hypergeometric(ngood=d['ngood'],
r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'],
nbad=d['nbad'],
nsample=d['nsample'],
size=size).astype(dtype),
Expand Down

0 comments on commit d26a3d5

Please sign in to comment.