Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question #692

Open
madara-1112 opened this issue Nov 6, 2024 · 16 comments
Open

Question #692

madara-1112 opened this issue Nov 6, 2024 · 16 comments

Comments

@madara-1112
Copy link

Hello, I have a question about the function pre2post_event_sum. The documentation says
"The pre-to-post event-driven synaptic summation with CSR synapse structure.
When values is a scalar, this function is equivalent to
post_val = np.zeros(post_num)
post_ids, idnptr = pre2post
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
post_val[post_ids[i]] += values"
But I wonder if “post_val[post_ids[i]] += values” was written incorrectly and should be changed to "post_val[post_ids[j]] += values".

@madara-1112
Copy link
Author

madara-1112 commented Nov 6, 2024

And I may also encountered another bug. I tried to comulate the firing rate of HH neuron using the following code. When I set the strength of input to be 0.280 and duration to be 1000, the figure was obviously incorrectly.
205621730902081_ pic
And when I set the strength to 0.278 or 0.280, they both worked successfully as follow.
205641730902101_ pic

`import brainpy as bp
import brainpy.math as bm
import numpy as np
import matplotlib.pyplot as plt

class HH(bp.dyn.NeuGroup):

def __init__(self,size,ENa=50.,gNa=1.2,EK=-77.
             ,gK=0.36,EL=-54.387,gL=0.003,V_th=0.,C=0.01):
    super(HH,self).__init__(size=size)
    self.ENa=ENa
    self.EK=EK
    self.EL=EL
    self.gNa=gNa
    self.gK=gK
    self.gL=gL
    self.C=C
    self.V_th=V_th


    self.V = bm.Variable(-65 * bm.ones(self.num))
    self.m = bm.Variable(0.0529 * bm.ones(self.num))
    self.h = bm.Variable(0.5961 * bm.ones(self.num))
    self.n = bm.Variable(0.3177 * bm.ones(self.num))
    self.gNa_=bm.Variable(0 * bm.ones(self.num))
    self.gK_ = bm.Variable(0 * bm.ones(self.num))

    self.input = bm.Variable(bm.zeros(self.num))
    self.spike = bm.Variable(bm.zeros(self.num,dtype=bool))
    self.t_last_spike=bm.Variable(bm.ones(self.num)*-1e7)

    self.intergral = bp.odeint(f=self.derivative,method='exp_auto')

@property
def derivative(self):
    return bp.JointEq(self.dV, self.dm, self.dh, self.dn)

def dm(self, m, t, V):
    alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10))
    beta = 4.0 * bm.exp(-(V + 65) / 18)
    dmdt = alpha * (1 - m) - beta * m
    return  dmdt

def dh(self, h, t, V):
    alpha = 0.07 * bm.exp(-(V + 65) / 20.)
    beta = 1 / (1 + bm.exp(-(V + 35) / 10))
    dhdt = alpha * (1 - h) - beta * h
    return  dhdt

def dn(self, n, t, V):
    alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10))
    beta = 0.125 * bm.exp(-(V + 65) / 80)
    dndt = alpha * (1 - n) - beta * n
    return  dndt



def dV(self, V, t, h, n, m):

    I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
    I_K = (self.gK * n ** 4.0) * (V - self.EK)
    I_leak = self.gL * (V - self.EL)
    dVdt = (- I_Na - I_K - I_leak + self.input) / self.C

    return dVdt


def update(self, tdi):

    t, dt = tdi.t, tdi.dt
    V, m, h, n = self.intergral(self.V, self.m, self.h, self.n, t, dt=dt)
    self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
    self.t_last_spike.value = bm.where(self.spike, t, self.t_last_spike)
    self.V.value = V
    self.m.value = m
    self.h.value = h
    self.n.value = n

    self.gNa_.value=self.gNa * m ** 3.0 * h #记录钠电导变化
    self.gK_.value=self.gK * n ** 4.0 #记录钾电导变化
    self.input[:] = 0.  # 重置神经元接收到的输入

currents, length = bp.inputs.section_input(values=[0.280],
durations=[1000],return_length=True)
hh = HH(1)
runner = bp.dyn.DSRunner(hh,monitors=['V', 'm', 'h', 'n','gNa_','gK_'],
inputs=['input', currents, 'iter'])
runner.run(length)
fig,axe=plt.subplots(2,1)
axe[0].plot(runner.mon.ts, runner.mon.V,linewidth=2)
print(runner.mon.V.shape)
axe[0].set_ylabel('V(mV)')

axe[1].plot(runner.mon.ts, runner.mon.gNa_,linewidth=2,color='blue')

axe[1].plot(runner.mon.ts, runner.mon.gK_,linewidth=2,color='red')

axe[1].set_ylabel('Conductance')

axe[1].plot(runner.mon.ts, runner.mon.m,linewidth=2,color='blue',label='m')
axe[1].plot(runner.mon.ts, runner.mon.n,linewidth=2,color='red',label='n')
axe[1].plot(runner.mon.ts, runner.mon.h,linewidth=2,color='orange',label='h')
axe[1].set_ylabel('Channel')
plt.legend()
plt.tight_layout()
plt.show()`

