From 789d5d83f06e9c8c5e4f2e93af1d4c9eb2134d43 Mon Sep 17 00:00:00 2001 From: Antoine Hoorelbeke Date: Thu, 11 Jul 2019 14:12:56 +0200 Subject: [PATCH] [fix] handle exception when real dataset has not enough examples in evaluation --- compare_gan/eval_utils.py | 7 ++++++- compare_gan/runner_lib.py | 6 +++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/compare_gan/eval_utils.py b/compare_gan/eval_utils.py index d0d8c34..9166dcf 100644 --- a/compare_gan/eval_utils.py +++ b/compare_gan/eval_utils.py @@ -53,6 +53,10 @@ class NanFoundError(Exception): """Exception thrown, when the Nans are present in the output.""" +class DatasetOutOfRangeError(Exception): + """Exception thrown, when the dataset has not enough samples.""" + + class EvalDataSample(object): """Helper class to hold images and Inception features for evaluation. @@ -127,11 +131,12 @@ def get_real_images(dataset, real_images[i] = b except tf.errors.OutOfRangeError: logging.error("Reached the end of dataset. Read: %d samples.", i) + real_images = real_images[:i] break if real_images.shape[0] != num_examples: if failure_on_insufficient_examples: - raise ValueError("Not enough examples in the dataset %s: %d / %d" % + raise DatasetOutOfRangeError("Not enough examples in the dataset %s: %d / %d" % (dataset, real_images.shape[0], num_examples)) else: logging.error("Not enough examples in the dataset %s: %d / %d", dataset, diff --git a/compare_gan/runner_lib.py b/compare_gan/runner_lib.py index 2aad6a3..3cac83a 100644 --- a/compare_gan/runner_lib.py +++ b/compare_gan/runner_lib.py @@ -28,6 +28,7 @@ from absl import logging from compare_gan import datasets from compare_gan import eval_gan_lib +from compare_gan.eval_utils import NanFoundError, DatasetOutOfRangeError from compare_gan import hooks from compare_gan.gans import utils from compare_gan.metrics import fid_score as fid_score_lib @@ -267,10 +268,13 @@ def _run_eval(module_spec, checkpoints, task_manager, run_config, result_dict = eval_gan_lib.evaluate_tfhub_module( export_path, eval_tasks, use_tpu=use_tpu, num_averaging_runs=num_averaging_runs) - except ValueError as nan_found_error: + except NanFoundError as nan_found_error: result_dict = {} logging.exception(nan_found_error) default_value = eval_gan_lib.NAN_DETECTED + except DatasetOutOfRangeError as dataset_out_of_range_error: + logging.exception(dataset_out_of_range_error) + break logging.info("Evaluation result for checkpoint %s: %s (default value: %s)", checkpoint_path, result_dict, default_value)