Skip to content

Commit

Permalink
Merge pull request #121 from PKU-NIP-Lab/whole-brain-modeling
Browse files Browse the repository at this point in the history
Whole brain modeling
  • Loading branch information
chaoming0625 authored Mar 23, 2022
2 parents 3086c69 + 0de0693 commit 5eb9f03
Show file tree
Hide file tree
Showing 69 changed files with 3,548 additions and 1,274 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ BrainModels/
book/
docs/examples
docs/apis/jaxsetting.rst
docs/quickstart/data
examples/recurrent_neural_network/neurogym
develop/iconip_paper
develop/benchmark/COBA/results
Expand Down
27 changes: 26 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ runner.run(100.)
Numerical methods for delay differential equations (SDEs).

```python
xdelay = bm.TimeDelay(1, delay_len=1., before_t0=1., dt=0.01)
xdelay = bm.TimeDelay(bm.zeros(1), delay_len=1., before_t0=1., dt=0.01)


@bp.ddeint(method='rk4', state_delays={'x': xdelay})
Expand Down Expand Up @@ -191,6 +191,31 @@ runner = bp.dyn.DSRunner(net)
runner(100.)
```

Simulating a whole brain network by using rate models.

```python
import numpy as np

class WholeBrainNet(bp.dyn.Network):
def __init__(self, signal_speed=20.):
super(WholeBrainNet, self).__init__()

self.fhn = bp.dyn.RateFHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn')
self.syn = bp.dyn.DiffusiveDelayCoupling(self.fhn, self.fhn,
'x->input',
conn_mat=conn_mat,
delay_mat=delay_mat)

def update(self, _t, _dt):
self.syn.update(_t, _dt)
self.fhn.update(_t, _dt)


net = WholeBrainNet()
runner = bp.dyn.DSRunner(net, monitors=['fhn.x'], inputs=['fhn.input', 0.72])
runner.run(6e3)
```



### 4. Dynamics training level
Expand Down
4 changes: 2 additions & 2 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.1.1"
__version__ = "2.1.2"


try:
Expand All @@ -15,7 +15,7 @@


# fundamental modules
from . import errors, tools
from . import errors, tools, check


# "base" module
Expand Down
60 changes: 39 additions & 21 deletions brainpy/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# -*- coding: utf-8 -*-

import inspect
import time
import warnings
from functools import partial

from jax import vmap
import jax.numpy
import numpy as np
from jax.scipy.optimize import minimize

import brainpy.math as bm
from brainpy import optimizers as optim
from brainpy.analysis import utils
from brainpy.errors import AnalyzerError
from brainpy import optimizers as optim

__all__ = [
'SlowPointFinder',
Expand Down Expand Up @@ -56,15 +57,15 @@ def __init__(self, f_cell, f_type='continuous', f_loss_batch=None, verbose=True)
if f_loss_batch is None:
if f_type == 'discrete':
self.f_loss = bm.jit(lambda h: bm.mean((h - f_cell(h)) ** 2))
self.f_loss_batch = bm.jit(lambda h: bm.mean((h - bm.vmap(f_cell, auto_infer=False)(h)) ** 2, axis=1))
self.f_loss_batch = bm.jit(lambda h: bm.mean((h - vmap(f_cell)(h)) ** 2, axis=1))
if f_type == 'continuous':
self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h) ** 2))
self.f_loss_batch = bm.jit(lambda h: bm.mean((bm.vmap(f_cell, auto_infer=False)(h)) ** 2, axis=1))
self.f_loss_batch = bm.jit(lambda h: bm.mean((vmap(f_cell)(h)) ** 2, axis=1))

else:
self.f_loss_batch = f_loss_batch
self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h) ** 2))
self.f_jacob_batch = bm.jit(bm.vmap(bm.jacobian(f_cell)))
self.f_jacob_batch = bm.jit(vmap(bm.jacobian(f_cell)))

# essential variables
self._losses = None
Expand All @@ -87,8 +88,13 @@ def selected_ids(self):
"""The selected ids of candidate points."""
return self._selected_ids

