-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_LSTM_model.py
114 lines (96 loc) · 5.38 KB
/
train_LSTM_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from __future__ import print_function, absolute_import
import sys
sys.path.insert(0,'.')
import argparse
from datetime import datetime
import os
from training.train_LSTM import main
VECTOR_FOLDER = None
# Training parameters
LEARNING_RATE_LIST = [0.001, 0.0001, 0.00001]
NUM_TRAINING_STEPS = 10000
BATCH_SIZE = 3
STARTED_DATESTRING = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
DISPLAY_STEP = 10
SAVE_EVERY = 500
CHECKPOINT_EVERY = 500
MAX_PATIENCE = 500
PLATEAU_TOL = (5000, 0.0001) # Parameters for detecting plateau. (n_iterations, minimum_decrease_in_cost)
N_HIDDEN = [1000] # hidden layer num of features in LSTM
# If it's not specified in the arguments then a datestamped subfolder will be created for the model
MODEL_FOLDER = './training/saved_models'
def get_arguments():
def _str_to_bool(s):
"""Convert string to bool (in argparse context)."""
if s.lower() not in ['true', 'false']:
raise ValueError('Argument needs to be a '
'boolean, got {}'.format(s))
return {'true': True, 'false': False}[s.lower()]
parser = argparse.ArgumentParser(description='LSTM audio synth model')
parser.add_argument('--batch_size', type=int, default=BATCH_SIZE,
help='How many wav files to process at once.')
parser.add_argument('--vector_folder', type=str, default=VECTOR_FOLDER,
help='The directory containing the vectorised data.')
parser.add_argument('--store_metadata', type=bool, default=False,
help='Whether to store advanced debugging information '
'(execution time, memory consumption) for use with '
'TensorBoard.')
parser.add_argument('--model_folder', type=str, default=None,
help='Directory in which to restore the model from. '
'This creates the new model under the dated directory '
'in --logdir_root. '
'Cannot use with --logdir.')
parser.add_argument('--checkpoint_every', type=int, default=CHECKPOINT_EVERY,
help='How many steps to save each checkpoint after')
parser.add_argument('--num_training_steps', type=int, default=NUM_TRAINING_STEPS,
help='Number of training steps.')
parser.add_argument('--learning_rates', default=LEARNING_RATE_LIST, type=float, nargs='+',
help='Learning rate list for training.')
parser.add_argument('--lstm_hidden_units', default=N_HIDDEN, type=int, nargs='+',
help='Number of hidden units in each LSTM layer')
parser.add_argument('--display_step', type=int, default=DISPLAY_STEP,
help='How often to display training progress and save model.')
parser.add_argument('--grad_clip', type=float, default=5.,
help='Gradient clipping')
parser.add_argument('--save_every', type=int, default=SAVE_EVERY,
help='How often to save the model.')
parser.add_argument('--max_patience', type=int, default=MAX_PATIENCE,
help='Maximum number of iterations for patience.')
parser.add_argument('--plateau_tol', type=float, nargs='+',
help='Plateau tolerance. Number of iterations and minimum cost decrease.')
# parser.add_argument('--analysis_type', type=str, default=None,
# help='The analysis type to use on the audio data. Supported options are'
# '"stft" and "sine_model"')
# parser.add_argument('--wavenet_params', type=str, default=WAVENET_PARAMS,
# help='JSON file with the network parameters.')
# parser.add_argument('--sample_size', type=int, default=SAMPLE_SIZE,
# help='Concatenate and cut audio samples to this many '
# 'samples.')
# parser.add_argument('--l2_regularization_strength', type=float,
# default=L2_REGULARIZATION_STRENGTH,
# help='Coefficient in the L2 regularization. '
# 'Disabled by default')
# parser.add_argument('--silence_threshold', type=float,
# default=SILENCE_THRESHOLD,
# help='Volume threshold below which to trim the start '
# 'and the end from the training set samples.')
# parser.add_argument('--optimizer', type=str, default='adam',
# choices=optimizer_factory.keys(),
# help='Select the optimizer specified by this option.')
# parser.add_argument('--momentum', type=float,
# default=MOMENTUM, help='Specify the momentum to be '
# 'used by sgd or rmsprop optimizer. Ignored by the '
# 'adam optimizer.')
return parser.parse_args()
if __name__ == '__main__':
args = get_arguments()
if args.model_folder is None:
args.model_folder = MODEL_FOLDER + '/' + STARTED_DATESTRING
if args.vector_folder is None:
raise Exception('No vectors folder specified. (Use --vector_folder argument)')
else:
if not os.path.exists(args.vector_folder):
raise Exception('{} not found'.format(args.vector_folder))
if args.vector_folder[-1] == '/':
args.vector_folder = args.vector_folder[:-1]
main(args)