Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[connect] Complete FixedTotalNum class and fix bugs #434

Merged
merged 6 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion brainpy/_add_deprecations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

from ._src import checking, train, integrators
from . import tools, math, integrators, dyn, dnn, neurons, synapses, layers
from . import tools, math, integrators, dyn, dnn, neurons, synapses, layers, connect
from .integrators import ode, fde, sde
from brainpy._src.integrators.base import Integrator
from brainpy._src.integrators.runner import IntegratorRunner
Expand Down Expand Up @@ -114,3 +114,11 @@
# layers.__getattr__ = deprecation_getattr2('brainpy.layers', layers.__deprecations)


connect.__deprecations = {
'one2one': ('brainpy.connect.one2one', 'brainpy.connect.One2One', connect.One2One),
'all2all': ('brainpy.connect.all2all', 'brainpy.connect.All2All', connect.All2All),
'grid_four': ('brainpy.connect.grid_four', 'brainpy.connect.GridFour', connect.GridFour),
'grid_eight': ('brainpy.connect.grid_eight', 'brainpy.connect.GridEight', connect.GridEight),
}
connect.__getattr__ = deprecation_getattr2('brainpy.connect', connect.__deprecations)

63 changes: 62 additions & 1 deletion brainpy/_src/connect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
'mat2coo', 'mat2csc', 'mat2csr',
'csr2csc', 'csr2mat', 'csr2coo',
'coo2csr', 'coo2csc', 'coo2mat',
'coo2mat_num', 'mat2mat_num',

# visualize
'visualizeMat',
Expand Down Expand Up @@ -426,7 +427,8 @@ def require(self, *structures):
elif POST_IDS in structures and _has_coo_imp:
return bm.as_jax(self.build_coo()[1], dtype=IDX_DTYPE)
elif COO in structures and _has_coo_imp:
return bm.as_jax(self.build_coo(), dtype=IDX_DTYPE)
r = self.build_coo()
return bm.as_jax(r[0], dtype=IDX_DTYPE), bm.as_jax(r[1], dtype=IDX_DTYPE)

elif len(structures) == 2:
if (PRE_IDS in structures and POST_IDS in structures and _has_coo_imp):
Expand Down Expand Up @@ -725,8 +727,67 @@ def coo2csc(coo, post_num, data=None):
data_new = data[sort_ids]
return pre_ids_new, indptr_new, data_new

def coo2mat_num(ij, num_pre, num_post, num, seed=0):
"""
convert (indices, indptr) to a dense connection number matrix.\n
Specific for FixedTotalNum.
"""
rng = bm.random.RandomState(seed)
mat = coo2mat(ij, num_pre, num_post)

# get nonzero indices and number
nonzero_idx = jnp.nonzero(mat)
nonzero_num = jnp.count_nonzero(mat)

# get multi connection number
multi_conn_num = num - nonzero_num

# alter the element type to int
mat = mat.astype(jnp.int32)

# 随机在mat中选取nonzero_idx的元素,将其值加1
index = rng.choice(nonzero_num, size=(multi_conn_num,), replace=False)
for i in index:
mat = mat.at[nonzero_idx[0][i], nonzero_idx[1][i]].set(mat[nonzero_idx[0][i], nonzero_idx[1][i]] + 1)

return mat

def mat2mat_num(mat, num, seed=0):
"""
Convert boolean matrix to a dense connection number matrix.\n
Specific for FixedTotalNum.
"""
rng = bm.random.RandomState(seed)

# get nonzero indices and number
nonzero_idx = jnp.nonzero(mat)
nonzero_num = jnp.count_nonzero(mat)

# get multi connection number
multi_conn_num = num - nonzero_num

# alter the element type to int
mat = mat.astype(jnp.int32)

