Skip to content

Commit

Permalink
Replace
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 12, 2024
1 parent d795517 commit 59fb681
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 45 deletions.
3 changes: 1 addition & 2 deletions brainpy/_src/integrators/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import jax.numpy as jnp
import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_flatten

from brainpy import math as bm
Expand Down Expand Up @@ -245,7 +244,7 @@ def _step_fun_integrator(self, static_args, dyn_args, t, i):

# progress bar
if self.progress_bar:
id_tap(lambda *args: self._pbar.update(), ())
jax.pure_callback(lambda *args: self._pbar.update(), ())

# return of function monitors
shared = dict(t=t + self.dt, dt=self.dt, i=i)
Expand Down
13 changes: 6 additions & 7 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import jax
import jax.numpy as jnp
from jax.errors import UnexpectedTracerError
from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_flatten, tree_unflatten
from tqdm.auto import tqdm

Expand Down Expand Up @@ -421,14 +420,14 @@ def call(pred, x=None):
def _warp(f):
@functools.wraps(f)
def new_f(*args, **kwargs):
return jax.tree_map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array))
return jax.tree.map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array))

return new_f


def _warp_data(data):
def new_f(*args, **kwargs):
return jax.tree_map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array))
return jax.tree.map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array))

return new_f

Expand Down Expand Up @@ -727,7 +726,7 @@ def fun2scan(carry, x):
dyn_vars[k]._value = carry[k]
results = body_fun(*x, **unroll_kwargs)
if progress_bar:
id_tap(lambda *arg: bar.update(), ())
jax.pure_callback(lambda *arg: bar.update(), ())
return dyn_vars.dict_data(), results

if remat:
Expand Down Expand Up @@ -916,15 +915,15 @@ def fun2scan(carry, x):
dyn_vars[k]._value = dyn_vars_data[k]
carry, results = body_fun(carry, x)
if progress_bar:
id_tap(lambda *arg: bar.update(), ())
carry = jax.tree_map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
jax.pure_callback(lambda *arg: bar.update(), ())
carry = jax.tree.map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
return (dyn_vars.dict_data(), carry), results

if remat:
fun2scan = jax.checkpoint(fun2scan)

def call(init, operands):
init = jax.tree_map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array))
init = jax.tree.map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array))
return jax.lax.scan(f=fun2scan,
init=(dyn_vars.dict_data(), init),
xs=operands,
Expand Down
52 changes: 26 additions & 26 deletions brainpy/_src/math/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import numpy as np
from jax import lax, jit, vmap, numpy as jnp, random as jr, core, dtypes
from jax._src.array import ArrayImpl
from jax.experimental.host_callback import call
from jax.tree_util import register_pytree_node_class

from brainpy.check import jit_error_checking, jit_error_checking_no_args
Expand Down Expand Up @@ -1233,9 +1232,9 @@ def zipf(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optiona
if size is None:
size = jnp.shape(a)
dtype = jax.dtypes.canonicalize_dtype(jnp.int_)
r = call(lambda x: np.random.zipf(x, size).astype(dtype),
a,
result_shape=jax.ShapeDtypeStruct(size, dtype))
r = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
a)
return _return(r)

def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
Expand All @@ -1244,9 +1243,9 @@ def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Option
size = jnp.shape(a)
size = _size2shape(size)
dtype = jax.dtypes.canonicalize_dtype(jnp.float_)
r = call(lambda a: np.random.power(a=a, size=size).astype(dtype),
a,
result_shape=jax.ShapeDtypeStruct(size, dtype))
r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
a)
return _return(r)

