-
Notifications
You must be signed in to change notification settings - Fork 2
/
kitti_evaluate.py
112 lines (98 loc) · 4.34 KB
/
kitti_evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
'''
Evaluate trained PredNet on KITTI sequences.
Calculates mean-squared error and plots predictions.
'''
import os
import numpy as np
from six.moves import cPickle
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from keras import backend as K
from keras.models import Model, model_from_json
from keras.layers import Input, Dense, Flatten
from prednet import PredNet
from data_utils import SequenceGenerator
from kitti_settings import *
from skimage.measure import compare_ssim
from sklearn.metrics import mean_squared_error
n_plot = 40
batch_size = 10
nt = 10
DATA_DIR="/home/matin/Desktop/Datases/kitti"
weights_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_weights.hdf5')
json_file = os.path.join(WEIGHTS_DIR, 'prednet_kitti_model.json')
test_file = os.path.join(DATA_DIR, 'X_test.npy')
test_sources = os.path.join(DATA_DIR, 'source_test.npy')
# Load trained model
f = open(json_file, 'r')
json_string = f.read()
f.close()
train_model = model_from_json(json_string, custom_objects = {'PredNet': PredNet})
train_model.load_weights(weights_file)
# Create testing model (to output predictions)
layer_config = train_model.layers[1].get_config()
layer_config['output_mode'] = 'prediction'
data_format = layer_config['data_format'] if 'data_format' in layer_config else layer_config['dim_ordering']
test_prednet = PredNet(weights=train_model.layers[1].get_weights(), **layer_config)
input_shape = list(train_model.layers[0].batch_input_shape[1:])
input_shape[0] = nt
inputs = Input(shape=tuple(input_shape))
predictions = test_prednet(inputs)
test_model = Model(inputs=inputs, outputs=predictions)
test_generator = SequenceGenerator(test_file, test_sources, nt, sequence_start_mode='unique', data_format=data_format)
X_test = test_generator.create_all()
X_hat = test_model.predict(X_test, batch_size)
if data_format == 'channels_first':
X_test = np.transpose(X_test, (0, 1, 3, 4, 2))
X_hat = np.transpose(X_hat, (0, 1, 3, 4, 2))
# Compare MSE of PredNet predictions vs. using last frame. Write results to prediction_scores.txt
from skimage import measure as a
# Compare MSE of PredNet predictions vs. using last frame. Write results to prediction_scores.txt
mse_model = np.mean( (X_test[:, 1:] - X_hat[:, 1:])**2 ) # look at all timesteps except the first
mae_model = np.mean(np.abs(X_test[:, 1:] - X_hat[:, 1:]))
mse_prev = np.mean( (X_test[:, :-1] - X_test[:, 1:])**2 )
ssim=a.compare_ssim(X_test[:,1:],X_hat[:,1:],win_size=3,multichannel=True)
#s=np.empty(shape=(X_test.shape[0],nt))
#for i in range(X_test.shape[0]):
# for t in range(nt):
# s[i,t]=compare_ssim(X_test[i,t,:,:,:] - X_hat[i,t,:,:,:])
if not os.path.exists(RESULTS_SAVE_DIR): os.mkdir(RESULTS_SAVE_DIR)
f = open(RESULTS_SAVE_DIR + 'prediction_scores.txt', 'w')
f.write("Model MSE: %f\n" % mse_model)
f.write("Previous Frame MSE: %f" % mse_prev)
f.write("SSIM: %f" % ssim)
f.write("MAE: %f\n" %mae_model )
#f.write("SSIM: %f\n" %ssim )
f.close()
s = np.empty(shape=(X_test.shape[0],nt))
for i in range(X_test.shape[0]):
for t in range(nt):
s[i,t] = compare_ssim(X_test[i,t,:,:,:] , X_hat[i,t,:,:,:],multichannel=True)
np.savetxt('SSIM.txt', s)
m = np.empty(shape=(X_test.shape[0],nt))
for i in range(X_test.shape[0]):
for t in range(nt):
m[i,t] = np.mean((X_test[i,t,:,:,:] - X_hat[i,t,:,:,:])**2)
np.savetxt('MSE.txt', m)
# Plot some predictions
aspect_ratio = float(X_hat.shape[2]) / X_hat.shape[3]
plt.figure(figsize = (nt, 2*aspect_ratio))
gs = gridspec.GridSpec(2, nt)
gs.update(wspace=0., hspace=0.)
plot_save_dir = os.path.join(RESULTS_SAVE_DIR, 'prediction_plots/')
if not os.path.exists(plot_save_dir): os.mkdir(plot_save_dir)
plot_idx = np.random.permutation(X_test.shape[0])[:n_plot]
for i in plot_idx:
for t in range(nt):
plt.subplot(gs[t])
plt.imshow(X_test[i,t], interpolation='none')
plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off')
if t==0: plt.ylabel('Actual', fontsize=10)
plt.subplot(gs[t + nt])
plt.imshow(X_hat[i,t], interpolation='none')
plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off')
if t==0: plt.ylabel('Predicted', fontsize=10)
plt.savefig(plot_save_dir + 'plot_' + str(i) + '.png')
plt.clf()