-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
157 lines (112 loc) · 5.62 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import data_loader
import subdeq
import subdeq_conv
import train
import torch
import random
import numpy as np
worker = 0
#MNIST FC
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
print('=================================================================================')
print(f'MNIST (dense)')
print('=================================================================================')
print("\n")
train_loader,val_loeader, test_loader = data_loader.mnist_loaders(worker,test_batch_size=128*4)
print('=================================================================================')
print(f'Subdeq (Tanh)')
print('=================================================================================')
print("\n")
subdeq_tanhshift= subdeq.Subdeq_shift()
train.train(subdeq_tanhshift,train_loader,val_loeader, test_loader)
print('=================================================================================')
print(f'Subdeq (Normalized Tanh)')
print('=================================================================================')
print("\n")
subdeq_tanh = subdeq.Subdeq_shift(norm=np.inf,shift=1.603)
train.train(subdeq_tanh,train_loader,val_loeader, test_loader)
#MNIST conv
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
print('=================================================================================')
print(f'MNIST (Convolutional)')
print('=================================================================================')
print("\n")
print('=================================================================================')
print(f'Subdeq (Tanh)')
print('=================================================================================')
print("\n")
subdeq1 = subdeq_conv.subdeq_shifttanh(chan=32,pool=6,n=28,input_chan=1)
train.train(subdeq1,train_loader,val_loeader, test_loader,max_epochs=40)
print('=================================================================================')
print(f'Subdeq (Normalized Tanh)')
print('=================================================================================')
print("\n")
subdeq1 = subdeq_conv.subdeq_shifttanh(chan=32,pool=6,n=28,input_chan=1,norm=np.inf,shift=1.603)
train.train(subdeq1,train_loader,val_loeader, test_loader,max_epochs=40)
#CIFAR-10 conv
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
train_loader,val_loeader, test_loader = data_loader.cifar_loaders(worker,test_batch_size=128*4)
print('=================================================================================')
print(f'CIFAR-10 (Convolutional)')
print('=================================================================================')
print("\n")
print('=================================================================================')
print(f'Subdeq (Tanh)')
print('=================================================================================')
print("\n")
subdeq1 = subdeq_conv.subdeq_shifttanh(chan=48,pool=8,n=32,input_chan=3)
train.train(subdeq1,train_loader,val_loeader, test_loader,max_epochs=40)
print('=================================================================================')
print(f'Subdeq (Normalized Tanh)')
print('=================================================================================')
print("\n")
subdeq1 = subdeq_conv.subdeq_shifttanh(chan=48,pool=8,n=32,input_chan=3)
train.train(subdeq1,train_loader,val_loeader, test_loader,max_epochs=40,norm=np.inf,shift=1.603)
#SVHN conv
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
train_loader,val_loeader, test_loader = data_loader.svhn_loaders(worker, test_batch_size=128*4)
print('=================================================================================')
print(f'SVHN (Convolutional)')
print('=================================================================================')
print("\n")
print('=================================================================================')
print(f'Subdeq (Tanh)')
print('=================================================================================')
print("\n")
subdeq1 = subdeq_conv.subdeq_shifttanh(chan=48,pool=8,n=32,input_chan=3)
train.train(subdeq1,train_loader,val_loeader, test_loader,max_epochs=40)
print('=================================================================================')
print(f'Subdeq (Normalized Tanh)')
print('=================================================================================')
print("\n")
subdeq1 = subdeq_conv.subdeq_shifttanh(chan=48,pool=8,n=32,input_chan=3,norm=np.inf,shift=1.603)
train.train(subdeq1,train_loader,val_loeader, test_loader,max_epochs=40)
#Tiny ImageNet conv
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
train_loader,val_loader = data_loader.imagnet_loader()
print('=================================================================================')
print(f'Tiny ImageNet (Convolutional)')
print('=================================================================================')
print("\n")
print('=================================================================================')
print(f'Subdeq (Tanh)')
print('=================================================================================')
print("\n")
subdeq1 = subdeq_conv.subdeq_shifttanh(chan=64,pool=8,n=64,input_chan=3,out_class=200)
train.train(subdeq1,train_loader,val_loader, val_loader,max_epochs=50,runs = 3)
print('=================================================================================')
print(f'Subdeq (Normalized Tanh)')
print('=================================================================================')
print("\n")
subdeq1 = subdeq_conv.subdeq_shifttanh(chan=64,pool=8,n=64,input_chan=3,out_class=200,norm=np.inf,shift=1.603)
train.train(subdeq1,train_loader,val_loader, val_loader,max_epochs=50,runs = 3)