Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added example for prediction #11

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 51 additions & 17 deletions tensorflow/vision/model/input_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,26 @@ def _parse_function(filename, label, size):
return resized_image, label


def _parse_function_pred(filename, size):
"""Obtain the image from the filename (for prediction only).

The following operations are applied:
- Decode the image from jpeg format
- Convert to float and to range [0, 1]
"""
image_string = tf.read_file(filename)

# Don't use tf.image.decode_image, or the output shape will be undefined
image_decoded = tf.image.decode_jpeg(image_string, channels=3)

# This will convert to float values in [0, 1]
image = tf.image.convert_image_dtype(image_decoded, tf.float32)

resized_image = tf.image.resize_images(image, [size, size])

return resized_image


def train_preprocess(image, label, use_random_flip):
"""Image preprocessing for training.

Expand All @@ -42,45 +62,59 @@ def train_preprocess(image, label, use_random_flip):
return image, label


def input_fn(is_training, filenames, labels, params):
def input_fn(mode, filenames, labels, params):
"""Input function for the SIGNS dataset.

The filenames have format "{label}_IMG_{id}.jpg".
For instance: "data_dir/2_IMG_4584.jpg".

Args:
is_training: (bool) whether to use the train or test pipeline.
At training, we shuffle the data and have multiple epochs
mode: (tf.estimator.ModeKeys) Mode to choose between TRAIN, EVAL, and PREDICT pipelines.
At training, we shuffle the data and have multiple epochs.
At prediction, labels are not considered.
filenames: (list) filenames of the images, as ["data_dir/{label}_IMG_{id}.jpg"...]
labels: (list) corresponding list of labels
params: (Params) contains hyperparameters of the model (ex: `params.num_epochs`)
"""
num_samples = len(filenames)
assert len(filenames) == len(labels), "Filenames and labels should have same length"

# Create a Dataset serving batches of images and labels
# We don't repeat for multiple epochs because we always train and evaluate for one epoch
parse_fn = lambda f, l: _parse_function(f, l, params.image_size)
train_fn = lambda f, l: train_preprocess(f, l, params.use_random_flip)
parse_fn_pred = lambda f: _parse_function_pred(f, params.image_size)

if is_training:
if mode == tf.estimator.ModeKeys.TRAIN:
assert len(filenames) == len(labels), "Filenames and labels should have same length"
dataset = (tf.data.Dataset.from_tensor_slices((tf.constant(filenames), tf.constant(labels)))
.shuffle(num_samples) # whole dataset into the buffer ensures good shuffling
.map(parse_fn, num_parallel_calls=params.num_parallel_calls)
.map(train_fn, num_parallel_calls=params.num_parallel_calls)
.batch(params.batch_size)
.prefetch(1) # make sure you always have one batch ready to serve
)
else:
.shuffle(num_samples) # whole dataset into the buffer ensures good shuffling
.map(parse_fn, num_parallel_calls=params.num_parallel_calls)
.map(train_fn, num_parallel_calls=params.num_parallel_calls)
.batch(params.batch_size)
.prefetch(1) # make sure you always have one batch ready to serve
)
elif mode == tf.estimator.ModeKeys.EVAL:
assert len(filenames) == len(labels), "Filenames and labels should have same length"
dataset = (tf.data.Dataset.from_tensor_slices((tf.constant(filenames), tf.constant(labels)))
.map(parse_fn)
.batch(params.batch_size)
.prefetch(1) # make sure you always have one batch ready to serve
)
.map(parse_fn)
.batch(params.batch_size)
.prefetch(1) # make sure you always have one batch ready to serve
)
elif mode == tf.estimator.ModeKeys.PREDICT:
dataset = (tf.data.Dataset.from_tensor_slices(tf.constant(filenames))
.map(parse_fn_pred)
.batch(params.batch_size)
.prefetch(1) # make sure you always have one batch ready to serve
)
else:
assert False, "Unknown mode"

# Create reinitializable iterator from dataset
iterator = dataset.make_initializable_iterator()
images, labels = iterator.get_next()
if mode == tf.estimator.ModeKeys.PREDICT:
images = iterator.get_next()
else:
images, labels = iterator.get_next()
iterator_init_op = iterator.initializer

inputs = {'images': images, 'labels': labels, 'iterator_init_op': iterator_init_op}
Expand Down
22 changes: 13 additions & 9 deletions tensorflow/vision/model/model_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def model_fn(mode, inputs, params, reuse=False):
"""Model function defining the graph operations.

