-
Notifications
You must be signed in to change notification settings - Fork 0
/
keras_utils.py
58 lines (50 loc) · 2.38 KB
/
keras_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import numpy as np
from keras import backend as K
from keras.legacy.interfaces import generate_legacy_interface, recurrent_args_preprocessor
from keras.models import model_from_json
legacy_prednet_support = generate_legacy_interface(
allowed_positional_args=['stack_sizes', 'R_stack_sizes',
'A_filt_sizes', 'Ahat_filt_sizes', 'R_filt_sizes'],
conversions=[('dim_ordering', 'data_format'),
('consume_less', 'implementation')],
value_conversions={'dim_ordering': {'tf': 'channels_last',
'th': 'channels_first',
'default': None},
'consume_less': {'cpu': 0,
'mem': 1,
'gpu': 2}},
preprocessor=recurrent_args_preprocessor)
# Convert old Keras (1.2) json models and weights to Keras 2.0
def convert_model_to_keras2(old_json_file, old_weights_file, new_json_file, new_weights_file):
from prednet import PredNet
# If using tensorflow, it doesn't allow you to load the old weights.
if K.backend() != 'theano':
os.environ['KERAS_BACKEND'] = backend
reload(K)
f = open(old_json_file, 'r')
json_string = f.read()
f.close()
model = model_from_json(json_string, custom_objects = {'PredNet': PredNet})
model.load_weights(old_weights_file)
weights = model.layers[1].get_weights()
if weights[0].shape[0] == model.layers[1].stack_sizes[1]:
for i, w in enumerate(weights):
if w.ndim == 4:
weights[i] = np.transpose(w, (2, 3, 1, 0))
model.set_weights(weights)
model.save_weights(new_weights_file)
json_string = model.to_json()
with open(new_json_file, "w") as f:
f.write(json_string)
if __name__ == '__main__':
old_dir = './model_data/'
new_dir = './model_data_keras2/'
if not os.path.exists(new_dir):
os.mkdir(new_dir)
for w_tag in ['', '-Lall', '-extrapfinetuned']:
m_tag = '' if w_tag == '-Lall' else w_tag
convert_model_to_keras2(old_dir + 'prednet_kitti_model' + m_tag + '.json',
old_dir + 'prednet_kitti_weights' + w_tag + '.hdf5',
new_dir + 'prednet_kitti_model' + m_tag + '.json',
new_dir + 'prednet_kitti_weights' + w_tag + '.hdf5')