@madara-1112
Copy link
Author

madara-1112 commented Nov 6, 2024

The version I use is brainpy 2.4.5

@Routhleck
Copy link
Collaborator

Thank you for your question. Indeed, in the documentation, it should be corrected from "post_val[post_ids[i]] += values" to "post_val[post_ids[j]] += values". We will make the necessary changes to the documentation as soon as possible.

Regarding your second question, I am not quite clear about it. Could you provide the problematic code and specify which parameters should be modified to address the issue? Perhaps you could try upgrading brainpy to the latest version?

pip install brainpy -U

@madara-1112
Copy link
Author

madara-1112 commented Nov 9, 2024

Thank you for answering my first question!
As for the second question, the original code are attached. I'm not sure what caused the problem. I tried to retell the problem I encountered: When my current strength was set to 0.280, there was an obvious error in the image, which only showed the first 5000 time steps, even though I had set the model to simulate 10000 time steps. However, when I set the current strength to some other value such as 0.278, the image correctly showed all time steps. I don't know why that current strength value is so special to cause the bug.
code.txt

@madara-1112
Copy link
Author

I adjusted the time step of the model simulation and found that when the current strength is set to 0.280, the image could only show the first 5000 time points at most, no matter how long the time points I set was. That was, It could showed correctly all time steps if the time steps were set under 5000.

@madara-1112
Copy link
Author

I might also try upgrading brainpy😂

@Routhleck
Copy link
Collaborator

It seems I know what the issue is. Is your JAX version above 0.4.32? JAX introduced an asynchronous CPU scheduling mechanism in version 0.4.32, which can cause runner.run() to return prematurely and allow the subsequent code to execute. You can consider downgrading JAX to a version below 0.4.31, or change runner.run(length) to jax.block_until_ready(runner.run(length)).

@Routhleck
Copy link
Collaborator

Could you please provide the specific hardware information of your device? It seems that my device has difficulty reproducing the error.

@madara-1112
Copy link
Author

The JAX version (also jaxlib) I used is 0.4.16. I ran this code on MacBook Pro, M2, macOS Sequoia 15.2.

@madara-1112
Copy link
Author

madara-1112 commented Nov 10, 2024

I changed runner.run(length) to jax.block_until_ready(runner.run(length)) but it did not work either.

@Routhleck
Copy link
Collaborator

@ztqakita

@madara-1112
Copy link
Author

madara-1112 commented Nov 10, 2024

I wanted to see how firing rate varies with the current strength, and it output like this😂.
duration=1000 I=np.arange(0,0.5,0.002) group=HH(len(I)) runner = bp.dyn.DSRunner(group, monitors=['spike'], inputs=['input', I]) runner(duration=duration) F=runner.mon.spike.sum(axis=0)/(duration/1000) print(F) plt.plot(I,F,linewidth=2) plt.xlabel('I(mA/mm^2)') plt.ylabel('F(Hz)') plt.title('firing rate vs current') plt.show()

Figure_1

I printed the value of V[5000:10000] when the current strength was set at 0.280 and found that they were all NaN

@ztqakita
Copy link
Collaborator

There are two places in your code that can be fixed with a few changes, as shown below:
image
You can replace bm.exp with bm.exprel to avoid the NaN problem. When x is near zero, exp(x) is near 1, so the numerical calculation of exp(x) - 1 can suffer from catastrophic loss of precision. exprel(x) is implemented to avoid the loss of precision that occurs when x is near zero.

@madara-1112
Copy link
Author

Thank you for your reply! It seems that I must update the version of brainpy to use this funciton, and I'll try it.

@madara-1112
Copy link
Author

I updated brainpy and jax, and replaced bm.exp with bm.exprel as you did. It did work at current strength = 0.280, but i am not sure it also worked at other values of current strength, which i never encounter before replacing the function, with the firing rates shown below.
Figure_1
I set the current strength to zero, and was surprised to find that the HH neuron still produced spikes.
Figure_1

@madara-1112
Copy link
Author

madara-1112 commented Nov 10, 2024

When I went back to the orginal function bm.exp, it behaved just like before, working well at other values of current strength except 0.280, which ruled out the potential explanation of issues caused by packages updates.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants