Skip to content

Commit

Permalink
Fix JIT bugs and Replace deprecated functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 14, 2024
1 parent 5253b59 commit 92292af
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 25 deletions.
2 changes: 1 addition & 1 deletion brainpy/_src/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def find_fps_with_gd_method(
"""
# optimization settings
if optimizer is None:
optimizer = optim.Adam(lr=optim.ExponentialDecay(0.2, 1, 0.9999),
optimizer = optim.Adam(lr=optim.ExponentialDecayLR(0.2, 1, 0.9999),
beta1=0.9, beta2=0.999, eps=1e-8)
else:
if not isinstance(optimizer, optim.Optimizer):
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dyn/rates/tests/test_nvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Test_NVAR(parameterized.TestCase):
def test_NVAR(self,mode):
bm.random.seed()
input=bm.random.randn(1,5)
layer=bp.dnn.NVAR(num_in=5,
layer=bp.dyn.NVAR(num_in=5,
delay=10,
mode=mode)
if mode in [bm.NonBatchingMode()]:
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/initialize/tests/test_decay_inits.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# visualization
def mat_visualize(matrix, cmap=None):
if cmap is None:
cmap = plt.cm.get_cmap('coolwarm')
plt.cm.get_cmap('coolwarm')
cmap = plt.colormaps.get_cmap('coolwarm')
plt.colormaps.get_cmap('coolwarm')
im = plt.matshow(matrix, cmap=cmap)
plt.colorbar(mappable=im, shrink=0.8, aspect=15)
plt.show()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def dV(self, V, t, h, n, Iext):

return dVdt

def update(self, tdi):
t, dt = tdi.t, tdi.dt
def update(self):
t, dt = bp.share['t'], bp.share['dt']
V, h, n = self.integral(self.V, self.h, self.n, t, self.input, dt=dt)
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
Expand Down
3 changes: 1 addition & 2 deletions brainpy/_src/math/object_transform/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,8 @@ def call_fun(self, *args, **kwargs):

return call_fun


def _make_transform(fun, stack):
@wraps(fun)
# @wraps(fun)
def _transform_function(variable_data: Dict, *args, **kwargs):
for key, v in stack.items():
v._value = variable_data[key]
Expand Down
14 changes: 7 additions & 7 deletions brainpy/_src/math/object_transform/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,14 @@ def test1(self):
hh = bp.dyn.HH(1)
hh.reset()

tree = jax.tree_structure(hh)
leaves = jax.tree_leaves(hh)
tree = jax.tree.structure(hh)
leaves = jax.tree.leaves(hh)
# tree = jax.tree.structure(hh)
# leaves = jax.tree.leaves(hh)

print(tree)
print(leaves)
print(jax.tree_unflatten(tree, leaves))
print(jax.tree.unflatten(tree, leaves))
# print(jax.tree.unflatten(tree, leaves))
print()

Expand Down Expand Up @@ -284,16 +284,16 @@ def not_close(x, y):
def all_close(x, y):
assert bm.allclose(x, y)

jax.tree_map(all_close, all_states, variables, is_leaf=bm.is_bp_array)
jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array)
# jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array)

random_state = jax.tree_map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
jax.tree_map(not_close, random_state, variables, is_leaf=bm.is_bp_array)
random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array)
# random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
# jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array)

obj.load_state_dict(random_state)
jax.tree_map(all_close, random_state, variables, is_leaf=bm.is_bp_array)
jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array)
# jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_nodes():
A.pre = B
B.pre = A

net = bp.dyn.Network(A, B)
net = bp.Network(A, B)
abs_nodes = net.nodes(method='absolute')
rel_nodes = net.nodes(method='relative')
print()
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/math/object_transform/tests/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import brainpy as bp


class GABAa_without_Variable(bp.TwoEndConn):
class GABAa_without_Variable(bp.synapses.TwoEndConn):
def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75.,
alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs):
super(GABAa_without_Variable, self).__init__(pre=pre, post=post, **kwargs)
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_neu_nodes_1():
assert len(neu.nodes(method='relative', include_self=False)) == 1


class GABAa_with_Variable(bp.TwoEndConn):
class GABAa_with_Variable(bp.synapses.TwoEndConn):
def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75.,
alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs):
super(GABAa_with_Variable, self).__init__(pre=pre, post=post, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/object_transform/tests/test_controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def f1():
branches=[f1,
lambda: 2, lambda: 3,
lambda: 4, lambda: 5],
dyn_vars=var_a,
# dyn_vars=var_a,
show_code=True)

self.assertTrue(f(11) == 1)
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/math/object_transform/tests/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __call__(self, *args, **kwargs):
def test_jit_with_static(self):
a = bm.Variable(bm.ones(2))

@bm.jit(static_argnums=0)
@bm.jit(static_argnums=1)
def f(b, c):
a.value *= b
a.value /= c
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(self):
self.a = bm.zeros(2)
self.b = bm.Variable(bm.ones(2))

self.call1 = bm.jit(self.call, static_argnums=1)
self.call1 = bm.jit(self.call, static_argnums=0)
self.call2 = bm.jit(self.call, static_argnames=['fit'])

def call(self, fit=True):
Expand Down
24 changes: 21 additions & 3 deletions brainpy/_src/math/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from collections import namedtuple
from functools import partial
from operator import index
from typing import Optional, Union, Sequence
from typing import Optional, Union, Sequence, Any

import jax
import numpy as np
from jax import lax, jit, vmap, numpy as jnp, random as jr, core, dtypes
from jax._src.array import ArrayImpl
from jax._src.core import _canonicalize_dimension, _invalid_shape_error
from jax._src.typing import Shape
from jax.tree_util import register_pytree_node_class

from brainpy.check import jit_error_checking, jit_error_checking_no_args
Expand All @@ -33,7 +35,7 @@
'hypergeometric', 'logseries', 'multinomial', 'multivariate_normal',
'negative_binomial', 'noncentral_chisquare', 'noncentral_f', 'power',
'rayleigh', 'triangular', 'vonmises', 'wald', 'weibull', 'weibull_min',
'zipf', 'maxwell', 't', 'orthogonal', 'loggamma', 'categorical',
'zipf', 'maxwell', 't', 'orthogonal', 'loggamma', 'categorical', 'canonicalize_shape',

# pytorch compatibility
'rand_like', 'randint_like', 'randn_like',
Expand Down Expand Up @@ -437,6 +439,22 @@ def _check_py_seq(seq):
return jnp.asarray(seq) if isinstance(seq, (tuple, list)) else seq


def canonicalize_shape(shape: Shape, context: str = "") -> tuple[Any, ...]:
"""Canonicalizes and checks for errors in a user-provided shape value.
Args:
shape: a Python value that represents a shape.
Returns:
A tuple of canonical dimension values.
"""
try:
return tuple(map(_canonicalize_dimension, shape))
except TypeError:
pass
raise _invalid_shape_error(shape, context)


@register_pytree_node_class
class RandomState(Variable):
"""RandomState that track the random generator state. """
Expand Down Expand Up @@ -1097,7 +1115,7 @@ def weibull_min(self, a, scale=None, size: Optional[Union[int, Sequence[int]]] =

def maxwell(self, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None):
key = self.split_key() if key is None else _formalize_key(key)
shape = core.canonicalize_shape(_size2shape(size)) + (3,)
shape = canonicalize_shape(_size2shape(size)) + (3,)
norm_rvs = jr.normal(key=key, shape=shape)
r = jnp.linalg.norm(norm_rvs, axis=-1)
return _return(r)
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/optimizers/tests/test_ModifyLr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def train_data():
class RNN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden):
super(RNN, self).__init__()
self.rnn = bp.dnn.RNNCell(num_in, num_hidden, train_state=True)
self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
self.out = bp.dnn.Dense(num_hidden, 1)

def update(self, x):
Expand Down

0 comments on commit 92292af

Please sign in to comment.