forked from Kashu7100/Qualia2.0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
alexnet.py
44 lines (41 loc) · 1.37 KB
/
alexnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# -*- coding: utf-8 -*-
from ..nn.modules.module import Sequential, Module
from ..nn.modules import Linear, Conv2d, MaxPool2d, SoftMax, ReLU, Dropout
class AlexNet(Module):
''' AlexNet \n
Args:
pretrained (bool): if true, load a pretrained weights
'''
def __init__(self, pretrained=False):
super().__init__()
self.features = Sequential(
Conv2d(3, 64, 11, stride=4, padding=2),
ReLU(),
MaxPool2d(kernel_size=3, stride=2),
Conv2d(64, 192, 5, padding=2),
ReLU(),
MaxPool2d(kernel_size=3, stride=2),
Conv2d(192, 384, 3),
ReLU(),
Conv2d(384, 256, 3),
ReLU(),
Conv2d(256, 256, 3),
ReLU(),
MaxPool2d(kernel_size=3, stride=2)
)
self.classifier = Sequential(
Dropout(0.5),
Linear(6*6*256, 4096),
ReLU(),
Dropout(0.5),
Linear(4096, 4096),
ReLU(),
Linear(4096, 1000),
SoftMax()
)
if pretrained:
self.load_state_dict_from_url('https://www.dropbox.com/s/2lgr0q2h6wyxkjg/alexnet.qla?dl=1', version=1)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.reshape(-1,6*6*256))
return x