Skip to content
This repository has been archived by the owner on Jan 10, 2023. It is now read-only.

Commit

Permalink
[fix] handle exception when real dataset has not enough examples in e…
Browse files Browse the repository at this point in the history
…valuation
  • Loading branch information
Antoine Hoorelbeke committed Jul 11, 2019
1 parent 068edab commit 62a7cc2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
7 changes: 6 additions & 1 deletion compare_gan/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class NanFoundError(Exception):
"""Exception thrown, when the Nans are present in the output."""


class DatasetOutOfRangeError(Exception):
"""Exception thrown, when the dataset has not enough samples."""


class EvalDataSample(object):
"""Helper class to hold images and Inception features for evaluation.
Expand Down Expand Up @@ -127,11 +131,12 @@ def get_real_images(dataset,
real_images[i] = b
except tf.errors.OutOfRangeError:
logging.error("Reached the end of dataset. Read: %d samples.", i)
real_images = real_images[:i]
break

if real_images.shape[0] != num_examples:
if failure_on_insufficient_examples:
raise ValueError("Not enough examples in the dataset %s: %d / %d" %
raise DatasetOutOfRangeError("Not enough examples in the dataset %s: %d / %d" %
(dataset, real_images.shape[0], num_examples))
else:
logging.error("Not enough examples in the dataset %s: %d / %d", dataset,
Expand Down
6 changes: 5 additions & 1 deletion compare_gan/runner_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from absl import logging
from compare_gan import datasets
from compare_gan import eval_gan_lib
from compare_gan.eval_utils import NanFoundError, DatasetOutOfRangeError
from compare_gan import hooks
from compare_gan.gans import utils
from compare_gan.metrics import fid_score as fid_score_lib
Expand Down Expand Up @@ -267,10 +268,13 @@ def _run_eval(module_spec, checkpoints, task_manager, run_config,
result_dict = eval_gan_lib.evaluate_tfhub_module(
export_path, eval_tasks, use_tpu=use_tpu,
num_averaging_runs=num_averaging_runs)
except ValueError as nan_found_error:
except NanFoundError as nan_found_error:
result_dict = {}
logging.exception(nan_found_error)
default_value = eval_gan_lib.NAN_DETECTED
except DatasetOutOfRangeError as dataset_out_of_range_error:
logging.exception(dataset_out_of_range_error)
break

logging.info("Evaluation result for checkpoint %s: %s (default value: %s)",
checkpoint_path, result_dict, default_value)
Expand Down

0 comments on commit 62a7cc2

Please sign in to comment.