From f6beddc89db6444ed24d7048971cc57d31cbb4ad Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 6 Sep 2023 17:15:58 +0800 Subject: [PATCH 01/10] update dependencies --- brainpy/_src/checkpoints/io.py | 4 +- brainpy/_src/checkpoints/tests/test_io.py | 48 +++++++++---------- .../_src/running/pathos_multiprocessing.py | 4 +- requirements-dev.txt | 2 - 4 files changed, 28 insertions(+), 30 deletions(-) diff --git a/brainpy/_src/checkpoints/io.py b/brainpy/_src/checkpoints/io.py index bf254bf0e..4e712c5ca 100644 --- a/brainpy/_src/checkpoints/io.py +++ b/brainpy/_src/checkpoints/io.py @@ -151,7 +151,7 @@ def save_as_h5(filename: str, variables: dict): raise ValueError(f'Cannot save variables as a HDF5 file. We only support file with ' f'postfix of ".hdf5" and ".h5". But we got {filename}') - import h5py + import h5py # noqa # check variables check_dict_data(variables, name='variables') @@ -184,7 +184,7 @@ def load_by_h5(filename: str, target, verbose: bool = False): f'postfix of ".hdf5" and ".h5". But we got {filename}') # read data - import h5py + import h5py # noqa load_vars = dict() with h5py.File(filename, "r") as f: for key in f.keys(): diff --git a/brainpy/_src/checkpoints/tests/test_io.py b/brainpy/_src/checkpoints/tests/test_io.py index f8ed80210..36c8f374b 100644 --- a/brainpy/_src/checkpoints/tests/test_io.py +++ b/brainpy/_src/checkpoints/tests/test_io.py @@ -40,18 +40,18 @@ def __init__(self): print(self.net.vars().keys()) print(self.net.vars().unique().keys()) - def test_h5(self): - bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars()) - bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True) - - bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars()) - bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True) - - def test_h5_postfix(self): - with self.assertRaises(ValueError): - bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars()) - with self.assertRaises(ValueError): - bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True) + # def test_h5(self): + # bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars()) + # bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True) + # + # bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars()) + # bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True) + # + # def test_h5_postfix(self): + # with self.assertRaises(ValueError): + # bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars()) + # with self.assertRaises(ValueError): + # bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True) def test_npz(self): bp.checkpoints.io.save_as_npz('io_test_tmp.npz', self.net.vars()) @@ -120,18 +120,18 @@ def __init__(self): print(self.net.vars().keys()) print(self.net.vars().unique().keys()) - def test_h5(self): - bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars()) - bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True) - - bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars()) - bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True) - - def test_h5_postfix(self): - with self.assertRaises(ValueError): - bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars()) - with self.assertRaises(ValueError): - bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True) + # def test_h5(self): + # bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars()) + # bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True) + # + # bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars()) + # bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True) + # + # def test_h5_postfix(self): + # with self.assertRaises(ValueError): + # bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars()) + # with self.assertRaises(ValueError): + # bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True) def test_npz(self): bp.checkpoints.io.save_as_npz('io_test_tmp.npz', self.net.vars()) diff --git a/brainpy/_src/running/pathos_multiprocessing.py b/brainpy/_src/running/pathos_multiprocessing.py index b58b1691e..1573a541c 100644 --- a/brainpy/_src/running/pathos_multiprocessing.py +++ b/brainpy/_src/running/pathos_multiprocessing.py @@ -18,8 +18,8 @@ from brainpy.errors import PackageMissingError try: - from pathos.helpers import cpu_count - from pathos.multiprocessing import ProcessPool + from pathos.helpers import cpu_count # noqa + from pathos.multiprocessing import ProcessPool # noqa except ModuleNotFoundError: cpu_count = None ProcessPool = None diff --git a/requirements-dev.txt b/requirements-dev.txt index d8e87ac5f..6e6392b31 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,8 +6,6 @@ jax>=0.4.1 jaxlib>=0.4.1 scipy>=1.1.0 brainpylib -h5py -pathos # test requirements pytest From b4aeecc39d15e8d8180c7149c9f026c1a450a030 Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 6 Sep 2023 17:21:15 +0800 Subject: [PATCH 02/10] updates --- brainpy/_src/dyn/neurons/hh.py | 4 +- brainpy/_src/dyn/projections/aligns.py | 55 +++++- examples/dynamics_simulation/COBA_parallel.py | 167 ++++++++++++++---- 3 files changed, 184 insertions(+), 42 deletions(-) diff --git a/brainpy/_src/dyn/neurons/hh.py b/brainpy/_src/dyn/neurons/hh.py index afb4ab262..2069f4e65 100644 --- a/brainpy/_src/dyn/neurons/hh.py +++ b/brainpy/_src/dyn/neurons/hh.py @@ -180,7 +180,7 @@ def update(self, x=None): return super().update(x) -class HHLTC(NeuDyn): +class HHLTC(HHTypedNeuron): r"""Hodgkin–Huxley neuron model with liquid time constant. **Model Descriptions** @@ -758,7 +758,7 @@ def update(self, x=None): return super().update(x) -class WangBuzsakiHHLTC(NeuDyn): +class WangBuzsakiHHLTC(HHTypedNeuron): r"""Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model with liquid time constant. Each model is described by a single compartment and obeys the current balance equation: diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py index d63033eb7..2dfa2dd14 100644 --- a/brainpy/_src/dyn/projections/aligns.py +++ b/brainpy/_src/dyn/projections/aligns.py @@ -1,7 +1,5 @@ from typing import Optional, Callable, Union -import jax - from brainpy import math as bm, check from brainpy._src.delay import Delay, DelayAccess, delay_identifier, init_delay_by_return from brainpy._src.dynsys import DynamicalSystem, Projection @@ -127,6 +125,7 @@ def __init__( # references self.refs = dict(post=post, out=out) # invisible to ``self.nodes()`` + self.refs['comm'] = comm # unify the access def update(self, x): current = self.comm(x) @@ -218,6 +217,7 @@ def __init__( self.refs = dict(post=post) # invisible to ``self.nodes()`` self.refs['syn'] = post.get_bef_update(self._post_repr).syn self.refs['out'] = post.get_bef_update(self._post_repr).out + self.refs['comm'] = comm # unify the access def update(self, x): current = self.comm(x) @@ -342,6 +342,9 @@ def __init__( self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` self.refs['syn'] = post.get_bef_update(self._post_repr).syn # invisible to ``self.node()`` self.refs['out'] = post.get_bef_update(self._post_repr).out # invisible to ``self.node()`` + # unify the access + self.refs['comm'] = comm + self.refs['delay'] = pre.get_aft_update(delay_identifier) def update(self): x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) @@ -422,9 +425,13 @@ def __init__( post.add_bef_update(self.name, _AlignPost(syn, out)) # reference - self.refs = dict(post=post) # invisible to ``self.nodes()`` + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['post'] = post self.refs['syn'] = post.get_bef_update(self.name).syn self.refs['out'] = post.get_bef_update(self.name).out + # unify the access + self.refs['comm'] = comm def update(self, x): current = self.comm(x) @@ -538,8 +545,15 @@ def __init__( post.add_inp_fun(out_name, out) # references - self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post self.refs['out'] = out + # unify the access + self.refs['delay'] = pre.get_aft_update(delay_identifier) + self.refs['comm'] = comm + self.refs['syn'] = syn def update(self): x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) @@ -655,8 +669,15 @@ def __init__( post.add_inp_fun(out_name, out) # references - self.refs = dict(pre=pre, post=post, out=out, delay=delay_cls) # invisible to ``self.nodes()`` + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out + self.refs['delay'] = delay_cls self.refs['syn'] = pre.get_aft_update(self._syn_id).syn + # unify the access + self.refs['comm'] = comm def update(self, x=None): if x is None: @@ -778,9 +799,14 @@ def __init__( post.add_inp_fun(out_name, out) # references - self.refs = dict(pre=pre, post=post) # invisible to `self.nodes()` + self.refs = dict() + # invisible to `self.nodes()` + self.refs['pre'] = pre + self.refs['post'] = post self.refs['syn'] = delay_cls.get_bef_update(self._syn_id).syn self.refs['out'] = out + # unify the access + self.refs['comm'] = comm def update(self): x = _get_return(self.refs['syn'].return_info()) @@ -890,9 +916,15 @@ def __init__( post.add_inp_fun(out_name, out) # references - self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()`` + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out self.refs['delay'] = delay_cls self.refs['syn'] = syn + # unify the access + self.refs['comm'] = comm def update(self, x=None): if x is None: @@ -1006,8 +1038,15 @@ def __init__( post.add_inp_fun(out_name, out) # references - self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()`` + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out self.refs['delay'] = pre.get_aft_update(delay_identifier) + # unify the access + self.refs['syn'] = syn + self.refs['comm'] = comm def update(self): spk = self.refs['delay'].at(self.name) diff --git a/examples/dynamics_simulation/COBA_parallel.py b/examples/dynamics_simulation/COBA_parallel.py index a0f10de09..45cf81953 100644 --- a/examples/dynamics_simulation/COBA_parallel.py +++ b/examples/dynamics_simulation/COBA_parallel.py @@ -2,10 +2,23 @@ import brainpy as bp import brainpy.math as bm +from jax.experimental.maps import xmap + # bm.set_host_device_count(4) +class ExpJIT(bp.Projection): + def __init__(self, pre_num, post, prob, g_max, tau=5., E=0.): + super().__init__() + self.proj = bp.dyn.ProjAlignPostMg1( + comm=bp.dnn.EventJitFPHomoLinear(pre_num, post.num, prob=prob, weight=g_max), + syn=bp.dyn.Expon.desc(size=post.num, tau=tau, sharding=[bm.sharding.NEU_AXIS]), + out=bp.dyn.COBA.desc(E=E), + post=post + ) + + class EINet1(bp.DynSysGroup): def __init__(self): super().__init__() @@ -13,18 +26,8 @@ def __init__(self): V_initializer=bp.init.Normal(-55., 2.), sharding=[bm.sharding.NEU_AXIS]) self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.ProjAlignPostMg1( - comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), - syn=bp.dyn.Expon.desc(size=4000, tau=5., sharding=[bm.sharding.NEU_AXIS]), - out=bp.dyn.COBA.desc(E=0.), - post=self.N - ) - self.I = bp.dyn.ProjAlignPostMg1( - comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), - syn=bp.dyn.Expon.desc(size=4000, tau=10., sharding=[bm.sharding.NEU_AXIS]), - out=bp.dyn.COBA.desc(E=-80.), - post=self.N - ) + self.E = ExpJIT(3200, self.N, 0.02, 0.6) + self.I = ExpJIT(800, self.N, 0.02, 6.7, E=-80., tau=10.) def update(self, input): spk = self.delay.at('I') @@ -34,6 +37,18 @@ def update(self, input): return self.N.spike.value +class ExpMasked(bp.Projection): + def __init__(self, pre_num, post, prob, g_max, tau=5., E=0.): + super().__init__() + self.proj = bp.dyn.ProjAlignPostMg1( + comm=bp.dnn.MaskedLinear(bp.conn.FixedProb(prob, pre=pre_num, post=post.num), weight=g_max, + sharding=[None, bm.sharding.NEU_AXIS]), + syn=bp.dyn.Expon.desc(size=post.num, tau=tau, sharding=[bm.sharding.NEU_AXIS]), + out=bp.dyn.COBA.desc(E=E), + post=post + ) + + class EINet2(bp.DynSysGroup): def __init__(self): super().__init__() @@ -41,21 +56,79 @@ def __init__(self): V_initializer=bp.init.Normal(-55., 2.), sharding=[bm.sharding.NEU_AXIS]) self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.ProjAlignPostMg1( - comm=bp.dnn.MaskedLinear(bp.conn.FixedProb(0.02, pre=3200, post=4000), weight=0.6, - sharding=[None, bm.sharding.NEU_AXIS]), - syn=bp.dyn.Expon.desc(size=4000, tau=5., sharding=[bm.sharding.NEU_AXIS]), - out=bp.dyn.COBA.desc(E=0.), - post=self.N + self.E = ExpMasked(3200, self.N, 0.02, 0.6) + self.I = ExpMasked(800, self.N, 0.02, 6.7, E=-80., tau=10.) + + def update(self, input): + spk = self.delay.at('I') + self.E(spk[:3200]) + self.I(spk[3200:]) + self.delay(self.N(input)) + return self.N.spike.value + + +class PCSR(bp.dnn.Layer): + def __init__(self, conn, weight, num_shard, transpose=True): + super().__init__() + + self.conn = conn + self.transpose = transpose + self.num_shard = num_shard + + # connection + self.indices = [] + self.inptr = [] + for _ in range(num_shard): + indices, inptr = self.conn.require('csr') + self.indices.append(indices) + self.inptr.append(inptr) + self.indices = bm.asarray(self.indices) + self.inptr = bm.asarray(self.inptr) + + # weight + weight = bp.init.parameter(weight, (self.indices.size,)) + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, v): + # ax1 = None if bm.size(self.weight) > 1 else (None, bm.sharding.NEU_AXIS) + mapped = xmap( + self._f, + in_axes=((bm.sharding.NEU_AXIS, None), (bm.sharding.NEU_AXIS, None), (None, )), + out_axes=(bm.sharding.NEU_AXIS, None), + # axis_resources={bm.sharding.NEU_AXIS: bm.sharding.NEU_AXIS}, ) - self.I = bp.dyn.ProjAlignPostMg1( - comm=bp.dnn.MaskedLinear(bp.conn.FixedProb(0.02, pre=800, post=4000), weight=0.6, - sharding=[None, bm.sharding.NEU_AXIS]), - syn=bp.dyn.Expon.desc(size=4000, tau=10., sharding=[bm.sharding.NEU_AXIS]), - out=bp.dyn.COBA.desc(E=-80.), - post=self.N + r = mapped(self.indices, self.inptr, v) + return r.flatten() + + def _f(self, indices, indptr, x): + return bm.event.csrmv(self.weight, indices, indptr, x, + shape=(self.conn.pre_num, self.conn.post_num // self.num_shard), + transpose=self.transpose) + + +class ExpMasked2(bp.Projection): + def __init__(self, pre_num, post, prob, g_max, tau=5., E=0.): + super().__init__() + self.proj = bp.dyn.ProjAlignPostMg1( + comm=PCSR(bp.conn.FixedProb(prob, pre=pre_num, post=post.num), weight=g_max, num_shard=4), + syn=bp.dyn.Expon.desc(size=post.num, tau=tau, sharding=[bm.sharding.NEU_AXIS]), + out=bp.dyn.COBA.desc(E=E), + post=post ) + +class EINet3(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.N = bp.dyn.LifRefLTC(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.), + sharding=[bm.sharding.NEU_AXIS]) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.E = ExpMasked2(3200, self.N, 0.02, 0.6) + self.I = ExpMasked2(800, self.N, 0.02, 6.7, E=-80., tau=10.) + def update(self, input): spk = self.delay.at('I') self.E(spk[:3200]) @@ -64,14 +137,44 @@ def update(self, input): return self.N.spike.value -@bm.jit -def run(indexes): - return bm.for_loop(lambda i: model.step_run(i, 20.), indexes) +def try_ei_net1(): + @bm.jit + def run(indexes): + return bm.for_loop(lambda i: model.step_run(i, 20.), indexes) + + with bm.sharding.device_mesh(jax.devices(), [bm.sharding.NEU_AXIS]): + model = EINet1() + indices = bm.arange(1000) + spks = run(indices) + bp.visualize.raster_plot(indices, spks, show=True) + + +def try_ei_net2(): + @bm.jit + def run(indexes): + return bm.for_loop(lambda i: model.step_run(i, 20.), indexes) + + with bm.sharding.device_mesh(jax.devices(), [bm.sharding.NEU_AXIS]): + model = EINet2() + indices = bm.arange(1000) + spks = run(indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + +def try_ei_net3(): + @bm.jit + def run(indexes): + return bm.for_loop(lambda i: model.step_run(i, 20.), indexes) + with bm.sharding.device_mesh(jax.devices(), [bm.sharding.NEU_AXIS]): + model = EINet3() + indices = bm.arange(1000) + spks = run(indices) + bp.visualize.raster_plot(indices, spks, show=True) -with bm.sharding.device_mesh(jax.devices(), [bm.sharding.NEU_AXIS]): - model = EINet2() - indices = bm.arange(1000) - spks = run(indices) -bp.visualize.raster_plot(indices, spks, show=True) +if __name__ == '__main__': + # try_ei_net1() + # try_ei_net2() + try_ei_net3() From 94491112e20f1a4988e85298267e6c2922205056 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 8 Sep 2023 22:01:50 +0800 Subject: [PATCH 03/10] [random] add `brainpy.math.random.split_keys()` --- brainpy/_src/math/random.py | 13 +++++++++++++ brainpy/math/random.py | 1 + 2 files changed, 14 insertions(+) diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index ddd4753a9..eb04c5d2e 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -1253,6 +1253,19 @@ def split_key(): return DEFAULT.split_key() +def split_keys(n): + """Create multiple seeds from the current seed. This is used + internally by `pmap` and `vmap` to ensure that random numbers + are different in parallel threads. + + Parameters + ---------- + n : int + The number of seeds to generate. + """ + return DEFAULT.split_keys(n) + + def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState: if seed_or_key is None: return DEFAULT.clone() if clone else DEFAULT diff --git a/brainpy/math/random.py b/brainpy/math/random.py index ed3fbeea4..dde1f4832 100644 --- a/brainpy/math/random.py +++ b/brainpy/math/random.py @@ -7,6 +7,7 @@ seed as seed, split_key as split_key, + split_keys as split_keys, default_rng as default_rng, # numpy compatibility From 1ea18e62339dcc383ad6e7e89c2957d04e34b330 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 8 Sep 2023 22:02:22 +0800 Subject: [PATCH 04/10] update requirements --- requirements-dev.txt | 1 + requirements-doc.txt | 1 + requirements.txt | 1 + 3 files changed, 3 insertions(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index 6e6392b31..126f0bd27 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,6 +6,7 @@ jax>=0.4.1 jaxlib>=0.4.1 scipy>=1.1.0 brainpylib +numba # test requirements pytest diff --git a/requirements-doc.txt b/requirements-doc.txt index dc67a4b04..d41a8cf41 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -6,6 +6,7 @@ jax>=0.4.1 matplotlib>=3.4 jaxlib>=0.4.1 scipy>=1.1.0 +numba # document requirements pandoc diff --git a/requirements.txt b/requirements.txt index d8343cde7..74db0a68a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ numpy jax>=0.4.1 tqdm msgpack +numba \ No newline at end of file From df42206be95904294bb8caa4048dba94accf0c6c Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 8 Sep 2023 22:03:11 +0800 Subject: [PATCH 05/10] [math] support initializing in-trace variables --- brainpy/_src/math/object_transform/base.py | 33 ++++++++++++----- .../object_transform/tests/test_variable.py | 36 +++++++++++++++++++ .../_src/math/object_transform/variables.py | 33 +++++++++-------- 3 files changed, 79 insertions(+), 23 deletions(-) create mode 100644 brainpy/_src/math/object_transform/tests/test_variable.py diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index 851e23776..af6ae6e67 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -102,17 +102,35 @@ def __init__(self, name=None): def setattr(self, key: str, value: Any) -> None: super().__setattr__(key, value) + def in_trace_variable(self, key: str, value: Variable) -> Variable: + """Initialize and get the in-trace variable. + + Args: + key: str. The name of the variable. + value: Array. The data of the in-trace variable. + + Returns: + variable. + """ + if not hasattr(self, key): + if not isinstance(value, Variable): + value = Variable(value) + value._ready_to_trace = True + self.setattr(key, value) + else: + var = getattr(self, key) + var.value = value + value = var + return value + def __setattr__(self, key: str, value: Any) -> None: """Overwrite `__setattr__` method for changing :py:class:`~.Variable` values. .. versionadded:: 2.3.1 - Parameters - ---------- - key: str - The attribute. - value: Any - The value. + Args: + key: str. The attribute. + value: Any. The value. """ if key in self.__dict__: val = self.__dict__[key] @@ -252,7 +270,7 @@ def vars(self, continue v = getattr(node, k) if isinstance(v, Variable) and not isinstance(v, exclude_types): - gather[f'{node_path}.{k}' if node_path else k] = v + gather[f'{node_path}.{k}' if node_path else k] = v elif isinstance(v, VarList): for i, vv in enumerate(v): if not isinstance(vv, exclude_types): @@ -702,4 +720,3 @@ def __setitem__(self, key, value) -> 'VarDict': node_dict = NodeDict - diff --git a/brainpy/_src/math/object_transform/tests/test_variable.py b/brainpy/_src/math/object_transform/tests/test_variable.py new file mode 100644 index 000000000..ef703fba6 --- /dev/null +++ b/brainpy/_src/math/object_transform/tests/test_variable.py @@ -0,0 +1,36 @@ +import brainpy.math as bm +import unittest + + +class TestVar(unittest.TestCase): + def test1(self): + class A(bm.BrainPyObject): + def __init__(self): + super().__init__() + self.a = bm.Variable(1) + self.f1 = bm.jit(self.f) + self.f2 = bm.jit(self.ff) + + def f(self): + b = self.in_trace_variable('b', bm.ones(1,)) + self.a += (b * 2) + return self.a.value + + def ff(self): + self.b += 1. + + print() + f_jit = bm.jit(A().f) + f_jit() + self.assertTrue(len(f_jit._dyn_vars) == 2) + + print() + a = A() + self.assertTrue(bm.all(a.f1() == 2.)) + self.assertTrue(len(a.f1._dyn_vars) == 2) + print(a.f2()) + self.assertTrue(len(a.f2._dyn_vars) == 1) + + + + diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index f526a6680..a8a3e54d0 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -6,6 +6,7 @@ from jax import numpy as jnp from jax.dtypes import canonicalize_dtype from jax.tree_util import register_pytree_node_class +from jax._src.array import ArrayImpl from brainpy._src.math.sharding import BATCH_AXIS from brainpy._src.math.ndarray import Array @@ -38,7 +39,13 @@ def add(self, var: 'Variable'): id_ = id(var) if id_ not in self: self[id_] = var - self._values[id_] = var._value + # self._values[id_] = var._value + v = var._value + if not isinstance(v, ArrayImpl): + with jax.ensure_compile_time_eval(): + v = jnp.zeros_like(v) + var._value = v + self._values[id_] = v def collect_values(self): """Collect the value of each variable once again.""" @@ -71,7 +78,7 @@ def dict_data(self) -> dict: """Get all data in the collected variables with a python dict structure.""" new_dict = dict() for id_, elem in tuple(self.items()): - new_dict[id_] = elem.value if isinstance(elem, Array) else elem + new_dict[id_] = elem.value return new_dict def list_data(self) -> list: @@ -163,14 +170,11 @@ class Variable(Array): Note that when initializing a `Variable` by the data shape, all values in this `Variable` will be initialized as zeros. - Parameters - ---------- - value_or_size: Shape, Array, int - The value or the size of the value. - dtype: - The type of the data. - batch_axis: optional, int - The batch axis. + Args: + value_or_size: Shape, Array, int. The value or the size of the value. + dtype: Any. The type of the data. + batch_axis: optional, int. The batch axis. + axis_names: sequence of str. The name for each axis. """ __slots__ = ('_value', '_batch_axis', '_ready_to_trace', 'axis_names') @@ -191,7 +195,7 @@ def __init__( else: value = value_or_size - super(Variable, self).__init__(value, dtype=dtype) + super().__init__(value, dtype=dtype) # check batch axis if isinstance(value, Variable): @@ -276,7 +280,6 @@ def value(self, v): v = v self._value = v - def _append_to_stack(self): if self._ready_to_trace: for stack in var_stack_list: @@ -319,7 +322,7 @@ def __init__( axis_names: Optional[Sequence[str]] = None, _ready_to_trace: bool = True ): - super(TrainVar, self).__init__( + super().__init__( value_or_size, dtype=dtype, batch_axis=batch_axis, @@ -342,7 +345,7 @@ def __init__( axis_names: Optional[Sequence[str]] = None, _ready_to_trace: bool = True ): - super(Parameter, self).__init__( + super().__init__( value_or_size, dtype=dtype, batch_axis=batch_axis, @@ -390,7 +393,7 @@ def __init__( self.index = jax.tree_util.tree_map(_as_jax_array_, index, is_leaf=lambda a: isinstance(a, Array)) if not isinstance(value, Variable): raise ValueError('Must be instance of Variable.') - super(VariableView, self).__init__(value.value, batch_axis=value.batch_axis, _ready_to_trace=False) + super().__init__(value.value, batch_axis=value.batch_axis, _ready_to_trace=False) self._value = value def __repr__(self) -> str: From 8d66e6892b4257cd1964fa2d26e95f931b7dc4a7 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 8 Sep 2023 22:04:09 +0800 Subject: [PATCH 06/10] [math] updates --- brainpy/_src/connect/random_conn.py | 7 +++++-- brainpy/_src/math/sparse/_csr_mv.py | 3 +++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py index ee98ea135..1f5b1db6d 100644 --- a/brainpy/_src/connect/random_conn.py +++ b/brainpy/_src/connect/random_conn.py @@ -128,8 +128,11 @@ def build_csr(self): return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) def build_mat(self): - pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio - mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state + if self.pre_ratio < 1.: + pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio + mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state + else: + mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) mat = bm.asarray(mat) if not self.include_self: bm.fill_diagonal(mat, False) diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py index e43965d4d..9a37f0902 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/_csr_mv.py @@ -81,6 +81,9 @@ def csrmv( indptr = as_jax(indptr) vector = as_jax(vector) + if vector.dtype == jnp.bool_: + vector = as_jax(vector, dtype=data.dtype) + if method == 'cusparse': if jax.default_backend() == 'gpu': if data.shape[0] == 1: From fa30736d8ac308572851d0232c6ad694e73060f6 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 8 Sep 2023 23:12:57 +0800 Subject: [PATCH 07/10] [math] fix the bug in tracing in-comp variables --- brainpy/_src/math/object_transform/base.py | 6 +++++- .../math/object_transform/tests/test_variable.py | 15 +++++++++++++++ brainpy/_src/math/object_transform/variables.py | 16 +++++++--------- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index af6ae6e67..6103bf6df 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -20,7 +20,7 @@ from brainpy._src.math.object_transform.naming import (get_unique_name, check_name_uniqueness) from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar, - VarList, VarDict) + VarList, VarDict, var_stack_list) StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys']) @@ -116,6 +116,10 @@ def in_trace_variable(self, key: str, value: Variable) -> Variable: if not isinstance(value, Variable): value = Variable(value) value._ready_to_trace = True + v = value._value + if len(var_stack_list) > 0 and isinstance(v, jax.core.Tracer): + with jax.ensure_compile_time_eval(): + value._value = jax.numpy.zeros_like(v) self.setattr(key, value) else: var = getattr(self, key) diff --git a/brainpy/_src/math/object_transform/tests/test_variable.py b/brainpy/_src/math/object_transform/tests/test_variable.py index ef703fba6..d4a289694 100644 --- a/brainpy/_src/math/object_transform/tests/test_variable.py +++ b/brainpy/_src/math/object_transform/tests/test_variable.py @@ -10,6 +10,7 @@ def __init__(self): self.a = bm.Variable(1) self.f1 = bm.jit(self.f) self.f2 = bm.jit(self.ff) + self.f3 = bm.jit(self.fff) def f(self): b = self.in_trace_variable('b', bm.ones(1,)) @@ -19,6 +20,12 @@ def f(self): def ff(self): self.b += 1. + def fff(self): + self.f() + self.ff() + self.b *= self.a + return self.b.value + print() f_jit = bm.jit(A().f) f_jit() @@ -31,6 +38,14 @@ def ff(self): print(a.f2()) self.assertTrue(len(a.f2._dyn_vars) == 1) + print() + a = A() + print() + self.assertTrue(bm.allclose(a.f3(), 4.)) + self.assertTrue(len(a.f3._dyn_vars) == 2) + + bm.clear_buffer_memory() + diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index a8a3e54d0..06020f4cc 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -6,7 +6,6 @@ from jax import numpy as jnp from jax.dtypes import canonicalize_dtype from jax.tree_util import register_pytree_node_class -from jax._src.array import ArrayImpl from brainpy._src.math.sharding import BATCH_AXIS from brainpy._src.math.ndarray import Array @@ -39,13 +38,13 @@ def add(self, var: 'Variable'): id_ = id(var) if id_ not in self: self[id_] = var - # self._values[id_] = var._value - v = var._value - if not isinstance(v, ArrayImpl): - with jax.ensure_compile_time_eval(): - v = jnp.zeros_like(v) - var._value = v - self._values[id_] = v + self._values[id_] = var._value + # v = var._value + # if isinstance(v, Tracer): + # with jax.ensure_compile_time_eval(): + # v = jnp.zeros_like(v) + # var._value = v + # self._values[id_] = v def collect_values(self): """Collect the value of each variable once again.""" @@ -115,7 +114,6 @@ def __add__(self, other: dict): new_dict._values.update(other._values) return new_dict - var_stack_list: List[VariableStack] = [] transform_stack: List[Callable] = [] From 2e258f963e38487e4545aaf9a9f230d8c2015bca Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 9 Sep 2023 10:21:18 +0800 Subject: [PATCH 08/10] [math] update `.tracing_variable()` function --- brainpy/_src/math/object_transform/base.py | 48 ++++++++++++++----- .../object_transform/tests/test_variable.py | 2 +- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index 6103bf6df..99bd548ef 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -102,27 +102,53 @@ def __init__(self, name=None): def setattr(self, key: str, value: Any) -> None: super().__setattr__(key, value) - def in_trace_variable(self, key: str, value: Variable) -> Variable: - """Initialize and get the in-trace variable. + def tracing_variable(self, name: str, value: Union[jax.Array, Array]) -> Variable: + """Initialize and get the variable which can be traced during computation. + + Although this function is designed to initialize tracing variables during computation or compilation, + it can also be used for initialization of variables before or after computation and compilation. + + - If ``name`` has been used in this object, a ``KeyError`` will be raised. + - If the variable has not been instantiated, the given ``value`` will be used to + instantiate a :py:class:`~.Variable`. + - If the variable has been created, the further call of this function will + refresh the value of the variable with the given ``value``. + + Here is the usage example:: + + class Example(bm.BrainPyObject): + def fun(self): + # this line will create a Variable instance + self.tracing_variable('a', bm.zeros(10)) + + # calling this function again will assign a different value + # to the created Variable instance + self.tracing_variable('a', bm.random.random(10)) Args: - key: str. The name of the variable. - value: Array. The data of the in-trace variable. + name: str. The variable name. + value: Array. The data of the in-trace variable. It can also be the instance of + :py:class:`~.Variable`, so that users can control the property of the created + variable instance. If an ``Array`` is provided, then it will be instantiated + as a :py:class:`~.Variable`. Returns: - variable. + The instance of :py:class:`~.Variable`. """ - if not hasattr(self, key): + if not hasattr(self, name): if not isinstance(value, Variable): value = Variable(value) value._ready_to_trace = True - v = value._value - if len(var_stack_list) > 0 and isinstance(v, jax.core.Tracer): + if len(var_stack_list) > 0 and isinstance(value._value, jax.core.Tracer): with jax.ensure_compile_time_eval(): - value._value = jax.numpy.zeros_like(v) - self.setattr(key, value) + value._value = jax.numpy.zeros_like(value._value) + self.setattr(name, value) else: - var = getattr(self, key) + var = getattr(self, name) + if not isinstance(var, Variable): + raise KeyError(f'"{name}" has been used in this class. Please assign ' + f'another name for the initialization of variables ' + f'tracing during computation and compilation.') var.value = value value = var return value diff --git a/brainpy/_src/math/object_transform/tests/test_variable.py b/brainpy/_src/math/object_transform/tests/test_variable.py index d4a289694..aed07ee3b 100644 --- a/brainpy/_src/math/object_transform/tests/test_variable.py +++ b/brainpy/_src/math/object_transform/tests/test_variable.py @@ -13,7 +13,7 @@ def __init__(self): self.f3 = bm.jit(self.fff) def f(self): - b = self.in_trace_variable('b', bm.ones(1,)) + b = self.tracing_variable('b', bm.ones(1, )) self.a += (b * 2) return self.a.value From 8bc7e23a8e09dc8b2aebad46abd823618a3e0567 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 9 Sep 2023 10:28:19 +0800 Subject: [PATCH 09/10] [math] update `.tracing_variable()` function --- brainpy/_src/math/object_transform/base.py | 33 ++++++++++++---------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index 99bd548ef..16b6e1d32 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -135,22 +135,25 @@ def fun(self): Returns: The instance of :py:class:`~.Variable`. """ - if not hasattr(self, name): - if not isinstance(value, Variable): - value = Variable(value) - value._ready_to_trace = True - if len(var_stack_list) > 0 and isinstance(value._value, jax.core.Tracer): - with jax.ensure_compile_time_eval(): - value._value = jax.numpy.zeros_like(value._value) - self.setattr(name, value) - else: + # the variable has been created + if hasattr(self, name): var = getattr(self, name) - if not isinstance(var, Variable): - raise KeyError(f'"{name}" has been used in this class. Please assign ' - f'another name for the initialization of variables ' - f'tracing during computation and compilation.') - var.value = value - value = var + if isinstance(var, Variable): + var.value = value + return var + + # create the variable + if not isinstance(value, Variable): + value = Variable(value) + value._ready_to_trace = True + if len(var_stack_list) > 0 and isinstance(value._value, jax.core.Tracer): + with jax.ensure_compile_time_eval(): + value._value = jax.numpy.zeros_like(value._value) + self.setattr(name, value) + # if not isinstance(var, Variable): + # raise KeyError(f'"{name}" has been used in this class. Please assign ' + # f'another name for the initialization of variables ' + # f'tracing during computation and compilation.') return value def __setattr__(self, key: str, value: Any) -> None: From beb6cc598eb1be281fdbd4dab0beddf3a02aa3b4 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 9 Sep 2023 12:04:24 +0800 Subject: [PATCH 10/10] [math] update `.tracing_variable()` functionality --- brainpy/_src/math/modes.py | 4 + brainpy/_src/math/object_transform/base.py | 87 ++++++++++++------- .../object_transform/tests/test_variable.py | 2 +- 3 files changed, 62 insertions(+), 31 deletions(-) diff --git a/brainpy/_src/math/modes.py b/brainpy/_src/math/modes.py index 5e72ff09c..674035e18 100644 --- a/brainpy/_src/math/modes.py +++ b/brainpy/_src/math/modes.py @@ -61,6 +61,10 @@ class NonBatchingMode(Mode): """ pass + @property + def batch_size(self): + return tuple() + class BatchingMode(Mode): """Batching mode. diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index 16b6e1d32..daa8a55bb 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -21,7 +21,11 @@ check_name_uniqueness) from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar, VarList, VarDict, var_stack_list) +from brainpy._src.math.modes import Mode +from brainpy._src.math.sharding import BATCH_AXIS + +variable_ = None StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys']) __all__ = [ @@ -102,35 +106,52 @@ def __init__(self, name=None): def setattr(self, key: str, value: Any) -> None: super().__setattr__(key, value) - def tracing_variable(self, name: str, value: Union[jax.Array, Array]) -> Variable: - """Initialize and get the variable which can be traced during computation. + def tracing_variable( + self, + name: str, + init: Union[Callable, Array, jax.Array], + shape: Union[int, Sequence[int]], + batch_or_mode: Union[int, bool, Mode] = None, + batch_axis: int = 0, + axis_names: Optional[Sequence[str]] = None, + batch_axis_name: Optional[str] = BATCH_AXIS, + ) -> Variable: + """Initialize the variable which can be traced during computations and transformations. Although this function is designed to initialize tracing variables during computation or compilation, - it can also be used for initialization of variables before or after computation and compilation. + it can also be used for the initialization of variables before computation and compilation. - - If ``name`` has been used in this object, a ``KeyError`` will be raised. - - If the variable has not been instantiated, the given ``value`` will be used to - instantiate a :py:class:`~.Variable`. - - If the variable has been created, the further call of this function will - refresh the value of the variable with the given ``value``. + - If the variable has not been instantiated, a :py:class:`~.Variable` will be instantiated. + - If the variable has been created, the further call of this function will return the created variable. Here is the usage example:: class Example(bm.BrainPyObject): def fun(self): - # this line will create a Variable instance - self.tracing_variable('a', bm.zeros(10)) + # The first time of calling `.fun()`, this line will create a Variable instance. + # If users repeatedly call `.fun()` function, this line will not initialize variables again. + # Instead, it will return the variable has been created. + self.tracing_variable('a', bm.zeros, (10,)) - # calling this function again will assign a different value - # to the created Variable instance - self.tracing_variable('a', bm.random.random(10)) + # The created variable can be accessed with self.xxx + self.a.value = bm.ones(10) + + # Calling this function again will not reinitialize the + # variable again, Instead, it will return the variable + # that has been created. + a = self.tracing_variable('a', bm.zeros, (10,)) Args: name: str. The variable name. - value: Array. The data of the in-trace variable. It can also be the instance of - :py:class:`~.Variable`, so that users can control the property of the created - variable instance. If an ``Array`` is provided, then it will be instantiated - as a :py:class:`~.Variable`. + init: callable, Array. The data to be initialized as a ``Variable``. + batch_or_mode: int, bool, Mode. This is used to specify the batch size of this variable. + If it is a boolean or an instance of ``Mode``, the batch size will be 1. + If it is None, the variable has no batch axis. + shape: int, sequence of int. The shape of the variable. + batch_axis: int. The batch axis, if batch size is given. + axis_names: sequence of str. The name for each axis. These names should match the given ``axes``. + batch_axis_name: str. The name for the batch axis. The name will be used + if ``batch_or_mode`` is given. Default is ``brainpy.math.sharding.BATCH_AXIS``. Returns: The instance of :py:class:`~.Variable`. @@ -139,21 +160,27 @@ def fun(self): if hasattr(self, name): var = getattr(self, name) if isinstance(var, Variable): - var.value = value return var - - # create the variable - if not isinstance(value, Variable): - value = Variable(value) - value._ready_to_trace = True - if len(var_stack_list) > 0 and isinstance(value._value, jax.core.Tracer): - with jax.ensure_compile_time_eval(): - value._value = jax.numpy.zeros_like(value._value) + # if var.shape != value.shape: + # raise ValueError( + # f'"{name}" has been used in this class with the shape of {var.shape} (!= {value.shape}). ' + # f'Please assign another name for the initialization of variables ' + # f'tracing during computation and compilation.' + # ) + # if var.dtype != value.dtype: + # raise ValueError( + # f'"{name}" has been used in this class with the dtype of {var.dtype} (!= {value.dtype}). ' + # f'Please assign another name for the initialization of variables ' + # f'tracing during computation and compilation.' + # ) + + global variable_ + if variable_ is None: + from brainpy.initialize import variable_ + with jax.ensure_compile_time_eval(): + value = variable_(init, shape, batch_or_mode, batch_axis, axis_names, batch_axis_name) + value._ready_to_trace = True self.setattr(name, value) - # if not isinstance(var, Variable): - # raise KeyError(f'"{name}" has been used in this class. Please assign ' - # f'another name for the initialization of variables ' - # f'tracing during computation and compilation.') return value def __setattr__(self, key: str, value: Any) -> None: diff --git a/brainpy/_src/math/object_transform/tests/test_variable.py b/brainpy/_src/math/object_transform/tests/test_variable.py index aed07ee3b..ddf7c8d22 100644 --- a/brainpy/_src/math/object_transform/tests/test_variable.py +++ b/brainpy/_src/math/object_transform/tests/test_variable.py @@ -13,7 +13,7 @@ def __init__(self): self.f3 = bm.jit(self.fff) def f(self): - b = self.tracing_variable('b', bm.ones(1, )) + b = self.tracing_variable('b', bm.ones, (1,)) self.a += (b * 2) return self.a.value