From 69812a21a3921fdd7b201ff6a21b323fd437ddd5 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 9 May 2024 16:57:20 +0800 Subject: [PATCH] Update --- brainpy/_src/integrators/runner.py | 4 +- .../_src/math/event/tests/test_event_csrmv.py | 2 + .../math/jitconn/tests/test_event_matvec.py | 2 + .../_src/math/jitconn/tests/test_matvec.py | 2 + .../_src/math/object_transform/controls.py | 11 ++-- brainpy/_src/math/random.py | 53 ++++++++++--------- brainpy/_src/math/sparse/tests/test_csrmv.py | 3 ++ brainpy/_src/runners.py | 3 -- brainpy/_src/train/offline.py | 2 - brainpy/_src/train/online.py | 2 - brainpy/check.py | 3 -- 11 files changed, 41 insertions(+), 46 deletions(-) diff --git a/brainpy/_src/integrators/runner.py b/brainpy/_src/integrators/runner.py index e4f9e79c..631b01d7 100644 --- a/brainpy/_src/integrators/runner.py +++ b/brainpy/_src/integrators/runner.py @@ -9,7 +9,6 @@ import jax.numpy as jnp import numpy as np import tqdm.auto -from jax.experimental.host_callback import id_tap from jax.tree_util import tree_flatten from brainpy import math as bm @@ -245,8 +244,7 @@ def _step_fun_integrator(self, static_args, dyn_args, t, i): # progress bar if self.progress_bar: - jax.pure_callback(lambda *args: self._pbar.update(), ()) - # id_tap(lambda *args: self._pbar.update(), ()) + jax.pure_callback(lambda *args: self._pbar.update(), ())\ # return of function monitors shared = dict(t=t + self.dt, dt=self.dt, i=i) diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 181ee552..0190628f 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -11,6 +11,8 @@ import brainpy.math as bm from brainpy._src.dependency_check import import_taichi +pytest.skip('Remove customize op tests', allow_module_level=True) + if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index dd1bafde..0be7a550 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -8,6 +8,8 @@ import brainpy.math as bm from brainpy._src.dependency_check import import_taichi +pytest.skip('Remove customize op tests', allow_module_level=True) + if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index e42bd369..e85045a8 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -8,6 +8,8 @@ import brainpy.math as bm from brainpy._src.dependency_check import import_taichi +pytest.skip('Remove customize op tests', allow_module_level=True) + if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index bf974968..126ca15c 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -7,7 +7,6 @@ import jax import jax.numpy as jnp from jax.errors import UnexpectedTracerError -from jax.experimental.host_callback import id_tap from jax.tree_util import tree_flatten, tree_unflatten from tqdm.auto import tqdm @@ -421,14 +420,14 @@ def call(pred, x=None): def _warp(f): @functools.wraps(f) def new_f(*args, **kwargs): - return jax.tree_map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array)) + return jax.tree.map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array)) return new_f def _warp_data(data): def new_f(*args, **kwargs): - return jax.tree_map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array)) + return jax.tree.map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array)) return new_f @@ -728,7 +727,6 @@ def fun2scan(carry, x): results = body_fun(*x, **unroll_kwargs) if progress_bar: jax.pure_callback(lambda *arg: bar.update(), ()) - # id_tap(lambda *arg: bar.update(), ()) return dyn_vars.dict_data(), results if remat: @@ -918,15 +916,14 @@ def fun2scan(carry, x): carry, results = body_fun(carry, x) if progress_bar: jax.pure_callback(lambda *arg: bar.update(), ()) - # id_tap(lambda *arg: bar.update(), ()) - carry = jax.tree_map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array)) + carry = jax.tree.map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array)) return (dyn_vars.dict_data(), carry), results if remat: fun2scan = jax.checkpoint(fun2scan) def call(init, operands): - init = jax.tree_map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array)) + init = jax.tree.map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array)) return jax.lax.scan(f=fun2scan, init=(dyn_vars.dict_data(), init), xs=operands, diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index 9ae012bc..355d2c98 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- +1 import warnings from collections import namedtuple from functools import partial @@ -10,7 +11,6 @@ import numpy as np from jax import lax, jit, vmap, numpy as jnp, random as jr, core, dtypes from jax._src.array import ArrayImpl -from jax.experimental.host_callback import call from jax.tree_util import register_pytree_node_class from brainpy.check import jit_error_checking, jit_error_checking_no_args @@ -1233,9 +1233,9 @@ def zipf(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optiona if size is None: size = jnp.shape(a) dtype = jax.dtypes.canonicalize_dtype(jnp.int_) - r = call(lambda x: np.random.zipf(x, size).astype(dtype), - a, - result_shape=jax.ShapeDtypeStruct(size, dtype)) + dtype = jax.pure_callback(lambda x: np.random.zipf(x, size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + a) return _return(r) def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): @@ -1244,9 +1244,9 @@ def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Option size = jnp.shape(a) size = _size2shape(size) dtype = jax.dtypes.canonicalize_dtype(jnp.float_) - r = call(lambda a: np.random.power(a=a, size=size).astype(dtype), - a, - result_shape=jax.ShapeDtypeStruct(size, dtype)) + r = jax.pure_callback(lambda a: np.random.power(a=a, size=size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + a) return _return(r) def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None, @@ -1260,11 +1260,11 @@ def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None, size = _size2shape(size) d = {'dfnum': dfnum, 'dfden': dfden} dtype = jax.dtypes.canonicalize_dtype(jnp.float_) - r = call(lambda x: np.random.f(dfnum=x['dfnum'], - dfden=x['dfden'], - size=size).astype(dtype), - d, - result_shape=jax.ShapeDtypeStruct(size, dtype)) + r = jax.pure_callback(lambda x: np.random.f(dfnum=x['dfnum'], + dfden=x['dfden'], + size=size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + d) return _return(r) def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequence[int]]] = None, @@ -1280,12 +1280,12 @@ def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequenc size = _size2shape(size) dtype = jax.dtypes.canonicalize_dtype(jnp.int_) d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample} - r = call(lambda d: np.random.hypergeometric(ngood=d['ngood'], - nbad=d['nbad'], - nsample=d['nsample'], - size=size).astype(dtype), - d, - result_shape=jax.ShapeDtypeStruct(size, dtype)) + r = jax.pure_callback()(lambda d: np.random.hypergeometric(ngood=d['ngood'], + nbad=d['nbad'], + nsample=d['nsample'], + size=size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + d) return _return(r) def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None, @@ -1295,9 +1295,9 @@ def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None, size = jnp.shape(p) size = _size2shape(size) dtype = jax.dtypes.canonicalize_dtype(jnp.int_) - r = call(lambda p: np.random.logseries(p=p, size=size).astype(dtype), - p, - result_shape=jax.ShapeDtypeStruct(size, dtype)) + r = jax.pure_callback(lambda p: np.random.logseries(p=p, size=size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + p) return _return(r) def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[int]]] = None, @@ -1312,11 +1312,12 @@ def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[in size = _size2shape(size) d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc} dtype = jax.dtypes.canonicalize_dtype(jnp.float_) - r = call(lambda x: np.random.noncentral_f(dfnum=x['dfnum'], - dfden=x['dfden'], - nonc=x['nonc'], - size=size).astype(dtype), - d, result_shape=jax.ShapeDtypeStruct(size, dtype)) + r = jax.pure_callback(lambda x: np.random.noncentral_f(dfnum=x['dfnum'], + dfden=x['dfden'], + nonc=x['nonc'], + size=size).astype(dtype), + jax.ShapeDtypeStruct(size, dtype), + d) return _return(r) # PyTorch compatibility # diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index acedcff1..1ef98be3 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -9,6 +9,9 @@ import brainpy as bp import brainpy.math as bm from brainpy._src.dependency_check import import_taichi + +pytest.skip('Remove customize op tests', allow_module_level=True) + if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) diff --git a/brainpy/_src/runners.py b/brainpy/_src/runners.py index 6340781d..80609608 100644 --- a/brainpy/_src/runners.py +++ b/brainpy/_src/runners.py @@ -11,7 +11,6 @@ import jax.numpy as jnp import numpy as np import tqdm.auto -from jax.experimental.host_callback import id_tap from jax.tree_util import tree_map, tree_flatten from brainpy import math as bm, tools @@ -633,7 +632,6 @@ def _step_func_predict(self, i, *x, shared_args=None): # finally if self.progress_bar: jax.pure_callback(lambda: self._pbar.update(), ()) - # id_tap(lambda *arg: self._pbar.update(), ()) # share.clear_shargs() clear_input(self.target) @@ -644,7 +642,6 @@ def _step_func_predict(self, i, *x, shared_args=None): mon_shape_dtype, mon, ) - # id_tap(self._step_mon_on_cpu, mon) return out, None else: return out, mon diff --git a/brainpy/_src/train/offline.py b/brainpy/_src/train/offline.py index fe778eb4..e801a29e 100644 --- a/brainpy/_src/train/offline.py +++ b/brainpy/_src/train/offline.py @@ -5,7 +5,6 @@ import jax import numpy as np import tqdm.auto -from jax.experimental.host_callback import id_tap import brainpy.math as bm from brainpy import tools @@ -221,7 +220,6 @@ def _fun_train(self, node.offline_fit(targets, fit_record) if self.progress_bar: jax.pure_callback(lambda *args: self._pbar.update(), ()) - # id_tap(lambda *args: self._pbar.update(), ()) def _step_func_monitor(self): res = dict() diff --git a/brainpy/_src/train/online.py b/brainpy/_src/train/online.py index c7398a79..932de501 100644 --- a/brainpy/_src/train/online.py +++ b/brainpy/_src/train/online.py @@ -5,7 +5,6 @@ import jax import numpy as np import tqdm.auto -from jax.experimental.host_callback import id_tap from jax.tree_util import tree_map from brainpy import math as bm, tools @@ -253,7 +252,6 @@ def _step_func_fit(self, i, xs: Sequence, ys: Dict, shared_args=None): # finally if self.progress_bar: jax.pure_callback(lambda *arg: self._pbar.update(), ()) - # id_tap(lambda *arg: self._pbar.update(), ()) return out, monitors def _check_interface(self): diff --git a/brainpy/check.py b/brainpy/check.py index 5bf3cda3..26b2afaa 100644 --- a/brainpy/check.py +++ b/brainpy/check.py @@ -7,7 +7,6 @@ import numpy as np import numpy as onp from jax import numpy as jnp -from jax.experimental.host_callback import id_tap from jax.lax import cond conn = None @@ -575,7 +574,6 @@ def _err_jit_true_branch(err_fun, x): else: x_shape_dtype = jax.ShapeDtypeStruct(x.shape, x.dtype) jax.pure_callback(err_fun, x_shape_dtype, x) - # id_tap(err_fun, x) return @@ -634,6 +632,5 @@ def true_err_fun(arg, transforms): raise err cond(remove_vmap(as_jax(pred)), - # lambda: id_tap(true_err_fun, None), lambda: jax.pure_callback(true_err_fun, None), lambda: None)