def find_fps_with_gd_method(self, candidates, tolerance=1e-5, num_batch=100,
num_opt=10000, opt_setting=None):
def find_fps_with_gd_method(self,
candidates,
tolerance=1e-5,
num_batch=100,
num_opt=10000,
optimizer=None,
opt_setting=None):
"""Optimize fixed points with gradient descent methods.
Parameters
Expand All @@ -104,44 +110,56 @@ def find_fps_with_gd_method(self, candidates, tolerance=1e-5, num_batch=100,
Print training information during optimization every so often.
opt_setting: optional, dict
The optimization settings.
.. deprecated:: 2.1.2
Use "optimizer" to set optimization method instead.
optimizer: optim.Optimizer
The optimizer instance.
.. versionadded:: 2.1.2
"""

# optimization settings
if opt_setting is None:
opt_method = optim.Adam
opt_lr = optim.ExponentialDecay(0.2, 1, 0.9999)
opt_setting = {'beta1': 0.9,
'beta2': 0.999,
'eps': 1e-8,
'name': None}
if optimizer is None:
optimizer = optim.Adam(lr=optim.ExponentialDecay(0.2, 1, 0.9999),
beta1=0.9, beta2=0.999, eps=1e-8)
else:
assert isinstance(optimizer, optim.Optimizer), (f'Must be an instance of '
f'{optim.Optimizer.__name__}, '
f'while we got {type(optimizer)}')
else:
warnings.warn('Please use "optimizer" to set optimization method. '
'"opt_setting" is deprecated since version 2.1.2. ',
DeprecationWarning)

assert isinstance(opt_setting, dict)
assert 'method' in opt_setting
assert 'lr' in opt_setting
opt_method = opt_setting.pop('method')
if isinstance(opt_method, str):
assert opt_method in optim.__dict__
opt_method = getattr(optim, opt_method)
assert isinstance(opt_method, type)
if optim.Optimizer not in inspect.getmro(opt_method):
raise ValueError
assert issubclass(opt_method, optim.Optimizer)
opt_lr = opt_setting.pop('lr')
assert isinstance(opt_lr, (int, float, optim.Scheduler))
opt_setting = opt_setting
optimizer = opt_method(lr=opt_lr, **opt_setting)

if self.verbose:
print(f"Optimizing with {opt_method.__name__} to find fixed points:")
print(f"Optimizing with {optimizer.__name__} to find fixed points:")

# set up optimization
fixed_points = bm.Variable(bm.asarray(candidates))
grad_f = bm.grad(lambda: self.f_loss_batch(fixed_points.value).mean(),
grad_vars={'a': fixed_points}, return_value=True)
opt = opt_method(train_vars={'a': fixed_points}, lr=opt_lr, **opt_setting)
dyn_vars = opt.vars() + {'_a': fixed_points}
optimizer.register_vars({'a': fixed_points})
dyn_vars = optimizer.vars() + {'_a': fixed_points}

def train(idx):
gradients, loss = grad_f()
opt.update(gradients)
optimizer.update(gradients)
return loss

@partial(bm.jit, dyn_vars=dyn_vars, static_argnames=('start_i', 'num_batch'))
Expand Down Expand Up @@ -191,7 +209,7 @@ def find_fps_with_opt_solver(self, candidates, opt_method=None):
opt_method = lambda f, x0: minimize(f, x0, method='BFGS')
if self.verbose:
print(f"Optimizing to find fixed points:")
f_opt = bm.jit(bm.vmap(lambda x0: opt_method(self.f_loss, x0)))
f_opt = bm.jit(vmap(lambda x0: opt_method(self.f_loss, x0)))
res = f_opt(bm.as_device_array(candidates))
valid_ids = jax.numpy.where(res.success)[0]
self._fixed_points = np.asarray(res.x[valid_ids])
Expand Down
39 changes: 20 additions & 19 deletions brainpy/analysis/lowdim/lowdim_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import partial

