-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathscaetrain_3d.m
111 lines (92 loc) · 3.65 KB
/
scaetrain_3d.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
function scae = scaetrain_3d(scae, x, opts)
%TODO: Transform x through scae{1} into new x. Only works for a single PAE.
for i=1:numel(scae)
disp(['---------SAE' num2str(i) '---------']);
fprintf('\n');
% scae{i} = paetrain(scae{i}, x, opts);
if i == 1
hidden = x{1};
scae{i} = caetrain_3d(scae{i}, hidden, opts);
for l = 1:numel(hidden)
x1{1} = hidden{l};
scae{i} = caeup_3d(scae{i}, x1);
X = scae{i}.a;
ll = size(X{1});
for j = 1:numel(X)
B=X{j}(scae{i}.M);
B=B+rand(size(B))*1e-12;
B=(B.*(B==repmat(max(B,[],1),[size(B,1) 1])));
b=B(B~=0);
b=b(1:size(B,2));
b=reshape(b,ll/2);
hidden_tmp{j} = b;
% zero-padding----
for k = 1:3
if mod(size(hidden_tmp{j},k),2) ~= 0
if k == 1
hidden_tmp{j} = padarray(hidden_tmp{j},[1,0,0],'post');
elseif k == 2
hidden_tmp{j} = padarray(hidden_tmp{j},[0,1,0],'post');
else
hidden_tmp{j} = padarray(hidden_tmp{j},[0,0,1],'post');
end
end
end
%-----------------
end
scae{i}.hid{1}{l} = hidden_tmp;
end
else
% hidden = scae{i-1}.hid{1}{l};
hidden = scae{i-1}.hid{1};
scae{i} = caetrain_3d(scae{i}, scae{i-1}.hid{1}, opts);
% for l = 1:numel(hidden)
for l = 1:numel(hidden)
% for k = 1:numel(hidden{1})
x1{1} = hidden{l};
scae{i} = caeup_3d(scae{i}, x1{1});
X = scae{i}.a;
ll = size(X{1});
for j = 1:numel(X)
B=X{j}(scae{i}.M);
B=B+rand(size(B))*1e-12;
B=(B.*(B==repmat(max(B,[],1),[size(B,1) 1])));
b=B(B~=0);
b=b(1:size(B,2));
b=reshape(b,ll/2);
hidden_tmp{j} = b;
% zero-padding----
for k = 1:3
if mod(size(hidden_tmp{j},k),2) ~= 0
if k == 1
hidden_tmp{j} = padarray(hidden_tmp{j},[1,0,0],'post');
elseif k == 2
hidden_tmp{j} = padarray(hidden_tmp{j},[0,1,0],'post');
else
hidden_tmp{j} = padarray(hidden_tmp{j},[0,0,1],'post');
end
end
end
%-----------------
end
scae{i}.hid{1}{l} = hidden_tmp;
% end
end
% X = scae{i}.a;
% ll = size(X{1});
% for j = 1:numel(X)
% B=X{j}(scae{i}.M);
% B=B+rand(size(B))*1e-12;
% B=(B.*(B==repmat(max(B,[],1),[size(B,1) 1])));
% b=B(B~=0);
% b=b(1:size(B,2));
% b=reshape(b,ll/2);
% hidden{j} = b;
% end
%
% scae{i}.hid = hidden;
% end
end
end
% scae{1} = caetrain_3d(scae{1}, x, opts);
end