Skip to content

Commit

Permalink
Merge pull request #472 from chaoming0625/master
Browse files Browse the repository at this point in the history
Support tracing ``Variable`` during computation and compilation by using ``tracing_variable()`` function
  • Loading branch information
chaoming0625 authored Sep 9, 2023
2 parents 21848a1 + beb6cc5 commit d83caa4
Show file tree
Hide file tree
Showing 17 changed files with 394 additions and 98 deletions.
4 changes: 2 additions & 2 deletions brainpy/_src/checkpoints/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def save_as_h5(filename: str, variables: dict):
raise ValueError(f'Cannot save variables as a HDF5 file. We only support file with '
f'postfix of ".hdf5" and ".h5". But we got {filename}')

import h5py
import h5py # noqa

# check variables
check_dict_data(variables, name='variables')
Expand Down Expand Up @@ -184,7 +184,7 @@ def load_by_h5(filename: str, target, verbose: bool = False):
f'postfix of ".hdf5" and ".h5". But we got {filename}')

# read data
import h5py
import h5py # noqa
load_vars = dict()
with h5py.File(filename, "r") as f:
for key in f.keys():
Expand Down
48 changes: 24 additions & 24 deletions brainpy/_src/checkpoints/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ def __init__(self):
print(self.net.vars().keys())
print(self.net.vars().unique().keys())

def test_h5(self):
bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars())
bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True)

bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars())
bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)

def test_h5_postfix(self):
with self.assertRaises(ValueError):
bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars())
with self.assertRaises(ValueError):
bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True)
# def test_h5(self):
# bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars())
# bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True)
#
# bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars())
# bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)
#
# def test_h5_postfix(self):
# with self.assertRaises(ValueError):
# bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars())
# with self.assertRaises(ValueError):
# bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True)

def test_npz(self):
bp.checkpoints.io.save_as_npz('io_test_tmp.npz', self.net.vars())
Expand Down Expand Up @@ -120,18 +120,18 @@ def __init__(self):
print(self.net.vars().keys())
print(self.net.vars().unique().keys())

def test_h5(self):
bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars())
bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True)

bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars())
bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)

def test_h5_postfix(self):
with self.assertRaises(ValueError):
bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars())
with self.assertRaises(ValueError):
bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True)
# def test_h5(self):
# bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars())
# bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True)
#
# bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars())
# bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)
#
# def test_h5_postfix(self):
# with self.assertRaises(ValueError):
# bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars())
# with self.assertRaises(ValueError):
# bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True)

def test_npz(self):
bp.checkpoints.io.save_as_npz('io_test_tmp.npz', self.net.vars())
Expand Down
7 changes: 5 additions & 2 deletions brainpy/_src/connect/random_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,11 @@ def build_csr(self):
return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type())

def build_mat(self):
pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio
mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state
if self.pre_ratio < 1.:
pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio
mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state
else:
mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob)
mat = bm.asarray(mat)
if not self.include_self:
bm.fill_diagonal(mat, False)
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/dyn/neurons/hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def update(self, x=None):
return super().update(x)


class HHLTC(NeuDyn):
class HHLTC(HHTypedNeuron):
r"""Hodgkin–Huxley neuron model with liquid time constant.
**Model Descriptions**
Expand Down Expand Up @@ -758,7 +758,7 @@ def update(self, x=None):
return super().update(x)


class WangBuzsakiHHLTC(NeuDyn):
class WangBuzsakiHHLTC(HHTypedNeuron):
r"""Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model with liquid time constant.
Each model is described by a single compartment and obeys the current balance equation:
Expand Down
55 changes: 47 additions & 8 deletions brainpy/_src/dyn/projections/aligns.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Optional, Callable, Union

import jax