import numpy as np
from jax import vmap
from jax import numpy as jnp
from jax.scipy.optimize import minimize

Expand Down Expand Up @@ -262,7 +263,7 @@ def F_fx(self):
@property
def F_vmap_fx(self):
if C.F_vmap_fx not in self.analyzed_results:
self.analyzed_results[C.F_vmap_fx] = bm.jit(bm.vmap(self.F_fx), device=self.jit_device)
self.analyzed_results[C.F_vmap_fx] = bm.jit(vmap(self.F_fx), device=self.jit_device)
return self.analyzed_results[C.F_vmap_fx]

@property
Expand All @@ -289,7 +290,7 @@ def F_vmap_fp_aux(self):
# ---
# "X": a two-dimensional matrix: (num_batch, num_var)
# "args": a list of one-dimensional vectors, each has the shape of (num_batch,)
self.analyzed_results[C.F_vmap_fp_aux] = bm.jit(bm.vmap(self.F_fixed_point_aux))
self.analyzed_results[C.F_vmap_fp_aux] = bm.jit(vmap(self.F_fixed_point_aux))
return self.analyzed_results[C.F_vmap_fp_aux]

@property
Expand All @@ -308,7 +309,7 @@ def F_vmap_fp_opt(self):
# ---
# "X": a two-dimensional matrix: (num_batch, num_var)
# "args": a list of one-dimensional vectors, each has the shape of (num_batch,)
self.analyzed_results[C.F_vmap_fp_opt] = bm.jit(bm.vmap(self.F_fixed_point_opt))
self.analyzed_results[C.F_vmap_fp_opt] = bm.jit(vmap(self.F_fixed_point_opt))
return self.analyzed_results[C.F_vmap_fp_opt]

def _get_fixed_points(self, candidates, *args, num_seg=None, tol_aux=1e-7, loss_screen=None):
Expand Down Expand Up @@ -501,7 +502,7 @@ def F_y_by_x_in_fy(self):
@property
def F_vmap_fy(self):
if C.F_vmap_fy not in self.analyzed_results:
self.analyzed_results[C.F_vmap_fy] = bm.jit(bm.vmap(self.F_fy), device=self.jit_device)
self.analyzed_results[C.F_vmap_fy] = bm.jit(vmap(self.F_fy), device=self.jit_device)
return self.analyzed_results[C.F_vmap_fy]

@property
Expand Down Expand Up @@ -663,7 +664,7 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux

if self.F_x_by_y_in_fx is not None:
utils.output("I am evaluating fx-nullcline by F_x_by_y_in_fx ...")
vmap_f = bm.jit(bm.vmap(self.F_x_by_y_in_fx), device=self.jit_device)
vmap_f = bm.jit(vmap(self.F_x_by_y_in_fx), device=self.jit_device)
for j, pars in enumerate(par_seg):
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
mesh_values = jnp.meshgrid(*((ys,) + pars))
Expand All @@ -679,7 +680,7 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux

elif self.F_y_by_x_in_fx is not None:
utils.output("I am evaluating fx-nullcline by F_y_by_x_in_fx ...")
vmap_f = bm.jit(bm.vmap(self.F_y_by_x_in_fx), device=self.jit_device)
vmap_f = bm.jit(vmap(self.F_y_by_x_in_fx), device=self.jit_device)
for j, pars in enumerate(par_seg):
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
mesh_values = jnp.meshgrid(*((xs,) + pars))
Expand All @@ -697,9 +698,9 @@ def _get_fx_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux
utils.output("I am evaluating fx-nullcline by optimization ...")
# auxiliary functions
f2 = lambda y, x, *pars: self.F_fx(x, y, *pars)
vmap_f2 = bm.jit(bm.vmap(f2), device=self.jit_device)
vmap_brentq_f2 = bm.jit(bm.vmap(utils.jax_brentq(f2)), device=self.jit_device)
vmap_brentq_f1 = bm.jit(bm.vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device)
vmap_f2 = bm.jit(vmap(f2), device=self.jit_device)
vmap_brentq_f2 = bm.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device)
vmap_brentq_f1 = bm.jit(vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device)

# num segments
for _j, Ps in enumerate(par_seg):
Expand Down Expand Up @@ -756,7 +757,7 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux

if self.F_x_by_y_in_fy is not None:
utils.output("I am evaluating fy-nullcline by F_x_by_y_in_fy ...")
vmap_f = bm.jit(bm.vmap(self.F_x_by_y_in_fy), device=self.jit_device)
vmap_f = bm.jit(vmap(self.F_x_by_y_in_fy), device=self.jit_device)
for j, pars in enumerate(par_seg):
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
mesh_values = jnp.meshgrid(*((ys,) + pars))
Expand All @@ -772,7 +773,7 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux

elif self.F_y_by_x_in_fy is not None:
utils.output("I am evaluating fy-nullcline by F_y_by_x_in_fy ...")
vmap_f = bm.jit(bm.vmap(self.F_y_by_x_in_fy), device=self.jit_device)
vmap_f = bm.jit(vmap(self.F_y_by_x_in_fy), device=self.jit_device)
for j, pars in enumerate(par_seg):
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
mesh_values = jnp.meshgrid(*((xs,) + pars))
Expand All @@ -791,9 +792,9 @@ def _get_fy_nullcline_points(self, coords=None, tol=1e-7, num_segments=1, fp_aux

# auxiliary functions
f2 = lambda y, x, *pars: self.F_fy(x, y, *pars)
vmap_f2 = bm.jit(bm.vmap(f2), device=self.jit_device)
vmap_brentq_f2 = bm.jit(bm.vmap(utils.jax_brentq(f2)), device=self.jit_device)
vmap_brentq_f1 = bm.jit(bm.vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device)
vmap_f2 = bm.jit(vmap(f2), device=self.jit_device)
vmap_brentq_f2 = bm.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device)
vmap_brentq_f1 = bm.jit(vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device)

for j, Ps in enumerate(par_seg):
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...")
Expand Down Expand Up @@ -841,7 +842,7 @@ def _get_fp_candidates_by_aux_rank(self, num_segments=1, num_rank=100):
xs = self.resolutions[self.x_var].value
ys = self.resolutions[self.y_var].value
P = tuple(self.resolutions[p].value for p in self.target_par_names)
f_select = bm.jit(bm.vmap(lambda vals, ids: vals[ids], in_axes=(1, 1)))
f_select = bm.jit(vmap(lambda vals, ids: vals[ids], in_axes=(1, 1)))

# num seguments
if isinstance(num_segments, int):
Expand Down Expand Up @@ -921,10 +922,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,

if self.convert_type() == C.x_by_y:
num_seg = len(self.resolutions[self.y_var])
f_vmap = bm.jit(bm.vmap(self.F_y_convert[1]))
f_vmap = bm.jit(vmap(self.F_y_convert[1]))
else:
num_seg = len(self.resolutions[self.x_var])
f_vmap = bm.jit(bm.vmap(self.F_x_convert[1]))
f_vmap = bm.jit(vmap(self.F_x_convert[1]))
# get the signs
signs = jnp.sign(f_vmap(candidates, *args))
signs = signs.reshape((num_seg, -1))
Expand Down Expand Up @@ -954,10 +955,10 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
# get another value
if self.convert_type() == C.x_by_y:
y_values = fps
x_values = bm.jit(bm.vmap(self.F_y_convert[0]))(y_values, *args)
x_values = bm.jit(vmap(self.F_y_convert[0]))(y_values, *args)
else:
x_values = fps
y_values = bm.jit(bm.vmap(self.F_x_convert[0]))(x_values, *args)
y_values = bm.jit(vmap(self.F_x_convert[0]))(x_values, *args)
fps = jnp.stack([x_values, y_values]).T
return fps, selected_ids, args

Expand Down
Loading

0 comments on commit 5eb9f03

Please sign in to comment.