Args:
mode: (string) can be 'train' or 'eval'
mode: (tf.estimator.ModeKeys) Mode to choose between TRAIN, EVAL, and PREDICT pipelines.
inputs: (dict) contains the inputs of the graph (features, labels...)
this can be `tf.placeholder` or outputs of `tf.data`
params: (Params) contains hyperparameters of the model (ex: `params.learning_rate`)
Expand All @@ -60,16 +60,23 @@ def model_fn(mode, inputs, params, reuse=False):
Returns:
model_spec: (dict) contains the graph operations or nodes needed for training / evaluation
"""
is_training = (mode == 'train')
labels = inputs['labels']
labels = tf.cast(labels, tf.int64)
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
model_spec = inputs

# -----------------------------------------------------------
# MODEL: define the layers of the model
with tf.variable_scope('model', reuse=reuse):
# Compute the output distribution of the model and the predictions
logits = build_model(is_training, inputs, params)
predictions = tf.argmax(logits, 1)
model_spec["predictions"] = predictions

if mode == tf.estimator.ModeKeys.PREDICT:
model_spec['variable_init_op'] = tf.global_variables_initializer()
return model_spec

labels = inputs['labels']
labels = tf.cast(labels, tf.int64)

# Define loss and accuracy
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
Expand All @@ -86,7 +93,6 @@ def model_fn(mode, inputs, params, reuse=False):
else:
train_op = optimizer.minimize(loss, global_step=global_step)


# -----------------------------------------------------------
# METRICS AND SUMMARIES
# Metrics for evaluation using tf.metrics (average over whole dataset)
Expand All @@ -108,7 +114,7 @@ def model_fn(mode, inputs, params, reuse=False):
tf.summary.scalar('accuracy', accuracy)
tf.summary.image('train_image', inputs['images'])

#TODO: if mode == 'eval': ?
# TODO: if mode == tf.estimator.ModeKeys.EVAL: ?
# Add incorrectly labeled images
mask = tf.not_equal(labels, predictions)

Expand All @@ -122,15 +128,13 @@ def model_fn(mode, inputs, params, reuse=False):
# MODEL SPECIFICATION
# Create the model specification and return it
# It contains nodes or operations in the graph that will be used for training and evaluation
model_spec = inputs
model_spec['variable_init_op'] = tf.global_variables_initializer()
model_spec["predictions"] = predictions
model_spec['loss'] = loss
model_spec['accuracy'] = accuracy
model_spec['metrics_init_op'] = metrics_init_op
model_spec['metrics'] = metrics
model_spec['update_metrics'] = update_metrics_op
model_spec['summary_op'] = tf.summary.merge_all()
model_spec['variable_init_op'] = tf.global_variables_initializer()

if is_training:
model_spec['train_op'] = train_op
Expand Down
60 changes: 60 additions & 0 deletions tensorflow/vision/model/prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Tensorflow utility functions for evaluation"""

import logging
import os

from tqdm import trange
import tensorflow as tf

from model.utils import save_dict_to_json


def predict_sess(sess, model_spec, num_steps, writer=None, params=None):
"""Train the model on `num_steps` batches.

Args:
sess: (tf.Session) current session
model_spec: (dict) contains the graph operations or nodes needed for training
num_steps: (int) train for this number of batches
writer: (tf.summary.FileWriter) writer for summaries. Is None if we don't log anything
params: (Params) hyperparameters
"""
# Load the prediction dataset into the pipeline and initialize the metrics init op
sess.run(model_spec['iterator_init_op'])

# compute predictions over the dataset
all_pred = []
for _ in range(num_steps):
all_pred.extend(sess.run(model_spec['predictions']))

return all_pred


def prediction(model_spec, model_dir, params, restore_from):
"""Evaluate the model

Args:
model_spec: (dict) contains the graph operations or nodes needed for prediction
model_dir: (string) directory containing config, weights and log
params: (Params) contains hyperparameters of the model.
Must define: num_epochs, train_size, batch_size, eval_size, save_summary_steps
restore_from: (string) directory or file containing weights to restore the graph
"""
# Initialize tf.Saver
saver = tf.train.Saver()

with tf.Session() as sess:
# Initialize the lookup table
sess.run(model_spec['variable_init_op'])

# Reload weights from the weights subdirectory
save_path = os.path.join(model_dir, restore_from)
if os.path.isdir(save_path):
save_path = tf.train.latest_checkpoint(save_path)
saver.restore(sess, save_path)

# # Evaluate
num_steps = (params.eval_size + params.batch_size - 1) // params.batch_size
pred = predict_sess(sess, model_spec, num_steps)

return pred
6 changes: 2 additions & 4 deletions tensorflow/vision/model/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,14 @@ def train_sess(sess, model_spec, num_steps, writer, params):
# Evaluate summaries for tensorboard only once in a while
if i % params.save_summary_steps == 0:
# Perform a mini-batch update
_, _, loss_val, summ, global_step_val = sess.run([train_op, update_metrics, loss,
summary_op, global_step])
_, _, loss_val, summ, global_step_val = sess.run([train_op, update_metrics, loss, summary_op, global_step])
# Write summaries for tensorboard
writer.add_summary(summ, global_step_val)
else:
_, _, loss_val = sess.run([train_op, update_metrics, loss])
# Log the loss in the tqdm progress bar
t.set_postfix(loss='{:05.3f}'.format(loss_val))


