forked from rdevooght/sequence-based-recommendations
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
103 lines (85 loc) · 4.61 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from __future__ import print_function
import sys
import numpy as np
import helpers.command_parser as parse
from helpers import early_stopping
from helpers.data_handling import DataHandler
def training_command_parser(parser):
parser.add_argument('--tshuffle', help='Shuffle sequences during training.', action='store_true')
parser.add_argument('--extended_set',
help='Use extended training set (contains first half of validation and test set).',
action='store_true')
parser.add_argument('-d', dest='dataset', help='Directory name of the dataset.', default='', type=str)
parser.add_argument('--dir', help='Directory name to save model.', default='', type=str)
parser.add_argument('--save', choices=['All', 'Best', 'None'], help='Policy for saving models.', default='Best')
parser.add_argument('--metrics', help='Metrics for validation, comma separated', default='sps', type=str)
parser.add_argument('--time_based_progress', help='Follow progress based on time rather than iterations.',
action='store_true')
parser.add_argument('--load_last_model', help='Load Last model before starting training.', action='store_true')
parser.add_argument('--progress', help='Progress intervals', default='2.', type=str)
parser.add_argument('--mpi', help='Max progress intervals', default=np.inf, type=float)
parser.add_argument('--max_iter', help='Max number of iterations', default=np.inf, type=float)
parser.add_argument('--max_time', help='Max training time in seconds', default=np.inf, type=float)
parser.add_argument('--min_iter', help='Min number of iterations before showing progress', default=0., type=float)
def num(s):
try:
return int(s)
except ValueError:
return float(s)
def main():
sys.argv.extend(['--tshuffle', '--load_last_model', # '--extended_set',
'-d', 'datasets/',
'--save', 'Best',
'--progress', '200', '--mpi', '1000.0',
'--max_iter', '6000.0', '--max_time', '28800.0', '--min_iter', '100.0',
'--es_m', 'StopAfterN', '--es_n', '3',
'-m', 'RNN', '--r_t', 'GRU', '--r_l', '100-50',
'--u_m', 'rmsprop',
'--rf'])
# ####################################################
# # for RNNCluster
# sys.argv.extend(['--dir', 'RNNCluster_',
# '--metrics', 'recall,cluster_recall,sps,cluster_sps,ignored_items,assr',
# '--loss', 'BPR', '--clusters', '10'])
# ####################################################
# for RNNOneHot
sys.argv.extend(['--dir', 'RNNOneHot_',
'--metrics', 'recall,sps', # ,ndcg,item_coverage,user_coverage,blockbuster_share
'--loss', 'CCE'])
# ####################################################
# # for RNNMargin
# sys.argv.extend(['--dir', 'RNNMargin_',
# '--metrics', 'recall,sps',
# '--loss', 'logit'])
# ####################################################
# # for RNNSampling
# sys.argv.extend(['--dir', 'RNNSampling_',
# '--metrics', 'recall,sps',
# '--loss', 'BPR'])
# ####################################################
# # without MOVIES_FEATURES
# sys.argv.extend(['--r_emb', '100'])
# # with MOVIES_FEATURES
sys.argv.extend(['--mf'])
# ####################################################
args = parse.command_parser(parse.predictor_command_parser, training_command_parser,
early_stopping.early_stopping_command_parser)
predictor = parse.get_predictor(args)
dataset = DataHandler(dirname=args.dataset, extended_training_set=args.extended_set, shuffle_training=args.tshuffle)
if args.mf:
predictor.load_movies_features(dirname=dataset.dirname)
predictor.prepare_model(dataset)
predictor.train(dataset,
save_dir=dataset.dirname + "models/" + args.dir,
time_based_progress=args.time_based_progress,
progress=num(args.progress),
autosave=args.save,
max_progress_interval=args.mpi,
max_iter=args.max_iter,
min_iterations=args.min_iter,
max_time=args.max_time,
early_stopping=early_stopping.get_early_stopper(args),
load_last_model=args.load_last_model,
validation_metrics=args.metrics.split(','))
if __name__ == '__main__':
main()