def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None,
Expand All @@ -1260,11 +1259,11 @@ def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None,
size = _size2shape(size)
d = {'dfnum': dfnum, 'dfden': dfden}
dtype = jax.dtypes.canonicalize_dtype(jnp.float_)
r = call(lambda x: np.random.f(dfnum=x['dfnum'],
dfden=x['dfden'],
size=size).astype(dtype),
d,
result_shape=jax.ShapeDtypeStruct(size, dtype))
r = jax.pure_callback(lambda x: np.random.f(dfnum=x['dfnum'],
dfden=x['dfden'],
size=size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
d)
return _return(r)

def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequence[int]]] = None,
Expand All @@ -1280,12 +1279,12 @@ def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequenc
size = _size2shape(size)
dtype = jax.dtypes.canonicalize_dtype(jnp.int_)
d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
r = call(lambda d: np.random.hypergeometric(ngood=d['ngood'],
nbad=d['nbad'],
nsample=d['nsample'],
size=size).astype(dtype),
d,
result_shape=jax.ShapeDtypeStruct(size, dtype))
r = jax.pure_callback(lambda d: np.random.hypergeometric(ngood=d['ngood'],
nbad=d['nbad'],
nsample=d['nsample'],
size=size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
d)
return _return(r)

def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None,
Expand All @@ -1295,9 +1294,9 @@ def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None,
size = jnp.shape(p)
size = _size2shape(size)
dtype = jax.dtypes.canonicalize_dtype(jnp.int_)
r = call(lambda p: np.random.logseries(p=p, size=size).astype(dtype),
p,
result_shape=jax.ShapeDtypeStruct(size, dtype))
r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
p)
return _return(r)

def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[int]]] = None,
Expand All @@ -1312,11 +1311,12 @@ def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[in
size = _size2shape(size)
d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
dtype = jax.dtypes.canonicalize_dtype(jnp.float_)
r = call(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
dfden=x['dfden'],
nonc=x['nonc'],
size=size).astype(dtype),
d, result_shape=jax.ShapeDtypeStruct(size, dtype))
r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
dfden=x['dfden'],
nonc=x['nonc'],
size=size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
d)
return _return(r)

# PyTorch compatibility #
Expand Down
10 changes: 7 additions & 3 deletions brainpy/_src/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import jax.numpy as jnp
import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_map, tree_flatten

from brainpy import math as bm, tools
Expand Down Expand Up @@ -632,12 +631,17 @@ def _step_func_predict(self, i, *x, shared_args=None):

# finally
if self.progress_bar:
id_tap(lambda *arg: self._pbar.update(), ())
jax.pure_callback(lambda: self._pbar.update(), ())
# share.clear_shargs()
clear_input(self.target)

if self._memory_efficient:
id_tap(self._step_mon_on_cpu, mon)
mon_shape_dtype = jax.ShapeDtypeStruct(mon.shape, mon.dtype)
result = jax.pure_callback(
self._step_mon_on_cpu,
mon_shape_dtype,
mon,
)
return out, None
else:
return out, mon
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/train/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from typing import Dict, Sequence, Union, Callable, Any

import jax
import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap

import brainpy.math as bm
from brainpy import tools
Expand Down Expand Up @@ -219,7 +219,7 @@ def _fun_train(self,
targets = target_data[node.name]
node.offline_fit(targets, fit_record)
if self.progress_bar:
id_tap(lambda *args: self._pbar.update(), ())
jax.pure_callback(lambda *args: self._pbar.update(), ())

def _step_func_monitor(self):
res = dict()
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/train/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import functools
from typing import Dict, Sequence, Union, Callable

import jax
import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_map

from brainpy import math as bm, tools
Expand Down Expand Up @@ -252,7 +252,7 @@ def _step_func_fit(self, i, xs: Sequence, ys: Dict, shared_args=None):

# finally
if self.progress_bar:
id_tap(lambda *arg: self._pbar.update(), ())
jax.pure_callback(lambda *arg: self._pbar.update(), ())
return out, monitors

def _check_interface(self):
Expand Down
9 changes: 6 additions & 3 deletions brainpy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import numpy as onp
from jax import numpy as jnp
from jax.experimental.host_callback import id_tap
from jax.lax import cond

conn = None
Expand Down Expand Up @@ -570,7 +569,11 @@ def is_all_objs(targets: Any, out_as: str = 'tuple'):


def _err_jit_true_branch(err_fun, x):
id_tap(err_fun, x)
if isinstance(x, (tuple, list)):
x_shape_dtype = tuple(jax.ShapeDtypeStruct(arr.shape, arr.dtype) for arr in x)
else:
x_shape_dtype = jax.ShapeDtypeStruct(x.shape, x.dtype)
jax.pure_callback(err_fun, x_shape_dtype, x)
return


Expand Down Expand Up @@ -629,6 +632,6 @@ def true_err_fun(arg, transforms):
raise err

cond(remove_vmap(as_jax(pred)),
lambda: id_tap(true_err_fun, None),
lambda: jax.pure_callback(true_err_fun, None),
lambda: None)

0 comments on commit 59fb681

Please sign in to comment.