-
Notifications
You must be signed in to change notification settings - Fork 2
/
dnf_2d.py
135 lines (106 loc) · 4.52 KB
/
dnf_2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
##########################################################
## ##
## Module: dnf_2d.py ##
## ##
## Version: 0.2 ##
## ##
## Description: A running simulation of a dynamic ##
## neural field with several modifications for ##
## improving runtime. ##
## ##
##########################################################
from numpy import *
from scipy.signal import *
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import matplotlib.pyplot as plt
from matplotlib.ticker import LinearLocator, FormatStrFormatter
class dnf:
######### Simulation variables ############
n=50 #Number of nodes per side
tau=0.1 #Tau
u=zeros((n,n)) #The neural field state at a specific time
I=zeros((n,n)) #Input activity
dx=2*pi/n #Length per node in the x dimension
dy=2*pi/n #Length per node in the y dimension
sig=2*pi/11*0.6 #Sigma of the gaussian function
c = 0.095 #Global inhibition
X, Y = meshgrid(
arange(n),
arange(n)) #A grid for plotting purposes
###########################################
#Takes a location between 0 and 2pi for both x and y along with the sigma
# value for the gaussian function. Produces an n*n array with the gaussian
# figure.
@staticmethod
def gauss_pbc(locx,locy,sig):
z=zeros((dnf.n,dnf.n))
for i in range(dnf.n):
for j in range(dnf.n):
d=min([abs(i*dnf.dx-locx) , 2*pi-abs(i*dnf.dx-locx)])
d2=min([abs(j*dnf.dy-locy), 2*pi-abs(j*dnf.dy-locy)])
z[j][i]=1./(sqrt(2*pi)*sig)\
*exp(-(d**2/(2*sig**2)+(d2**2/(2*sig**2))))
return z
def __init__(self):
self.gauss = dnf.gauss_pbc(pi,pi,self.sig)
self.z = 1000*(self.hebb()-self.c)
self.zxn = 1000*(self.hebb_PI_X_neg()-self.c)
self.zxp = 1000*(self.hebb_PI_X_pos()-self.c)
self.zyn = 1000*(self.hebb_PI_Y_neg()-self.c)
self.zyp = 1000*(self.hebb_PI_Y_pos()-self.c)
#Uses hebbian learning to produce a single n*n array of the relation between
# two gaussians. It yields the weights between the middle node and all the
# others.
def hebb(self):
z = convolve2d(self.gauss, self.gauss, 'same', 'wrap')
return z/(self.n*2)
def hebb_PI_X_neg(self):
z = convolve2d(self.gauss, roll(self.gauss,-1,axis=1),"same","wrap")
return z/(self.n*2)
def hebb_PI_X_pos(self):
z = convolve2d(self.gauss, roll(self.gauss,1,axis=1),"same","wrap")
return z/(self.n*2)
def hebb_PI_Y_neg(self):
z = convolve2d(self.gauss, roll(self.gauss,-1,axis=0),"same","wrap")
return z/(self.n*2)
def hebb_PI_Y_pos(self):
z = convolve2d(self.gauss, roll(self.gauss,1,axis=0),"same","wrap")
return z/(self.n*2)
#Takes a dynamic field state along with an activity input and the weight
# from hebb. It updates the dynamic neural field and returns it after one
# step in time.
def update(self,I):
r=0.5*(tanh(0.1*self.u)+1)
convo = convolve2d(r,self.z,'same','wrap')
self.u=self.u+self.tau*(-self.u+convo*self.dx+I)
#Takes a dnf state and draws it in 3d.
def plot(self):
fig = plt.figure()
ax = fig.gca(projection='3d')
surf = ax.plot_surface(self.X, self.Y, self.u, cmap=cm.jet,
linewidth=0, antialiased=True)
ax.w_zaxis.set_major_locator(LinearLocator(10))
ax.w_zaxis.set_major_formatter(FormatStrFormatter('%.1f'))
ax.set_zlabel("Excitation")
ax.set_xlabel("Node 'X'")
ax.set_ylabel("Node 'Y'")
fig.colorbar(surf, shrink=0.5, aspect=5)
plt.show()
#Test case for the module
if __name__ == "__main__":
dnfex = dnf()
#Provide input at pi/2 for 50 steps
I=dnf.gauss_pbc(3*pi/2,3*pi/2,dnf.sig)
for t in arange(50):
dnfex.update(I)
dnfex.plot()
# I=zeros((dnf.n,dnf.n))
# for t in arange(20):
# dnfex.update(I)
#
# dnfex.plot()
I=dnf.gauss_pbc(pi/2,pi/2,dnf.sig)
for t in arange(10):
dnfex.plot()
dnfex.update(I)