# 随机在mat中选取nonzero_idx的元素,将其值加1
index = rng.choice(nonzero_num, size=(multi_conn_num,), replace=False)
for i in index:
mat = mat.at[nonzero_idx[0][i], nonzero_idx[1][i]].set(mat[nonzero_idx[0][i], nonzero_idx[1][i]] + 1)

return mat


def visualizeMat(mat, description='Untitled'):
"""
Visualize the matrix. (Need seaborn and matplotlib)

parameters
----------
mat : jnp.ndarray
The matrix to be visualized.
description : str
The title of the figure.
"""
try:
import seaborn as sns
import matplotlib.pyplot as plt
Expand Down
24 changes: 18 additions & 6 deletions brainpy/_src/connect/random_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,17 @@ class FixedTotalNum(TwoEndConnector):
----------
num : float,int
The conn total number.
allow_multi_conn : bool, optional
Whether allow one pre-synaptic neuron connects to multiple post-synaptic neurons.
seed: int, optional
The random number seed.
"""

def __init__(self, num, seed=None, **kwargs):
super(FixedTotalNum, self).__init__(**kwargs)
def __init__(self,
num,
allow_multi_conn=False,
seed=None, **kwargs):
super().__init__(**kwargs)
if isinstance(num, int):
assert num >= 0, '"num" must be a non-negative integer.'
elif isinstance(num, float):
Expand All @@ -157,14 +162,21 @@ def __init__(self, num, seed=None, **kwargs):
raise ConnectorError(f'Unknown type: {type(num)}')
self.num = num
self.seed = format_seed(seed)
self.allow_multi_conn = allow_multi_conn
self.rng = bm.random.RandomState(self.seed)

def build_coo(self):
if self.num > self.pre_num * self.post_num:
mat_element_num = self.pre_num * self.post_num
if self.num > mat_element_num:
raise ConnectorError(f'"num" must be smaller than "all2all num", '
f'but got {self.num} > {self.pre_num * self.post_num}')
selected_pre_ids = self.rng.randint(0, self.pre_num, (self.num,))
selected_post_ids = self.rng.randint(0, self.post_num, (self.num,))
f'but got {self.num} > {mat_element_num}')
if self.allow_multi_conn:
selected_pre_ids = self.rng.randint(0, self.pre_num, (self.num,))
selected_post_ids = self.rng.randint(0, self.post_num, (self.num,))
else:
index = self.rng.choice(mat_element_num, size=(self.num,), replace=False)
selected_pre_ids = index // self.post_num
selected_post_ids = index % self.post_num
return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE)

def __repr__(self):
Expand Down
10 changes: 10 additions & 0 deletions brainpy/_src/connect/tests/test_random_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ def test_random_fix_post3():
conn1 = bp.connect.FixedPostNum(num=6, seed=1234)(pre_size=3, post_size=4)
conn1.require(bp.connect.CONN_MAT)

def test_random_fix_total1():
conn1 = bp.connect.FixedTotalNum(num=8, allow_multi_conn=False, seed=1234)(pre_size=3, post_size=4)
coo1 = conn1.require(bp.connect.COO)
conn_mat = bp.connect.coo2mat_num(ij=coo1, num_pre=3, num_post=4, num=conn1.num, seed=1234)

def test_random_fix_total2():
conn1 = bp.connect.FixedTotalNum(num=8, allow_multi_conn=True, seed=1234)(pre_size=3, post_size=4)
mat1 = conn1.require(bp.connect.CONN_MAT)
conn_mat = bp.connect.mat2mat_num(mat=mat1, num=conn1.num, seed=1234)


def test_gaussian_prob1():
conn = bp.connect.GaussianProb(sigma=1., include_self=False)(pre_size=100)
Expand Down
2 changes: 2 additions & 0 deletions brainpy/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
coo2csr as coo2csr,
coo2csc as coo2csc,
coo2mat as coo2mat,
coo2mat_num as coo2mat_num,
mat2mat_num as mat2mat_num,
visualizeMat as visualizeMat,

CONN_MAT,
Expand Down
Loading