from brainpy import math as bm, check
from brainpy._src.delay import Delay, DelayAccess, delay_identifier, init_delay_by_return
from brainpy._src.dynsys import DynamicalSystem, Projection
Expand Down Expand Up @@ -127,6 +125,7 @@ def __init__(

# references
self.refs = dict(post=post, out=out) # invisible to ``self.nodes()``
self.refs['comm'] = comm # unify the access

def update(self, x):
current = self.comm(x)
Expand Down Expand Up @@ -218,6 +217,7 @@ def __init__(
self.refs = dict(post=post) # invisible to ``self.nodes()``
self.refs['syn'] = post.get_bef_update(self._post_repr).syn
self.refs['out'] = post.get_bef_update(self._post_repr).out
self.refs['comm'] = comm # unify the access

def update(self, x):
current = self.comm(x)
Expand Down Expand Up @@ -342,6 +342,9 @@ def __init__(
self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()``
self.refs['syn'] = post.get_bef_update(self._post_repr).syn # invisible to ``self.node()``
self.refs['out'] = post.get_bef_update(self._post_repr).out # invisible to ``self.node()``
# unify the access
self.refs['comm'] = comm
self.refs['delay'] = pre.get_aft_update(delay_identifier)

def update(self):
x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name)
Expand Down Expand Up @@ -422,9 +425,13 @@ def __init__(
post.add_bef_update(self.name, _AlignPost(syn, out))

# reference
self.refs = dict(post=post) # invisible to ``self.nodes()``
self.refs = dict()
# invisible to ``self.nodes()``
self.refs['post'] = post
self.refs['syn'] = post.get_bef_update(self.name).syn
self.refs['out'] = post.get_bef_update(self.name).out
# unify the access
self.refs['comm'] = comm

def update(self, x):
current = self.comm(x)
Expand Down Expand Up @@ -538,8 +545,15 @@ def __init__(
post.add_inp_fun(out_name, out)

# references
self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()``
self.refs = dict()
# invisible to ``self.nodes()``
self.refs['pre'] = pre
self.refs['post'] = post
self.refs['out'] = out
# unify the access
self.refs['delay'] = pre.get_aft_update(delay_identifier)
self.refs['comm'] = comm
self.refs['syn'] = syn

def update(self):
x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name)
Expand Down Expand Up @@ -655,8 +669,15 @@ def __init__(
post.add_inp_fun(out_name, out)

# references
self.refs = dict(pre=pre, post=post, out=out, delay=delay_cls) # invisible to ``self.nodes()``
self.refs = dict()
# invisible to ``self.nodes()``
self.refs['pre'] = pre
self.refs['post'] = post
self.refs['out'] = out
self.refs['delay'] = delay_cls
self.refs['syn'] = pre.get_aft_update(self._syn_id).syn
# unify the access
self.refs['comm'] = comm

def update(self, x=None):
if x is None:
Expand Down Expand Up @@ -778,9 +799,14 @@ def __init__(
post.add_inp_fun(out_name, out)

# references
self.refs = dict(pre=pre, post=post) # invisible to `self.nodes()`
self.refs = dict()
# invisible to `self.nodes()`
self.refs['pre'] = pre
self.refs['post'] = post
self.refs['syn'] = delay_cls.get_bef_update(self._syn_id).syn
self.refs['out'] = out
# unify the access
self.refs['comm'] = comm

def update(self):
x = _get_return(self.refs['syn'].return_info())
Expand Down Expand Up @@ -890,9 +916,15 @@ def __init__(
post.add_inp_fun(out_name, out)

# references
self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()``
self.refs = dict()
# invisible to ``self.nodes()``
self.refs['pre'] = pre
self.refs['post'] = post
self.refs['out'] = out
self.refs['delay'] = delay_cls
self.refs['syn'] = syn
# unify the access
self.refs['comm'] = comm

def update(self, x=None):
if x is None:
Expand Down Expand Up @@ -1006,8 +1038,15 @@ def __init__(
post.add_inp_fun(out_name, out)

# references
self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()``
self.refs = dict()
# invisible to ``self.nodes()``
self.refs['pre'] = pre
self.refs['post'] = post
self.refs['out'] = out
self.refs['delay'] = pre.get_aft_update(delay_identifier)
# unify the access
self.refs['syn'] = syn
self.refs['comm'] = comm

def update(self):
spk = self.refs['delay'].at(self.name)
Expand Down
4 changes: 4 additions & 0 deletions brainpy/_src/math/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class NonBatchingMode(Mode):
"""
pass

@property
def batch_size(self):
return tuple()


class BatchingMode(Mode):
"""Batching mode.
Expand Down
95 changes: 86 additions & 9 deletions brainpy/_src/math/object_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@
from brainpy._src.math.object_transform.naming import (get_unique_name,
check_name_uniqueness)
from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar,
VarList, VarDict)
VarList, VarDict, var_stack_list)
from brainpy._src.math.modes import Mode
from brainpy._src.math.sharding import BATCH_AXIS


variable_ = None
StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])

__all__ = [
Expand Down Expand Up @@ -102,17 +106,91 @@ def __init__(self, name=None):
def setattr(self, key: str, value: Any) -> None:
super().__setattr__(key, value)

def tracing_variable(
self,
name: str,
init: Union[Callable, Array, jax.Array],
shape: Union[int, Sequence[int]],
batch_or_mode: Union[int, bool, Mode] = None,
batch_axis: int = 0,
axis_names: Optional[Sequence[str]] = None,
batch_axis_name: Optional[str] = BATCH_AXIS,
) -> Variable:
"""Initialize the variable which can be traced during computations and transformations.
Although this function is designed to initialize tracing variables during computation or compilation,
it can also be used for the initialization of variables before computation and compilation.
- If the variable has not been instantiated, a :py:class:`~.Variable` will be instantiated.
- If the variable has been created, the further call of this function will return the created variable.
Here is the usage example::
class Example(bm.BrainPyObject):
def fun(self):
# The first time of calling `.fun()`, this line will create a Variable instance.
# If users repeatedly call `.fun()` function, this line will not initialize variables again.
# Instead, it will return the variable has been created.
self.tracing_variable('a', bm.zeros, (10,))
# The created variable can be accessed with self.xxx
self.a.value = bm.ones(10)
# Calling this function again will not reinitialize the
# variable again, Instead, it will return the variable
# that has been created.
a = self.tracing_variable('a', bm.zeros, (10,))
Args:
name: str. The variable name.
init: callable, Array. The data to be initialized as a ``Variable``.
batch_or_mode: int, bool, Mode. This is used to specify the batch size of this variable.
If it is a boolean or an instance of ``Mode``, the batch size will be 1.
If it is None, the variable has no batch axis.
shape: int, sequence of int. The shape of the variable.
batch_axis: int. The batch axis, if batch size is given.
axis_names: sequence of str. The name for each axis. These names should match the given ``axes``.
batch_axis_name: str. The name for the batch axis. The name will be used
if ``batch_or_mode`` is given. Default is ``brainpy.math.sharding.BATCH_AXIS``.
Returns:
The instance of :py:class:`~.Variable`.
"""
# the variable has been created
if hasattr(self, name):
var = getattr(self, name)
if isinstance(var, Variable):
return var
# if var.shape != value.shape:
# raise ValueError(
# f'"{name}" has been used in this class with the shape of {var.shape} (!= {value.shape}). '
# f'Please assign another name for the initialization of variables '
# f'tracing during computation and compilation.'
# )
# if var.dtype != value.dtype:
# raise ValueError(
# f'"{name}" has been used in this class with the dtype of {var.dtype} (!= {value.dtype}). '
# f'Please assign another name for the initialization of variables '
# f'tracing during computation and compilation.'
# )

global variable_
if variable_ is None:
from brainpy.initialize import variable_
with jax.ensure_compile_time_eval():
value = variable_(init, shape, batch_or_mode, batch_axis, axis_names, batch_axis_name)
value._ready_to_trace = True
self.setattr(name, value)
return value

def __setattr__(self, key: str, value: Any) -> None:
"""Overwrite `__setattr__` method for changing :py:class:`~.Variable` values.
.. versionadded:: 2.3.1
Parameters
----------
key: str
The attribute.
value: Any
The value.
Args:
key: str. The attribute.
value: Any. The value.
"""
if key in self.__dict__:
val = self.__dict__[key]
Expand Down Expand Up @@ -252,7 +330,7 @@ def vars(self,
continue
v = getattr(node, k)
if isinstance(v, Variable) and not isinstance(v, exclude_types):
gather[f'{node_path}.{k}' if node_path else k] = v
gather[f'{node_path}.{k}' if node_path else k] = v
elif isinstance(v, VarList):
for i, vv in enumerate(v):
if not isinstance(vv, exclude_types):
Expand Down Expand Up @@ -702,4 +780,3 @@ def __setitem__(self, key, value) -> 'VarDict':


node_dict = NodeDict

Loading

0 comments on commit d83caa4

Please sign in to comment.