forked from FabioMurgese/computational-neuroscience
-
Notifications
You must be signed in to change notification settings - Fork 1
/
bcm_rule.m
84 lines (70 loc) · 2.27 KB
/
bcm_rule.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
%%%%%%%%%%%%%%% Learning using BCM Rule %%%%%%%%%%%%%%%
clear variables;
data = readtable('../lab2_1_data.csv'); % importing data as table
U = table2array(data); % converting table into input array
U_size = size(U,2); % training set dimension
eta = 5*10e-4; % learning rate
epochs = 2000; % iterations
theta = -1;
stop_condition = 10e-7;
Q = U*U'; % input correlation matrix
weights = [];
w = -1 + 2.*rand(2,1); % random weights initialization
W_norm = [];
for i = 1:epochs
U = U(:,randperm(U_size)); % reshuffling dataset
w_temp = w;
for n = 1:U_size
% linear firing model
u = U(:,n);
v = w' * u; % compute output
delta_w = v * u * (v - theta);
w = w + eta * delta_w; % update weights
theta = eta * (v^2 - theta) + theta; % update theta
end
weights = [weights; w];
W_norm = [W_norm; norm(w)];
diff = norm(w - w_temp);
fprintf('Epoch: %d Norm(W): %1.5f Diff: %1.7f Theta: %1.7f \n', i, norm(w), diff, theta)
if diff < stop_condition % stop condition
break;
end
end
[eigvecs, D] = eig(Q); % computing eigenvalues and diagonal matrix of Q
eigvals = diag(D); % storing eigenvalues in a separated array
[max_eigval, max_i] = max(eigvals); % take the principal eigenvector index
% Plotting data points and comparison between final weight vector and
% principal eigenvector of Q
fig = figure;
hold on
plot(U(1,:),U(2,:), '.')
plotv(eigvecs(:,max_i));
set(findall(gca,'Type', 'Line'),'LineWidth',1.75);
plotv(w/norm(w))
legend('data points','principal eigenvector','weight vector','Location', 'best')
title('P1: data points, final weight vector and principal eigenvector of Q');
print(fig,'P1.png','-dpng')
w1 = weights(1:2:end);
w2 = weights(2:2:end);
% weight over time, first component
fig = figure;
plot(w1)
xlabel('time')
ylabel('weight')
title('Weight vector over time (1st component)')
print(fig,'P2.1.png','-dpng')
% weight over time, second component
fig = figure;
plot(w2)
xlabel('time')
ylabel('weight')
title('Weight vector over time (2nd component)')
print(fig,'P2.2.png','-dpng')
% weight norm over time
fig = figure;
plot(1:size(W_norm,1), W_norm)
xlabel('time')
ylabel('weight')
title('Weight norm vector over time')
print(fig,'P2.3.png','-dpng')
save('weights.mat','weights');