Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 9, 2024
1 parent af69690 commit 69812a2
Show file tree
Hide file tree
Showing 11 changed files with 41 additions and 46 deletions.
4 changes: 1 addition & 3 deletions brainpy/_src/integrators/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import jax.numpy as jnp
import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_flatten

from brainpy import math as bm
Expand Down Expand Up @@ -245,8 +244,7 @@ def _step_fun_integrator(self, static_args, dyn_args, t, i):

# progress bar
if self.progress_bar:
jax.pure_callback(lambda *args: self._pbar.update(), ())
# id_tap(lambda *args: self._pbar.update(), ())
jax.pure_callback(lambda *args: self._pbar.update(), ())\

# return of function monitors
shared = dict(t=t + self.dt, dt=self.dt, i=i)
Expand Down
2 changes: 2 additions & 0 deletions brainpy/_src/math/event/tests/test_event_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

pytest.skip('Remove customize op tests', allow_module_level=True)

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

Expand Down
2 changes: 2 additions & 0 deletions brainpy/_src/math/jitconn/tests/test_event_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

pytest.skip('Remove customize op tests', allow_module_level=True)

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

Expand Down
2 changes: 2 additions & 0 deletions brainpy/_src/math/jitconn/tests/test_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

pytest.skip('Remove customize op tests', allow_module_level=True)

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

Expand Down
11 changes: 4 additions & 7 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import jax
import jax.numpy as jnp
from jax.errors import UnexpectedTracerError
from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_flatten, tree_unflatten
from tqdm.auto import tqdm

Expand Down Expand Up @@ -421,14 +420,14 @@ def call(pred, x=None):
def _warp(f):
@functools.wraps(f)
def new_f(*args, **kwargs):
return jax.tree_map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array))
return jax.tree.map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array))

return new_f


def _warp_data(data):
def new_f(*args, **kwargs):
return jax.tree_map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array))
return jax.tree.map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array))

return new_f

Expand Down Expand Up @@ -728,7 +727,6 @@ def fun2scan(carry, x):
results = body_fun(*x, **unroll_kwargs)
if progress_bar:
jax.pure_callback(lambda *arg: bar.update(), ())
# id_tap(lambda *arg: bar.update(), ())
return dyn_vars.dict_data(), results

if remat:
Expand Down Expand Up @@ -918,15 +916,14 @@ def fun2scan(carry, x):
carry, results = body_fun(carry, x)
if progress_bar:
jax.pure_callback(lambda *arg: bar.update(), ())
# id_tap(lambda *arg: bar.update(), ())
carry = jax.tree_map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
carry = jax.tree.map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
return (dyn_vars.dict_data(), carry), results

if remat:
fun2scan = jax.checkpoint(fun2scan)

def call(init, operands):
init = jax.tree_map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array))
init = jax.tree.map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array))
return jax.lax.scan(f=fun2scan,
init=(dyn_vars.dict_data(), init),
xs=operands,
Expand Down
53 changes: 27 additions & 26 deletions brainpy/_src/math/random.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-

1
import warnings
from collections import namedtuple
from functools import partial
Expand All @@ -10,7 +11,6 @@
import numpy as np
from jax import lax, jit, vmap, numpy as jnp, random as jr, core, dtypes
from jax._src.array import ArrayImpl
from jax.experimental.host_callback import call
from jax.tree_util import register_pytree_node_class

from brainpy.check import jit_error_checking, jit_error_checking_no_args
Expand Down Expand Up @@ -1233,9 +1233,9 @@ def zipf(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optiona
if size is None:
size = jnp.shape(a)
dtype = jax.dtypes.canonicalize_dtype(jnp.int_)
r = call(lambda x: np.random.zipf(x, size).astype(dtype),
a,
result_shape=jax.ShapeDtypeStruct(size, dtype))
dtype = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
a)
return _return(r)

def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
Expand All @@ -1244,9 +1244,9 @@ def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Option
size = jnp.shape(a)
size = _size2shape(size)
dtype = jax.dtypes.canonicalize_dtype(jnp.float_)
r = call(lambda a: np.random.power(a=a, size=size).astype(dtype),
a,
result_shape=jax.ShapeDtypeStruct(size, dtype))
r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
a)
return _return(r)

def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None,
Expand All @@ -1260,11 +1260,11 @@ def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None,
size = _size2shape(size)
d = {'dfnum': dfnum, 'dfden': dfden}
dtype = jax.dtypes.canonicalize_dtype(jnp.float_)
r = call(lambda x: np.random.f(dfnum=x['dfnum'],
dfden=x['dfden'],
size=size).astype(dtype),
d,
result_shape=jax.ShapeDtypeStruct(size, dtype))
r = jax.pure_callback(lambda x: np.random.f(dfnum=x['dfnum'],
dfden=x['dfden'],
size=size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
d)
return _return(r)

def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequence[int]]] = None,
Expand All @@ -1280,12 +1280,12 @@ def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequenc
size = _size2shape(size)
dtype = jax.dtypes.canonicalize_dtype(jnp.int_)
d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample}
r = call(lambda d: np.random.hypergeometric(ngood=d['ngood'],
nbad=d['nbad'],
nsample=d['nsample'],
size=size).astype(dtype),
d,
result_shape=jax.ShapeDtypeStruct(size, dtype))
r = jax.pure_callback()(lambda d: np.random.hypergeometric(ngood=d['ngood'],
nbad=d['nbad'],
nsample=d['nsample'],
size=size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
d)
return _return(r)

def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None,
Expand All @@ -1295,9 +1295,9 @@ def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None,
size = jnp.shape(p)
size = _size2shape(size)
dtype = jax.dtypes.canonicalize_dtype(jnp.int_)
r = call(lambda p: np.random.logseries(p=p, size=size).astype(dtype),
p,
result_shape=jax.ShapeDtypeStruct(size, dtype))
r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
p)
return _return(r)

def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[int]]] = None,
Expand All @@ -1312,11 +1312,12 @@ def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[in
size = _size2shape(size)
d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc}
dtype = jax.dtypes.canonicalize_dtype(jnp.float_)
r = call(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
dfden=x['dfden'],
nonc=x['nonc'],
size=size).astype(dtype),
d, result_shape=jax.ShapeDtypeStruct(size, dtype))
r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'],
dfden=x['dfden'],
nonc=x['nonc'],
size=size).astype(dtype),
jax.ShapeDtypeStruct(size, dtype),
d)
return _return(r)

# PyTorch compatibility #
Expand Down
3 changes: 3 additions & 0 deletions brainpy/_src/math/sparse/tests/test_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import brainpy as bp
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

pytest.skip('Remove customize op tests', allow_module_level=True)

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

Expand Down
3 changes: 0 additions & 3 deletions brainpy/_src/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import jax.numpy as jnp
import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_map, tree_flatten

from brainpy import math as bm, tools
Expand Down Expand Up @@ -633,7 +632,6 @@ def _step_func_predict(self, i, *x, shared_args=None):
# finally
if self.progress_bar:
jax.pure_callback(lambda: self._pbar.update(), ())
# id_tap(lambda *arg: self._pbar.update(), ())
# share.clear_shargs()
clear_input(self.target)

Expand All @@ -644,7 +642,6 @@ def _step_func_predict(self, i, *x, shared_args=None):
mon_shape_dtype,
mon,
)
# id_tap(self._step_mon_on_cpu, mon)
return out, None
else:
return out, mon
Expand Down
2 changes: 0 additions & 2 deletions brainpy/_src/train/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import jax
import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap

import brainpy.math as bm
from brainpy import tools
Expand Down Expand Up @@ -221,7 +220,6 @@ def _fun_train(self,
node.offline_fit(targets, fit_record)
if self.progress_bar:
jax.pure_callback(lambda *args: self._pbar.update(), ())
# id_tap(lambda *args: self._pbar.update(), ())

def _step_func_monitor(self):
res = dict()
Expand Down
2 changes: 0 additions & 2 deletions brainpy/_src/train/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import jax
import numpy as np
import tqdm.auto
from jax.experimental.host_callback import id_tap
from jax.tree_util import tree_map

from brainpy import math as bm, tools
Expand Down Expand Up @@ -253,7 +252,6 @@ def _step_func_fit(self, i, xs: Sequence, ys: Dict, shared_args=None):
# finally
if self.progress_bar:
jax.pure_callback(lambda *arg: self._pbar.update(), ())
# id_tap(lambda *arg: self._pbar.update(), ())
return out, monitors

def _check_interface(self):
Expand Down
3 changes: 0 additions & 3 deletions brainpy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import numpy as onp
from jax import numpy as jnp
from jax.experimental.host_callback import id_tap
from jax.lax import cond

conn = None
Expand Down Expand Up @@ -575,7 +574,6 @@ def _err_jit_true_branch(err_fun, x):
else:
x_shape_dtype = jax.ShapeDtypeStruct(x.shape, x.dtype)
jax.pure_callback(err_fun, x_shape_dtype, x)
# id_tap(err_fun, x)
return


Expand Down Expand Up @@ -634,6 +632,5 @@ def true_err_fun(arg, transforms):
raise err

cond(remove_vmap(as_jax(pred)),
# lambda: id_tap(true_err_fun, None),
lambda: jax.pure_callback(true_err_fun, None),
lambda: None)

0 comments on commit 69812a2

Please sign in to comment.