forked from Arsey/keras-transfer-learning-for-oxford102
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
62 lines (50 loc) · 1.65 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from __future__ import print_function
import os
import argparse
import traceback
import numpy as np
import util
import config
np.random.seed(1337) # for reproducibility
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', help='Path to data dir')
parser.add_argument('--model', type=str, help='Base model architecture', choices=[
config.MODEL_RESNET50,
config.MODEL_RESNET152,
config.MODEL_INCEPTION_V3,
config.MODEL_VGG16])
parser.add_argument('--nb_epoch', type=int, default=1000)
parser.add_argument('--freeze_layers_number',
type=int, help='will freeze the first N layers and unfreeze the rest')
return parser.parse_args()
def init():
util.lock()
util.set_img_format()
util.override_keras_directory_iterator_next()
util.set_classes_from_train_dir()
util.set_samples_info()
if not os.path.exists(config.trained_dir):
os.mkdir(config.trained_dir)
def train(nb_epoch, freeze_layers_number):
model = util.get_model_class_instance(
class_weight=util.get_class_weight(config.train_dir),
nb_epoch=nb_epoch,
freeze_layers_number=freeze_layers_number)
model.train()
print('Training is finished!')
if __name__ == '__main__':
try:
args = parse_args()
if args.data_dir:
config.data_dir = args.data_dir
config.set_paths()
if args.model:
config.model = args.model
init()
train(args.nb_epoch, args.freeze_layers_number)
except Exception as e:
print(e)
traceback.print_exc()
finally:
util.unlock()