From 953b3628cfd1f654652c458697384d2f488b389e Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 9 Aug 2023 14:30:59 +0800 Subject: [PATCH 1/5] Fix bugs in the connect module and Complete FixedTotalNum class --- brainpy/_src/connect/base.py | 53 ++++++++++++++++++- brainpy/_src/connect/random_conn.py | 22 +++++--- .../_src/connect/tests/test_random_conn.py | 12 +++++ brainpy/connect.py | 2 + 4 files changed, 82 insertions(+), 7 deletions(-) diff --git a/brainpy/_src/connect/base.py b/brainpy/_src/connect/base.py index 3a264d313..eef74dfcb 100644 --- a/brainpy/_src/connect/base.py +++ b/brainpy/_src/connect/base.py @@ -29,6 +29,7 @@ 'mat2coo', 'mat2csc', 'mat2csr', 'csr2csc', 'csr2mat', 'csr2coo', 'coo2csr', 'coo2csc', 'coo2mat', + 'coo2mat_num', 'mat2mat_num', # visualize 'visualizeMat', @@ -426,7 +427,8 @@ def require(self, *structures): elif POST_IDS in structures and _has_coo_imp: return bm.as_jax(self.build_coo()[1], dtype=IDX_DTYPE) elif COO in structures and _has_coo_imp: - return bm.as_jax(self.build_coo(), dtype=IDX_DTYPE) + r = self.build_coo() + return bm.as_jax(r[0], dtype=IDX_DTYPE), bm.as_jax(r[1], dtype=IDX_DTYPE) elif len(structures) == 2: if (PRE_IDS in structures and POST_IDS in structures and _has_coo_imp): @@ -725,6 +727,55 @@ def coo2csc(coo, post_num, data=None): data_new = data[sort_ids] return pre_ids_new, indptr_new, data_new +def coo2mat_num(ij, num_pre, num_post, num, seed=0): + """ + convert (indices, indptr) to a dense connection number matrix.\n + Specific for FixedTotalNum. + """ + rng = bm.random.RandomState(seed) + mat = coo2mat(ij, num_pre, num_post) + + # get nonzero indices and number + nonzero_idx = jnp.nonzero(mat) + nonzero_num = jnp.count_nonzero(mat) + + # get multi connection number + multi_conn_num = num - nonzero_num + + # alter the element type to int + mat = mat.astype(jnp.int32) + + # 随机在mat中选取nonzero_idx的元素,将其值加1 + index = rng.choice(nonzero_num, size=(multi_conn_num,), replace=False) + for i in index: + mat = mat.at[nonzero_idx[0][i], nonzero_idx[1][i]].set(mat[nonzero_idx[0][i], nonzero_idx[1][i]] + 1) + + return mat + +def mat2mat_num(mat, num, seed=0): + """ + Convert boolean matrix to a dense connection number matrix.\n + Specific for FixedTotalNum. + """ + rng = bm.random.RandomState(seed) + + # get nonzero indices and number + nonzero_idx = jnp.nonzero(mat) + nonzero_num = jnp.count_nonzero(mat) + + # get multi connection number + multi_conn_num = num - nonzero_num + + # alter the element type to int + mat = mat.astype(jnp.int32) + + # 随机在mat中选取nonzero_idx的元素,将其值加1 + index = rng.choice(nonzero_num, size=(multi_conn_num,), replace=False) + for i in index: + mat = mat.at[nonzero_idx[0][i], nonzero_idx[1][i]].set(mat[nonzero_idx[0][i], nonzero_idx[1][i]] + 1) + + return mat + def visualizeMat(mat, description='Untitled'): try: diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py index 5c66e47c7..3009f28fc 100644 --- a/brainpy/_src/connect/random_conn.py +++ b/brainpy/_src/connect/random_conn.py @@ -147,8 +147,11 @@ class FixedTotalNum(TwoEndConnector): The random number seed. """ - def __init__(self, num, seed=None, **kwargs): - super(FixedTotalNum, self).__init__(**kwargs) + def __init__(self, + num, + allow_multi_conn=False, + seed=None, **kwargs): + super().__init__(**kwargs) if isinstance(num, int): assert num >= 0, '"num" must be a non-negative integer.' elif isinstance(num, float): @@ -157,14 +160,21 @@ def __init__(self, num, seed=None, **kwargs): raise ConnectorError(f'Unknown type: {type(num)}') self.num = num self.seed = format_seed(seed) + self.allow_multi_conn = allow_multi_conn self.rng = bm.random.RandomState(self.seed) def build_coo(self): - if self.num > self.pre_num * self.post_num: + mat_element_num = self.pre_num * self.post_num + if self.num > mat_element_num: raise ConnectorError(f'"num" must be smaller than "all2all num", ' - f'but got {self.num} > {self.pre_num * self.post_num}') - selected_pre_ids = self.rng.randint(0, self.pre_num, (self.num,)) - selected_post_ids = self.rng.randint(0, self.post_num, (self.num,)) + f'but got {self.num} > {mat_element_num}') + if self.allow_multi_conn: + selected_pre_ids = self.rng.randint(0, self.pre_num, (self.num,)) + selected_post_ids = self.rng.randint(0, self.post_num, (self.num,)) + else: + index = self.rng.choice(mat_element_num, size=(self.num,), replace=False) + selected_pre_ids = index // self.post_num + selected_post_ids = index % self.post_num return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE) def __repr__(self): diff --git a/brainpy/_src/connect/tests/test_random_conn.py b/brainpy/_src/connect/tests/test_random_conn.py index 195761548..b918d0f4b 100644 --- a/brainpy/_src/connect/tests/test_random_conn.py +++ b/brainpy/_src/connect/tests/test_random_conn.py @@ -87,6 +87,18 @@ def test_random_fix_post3(): conn1 = bp.connect.FixedPostNum(num=6, seed=1234)(pre_size=3, post_size=4) conn1.require(bp.connect.CONN_MAT) +def test_random_fix_total1(): + conn1 = bp.connect.FixedTotalNum(num=8, allow_multi_conn=False, seed=1234)(pre_size=3, post_size=4) + coo1 = conn1.require(bp.connect.COO) + conn_mat = bp.connect.coo2mat_num(ij=coo1, num_pre=3, num_post=4, num=conn1.num, seed=1234) + bp.connect.visualizeMat(conn_mat, 'FixedTotalNum: allow_multi_conn=False') + +def test_random_fix_total2(): + conn1 = bp.connect.FixedTotalNum(num=8, allow_multi_conn=True, seed=1234)(pre_size=3, post_size=4) + mat1 = conn1.require(bp.connect.CONN_MAT) + conn_mat = bp.connect.mat2mat_num(mat=mat1, num=conn1.num, seed=1234) + bp.connect.visualizeMat(conn_mat, 'FixedTotalNum: allow_multi_conn=True') + def test_gaussian_prob1(): conn = bp.connect.GaussianProb(sigma=1., include_self=False)(pre_size=100) diff --git a/brainpy/connect.py b/brainpy/connect.py index 0024b08aa..c3005f595 100644 --- a/brainpy/connect.py +++ b/brainpy/connect.py @@ -13,6 +13,8 @@ coo2csr as coo2csr, coo2csc as coo2csc, coo2mat as coo2mat, + coo2mat_num as coo2mat_num, + mat2mat_num as mat2mat_num, visualizeMat as visualizeMat, CONN_MAT, From f213be2744f3be5f1412c4acfe1e2bdba253f033 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 9 Aug 2023 14:37:02 +0800 Subject: [PATCH 2/5] Add description for FixedTotalNum --- brainpy/_src/connect/random_conn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py index 3009f28fc..21e47a5c0 100644 --- a/brainpy/_src/connect/random_conn.py +++ b/brainpy/_src/connect/random_conn.py @@ -143,6 +143,8 @@ class FixedTotalNum(TwoEndConnector): ---------- num : float,int The conn total number. + allow_multi_conn : bool, optional + Whether allow multiple connections between two neurons. seed: int, optional The random number seed. """ From 026378ce7e3091ac8ac60b78d3150bba8d4d1540 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 9 Aug 2023 14:42:51 +0800 Subject: [PATCH 3/5] dd description for `visualizeMat` --- brainpy/_src/connect/base.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/brainpy/_src/connect/base.py b/brainpy/_src/connect/base.py index eef74dfcb..9b7636d3d 100644 --- a/brainpy/_src/connect/base.py +++ b/brainpy/_src/connect/base.py @@ -778,6 +778,16 @@ def mat2mat_num(mat, num, seed=0): def visualizeMat(mat, description='Untitled'): + """ + Visualize the matrix. (Need seaborn and matplotlib) + + parameters + ---------- + mat : jnp.ndarray + The matrix to be visualized. + description : str + The title of the figure. + """ try: import seaborn as sns import matplotlib.pyplot as plt From 917ec9e4f0dcd6eea43d2293bb09d859ac42f7cc Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 9 Aug 2023 15:28:34 +0800 Subject: [PATCH 4/5] Update test_random_conn.py --- brainpy/_src/connect/random_conn.py | 2 +- brainpy/_src/connect/tests/test_random_conn.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py index 21e47a5c0..ff4c2d50d 100644 --- a/brainpy/_src/connect/random_conn.py +++ b/brainpy/_src/connect/random_conn.py @@ -144,7 +144,7 @@ class FixedTotalNum(TwoEndConnector): num : float,int The conn total number. allow_multi_conn : bool, optional - Whether allow multiple connections between two neurons. + Whether allow one pre-synaptic neuron connects to multiple post-synaptic neurons. seed: int, optional The random number seed. """ diff --git a/brainpy/_src/connect/tests/test_random_conn.py b/brainpy/_src/connect/tests/test_random_conn.py index b918d0f4b..68531ded7 100644 --- a/brainpy/_src/connect/tests/test_random_conn.py +++ b/brainpy/_src/connect/tests/test_random_conn.py @@ -91,13 +91,11 @@ def test_random_fix_total1(): conn1 = bp.connect.FixedTotalNum(num=8, allow_multi_conn=False, seed=1234)(pre_size=3, post_size=4) coo1 = conn1.require(bp.connect.COO) conn_mat = bp.connect.coo2mat_num(ij=coo1, num_pre=3, num_post=4, num=conn1.num, seed=1234) - bp.connect.visualizeMat(conn_mat, 'FixedTotalNum: allow_multi_conn=False') def test_random_fix_total2(): conn1 = bp.connect.FixedTotalNum(num=8, allow_multi_conn=True, seed=1234)(pre_size=3, post_size=4) mat1 = conn1.require(bp.connect.CONN_MAT) conn_mat = bp.connect.mat2mat_num(mat=mat1, num=conn1.num, seed=1234) - bp.connect.visualizeMat(conn_mat, 'FixedTotalNum: allow_multi_conn=True') def test_gaussian_prob1(): From 0088b54a83448fdea65941566528344e049a276f Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 9 Aug 2023 16:01:18 +0800 Subject: [PATCH 5/5] Add connect module deprecations --- brainpy/_add_deprecations.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/brainpy/_add_deprecations.py b/brainpy/_add_deprecations.py index 0b782d3cf..89fd1dd8c 100644 --- a/brainpy/_add_deprecations.py +++ b/brainpy/_add_deprecations.py @@ -1,6 +1,6 @@ from ._src import checking, train, integrators -from . import tools, math, integrators, dyn, dnn, neurons, synapses, layers +from . import tools, math, integrators, dyn, dnn, neurons, synapses, layers, connect from .integrators import ode, fde, sde from brainpy._src.integrators.base import Integrator from brainpy._src.integrators.runner import IntegratorRunner @@ -114,3 +114,11 @@ # layers.__getattr__ = deprecation_getattr2('brainpy.layers', layers.__deprecations) +connect.__deprecations = { + 'one2one': ('brainpy.connect.one2one', 'brainpy.connect.One2One', connect.One2One), + 'all2all': ('brainpy.connect.all2all', 'brainpy.connect.All2All', connect.All2All), + 'grid_four': ('brainpy.connect.grid_four', 'brainpy.connect.GridFour', connect.GridFour), + 'grid_eight': ('brainpy.connect.grid_eight', 'brainpy.connect.GridEight', connect.GridEight), +} +connect.__getattr__ = deprecation_getattr2('brainpy.connect', connect.__deprecations) +