Skip to content

Commit

Permalink
[math] update .tracing_variable() functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Sep 9, 2023
1 parent 8bc7e23 commit beb6cc5
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 31 deletions.
4 changes: 4 additions & 0 deletions brainpy/_src/math/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class NonBatchingMode(Mode):
"""
pass

@property
def batch_size(self):
return tuple()


class BatchingMode(Mode):
"""Batching mode.
Expand Down
87 changes: 57 additions & 30 deletions brainpy/_src/math/object_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
check_name_uniqueness)
from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar,
VarList, VarDict, var_stack_list)
from brainpy._src.math.modes import Mode
from brainpy._src.math.sharding import BATCH_AXIS


variable_ = None
StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])

__all__ = [
Expand Down Expand Up @@ -102,35 +106,52 @@ def __init__(self, name=None):
def setattr(self, key: str, value: Any) -> None:
super().__setattr__(key, value)

def tracing_variable(self, name: str, value: Union[jax.Array, Array]) -> Variable:
"""Initialize and get the variable which can be traced during computation.
def tracing_variable(
self,
name: str,
init: Union[Callable, Array, jax.Array],
shape: Union[int, Sequence[int]],
batch_or_mode: Union[int, bool, Mode] = None,
batch_axis: int = 0,
axis_names: Optional[Sequence[str]] = None,
batch_axis_name: Optional[str] = BATCH_AXIS,
) -> Variable:
"""Initialize the variable which can be traced during computations and transformations.
Although this function is designed to initialize tracing variables during computation or compilation,
it can also be used for initialization of variables before or after computation and compilation.
it can also be used for the initialization of variables before computation and compilation.
- If ``name`` has been used in this object, a ``KeyError`` will be raised.
- If the variable has not been instantiated, the given ``value`` will be used to
instantiate a :py:class:`~.Variable`.
- If the variable has been created, the further call of this function will
refresh the value of the variable with the given ``value``.
- If the variable has not been instantiated, a :py:class:`~.Variable` will be instantiated.
- If the variable has been created, the further call of this function will return the created variable.
Here is the usage example::
class Example(bm.BrainPyObject):
def fun(self):
# this line will create a Variable instance
self.tracing_variable('a', bm.zeros(10))
# The first time of calling `.fun()`, this line will create a Variable instance.
# If users repeatedly call `.fun()` function, this line will not initialize variables again.
# Instead, it will return the variable has been created.
self.tracing_variable('a', bm.zeros, (10,))
# calling this function again will assign a different value
# to the created Variable instance
self.tracing_variable('a', bm.random.random(10))
# The created variable can be accessed with self.xxx
self.a.value = bm.ones(10)
# Calling this function again will not reinitialize the
# variable again, Instead, it will return the variable
# that has been created.
a = self.tracing_variable('a', bm.zeros, (10,))
Args:
name: str. The variable name.
value: Array. The data of the in-trace variable. It can also be the instance of
:py:class:`~.Variable`, so that users can control the property of the created
variable instance. If an ``Array`` is provided, then it will be instantiated
as a :py:class:`~.Variable`.
init: callable, Array. The data to be initialized as a ``Variable``.
batch_or_mode: int, bool, Mode. This is used to specify the batch size of this variable.
If it is a boolean or an instance of ``Mode``, the batch size will be 1.
If it is None, the variable has no batch axis.
shape: int, sequence of int. The shape of the variable.
batch_axis: int. The batch axis, if batch size is given.
axis_names: sequence of str. The name for each axis. These names should match the given ``axes``.
batch_axis_name: str. The name for the batch axis. The name will be used
if ``batch_or_mode`` is given. Default is ``brainpy.math.sharding.BATCH_AXIS``.
Returns:
The instance of :py:class:`~.Variable`.
Expand All @@ -139,21 +160,27 @@ def fun(self):
if hasattr(self, name):
var = getattr(self, name)
if isinstance(var, Variable):
var.value = value
return var

# create the variable
if not isinstance(value, Variable):
value = Variable(value)
value._ready_to_trace = True
if len(var_stack_list) > 0 and isinstance(value._value, jax.core.Tracer):
with jax.ensure_compile_time_eval():
value._value = jax.numpy.zeros_like(value._value)
# if var.shape != value.shape:
# raise ValueError(
# f'"{name}" has been used in this class with the shape of {var.shape} (!= {value.shape}). '
# f'Please assign another name for the initialization of variables '
# f'tracing during computation and compilation.'
# )
# if var.dtype != value.dtype:
# raise ValueError(
# f'"{name}" has been used in this class with the dtype of {var.dtype} (!= {value.dtype}). '
# f'Please assign another name for the initialization of variables '
# f'tracing during computation and compilation.'
# )

global variable_
if variable_ is None:
from brainpy.initialize import variable_
with jax.ensure_compile_time_eval():
value = variable_(init, shape, batch_or_mode, batch_axis, axis_names, batch_axis_name)
value._ready_to_trace = True
self.setattr(name, value)
# if not isinstance(var, Variable):
# raise KeyError(f'"{name}" has been used in this class. Please assign '
# f'another name for the initialization of variables '
# f'tracing during computation and compilation.')
return value

def __setattr__(self, key: str, value: Any) -> None:
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/object_transform/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self):
self.f3 = bm.jit(self.fff)

def f(self):
b = self.tracing_variable('b', bm.ones(1, ))
b = self.tracing_variable('b', bm.ones, (1,))
self.a += (b * 2)
return self.a.value

Expand Down

0 comments on commit beb6cc5

Please sign in to comment.