diff --git a/scripts/tf_cnn_benchmarks/preprocessing.py b/scripts/tf_cnn_benchmarks/preprocessing.py index a0fa209d..6c0eda14 100644 --- a/scripts/tf_cnn_benchmarks/preprocessing.py +++ b/scripts/tf_cnn_benchmarks/preprocessing.py @@ -858,24 +858,29 @@ def minibatch(self, subset, params, shift_ratio=-1): - # TODO(jsimsa): Implement datasets code path del shift_ratio, params with tf.name_scope('batch_processing'): all_images, all_labels = dataset.read_data_files(subset) all_images = tf.constant(all_images) all_labels = tf.constant(all_labels) - input_image, input_label = tf.train.slice_input_producer( - [all_images, all_labels]) - input_image = tf.cast(input_image, self.dtype) - input_label = tf.cast(input_label, tf.int32) - # Ensure that the random shuffling has good mixing properties. + input_image = tf.cast(all_images, self.dtype) + input_label = tf.cast(all_labels, tf.int32) + dataset_train = tf.data.Dataset.from_tensor_slices( + (input_image, input_label)) + min_fraction_of_examples_in_queue = 0.4 min_queue_examples = int(dataset.num_examples_per_epoch(subset) * min_fraction_of_examples_in_queue) - raw_images, raw_labels = tf.train.shuffle_batch( - [input_image, input_label], batch_size=self.batch_size, - capacity=min_queue_examples + 3 * self.batch_size, - min_after_dequeue=min_queue_examples) + + dataset_train = dataset_train.shuffle(min_queue_examples).batch( + self.batch_size, drop_remainder=True) + + if tf.VERSION > "1.12": + raw_images, raw_labels = tf.compat.v1.data.make_one_shot_iterator( + dataset_train).get_next() + else: + raw_images, raw_labels = dataset_train.make_one_shot_iterator( + ).get_next() images = [[] for i in range(self.num_splits)] labels = [[] for i in range(self.num_splits)]