metrics_values = {k: v[0] for k, v in metrics.items()}
metrics_val = sess.run(metrics_values)
metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_val.items())
Expand All @@ -66,7 +64,7 @@ def train_and_evaluate(train_model_spec, eval_model_spec, model_dir, params, res
restore_from: (string) directory or file containing weights to restore the graph
"""
# Initialize tf.Saver instances to save weights during training
last_saver = tf.train.Saver() # will keep last 5 epochs
last_saver = tf.train.Saver() # will keep last 5 epochs
best_saver = tf.train.Saver(max_to_keep=1) # only keep 1 best checkpoint (best on eval)
begin_at_epoch = 0

Expand Down
67 changes: 67 additions & 0 deletions tensorflow/vision/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Evaluate the model"""

import argparse
import logging
import os

import tensorflow as tf

from model.input_fn import input_fn
from model.model_fn import model_fn
from model.prediction import prediction
from model.utils import Params
from model.utils import set_logger


parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', default='experiments/test',
help="Experiment directory containing params.json")
parser.add_argument('--data_dir', default='data/64x64_SIGNS',
help="Directory containing the dataset")
parser.add_argument('--restore_from', default='best_weights',
help="Subdirectory of model dir or file containing the weights")


if __name__ == '__main__':
# Set the random seed for the whole graph
tf.set_random_seed(230)

# Load the parameters
args = parser.parse_args()
json_path = os.path.join(args.model_dir, 'params.json')
assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
params = Params(json_path)

# Set the logger
set_logger(os.path.join(args.model_dir, 'prediction.log'))

# Create the input data pipeline
logging.info("Creating the dataset...")
data_dir = args.data_dir
test_data_dir = os.path.join(data_dir, "test_signs")

# Get the filenames from the test set
test_filenames = os.listdir(test_data_dir)
test_filenames = [os.path.join(test_data_dir, f) for f in test_filenames if f.endswith('.jpg')]

test_labels = [int(f.split('/')[-1][0]) for f in test_filenames]

# specify the size of the evaluation set
params.eval_size = len(test_filenames)
logging.info('Read {} image filenames for prediction.'.format(params.eval_size))

# create the iterator over the dataset
test_inputs = input_fn(tf.estimator.ModeKeys.PREDICT, test_filenames, None, params)

# Define the model
logging.info("Creating the model...")
model_spec = model_fn(tf.estimator.ModeKeys.PREDICT, test_inputs, params, reuse=False)

# logging.info("Starting prediction")
pred = prediction(model_spec, args.model_dir, params, args.restore_from)

for f, p, l in zip(test_filenames, pred, test_labels):
if p == l:
print('{}: Label {} has been predicted with {}... Correct!'.format(f, l, p))
else:
print('{}: Label {} has been predicted with {}... Not correct!'.format(f, l, p))
14 changes: 6 additions & 8 deletions tensorflow/vision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,8 @@
dev_data_dir = os.path.join(data_dir, "dev_signs")

# Get the filenames from the train and dev sets
train_filenames = [os.path.join(train_data_dir, f) for f in os.listdir(train_data_dir)
if f.endswith('.jpg')]
eval_filenames = [os.path.join(dev_data_dir, f) for f in os.listdir(dev_data_dir)
if f.endswith('.jpg')]
train_filenames = [os.path.join(train_data_dir, f) for f in os.listdir(train_data_dir) if f.endswith('.jpg')]
eval_filenames = [os.path.join(dev_data_dir, f) for f in os.listdir(dev_data_dir) if f.endswith('.jpg')]

# Labels will be between 0 and 5 included (6 classes in total)
train_labels = [int(f.split('/')[-1][0]) for f in train_filenames]
Expand All @@ -64,13 +62,13 @@
params.eval_size = len(eval_filenames)

# Create the two iterators over the two datasets
train_inputs = input_fn(True, train_filenames, train_labels, params)
eval_inputs = input_fn(False, eval_filenames, eval_labels, params)
train_inputs = input_fn(tf.estimator.ModeKeys.TRAIN, train_filenames, train_labels, params)
eval_inputs = input_fn(tf.estimator.ModeKeys.EVAL, eval_filenames, eval_labels, params)

# Define the model
logging.info("Creating the model...")
train_model_spec = model_fn('train', train_inputs, params)
eval_model_spec = model_fn('eval', eval_inputs, params, reuse=True)
train_model_spec = model_fn(tf.estimator.ModeKeys.TRAIN, train_inputs, params)
eval_model_spec = model_fn(tf.estimator.ModeKeys.EVAL, eval_inputs, params, reuse=True)

# Train the model
logging.info("Starting training for {} epoch(s)".format(params.num_epochs))
Expand Down