forked from lukemelas/EfficientNet-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hubconf.py
43 lines (35 loc) · 1.67 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from efficientnet_pytorch import EfficientNet as _EfficientNet
dependencies = ['torch']
def _create_model_fn(model_name):
def _model_fn(num_classes=1000, in_channels=3, pretrained='imagenet'):
"""Create Efficient Net.
Described in detail here: https://arxiv.org/abs/1905.11946
Args:
num_classes (int, optional): Number of classes, default is 1000.
in_channels (int, optional): Number of input channels, default
is 3.
pretrained (str, optional): One of [None, 'imagenet', 'advprop']
If None, no pretrained model is loaded.
If 'imagenet', models trained on imagenet dataset are loaded.
If 'advprop', models trained using adversarial training called
advprop are loaded. It is important to note that the
preprocessing required for the advprop pretrained models is
slightly different from normal ImageNet preprocessing
"""
model_name_ = model_name.replace('_', '-')
if pretrained is not None:
model = _EfficientNet.from_pretrained(
model_name=model_name_,
advprop=(pretrained == 'advprop'),
num_classes=num_classes,
in_channels=in_channels)
else:
model = _EfficientNet.from_name(
model_name=model_name_,
override_params={'num_classes': num_classes},
)
model._change_in_channels(in_channels)
return model
return _model_fn
for model_name in ['efficientnet_b' + str(i) for i in range(9)]:
locals()[model_name] = _create_model_fn(model_name)