Skip to content

Commit

Permalink
[math] update .tracing_variable() function
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Sep 9, 2023
1 parent 2e258f9 commit 8bc7e23
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions brainpy/_src/math/object_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,22 +135,25 @@ def fun(self):
Returns:
The instance of :py:class:`~.Variable`.
"""
if not hasattr(self, name):
if not isinstance(value, Variable):
value = Variable(value)
value._ready_to_trace = True
if len(var_stack_list) > 0 and isinstance(value._value, jax.core.Tracer):
with jax.ensure_compile_time_eval():
value._value = jax.numpy.zeros_like(value._value)
self.setattr(name, value)
else:
# the variable has been created
if hasattr(self, name):
var = getattr(self, name)
if not isinstance(var, Variable):
raise KeyError(f'"{name}" has been used in this class. Please assign '
f'another name for the initialization of variables '
f'tracing during computation and compilation.')
var.value = value
value = var
if isinstance(var, Variable):
var.value = value
return var

# create the variable
if not isinstance(value, Variable):
value = Variable(value)
value._ready_to_trace = True
if len(var_stack_list) > 0 and isinstance(value._value, jax.core.Tracer):
with jax.ensure_compile_time_eval():
value._value = jax.numpy.zeros_like(value._value)
self.setattr(name, value)
# if not isinstance(var, Variable):
# raise KeyError(f'"{name}" has been used in this class. Please assign '
# f'another name for the initialization of variables '
# f'tracing during computation and compilation.')
return value

def __setattr__(self, key: str, value: Any) -> None:
Expand Down

0 comments on commit 8bc7e23

Please sign in to comment.