-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_execution.py
59 lines (49 loc) · 1.83 KB
/
main_execution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()
from constants import external_drive_location
import callback as callback_file
from split_data_training_test import split_data_training_test as split_data
from categorize_data import categorize_data as categorize
from display_images import display_images as display
from data_generators import train_generator, validation_generator
from model import my_model
'''
split = split_data(data_directory)
split.move_training_to(training_directory)
print("Data split into training directory")
split.move_test_to(testing_directory)
print("Data split into test directory")
categories = categorize(data)
categories.in_directory(training_directory)
print("Data categorized in the training directory")
categories.in_directory(testing_directory)
print("Data categorized in the testing directory")
display("buffalo", training_directory).numberOfTimes(4)
'''
# Fit the model using the training and validation generators, specify the number of
# epochs to train for and how descriptive the output should be
try:
history = my_model.fit_generator(
train_generator,
validation_data = validation_generator,
steps_per_epoch = 100,
epochs = 20,
validation_steps = 50,
verbose = 1,
callbacks=[callback_file.myCallback()])
except:
print("Training interrupted, saving model")
my_model.save_weights(external_drive_location + "\\V3_Parthiv_Attempt_1_Interrupt.h5")
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend(loc=0)
plt.figure()
plt.show()