From 31ae57eb8555e51493a4721f6efbea68d98a3d69 Mon Sep 17 00:00:00 2001 From: pkulzc Date: Fri, 2 Nov 2018 08:48:34 -0700 Subject: [PATCH] Minor fixes for object detection (#5613) * Internal change. PiperOrigin-RevId: 213914693 * Add original_image_spatial_shape tensor in input dictionary to store shape of the original input image PiperOrigin-RevId: 214018767 * Remove "groundtruth_confidences" from decoders use "groundtruth_weights" to indicate label confidence. This also solves a bug that only surfaced now - random crop routines in core/preprocessor.py did not correctly handle "groundtruth_weight" tensors returned by the decoders. PiperOrigin-RevId: 214091843 * Update CocoMaskEvaluator to allow for a batch of image info, rather than a single image. PiperOrigin-RevId: 214295305 * Adding the option to be able to summarize gradients. PiperOrigin-RevId: 214310875 * Adds FasterRCNN inference on CPU 1. Adds a flag use_static_shapes_for_eval to restrict to the ops that guarantees static shape. 2. No filtering of overlapping anchors while clipping the anchors when use_static_shapes_for_eval is set to True. 3. Adds test for faster_rcnn_meta_arch for predict and postprocess in inference mode for first and second stages. PiperOrigin-RevId: 214329565 * Fix model_lib eval_spec_names assignment (integer->string). PiperOrigin-RevId: 214335461 * Refactor Mask HEAD to optionally upsample after applying convolutions on ROI crops. PiperOrigin-RevId: 214338440 * Uses final_exporter_name as exporter_name for the first eval spec for backward compatibility. PiperOrigin-RevId: 214522032 * Add reshaped `mask_predictions` tensor to the prediction dictionary in `_predict_third_stage` method to allow computing mask loss in eval job. PiperOrigin-RevId: 214620716 * Add support for fully conv training to fpn. PiperOrigin-RevId: 214626274 * Fix the proprocess() function in Resnet v1 to make it work for any number of input channels. Note: If the #channels != 3, this will simply skip the mean subtraction in preprocess() function. PiperOrigin-RevId: 214635428 * Wrap result_dict_for_single_example in eval_util to run for batched examples. PiperOrigin-RevId: 214678514 * Adds PNASNet-based (ImageNet model) feature extractor for SSD. PiperOrigin-RevId: 214988331 * Update documentation PiperOrigin-RevId: 215243502 * Correct index used to compute number of groundtruth/detection boxes in COCOMaskEvaluator. Due to an incorrect indexing in cl/214295305 only the first detection mask and first groundtruth mask for a given image are fed to the COCO Mask evaluation library. Since groundtruth masks are arranged in no particular order, the first and highest scoring detection mask (detection masks are ordered by score) won't match the the first and only groundtruth retained in all cases. This is I think why mask evaluation metrics do not get better than ~11 mAP. Note that this code path is only active when using model_main.py binary for evaluation. This change fixes the indices and modifies an existing test case to cover it. PiperOrigin-RevId: 215275936 * Fixing grayscale_image_resizer to accept mask as input. PiperOrigin-RevId: 215345836 * Add an option not to clip groundtruth boxes during preprocessing. Clipping boxes adversely affects training for partially occluded or large objects, especially for fully conv models. Clipping already occurs during postprocessing, and should not occur during training. PiperOrigin-RevId: 215613379 * Always return recalls and precisions with length equal to the number of classes. The previous behavior of ObjectDetectionEvaluation was somewhat dangerous: when no groundtruth boxes were present, the lists of per-class precisions and recalls were simply truncated. Unless you were aware of this phenomenon (and consulted the `num_gt_instances_per_class` vector) it was difficult to associate each metric with each class. PiperOrigin-RevId: 215633711 * Expose the box feature node in SSD. PiperOrigin-RevId: 215653316 * Fix ssd mobilenet v2 _CONV_DEFS overwriting issue. PiperOrigin-RevId: 215654160 * More documentation updates PiperOrigin-RevId: 215656580 * Add pooling + residual option in multi_resolution_feature_maps. It adds an average pooling and a residual layer between feature maps with matching depth. Designed to be used with WeightSharedBoxPredictor. PiperOrigin-RevId: 215665619 * Only call create_modificed_mobilenet_config on init if use_depthwise is true. PiperOrigin-RevId: 215784290 * Only call create_modificed_mobilenet_config on init if use_depthwise is true. PiperOrigin-RevId: 215837524 * Don't prune keypoints if clip_boxes is false. PiperOrigin-RevId: 216187642 * Makes sure "key" field exists in the result dictionary. PiperOrigin-RevId: 216456543 * Add add_background_class parameter to allow disabling the inclusion of a background class. PiperOrigin-RevId: 216567612 * Update expected_classification_loss_under_sampling to better account for expected sampling. PiperOrigin-RevId: 216712287 * Let the evaluation receive a evaluation class in its constructor. PiperOrigin-RevId: 216769374 * This CL adds model building & training support for end-to-end Keras-based SSD models. If a Keras feature extractor's name is specified in the model config (e.g. 'ssd_mobilenet_v2_keras'), the model will use that feature extractor and a corresponding Keras-based box predictor. This CL makes sure regularization losses & batch norm updates work correctly when training models that have Keras-based components. It also updates the default hyperparameter settings of the keras-based mobilenetV2 (when not overriding hyperparams) to more closely match the legacy Slim training scope. PiperOrigin-RevId: 216938707 * Adding the ability in the coco evaluator to indicate whether an image has been annotated. For a non-annotated image, detections and groundtruth are not supplied. PiperOrigin-RevId: 217316342 * Release the 8k minival dataset ids for MSCOCO, used in Huang et al. "Speed/accuracy trade-offs for modern convolutional object detectors" (https://arxiv.org/abs/1611.10012) PiperOrigin-RevId: 217549353 * Exposes weighted_sigmoid_focal loss for faster rcnn classifier PiperOrigin-RevId: 217601740 * Add detection_features to output nodes. The shape of the feature is [batch_size, max_detections, depth]. PiperOrigin-RevId: 217629905 * FPN uses a custom NN resize op for TPU-compatibility. Replace this op with the Tensorflow version at export time for TFLite-compatibility. PiperOrigin-RevId: 217721184 * Compute `num_groundtruth_boxes` in inputs.tranform_input_data_fn after data augmentation instead of decoders. PiperOrigin-RevId: 217733432 * 1. Stop gradients from flowing into groundtruth masks with zero paddings. 2. Normalize pixelwise cross entropy loss across the whole batch. PiperOrigin-RevId: 217735114 * Optimize Input pipeline for Mask R-CNN on TPU with blfoat16: improve the step time from: 1663.6 ms -> 1184.2 ms, about 28.8% improvement. PiperOrigin-RevId: 217748833 * Fixes to export a TPU compatible model Adds nodes to each of the output tensor. Also increments the value of class labels by 1. PiperOrigin-RevId: 217856760 * API changes: - change the interface of target assigner to return per-class weights. - change the interface of classification loss to take per-class weights. PiperOrigin-RevId: 217968393 * Add an option to override pipeline config in export_saved_model using command line arg PiperOrigin-RevId: 218429292 * Include Quantized trained MobileNet V2 SSD and FaceSsd in model zoo. PiperOrigin-RevId: 218530947 * Write final config to disk in `train` mode only. PiperOrigin-RevId: 218735512 --- .../multiscale_grid_anchor_generator.py | 12 +- .../multiscale_grid_anchor_generator_test.py | 18 +- .../builders/box_predictor_builder.py | 188 +- .../builders/box_predictor_builder_test.py | 53 +- .../builders/hyperparams_builder.py | 4 + .../builders/image_resizer_builder.py | 35 +- .../builders/losses_builder.py | 8 + .../builders/losses_builder_test.py | 15 +- .../builders/model_builder.py | 102 +- .../builders/model_builder_test.py | 87 +- .../builders/preprocessor_builder.py | 9 + .../builders/preprocessor_builder_test.py | 22 +- research/object_detection/core/losses.py | 26 +- research/object_detection/core/losses_test.py | 200 +- research/object_detection/core/matcher.py | 6 +- research/object_detection/core/model.py | 25 + .../object_detection/core/post_processing.py | 53 +- .../object_detection/core/preprocessor.py | 76 +- .../core/preprocessor_test.py | 269 +- .../object_detection/core/standard_fields.py | 1 - .../object_detection/core/target_assigner.py | 15 +- .../core/target_assigner_test.py | 137 +- .../data/face_label_map.pbtxt | 6 + .../data/mscoco_minival_ids.txt | 8059 +++++++++++++++++ .../data_decoders/tf_example_decoder.py | 4 - .../data_decoders/tf_example_decoder_test.py | 4 - research/object_detection/eval_util.py | 249 +- research/object_detection/eval_util_test.py | 96 +- .../export_tflite_ssd_graph_lib.py | 6 +- .../export_tflite_ssd_graph_lib_test.py | 29 + research/object_detection/exporter.py | 66 +- research/object_detection/exporter_test.py | 55 + .../g3doc/detection_model_zoo.md | 3 + .../object_detection/g3doc/installation.md | 4 +- research/object_detection/inputs.py | 8 +- research/object_detection/inputs_test.py | 3 + .../object_detection/legacy/trainer_test.py | 23 + .../faster_rcnn_meta_arch.py | 125 +- .../faster_rcnn_meta_arch_test.py | 5 +- .../faster_rcnn_meta_arch_test_lib.py | 462 +- .../meta_architectures/ssd_meta_arch.py | 168 +- .../meta_architectures/ssd_meta_arch_test.py | 21 +- .../ssd_meta_arch_test_lib.py | 10 +- .../metrics/coco_evaluation.py | 153 +- .../metrics/coco_evaluation_test.py | 243 +- research/object_detection/model_lib.py | 34 +- research/object_detection/model_lib_test.py | 2 +- research/object_detection/model_tpu_main.py | 1 + ...faster_rcnn_resnet_v1_feature_extractor.py | 9 +- .../models/feature_map_generators.py | 39 +- .../models/feature_map_generators_test.py | 39 +- .../models/keras_applications/mobilenet_v2.py | 7 + .../ssd_mobilenet_v1_fpn_feature_extractor.py | 20 +- .../ssd_mobilenet_v2_fpn_feature_extractor.py | 24 +- ...sd_mobilenet_v2_keras_feature_extractor.py | 61 +- .../models/ssd_pnasnet_feature_extractor.py | 175 + .../ssd_pnasnet_feature_extractor_test.py | 87 + .../ssd_resnet_v1_fpn_feature_extractor.py | 9 +- ...esnet_v1_fpn_feature_extractor_testbase.py | 11 +- .../ssd_resnet_v1_ppn_feature_extractor.py | 9 +- ...esnet_v1_ppn_feature_extractor_testbase.py | 10 +- .../convolutional_keras_box_predictor.py | 31 +- .../convolutional_keras_box_predictor_test.py | 4 +- .../predictors/heads/class_head.py | 60 +- .../predictors/heads/class_head_test.py | 12 +- .../predictors/heads/keras_box_head.py | 4 +- .../predictors/heads/keras_class_head.py | 14 +- .../predictors/heads/keras_class_head_test.py | 4 +- .../predictors/heads/keras_mask_head.py | 4 +- .../predictors/heads/mask_head.py | 34 +- .../predictors/heads/mask_head_test.py | 16 + .../protos/box_predictor.proto | 7 + .../object_detection/protos/faster_rcnn.proto | 4 + .../protos/preprocessor.proto | 21 + research/object_detection/protos/ssd.proto | 10 +- research/object_detection/protos/train.proto | 5 +- ..._v2_quantized_320x320_open_image_v4.config | 211 + ...mobilenet_v2_quantized_300x300_coco.config | 202 + .../object_detection/utils/config_util.py | 8 +- .../utils/object_detection_evaluation.py | 40 +- research/object_detection/utils/ops.py | 103 +- research/object_detection/utils/ops_test.py | 194 +- research/object_detection/utils/test_utils.py | 18 +- 83 files changed, 11727 insertions(+), 979 deletions(-) create mode 100644 research/object_detection/data/face_label_map.pbtxt create mode 100644 research/object_detection/data/mscoco_minival_ids.txt create mode 100644 research/object_detection/models/ssd_pnasnet_feature_extractor.py create mode 100644 research/object_detection/models/ssd_pnasnet_feature_extractor_test.py create mode 100644 research/object_detection/samples/configs/facessd_mobilenet_v2_quantized_320x320_open_image_v4.config create mode 100644 research/object_detection/samples/configs/ssd_mobilenet_v2_quantized_300x300_coco.config diff --git a/research/object_detection/anchor_generators/multiscale_grid_anchor_generator.py b/research/object_detection/anchor_generators/multiscale_grid_anchor_generator.py index c5afd547669..cd2440a462c 100644 --- a/research/object_detection/anchor_generators/multiscale_grid_anchor_generator.py +++ b/research/object_detection/anchor_generators/multiscale_grid_anchor_generator.py @@ -108,9 +108,6 @@ def _generate(self, feature_map_shape_list, im_height=1, im_width=1): ValueError: if im_height and im_width are 1, but normalized coordinates were requested. """ - if not isinstance(im_height, int) or not isinstance(im_width, int): - raise ValueError('MultiscaleGridAnchorGenerator currently requires ' - 'input image shape to be statically defined.') anchor_grid_list = [] for feat_shape, grid_info in zip(feature_map_shape_list, self._anchor_grid_info): @@ -122,10 +119,11 @@ def _generate(self, feature_map_shape_list, im_height=1, im_width=1): feat_h = feat_shape[0] feat_w = feat_shape[1] anchor_offset = [0, 0] - if im_height % 2.0**level == 0 or im_height == 1: - anchor_offset[0] = stride / 2.0 - if im_width % 2.0**level == 0 or im_width == 1: - anchor_offset[1] = stride / 2.0 + if isinstance(im_height, int) and isinstance(im_width, int): + if im_height % 2.0**level == 0 or im_height == 1: + anchor_offset[0] = stride / 2.0 + if im_width % 2.0**level == 0 or im_width == 1: + anchor_offset[1] = stride / 2.0 ag = grid_anchor_generator.GridAnchorGenerator( scales, aspect_ratios, diff --git a/research/object_detection/anchor_generators/multiscale_grid_anchor_generator_test.py b/research/object_detection/anchor_generators/multiscale_grid_anchor_generator_test.py index ed5c90ce70c..178705c1943 100644 --- a/research/object_detection/anchor_generators/multiscale_grid_anchor_generator_test.py +++ b/research/object_detection/anchor_generators/multiscale_grid_anchor_generator_test.py @@ -116,7 +116,7 @@ def test_num_anchors_per_location(self): normalize_coordinates=False) self.assertEqual(anchor_generator.num_anchors_per_location(), [6, 6]) - def test_construct_single_anchor_fails_with_tensor_image_size(self): + def test_construct_single_anchor_dynamic_size(self): min_level = 5 max_level = 5 anchor_scale = 4.0 @@ -125,12 +125,22 @@ def test_construct_single_anchor_fails_with_tensor_image_size(self): im_height = tf.constant(64) im_width = tf.constant(64) feature_map_shape_list = [(2, 2)] + # Zero offsets are used. + exp_anchor_corners = [[-64, -64, 64, 64], + [-64, -32, 64, 96], + [-32, -64, 96, 64], + [-32, -32, 96, 96]] + anchor_generator = mg.MultiscaleGridAnchorGenerator( min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave, normalize_coordinates=False) - with self.assertRaisesRegexp(ValueError, 'statically defined'): - anchor_generator.generate( - feature_map_shape_list, im_height=im_height, im_width=im_width) + anchors_list = anchor_generator.generate( + feature_map_shape_list, im_height=im_height, im_width=im_width) + anchor_corners = anchors_list[0].get() + + with self.test_session(): + anchor_corners_out = anchor_corners.eval() + self.assertAllClose(anchor_corners_out, exp_anchor_corners) def test_construct_single_anchor_with_odd_input_dimension(self): diff --git a/research/object_detection/builders/box_predictor_builder.py b/research/object_detection/builders/box_predictor_builder.py index 97e1ea77a0b..a3f4a846cfe 100644 --- a/research/object_detection/builders/box_predictor_builder.py +++ b/research/object_detection/builders/box_predictor_builder.py @@ -42,6 +42,7 @@ def build_convolutional_box_predictor(is_training, kernel_size, box_code_size, apply_sigmoid_to_scores=False, + add_background_class=True, class_prediction_bias_init=0.0, use_depthwise=False, mask_head_config=None): @@ -49,7 +50,10 @@ def build_convolutional_box_predictor(is_training, Args: is_training: Indicates whether the BoxPredictor is in training mode. - num_classes: Number of classes. + num_classes: number of classes. Note that num_classes *does not* + include the background category, so if groundtruth labels take values + in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the + assigned classification targets can range from {0,... K}). conv_hyperparams_fn: A function to generate tf-slim arg_scope with hyperparameters for convolution ops. min_depth: Minimum feature depth prior to predicting box encodings @@ -71,6 +75,7 @@ def build_convolutional_box_predictor(is_training, box_code_size: Size of encoding for each box. apply_sigmoid_to_scores: If True, apply the sigmoid on the output class_predictions. + add_background_class: Whether to add an implicit background class. class_prediction_bias_init: Constant value to initialize bias of the last conv2d layer before class prediction. use_depthwise: Whether to use depthwise convolutions for prediction @@ -88,7 +93,7 @@ def build_convolutional_box_predictor(is_training, use_depthwise=use_depthwise) class_prediction_head = class_head.ConvolutionalClassHead( is_training=is_training, - num_classes=num_classes, + num_class_slots=num_classes + 1 if add_background_class else num_classes, use_dropout=use_dropout, dropout_keep_prob=dropout_keep_prob, kernel_size=kernel_size, @@ -136,15 +141,19 @@ def build_convolutional_keras_box_predictor(is_training, dropout_keep_prob, kernel_size, box_code_size, + add_background_class=True, class_prediction_bias_init=0.0, use_depthwise=False, mask_head_config=None, name='BoxPredictor'): - """Builds the ConvolutionalBoxPredictor from the arguments. + """Builds the Keras ConvolutionalBoxPredictor from the arguments. Args: is_training: Indicates whether the BoxPredictor is in training mode. - num_classes: Number of classes. + num_classes: number of classes. Note that num_classes *does not* + include the background category, so if groundtruth labels take values + in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the + assigned classification targets can range from {0,... K}). conv_hyperparams: A `hyperparams_builder.KerasLayerHyperparams` object containing hyperparameters for convolution ops. freeze_batchnorm: Whether to freeze batch norm parameters during @@ -175,6 +184,7 @@ def build_convolutional_keras_box_predictor(is_training, then the kernel size is automatically set to be min(feature_width, feature_height). box_code_size: Size of encoding for each box. + add_background_class: Whether to add an implicit background class. class_prediction_bias_init: constant value to initialize bias of the last conv2d layer before class prediction. use_depthwise: Whether to use depthwise convolutions for prediction @@ -185,7 +195,7 @@ def build_convolutional_keras_box_predictor(is_training, will auto-generate one from the class name. Returns: - A ConvolutionalBoxPredictor class. + A Keras ConvolutionalBoxPredictor class. """ box_prediction_heads = [] class_prediction_heads = [] @@ -210,7 +220,8 @@ def build_convolutional_keras_box_predictor(is_training, class_prediction_heads.append( keras_class_head.ConvolutionalClassHead( is_training=is_training, - num_classes=num_classes, + num_class_slots=( + num_classes + 1 if add_background_class else num_classes), use_dropout=use_dropout, dropout_keep_prob=dropout_keep_prob, kernel_size=kernel_size, @@ -264,6 +275,7 @@ def build_weight_shared_convolutional_box_predictor( num_layers_before_predictor, box_code_size, kernel_size=3, + add_background_class=True, class_prediction_bias_init=0.0, use_dropout=False, dropout_keep_prob=0.8, @@ -288,6 +300,7 @@ def build_weight_shared_convolutional_box_predictor( the predictor. box_code_size: Size of encoding for each box. kernel_size: Size of final convolution kernel. + add_background_class: Whether to add an implicit background class. class_prediction_bias_init: constant value to initialize bias of the last conv2d layer before class prediction. use_dropout: Whether to apply dropout to class prediction head. @@ -313,7 +326,8 @@ class scores. box_encodings_clip_range=box_encodings_clip_range) class_prediction_head = ( class_head.WeightSharedConvolutionalClassHead( - num_classes=num_classes, + num_class_slots=( + num_classes + 1 if add_background_class else num_classes), kernel_size=kernel_size, class_prediction_bias_init=class_prediction_bias_init, use_dropout=use_dropout, @@ -355,6 +369,7 @@ def build_mask_rcnn_box_predictor(is_training, use_dropout, dropout_keep_prob, box_code_size, + add_background_class=True, share_box_across_classes=False, predict_instance_masks=False, conv_hyperparams_fn=None, @@ -362,40 +377,46 @@ def build_mask_rcnn_box_predictor(is_training, mask_width=14, mask_prediction_num_conv_layers=2, mask_prediction_conv_depth=256, - masks_are_class_agnostic=False): + masks_are_class_agnostic=False, + convolve_then_upsample_masks=False): """Builds and returns a MaskRCNNBoxPredictor class. Args: - is_training: Indicates whether the BoxPredictor is in training mode. - num_classes: number of classes. Note that num_classes *does not* - include the background category, so if groundtruth labels take values - in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the - assigned classification targets can range from {0,... K}). - fc_hyperparams_fn: A function to generate tf-slim arg_scope with - hyperparameters for fully connected ops. - use_dropout: Option to use dropout or not. Note that a single dropout - op is applied here prior to both box and class predictions, which stands - in contrast to the ConvolutionalBoxPredictor below. - dropout_keep_prob: Keep probability for dropout. - This is only used if use_dropout is True. - box_code_size: Size of encoding for each box. - share_box_across_classes: Whether to share boxes across classes rather - than use a different box for each class. - predict_instance_masks: If True, will add a third stage mask prediction - to the returned class. - conv_hyperparams_fn: A function to generate tf-slim arg_scope with - hyperparameters for convolution ops. - mask_height: Desired output mask height. The default value is 14. - mask_width: Desired output mask width. The default value is 14. - mask_prediction_num_conv_layers: Number of convolution layers applied to - the image_features in mask prediction branch. - mask_prediction_conv_depth: The depth for the first conv2d_transpose op - applied to the image_features in the mask prediction branch. If set - to 0, the depth of the convolution layers will be automatically chosen - based on the number of object classes and the number of channels in the - image features. - masks_are_class_agnostic: Boolean determining if the mask-head is - class-agnostic or not. + is_training: Indicates whether the BoxPredictor is in training mode. + num_classes: number of classes. Note that num_classes *does not* + include the background category, so if groundtruth labels take values + in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the + assigned classification targets can range from {0,... K}). + fc_hyperparams_fn: A function to generate tf-slim arg_scope with + hyperparameters for fully connected ops. + use_dropout: Option to use dropout or not. Note that a single dropout + op is applied here prior to both box and class predictions, which stands + in contrast to the ConvolutionalBoxPredictor below. + dropout_keep_prob: Keep probability for dropout. + This is only used if use_dropout is True. + box_code_size: Size of encoding for each box. + add_background_class: Whether to add an implicit background class. + share_box_across_classes: Whether to share boxes across classes rather + than use a different box for each class. + predict_instance_masks: If True, will add a third stage mask prediction + to the returned class. + conv_hyperparams_fn: A function to generate tf-slim arg_scope with + hyperparameters for convolution ops. + mask_height: Desired output mask height. The default value is 14. + mask_width: Desired output mask width. The default value is 14. + mask_prediction_num_conv_layers: Number of convolution layers applied to + the image_features in mask prediction branch. + mask_prediction_conv_depth: The depth for the first conv2d_transpose op + applied to the image_features in the mask prediction branch. If set + to 0, the depth of the convolution layers will be automatically chosen + based on the number of object classes and the number of channels in the + image features. + masks_are_class_agnostic: Boolean determining if the mask-head is + class-agnostic or not. + convolve_then_upsample_masks: Whether to apply convolutions on mask + features before upsampling using nearest neighbor resizing. Otherwise, + mask features are resized to [`mask_height`, `mask_width`] using + bilinear resizing before applying convolutions. Returns: A MaskRCNNBoxPredictor class. @@ -410,7 +431,7 @@ def build_mask_rcnn_box_predictor(is_training, share_box_across_classes=share_box_across_classes) class_prediction_head = class_head.MaskRCNNClassHead( is_training=is_training, - num_classes=num_classes, + num_class_slots=num_classes + 1 if add_background_class else num_classes, fc_hyperparams_fn=fc_hyperparams_fn, use_dropout=use_dropout, dropout_keep_prob=dropout_keep_prob) @@ -425,7 +446,8 @@ def build_mask_rcnn_box_predictor(is_training, mask_width=mask_width, mask_prediction_num_conv_layers=mask_prediction_num_conv_layers, mask_prediction_conv_depth=mask_prediction_conv_depth, - masks_are_class_agnostic=masks_are_class_agnostic) + masks_are_class_agnostic=masks_are_class_agnostic, + convolve_then_upsample=convolve_then_upsample_masks) return mask_rcnn_box_predictor.MaskRCNNBoxPredictor( is_training=is_training, num_classes=num_classes, @@ -464,7 +486,8 @@ def build_score_converter(score_converter_config, is_training): ['min', 'max']) -def build(argscope_fn, box_predictor_config, is_training, num_classes): +def build(argscope_fn, box_predictor_config, is_training, num_classes, + add_background_class=True): """Builds box predictor based on the configuration. Builds box predictor based on the configuration. See box_predictor.proto for @@ -479,6 +502,7 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): configuration. is_training: Whether the models is in training mode. num_classes: Number of classes to predict. + add_background_class: Whether to add an implicit background class. Returns: box_predictor: box_predictor.BoxPredictor object. @@ -502,6 +526,7 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): return build_convolutional_box_predictor( is_training=is_training, num_classes=num_classes, + add_background_class=add_background_class, conv_hyperparams_fn=conv_hyperparams_fn, use_dropout=config_box_predictor.use_dropout, dropout_keep_prob=config_box_predictor.dropout_keep_probability, @@ -542,6 +567,7 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): return build_weight_shared_convolutional_box_predictor( is_training=is_training, num_classes=num_classes, + add_background_class=add_background_class, conv_hyperparams_fn=conv_hyperparams_fn, depth=config_box_predictor.depth, num_layers_before_predictor=( @@ -570,6 +596,7 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): return build_mask_rcnn_box_predictor( is_training=is_training, num_classes=num_classes, + add_background_class=add_background_class, fc_hyperparams_fn=fc_hyperparams_fn, use_dropout=config_box_predictor.use_dropout, dropout_keep_prob=config_box_predictor.dropout_keep_probability, @@ -585,7 +612,9 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): mask_prediction_conv_depth=( config_box_predictor.mask_prediction_conv_depth), masks_are_class_agnostic=( - config_box_predictor.masks_are_class_agnostic)) + config_box_predictor.masks_are_class_agnostic), + convolve_then_upsample_masks=( + config_box_predictor.convolve_then_upsample_masks)) if box_predictor_oneof == 'rfcn_box_predictor': config_box_predictor = box_predictor_config.rfcn_box_predictor @@ -603,3 +632,78 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): box_code_size=config_box_predictor.box_code_size) return box_predictor_object raise ValueError('Unknown box predictor: {}'.format(box_predictor_oneof)) + + +def build_keras(conv_hyperparams_fn, freeze_batchnorm, inplace_batchnorm_update, + num_predictions_per_location_list, box_predictor_config, + is_training, num_classes, add_background_class=True): + """Builds a Keras-based box predictor based on the configuration. + + Builds Keras-based box predictor based on the configuration. + See box_predictor.proto for configurable options. Also, see box_predictor.py + for more details. + + Args: + conv_hyperparams_fn: A function that takes a hyperparams_pb2.Hyperparams + proto and returns a `hyperparams_builder.KerasLayerHyperparams` + for Conv or FC hyperparameters. + freeze_batchnorm: Whether to freeze batch norm parameters during + training or not. When training with a small batch size (e.g. 1), it is + desirable to freeze batch norm update and use pretrained batch norm + params. + inplace_batchnorm_update: Whether to update batch norm moving average + values inplace. When this is false train op must add a control + dependency on tf.graphkeys.UPDATE_OPS collection in order to update + batch norm statistics. + num_predictions_per_location_list: A list of integers representing the + number of box predictions to be made per spatial location for each + feature map. + box_predictor_config: box_predictor_pb2.BoxPredictor proto containing + configuration. + is_training: Whether the models is in training mode. + num_classes: Number of classes to predict. + add_background_class: Whether to add an implicit background class. + + Returns: + box_predictor: box_predictor.KerasBoxPredictor object. + + Raises: + ValueError: On unknown box predictor, or one with no Keras box predictor. + """ + if not isinstance(box_predictor_config, box_predictor_pb2.BoxPredictor): + raise ValueError('box_predictor_config not of type ' + 'box_predictor_pb2.BoxPredictor.') + + box_predictor_oneof = box_predictor_config.WhichOneof('box_predictor_oneof') + + if box_predictor_oneof == 'convolutional_box_predictor': + config_box_predictor = box_predictor_config.convolutional_box_predictor + conv_hyperparams = conv_hyperparams_fn( + config_box_predictor.conv_hyperparams) + + mask_head_config = ( + config_box_predictor.mask_head + if config_box_predictor.HasField('mask_head') else None) + return build_convolutional_keras_box_predictor( + is_training=is_training, + num_classes=num_classes, + add_background_class=add_background_class, + conv_hyperparams=conv_hyperparams, + freeze_batchnorm=freeze_batchnorm, + inplace_batchnorm_update=inplace_batchnorm_update, + num_predictions_per_location_list=num_predictions_per_location_list, + use_dropout=config_box_predictor.use_dropout, + dropout_keep_prob=config_box_predictor.dropout_keep_probability, + box_code_size=config_box_predictor.box_code_size, + kernel_size=config_box_predictor.kernel_size, + num_layers_before_predictor=( + config_box_predictor.num_layers_before_predictor), + min_depth=config_box_predictor.min_depth, + max_depth=config_box_predictor.max_depth, + class_prediction_bias_init=( + config_box_predictor.class_prediction_bias_init), + use_depthwise=config_box_predictor.use_depthwise, + mask_head_config=mask_head_config) + + raise ValueError( + 'Unknown box predictor for Keras: {}'.format(box_predictor_oneof)) diff --git a/research/object_detection/builders/box_predictor_builder_test.py b/research/object_detection/builders/box_predictor_builder_test.py index 51a812a9a12..12f2dfc51bf 100644 --- a/research/object_detection/builders/box_predictor_builder_test.py +++ b/research/object_detection/builders/box_predictor_builder_test.py @@ -113,7 +113,8 @@ def mock_conv_argscope_builder(conv_hyperparams_arg, is_training): argscope_fn=mock_conv_argscope_builder, box_predictor_config=box_predictor_proto, is_training=False, - num_classes=10) + num_classes=10, + add_background_class=False) class_head = box_predictor._class_prediction_head self.assertEqual(box_predictor._min_depth, 2) self.assertEqual(box_predictor._max_depth, 16) @@ -122,6 +123,7 @@ def mock_conv_argscope_builder(conv_hyperparams_arg, is_training): self.assertAlmostEqual(class_head._dropout_keep_prob, 0.4) self.assertTrue(class_head._apply_sigmoid_to_scores) self.assertAlmostEqual(class_head._class_prediction_bias_init, 4.0) + self.assertEqual(class_head._num_class_slots, 10) self.assertEqual(box_predictor.num_classes, 10) self.assertFalse(box_predictor._is_training) self.assertTrue(class_head._use_depthwise) @@ -154,6 +156,7 @@ def test_construct_default_conv_box_predictor(self): self.assertTrue(class_head._use_dropout) self.assertAlmostEqual(class_head._dropout_keep_prob, 0.8) self.assertFalse(class_head._apply_sigmoid_to_scores) + self.assertEqual(class_head._num_class_slots, 91) self.assertEqual(box_predictor.num_classes, 90) self.assertTrue(box_predictor._is_training) self.assertFalse(class_head._use_depthwise) @@ -306,7 +309,8 @@ def mock_conv_argscope_builder(conv_hyperparams_arg, is_training): argscope_fn=mock_conv_argscope_builder, box_predictor_config=box_predictor_proto, is_training=False, - num_classes=10) + num_classes=10, + add_background_class=False) class_head = box_predictor._class_prediction_head self.assertEqual(box_predictor._depth, 2) self.assertEqual(box_predictor._num_layers_before_predictor, 2) @@ -349,7 +353,8 @@ def mock_conv_argscope_builder(conv_hyperparams_arg, is_training): argscope_fn=mock_conv_argscope_builder, box_predictor_config=box_predictor_proto, is_training=False, - num_classes=10) + num_classes=10, + add_background_class=False) class_head = box_predictor._class_prediction_head self.assertEqual(box_predictor._depth, 2) self.assertEqual(box_predictor._num_layers_before_predictor, 2) @@ -627,6 +632,48 @@ def test_build_box_predictor_with_mask_branch(self): third_stage_heads[mask_rcnn_box_predictor.MASK_PREDICTIONS] ._mask_prediction_conv_depth, 512) + def test_build_box_predictor_with_convlve_then_upsample_masks(self): + box_predictor_proto = box_predictor_pb2.BoxPredictor() + box_predictor_proto.mask_rcnn_box_predictor.fc_hyperparams.op = ( + hyperparams_pb2.Hyperparams.FC) + box_predictor_proto.mask_rcnn_box_predictor.conv_hyperparams.op = ( + hyperparams_pb2.Hyperparams.CONV) + box_predictor_proto.mask_rcnn_box_predictor.predict_instance_masks = True + box_predictor_proto.mask_rcnn_box_predictor.mask_prediction_conv_depth = 512 + box_predictor_proto.mask_rcnn_box_predictor.mask_height = 24 + box_predictor_proto.mask_rcnn_box_predictor.mask_width = 24 + box_predictor_proto.mask_rcnn_box_predictor.convolve_then_upsample_masks = ( + True) + + mock_argscope_fn = mock.Mock(return_value='arg_scope') + box_predictor = box_predictor_builder.build( + argscope_fn=mock_argscope_fn, + box_predictor_config=box_predictor_proto, + is_training=True, + num_classes=90) + mock_argscope_fn.assert_has_calls( + [mock.call(box_predictor_proto.mask_rcnn_box_predictor.fc_hyperparams, + True), + mock.call(box_predictor_proto.mask_rcnn_box_predictor.conv_hyperparams, + True)], any_order=True) + box_head = box_predictor._box_prediction_head + class_head = box_predictor._class_prediction_head + third_stage_heads = box_predictor._third_stage_heads + self.assertFalse(box_head._use_dropout) + self.assertFalse(class_head._use_dropout) + self.assertAlmostEqual(box_head._dropout_keep_prob, 0.5) + self.assertAlmostEqual(class_head._dropout_keep_prob, 0.5) + self.assertEqual(box_predictor.num_classes, 90) + self.assertTrue(box_predictor._is_training) + self.assertEqual(box_head._box_code_size, 4) + self.assertTrue( + mask_rcnn_box_predictor.MASK_PREDICTIONS in third_stage_heads) + self.assertEqual( + third_stage_heads[mask_rcnn_box_predictor.MASK_PREDICTIONS] + ._mask_prediction_conv_depth, 512) + self.assertTrue(third_stage_heads[mask_rcnn_box_predictor.MASK_PREDICTIONS] + ._convolve_then_upsample) + class RfcnBoxPredictorBuilderTest(tf.test.TestCase): diff --git a/research/object_detection/builders/hyperparams_builder.py b/research/object_detection/builders/hyperparams_builder.py index 2e2a4bf557d..496d41d676a 100644 --- a/research/object_detection/builders/hyperparams_builder.py +++ b/research/object_detection/builders/hyperparams_builder.py @@ -64,6 +64,10 @@ def __init__(self, hyperparams_config): hyperparams_config.batch_norm) self._activation_fn = _build_activation_fn(hyperparams_config.activation) + # TODO(kaftan): Unclear if these kwargs apply to separable & depthwise conv + # (Those might use depthwise_* instead of kernel_*) + # We should probably switch to using build_conv2d_layer and + # build_depthwise_conv2d_layer methods instead. self._op_params = { 'kernel_regularizer': _build_keras_regularizer( hyperparams_config.regularizer), diff --git a/research/object_detection/builders/image_resizer_builder.py b/research/object_detection/builders/image_resizer_builder.py index 3b3014f727e..243c84dd415 100644 --- a/research/object_detection/builders/image_resizer_builder.py +++ b/research/object_detection/builders/image_resizer_builder.py @@ -106,10 +106,35 @@ def build(image_resizer_config): raise ValueError( 'Invalid image resizer option: \'%s\'.' % image_resizer_oneof) - def grayscale_image_resizer(image): - [resized_image, resized_image_shape] = image_resizer_fn(image) - grayscale_image = preprocessor.rgb_to_gray(resized_image) - grayscale_image_shape = tf.concat([resized_image_shape[:-1], [1]], 0) - return [grayscale_image, grayscale_image_shape] + def grayscale_image_resizer(image, masks=None): + """Convert to grayscale before applying image_resizer_fn. + + Args: + image: A 3D tensor of shape [height, width, 3] + masks: (optional) rank 3 float32 tensor with shape [num_instances, height, + width] containing instance masks. + + Returns: + Note that the position of the resized_image_shape changes based on whether + masks are present. + resized_image: A 3D tensor of shape [new_height, new_width, 1], + where the image has been resized (with bilinear interpolation) so that + min(new_height, new_width) == min_dimension or + max(new_height, new_width) == max_dimension. + resized_masks: If masks is not None, also outputs masks. A 3D tensor of + shape [num_instances, new_height, new_width]. + resized_image_shape: A 1D tensor of shape [3] containing shape of the + resized image. + """ + # image_resizer_fn returns [resized_image, resized_image_shape] if + # mask==None, otherwise it returns + # [resized_image, resized_mask, resized_image_shape]. In either case, we + # only deal with first and last element of the returned list. + retval = image_resizer_fn(image, masks) + resized_image = retval[0] + resized_image_shape = retval[-1] + retval[0] = preprocessor.rgb_to_gray(resized_image) + retval[-1] = tf.concat([resized_image_shape[:-1], [1]], 0) + return retval return functools.partial(grayscale_image_resizer) diff --git a/research/object_detection/builders/losses_builder.py b/research/object_detection/builders/losses_builder.py index e4f7a12400f..3c2345f7192 100644 --- a/research/object_detection/builders/losses_builder.py +++ b/research/object_detection/builders/losses_builder.py @@ -136,6 +136,14 @@ def build_faster_rcnn_classification_loss(loss_config): config = loss_config.weighted_logits_softmax return losses.WeightedSoftmaxClassificationAgainstLogitsLoss( logit_scale=config.logit_scale) + if loss_type == 'weighted_sigmoid_focal': + config = loss_config.weighted_sigmoid_focal + alpha = None + if config.HasField('alpha'): + alpha = config.alpha + return losses.SigmoidFocalClassificationLoss( + gamma=config.gamma, + alpha=alpha) # By default, Faster RCNN second stage classifier uses Softmax loss # with anchor-wise outputs. diff --git a/research/object_detection/builders/losses_builder_test.py b/research/object_detection/builders/losses_builder_test.py index 4dc4a754eca..cac8f442cef 100644 --- a/research/object_detection/builders/losses_builder_test.py +++ b/research/object_detection/builders/losses_builder_test.py @@ -280,7 +280,7 @@ def test_anchorwise_output(self): losses.WeightedSigmoidClassificationLoss)) predictions = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.5, 0.5]]]) targets = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]) - weights = tf.constant([[1.0, 1.0]]) + weights = tf.constant([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]]) loss = classification_loss(predictions, targets, weights=weights) self.assertEqual(loss.shape, [1, 2, 3]) @@ -473,6 +473,19 @@ def test_build_logits_softmax_loss(self): isinstance(classification_loss, losses.WeightedSoftmaxClassificationAgainstLogitsLoss)) + def test_build_sigmoid_focal_loss(self): + losses_text_proto = """ + weighted_sigmoid_focal { + } + """ + losses_proto = losses_pb2.ClassificationLoss() + text_format.Merge(losses_text_proto, losses_proto) + classification_loss = losses_builder.build_faster_rcnn_classification_loss( + losses_proto) + self.assertTrue( + isinstance(classification_loss, + losses.SigmoidFocalClassificationLoss)) + def test_build_softmax_loss_by_default(self): losses_text_proto = """ """ diff --git a/research/object_detection/builders/model_builder.py b/research/object_detection/builders/model_builder.py index e754b376e63..de1d874c2a5 100644 --- a/research/object_detection/builders/model_builder.py +++ b/research/object_detection/builders/model_builder.py @@ -47,6 +47,8 @@ from object_detection.models.ssd_mobilenet_v1_ppn_feature_extractor import SSDMobileNetV1PpnFeatureExtractor from object_detection.models.ssd_mobilenet_v2_feature_extractor import SSDMobileNetV2FeatureExtractor from object_detection.models.ssd_mobilenet_v2_fpn_feature_extractor import SSDMobileNetV2FpnFeatureExtractor +from object_detection.models.ssd_mobilenet_v2_keras_feature_extractor import SSDMobileNetV2KerasFeatureExtractor +from object_detection.models.ssd_pnasnet_feature_extractor import SSDPNASNetFeatureExtractor from object_detection.predictors import rfcn_box_predictor from object_detection.protos import model_pb2 from object_detection.utils import ops @@ -69,6 +71,11 @@ 'ssd_resnet152_v1_ppn': ssd_resnet_v1_ppn.SSDResnet152V1PpnFeatureExtractor, 'embedded_ssd_mobilenet_v1': EmbeddedSSDMobileNetV1FeatureExtractor, + 'ssd_pnasnet': SSDPNASNetFeatureExtractor, +} + +SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP = { + 'ssd_mobilenet_v2_keras': SSDMobileNetV2KerasFeatureExtractor } # A map of names to Faster R-CNN feature extractors. @@ -90,8 +97,7 @@ } -def build(model_config, is_training, add_summaries=True, - add_background_class=True): +def build(model_config, is_training, add_summaries=True): """Builds a DetectionModel based on the model config. Args: @@ -99,10 +105,6 @@ def build(model_config, is_training, add_summaries=True, DetectionModel. is_training: True if this model is being built for training purposes. add_summaries: Whether to add tensorflow summaries in the model graph. - add_background_class: Whether to add an implicit background class to one-hot - encodings of groundtruth labels. Set to false if using groundtruth labels - with an explicit background class or using multiclass scores instead of - truth in the case of distillation. Ignored in the case of faster_rcnn. Returns: DetectionModel based on the config. @@ -113,21 +115,26 @@ def build(model_config, is_training, add_summaries=True, raise ValueError('model_config not of type model_pb2.DetectionModel.') meta_architecture = model_config.WhichOneof('model') if meta_architecture == 'ssd': - return _build_ssd_model(model_config.ssd, is_training, add_summaries, - add_background_class) + return _build_ssd_model(model_config.ssd, is_training, add_summaries) if meta_architecture == 'faster_rcnn': return _build_faster_rcnn_model(model_config.faster_rcnn, is_training, add_summaries) raise ValueError('Unknown meta architecture: {}'.format(meta_architecture)) -def _build_ssd_feature_extractor(feature_extractor_config, is_training, +def _build_ssd_feature_extractor(feature_extractor_config, + is_training, + freeze_batchnorm, reuse_weights=None): """Builds a ssd_meta_arch.SSDFeatureExtractor based on config. Args: feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto. is_training: True if this feature extractor is being built for training. + freeze_batchnorm: Whether to freeze batch norm parameters during + training or not. When training with a small batch size (e.g. 1), it is + desirable to freeze batch norm update and use pretrained batch norm + params. reuse_weights: if the feature extractor should reuse weights. Returns: @@ -137,20 +144,31 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training, ValueError: On invalid feature extractor type. """ feature_type = feature_extractor_config.type + is_keras_extractor = feature_type in SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP depth_multiplier = feature_extractor_config.depth_multiplier min_depth = feature_extractor_config.min_depth pad_to_multiple = feature_extractor_config.pad_to_multiple use_explicit_padding = feature_extractor_config.use_explicit_padding use_depthwise = feature_extractor_config.use_depthwise - conv_hyperparams = hyperparams_builder.build( - feature_extractor_config.conv_hyperparams, is_training) + + if is_keras_extractor: + conv_hyperparams = hyperparams_builder.KerasLayerHyperparams( + feature_extractor_config.conv_hyperparams) + else: + conv_hyperparams = hyperparams_builder.build( + feature_extractor_config.conv_hyperparams, is_training) override_base_feature_extractor_hyperparams = ( feature_extractor_config.override_base_feature_extractor_hyperparams) - if feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP: + if (feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP) and ( + not is_keras_extractor): raise ValueError('Unknown ssd feature_extractor: {}'.format(feature_type)) - feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type] + if is_keras_extractor: + feature_extractor_class = SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[ + feature_type] + else: + feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type] kwargs = { 'is_training': is_training, @@ -160,10 +178,6 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training, min_depth, 'pad_to_multiple': pad_to_multiple, - 'conv_hyperparams_fn': - conv_hyperparams, - 'reuse_weights': - reuse_weights, 'use_explicit_padding': use_explicit_padding, 'use_depthwise': @@ -172,6 +186,18 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training, override_base_feature_extractor_hyperparams } + if is_keras_extractor: + kwargs.update({ + 'conv_hyperparams': conv_hyperparams, + 'inplace_batchnorm_update': False, + 'freeze_batchnorm': freeze_batchnorm + }) + else: + kwargs.update({ + 'conv_hyperparams_fn': conv_hyperparams, + 'reuse_weights': reuse_weights, + }) + if feature_extractor_config.HasField('fpn'): kwargs.update({ 'fpn_min_level': @@ -185,8 +211,7 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training, return feature_extractor_class(**kwargs) -def _build_ssd_model(ssd_config, is_training, add_summaries, - add_background_class=True): +def _build_ssd_model(ssd_config, is_training, add_summaries): """Builds an SSD detection model based on the model config. Args: @@ -194,10 +219,6 @@ def _build_ssd_model(ssd_config, is_training, add_summaries, SSDMetaArch. is_training: True if this model is being built for training purposes. add_summaries: Whether to add tf summaries in the model. - add_background_class: Whether to add an implicit background class to one-hot - encodings of groundtruth labels. Set to false if using groundtruth labels - with an explicit background class or using multiclass scores instead of - truth in the case of distillation. Returns: SSDMetaArch based on the config. @@ -210,6 +231,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries, # Feature extractor feature_extractor = _build_ssd_feature_extractor( feature_extractor_config=ssd_config.feature_extractor, + freeze_batchnorm=ssd_config.freeze_batchnorm, is_training=is_training) box_coder = box_coder_builder.build(ssd_config.box_coder) @@ -218,11 +240,23 @@ def _build_ssd_model(ssd_config, is_training, add_summaries, ssd_config.similarity_calculator) encode_background_as_zeros = ssd_config.encode_background_as_zeros negative_class_weight = ssd_config.negative_class_weight - ssd_box_predictor = box_predictor_builder.build(hyperparams_builder.build, - ssd_config.box_predictor, - is_training, num_classes) anchor_generator = anchor_generator_builder.build( ssd_config.anchor_generator) + if feature_extractor.is_keras_model: + ssd_box_predictor = box_predictor_builder.build_keras( + conv_hyperparams_fn=hyperparams_builder.KerasLayerHyperparams, + freeze_batchnorm=ssd_config.freeze_batchnorm, + inplace_batchnorm_update=False, + num_predictions_per_location_list=anchor_generator + .num_anchors_per_location(), + box_predictor_config=ssd_config.box_predictor, + is_training=is_training, + num_classes=num_classes, + add_background_class=ssd_config.add_background_class) + else: + ssd_box_predictor = box_predictor_builder.build( + hyperparams_builder.build, ssd_config.box_predictor, is_training, + num_classes, ssd_config.add_background_class) image_resizer_fn = image_resizer_builder.build(ssd_config.image_resizer) non_max_suppression_fn, score_conversion_fn = post_processing_builder.build( ssd_config.post_processing) @@ -244,7 +278,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries, if ssd_config.use_expected_classification_loss_under_sampling: expected_classification_loss_under_sampling = functools.partial( ops.expected_classification_loss_under_sampling, - minimum_negative_sampling=ssd_config.minimum_negative_sampling, + min_num_negative_samples=ssd_config.min_num_negative_samples, desired_negative_sampling_ratio=ssd_config. desired_negative_sampling_ratio) @@ -271,7 +305,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries, normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize, freeze_batchnorm=ssd_config.freeze_batchnorm, inplace_batchnorm_update=ssd_config.inplace_batchnorm_update, - add_background_class=add_background_class, + add_background_class=ssd_config.add_background_class, random_example_sampler=random_example_sampler, expected_classification_loss_under_sampling= expected_classification_loss_under_sampling) @@ -357,12 +391,11 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries): frcnn_config.first_stage_box_predictor_kernel_size) first_stage_box_predictor_depth = frcnn_config.first_stage_box_predictor_depth first_stage_minibatch_size = frcnn_config.first_stage_minibatch_size - # TODO(bhattad): When eval is supported using static shapes, add separate - # use_static_shapes_for_trainig and use_static_shapes_for_evaluation. - use_static_shapes = frcnn_config.use_static_shapes and is_training + use_static_shapes = frcnn_config.use_static_shapes first_stage_sampler = sampler.BalancedPositiveNegativeSampler( positive_fraction=frcnn_config.first_stage_positive_balance_fraction, - is_static=frcnn_config.use_static_balanced_label_sampler and is_training) + is_static=(frcnn_config.use_static_balanced_label_sampler and + use_static_shapes)) first_stage_max_proposals = frcnn_config.first_stage_max_proposals if (frcnn_config.first_stage_nms_iou_threshold < 0 or frcnn_config.first_stage_nms_iou_threshold > 1.0): @@ -377,7 +410,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries): iou_thresh=frcnn_config.first_stage_nms_iou_threshold, max_size_per_class=frcnn_config.first_stage_max_proposals, max_total_size=frcnn_config.first_stage_max_proposals, - use_static_shapes=use_static_shapes and is_training) + use_static_shapes=use_static_shapes) first_stage_loc_loss_weight = ( frcnn_config.first_stage_localization_loss_weight) first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight @@ -398,7 +431,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries): second_stage_batch_size = frcnn_config.second_stage_batch_size second_stage_sampler = sampler.BalancedPositiveNegativeSampler( positive_fraction=frcnn_config.second_stage_balance_fraction, - is_static=frcnn_config.use_static_balanced_label_sampler and is_training) + is_static=(frcnn_config.use_static_balanced_label_sampler and + use_static_shapes)) (second_stage_non_max_suppression_fn, second_stage_score_conversion_fn ) = post_processing_builder.build(frcnn_config.second_stage_post_processing) second_stage_localization_loss_weight = ( diff --git a/research/object_detection/builders/model_builder_test.py b/research/object_detection/builders/model_builder_test.py index be74ba087e7..2809bda132e 100644 --- a/research/object_detection/builders/model_builder_test.py +++ b/research/object_detection/builders/model_builder_test.py @@ -39,6 +39,9 @@ from object_detection.models.ssd_mobilenet_v1_ppn_feature_extractor import SSDMobileNetV1PpnFeatureExtractor from object_detection.models.ssd_mobilenet_v2_feature_extractor import SSDMobileNetV2FeatureExtractor from object_detection.models.ssd_mobilenet_v2_fpn_feature_extractor import SSDMobileNetV2FpnFeatureExtractor +from object_detection.models.ssd_mobilenet_v2_keras_feature_extractor import SSDMobileNetV2KerasFeatureExtractor +from object_detection.predictors import convolutional_box_predictor +from object_detection.predictors import convolutional_keras_box_predictor from object_detection.protos import model_pb2 FRCNN_RESNET_FEAT_MAPS = { @@ -148,7 +151,7 @@ def test_create_ssd_inception_v2_model_from_config(self): } } use_expected_classification_loss_under_sampling: true - minimum_negative_sampling: 10 + min_num_negative_samples: 10 desired_negative_sampling_ratio: 2 }""" model_proto = model_pb2.DetectionModel() @@ -160,7 +163,7 @@ def test_create_ssd_inception_v2_model_from_config(self): self.assertIsNotNone(model._expected_classification_loss_under_sampling) self.assertEqual( model._expected_classification_loss_under_sampling.keywords, { - 'minimum_negative_sampling': 10, + 'min_num_negative_samples': 10, 'desired_negative_sampling_ratio': 2 }) @@ -713,6 +716,86 @@ def test_create_ssd_mobilenet_v2_model_from_config(self): self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch) self.assertIsInstance(model._feature_extractor, SSDMobileNetV2FeatureExtractor) + self.assertIsInstance(model._box_predictor, + convolutional_box_predictor.ConvolutionalBoxPredictor) + self.assertTrue(model._normalize_loc_loss_by_codesize) + self.assertTrue(model._target_assigner._weight_regression_loss_by_score) + + def test_create_ssd_mobilenet_v2_keras_model_from_config(self): + model_text_proto = """ + ssd { + feature_extractor { + type: 'ssd_mobilenet_v2_keras' + conv_hyperparams { + regularizer { + l2_regularizer { + } + } + initializer { + truncated_normal_initializer { + } + } + } + } + box_coder { + faster_rcnn_box_coder { + } + } + matcher { + argmax_matcher { + } + } + similarity_calculator { + iou_similarity { + } + } + anchor_generator { + ssd_anchor_generator { + aspect_ratios: 1.0 + } + } + image_resizer { + fixed_shape_resizer { + height: 320 + width: 320 + } + } + box_predictor { + convolutional_box_predictor { + conv_hyperparams { + regularizer { + l2_regularizer { + } + } + initializer { + truncated_normal_initializer { + } + } + } + } + } + normalize_loc_loss_by_codesize: true + loss { + classification_loss { + weighted_softmax { + } + } + localization_loss { + weighted_smooth_l1 { + } + } + } + weight_regression_loss_by_score: true + }""" + model_proto = model_pb2.DetectionModel() + text_format.Merge(model_text_proto, model_proto) + model = self.create_model(model_proto) + self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch) + self.assertIsInstance(model._feature_extractor, + SSDMobileNetV2KerasFeatureExtractor) + self.assertIsInstance( + model._box_predictor, + convolutional_keras_box_predictor.ConvolutionalBoxPredictor) self.assertTrue(model._normalize_loc_loss_by_codesize) self.assertTrue(model._target_assigner._weight_regression_loss_by_score) diff --git a/research/object_detection/builders/preprocessor_builder.py b/research/object_detection/builders/preprocessor_builder.py index 07e3378dc58..050b01939e5 100644 --- a/research/object_detection/builders/preprocessor_builder.py +++ b/research/object_detection/builders/preprocessor_builder.py @@ -167,6 +167,7 @@ def build(preprocessor_step_config): config.max_aspect_ratio), 'area_range': (config.min_area, config.max_area), 'overlap_thresh': config.overlap_thresh, + 'clip_boxes': config.clip_boxes, 'random_coef': config.random_coef, }) @@ -217,6 +218,7 @@ def build(preprocessor_step_config): config.max_aspect_ratio), 'area_range': (config.min_area, config.max_area), 'overlap_thresh': config.overlap_thresh, + 'clip_boxes': config.clip_boxes, 'random_coef': config.random_coef, } if min_padded_size_ratio: @@ -252,6 +254,7 @@ def build(preprocessor_step_config): for op in config.operations] area_range = [(op.min_area, op.max_area) for op in config.operations] overlap_thresh = [op.overlap_thresh for op in config.operations] + clip_boxes = [op.clip_boxes for op in config.operations] random_coef = [op.random_coef for op in config.operations] return (preprocessor.ssd_random_crop, { @@ -259,6 +262,7 @@ def build(preprocessor_step_config): 'aspect_ratio_range': aspect_ratio_range, 'area_range': area_range, 'overlap_thresh': overlap_thresh, + 'clip_boxes': clip_boxes, 'random_coef': random_coef, }) return (preprocessor.ssd_random_crop, {}) @@ -271,6 +275,7 @@ def build(preprocessor_step_config): for op in config.operations] area_range = [(op.min_area, op.max_area) for op in config.operations] overlap_thresh = [op.overlap_thresh for op in config.operations] + clip_boxes = [op.clip_boxes for op in config.operations] random_coef = [op.random_coef for op in config.operations] min_padded_size_ratio = [tuple(op.min_padded_size_ratio) for op in config.operations] @@ -284,6 +289,7 @@ def build(preprocessor_step_config): 'aspect_ratio_range': aspect_ratio_range, 'area_range': area_range, 'overlap_thresh': overlap_thresh, + 'clip_boxes': clip_boxes, 'random_coef': random_coef, 'min_padded_size_ratio': min_padded_size_ratio, 'max_padded_size_ratio': max_padded_size_ratio, @@ -297,6 +303,7 @@ def build(preprocessor_step_config): min_object_covered = [op.min_object_covered for op in config.operations] area_range = [(op.min_area, op.max_area) for op in config.operations] overlap_thresh = [op.overlap_thresh for op in config.operations] + clip_boxes = [op.clip_boxes for op in config.operations] random_coef = [op.random_coef for op in config.operations] return (preprocessor.ssd_random_crop_fixed_aspect_ratio, { @@ -304,6 +311,7 @@ def build(preprocessor_step_config): 'aspect_ratio': config.aspect_ratio, 'area_range': area_range, 'overlap_thresh': overlap_thresh, + 'clip_boxes': clip_boxes, 'random_coef': random_coef, }) return (preprocessor.ssd_random_crop_fixed_aspect_ratio, {}) @@ -332,6 +340,7 @@ def build(preprocessor_step_config): kwargs['area_range'] = [(op.min_area, op.max_area) for op in config.operations] kwargs['overlap_thresh'] = [op.overlap_thresh for op in config.operations] + kwargs['clip_boxes'] = [op.clip_boxes for op in config.operations] kwargs['random_coef'] = [op.random_coef for op in config.operations] return (preprocessor.ssd_random_crop_pad_fixed_aspect_ratio, kwargs) diff --git a/research/object_detection/builders/preprocessor_builder_test.py b/research/object_detection/builders/preprocessor_builder_test.py index 8a72aa40e62..89de2b437b1 100644 --- a/research/object_detection/builders/preprocessor_builder_test.py +++ b/research/object_detection/builders/preprocessor_builder_test.py @@ -222,6 +222,7 @@ def test_build_random_crop_image(self): min_area: 0.25 max_area: 0.875 overlap_thresh: 0.5 + clip_boxes: False random_coef: 0.125 } """ @@ -234,6 +235,7 @@ def test_build_random_crop_image(self): 'aspect_ratio_range': (0.75, 1.5), 'area_range': (0.25, 0.875), 'overlap_thresh': 0.5, + 'clip_boxes': False, 'random_coef': 0.125, }) @@ -261,6 +263,7 @@ def test_build_random_crop_pad_image(self): min_area: 0.25 max_area: 0.875 overlap_thresh: 0.5 + clip_boxes: False random_coef: 0.125 } """ @@ -273,6 +276,7 @@ def test_build_random_crop_pad_image(self): 'aspect_ratio_range': (0.75, 1.5), 'area_range': (0.25, 0.875), 'overlap_thresh': 0.5, + 'clip_boxes': False, 'random_coef': 0.125, }) @@ -285,6 +289,7 @@ def test_build_random_crop_pad_image_with_optional_parameters(self): min_area: 0.25 max_area: 0.875 overlap_thresh: 0.5 + clip_boxes: False random_coef: 0.125 min_padded_size_ratio: 0.5 min_padded_size_ratio: 0.75 @@ -304,6 +309,7 @@ def test_build_random_crop_pad_image_with_optional_parameters(self): 'aspect_ratio_range': (0.75, 1.5), 'area_range': (0.25, 0.875), 'overlap_thresh': 0.5, + 'clip_boxes': False, 'random_coef': 0.125, 'min_padded_size_ratio': (0.5, 0.75), 'max_padded_size_ratio': (0.5, 0.75), @@ -315,6 +321,7 @@ def test_build_random_crop_to_aspect_ratio(self): random_crop_to_aspect_ratio { aspect_ratio: 0.85 overlap_thresh: 0.35 + clip_boxes: False } """ preprocessor_proto = preprocessor_pb2.PreprocessingStep() @@ -322,7 +329,8 @@ def test_build_random_crop_to_aspect_ratio(self): function, args = preprocessor_builder.build(preprocessor_proto) self.assertEqual(function, preprocessor.random_crop_to_aspect_ratio) self.assert_dictionary_close(args, {'aspect_ratio': 0.85, - 'overlap_thresh': 0.35}) + 'overlap_thresh': 0.35, + 'clip_boxes': False}) def test_build_random_black_patches(self): preprocessor_text_proto = """ @@ -411,6 +419,7 @@ def test_build_ssd_random_crop(self): min_area: 0.5 max_area: 1.0 overlap_thresh: 0.0 + clip_boxes: False random_coef: 0.375 } operations { @@ -420,6 +429,7 @@ def test_build_ssd_random_crop(self): min_area: 0.5 max_area: 1.0 overlap_thresh: 0.25 + clip_boxes: True random_coef: 0.375 } } @@ -432,6 +442,7 @@ def test_build_ssd_random_crop(self): 'aspect_ratio_range': [(0.875, 1.125), (0.75, 1.5)], 'area_range': [(0.5, 1.0), (0.5, 1.0)], 'overlap_thresh': [0.0, 0.25], + 'clip_boxes': [False, True], 'random_coef': [0.375, 0.375]}) def test_build_ssd_random_crop_empty_operations(self): @@ -455,6 +466,7 @@ def test_build_ssd_random_crop_pad(self): min_area: 0.5 max_area: 1.0 overlap_thresh: 0.0 + clip_boxes: False random_coef: 0.375 min_padded_size_ratio: [1.0, 1.0] max_padded_size_ratio: [2.0, 2.0] @@ -469,6 +481,7 @@ def test_build_ssd_random_crop_pad(self): min_area: 0.5 max_area: 1.0 overlap_thresh: 0.25 + clip_boxes: True random_coef: 0.375 min_padded_size_ratio: [1.0, 1.0] max_padded_size_ratio: [2.0, 2.0] @@ -486,6 +499,7 @@ def test_build_ssd_random_crop_pad(self): 'aspect_ratio_range': [(0.875, 1.125), (0.75, 1.5)], 'area_range': [(0.5, 1.0), (0.5, 1.0)], 'overlap_thresh': [0.0, 0.25], + 'clip_boxes': [False, True], 'random_coef': [0.375, 0.375], 'min_padded_size_ratio': [(1.0, 1.0), (1.0, 1.0)], 'max_padded_size_ratio': [(2.0, 2.0), (2.0, 2.0)], @@ -499,6 +513,7 @@ def test_build_ssd_random_crop_fixed_aspect_ratio(self): min_area: 0.5 max_area: 1.0 overlap_thresh: 0.0 + clip_boxes: False random_coef: 0.375 } operations { @@ -506,6 +521,7 @@ def test_build_ssd_random_crop_fixed_aspect_ratio(self): min_area: 0.5 max_area: 1.0 overlap_thresh: 0.25 + clip_boxes: True random_coef: 0.375 } aspect_ratio: 0.875 @@ -519,6 +535,7 @@ def test_build_ssd_random_crop_fixed_aspect_ratio(self): 'aspect_ratio': 0.875, 'area_range': [(0.5, 1.0), (0.5, 1.0)], 'overlap_thresh': [0.0, 0.25], + 'clip_boxes': [False, True], 'random_coef': [0.375, 0.375]}) def test_build_ssd_random_crop_pad_fixed_aspect_ratio(self): @@ -531,6 +548,7 @@ def test_build_ssd_random_crop_pad_fixed_aspect_ratio(self): min_area: 0.5 max_area: 1.0 overlap_thresh: 0.0 + clip_boxes: False random_coef: 0.375 } operations { @@ -540,6 +558,7 @@ def test_build_ssd_random_crop_pad_fixed_aspect_ratio(self): min_area: 0.5 max_area: 1.0 overlap_thresh: 0.25 + clip_boxes: True random_coef: 0.375 } aspect_ratio: 0.875 @@ -557,6 +576,7 @@ def test_build_ssd_random_crop_pad_fixed_aspect_ratio(self): 'aspect_ratio_range': [(0.875, 1.125), (0.75, 1.5)], 'area_range': [(0.5, 1.0), (0.5, 1.0)], 'overlap_thresh': [0.0, 0.25], + 'clip_boxes': [False, True], 'random_coef': [0.375, 0.375], 'min_padded_size_ratio': (1.0, 1.0), 'max_padded_size_ratio': (2.0, 2.0)}) diff --git a/research/object_detection/core/losses.py b/research/object_detection/core/losses.py index 83ad1ab8016..c7a85ac372f 100644 --- a/research/object_detection/core/losses.py +++ b/research/object_detection/core/losses.py @@ -225,7 +225,9 @@ def _compute_loss(self, num_classes] representing the predicted logits for each class target_tensor: A float tensor of shape [batch_size, num_anchors, num_classes] representing one-hot encoded classification targets - weights: a float tensor of shape [batch_size, num_anchors] + weights: a float tensor of shape, either [batch_size, num_anchors, + num_classes] or [batch_size, num_anchors, 1]. If the shape is + [batch_size, num_anchors, 1], all the classses are equally weighted. class_indices: (Optional) A 1-D integer tensor of class indices. If provided, computes loss only for the specified class indices. @@ -233,7 +235,6 @@ def _compute_loss(self, loss: a float tensor of shape [batch_size, num_anchors, num_classes] representing the value of the loss function. """ - weights = tf.expand_dims(weights, 2) if class_indices is not None: weights *= tf.reshape( ops.indices_to_dense_vector(class_indices, @@ -273,7 +274,9 @@ def _compute_loss(self, num_classes] representing the predicted logits for each class target_tensor: A float tensor of shape [batch_size, num_anchors, num_classes] representing one-hot encoded classification targets - weights: a float tensor of shape [batch_size, num_anchors] + weights: a float tensor of shape, either [batch_size, num_anchors, + num_classes] or [batch_size, num_anchors, 1]. If the shape is + [batch_size, num_anchors, 1], all the classses are equally weighted. class_indices: (Optional) A 1-D integer tensor of class indices. If provided, computes loss only for the specified class indices. @@ -281,7 +284,6 @@ def _compute_loss(self, loss: a float tensor of shape [batch_size, num_anchors, num_classes] representing the value of the loss function. """ - weights = tf.expand_dims(weights, 2) if class_indices is not None: weights *= tf.reshape( ops.indices_to_dense_vector(class_indices, @@ -326,12 +328,15 @@ def _compute_loss(self, prediction_tensor, target_tensor, weights): num_classes] representing the predicted logits for each class target_tensor: A float tensor of shape [batch_size, num_anchors, num_classes] representing one-hot encoded classification targets - weights: a float tensor of shape [batch_size, num_anchors] + weights: a float tensor of shape, either [batch_size, num_anchors, + num_classes] or [batch_size, num_anchors, 1]. If the shape is + [batch_size, num_anchors, 1], all the classses are equally weighted. Returns: loss: a float tensor of shape [batch_size, num_anchors] representing the value of the loss function. """ + weights = tf.reduce_mean(weights, axis=2) num_classes = prediction_tensor.get_shape().as_list()[-1] prediction_tensor = tf.divide( prediction_tensor, self._logit_scale, name='scale_logit') @@ -372,12 +377,15 @@ def _compute_loss(self, prediction_tensor, target_tensor, weights): num_classes] representing the predicted logits for each class target_tensor: A float tensor of shape [batch_size, num_anchors, num_classes] representing logit classification targets - weights: a float tensor of shape [batch_size, num_anchors] + weights: a float tensor of shape, either [batch_size, num_anchors, + num_classes] or [batch_size, num_anchors, 1]. If the shape is + [batch_size, num_anchors, 1], all the classses are equally weighted. Returns: loss: a float tensor of shape [batch_size, num_anchors] representing the value of the loss function. """ + weights = tf.reduce_mean(weights, axis=2) num_classes = prediction_tensor.get_shape().as_list()[-1] target_tensor = self._scale_and_softmax_logits(target_tensor) prediction_tensor = tf.divide(prediction_tensor, self._logit_scale, @@ -431,7 +439,9 @@ def _compute_loss(self, prediction_tensor, target_tensor, weights): num_classes] representing the predicted logits for each class target_tensor: A float tensor of shape [batch_size, num_anchors, num_classes] representing one-hot encoded classification targets - weights: a float tensor of shape [batch_size, num_anchors] + weights: a float tensor of shape, either [batch_size, num_anchors, + num_classes] or [batch_size, num_anchors, 1]. If the shape is + [batch_size, num_anchors, 1], all the classses are equally weighted. Returns: loss: a float tensor of shape [batch_size, num_anchors, num_classes] @@ -446,7 +456,7 @@ def _compute_loss(self, prediction_tensor, target_tensor, weights): tf.sigmoid(prediction_tensor) > 0.5, tf.float32) per_entry_cross_ent = (tf.nn.sigmoid_cross_entropy_with_logits( labels=bootstrap_target_tensor, logits=prediction_tensor)) - return per_entry_cross_ent * tf.expand_dims(weights, 2) + return per_entry_cross_ent * weights class HardExampleMiner(object): diff --git a/research/object_detection/core/losses_test.py b/research/object_detection/core/losses_test.py index 6f831c35f02..548d93cb155 100644 --- a/research/object_detection/core/losses_test.py +++ b/research/object_detection/core/losses_test.py @@ -209,8 +209,14 @@ def testReturnsCorrectLoss(self): [0, 1, 0], [1, 1, 1], [1, 0, 0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1], - [1, 1, 1, 0]], tf.float32) + weights = tf.constant([[[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 0, 0]]], tf.float32) loss_op = losses.WeightedSigmoidClassificationLoss() loss = loss_op(prediction_tensor, target_tensor, weights=weights) loss = tf.reduce_sum(loss) @@ -237,8 +243,14 @@ def testReturnsCorrectAnchorWiseLoss(self): [0, 1, 0], [1, 1, 1], [1, 0, 0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1], - [1, 1, 1, 0]], tf.float32) + weights = tf.constant([[[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 0, 0]]], tf.float32) loss_op = losses.WeightedSigmoidClassificationLoss() loss = loss_op(prediction_tensor, target_tensor, weights=weights) loss = tf.reduce_sum(loss, axis=2) @@ -266,8 +278,14 @@ def testReturnsCorrectLossWithClassIndices(self): [0, 1, 0, 0], [1, 1, 1, 0], [1, 0, 0, 0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1], - [1, 1, 1, 0]], tf.float32) + weights = tf.constant([[[1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [0, 0, 0, 0]]], tf.float32) # Ignores the last class. class_indices = tf.constant([0, 1, 2], tf.int32) loss_op = losses.WeightedSigmoidClassificationLoss() @@ -306,9 +324,18 @@ def testReturnsCorrectLossWithLossesMask(self): [0, 0, 0], [0, 0, 0], [0, 0, 0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1], - [1, 1, 1, 0], - [1, 1, 1, 1]], tf.float32) + weights = tf.constant([[[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 0, 0]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]]], tf.float32) losses_mask = tf.constant([True, True, False], tf.bool) loss_op = losses.WeightedSigmoidClassificationLoss() @@ -345,7 +372,7 @@ def testEasyExamplesProduceSmallLossComparedToSigmoidXEntropy(self): [0], [0], [0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1, 1, 1]], tf.float32) + weights = tf.constant([[[1], [1], [1], [1], [1], [1]]], tf.float32) focal_loss_op = losses.SigmoidFocalClassificationLoss(gamma=2.0, alpha=None) sigmoid_loss_op = losses.WeightedSigmoidClassificationLoss() focal_loss = tf.reduce_sum(focal_loss_op(prediction_tensor, target_tensor, @@ -371,7 +398,7 @@ def testHardExamplesProduceLossComparableToSigmoidXEntropy(self): [1], [0], [0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1, 1]], tf.float32) + weights = tf.constant([[[1], [1], [1], [1], [1]]], tf.float32) focal_loss_op = losses.SigmoidFocalClassificationLoss(gamma=2.0, alpha=None) sigmoid_loss_op = losses.WeightedSigmoidClassificationLoss() focal_loss = tf.reduce_sum(focal_loss_op(prediction_tensor, target_tensor, @@ -397,7 +424,7 @@ def testNonAnchorWiseOutputComparableToSigmoidXEntropy(self): [1], [0], [0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1, 1]], tf.float32) + weights = tf.constant([[[1], [1], [1], [1], [1]]], tf.float32) focal_loss_op = losses.SigmoidFocalClassificationLoss(gamma=2.0, alpha=None) sigmoid_loss_op = losses.WeightedSigmoidClassificationLoss() focal_loss = tf.reduce_sum(focal_loss_op(prediction_tensor, target_tensor, @@ -423,7 +450,7 @@ def testIgnoreNegativeExampleLossViaAlphaMultiplier(self): [1], [0], [0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1, 1]], tf.float32) + weights = tf.constant([[[1], [1], [1], [1], [1]]], tf.float32) focal_loss_op = losses.SigmoidFocalClassificationLoss(gamma=2.0, alpha=1.0) sigmoid_loss_op = losses.WeightedSigmoidClassificationLoss() focal_loss = tf.reduce_sum(focal_loss_op(prediction_tensor, target_tensor, @@ -451,7 +478,7 @@ def testIgnorePositiveExampleLossViaAlphaMultiplier(self): [1], [0], [0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1, 1]], tf.float32) + weights = tf.constant([[[1], [1], [1], [1], [1]]], tf.float32) focal_loss_op = losses.SigmoidFocalClassificationLoss(gamma=2.0, alpha=0.0) sigmoid_loss_op = losses.WeightedSigmoidClassificationLoss() focal_loss = tf.reduce_sum(focal_loss_op(prediction_tensor, target_tensor, @@ -485,8 +512,14 @@ def testSimilarToSigmoidXEntropyWithHalfAlphaAndZeroGammaUpToAScale(self): [0, 1, 0], [1, 1, 1], [1, 0, 0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1], - [1, 1, 1, 0]], tf.float32) + weights = tf.constant([[[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 0, 0]]], tf.float32) focal_loss_op = losses.SigmoidFocalClassificationLoss(alpha=0.5, gamma=0.0) sigmoid_loss_op = losses.WeightedSigmoidClassificationLoss() focal_loss = focal_loss_op(prediction_tensor, target_tensor, @@ -515,8 +548,14 @@ def testSameAsSigmoidXEntropyWithNoAlphaAndZeroGamma(self): [0, 1, 0], [1, 1, 1], [1, 0, 0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1], - [1, 1, 1, 0]], tf.float32) + weights = tf.constant([[[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 0, 0]]], tf.float32) focal_loss_op = losses.SigmoidFocalClassificationLoss(alpha=None, gamma=0.0) sigmoid_loss_op = losses.WeightedSigmoidClassificationLoss() focal_loss = focal_loss_op(prediction_tensor, target_tensor, @@ -546,8 +585,14 @@ def testExpectedLossWithAlphaOneAndZeroGamma(self): [0, 1, 0], [1, 0, 0], [1, 0, 0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1], - [1, 1, 1, 1]], tf.float32) + weights = tf.constant([[[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]]], tf.float32) focal_loss_op = losses.SigmoidFocalClassificationLoss(alpha=1.0, gamma=0.0) focal_loss = tf.reduce_sum(focal_loss_op(prediction_tensor, target_tensor, @@ -578,8 +623,14 @@ def testExpectedLossWithAlpha75AndZeroGamma(self): [0, 1, 0], [1, 0, 0], [1, 0, 0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1], - [1, 1, 1, 1]], tf.float32) + weights = tf.constant([[[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]]], tf.float32) focal_loss_op = losses.SigmoidFocalClassificationLoss(alpha=0.75, gamma=0.0) focal_loss = tf.reduce_sum(focal_loss_op(prediction_tensor, target_tensor, @@ -620,9 +671,18 @@ def testExpectedLossWithLossesMask(self): [1, 0, 0], [1, 0, 0], [1, 0, 0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1], - [1, 1, 1, 1], - [1, 1, 1, 1]], tf.float32) + weights = tf.constant([[[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]]], tf.float32) losses_mask = tf.constant([True, True, False], tf.bool) focal_loss_op = losses.SigmoidFocalClassificationLoss(alpha=0.75, gamma=0.0) @@ -659,8 +719,14 @@ def testReturnsCorrectLoss(self): [0, 1, 0], [0, 1, 0], [1, 0, 0]]], tf.float32) - weights = tf.constant([[1, 1, .5, 1], - [1, 1, 1, 0]], tf.float32) + weights = tf.constant([[[1, 1, 1], + [1, 1, 1], + [0.5, 0.5, 0.5], + [1, 1, 1]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 0, 0]]], tf.float32) loss_op = losses.WeightedSoftmaxClassificationLoss() loss = loss_op(prediction_tensor, target_tensor, weights=weights) loss = tf.reduce_sum(loss) @@ -687,8 +753,14 @@ def testReturnsCorrectAnchorWiseLoss(self): [0, 1, 0], [0, 1, 0], [1, 0, 0]]], tf.float32) - weights = tf.constant([[1, 1, .5, 1], - [1, 1, 1, 0]], tf.float32) + weights = tf.constant([[[1, 1, 1], + [1, 1, 1], + [0.5, 0.5, 0.5], + [1, 1, 1]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 0, 0]]], tf.float32) loss_op = losses.WeightedSoftmaxClassificationLoss() loss = loss_op(prediction_tensor, target_tensor, weights=weights) @@ -718,8 +790,14 @@ def testReturnsCorrectAnchorWiseLossWithHighLogitScaleSetting(self): [0, 1, 0], [0, 1, 0], [1, 0, 0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1], - [1, 1, 1, 1]], tf.float32) + weights = tf.constant([[[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]]], tf.float32) loss_op = losses.WeightedSoftmaxClassificationLoss(logit_scale=logit_scale) loss = loss_op(prediction_tensor, target_tensor, weights=weights) @@ -755,9 +833,18 @@ def testReturnsCorrectLossWithLossesMask(self): [1, 0, 0], [1, 0, 0], [1, 0, 0]]], tf.float32) - weights = tf.constant([[1, 1, .5, 1], - [1, 1, 1, 0], - [1, 1, 1, 1]], tf.float32) + weights = tf.constant([[[1, 1, 1], + [1, 1, 1], + [0.5, 0.5, 0.5], + [1, 1, 1]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 0, 0]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]]], tf.float32) losses_mask = tf.constant([True, True, False], tf.bool) loss_op = losses.WeightedSoftmaxClassificationLoss() loss = loss_op(prediction_tensor, target_tensor, weights=weights, @@ -792,6 +879,11 @@ def testReturnsCorrectLoss(self): [100, -100, -100]]], tf.float32) weights = tf.constant([[1, 1, .5, 1], [1, 1, 1, 1]], tf.float32) + weights_shape = tf.shape(weights) + weights_multiple = tf.concat( + [tf.ones_like(weights_shape), tf.constant([3])], + axis=0) + weights = tf.tile(tf.expand_dims(weights, 2), weights_multiple) loss_op = losses.WeightedSoftmaxClassificationAgainstLogitsLoss() loss = loss_op(prediction_tensor, target_tensor, weights=weights) loss = tf.reduce_sum(loss) @@ -820,6 +912,11 @@ def testReturnsCorrectAnchorWiseLoss(self): [100, -100, -100]]], tf.float32) weights = tf.constant([[1, 1, .5, 1], [1, 1, 1, 0]], tf.float32) + weights_shape = tf.shape(weights) + weights_multiple = tf.concat( + [tf.ones_like(weights_shape), tf.constant([3])], + axis=0) + weights = tf.tile(tf.expand_dims(weights, 2), weights_multiple) loss_op = losses.WeightedSoftmaxClassificationAgainstLogitsLoss() loss = loss_op(prediction_tensor, target_tensor, weights=weights) @@ -849,6 +946,11 @@ def testReturnsCorrectAnchorWiseLossWithLogitScaleSetting(self): [100, -100, -100]]], tf.float32) weights = tf.constant([[1, 1, .5, 1], [1, 1, 1, 0]], tf.float32) + weights_shape = tf.shape(weights) + weights_multiple = tf.concat( + [tf.ones_like(weights_shape), tf.constant([3])], + axis=0) + weights = tf.tile(tf.expand_dims(weights, 2), weights_multiple) loss_op = losses.WeightedSoftmaxClassificationAgainstLogitsLoss( logit_scale=logit_scale) loss = loss_op(prediction_tensor, target_tensor, weights=weights) @@ -894,8 +996,14 @@ def testReturnsCorrectLossSoftBootstrapping(self): [0, 1, 0], [1, 1, 1], [1, 0, 0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1], - [1, 1, 1, 0]], tf.float32) + weights = tf.constant([[[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 0, 0]]], tf.float32) alpha = tf.constant(.5, tf.float32) loss_op = losses.BootstrappedSigmoidClassificationLoss( alpha, bootstrap_type='soft') @@ -923,8 +1031,14 @@ def testReturnsCorrectLossHardBootstrapping(self): [0, 1, 0], [1, 1, 1], [1, 0, 0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1], - [1, 1, 1, 0]], tf.float32) + weights = tf.constant([[[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 0, 0]]], tf.float32) alpha = tf.constant(.5, tf.float32) loss_op = losses.BootstrappedSigmoidClassificationLoss( alpha, bootstrap_type='hard') @@ -952,8 +1066,14 @@ def testReturnsCorrectAnchorWiseLoss(self): [0, 1, 0], [1, 1, 1], [1, 0, 0]]], tf.float32) - weights = tf.constant([[1, 1, 1, 1], - [1, 1, 1, 0]], tf.float32) + weights = tf.constant([[[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]], + [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [0, 0, 0]]], tf.float32) alpha = tf.constant(.5, tf.float32) loss_op = losses.BootstrappedSigmoidClassificationLoss( alpha, bootstrap_type='hard') diff --git a/research/object_detection/core/matcher.py b/research/object_detection/core/matcher.py index 65cc42186a8..61b5da61c56 100644 --- a/research/object_detection/core/matcher.py +++ b/research/object_detection/core/matcher.py @@ -197,8 +197,10 @@ def gather_based_on_match(self, input_tensor, unmatched_value, The shape of the gathered tensor is [match_results.shape[0]] + input_tensor.shape[1:]. """ - input_tensor = tf.concat([tf.stack([ignored_value, unmatched_value]), - input_tensor], axis=0) + input_tensor = tf.concat( + [tf.stack([ignored_value, unmatched_value]), + tf.to_float(input_tensor)], + axis=0) gather_indices = tf.maximum(self.match_results + 2, 0) gathered_tensor = self._gather_op(input_tensor, gather_indices) return gathered_tensor diff --git a/research/object_detection/core/model.py b/research/object_detection/core/model.py index 5bc1fe19147..4cd6404717e 100644 --- a/research/object_detection/core/model.py +++ b/research/object_detection/core/model.py @@ -289,6 +289,18 @@ def provide_groundtruth(self, self._groundtruth_lists[ fields.InputDataFields.is_annotated] = is_annotated_list + @abstractmethod + def regularization_losses(self): + """Returns a list of regularization losses for this model. + + Returns a list of regularization losses for this model that the estimator + needs to use during training/optimization. + + Returns: + A list of regularization loss tensors. + """ + pass + @abstractmethod def restore_map(self, fine_tune_checkpoint_type='detection'): """Returns a map of variables to load from a foreign checkpoint. @@ -312,3 +324,16 @@ def restore_map(self, fine_tune_checkpoint_type='detection'): the model graph. """ pass + + @abstractmethod + def updates(self): + """Returns a list of update operators for this model. + + Returns a list of update operators for this model that must be executed at + each training step. The estimator's train op needs to have a control + dependency on these updates. + + Returns: + A list of update operators. + """ + pass diff --git a/research/object_detection/core/post_processing.py b/research/object_detection/core/post_processing.py index c6c491285d3..2077585987e 100644 --- a/research/object_detection/core/post_processing.py +++ b/research/object_detection/core/post_processing.py @@ -15,6 +15,7 @@ """Post-processing operations on detected boxes.""" +import numpy as np import tensorflow as tf from object_detection.core import box_list @@ -407,28 +408,36 @@ def _single_image_nms_fn(args): for key, value in zip(additional_fields, args[4:-1]) } per_image_num_valid_boxes = args[-1] - per_image_boxes = tf.reshape( - tf.slice(per_image_boxes, 3 * [0], - tf.stack([per_image_num_valid_boxes, -1, -1])), [-1, q, 4]) - per_image_scores = tf.reshape( - tf.slice(per_image_scores, [0, 0], - tf.stack([per_image_num_valid_boxes, -1])), - [-1, num_classes]) - per_image_masks = tf.reshape( - tf.slice(per_image_masks, 4 * [0], - tf.stack([per_image_num_valid_boxes, -1, -1, -1])), - [-1, q, per_image_masks.shape[2].value, - per_image_masks.shape[3].value]) - if per_image_additional_fields is not None: - for key, tensor in per_image_additional_fields.items(): - additional_field_shape = tensor.get_shape() - additional_field_dim = len(additional_field_shape) - per_image_additional_fields[key] = tf.reshape( - tf.slice(per_image_additional_fields[key], - additional_field_dim * [0], - tf.stack([per_image_num_valid_boxes] + - (additional_field_dim - 1) * [-1])), - [-1] + [dim.value for dim in additional_field_shape[1:]]) + if use_static_shapes: + total_proposals = tf.shape(per_image_scores) + per_image_scores = tf.where( + tf.less(tf.range(total_proposals[0]), per_image_num_valid_boxes), + per_image_scores, + tf.fill(total_proposals, np.finfo('float32').min)) + else: + per_image_boxes = tf.reshape( + tf.slice(per_image_boxes, 3 * [0], + tf.stack([per_image_num_valid_boxes, -1, -1])), [-1, q, 4]) + per_image_scores = tf.reshape( + tf.slice(per_image_scores, [0, 0], + tf.stack([per_image_num_valid_boxes, -1])), + [-1, num_classes]) + per_image_masks = tf.reshape( + tf.slice(per_image_masks, 4 * [0], + tf.stack([per_image_num_valid_boxes, -1, -1, -1])), + [-1, q, per_image_masks.shape[2].value, + per_image_masks.shape[3].value]) + if per_image_additional_fields is not None: + for key, tensor in per_image_additional_fields.items(): + additional_field_shape = tensor.get_shape() + additional_field_dim = len(additional_field_shape) + per_image_additional_fields[key] = tf.reshape( + tf.slice(per_image_additional_fields[key], + additional_field_dim * [0], + tf.stack([per_image_num_valid_boxes] + + (additional_field_dim - 1) * [-1])), + [-1] + [dim.value for dim in additional_field_shape[1:]]) + nmsed_boxlist, num_valid_nms_boxes = multiclass_non_max_suppression( per_image_boxes, per_image_scores, diff --git a/research/object_detection/core/preprocessor.py b/research/object_detection/core/preprocessor.py index 1b6eccdf963..140c62b57e7 100644 --- a/research/object_detection/core/preprocessor.py +++ b/research/object_detection/core/preprocessor.py @@ -1108,7 +1108,7 @@ def random_jitter_box(box, ratio, seed): def _strict_random_crop_image(image, boxes, labels, - label_scores=None, + label_scores, multiclass_scores=None, masks=None, keypoints=None, @@ -1116,14 +1116,14 @@ def _strict_random_crop_image(image, aspect_ratio_range=(0.75, 1.33), area_range=(0.1, 1.0), overlap_thresh=0.3, + clip_boxes=True, preprocess_vars_cache=None): """Performs random crop. - Note: boxes will be clipped to the crop. Keypoint coordinates that are - outside the crop will be set to NaN, which is consistent with the original - keypoint encoding for non-existing keypoints. This function always crops - the image and is supposed to be used by `random_crop_image` function which - sometimes returns image unchanged. + Note: Keypoint coordinates that are outside the crop will be set to NaN, which + is consistent with the original keypoint encoding for non-existing keypoints. + This function always crops the image and is supposed to be used by + `random_crop_image` function which sometimes returns the image unchanged. Args: image: rank 3 float32 tensor containing 1 image -> [height, width, channels] @@ -1152,6 +1152,7 @@ def _strict_random_crop_image(image, original image. overlap_thresh: minimum overlap thresh with new cropped image to keep the box. + clip_boxes: whether to clip the boxes to the cropped image. preprocess_vars_cache: PreprocessorCache object that records previously performed augmentations. Updated in-place. If this function is called multiple times with the same @@ -1232,8 +1233,9 @@ def _strict_random_crop_image(image, new_boxlist = box_list_ops.change_coordinate_frame(overlapping_boxlist, im_box_rank1) new_boxes = new_boxlist.get() - new_boxes = tf.clip_by_value( - new_boxes, clip_value_min=0.0, clip_value_max=1.0) + if clip_boxes: + new_boxes = tf.clip_by_value( + new_boxes, clip_value_min=0.0, clip_value_max=1.0) result = [new_image, new_boxes, new_labels] @@ -1262,8 +1264,9 @@ def _strict_random_crop_image(image, keypoints_of_boxes_inside_window, keep_ids) new_keypoints = keypoint_ops.change_coordinate_frame( keypoints_of_boxes_completely_inside_window, im_box_rank1) - new_keypoints = keypoint_ops.prune_outside_window(new_keypoints, - [0.0, 0.0, 1.0, 1.0]) + if clip_boxes: + new_keypoints = keypoint_ops.prune_outside_window(new_keypoints, + [0.0, 0.0, 1.0, 1.0]) result.append(new_keypoints) return tuple(result) @@ -1280,6 +1283,7 @@ def random_crop_image(image, aspect_ratio_range=(0.75, 1.33), area_range=(0.1, 1.0), overlap_thresh=0.3, + clip_boxes=True, random_coef=0.0, seed=None, preprocess_vars_cache=None): @@ -1294,9 +1298,8 @@ def random_crop_image(image, form (e.g., lie in the unit square [0, 1]). This function will return the original image with probability random_coef. - Note: boxes will be clipped to the crop. Keypoint coordinates that are - outside the crop will be set to NaN, which is consistent with the original - keypoint encoding for non-existing keypoints. + Note: Keypoint coordinates that are outside the crop will be set to NaN, which + is consistent with the original keypoint encoding for non-existing keypoints. Args: image: rank 3 float32 tensor contains 1 image -> [height, width, channels] @@ -1325,6 +1328,7 @@ def random_crop_image(image, original image. overlap_thresh: minimum overlap thresh with new cropped image to keep the box. + clip_boxes: whether to clip the boxes to the cropped image. random_coef: a random coefficient that defines the chance of getting the original image. If random_coef is 0, we will always get the cropped image, and if it is 1.0, we will always get the @@ -1365,6 +1369,7 @@ def strict_random_crop_image_fn(): aspect_ratio_range=aspect_ratio_range, area_range=area_range, overlap_thresh=overlap_thresh, + clip_boxes=clip_boxes, preprocess_vars_cache=preprocess_vars_cache) # avoids tf.cond to make faster RCNN training on borg. See b/140057645. @@ -1515,12 +1520,13 @@ def random_pad_image(image, def random_crop_pad_image(image, boxes, labels, - label_scores=None, + label_scores, multiclass_scores=None, min_object_covered=1.0, aspect_ratio_range=(0.75, 1.33), area_range=(0.1, 1.0), overlap_thresh=0.3, + clip_boxes=True, random_coef=0.0, min_padded_size_ratio=(1.0, 1.0), max_padded_size_ratio=(2.0, 2.0), @@ -1558,6 +1564,7 @@ def random_crop_pad_image(image, original image. overlap_thresh: minimum overlap thresh with new cropped image to keep the box. + clip_boxes: whether to clip the boxes to the cropped image. random_coef: a random coefficient that defines the chance of getting the original image. If random_coef is 0, we will always get the cropped image, and if it is 1.0, we will always get the @@ -1599,6 +1606,7 @@ def random_crop_pad_image(image, aspect_ratio_range=aspect_ratio_range, area_range=area_range, overlap_thresh=overlap_thresh, + clip_boxes=clip_boxes, random_coef=random_coef, seed=seed, preprocess_vars_cache=preprocess_vars_cache) @@ -1639,12 +1647,13 @@ def random_crop_pad_image(image, def random_crop_to_aspect_ratio(image, boxes, labels, - label_scores=None, + label_scores, multiclass_scores=None, masks=None, keypoints=None, aspect_ratio=1.0, overlap_thresh=0.3, + clip_boxes=True, seed=None, preprocess_vars_cache=None): """Randomly crops an image to the specified aspect ratio. @@ -1680,6 +1689,7 @@ def random_crop_to_aspect_ratio(image, aspect_ratio: the aspect ratio of cropped image. overlap_thresh: minimum overlap thresh with new cropped image to keep the box. + clip_boxes: whether to clip the boxes to the cropped image. seed: random seed. preprocess_vars_cache: PreprocessorCache object that records previously performed augmentations. Updated in-place. If this @@ -1767,9 +1777,9 @@ def target_width_fn(): new_labels = overlapping_boxlist.get_field('labels') new_boxlist = box_list_ops.change_coordinate_frame(overlapping_boxlist, im_box) - new_boxlist = box_list_ops.clip_to_window(new_boxlist, - tf.constant([0.0, 0.0, 1.0, 1.0], - tf.float32)) + if clip_boxes: + new_boxlist = box_list_ops.clip_to_window( + new_boxlist, tf.constant([0.0, 0.0, 1.0, 1.0], tf.float32)) new_boxes = new_boxlist.get() result = [new_image, new_boxes, new_labels] @@ -1793,8 +1803,9 @@ def target_width_fn(): keypoints_inside_window = tf.gather(keypoints, keep_ids) new_keypoints = keypoint_ops.change_coordinate_frame( keypoints_inside_window, im_box) - new_keypoints = keypoint_ops.prune_outside_window(new_keypoints, - [0.0, 0.0, 1.0, 1.0]) + if clip_boxes: + new_keypoints = keypoint_ops.prune_outside_window(new_keypoints, + [0.0, 0.0, 1.0, 1.0]) result.append(new_keypoints) return tuple(result) @@ -2432,7 +2443,7 @@ def rgb_to_gray(image): def ssd_random_crop(image, boxes, labels, - label_scores=None, + label_scores, multiclass_scores=None, masks=None, keypoints=None, @@ -2440,6 +2451,7 @@ def ssd_random_crop(image, aspect_ratio_range=((0.5, 2.0),) * 7, area_range=((0.1, 1.0),) * 7, overlap_thresh=(0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0), + clip_boxes=(True,) * 7, random_coef=(0.15,) * 7, seed=None, preprocess_vars_cache=None): @@ -2474,6 +2486,7 @@ def ssd_random_crop(image, original image. overlap_thresh: minimum overlap thresh with new cropped image to keep the box. + clip_boxes: whether to clip the boxes to the cropped image. random_coef: a random coefficient that defines the chance of getting the original image. If random_coef is 0, we will always get the cropped image, and if it is 1.0, we will always get the @@ -2543,6 +2556,7 @@ def random_crop_selector(selected_result, index): aspect_ratio_range=aspect_ratio_range[index], area_range=area_range[index], overlap_thresh=overlap_thresh[index], + clip_boxes=clip_boxes[index], random_coef=random_coef[index], seed=seed, preprocess_vars_cache=preprocess_vars_cache) @@ -2561,12 +2575,13 @@ def random_crop_selector(selected_result, index): def ssd_random_crop_pad(image, boxes, labels, - label_scores=None, + label_scores, multiclass_scores=None, min_object_covered=(0.1, 0.3, 0.5, 0.7, 0.9, 1.0), aspect_ratio_range=((0.5, 2.0),) * 6, area_range=((0.1, 1.0),) * 6, overlap_thresh=(0.1, 0.3, 0.5, 0.7, 0.9, 1.0), + clip_boxes=(True,) * 6, random_coef=(0.15,) * 6, min_padded_size_ratio=((1.0, 1.0),) * 6, max_padded_size_ratio=((2.0, 2.0),) * 6, @@ -2599,6 +2614,7 @@ def ssd_random_crop_pad(image, original image. overlap_thresh: minimum overlap thresh with new cropped image to keep the box. + clip_boxes: whether to clip the boxes to the cropped image. random_coef: a random coefficient that defines the chance of getting the original image. If random_coef is 0, we will always get the cropped image, and if it is 1.0, we will always get the @@ -2646,6 +2662,7 @@ def random_crop_pad_selector(image_boxes_labels, index): aspect_ratio_range=aspect_ratio_range[index], area_range=area_range[index], overlap_thresh=overlap_thresh[index], + clip_boxes=clip_boxes[index], random_coef=random_coef[index], min_padded_size_ratio=min_padded_size_ratio[index], max_padded_size_ratio=max_padded_size_ratio[index], @@ -2666,7 +2683,7 @@ def ssd_random_crop_fixed_aspect_ratio( image, boxes, labels, - label_scores=None, + label_scores, multiclass_scores=None, masks=None, keypoints=None, @@ -2674,6 +2691,7 @@ def ssd_random_crop_fixed_aspect_ratio( aspect_ratio=1.0, area_range=((0.1, 1.0),) * 7, overlap_thresh=(0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0), + clip_boxes=(True,) * 7, random_coef=(0.15,) * 7, seed=None, preprocess_vars_cache=None): @@ -2711,6 +2729,7 @@ def ssd_random_crop_fixed_aspect_ratio( original image. overlap_thresh: minimum overlap thresh with new cropped image to keep the box. + clip_boxes: whether to clip the boxes to the cropped image. random_coef: a random coefficient that defines the chance of getting the original image. If random_coef is 0, we will always get the cropped image, and if it is 1.0, we will always get the @@ -2751,6 +2770,7 @@ def ssd_random_crop_fixed_aspect_ratio( aspect_ratio_range=aspect_ratio_range, area_range=area_range, overlap_thresh=overlap_thresh, + clip_boxes=clip_boxes, random_coef=random_coef, seed=seed, preprocess_vars_cache=preprocess_vars_cache) @@ -2781,6 +2801,7 @@ def ssd_random_crop_fixed_aspect_ratio( masks=new_masks, keypoints=new_keypoints, aspect_ratio=aspect_ratio, + clip_boxes=clip_boxes, seed=seed, preprocess_vars_cache=preprocess_vars_cache) @@ -2791,7 +2812,7 @@ def ssd_random_crop_pad_fixed_aspect_ratio( image, boxes, labels, - label_scores=None, + label_scores, multiclass_scores=None, masks=None, keypoints=None, @@ -2800,6 +2821,7 @@ def ssd_random_crop_pad_fixed_aspect_ratio( aspect_ratio_range=((0.5, 2.0),) * 7, area_range=((0.1, 1.0),) * 7, overlap_thresh=(0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0), + clip_boxes=(True,) * 7, random_coef=(0.15,) * 7, min_padded_size_ratio=(1.0, 1.0), max_padded_size_ratio=(2.0, 2.0), @@ -2841,6 +2863,7 @@ def ssd_random_crop_pad_fixed_aspect_ratio( original image. overlap_thresh: minimum overlap thresh with new cropped image to keep the box. + clip_boxes: whether to clip the boxes to the cropped image. random_coef: a random coefficient that defines the chance of getting the original image. If random_coef is 0, we will always get the cropped image, and if it is 1.0, we will always get the @@ -2882,6 +2905,7 @@ def ssd_random_crop_pad_fixed_aspect_ratio( aspect_ratio_range=aspect_ratio_range, area_range=area_range, overlap_thresh=overlap_thresh, + clip_boxes=clip_boxes, random_coef=random_coef, seed=seed, preprocess_vars_cache=preprocess_vars_cache) @@ -2950,7 +2974,7 @@ def convert_class_logits_to_softmax(multiclass_scores, temperature=1.0): return multiclass_scores -def get_default_func_arg_map(include_label_scores=False, +def get_default_func_arg_map(include_label_scores=True, include_multiclass_scores=False, include_instance_masks=False, include_keypoints=False): @@ -2972,7 +2996,7 @@ def get_default_func_arg_map(include_label_scores=False, groundtruth_label_scores = None if include_label_scores: groundtruth_label_scores = ( - fields.InputDataFields.groundtruth_confidences) + fields.InputDataFields.groundtruth_weights) multiclass_scores = None if include_multiclass_scores: diff --git a/research/object_detection/core/preprocessor_test.py b/research/object_detection/core/preprocessor_test.py index 5a1e5936bb8..9d3a0f55fc1 100644 --- a/research/object_detection/core/preprocessor_test.py +++ b/research/object_detection/core/preprocessor_test.py @@ -70,12 +70,9 @@ def createTestBoxes(self): [[0.0, 0.25, 0.75, 1.0], [0.25, 0.5, 0.75, 1.0]], dtype=tf.float32) return boxes - def createTestLabelScores(self): + def createTestGroundtruthWeights(self): return tf.constant([1.0, 0.5], dtype=tf.float32) - def createTestLabelScoresWithMissingScore(self): - return tf.constant([0.5, np.nan], dtype=tf.float32) - def createTestMasks(self): mask = np.array([ [[255.0, 0.0, 0.0], @@ -332,15 +329,15 @@ def testNormalizeImage(self): def testRetainBoxesAboveThreshold(self): boxes = self.createTestBoxes() labels = self.createTestLabels() - label_scores = self.createTestLabelScores() + weights = self.createTestGroundtruthWeights() (retained_boxes, retained_labels, - retained_label_scores) = preprocessor.retain_boxes_above_threshold( - boxes, labels, label_scores, threshold=0.6) + retained_weights) = preprocessor.retain_boxes_above_threshold( + boxes, labels, weights, threshold=0.6) with self.test_session() as sess: - (retained_boxes_, retained_labels_, retained_label_scores_, + (retained_boxes_, retained_labels_, retained_weights_, expected_retained_boxes_, expected_retained_labels_, - expected_retained_label_scores_) = sess.run([ - retained_boxes, retained_labels, retained_label_scores, + expected_retained_weights_) = sess.run([ + retained_boxes, retained_labels, retained_weights, self.expectedBoxesAfterThresholding(), self.expectedLabelsAfterThresholding(), self.expectedLabelScoresAfterThresholding()]) @@ -349,18 +346,18 @@ def testRetainBoxesAboveThreshold(self): self.assertAllClose( retained_labels_, expected_retained_labels_) self.assertAllClose( - retained_label_scores_, expected_retained_label_scores_) + retained_weights_, expected_retained_weights_) def testRetainBoxesAboveThresholdWithMultiClassScores(self): boxes = self.createTestBoxes() labels = self.createTestLabels() - label_scores = self.createTestLabelScores() + weights = self.createTestGroundtruthWeights() multiclass_scores = self.createTestMultiClassScores() (_, _, _, retained_multiclass_scores) = preprocessor.retain_boxes_above_threshold( boxes, labels, - label_scores, + weights, multiclass_scores=multiclass_scores, threshold=0.6) with self.test_session() as sess: @@ -376,10 +373,10 @@ def testRetainBoxesAboveThresholdWithMultiClassScores(self): def testRetainBoxesAboveThresholdWithMasks(self): boxes = self.createTestBoxes() labels = self.createTestLabels() - label_scores = self.createTestLabelScores() + weights = self.createTestGroundtruthWeights() masks = self.createTestMasks() _, _, _, retained_masks = preprocessor.retain_boxes_above_threshold( - boxes, labels, label_scores, masks, threshold=0.6) + boxes, labels, weights, masks, threshold=0.6) with self.test_session() as sess: retained_masks_, expected_retained_masks_ = sess.run([ retained_masks, self.expectedMasksAfterThresholding()]) @@ -390,10 +387,10 @@ def testRetainBoxesAboveThresholdWithMasks(self): def testRetainBoxesAboveThresholdWithKeypoints(self): boxes = self.createTestBoxes() labels = self.createTestLabels() - label_scores = self.createTestLabelScores() + weights = self.createTestGroundtruthWeights() keypoints = self.createTestKeypoints() (_, _, _, retained_keypoints) = preprocessor.retain_boxes_above_threshold( - boxes, labels, label_scores, keypoints=keypoints, threshold=0.6) + boxes, labels, weights, keypoints=keypoints, threshold=0.6) with self.test_session() as sess: (retained_keypoints_, expected_retained_keypoints_) = sess.run([ @@ -403,28 +400,6 @@ def testRetainBoxesAboveThresholdWithKeypoints(self): self.assertAllClose( retained_keypoints_, expected_retained_keypoints_) - def testRetainBoxesAboveThresholdWithMissingScore(self): - boxes = self.createTestBoxes() - labels = self.createTestLabels() - label_scores = self.createTestLabelScoresWithMissingScore() - (retained_boxes, retained_labels, - retained_label_scores) = preprocessor.retain_boxes_above_threshold( - boxes, labels, label_scores, threshold=0.6) - with self.test_session() as sess: - (retained_boxes_, retained_labels_, retained_label_scores_, - expected_retained_boxes_, expected_retained_labels_, - expected_retained_label_scores_) = sess.run([ - retained_boxes, retained_labels, retained_label_scores, - self.expectedBoxesAfterThresholdingWithMissingScore(), - self.expectedLabelsAfterThresholdingWithMissingScore(), - self.expectedLabelScoresAfterThresholdingWithMissingScore()]) - self.assertAllClose( - retained_boxes_, expected_retained_boxes_) - self.assertAllClose( - retained_labels_, expected_retained_labels_) - self.assertAllClose( - retained_label_scores_, expected_retained_label_scores_) - def testFlipBoxesLeftRight(self): boxes = self.createTestBoxes() flipped_boxes = preprocessor._flip_boxes_left_right(boxes) @@ -482,6 +457,7 @@ def _testPreprocessorCache(self, cache = preprocessor_cache.PreprocessorCache() images = self.createTestImages() boxes = self.createTestBoxes() + weights = self.createTestGroundtruthWeights() classes = self.createTestLabels() masks = self.createTestMasks() keypoints = self.createTestKeypoints() @@ -491,6 +467,7 @@ def _testPreprocessorCache(self, for i in range(num_runs): tensor_dict = { fields.InputDataFields.image: images, + fields.InputDataFields.groundtruth_weights: weights } num_outputs = 1 if test_boxes: @@ -1075,10 +1052,12 @@ def testRandomCropImage(self): images = self.createTestImages() boxes = self.createTestBoxes() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() tensor_dict = { fields.InputDataFields.image: images, fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, + fields.InputDataFields.groundtruth_weights: weights, } distorted_tensor_dict = preprocessor.preprocess(tensor_dict, preprocessing_options) @@ -1126,10 +1105,12 @@ def testRandomCropImageGrayscale(self): images = self.createTestImages() boxes = self.createTestBoxes() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() tensor_dict = { fields.InputDataFields.image: images, fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, + fields.InputDataFields.groundtruth_weights: weights, } distorted_tensor_dict = preprocessor.preprocess( tensor_dict, preprocessing_options) @@ -1163,10 +1144,12 @@ def testRandomCropImageWithBoxOutOfImage(self): images = self.createTestImages() boxes = self.createTestBoxesOutOfImage() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() tensor_dict = { fields.InputDataFields.image: images, fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, + fields.InputDataFields.groundtruth_weights: weights, } distorted_tensor_dict = preprocessor.preprocess(tensor_dict, preprocessing_options) @@ -1197,12 +1180,12 @@ def testRandomCropImageWithRandomCoefOne(self): images = self.createTestImages() boxes = self.createTestBoxes() labels = self.createTestLabels() - label_scores = self.createTestLabelScores() + weights = self.createTestGroundtruthWeights() tensor_dict = { fields.InputDataFields.image: images, fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, - fields.InputDataFields.groundtruth_label_scores: label_scores + fields.InputDataFields.groundtruth_weights: weights } tensor_dict = preprocessor.preprocess(tensor_dict, preprocessing_options) images = tensor_dict[fields.InputDataFields.image] @@ -1218,8 +1201,8 @@ def testRandomCropImageWithRandomCoefOne(self): fields.InputDataFields.groundtruth_boxes] distorted_labels = distorted_tensor_dict[ fields.InputDataFields.groundtruth_classes] - distorted_label_scores = distorted_tensor_dict[ - fields.InputDataFields.groundtruth_label_scores] + distorted_weights = distorted_tensor_dict[ + fields.InputDataFields.groundtruth_weights] boxes_shape = tf.shape(boxes) distorted_boxes_shape = tf.shape(distorted_boxes) images_shape = tf.shape(images) @@ -1229,17 +1212,17 @@ def testRandomCropImageWithRandomCoefOne(self): (boxes_shape_, distorted_boxes_shape_, images_shape_, distorted_images_shape_, images_, distorted_images_, boxes_, distorted_boxes_, labels_, distorted_labels_, - label_scores_, distorted_label_scores_) = sess.run( + weights_, distorted_weights_) = sess.run( [boxes_shape, distorted_boxes_shape, images_shape, distorted_images_shape, images, distorted_images, boxes, distorted_boxes, labels, distorted_labels, - label_scores, distorted_label_scores]) + weights, distorted_weights]) self.assertAllEqual(boxes_shape_, distorted_boxes_shape_) self.assertAllEqual(images_shape_, distorted_images_shape_) self.assertAllClose(images_, distorted_images_) self.assertAllClose(boxes_, distorted_boxes_) self.assertAllEqual(labels_, distorted_labels_) - self.assertAllEqual(label_scores_, distorted_label_scores_) + self.assertAllEqual(weights_, distorted_weights_) def testRandomCropWithMockSampleDistortedBoundingBox(self): preprocessing_options = [(preprocessor.normalize_image, { @@ -1254,11 +1237,13 @@ def testRandomCropWithMockSampleDistortedBoundingBox(self): [0.2, 0.4, 0.75, 0.75], [0.3, 0.1, 0.4, 0.7]], dtype=tf.float32) labels = tf.constant([1, 7, 11], dtype=tf.int32) + weights = tf.constant([1.0, 0.5, 0.6], dtype=tf.float32) tensor_dict = { fields.InputDataFields.image: images, fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, + fields.InputDataFields.groundtruth_weights: weights, } tensor_dict = preprocessor.preprocess(tensor_dict, preprocessing_options) images = tensor_dict[fields.InputDataFields.image] @@ -1279,18 +1264,98 @@ def testRandomCropWithMockSampleDistortedBoundingBox(self): fields.InputDataFields.groundtruth_boxes] distorted_labels = distorted_tensor_dict[ fields.InputDataFields.groundtruth_classes] + distorted_weights = distorted_tensor_dict[ + fields.InputDataFields.groundtruth_weights] expected_boxes = tf.constant([[0.178947, 0.07173, 0.75789469, 0.66244733], [0.28421, 0.0, 0.38947365, 0.57805908]], dtype=tf.float32) expected_labels = tf.constant([7, 11], dtype=tf.int32) + expected_weights = tf.constant([0.5, 0.6], dtype=tf.float32) with self.test_session() as sess: - (distorted_boxes_, distorted_labels_, - expected_boxes_, expected_labels_) = sess.run( - [distorted_boxes, distorted_labels, - expected_boxes, expected_labels]) + (distorted_boxes_, distorted_labels_, distorted_weights_, + expected_boxes_, expected_labels_, expected_weights_) = sess.run( + [distorted_boxes, distorted_labels, distorted_weights, + expected_boxes, expected_labels, expected_weights]) self.assertAllClose(distorted_boxes_, expected_boxes_) self.assertAllEqual(distorted_labels_, expected_labels_) + self.assertAllEqual(distorted_weights_, expected_weights_) + + def testRandomCropWithoutClipBoxes(self): + preprocessing_options = [(preprocessor.normalize_image, { + 'original_minval': 0, + 'original_maxval': 255, + 'target_minval': 0, + 'target_maxval': 1 + })] + + images = self.createColorfulTestImage() + boxes = tf.constant([[0.1, 0.1, 0.8, 0.3], + [0.2, 0.4, 0.75, 0.75], + [0.3, 0.1, 0.4, 0.7]], dtype=tf.float32) + keypoints = tf.constant([ + [[0.1, 0.1], [0.8, 0.3]], + [[0.2, 0.4], [0.75, 0.75]], + [[0.3, 0.1], [0.4, 0.7]], + ], dtype=tf.float32) + labels = tf.constant([1, 7, 11], dtype=tf.int32) + weights = tf.constant([1.0, 0.5, 0.6], dtype=tf.float32) + + tensor_dict = { + fields.InputDataFields.image: images, + fields.InputDataFields.groundtruth_boxes: boxes, + fields.InputDataFields.groundtruth_keypoints: keypoints, + fields.InputDataFields.groundtruth_classes: labels, + fields.InputDataFields.groundtruth_weights: weights, + } + tensor_dict = preprocessor.preprocess(tensor_dict, preprocessing_options) + + preprocessing_options = [(preprocessor.random_crop_image, { + 'clip_boxes': False, + })] + with mock.patch.object( + tf.image, + 'sample_distorted_bounding_box') as mock_sample_distorted_bounding_box: + mock_sample_distorted_bounding_box.return_value = (tf.constant( + [6, 143, 0], dtype=tf.int32), tf.constant( + [190, 237, -1], dtype=tf.int32), tf.constant( + [[[0.03, 0.3575, 0.98, 0.95]]], dtype=tf.float32)) + + preprocessor_arg_map = preprocessor.get_default_func_arg_map( + include_keypoints=True) + distorted_tensor_dict = preprocessor.preprocess( + tensor_dict, preprocessing_options, func_arg_map=preprocessor_arg_map) + + distorted_boxes = distorted_tensor_dict[ + fields.InputDataFields.groundtruth_boxes] + distorted_keypoints = distorted_tensor_dict[ + fields.InputDataFields.groundtruth_keypoints] + distorted_labels = distorted_tensor_dict[ + fields.InputDataFields.groundtruth_classes] + distorted_weights = distorted_tensor_dict[ + fields.InputDataFields.groundtruth_weights] + expected_boxes = tf.constant( + [[0.178947, 0.07173, 0.75789469, 0.66244733], + [0.28421, -0.434599, 0.38947365, 0.57805908]], + dtype=tf.float32) + expected_keypoints = tf.constant( + [[[0.178947, 0.07173], [0.75789469, 0.66244733]], + [[0.28421, -0.434599], [0.38947365, 0.57805908]]], + dtype=tf.float32) + expected_labels = tf.constant([7, 11], dtype=tf.int32) + expected_weights = tf.constant([0.5, 0.6], dtype=tf.float32) + + with self.test_session() as sess: + (distorted_boxes_, distorted_keypoints_, distorted_labels_, + distorted_weights_, expected_boxes_, expected_keypoints_, + expected_labels_, expected_weights_) = sess.run( + [distorted_boxes, distorted_keypoints, distorted_labels, + distorted_weights, expected_boxes, expected_keypoints, + expected_labels, expected_weights]) + self.assertAllClose(distorted_boxes_, expected_boxes_) + self.assertAllClose(distorted_keypoints_, expected_keypoints_) + self.assertAllEqual(distorted_labels_, expected_labels_) + self.assertAllEqual(distorted_weights_, expected_weights_) def testRandomCropImageWithMultiClassScores(self): preprocessing_options = [] @@ -1304,12 +1369,14 @@ def testRandomCropImageWithMultiClassScores(self): images = self.createTestImages() boxes = self.createTestBoxes() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() multiclass_scores = self.createTestMultiClassScores() tensor_dict = { fields.InputDataFields.image: images, fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, + fields.InputDataFields.groundtruth_weights: weights, fields.InputDataFields.multiclass_scores: multiclass_scores } distorted_tensor_dict = preprocessor.preprocess(tensor_dict, @@ -1342,11 +1409,11 @@ def testRandomCropImageWithMultiClassScores(self): self.assertAllEqual(distorted_boxes_.shape[0], distorted_multiclass_scores_.shape[0]) - def testStrictRandomCropImageWithLabelScores(self): + def testStrictRandomCropImageWithGroundtruthWeights(self): image = self.createColorfulTestImage()[0] boxes = self.createTestBoxes() labels = self.createTestLabels() - label_scores = self.createTestLabelScores() + weights = self.createTestGroundtruthWeights() with mock.patch.object( tf.image, 'sample_distorted_bounding_box' @@ -1355,20 +1422,20 @@ def testStrictRandomCropImageWithLabelScores(self): tf.constant([6, 143, 0], dtype=tf.int32), tf.constant([190, 237, -1], dtype=tf.int32), tf.constant([[[0.03, 0.3575, 0.98, 0.95]]], dtype=tf.float32)) - new_image, new_boxes, new_labels, new_label_scores = ( + new_image, new_boxes, new_labels, new_groundtruth_weights = ( preprocessor._strict_random_crop_image( - image, boxes, labels, label_scores)) + image, boxes, labels, weights)) with self.test_session() as sess: - new_image, new_boxes, new_labels, new_label_scores = ( + new_image, new_boxes, new_labels, new_groundtruth_weights = ( sess.run( - [new_image, new_boxes, new_labels, new_label_scores]) + [new_image, new_boxes, new_labels, new_groundtruth_weights]) ) expected_boxes = np.array( [[0.0, 0.0, 0.75789469, 1.0], [0.23157893, 0.24050637, 0.75789469, 1.0]], dtype=np.float32) self.assertAllEqual(new_image.shape, [190, 237, 3]) - self.assertAllEqual(new_label_scores, [1.0, 0.5]) + self.assertAllEqual(new_groundtruth_weights, [1.0, 0.5]) self.assertAllClose( new_boxes.flatten(), expected_boxes.flatten()) @@ -1376,6 +1443,7 @@ def testStrictRandomCropImageWithMasks(self): image = self.createColorfulTestImage()[0] boxes = self.createTestBoxes() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() masks = tf.random_uniform([2, 200, 400], dtype=tf.float32) with mock.patch.object( tf.image, @@ -1385,12 +1453,12 @@ def testStrictRandomCropImageWithMasks(self): tf.constant([6, 143, 0], dtype=tf.int32), tf.constant([190, 237, -1], dtype=tf.int32), tf.constant([[[0.03, 0.3575, 0.98, 0.95]]], dtype=tf.float32)) - new_image, new_boxes, new_labels, new_masks = ( + new_image, new_boxes, new_labels, new_weights, new_masks = ( preprocessor._strict_random_crop_image( - image, boxes, labels, masks=masks)) + image, boxes, labels, weights, masks=masks)) with self.test_session() as sess: - new_image, new_boxes, new_labels, new_masks = sess.run( - [new_image, new_boxes, new_labels, new_masks]) + new_image, new_boxes, new_labels, new_weights, new_masks = sess.run( + [new_image, new_boxes, new_labels, new_weights, new_masks]) expected_boxes = np.array( [[0.0, 0.0, 0.75789469, 1.0], [0.23157893, 0.24050637, 0.75789469, 1.0]], dtype=np.float32) @@ -1403,6 +1471,7 @@ def testStrictRandomCropImageWithKeypoints(self): image = self.createColorfulTestImage()[0] boxes = self.createTestBoxes() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() keypoints = self.createTestKeypoints() with mock.patch.object( tf.image, @@ -1412,12 +1481,12 @@ def testStrictRandomCropImageWithKeypoints(self): tf.constant([6, 143, 0], dtype=tf.int32), tf.constant([190, 237, -1], dtype=tf.int32), tf.constant([[[0.03, 0.3575, 0.98, 0.95]]], dtype=tf.float32)) - new_image, new_boxes, new_labels, new_keypoints = ( + new_image, new_boxes, new_labels, new_weights, new_keypoints = ( preprocessor._strict_random_crop_image( - image, boxes, labels, keypoints=keypoints)) + image, boxes, labels, weights, keypoints=keypoints)) with self.test_session() as sess: - new_image, new_boxes, new_labels, new_keypoints = sess.run( - [new_image, new_boxes, new_labels, new_keypoints]) + new_image, new_boxes, new_labels, new_weights, new_keypoints = sess.run( + [new_image, new_boxes, new_labels, new_weights, new_keypoints]) expected_boxes = np.array([ [0.0, 0.0, 0.75789469, 1.0], @@ -1440,12 +1509,14 @@ def testRunRandomCropImageWithMasks(self): image = self.createColorfulTestImage() boxes = self.createTestBoxes() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() masks = tf.random_uniform([2, 200, 400], dtype=tf.float32) tensor_dict = { fields.InputDataFields.image: image, fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, + fields.InputDataFields.groundtruth_weights: weights, fields.InputDataFields.groundtruth_instance_masks: masks, } @@ -1491,13 +1562,15 @@ def testRunRandomCropImageWithKeypointsInsideCrop(self): image = self.createColorfulTestImage() boxes = self.createTestBoxes() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() keypoints = self.createTestKeypointsInsideCrop() tensor_dict = { fields.InputDataFields.image: image, fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, - fields.InputDataFields.groundtruth_keypoints: keypoints + fields.InputDataFields.groundtruth_keypoints: keypoints, + fields.InputDataFields.groundtruth_weights: weights } preprocessor_arg_map = preprocessor.get_default_func_arg_map( @@ -1551,12 +1624,14 @@ def testRunRandomCropImageWithKeypointsOutsideCrop(self): image = self.createColorfulTestImage() boxes = self.createTestBoxes() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() keypoints = self.createTestKeypointsOutsideCrop() tensor_dict = { fields.InputDataFields.image: image, fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, + fields.InputDataFields.groundtruth_weights: weights, fields.InputDataFields.groundtruth_keypoints: keypoints } @@ -1610,33 +1685,32 @@ def testRunRandomCropImageWithKeypointsOutsideCrop(self): def testRunRetainBoxesAboveThreshold(self): boxes = self.createTestBoxes() labels = self.createTestLabels() - label_scores = self.createTestLabelScores() + weights = self.createTestGroundtruthWeights() tensor_dict = { fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, - fields.InputDataFields.groundtruth_confidences: label_scores + fields.InputDataFields.groundtruth_weights: weights, } preprocessing_options = [ (preprocessor.retain_boxes_above_threshold, {'threshold': 0.6}) ] - preprocessor_arg_map = preprocessor.get_default_func_arg_map( - include_label_scores=True) + preprocessor_arg_map = preprocessor.get_default_func_arg_map() retained_tensor_dict = preprocessor.preprocess( tensor_dict, preprocessing_options, func_arg_map=preprocessor_arg_map) retained_boxes = retained_tensor_dict[ fields.InputDataFields.groundtruth_boxes] retained_labels = retained_tensor_dict[ fields.InputDataFields.groundtruth_classes] - retained_label_scores = retained_tensor_dict[ - fields.InputDataFields.groundtruth_confidences] + retained_weights = retained_tensor_dict[ + fields.InputDataFields.groundtruth_weights] with self.test_session() as sess: (retained_boxes_, retained_labels_, - retained_label_scores_, expected_retained_boxes_, - expected_retained_labels_, expected_retained_label_scores_) = sess.run( - [retained_boxes, retained_labels, retained_label_scores, + retained_weights_, expected_retained_boxes_, + expected_retained_labels_, expected_retained_weights_) = sess.run( + [retained_boxes, retained_labels, retained_weights, self.expectedBoxesAfterThresholding(), self.expectedLabelsAfterThresholding(), self.expectedLabelScoresAfterThresholding()]) @@ -1644,18 +1718,18 @@ def testRunRetainBoxesAboveThreshold(self): self.assertAllClose(retained_boxes_, expected_retained_boxes_) self.assertAllClose(retained_labels_, expected_retained_labels_) self.assertAllClose( - retained_label_scores_, expected_retained_label_scores_) + retained_weights_, expected_retained_weights_) def testRunRetainBoxesAboveThresholdWithMasks(self): boxes = self.createTestBoxes() labels = self.createTestLabels() - label_scores = self.createTestLabelScores() + weights = self.createTestGroundtruthWeights() masks = self.createTestMasks() tensor_dict = { fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, - fields.InputDataFields.groundtruth_confidences: label_scores, + fields.InputDataFields.groundtruth_weights: weights, fields.InputDataFields.groundtruth_instance_masks: masks } @@ -1681,18 +1755,17 @@ def testRunRetainBoxesAboveThresholdWithMasks(self): def testRunRetainBoxesAboveThresholdWithKeypoints(self): boxes = self.createTestBoxes() labels = self.createTestLabels() - label_scores = self.createTestLabelScores() + weights = self.createTestGroundtruthWeights() keypoints = self.createTestKeypoints() tensor_dict = { fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, - fields.InputDataFields.groundtruth_confidences: label_scores, + fields.InputDataFields.groundtruth_weights: weights, fields.InputDataFields.groundtruth_keypoints: keypoints } preprocessor_arg_map = preprocessor.get_default_func_arg_map( - include_label_scores=True, include_keypoints=True) preprocessing_options = [ @@ -1721,12 +1794,14 @@ def testRunRandomCropToAspectRatioWithMasks(self): image = self.createColorfulTestImage() boxes = self.createTestBoxes() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() masks = tf.random_uniform([2, 200, 400], dtype=tf.float32) tensor_dict = { fields.InputDataFields.image: image, fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, + fields.InputDataFields.groundtruth_weights: weights, fields.InputDataFields.groundtruth_instance_masks: masks } @@ -1764,12 +1839,14 @@ def testRunRandomCropToAspectRatioWithKeypoints(self): image = self.createColorfulTestImage() boxes = self.createTestBoxes() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() keypoints = self.createTestKeypoints() tensor_dict = { fields.InputDataFields.image: image, fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, + fields.InputDataFields.groundtruth_weights: weights, fields.InputDataFields.groundtruth_keypoints: keypoints } @@ -2016,10 +2093,12 @@ def testRandomCropPadImageWithRandomCoefOne(self): images = self.createTestImages() boxes = self.createTestBoxes() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() tensor_dict = { fields.InputDataFields.image: images, fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, + fields.InputDataFields.groundtruth_weights: weights, } tensor_dict = preprocessor.preprocess(tensor_dict, preprocessing_options) images = tensor_dict[fields.InputDataFields.image] @@ -2057,10 +2136,12 @@ def testRandomCropToAspectRatio(self): images = self.createTestImages() boxes = self.createTestBoxes() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() tensor_dict = { fields.InputDataFields.image: images, fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, + fields.InputDataFields.groundtruth_weights: weights, } tensor_dict = preprocessor.preprocess(tensor_dict, []) images = tensor_dict[fields.InputDataFields.image] @@ -2638,10 +2719,12 @@ def testSSDRandomCrop(self): images = self.createTestImages() boxes = self.createTestBoxes() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() tensor_dict = { fields.InputDataFields.image: images, fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, + fields.InputDataFields.groundtruth_weights: weights, } distorted_tensor_dict = preprocessor.preprocess(tensor_dict, preprocessing_options) @@ -2672,6 +2755,7 @@ def testSSDRandomCropWithMultiClassScores(self): images = self.createTestImages() boxes = self.createTestBoxes() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() multiclass_scores = self.createTestMultiClassScores() tensor_dict = { @@ -2679,6 +2763,7 @@ def testSSDRandomCropWithMultiClassScores(self): fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, fields.InputDataFields.multiclass_scores: multiclass_scores, + fields.InputDataFields.groundtruth_weights: weights, } preprocessor_arg_map = preprocessor.get_default_func_arg_map( include_multiclass_scores=True) @@ -2717,6 +2802,7 @@ def testSSDRandomCropPad(self): images = self.createTestImages() boxes = self.createTestBoxes() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() preprocessing_options = [ (preprocessor.normalize_image, { 'original_minval': 0, @@ -2729,6 +2815,7 @@ def testSSDRandomCropPad(self): fields.InputDataFields.image: images, fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, + fields.InputDataFields.groundtruth_weights: weights, } distorted_tensor_dict = preprocessor.preprocess(tensor_dict, preprocessing_options) @@ -2764,13 +2851,13 @@ def testSSDRandomCropFixedAspectRatioWithCache(self): test_keypoints=False) def _testSSDRandomCropFixedAspectRatio(self, - include_label_scores, include_multiclass_scores, include_instance_masks, include_keypoints): images = self.createTestImages() boxes = self.createTestBoxes() labels = self.createTestLabels() + weights = self.createTestGroundtruthWeights() preprocessing_options = [(preprocessor.normalize_image, { 'original_minval': 0, 'original_maxval': 255, @@ -2781,11 +2868,8 @@ def _testSSDRandomCropFixedAspectRatio(self, fields.InputDataFields.image: images, fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_classes: labels, + fields.InputDataFields.groundtruth_weights: weights } - if include_label_scores: - label_scores = self.createTestLabelScores() - tensor_dict[fields.InputDataFields.groundtruth_confidences] = ( - label_scores) if include_multiclass_scores: multiclass_scores = self.createTestMultiClassScores() tensor_dict[fields.InputDataFields.multiclass_scores] = ( @@ -2798,7 +2882,6 @@ def _testSSDRandomCropFixedAspectRatio(self, tensor_dict[fields.InputDataFields.groundtruth_keypoints] = keypoints preprocessor_arg_map = preprocessor.get_default_func_arg_map( - include_label_scores=include_label_scores, include_multiclass_scores=include_multiclass_scores, include_instance_masks=include_instance_masks, include_keypoints=include_keypoints) @@ -2821,26 +2904,22 @@ def _testSSDRandomCropFixedAspectRatio(self, self.assertAllEqual(images_rank_, distorted_images_rank_) def testSSDRandomCropFixedAspectRatio(self): - self._testSSDRandomCropFixedAspectRatio(include_label_scores=False, - include_multiclass_scores=False, + self._testSSDRandomCropFixedAspectRatio(include_multiclass_scores=False, include_instance_masks=False, include_keypoints=False) def testSSDRandomCropFixedAspectRatioWithMultiClassScores(self): - self._testSSDRandomCropFixedAspectRatio(include_label_scores=False, - include_multiclass_scores=True, + self._testSSDRandomCropFixedAspectRatio(include_multiclass_scores=True, include_instance_masks=False, include_keypoints=False) def testSSDRandomCropFixedAspectRatioWithMasksAndKeypoints(self): - self._testSSDRandomCropFixedAspectRatio(include_label_scores=False, - include_multiclass_scores=False, + self._testSSDRandomCropFixedAspectRatio(include_multiclass_scores=False, include_instance_masks=True, include_keypoints=True) def testSSDRandomCropFixedAspectRatioWithLabelScoresMasksAndKeypoints(self): - self._testSSDRandomCropFixedAspectRatio(include_label_scores=True, - include_multiclass_scores=False, + self._testSSDRandomCropFixedAspectRatio(include_multiclass_scores=False, include_instance_masks=True, include_keypoints=True) diff --git a/research/object_detection/core/standard_fields.py b/research/object_detection/core/standard_fields.py index 0cdddd8e4eb..bd82d60272f 100644 --- a/research/object_detection/core/standard_fields.py +++ b/research/object_detection/core/standard_fields.py @@ -44,7 +44,6 @@ class InputDataFields(object): groundtruth_image_confidences: image-level class confidences. groundtruth_boxes: coordinates of the ground truth boxes in the image. groundtruth_classes: box-level class labels. - groundtruth_confidences: box-level class confidences. groundtruth_label_types: box-level label types (e.g. explicit negative). groundtruth_is_crowd: [DEPRECATED, use groundtruth_group_of instead] is the groundtruth a single object or a crowd. diff --git a/research/object_detection/core/target_assigner.py b/research/object_detection/core/target_assigner.py index 54f10fbc8d9..dfdaeccc928 100644 --- a/research/object_detection/core/target_assigner.py +++ b/research/object_detection/core/target_assigner.py @@ -130,7 +130,8 @@ def assign(self, cls_targets: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], where the subshape [d_1, ..., d_k] is compatible with groundtruth_labels which has shape [num_gt_boxes, d_1, d_2, ... d_k]. - cls_weights: a float32 tensor with shape [num_anchors] + cls_weights: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], + representing weights for each element in cls_targets. reg_targets: a float32 tensor with shape [num_anchors, box_code_dimension] reg_weights: a float32 tensor with shape [num_anchors] match: a matcher.Match object encoding the match between anchors and @@ -195,6 +196,15 @@ def assign(self, cls_weights = self._create_classification_weights(match, groundtruth_weights) + # convert cls_weights from per-anchor to per-class. + class_label_shape = tf.shape(cls_targets)[1:] + weights_shape = tf.shape(cls_weights) + weights_multiple = tf.concat( + [tf.ones_like(weights_shape), class_label_shape], + axis=0) + for _ in range(len(cls_targets.get_shape()[1:])): + cls_weights = tf.expand_dims(cls_weights, -1) + cls_weights = tf.tile(cls_weights, weights_multiple) num_anchors = anchors.num_boxes_static() if num_anchors is not None: @@ -445,7 +455,8 @@ def batch_assign_targets(target_assigner, Returns: batch_cls_targets: a tensor with shape [batch_size, num_anchors, num_classes], - batch_cls_weights: a tensor with shape [batch_size, num_anchors], + batch_cls_weights: a tensor with shape [batch_size, num_anchors, + num_classes], batch_reg_targets: a tensor with shape [batch_size, num_anchors, box_code_dimension] batch_reg_weights: a tensor with shape [batch_size, num_anchors], diff --git a/research/object_detection/core/target_assigner_test.py b/research/object_detection/core/target_assigner_test.py index 5c52add14e7..98de26dd86c 100644 --- a/research/object_detection/core/target_assigner_test.py +++ b/research/object_detection/core/target_assigner_test.py @@ -52,7 +52,7 @@ def graph_fn(anchor_means, groundtruth_box_corners): [0.5, 0.5, 0.9, 0.9]], dtype=np.float32) exp_cls_targets = [[1], [1], [0]] - exp_cls_weights = [1, 1, 1] + exp_cls_weights = [[1], [1], [1]] exp_reg_targets = [[0, 0, 0, 0], [0, 0, -1, 1], [0, 0, 0, 0]] @@ -96,7 +96,7 @@ def graph_fn(anchor_means, groundtruth_box_corners): groundtruth_box_corners = np.array([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 0.9, 0.9]], dtype=np.float32) exp_cls_targets = [[1], [1], [0]] - exp_cls_weights = [1, 1, 0] + exp_cls_weights = [[1], [1], [0]] exp_reg_targets = [[0, 0, 0, 0], [0, 0, -1, 1], [0, 0, 0, 0]] @@ -143,7 +143,7 @@ def graph_fn(anchor_means, groundtruth_box_corners, [[0, 0.3], [0.2, 0.4], [0.5, 0.6], [0, 0.6], [0.8, 0.2], [0.2, 0.4]]], dtype=np.float32) exp_cls_targets = [[1], [1], [0]] - exp_cls_weights = [1, 1, 1] + exp_cls_weights = [[1], [1], [1]] exp_reg_targets = [[0, 0, 0, 0, -3, -1, -3, 1, -1, -1, -1, -1, -3, -3, 13, -5], [-1, -1, 0, 0, -15, -9, -11, -7, -5, -3, -15, -3, 1, -11, @@ -198,7 +198,7 @@ def graph_fn(anchor_means, groundtruth_box_corners, [[0, 0.3], [0.2, 0.4], [0.5, 0.6], [0, 0.6], [0.8, 0.2], [0.2, 0.4]]], dtype=np.float32) exp_cls_targets = [[1], [1], [0]] - exp_cls_weights = [1, 1, 1] + exp_cls_weights = [[1], [1], [1]] exp_reg_targets = [[0, 0, 0, 0, -3, -1, -3, 1, -1, -1, -1, -1, -3, -3, 13, -5], [-1, -1, 0, 0, -15, -9, -11, -7, -5, -3, -15, -3, 1, -11, @@ -254,7 +254,10 @@ def graph_fn(anchor_means, groundtruth_box_corners, groundtruth_labels): [0, 0, 0, 0, 0, 1, 0], [1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0]] - exp_cls_weights = [1, 1, 1, 1] + exp_cls_weights = [[1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1]] exp_reg_targets = [[0, 0, 0, 0], [0, 0, -1, 1], [0, 0, 0, 0], @@ -308,7 +311,11 @@ def graph_fn(anchor_means, groundtruth_box_corners, groundtruth_labels, [0, 0, 0, 1, 0, 0, 0]], dtype=np.float32) groundtruth_weights = np.array([0.3, 0., 0.5], dtype=np.float32) - exp_cls_weights = [0.3, 0., 1, 0.5] # background class gets weight of 1. + # background class gets weight of 1. + exp_cls_weights = [[0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3], + [0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1], + [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]] exp_reg_weights = [0.3, 0., 0., 0.5] # background class gets weight of 0. (cls_weights_out, reg_weights_out) = self.execute(graph_fn, [ @@ -354,7 +361,11 @@ def graph_fn(anchor_means, groundtruth_box_corners, groundtruth_labels): [.5, 0, 0, .5, 0, 0, 0]], dtype=np.float32) - exp_cls_weights = [1, 1, 1, 1] # background class gets weight of 1. + exp_cls_weights = [ + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1]] # background class gets weight of 1. exp_reg_weights = [.1, 1, 0., .5] # background class gets weight of 0. (cls_weights_out, reg_weights_out) = self.execute( @@ -400,7 +411,10 @@ def graph_fn(anchor_means, groundtruth_box_corners, groundtruth_labels): [[1, 0], [0, 1]], [[0, 0], [0, 0]], [[0, 1], [1, .5]]] - exp_cls_weights = [1, 1, 1, 1] + exp_cls_weights = [[[1, 1], [1, 1]], + [[1, 1], [1, 1]], + [[1, 1], [1, 1]], + [[1, 1], [1, 1]]] exp_reg_targets = [[0, 0, 0, 0], [0, 0, -1, 1], [0, 0, 0, 0], @@ -449,7 +463,10 @@ def graph_fn(anchor_means, groundtruth_box_corners, groundtruth_labels): [0, 0, 0], [0, 0, 0], [0, 0, 0]] - exp_cls_weights = [1, 1, 1, 1] + exp_cls_weights = [[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]] exp_reg_targets = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], @@ -555,6 +572,10 @@ def graph_fn(anchor_means, groundtruth_boxlist1, groundtruth_boxlist2): [0, .1, .5, .5], [.75, .75, 1, 1]], dtype=np.float32) + exp_cls_targets = [[[1], [0], [0], [0]], + [[0], [1], [1], [0]]] + exp_cls_weights = [[[1], [1], [1], [1]], + [[1], [1], [1], [1]]] exp_reg_targets = [[[0, 0, -0.5, -0.5], [0, 0, 0, 0], [0, 0, 0, 0,], @@ -563,10 +584,6 @@ def graph_fn(anchor_means, groundtruth_boxlist1, groundtruth_boxlist2): [0, 0.01231521, 0, 0], [0.15789001, -0.01500003, 0.57889998, -1.15799987], [0, 0, 0, 0]]] - exp_cls_weights = [[1, 1, 1, 1], - [1, 1, 1, 1]] - exp_cls_targets = [[[1], [0], [0], [0]], - [[0], [1], [1], [0]]] exp_reg_weights = [[1, 0, 0, 0], [0, 1, 1, 0]] @@ -608,17 +625,6 @@ def graph_fn(anchor_means, groundtruth_boxlist1, groundtruth_boxlist2, [0, .25, 1, 1], [0, .1, .5, .5], [.75, .75, 1, 1]], dtype=np.float32) - - exp_reg_targets = [[[0, 0, -0.5, -0.5], - [0, 0, 0, 0], - [0, 0, 0, 0,], - [0, 0, 0, 0,],], - [[0, 0, 0, 0,], - [0, 0.01231521, 0, 0], - [0.15789001, -0.01500003, 0.57889998, -1.15799987], - [0, 0, 0, 0]]] - exp_cls_weights = [[1, 1, 1, 1], - [1, 1, 1, 1]] exp_cls_targets = [[[0, 1, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], @@ -627,6 +633,22 @@ def graph_fn(anchor_means, groundtruth_boxlist1, groundtruth_boxlist2, [0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0]]] + exp_cls_weights = [[[1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1]]] + exp_reg_targets = [[[0, 0, -0.5, -0.5], + [0, 0, 0, 0], + [0, 0, 0, 0,], + [0, 0, 0, 0,],], + [[0, 0, 0, 0,], + [0, 0.01231521, 0, 0], + [0.15789001, -0.01500003, 0.57889998, -1.15799987], + [0, 0, 0, 0]]] exp_reg_weights = [[1, 0, 0, 0], [0, 1, 1, 0]] @@ -678,16 +700,6 @@ def graph_fn(anchor_means, groundtruth_boxlist1, groundtruth_boxlist2, [0, .1, .5, .5], [.75, .75, 1, 1]], dtype=np.float32) - exp_reg_targets = [[[0, 0, -0.5, -0.5], - [0, 0, 0, 0], - [0, 0, 0, 0,], - [0, 0, 0, 0,],], - [[0, 0, 0, 0,], - [0, 0.01231521, 0, 0], - [0.15789001, -0.01500003, 0.57889998, -1.15799987], - [0, 0, 0, 0]]] - exp_cls_weights = [[1, 1, 1, 1], - [1, 1, 1, 1]] exp_cls_targets = [[[0, 1, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], @@ -696,6 +708,22 @@ def graph_fn(anchor_means, groundtruth_boxlist1, groundtruth_boxlist2, [0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0]]] + exp_cls_weights = [[[1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1]]] + exp_reg_targets = [[[0, 0, -0.5, -0.5], + [0, 0, 0, 0], + [0, 0, 0, 0,], + [0, 0, 0, 0,],], + [[0, 0, 0, 0,], + [0, 0.01231521, 0, 0], + [0.15789001, -0.01500003, 0.57889998, -1.15799987], + [0, 0, 0, 0]]] exp_reg_weights = [[1, 0, 0, 0], [0, 1, 1, 0]] @@ -748,16 +776,6 @@ def graph_fn(anchor_means, groundtruth_boxlist1, groundtruth_boxlist2, [0, .1, .5, .5], [.75, .75, 1, 1]], dtype=np.float32) - exp_reg_targets = [[[0, 0, -0.5, -0.5], - [0, 0, 0, 0], - [0, 0, 0, 0,], - [0, 0, 0, 0,],], - [[0, 0, 0, 0,], - [0, 0.01231521, 0, 0], - [0.15789001, -0.01500003, 0.57889998, -1.15799987], - [0, 0, 0, 0]]] - exp_cls_weights = [[1, 1, 1, 1], - [1, 1, 1, 1]] exp_cls_targets = [[[[0., 1., 1.], [1., 1., 0.]], [[0., 0., 0.], @@ -774,6 +792,30 @@ def graph_fn(anchor_means, groundtruth_boxlist1, groundtruth_boxlist2, [0., 0., 1.]], [[0., 0., 0.], [0., 0., 0.]]]] + exp_cls_weights = [[[[1., 1., 1.], + [1., 1., 1.]], + [[1., 1., 1.], + [1., 1., 1.]], + [[1., 1., 1.], + [1., 1., 1.]], + [[1., 1., 1.], + [1., 1., 1.]]], + [[[1., 1., 1.], + [1., 1., 1.]], + [[1., 1., 1.], + [1., 1., 1.]], + [[1., 1., 1.], + [1., 1., 1.]], + [[1., 1., 1.], + [1., 1., 1.]]]] + exp_reg_targets = [[[0, 0, -0.5, -0.5], + [0, 0, 0, 0], + [0, 0, 0, 0,], + [0, 0, 0, 0,],], + [[0, 0, 0, 0,], + [0, 0.01231521, 0, 0], + [0.15789001, -0.01500003, 0.57889998, -1.15799987], + [0, 0, 0, 0]]] exp_reg_weights = [[1, 0, 0, 0], [0, 1, 1, 0]] @@ -807,11 +849,12 @@ def graph_fn(anchor_means, groundtruth_box_corners, gt_class_targets): groundtruth_box_corners = np.zeros((0, 4), dtype=np.float32) anchor_means = np.array([[0, 0, .25, .25], [0, .25, 1, 1]], dtype=np.float32) - exp_reg_targets = [[[0, 0, 0, 0], - [0, 0, 0, 0]]] - exp_cls_weights = [[1, 1]] exp_cls_targets = [[[1, 0, 0, 0], [1, 0, 0, 0]]] + exp_cls_weights = [[[1, 1, 1, 1], + [1, 1, 1, 1]]] + exp_reg_targets = [[[0, 0, 0, 0], + [0, 0, 0, 0]]] exp_reg_weights = [[0, 0]] num_classes = 3 pad = 1 diff --git a/research/object_detection/data/face_label_map.pbtxt b/research/object_detection/data/face_label_map.pbtxt new file mode 100644 index 00000000000..1c7355db1fd --- /dev/null +++ b/research/object_detection/data/face_label_map.pbtxt @@ -0,0 +1,6 @@ +item { + name: "face" + id: 1 + display_name: "face" +} + diff --git a/research/object_detection/data/mscoco_minival_ids.txt b/research/object_detection/data/mscoco_minival_ids.txt new file mode 100644 index 00000000000..5bbff3c18d4 --- /dev/null +++ b/research/object_detection/data/mscoco_minival_ids.txt @@ -0,0 +1,8059 @@ +25096 +251824 +35313 +546011 +524186 +205866 +511403 +313916 +47471 +258628 +233560 +576017 +404517 +410056 +178690 +248980 +511724 +429718 +163076 +244111 +126766 +313182 +191981 +139992 +325237 +248129 +214519 +175438 +493321 +174103 +563762 +536795 +289960 +473720 +515540 +292118 +360851 +267175 +532876 +171613 +581415 +259819 +441841 +381682 +58157 +4980 +473929 +70626 +93773 +283412 +36765 +495020 +278401 +329307 +192810 +491784 +506416 +225495 +553747 +86442 +242208 +132686 +385877 +290248 +525705 +5476 +486521 +332512 +138556 +348083 +284375 +40018 +296994 +38685 +432429 +183407 +434358 +472164 +530494 +570693 +193401 +392612 +98872 +445766 +532209 +98322 +285114 +267725 +51605 +314812 +91105 +535506 +540264 +375341 +449828 +277659 +68933 +76873 +217554 +213592 +190776 +516224 +474479 +343599 +578813 +128669 +546292 +475365 +377626 +128833 +427091 +547227 +11742 +80213 +462241 +374574 +121572 +29151 +13892 +262394 +303667 +198724 +7320 +448492 +419080 +460379 +483965 +556516 +139181 +1103 +308715 +207507 +213827 +216083 +445597 +240275 +379585 +116389 +138124 +559051 +326898 +419386 +503660 +519460 +23893 +24458 +518109 +462982 +151492 +514254 +2477 +147165 +570394 +548766 +250083 +364341 +351967 +386277 +328084 +511299 +499349 +315501 +234965 +428562 +219771 +288150 +136021 +168619 +298316 +75118 +189752 +243857 +296222 +554002 +533628 +384596 +202981 +498350 +391463 +183991 +528062 +451084 +7899 +408534 +329030 +318566 +22492 +361285 +226973 +213356 +417265 +105622 +161169 +261487 +167477 +233370 +142999 +256713 +305833 +103579 +352538 +135763 +392144 +61181 +200302 +456908 +286858 +179850 +488075 +174511 +194755 +317822 +2302 +304596 +172556 +548275 +341678 +55299 +134760 +352936 +545129 +377012 +141328 +103757 +552837 +28246 +125167 +328745 +278760 +337133 +403389 +146825 +502558 +265916 +428985 +492041 +113403 +372037 +306103 +287574 +187495 +479805 +336309 +162043 +95899 +43133 +464248 +149115 +247438 +74030 +130645 +282841 +127092 +101172 +536743 +179642 +58133 +49667 +170605 +11347 +365277 +201970 +292663 +217219 +463226 +41924 +281102 +357816 +490878 +100343 +525058 +133503 +416145 +29341 +415413 +125527 +507951 +262609 +240210 +581781 +345137 +526342 +268641 +328777 +32001 +137538 +39115 +415958 +6771 +421865 +64909 +383601 +206907 +420840 +370980 +28452 +571893 +153520 +185890 +392991 +547013 +257359 +279879 +478614 +131919 +40937 +22874 +173375 +106344 +44801 +205401 +312870 +400886 +351530 +344013 +173500 +470423 +396729 +402499 +276585 +377097 +367619 +518908 +263866 +332292 +67805 +152211 +515025 +221350 +525247 +78490 +504342 +95908 +82668 +256199 +220270 +552065 +242379 +84866 +152281 +228464 +223122 +67537 +456968 +368349 +101985 +14681 +543551 +107558 +372009 +99054 +126540 +86877 +492785 +482585 +571564 +501116 +296871 +20395 +181518 +568041 +121154 +56187 +190018 +97156 +310325 +393274 +214574 +243222 +289949 +452121 +150508 +341752 +310757 +24040 +228551 +335589 +12020 +529597 +459884 +344888 +229713 +51948 +370929 +552061 +261072 +120070 +332067 +263014 +158993 +451714 +397327 +20965 +414340 +574946 +370266 +487534 +492246 +264771 +73702 +43997 +235124 +301093 +400048 +77681 +58472 +331386 +13783 +242513 +419158 +59325 +383033 +393258 +529041 +249276 +182775 +351793 +9727 +334069 +566771 +539355 +38662 +423617 +47559 +120592 +508303 +462565 +47916 +218208 +182362 +562101 +441442 +71239 +395378 +522637 +25603 +484450 +872 +171483 +527248 +323155 +240754 +15032 +419144 +313214 +250917 +333430 +242757 +221914 +283190 +194297 +228506 +550691 +172513 +312192 +530619 +113867 +323552 +374115 +35435 +160239 +62877 +441873 +196574 +62858 +557114 +427612 +242869 +356733 +304828 +24880 +490509 +407083 +457877 +402788 +536416 +385912 +544121 +500389 +451102 +12120 +483476 +70987 +482799 +542549 +49236 +424258 +435783 +182366 +438093 +501824 +232845 +53965 +223198 +288933 +450458 +285664 +196484 +408930 +519815 +290981 +398567 +315792 +490683 +257136 +75611 +302498 +332153 +82293 +416911 +558608 +564659 +536195 +370260 +57904 +527270 +6593 +145620 +551650 +470832 +515785 +251404 +287331 +150788 +334006 +266117 +10039 +579158 +328397 +468351 +550400 +31745 +405970 +16761 +323515 +459598 +558457 +570736 +476939 +472610 +72155 +112517 +13659 +530905 +458768 +43486 +560893 +493174 +31217 +262736 +412204 +142722 +151231 +480643 +197245 +398666 +444869 +110999 +191724 +479057 +492420 +170638 +277329 +301908 +395644 +537611 +141887 +47149 +403432 +34818 +372495 +67994 +337497 +478586 +249815 +533462 +281032 +289941 +151911 +271215 +407868 +360700 +508582 +103873 +353658 +369081 +406403 +331692 +26430 +105655 +572630 +37181 +91336 +484587 +318284 +113019 +33055 +25293 +229324 +374052 +384111 +213951 +315195 +319283 +539453 +17655 +308974 +326243 +539436 +417876 +526940 +356347 +221932 +73753 +292648 +262284 +304924 +558587 +374858 +253518 +311744 +539636 +40924 +136624 +334305 +365997 +63355 +191226 +526732 +367128 +575198 +500657 +50637 +17182 +424792 +565353 +563040 +383494 +74458 +155142 +197125 +223857 +428241 +440830 +371289 +437303 +330449 +93771 +82715 +499631 +381257 +563951 +192834 +528600 +404273 +270554 +208053 +188613 +484760 +432016 +129800 +91756 +523097 +317018 +487282 +444913 +159500 +126822 +540564 +105812 +560756 +306099 +471226 +123842 +513219 +154877 +497034 +283928 +564003 +238602 +194780 +462728 +558640 +524373 +455624 +3690 +560367 +316351 +455772 +223777 +161517 +243034 +250440 +239975 +441008 +324715 +152106 +246973 +462805 +296521 +412767 +530913 +370165 +292526 +107244 +217440 +330204 +220176 +577735 +197022 +127451 +518701 +212322 +204887 +27696 +348474 +119233 +282804 +230040 +425690 +409241 +296825 +296353 +375909 +123136 +573891 +338256 +198247 +373375 +151051 +500084 +557596 +120478 +44989 +283380 +149005 +522065 +626 +17198 +309633 +524245 +291589 +322714 +455847 +248468 +371948 +444928 +20438 +481670 +147195 +95022 +548159 +553165 +395324 +391371 +86884 +561121 +219737 +38875 +338159 +377881 +185472 +359277 +114861 +378048 +126226 +10217 +320246 +15827 +178236 +370279 +352978 +408101 +77615 +337044 +223714 +20796 +352445 +263834 +156704 +377867 +119402 +399567 +1180 +257941 +560675 +390471 +209290 +258382 +466339 +56437 +195042 +384230 +203214 +36077 +283038 +38323 +158770 +532381 +395903 +375461 +397857 +326798 +371699 +369503 +495626 +464328 +462211 +397719 +434089 +424793 +476770 +531852 +303538 +525849 +480917 +419653 +265063 +48956 +5184 +279149 +396727 +374266 +124429 +36124 +240213 +147556 +339512 +577182 +288599 +257169 +178254 +393869 +122314 +28713 +48133 +540681 +100974 +368459 +500110 +73634 +460982 +203878 +578344 +443602 +502012 +399666 +103603 +22090 +257529 +176328 +536656 +408873 +116881 +460972 +33835 +460781 +51223 +46463 +89395 +407646 +337453 +461715 +16257 +426987 +234889 +3125 +165643 +517472 +451435 +206800 +112128 +331236 +163306 +94185 +498716 +532732 +146509 +458567 +153832 +105996 +353398 +546976 +283060 +247624 +110048 +243491 +154798 +543600 +149962 +355256 +352900 +203081 +372203 +284605 +516244 +190494 +150301 +326082 +64146 +402858 +413538 +399510 +460251 +94336 +458721 +57345 +424162 +423508 +69356 +567220 +509786 +37038 +111535 +341318 +372067 +358120 +244909 +180653 +39852 +438560 +357041 +67065 +51928 +171717 +520430 +552395 +431355 +528084 +20913 +309610 +262323 +573784 +449485 +154846 +283438 +430871 +199578 +516318 +563912 +348483 +485613 +143440 +94922 +168817 +74457 +45830 +66297 +514173 +99186 +296236 +230903 +452312 +476444 +568981 +100811 +237350 +194724 +453622 +49559 +270609 +113701 +415393 +92173 +137004 +188795 +148280 +448114 +575964 +163155 +518719 +219329 +214247 +363927 +65357 +87617 +552612 +457817 +124796 +47740 +560463 +513968 +273637 +354212 +95959 +261061 +307265 +316237 +191342 +463272 +169273 +396518 +93261 +572733 +407386 +202658 +446497 +420852 +229274 +432724 +34900 +352533 +49891 +66144 +146831 +467484 +97988 +561647 +301155 +507421 +173217 +577584 +451940 +99927 +350639 +178941 +485155 +175948 +360673 +92963 +361321 +48739 +577310 +517795 +93405 +506458 +394681 +167920 +16995 +519573 +270532 +527750 +563403 +494608 +557780 +178691 +8676 +186927 +550173 +361656 +575911 +281315 +534377 +57570 +340894 +37624 +143103 +538243 +425077 +376545 +108129 +170974 +7522 +408906 +264279 +79415 +344025 +186797 +234349 +226472 +123639 +225177 +237984 +38714 +223671 +358247 +152465 +521405 +453722 +361111 +557117 +235832 +309341 +268469 +108353 +532531 +357279 +537280 +437618 +122953 +7088 +36693 +127659 +431901 +57244 +567565 +568111 +202926 +504516 +555685 +322369 +347620 +110231 +568982 +295340 +529798 +300341 +158160 +73588 +119476 +387216 +154994 +259755 +211282 +433971 +263588 +299468 +570138 +123017 +355106 +540172 +406215 +8401 +548844 +161820 +396432 +495348 +222407 +53123 +491556 +108130 +440617 +448309 +22596 +346841 +213829 +135076 +56326 +233139 +487418 +227326 +137763 +383389 +47882 +207797 +167452 +112065 +150703 +421109 +171753 +158279 +240800 +66821 +152886 +163640 +475466 +301799 +106712 +470885 +536370 +420389 +396768 +281950 +18903 +357529 +33650 +168243 +201004 +389295 +557150 +185327 +181256 +557396 +182025 +61564 +301928 +332455 +199403 +18444 +177452 +204206 +38465 +215906 +153103 +445019 +324527 +299207 +429281 +574675 +157067 +241269 +100850 +502818 +576566 +296775 +873 +280363 +355240 +383445 +286182 +67327 +422778 +494855 +337246 +266853 +47516 +381991 +44081 +403862 +381430 +370798 +173383 +387173 +22396 +484066 +349414 +262235 +492814 +65238 +209420 +336276 +453328 +407286 +420490 +360328 +158440 +398534 +489475 +477389 +297108 +69750 +507833 +198992 +99736 +546444 +514914 +482574 +54355 +63478 +191693 +61684 +412914 +267408 +424641 +56872 +318080 +30290 +33441 +199310 +337403 +26731 +453390 +506137 +188945 +185950 +239843 +357944 +290570 +523637 +551952 +513397 +357870 +523517 +277048 +259879 +186991 +521943 +21900 +281074 +187194 +526723 +568147 +513037 +177338 +243831 +203488 +208494 +188460 +289943 +399177 +404668 +160761 +271143 +76087 +478922 +440045 +449432 +61025 +331138 +227019 +147577 +548337 +444294 +458663 +236837 +6854 +444926 +484816 +516641 +397863 +188534 +64822 +213453 +66561 +43218 +514901 +322844 +498453 +488788 +391656 +298994 +64088 +464706 +193720 +199017 +186427 +15278 +350386 +342335 +372024 +550939 +35594 +381382 +235902 +26630 +213765 +550001 +129706 +577149 +353096 +376891 +28499 +427041 +314965 +231163 +5728 +347836 +184388 +27476 +284860 +476872 +301317 +99546 +147653 +529515 +311922 +20777 +2613 +59463 +430670 +560744 +60677 +332087 +296724 +353321 +103306 +363887 +76431 +423058 +120340 +119452 +6723 +462327 +163127 +402723 +489382 +183181 +107656 +375409 +355228 +430762 +512468 +409125 +270544 +559113 +495388 +529434 +38355 +422025 +379667 +131386 +183409 +573536 +581317 +425404 +350084 +472 +28532 +329717 +230220 +187196 +484166 +97434 +224595 +87483 +516998 +314876 +32610 +514586 +344816 +394418 +402330 +305993 +371497 +315790 +294908 +207431 +561014 +26584 +368671 +374990 +54747 +47571 +449424 +283761 +84735 +522127 +120473 +524656 +479659 +131627 +450959 +153300 +580908 +207785 +49115 +284991 +96505 +278306 +291655 +1404 +489304 +557459 +37740 +157465 +390475 +119166 +33871 +247428 +75905 +20779 +65035 +333556 +375415 +383676 +505243 +87327 +16451 +287235 +70190 +245067 +417520 +229234 +183786 +333018 +554156 +198915 +108021 +128262 +412443 +242543 +555050 +436511 +445233 +207886 +156397 +526257 +521357 +413043 +427189 +401614 +94823 +351130 +105945 +182314 +305879 +526197 +64409 +496800 +236461 +138175 +43816 +185904 +345711 +72536 +526737 +360400 +556537 +426053 +59044 +28290 +222548 +434915 +418623 +246454 +111801 +12448 +427133 +459117 +11262 +169045 +469996 +304390 +513096 +322822 +196371 +504977 +395364 +243950 +216218 +417217 +106736 +58194 +504101 +478522 +379314 +30432 +207027 +297146 +91844 +176031 +98287 +278095 +196053 +343692 +523137 +220224 +349485 +376193 +407067 +185781 +37871 +336464 +46331 +44244 +80274 +170147 +361106 +468499 +537864 +467457 +267343 +291528 +287828 +555648 +388284 +576085 +531973 +350122 +422253 +509811 +78093 +410019 +133090 +581205 +343976 +9007 +92478 +450674 +486306 +503978 +46378 +335578 +404071 +225558 +217923 +406217 +138054 +575815 +234990 +336257 +159240 +399516 +226408 +531126 +138599 +61693 +89861 +29504 +163296 +477906 +48419 +25595 +195594 +97592 +392555 +203849 +139248 +245651 +275755 +245426 +127279 +521359 +517623 +235747 +475906 +11198 +336101 +70134 +505447 +218996 +30080 +484457 +120441 +575643 +132703 +197915 +505576 +90956 +99741 +517819 +240918 +150834 +207306 +132682 +88250 +213599 +462584 +413321 +361521 +496081 +410583 +440027 +417284 +397069 +280498 +473171 +129739 +279774 +29370 +518899 +509867 +85556 +434930 +280710 +55077 +348793 +157756 +281111 +190689 +281447 +502854 +232894 +268742 +199553 +220808 +137330 +256903 +116017 +466416 +41635 +110906 +340934 +557501 +146767 +517617 +487159 +1561 +417281 +489014 +292463 +113533 +412247 +263973 +515444 +343561 +310200 +293804 +225867 +150320 +183914 +9707 +89999 +177842 +296524 +287829 +68300 +363654 +465986 +159969 +313948 +522779 +219820 +198352 +12959 +266727 +8016 +175804 +497867 +307892 +287527 +309638 +205854 +114119 +23023 +322586 +383341 +134198 +553522 +70426 +329138 +105367 +175597 +187791 +17944 +366611 +93493 +242422 +41842 +558840 +32203 +19667 +124297 +383726 +252625 +234794 +498228 +102906 +287967 +69021 +51326 +243896 +509423 +440124 +122582 +344325 +34455 +442478 +23587 +236904 +185633 +349841 +44294 +112568 +186296 +71914 +3837 +135486 +223747 +557517 +385181 +265313 +404263 +26564 +516867 +497096 +332351 +345139 +444304 +510877 +356387 +561214 +311471 +408789 +561729 +291380 +174671 +45710 +435136 +388858 +361693 +50811 +531134 +573605 +340175 +534988 +382671 +327047 +348400 +547137 +401037 +490711 +499266 +236370 +449075 +334015 +107234 +232315 +462953 +252048 +186822 +410168 +28994 +45550 +453626 +417957 +468577 +106338 +391684 +375143 +217622 +357903 +347648 +142182 +213843 +299148 +352587 +436676 +161875 +144655 +304741 +235017 +181799 +211042 +335507 +553731 +412531 +229740 +437129 +423830 +561806 +337666 +52016 +138057 +70254 +494393 +73119 +262425 +565395 +305329 +489611 +377080 +569450 +549766 +332940 +235302 +53893 +203781 +38449 +114870 +18699 +396338 +449839 +423613 +379767 +369594 +375812 +359219 +229311 +291675 +224907 +416885 +32964 +573406 +17282 +103375 +81860 +576886 +461334 +35672 +243442 +217269 +445055 +211112 +455675 +412384 +88967 +550643 +24223 +504074 +9275 +155546 +329542 +172658 +331600 +315492 +194208 +162867 +324614 +432017 +140860 +157944 +406616 +486079 +361172 +258346 +494140 +315384 +451014 +242619 +413684 +386187 +408501 +121089 +343603 +232538 +558671 +551596 +32992 +406647 +435260 +11156 +40896 +175382 +110560 +252968 +189694 +63154 +564816 +72004 +164788 +434583 +453104 +111878 +268484 +290768 +473215 +450620 +32673 +277479 +529917 +315868 +562419 +378347 +398637 +84097 +120527 +134193 +431472 +400238 +86426 +208830 +524535 +22213 +516813 +526044 +386193 +246672 +386739 +559252 +153344 +236123 +246074 +323615 +92644 +408621 +323231 +499940 +296105 +578902 +150098 +145015 +131431 +318618 +68409 +497928 +362520 +467755 +112702 +163219 +277289 +192362 +497674 +525439 +56267 +465868 +407570 +551608 +345211 +179653 +55295 +97315 +534041 +505822 +411082 +132375 +25378 +272008 +536605 +123511 +148737 +577712 +493751 +29587 +468297 +528458 +491058 +558976 +181421 +209685 +147545 +486964 +570516 +168662 +19446 +395997 +242911 +232511 +317035 +354527 +5961 +513793 +124390 +370123 +113397 +195790 +252813 +326919 +432414 +409239 +458221 +115667 +212239 +279279 +375554 +546622 +317188 +260818 +286021 +377111 +209868 +243148 +132037 +560624 +459721 +193498 +22623 +254164 +112841 +383470 +62692 +227940 +471335 +44858 +213649 +179898 +102837 +474078 +44478 +256197 +309492 +182923 +421139 +275695 +104965 +480780 +449749 +76513 +578591 +336695 +247474 +320490 +246105 +53183 +485740 +575823 +510735 +290741 +37017 +348708 +279784 +453634 +567644 +434192 +482719 +435324 +544299 +106896 +569926 +301574 +492885 +103462 +487151 +513585 +219647 +303685 +459645 +76292 +188579 +154883 +207728 +425074 +310493 +27221 +371694 +119404 +399665 +273556 +454577 +580698 +267664 +295769 +423740 +22461 +22667 +508443 +390401 +369997 +524627 +193349 +132223 +576743 +130586 +487741 +107542 +501420 +520109 +308156 +540581 +231362 +86471 +472930 +351133 +463605 +575577 +159842 +39504 +223020 +63525 +298627 +139883 +375205 +303549 +16838 +495680 +408112 +394474 +188044 +472143 +463751 +31481 +378139 +190853 +442614 +172006 +140270 +133051 +178028 +495090 +88455 +13232 +46323 +346275 +425905 +487013 +433136 +514402 +521906 +4157 +61418 +567205 +213351 +304008 +296492 +506561 +408120 +415961 +323186 +480379 +349199 +201918 +135023 +456483 +136173 +237917 +4972 +99081 +331569 +150007 +36450 +93400 +487461 +203629 +218093 +487181 +113935 +139512 +210981 +358883 +47419 +248382 +80357 +462663 +83097 +26159 +80429 +283055 +452676 +50159 +12326 +29430 +303264 +158122 +569070 +52925 +534876 +46975 +426376 +170293 +434417 +235517 +218476 +445008 +482774 +305632 +116848 +557252 +229270 +453485 +382214 +54759 +59171 +193328 +17152 +238071 +148531 +409725 +75434 +65358 +473057 +415408 +579415 +48636 +269606 +298784 +162799 +356400 +326854 +24601 +66499 +340247 +20992 +190218 +548464 +122203 +405306 +495376 +536028 +5713 +206831 +9395 +503939 +194440 +474253 +395849 +165141 +204935 +412621 +402922 +87141 +570664 +202622 +137362 +221737 +78947 +112129 +341957 +169562 +164780 +360216 +107641 +415015 +444955 +559102 +123070 +176592 +309366 +116461 +222075 +530470 +214363 +414487 +471567 +292123 +370210 +364243 +510254 +396350 +141524 +220310 +398604 +145436 +392476 +17482 +78032 +336171 +130812 +489743 +346638 +418854 +139072 +263860 +458240 +383443 +337533 +182334 +535608 +517946 +489924 +308117 +129945 +59973 +538364 +513458 +449433 +25165 +335851 +487688 +153834 +347612 +349689 +443688 +486008 +479149 +442286 +61108 +315338 +511546 +506444 +775 +121839 +291412 +497626 +387223 +367095 +557896 +196118 +530652 +447991 +215622 +232160 +296731 +272273 +473415 +364705 +235790 +479950 +141278 +547903 +66523 +353989 +121875 +237735 +100083 +348941 +288983 +390083 +168248 +120776 +489764 +219135 +551713 +256035 +309005 +112493 +579759 +114972 +458992 +295768 +158497 +309696 +363844 +507966 +313491 +280779 +327130 +292901 +127761 +183843 +456521 +164475 +224281 +443713 +72514 +567383 +476215 +565650 +17708 +474471 +248334 +196313 +164759 +212453 +319024 +332916 +35436 +113139 +172716 +7570 +161609 +144534 +137475 +561411 +45844 +332027 +36990 +190160 +421231 +283210 +365611 +511407 +400887 +485071 +481214 +347203 +153506 +397403 +229599 +357322 +76034 +101189 +567444 +92363 +526767 +218811 +362812 +339120 +579696 +399269 +10705 +549012 +410428 +105623 +535307 +419235 +119911 +236604 +515779 +188173 +66397 +549119 +478742 +256180 +128224 +440539 +112818 +315434 +97513 +171970 +433483 +226008 +83217 +424548 +343753 +350334 +479280 +208808 +43266 +399893 +444386 +47687 +499093 +565269 +465835 +167486 +433460 +169872 +299640 +158466 +241373 +50576 +161567 +73560 +349804 +181745 +352684 +450357 +532693 +88335 +256518 +94926 +541197 +14629 +276149 +539439 +498738 +25654 +291330 +146465 +160190 +513064 +75748 +499007 +164464 +134042 +422416 +543315 +34056 +303197 +394801 +293071 +44964 +529083 +414522 +331180 +227599 +581040 +382850 +159898 +176841 +205352 +540782 +406591 +184499 +14380 +350230 +458175 +528786 +314935 +111086 +2191 +20371 +337042 +558371 +296907 +539937 +511463 +574856 +87864 +403817 +152598 +169712 +533227 +173545 +478862 +19455 +258433 +373440 +460229 +525682 +176857 +525050 +277025 +156416 +206784 +415179 +183204 +210374 +312868 +514366 +65208 +376342 +515792 +383066 +85247 +119132 +338007 +88748 +206705 +495808 +532164 +150686 +35474 +207860 +111165 +391199 +346011 +537721 +11390 +487482 +360983 +400347 +92795 +347506 +324322 +371958 +101280 +222842 +563604 +210299 +150616 +96351 +330455 +273551 +228749 +248051 +495252 +372265 +52664 +191874 +157416 +446428 +136681 +1228 +321811 +93791 +477867 +192520 +157124 +40620 +200541 +103904 +329494 +60093 +112573 +489125 +513115 +322968 +561619 +74309 +572462 +248252 +375376 +217312 +243213 +79878 +452218 +349754 +554291 +434043 +460373 +452591 +567787 +504711 +196007 +511153 +312416 +296056 +308849 +203667 +253223 +331230 +465545 +363048 +69392 +301506 +216198 +147979 +6005 +381870 +56983 +320972 +144122 +210855 +151480 +299288 +462486 +103931 +321079 +4134 +239861 +540006 +413805 +221222 +198943 +450790 +380597 +388298 +58737 +246197 +160726 +398554 +513946 +222235 +323851 +364703 +125643 +169800 +445662 +223764 +575372 +489207 +559474 +7155 +453819 +402720 +102355 +415076 +287436 +35705 +111076 +395865 +310862 +570834 +54728 +215778 +80053 +35148 +350488 +524140 +190097 +36661 +302110 +96884 +383397 +245462 +446958 +138937 +424712 +561814 +276964 +148034 +411068 +357824 +103257 +322149 +508899 +580294 +214386 +114419 +271429 +168260 +209835 +573072 +252269 +31980 +161308 +281508 +192714 +247599 +188948 +180563 +419601 +233660 +154804 +311846 +181499 +5535 +175082 +531018 +412338 +166995 +441411 +427820 +516846 +287366 +67959 +271266 +330845 +74209 +508167 +542699 +66485 +453756 +158412 +443784 +118097 +265050 +29074 +152623 +532493 +292988 +530384 +192660 +502336 +472648 +151657 +351626 +241010 +115070 +268356 +539557 +304698 +251140 +497158 +527445 +385428 +179200 +512394 +184978 +141910 +36311 +579457 +19129 +424960 +181714 +126216 +512911 +488360 +379533 +337551 +325410 +364587 +468885 +211107 +90062 +500446 +105960 +451951 +431431 +134178 +164548 +173826 +373988 +15157 +3091 +393557 +380011 +75372 +37403 +209995 +493610 +315899 +353299 +355040 +547000 +86133 +58174 +377326 +510230 +480583 +158588 +432529 +311206 +127626 +239980 +166340 +104185 +405174 +507211 +542782 +448078 +253477 +542694 +567308 +214853 +288824 +283268 +480757 +503200 +221089 +112388 +171539 +124452 +224200 +206362 +428754 +256192 +119414 +351620 +330050 +547504 +216398 +94261 +19916 +163242 +432588 +143824 +361103 +271138 +260150 +313627 +141086 +308263 +388453 +153217 +372794 +514787 +251910 +351335 +92683 +465836 +18442 +404128 +208476 +47873 +303219 +201622 +367489 +32760 +436174 +401926 +338419 +45248 +328464 +312216 +156282 +315702 +300701 +345401 +515350 +29094 +284296 +466449 +351057 +110672 +364853 +10014 +415828 +397522 +451412 +433124 +158277 +93476 +183387 +109889 +223326 +105547 +530061 +256301 +526778 +80974 +86650 +45835 +202154 +92678 +315991 +423919 +455044 +491168 +272253 +146627 +285349 +86001 +44171 +162332 +257328 +432820 +519275 +380639 +269436 +236016 +543215 +346752 +575970 +423498 +136926 +195648 +126634 +133078 +138656 +490012 +122388 +195165 +434900 +533625 +504167 +333697 +216576 +538775 +125072 +391154 +545007 +150292 +566717 +367362 +490991 +356623 +141271 +402795 +516786 +39499 +536716 +293324 +212853 +276381 +57124 +325992 +394659 +452178 +117674 +461172 +518586 +497021 +462345 +526570 +17328 +202928 +62566 +411277 +256983 +49473 +211206 +398031 +277955 +531178 +453959 +27946 +252844 +30273 +536933 +500298 +229111 +7977 +27642 +303726 +79927 +110313 +527691 +442205 +33345 +365851 +233236 +239157 +409221 +400803 +32947 +422516 +359727 +215872 +559454 +289716 +450247 +57827 +312298 +530383 +260048 +35857 +224222 +299533 +13296 +325907 +117869 +54088 +391011 +340478 +205344 +347823 +468604 +78701 +101414 +197499 +490871 +89273 +380343 +441974 +35974 +486114 +354398 +535536 +294030 +7276 +278742 +137028 +98721 +372764 +429802 +72105 +220307 +116845 +195406 +333000 +130401 +264382 +125458 +363036 +286994 +531070 +113801 +4108 +47603 +130118 +573924 +302990 +237566 +21470 +577926 +139436 +425925 +36844 +63602 +399791 +35894 +347228 +225617 +504813 +245320 +466007 +553931 +166731 +164885 +19090 +457262 +247806 +502895 +167593 +352491 +520 +26386 +497348 +352000 +386164 +32901 +730 +30925 +333167 +150361 +231747 +462244 +504958 +260738 +313762 +346645 +486118 +202998 +541613 +183884 +230245 +83172 +126638 +51844 +421673 +118625 +377723 +229427 +371326 +104345 +361687 +114246 +397354 +104137 +120850 +260516 +389168 +234555 +26348 +78522 +409784 +303024 +377949 +69887 +546983 +113736 +298197 +476810 +137315 +376321 +410337 +492905 +119785 +158167 +185930 +354061 +106563 +328452 +506587 +536517 +480173 +570688 +376441 +252127 +247720 +132554 +41923 +400317 +170041 +151938 +198650 +6437 +49091 +221820 +455966 +309859 +300659 +15850 +388014 +253386 +65415 +238228 +548882 +302155 +93483 +371869 +397287 +315249 +360564 +448410 +21382 +477474 +144862 +517515 +230190 +322353 +231568 +14940 +132719 +498942 +182469 +113720 +168890 +94852 +246077 +117535 +52596 +419116 +522020 +255338 +125228 +564332 +106375 +249534 +220915 +177758 +293057 +222430 +196878 +554980 +375606 +173081 +84936 +418907 +562229 +457616 +125700 +66038 +239274 +574110 +305540 +98431 +167347 +53345 +438481 +286010 +5569 +343606 +168898 +191301 +236338 +291394 +715 +520237 +236954 +192212 +524002 +471625 +476029 +413124 +203455 +483328 +476417 +114389 +372428 +369221 +322654 +388157 +561314 +264540 +418680 +359540 +426182 +521613 +92248 +74478 +398905 +554273 +125909 +430583 +418959 +503522 +382999 +403145 +536375 +352618 +108193 +279696 +163253 +439007 +204536 +552186 +269926 +372147 +399921 +201418 +240565 +471483 +91619 +393971 +331648 +385856 +567440 +81922 +391722 +372894 +535997 +134096 +545958 +239943 +186929 +34222 +177714 +277812 +197111 +281878 +532003 +557172 +142890 +196116 +385454 +322845 +374987 +123137 +255112 +111207 +304819 +523526 +336046 +42893 +241273 +240049 +90659 +271364 +408008 +253282 +167067 +354278 +178317 +229653 +93333 +163666 +566920 +495199 +100329 +218119 +558864 +257382 +406152 +206587 +420339 +325919 +278853 +555763 +293200 +151000 +209664 +79380 +197177 +353953 +464522 +392260 +46144 +154202 +164366 +206025 +511236 +24921 +497907 +393226 +318138 +364125 +157321 +492395 +187857 +109939 +441500 +144251 +368581 +51403 +283498 +43555 +89356 +404601 +23272 +425762 +460682 +544629 +209829 +322029 +199247 +307262 +571242 +124236 +162393 +104829 +250766 +563938 +237399 +131516 +483001 +21994 +97958 +540187 +264497 +384808 +343187 +51277 +6712 +566103 +435384 +292082 +359039 +165157 +267972 +263796 +489313 +392722 +541924 +554433 +571034 +146112 +201934 +518716 +64116 +294992 +289586 +159970 +479617 +269006 +140465 +513260 +554805 +6579 +452696 +34445 +548296 +372983 +509656 +199339 +130030 +128372 +449454 +139306 +247914 +99024 +499134 +536653 +468917 +412813 +404338 +215303 +455414 +413497 +574988 +397117 +188631 +378701 +241867 +143129 +419884 +412749 +496954 +317732 +16977 +398309 +162363 +147576 +100016 +209018 +92660 +173302 +525732 +449198 +99734 +12733 +172946 +168032 +210988 +340697 +4795 +534887 +483553 +278323 +178175 +190095 +357542 +230432 +227460 +334609 +562121 +378126 +555357 +325666 +451859 +526837 +531710 +297249 +294839 +499785 +254976 +527220 +173057 +11760 +163012 +215998 +114420 +57812 +563712 +513887 +201859 +36333 +291990 +338375 +460621 +518889 +337502 +133050 +80172 +537007 +295270 +335644 +227852 +336044 +204137 +82259 +165675 +295713 +343937 +442567 +356002 +346932 +62985 +180925 +525381 +13081 +377406 +159774 +462643 +359105 +185821 +390201 +84168 +128059 +80340 +481159 +491902 +306619 +353807 +390569 +541562 +292616 +64621 +439224 +96288 +449798 +160927 +496324 +90778 +126145 +97230 +572767 +11570 +539075 +350988 +3779 +208135 +551315 +216449 +169606 +502 +67765 +281414 +118594 +146127 +543985 +124927 +471394 +385508 +373783 +501315 +140974 +42757 +527054 +202387 +513056 +329931 +153973 +510152 +520812 +534601 +131282 +386638 +508538 +234779 +229329 +396568 +153568 +229478 +153574 +356299 +436694 +324139 +299409 +212462 +478155 +393266 +117836 +190760 +213605 +196 +444382 +445211 +363845 +433277 +521141 +464786 +169076 +301402 +4495 +177258 +328962 +183757 +452966 +416059 +113233 +559417 +280678 +481398 +328372 +234910 +30667 +343062 +383046 +370953 +258089 +404229 +456931 +535183 +300867 +60507 +262672 +7288 +81100 +575395 +539951 +347848 +437594 +352005 +14941 +196453 +528386 +466939 +482187 +293468 +494077 +217285 +362951 +435751 +411480 +517315 +480015 +60610 +353001 +376442 +430265 +478338 +303069 +525344 +437331 +389315 +8179 +31981 +313872 +330920 +515465 +258905 +142249 +323128 +389699 +565012 +124636 +488693 +376608 +309424 +370596 +261940 +39871 +226984 +152866 +515050 +116861 +412876 +120411 +550452 +565273 +273791 +181466 +183155 +293505 +336113 +569997 +303738 +331049 +147030 +74058 +198176 +23991 +198841 +79816 +85183 +261535 +566756 +386291 +318200 +569849 +57429 +36049 +420827 +519271 +24391 +172087 +158795 +133002 +522198 +133698 +499365 +79261 +258860 +457718 +179948 +421875 +558073 +206684 +529762 +456756 +65773 +425722 +53102 +294264 +416730 +38574 +176275 +404297 +127494 +242060 +272212 +189244 +510861 +421370 +208516 +206431 +248457 +39502 +375087 +130839 +308730 +572453 +263474 +544611 +255708 +412604 +390094 +578131 +234463 +493563 +9450 +381914 +148999 +32300 +423576 +569758 +347253 +92939 +112212 +13923 +39472 +363736 +289659 +269949 +88349 +188522 +488915 +129054 +573823 +316000 +440562 +408818 +539302 +199575 +122300 +340047 +322816 +472878 +313922 +228071 +265648 +400166 +169166 +10040 +125245 +148766 +31281 +172599 +431067 +208236 +441824 +175611 +15148 +431199 +521587 +50025 +443139 +349822 +515056 +27530 +571970 +82367 +7115 +424333 +157601 +537506 +447187 +115182 +547597 +5586 +143040 +31650 +196336 +279818 +206273 +403104 +514248 +243190 +558642 +548246 +16848 +391539 +89614 +284589 +191314 +259452 +208380 +209441 +465463 +385005 +321385 +223569 +11727 +87574 +566470 +210890 +323598 +427193 +425676 +401240 +94021 +259571 +447553 +456053 +84693 +14278 +119995 +234595 +408696 +136271 +143560 +357578 +28071 +36561 +157102 +293789 +392251 +356622 +180274 +48320 +475779 +301326 +100977 +413551 +574010 +404479 +80725 +552221 +575441 +197424 +124601 +215633 +359546 +25386 +73199 +334466 +156572 +124614 +34121 +460049 +327623 +441695 +292488 +476514 +464018 +348571 +113413 +125208 +129690 +446218 +493761 +383413 +460390 +343149 +374041 +525211 +451263 +333683 +385194 +107427 +102872 +517249 +475879 +575755 +147787 +297180 +343774 +112437 +142240 +384503 +511111 +51089 +145408 +143582 +408138 +162858 +71850 +126925 +222781 +314616 +425609 +203928 +337563 +223300 +52644 +272566 +232597 +374430 +469075 +267164 +265851 +28134 +308889 +465795 +47263 +233727 +42 +493117 +124621 +533378 +361259 +458750 +429033 +383289 +490927 +520964 +174420 +64425 +378859 +401850 +281475 +46508 +205300 +280736 +110961 +230679 +151956 +321497 +73665 +488736 +165353 +365983 +556230 +21465 +581226 +448861 +3793 +347335 +150726 +75319 +2521 +285894 +133876 +104589 +346013 +63516 +83656 +491515 +326256 +49942 +28508 +475413 +270222 +235839 +48554 +327777 +111179 +507171 +425973 +449490 +205239 +82375 +459575 +432300 +91885 +340922 +270239 +195894 +121417 +344831 +439651 +232148 +391688 +480793 +534275 +260823 +469294 +8688 +255654 +191300 +383464 +81594 +21240 +478077 +517596 +555953 +294119 +402234 +459500 +564280 +106849 +167501 +98328 +267411 +145512 +272599 +50054 +414156 +161129 +418226 +11796 +502090 +390350 +440500 +240727 +104406 +163682 +437910 +143767 +358901 +527631 +500543 +28377 +231097 +227985 +556703 +421566 +73201 +478393 +280347 +15497 +131969 +515760 +295440 +462527 +42147 +120007 +212895 +425361 +454143 +5758 +366782 +213932 +229848 +458861 +132791 +476664 +150365 +343038 +529649 +180515 +499810 +329041 +15660 +419228 +396295 +502644 +321085 +245049 +34193 +217323 +446455 +528046 +375573 +15802 +147448 +407291 +84000 +280891 +150487 +510606 +163025 +249964 +126123 +233771 +118507 +97278 +357386 +23121 +10580 +2153 +176017 +371472 +373289 +173908 +296797 +334083 +301107 +577522 +125404 +278359 +575032 +273002 +266371 +108315 +255633 +503490 +250051 +143927 +117407 +198271 +447043 +329789 +399991 +458388 +87489 +228411 +494634 +260802 +454161 +446322 +231079 +438373 +395665 +244539 +212427 +356660 +347276 +183287 +498374 +21167 +544522 +418533 +288493 +245660 +406103 +406976 +367313 +455555 +117337 +384465 +185697 +160393 +463825 +276852 +181462 +176288 +452816 +102497 +54277 +225791 +361046 +197278 +9857 +227736 +398992 +55868 +170914 +181677 +467803 +560470 +264599 +540372 +559442 +201207 +137227 +267643 +355471 +245431 +555669 +344498 +84783 +193474 +102411 +401860 +119469 +448786 +449990 +568082 +340472 +307573 +231828 +307547 +82052 +15140 +493612 +503972 +386592 +473219 +495557 +159440 +355869 +311531 +209733 +240119 +415048 +296098 +249482 +15663 +151432 +263011 +488539 +463913 +502798 +174276 +495613 +407861 +229304 +146742 +545039 +161202 +295134 +162144 +453317 +52759 +335201 +222903 +20333 +559550 +336049 +346140 +491223 +306611 +102746 +455355 +449921 +477288 +77821 +289712 +452663 +147758 +129571 +490869 +345961 +94501 +160394 +432993 +178796 +372494 +316323 +383435 +194940 +74583 +148911 +518027 +431827 +32724 +158548 +227227 +500330 +54679 +321024 +471175 +252074 +476569 +573258 +337247 +294373 +558661 +148898 +563267 +163112 +411968 +193565 +455210 +349344 +337160 +160456 +255158 +553678 +123843 +549687 +381968 +579471 +100604 +379841 +357526 +197263 +14756 +412639 +210915 +47204 +539251 +166255 +490199 +260363 +91654 +170550 +187888 +97362 +285418 +176993 +292741 +361901 +296988 +223496 +493753 +114907 +151358 +316534 +472509 +499802 +348519 +347747 +58851 +104790 +396779 +130528 +2255 +19624 +526800 +233950 +505945 +131207 +290750 +114090 +196665 +8708 +134688 +394715 +115088 +492196 +530099 +518729 +291572 +421457 +445365 +78929 +415461 +551796 +210002 +207913 +344878 +303893 +149196 +353275 +122413 +553361 +519132 +467135 +431439 +17089 +322119 +228214 +35062 +105689 +366141 +285651 +60409 +472671 +401446 +492846 +21023 +421952 +374100 +265200 +506628 +62298 +243626 +212122 +350648 +409921 +428140 +399212 +388267 +198921 +429246 +202040 +570001 +261346 +61171 +131815 +455448 +82696 +554607 +102174 +386803 +188421 +191846 +209898 +380117 +321064 +119617 +188651 +132210 +244299 +174072 +542910 +378334 +118405 +543347 +183657 +581180 +395289 +64760 +265584 +29573 +493720 +94795 +315601 +416596 +260106 +244019 +463884 +579468 +112085 +300972 +238528 +382542 +57672 +165298 +46889 +289497 +337180 +481252 +7913 +432150 +288161 +403758 +257336 +565331 +346589 +270785 +205670 +231580 +508580 +98871 +239997 +554579 +160057 +404922 +78771 +380756 +171199 +148077 +22892 +145378 +26967 +235200 +176007 +90349 +554377 +189744 +257053 +270515 +66508 +113890 +291983 +558927 +420916 +140908 +58384 +438226 +575776 +106935 +40602 +468993 +494810 +210408 +365685 +483722 +39430 +258793 +272615 +51476 +189919 +443887 +391648 +422670 +445135 +198959 +405529 +459757 +465489 +81827 +262576 +408289 +309237 +76249 +460091 +512630 +45959 +280320 +200492 +404652 +48475 +18480 +457097 +65889 +162256 +265950 +520752 +299082 +51500 +499313 +104906 +35438 +167647 +7274 +387824 +242139 +173166 +399830 +12014 +510642 +154053 +67785 +78170 +514118 +87998 +52703 +203539 +534533 +85926 +274438 +401653 +458790 +509262 +144481 +387515 +246649 +503207 +235131 +501531 +62025 +43286 +272323 +326128 +561889 +167529 +171067 +50778 +301282 +469719 +509388 +480317 +379055 +546428 +192763 +445602 +420882 +232790 +174332 +232865 +292822 +511145 +119502 +312591 +110330 +281353 +116244 +58778 +428079 +64902 +520840 +232054 +473214 +572574 +296684 +351590 +217997 +178761 +71618 +226496 +285212 +381195 +499903 +232849 +468997 +345559 +503097 +578570 +396404 +405223 +578752 +403500 +188958 +504498 +491623 +462929 +525762 +395550 +574227 +240751 +169356 +524694 +40886 +571635 +487774 +86220 +95677 +268987 +502599 +155270 +103855 +125100 +241355 +220214 +391774 +110618 +154587 +134483 +458781 +360877 +465963 +194595 +346934 +127153 +188078 +553869 +102665 +400547 +33759 +42779 +397587 +140295 +151807 +549136 +470288 +89738 +328368 +546934 +164255 +563683 +399988 +360951 +217303 +326781 +546133 +135399 +94666 +330037 +569839 +411070 +497466 +404805 +417854 +318442 +255036 +457230 +346863 +307438 +370448 +5124 +152582 +38118 +12179 +58462 +308420 +329456 +74920 +250368 +186428 +556073 +111806 +361244 +80273 +230964 +156754 +503101 +75173 +389404 +195538 +88848 +286018 +245481 +140929 +533721 +268378 +70048 +315467 +46269 +372807 +192403 +387328 +163033 +481314 +65306 +192529 +321107 +112232 +441216 +412399 +565391 +220670 +61471 +463290 +346707 +67587 +147624 +13031 +396754 +278601 +439426 +42834 +281829 +376209 +353148 +556562 +97579 +217989 +319530 +82551 +235319 +431799 +53892 +52853 +54533 +88897 +225093 +386777 +546742 +273684 +413900 +245447 +577995 +16249 +188414 +485142 +199602 +89258 +109679 +502397 +14494 +13632 +51674 +244999 +305050 +455956 +426795 +560700 +327306 +410301 +343803 +539422 +156740 +527845 +100582 +9941 +466585 +61515 +231895 +157052 +41271 +148128 +141172 +320232 +78565 +539883 +391300 +365182 +322194 +116517 +323496 +473783 +519874 +440706 +361587 +265153 +329946 +342814 +32258 +153510 +194555 +309317 +245006 +300303 +97767 +218224 +370170 +290477 +207178 +456730 +209480 +513775 +199516 +581542 +32524 +416337 +96241 +506279 +422893 +248911 +509855 +355183 +201220 +234914 +333436 +68198 +429074 +328430 +160531 +467854 +280688 +140661 +349525 +267315 +565543 +313162 +25751 +232574 +560358 +505213 +494427 +160308 +287335 +99182 +413260 +558808 +290839 +122954 +229221 +192007 +243189 +117645 +552824 +366111 +102056 +356949 +566298 +97899 +422545 +343769 +13127 +179273 +104486 +37660 +304099 +517570 +20207 +36484 +36492 +155974 +107257 +534019 +522371 +222825 +96183 +509227 +302260 +95078 +280918 +367582 +317033 +347982 +73209 +290521 +187243 +425151 +483723 +573796 +187249 +144114 +132992 +35887 +546067 +426532 +45626 +461805 +129989 +541478 +485489 +578498 +485483 +144784 +248224 +372362 +92050 +423519 +473118 +177207 +105455 +276434 +157767 +384335 +509497 +338191 +224010 +327388 +96988 +43376 +67867 +320743 +555197 +104453 +14439 +512194 +396387 +252559 +108953 +461262 +66320 +97946 +238065 +306139 +572408 +577864 +81004 +464526 +89378 +193389 +259049 +85665 +381134 +412419 +308947 +557510 +502084 +288290 +254609 +188752 +439525 +13980 +140513 +240173 +305268 +38678 +394050 +402926 +364079 +159260 +293034 +55429 +289640 +291028 +211120 +48050 +93887 +361029 +486026 +388374 +207803 +540174 +530630 +430359 +36420 +120099 +199764 +492911 +84498 +200882 +139843 +4975 +421209 +259513 +520324 +211317 +236457 +419344 +3867 +287846 +50434 +26624 +507235 +16238 +103705 +497555 +440060 +175825 +245460 +308276 +178535 +391735 +206391 +201550 +400945 +194634 +262360 +554142 +407574 +225225 +246057 +498627 +486172 +226571 +461751 +459733 +345869 +503841 +286460 +45644 +22861 +285599 +580284 +569565 +286778 +150024 +542101 +484075 +538153 +20470 +128034 +544120 +357109 +450728 +550968 +326230 +558809 +76334 +555387 +47121 +523978 +11081 +378134 +116279 +364884 +488250 +551957 +322824 +545564 +255573 +286327 +355453 +361933 +434897 +32597 +226761 +166482 +557564 +208166 +232115 +283520 +137395 +555894 +103509 +174284 +458313 +316147 +344059 +370701 +548930 +89894 +373662 +572095 +19324 +574411 +45746 +480122 +63950 +92339 +201111 +157053 +401539 +427956 +339099 +274651 +159537 +556101 +323399 +564337 +514915 +556025 +66427 +322357 +173737 +369128 +420230 +45176 +509675 +374677 +272311 +109797 +384723 +383678 +453040 +91080 +301634 +533003 +40361 +221605 +216228 +104002 +161011 +146123 +214421 +496252 +264948 +9759 +138856 +316189 +145734 +50411 +325157 +259099 +516856 +529668 +135976 +467130 +367433 +385598 +520933 +102805 +30066 +436696 +216837 +380754 +350457 +126974 +565374 +73832 +214703 +110501 +380609 +135872 +140231 +251816 +133836 +398866 +230362 +426815 +2240 +51484 +546325 +224093 +221190 +525024 +238806 +99908 +165795 +109146 +537727 +496571 +183803 +211175 +433845 +168692 +526394 +368402 +256309 +468972 +139169 +398440 +171678 +547341 +64332 +533589 +483249 +406000 +330348 +439188 +572886 +252829 +242724 +139127 +404568 +45809 +52257 +458727 +334509 +559665 +60992 +290896 +503106 +27972 +536891 +410855 +31202 +457882 +403315 +87399 +395291 +322141 +226377 +202799 +420826 +553034 +212077 +97693 +266370 +101656 +504142 +342933 +87567 +342060 +268854 +437028 +20175 +198625 +405047 +382374 +338291 +403975 +527906 +322429 +545550 +140043 +107389 +74059 +315621 +110138 +78381 +295576 +494438 +106335 +472349 +15818 +162358 +366484 +44604 +66524 +118606 +366873 +270721 +556478 +350789 +298628 +163314 +262800 +459428 +491725 +285421 +406332 +498280 +34535 +524282 +315744 +226592 +218294 +459141 +242034 +114164 +293733 +248242 +452881 +441496 +54358 +177489 +372861 +349489 +483941 +572802 +356494 +193875 +146570 +58253 +21338 +6220 +341933 +533368 +1818 +428248 +293026 +227656 +193021 +326938 +512966 +226020 +343059 +249720 +540106 +375278 +300023 +126512 +517135 +472540 +361439 +132702 +503294 +109537 +540669 +332007 +245266 +313999 +10386 +225715 +311567 +103837 +302405 +248616 +102654 +155087 +124756 +379659 +569272 +160166 +428234 +422280 +174425 +133412 +174503 +216581 +345063 +52949 +69536 +216161 +272728 +200870 +120792 +193480 +493923 +445567 +558539 +51938 +422706 +416271 +244160 +437898 +327352 +305480 +349459 +522418 +485219 +225133 +361400 +546569 +190015 +348216 +421822 +457683 +178683 +40894 +234526 +465074 +518725 +168096 +210190 +139605 +35195 +463640 +286770 +141651 +112022 +532552 +325327 +227224 +17272 +84163 +331475 +126065 +289309 +8583 +52952 +189427 +579693 +437947 +187565 +215982 +356424 +453731 +463522 +372316 +251797 +70187 +280515 +556608 +341635 +391067 +469480 +476298 +57917 +146672 +122747 +394328 +12209 +80013 +573291 +278449 +129659 +579560 +557190 +227468 +334782 +51157 +23774 +9426 +86582 +39211 +275751 +131597 +51250 +357255 +9041 +346482 +9647 +157019 +409016 +273416 +114414 +298172 +388854 +275025 +58079 +518034 +503518 +146710 +120632 +474680 +303713 +259097 +479630 +208318 +437298 +173704 +361831 +371638 +344279 +230175 +72507 +417980 +72621 +163057 +92894 +543525 +577364 +263696 +472732 +66027 +391584 +197745 +131019 +65604 +91318 +535934 +212646 +576354 +482071 +160556 +120129 +7260 +344881 +447548 +318193 +30383 +527002 +34904 +35677 +526222 +105261 +401897 +399452 +25660 +524595 +384512 +117543 +514600 +268944 +112664 +222340 +569058 +495332 +192153 +75591 +286711 +174888 +577065 +25508 +169972 +401820 +425475 +290700 +173091 +559101 +122418 +244124 +198645 +325519 +276437 +528276 +146614 +45574 +417804 +326420 +250594 +27353 +310407 +370103 +274957 +561160 +167598 +397166 +257458 +404546 +148392 +373396 +62230 +493522 +563665 +274240 +269815 +79024 +527427 +84674 +486788 +267690 +443347 +149304 +412285 +207041 +412916 +10764 +151338 +299000 +17882 +475510 +398188 +558213 +70493 +180779 +347210 +280211 +58146 +379022 +504125 +537604 +464858 +329573 +568623 +228309 +454444 +552775 +557884 +435671 +168706 +142257 +571437 +574845 +387773 +321008 +574208 +405811 +375426 +321887 +256852 +433554 +517029 +125870 +80395 +497139 +490008 +405279 +571857 +225738 +514913 +456239 +499402 +96440 +487607 +370999 +319617 +370233 +60760 +352703 +478575 +84170 +134112 +77689 +185036 +73738 +547502 +104782 +213276 +136908 +436273 +442149 +355000 +374061 +249884 +105711 +136464 +146997 +76351 +388487 +99115 +124135 +24721 +132931 +1149 +182403 +386089 +81691 +480657 +441522 +60989 +268000 +55840 +514321 +577959 +359638 +457986 +533596 +60332 +367082 +772 +535842 +473541 +270677 +409009 +259216 +302318 +117036 +331372 +231125 +384486 +405214 +20760 +579760 +172995 +359110 +83110 +410068 +109916 +328757 +299261 +19028 +515660 +40757 +10256 +442695 +553097 +185903 +74388 +425120 +241326 +299609 +29397 +328728 +283881 +344029 +367336 +27075 +163628 +127263 +488979 +460147 +473050 +405762 +221547 +131581 +561187 +406489 +140696 +452721 +530466 +118965 +398803 +218365 +298738 +19441 +521550 +120157 +498687 +4754 +365866 +70865 +235156 +133386 +142742 +221183 +262391 +567053 +520982 +121349 +448779 +440354 +3983 +578993 +519691 +160703 +103307 +300408 +137106 +488377 +523660 +318022 +132578 +302520 +153040 +408817 +145227 +311190 +159662 +202923 +256775 +359864 +384848 +336404 +185303 +421703 +362682 +464622 +246590 +422729 +165500 +42563 +219216 +520232 +95063 +265547 +532686 +290558 +112591 +448211 +315281 +545475 +225850 +232460 +82740 +272880 +347254 +122047 +352151 +541486 +97249 +200252 +544782 +499571 +379014 +303534 +479909 +305464 +323682 +181524 +273855 +190783 +567801 +119752 +241503 +536429 +327323 +128756 +349868 +500495 +372260 +315824 +484986 +364993 +124759 +300124 +329319 +68628 +14549 +121897 +506595 +115709 +199610 +230150 +31717 +139549 +222332 +534161 +360393 +541664 +507167 +286523 +158660 +66926 +195750 +80022 +589 +252220 +47255 +247014 +49881 +455005 +232453 +445722 +516805 +544122 +541917 +469356 +370042 +130522 +502163 +307866 +408894 +524247 +52233 +177861 +348881 +357943 +295303 +475389 +431691 +61316 +143998 +503483 +340155 +488785 +133636 +133567 +251627 +470095 +34873 +88815 +261178 +468612 +127477 +157960 +15687 +303089 +572331 +456708 +190515 +126131 +239194 +332074 +129765 +107167 +478184 +421833 +359715 +112440 +331317 +74492 +505386 +247839 +534210 +134503 +422700 +352111 +98674 +546219 +520508 +503008 +461953 +101913 +362092 +22103 +359128 +316666 +335579 +414750 +297980 +365652 +53635 +547601 +97589 +570515 +7125 +99828 +321437 +80671 +426275 +294883 +212605 +424293 +338108 +25005 +6949 +234291 +428399 +7149 +343076 +575287 +431848 +307611 +293909 +542511 +564739 +573843 +356878 +472864 +336793 +121904 +161060 +254004 +269873 +216428 +77172 +346517 +498555 +203690 +348973 +117704 +552672 +275270 +208107 +314016 +427518 +278134 +53420 +318777 +238980 +350614 +467315 +61233 +272188 +550797 +125051 +553965 +187286 +282912 +102532 +156076 +467848 +130875 +531585 +523470 +507684 +332582 +438989 +489209 +125944 +127474 +371957 +570349 +283286 +541635 +547106 +253630 +388677 +572525 +542302 +554537 +367205 +228300 +443498 +356432 +123946 +490441 +211063 +224542 +116574 +434510 +33116 +353136 +134167 +128291 +542510 +433963 +147453 +365766 +374806 +336600 +38238 +165476 +535578 +127788 +157099 +173640 +114348 +496722 +58141 +467296 +235864 +5154 +22775 +422536 +136820 +453438 +446359 +41990 +422240 +39267 +391392 +233825 +308504 +478250 +87328 +4079 +127074 +267709 +377635 +353231 +185768 +487897 +124215 +249757 +341681 +557552 +280733 +374734 +281601 +456420 +222266 +491947 +432732 +467157 +94025 +410328 +428291 +397639 +163528 +234697 +557573 +208363 +515962 +358658 +373075 +438995 +425672 +450169 +216103 +254638 +288591 +53626 +43417 +372252 +5038 +218357 +120860 +399349 +485509 +530261 +477087 +352302 +96075 +495443 +133928 +197175 +134074 +212553 +448181 +152000 +254277 +105734 +75481 +343662 +479350 +554347 +71090 +297426 +22176 +277622 +469235 +163041 +221272 +154263 +89296 +68411 +192871 +183217 +258141 +53058 +540529 +566414 +560948 +254535 +246076 +135972 +420069 +431023 +343643 +32682 +515176 +222635 +377155 +547041 +513283 +26017 +366096 +252133 +138078 +25685 +321798 +549361 +14088 +423048 +570810 +374974 +447501 +492544 +554046 +575357 +420791 +6019 +340451 +66800 +565575 +148055 +330432 +483038 +455004 +288765 +11034 +86988 +347142 +450559 +543581 +293757 +556901 +533032 +333020 +260266 +22420 +13948 +512657 +214124 +231236 +177149 +560879 +491793 +35767 +312878 +118542 +450596 +423773 +48653 +224523 +509577 +462677 +75405 +350023 +452122 +42008 +302555 +382309 +468483 +368684 +372580 +31333 +153697 +124876 +330023 +315672 +53990 +136533 +82815 +356836 +414821 +268717 +7333 +77544 +525373 +371042 +227048 +576327 +419309 +239773 +8119 +424135 +297425 +222711 +489909 +393995 +31019 +539326 +517612 +102461 +199989 +483374 +44952 +103863 +528980 +441543 +85381 +247234 +50924 +483994 +87456 +424271 +356091 +534669 +378831 +560662 +298773 +257896 +498274 +305800 +40517 +183949 +276840 +84442 +297620 +298252 +119088 +233315 +283977 +345154 +287649 +427311 +63399 +4700 +463611 +224104 +209388 +431655 +364190 +28864 +412455 +283290 +228541 +422200 +985 +133596 +323853 +503081 +130732 +224675 +199688 +230862 +21396 +485390 +1532 +125778 +235541 +370478 +522478 +514292 +384338 +531707 +178746 +532747 +62915 +519491 +140691 +112093 +358024 +263687 +297595 +506085 +102446 +325768 +29558 +222054 +466965 +316254 +546500 +216785 +194184 +464390 +348371 +231582 +208995 +464339 +308856 +340946 +214604 +570586 +182227 +248441 +89078 +376310 +73450 +115924 +308235 +15994 +8749 +429679 +37751 +122040 +284286 +388707 +248163 +11320 +427997 +282062 +237600 +376751 +223314 +86215 +12443 +163255 +564940 +462640 +522713 +306303 +460675 +126833 +26201 +224757 +357899 +546782 +96427 +480944 +479556 +569273 +520528 +190690 +344832 +462466 +270354 +559776 +279259 +280909 +227781 +163798 +491098 +439658 +416088 +107375 +74132 +379800 +511654 +346687 +226161 +578849 +544272 +146149 +570624 +178299 +126671 +356380 +530766 +175954 +158798 +422095 +55780 +512276 +560626 +187329 +513125 +347216 +306486 +161840 +180917 +188192 +421437 +93120 +324891 +252216 +488476 +578347 +101959 +10693 +170038 +213586 +210439 +469202 +381463 +343248 +127785 +287328 +538690 +16382 +293022 +112378 +435785 +56092 +381504 +284365 +406129 +233119 +53629 +188509 +191053 +81056 +82252 +538319 +38439 +181948 +439710 +529344 +434035 +342958 +563882 +37734 +364743 +330986 +546226 +463211 +62210 +442724 +232241 +293858 +119345 +61953 +577033 +522015 +381587 +350107 +4936 +511307 +228771 +177811 +231450 +176168 +84540 +259408 +264238 +539738 +255827 +459382 +221105 +431742 +204337 +227741 +336356 +37655 +167159 +59352 +165937 +53956 +378712 +88462 +495786 +542938 +566498 +367228 +157577 +442661 +62363 +390689 +480664 +521540 +414249 +20571 +160855 +451683 +156832 +570045 +326542 +568276 +568717 +563311 +113579 +218268 +546095 +160661 +341118 +150649 +462632 +198972 +220025 +61720 +430681 +524011 +457217 +40064 +285583 +314493 +78023 +470882 +298722 +555597 +489829 +314779 +367818 +138503 +243737 +580255 +444565 +386677 +190841 +493074 +234347 +466988 +227033 +519039 +351554 +390585 +443303 +140983 +81079 +538005 +169757 +368780 +457322 +341804 +409116 +181805 +284292 +551358 +344548 +503569 +336587 +417055 +522315 +58705 +148955 +375530 +474934 +577893 +28881 +360772 +445267 +244737 +355777 +72811 +190788 +54513 +243075 +518551 +487530 +292169 +69293 +397303 +129285 +429996 +109532 +53802 +340573 +91280 +535602 +270908 +381925 +549220 +488573 +47131 +32735 +117525 +279085 +43961 +188906 +394677 +395 +185201 +189365 +127596 +32712 +504810 +3703 +182874 +146981 +306755 +453093 +520503 +169808 +225670 +91063 +348584 +461802 +572555 +185922 +131497 +46736 +536006 +256505 +214975 +13445 +350736 +98115 +50304 +361180 +511333 +564820 +429717 +222500 +40083 +538230 +349438 +371250 +528578 +240418 +302380 +261758 +535809 +308388 +578878 +509451 +46919 +562592 +499950 +90374 +318146 +195353 +355325 +314515 +237277 +203024 +238911 +32039 +145591 +16030 +135411 +229350 +421757 +48034 +183704 +307292 +97974 +275999 +448256 +451915 +119113 +143503 +494141 +50124 +306553 +35526 +255279 +560908 +247264 +367599 +192782 +511324 +574350 +67569 +204360 +111907 +2839 +513971 +245201 +185240 +339468 +540101 +539673 +194425 +22168 +520150 +301595 +96006 +68286 +131280 +356662 +182441 +284749 +107108 +49761 +386718 +55244 +187990 +248678 +147721 +425727 +360350 +310797 +76765 +400489 +247639 +279864 +44699 +356145 +69138 +445041 +560598 +165464 +536343 +7818 +322831 +334760 +451463 +348730 +285967 +286353 +201887 +166165 +359 +465591 +519359 +550444 +402711 +3661 +132706 +534983 +306281 +150317 +15978 +580029 +496090 +267127 +210980 +384015 +222559 +2235 +255649 +278168 +440840 +27326 +202562 +230268 +362712 +1573 +107661 +464515 +373132 +447242 +547440 +43613 +200143 +260883 +250901 +64693 +408480 +204757 +319933 +147471 +381332 +518197 +27656 +260257 +434580 +159203 +568630 +497441 +499597 +60179 +574804 +343254 +501762 +220704 +524536 +86946 +456046 +62937 +49633 +144305 +475593 +478553 +574145 +63648 +3794 +303177 +1340 +82835 +371427 +156747 +448694 +219567 +75095 +242615 +492077 +132776 +199125 +349622 +195754 +455548 +181873 +138185 +338044 +362797 +180953 +505826 +69773 +304834 +162580 +154090 +519853 +319687 +132328 +27969 +52166 +100547 +568131 +415218 +348045 +478159 +402869 +10211 +26547 +551692 +105432 +313340 +182348 +383419 +570947 +345353 +226883 +255784 +214199 +262262 +283261 +449708 +299970 +392391 +245997 +330410 +343571 +519542 +37470 +42144 +342521 +498537 +10935 +443860 +512648 +146099 +98599 +123932 +489861 +262895 +184700 +218587 +363581 +21001 +481404 +249356 +64240 +492349 +199236 +481064 +353405 +116479 +132024 +138768 +524665 +434511 +326970 +138784 +340368 +312081 +366615 +171942 +21232 +473850 +93686 +295574 +51054 +162692 +174091 +20070 +270066 +492816 +20904 +484500 +147140 +242972 +420081 +63563 +261712 +316396 +49413 +520787 +510955 +393840 +142487 +19817 +261180 +413736 +230619 +484614 +337011 +496575 +4338 +552545 +5601 +75426 +568863 +184227 +170629 +438567 +505132 +541353 +284674 +322567 +182423 +312051 +18896 +40471 +321725 +188850 +37119 +95569 +187362 +397133 +528972 +487131 +174989 +370325 +223554 +385633 +103485 +537574 +63240 +256566 +86467 +401092 +486968 +308441 +280017 +527464 +131965 +310479 +125556 +220160 +532963 +310052 +107963 +293841 +388534 +45603 +368949 +391825 +5107 +569705 +231549 +250108 +152933 +206433 +358817 +434006 +283904 +152808 +539975 +24629 +410231 +13465 +502318 +51961 +445594 +209062 +38726 +295420 +430079 +240147 +561512 +35795 +102589 +505619 +565469 +271772 +520561 +372300 +178807 +492805 +1083 +303704 +125635 +217521 +278032 +208688 +335325 +140435 +313990 +143822 +320857 +549230 +76844 +424219 +463876 +243199 +2988 +215170 +30012 +377738 +408568 +490624 +404839 +138316 +157206 +404461 +122934 +263346 +21327 +99913 +67975 +339676 +391891 +365305 +337055 +233834 +125524 +46869 +32577 +304744 +104176 +167356 +210404 +307989 +217223 +196046 +454414 +16356 +244487 +543660 +197461 +199681 +476787 +455085 +307074 +260547 +107468 +334769 +29437 +166837 +53838 +502979 +82678 +288860 +535523 +311950 +237723 +98656 +223123 +273930 +58057 +544334 +324857 +198043 +535326 +316505 +12991 +576820 +43611 +107839 +275749 +456695 +78188 +375786 +466239 +184830 +537128 +434513 +244344 +374576 +69140 +434247 +555009 +510857 +220819 +20598 +99416 +74967 +533129 +515577 +213361 +330974 +548848 +431557 +503278 +130043 +402570 +320554 +559884 +252629 +364596 +423484 +271230 +105552 +143143 +285751 +49994 +204162 +80646 +381393 +123415 +118417 +30932 +425412 +388130 +551243 +468337 +484893 +25014 +174390 +463781 +124647 +60823 +361964 +425702 +575110 +532390 +230881 +84592 +189997 +221307 +361472 +32364 +71918 +316365 +492378 +234251 +48504 +418070 +89884 +562045 +506552 +66360 +122962 +262605 +529939 +345229 +294853 +344397 +56091 +8599 +459823 +175785 +226128 +259983 +354515 +379144 +384995 +205253 +116786 +441432 +448810 +83452 +465129 +506906 +90616 +551959 +406404 +157891 +362090 +439630 +45099 +61960 +478430 +489605 +127050 +579872 +475798 +64510 +447733 +33066 +102848 +538819 +323760 +200401 +179765 +251317 +239376 +83836 +578092 +522452 +393056 +278848 +27787 +377239 +473427 +83065 +377005 +576539 +248019 +473370 +536369 +92648 +332461 +437609 +274800 +388846 +323048 +193407 +541898 +480140 +46526 +26432 +339738 +325991 +37705 +528033 +542922 +313420 +190463 +531000 +454907 +26448 +238199 +476652 +457147 +364256 +72632 +430380 +315448 +353320 +18158 +91527 +454252 +546987 +386370 +38064 +19763 +64152 +453216 +55223 +361860 +522566 +509531 +438432 +31164 +163290 +389197 +333440 +173464 +447842 +381615 +99961 +156126 +103134 +394940 +165638 +261706 +378311 +534081 +373848 +401642 +338019 +378096 +289610 +547421 +174672 +133343 +191360 +293751 +520892 +145214 +167668 +37456 +460962 +465267 +292804 +347529 +203661 +10766 +27371 +203845 +155736 +136715 +463588 +26640 +547612 +131453 +184274 +442456 +265085 +223256 +129420 +23019 +536467 +194532 +127585 +392637 +330408 +524775 +31993 +433924 +502852 +553129 +559364 +297343 +71360 +225537 +271148 +345499 +475893 +237463 +5278 +501243 +413235 +444236 +541071 +380088 +468063 +94858 +225913 +295614 +210276 +170975 +205570 +422375 +550365 +308702 +484627 +565031 +98979 +480345 +579548 +272673 +436875 +287874 +16502 +274917 +281809 +442968 +289263 +347766 +160933 +84533 +266409 +122199 +396200 +30958 +504541 +1591 +89432 +387150 +306383 +15260 +154515 +50752 +166913 +102644 +100196 +160278 +349579 +442536 +17923 +310564 +62020 +152004 +578330 +126299 +527025 +83494 +226400 +268435 +445334 +310391 +505156 +19157 +44677 +318171 +447765 +354369 +527486 +329939 +184771 +134856 +467675 +517133 +89697 +447080 +70685 +144938 +519673 +485758 +454957 +564851 +189451 +408757 +192616 +280734 +305060 +243946 +99179 +303971 +170519 +48917 +549965 +300245 +384101 +576607 +186709 +516341 +241668 +133470 +134811 +500825 +464689 +29833 +343820 +213429 +387434 +279305 +444207 +210777 +372043 +189868 +572229 +8495 +370090 +450282 +277080 +199158 +109612 +567708 +245659 +485129 +268363 +23448 +5352 +235597 +6871 +348720 +94113 +314613 +63729 +114458 +215394 +460460 +240387 +398726 +135604 +571728 +415770 +286908 +138151 +146272 +344094 +345209 +241187 +282768 +113037 +545583 +219283 +145873 +285957 +489235 +157271 +197458 +502671 +499845 +334884 +79084 +505573 +115618 +561491 +354202 +279838 +190734 +134738 +269450 +482784 +144610 +52774 +290659 +440646 +25807 +442952 +159215 +318224 +73445 +211653 +527960 +401862 +431026 +488755 +292278 +400554 +272630 +382668 +470298 +166426 +129645 +28820 +161227 +417696 +560677 +283216 +28978 +310302 +154419 +230450 +328289 +73118 +104691 +15085 +405574 +510548 +470005 +102928 +569249 +413126 +77282 +96732 +359020 +42182 +250875 +106206 +354929 +320796 +453341 +237318 +254834 +137265 +399865 +292685 +152252 +319579 +81484 +16599 +162257 +351034 +396051 +502275 +308278 +34483 +13333 +320290 +321579 +349794 +99219 +200162 +369470 +487583 +62703 +251639 +138246 +157170 +477112 +283963 +74860 +307057 +364075 +295491 +34757 +400161 +170194 +120874 +492817 +3817 +183973 +135436 +512989 +114744 +379210 +201072 +293785 +578385 +237420 +7888 +18224 +155317 +522406 +441440 +110482 +173400 +183348 +552504 +475660 +166948 +147025 +443259 +578792 +245227 +546687 +474519 +393284 +249668 +87493 +151651 +100306 +540466 +546556 +212675 +282942 +21310 +385535 +7304 +303409 +386116 +574297 +514550 +217133 +533553 +447152 +578703 +45392 +166205 +180154 +25143 +338802 +330110 +261389 +343506 +442726 +285388 +554934 +421316 +479912 +85192 +34874 +487266 +226173 +20748 +360660 +574509 +543364 +1554 +125539 +566931 +312889 +466945 +444804 +257187 +568587 +427160 +71123 +563849 +138589 +162841 +129663 +107226 +140686 +321663 +437117 +179808 +321718 +62398 +16497 +468933 +219841 +355430 +293554 +293044 +109516 +485887 +490620 +579893 +427135 +31636 +217919 +432441 +314396 +119802 +393682 +201764 +146193 +116358 +84825 +208311 +419774 +177468 +72052 +142585 +519598 +464006 +556083 +412136 +169361 +442929 +84567 +549932 +75560 +74656 +93314 +393838 +383018 +372433 +431281 +556278 +5513 +108503 +500478 +148588 +138713 +368153 +22646 +303778 +270758 +276706 +275429 +492025 +169111 +494328 +35891 +70258 +400528 +165229 +460494 +269311 +307658 +98283 +369294 +319345 +414578 +541550 +425388 +129855 +99477 +383073 +387906 +293124 +155873 +549224 +266021 +52869 +1584 +421902 +498535 +277235 +153013 +452013 +553561 +138040 +20820 +58483 +423506 +569001 +325153 +383039 +213421 +38825 +453283 +384661 +127702 +238147 +104893 +577826 +64974 +240655 +459153 +145665 +49810 +65008 +545385 +125070 +46433 +143329 +429174 +52947 +321314 +253341 +157365 +453162 +111910 +339019 +239575 +362219 +80652 +247317 +460286 +365724 +160875 +372220 +483389 +572181 +146190 +580975 +54761 +348488 +416104 +468778 +18833 +251537 +234366 +510078 +14723 +338595 +153797 +513098 +467138 +404618 +261982 +545730 +135846 +108244 +562557 +180524 +227370 +341856 +131743 +255691 +497878 +68878 +430640 +441473 +347664 +214369 +347018 +225238 +421762 +317024 +6180 +172004 +303101 +22488 +193494 +199346 +409627 +315350 +263463 +190722 +523292 +363902 +573778 +437290 +389812 +517082 +145073 +37907 +489763 +456261 +270386 +508917 +566823 +543897 +362482 +130966 +66632 +181962 +274613 +135708 +549746 +323766 +366714 +353295 +318813 +153307 +213693 +293378 +149446 +199927 +580543 +331727 +238488 +472833 +308645 +424225 +228746 +110435 +495377 +240646 +274491 +130921 +140006 +4688 +115241 +76962 +66650 +47718 +224991 +434187 +272048 +11169 +158222 +154000 +507436 +443499 +109937 +309692 +534018 +22797 +163339 +168683 +210098 +246069 +137954 +143320 +262587 +414795 +226938 +536831 +128791 +459590 +50514 +30067 +317479 +378655 +229968 +522702 +11122 +515266 +136600 +224509 +149912 +97656 +120747 +349480 +155199 +528731 +523807 +168544 +325664 +229981 +434410 +431208 +508996 +63791 +89225 +513690 +136740 +224364 +515424 +508302 +418175 +465552 +439907 +272097 +451087 +396304 +342273 +52507 +300066 +380089 +326248 +167906 +37846 +262993 +60090 +499249 +90432 +74456 +264660 +325598 +480985 +245411 +425644 +224724 +475439 +246478 +487438 +563731 +441854 +522665 +245915 +85747 +315162 +108761 +407521 +388528 +389453 +298331 +447791 +368820 +440034 +305677 +122208 +182369 +543531 +151820 +63650 +457580 +563381 +320899 +14869 +137260 +61925 +376307 +80367 +269089 +203705 +274835 +267321 +418106 +471273 +74037 +227855 +519758 +89045 +321217 +324203 +479129 +503431 +368528 +527718 +278579 +13525 +291582 +301837 +31667 +68120 +14007 +114158 +124262 +33626 +53949 +187585 +192247 +208844 +212766 +318671 +575012 +439339 +364073 +419624 +178078 +427783 +302159 +339368 +190680 +23807 +288579 +312720 +15778 +553558 +571834 +574376 +122161 +493815 +472376 +483432 +149123 +51628 +264628 +26609 +23696 +485081 +441323 +451679 +42055 +378795 +86439 +366493 +520996 +332869 +18014 +554523 +83476 +6040 +421834 +424392 +308160 +335233 +249809 +349098 +358090 +187349 +61782 +35498 +386514 +207108 +578418 +84447 +104108 +126107 +211674 +111909 +490708 +477025 +206757 +556205 +142484 +454296 +464366 +358254 +215482 +468548 +82680 +100909 +405432 +85764 +94651 +63973 +8131 +288592 +257470 +47597 +321557 +34520 +134066 +246701 +317797 +282365 +78176 +29577 +311075 +331937 +190395 +5802 +245112 +111032 +140556 +199127 +376491 +305253 +300375 +545903 +357782 +377911 +74963 +329336 +25057 +3244 +252020 +293474 +171050 +239306 +189772 +238090 +160031 +36761 +445675 +252716 +152214 +239466 +55155 +479829 +420281 +445812 +118106 +434576 +451104 +316708 +438535 +300322 +167952 +390072 +487220 +20247 +9400 +43944 +35770 +487351 +425462 +212203 +9668 +8981 +574241 +332096 +535563 +192944 +498733 +276151 +550645 +507037 +9769 +404249 +236747 +376416 +306415 +45966 +191296 +576875 +493932 +225075 +536444 +79920 +561681 +60700 +99874 +219437 +509819 +466665 +579326 +428739 +394611 +263083 +379554 +279391 +178516 +133690 +77396 +300137 +6861 +435359 +314108 +444152 +500139 +92749 +89188 +300233 +414201 +443204 +211097 diff --git a/research/object_detection/data_decoders/tf_example_decoder.py b/research/object_detection/data_decoders/tf_example_decoder.py index e437ea1e0ac..c844e3dccfb 100644 --- a/research/object_detection/data_decoders/tf_example_decoder.py +++ b/research/object_detection/data_decoders/tf_example_decoder.py @@ -335,8 +335,6 @@ def decode(self, tf_example_string_tensor): [None] containing classes for the boxes. fields.InputDataFields.groundtruth_weights - 1D float32 tensor of shape [None] indicating the weights of groundtruth boxes. - fields.InputDataFields.num_groundtruth_boxes - int32 scalar indicating - the number of groundtruth_boxes. fields.InputDataFields.groundtruth_area - 1D float32 tensor of shape [None] containing containing object mask area in pixel squared. fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape @@ -369,8 +367,6 @@ def decode(self, tf_example_string_tensor): tensor_dict[fields.InputDataFields.image].set_shape([None, None, 3]) tensor_dict[fields.InputDataFields.original_image_spatial_shape] = tf.shape( tensor_dict[fields.InputDataFields.image])[:2] - tensor_dict[fields.InputDataFields.num_groundtruth_boxes] = tf.shape( - tensor_dict[fields.InputDataFields.groundtruth_boxes])[0] if fields.InputDataFields.image_additional_channels in tensor_dict: channels = tensor_dict[fields.InputDataFields.image_additional_channels] diff --git a/research/object_detection/data_decoders/tf_example_decoder_test.py b/research/object_detection/data_decoders/tf_example_decoder_test.py index 7ea820b6250..91fa8693fc2 100644 --- a/research/object_detection/data_decoders/tf_example_decoder_test.py +++ b/research/object_detection/data_decoders/tf_example_decoder_test.py @@ -256,8 +256,6 @@ def testDecodeBoundingBox(self): bbox_xmaxs]).transpose() self.assertAllEqual(expected_boxes, tensor_dict[fields.InputDataFields.groundtruth_boxes]) - self.assertAllEqual( - 2, tensor_dict[fields.InputDataFields.num_groundtruth_boxes]) @test_util.enable_c_shapes def testDecodeKeypoint(self): @@ -305,8 +303,6 @@ def testDecodeKeypoint(self): bbox_xmaxs]).transpose() self.assertAllEqual(expected_boxes, tensor_dict[fields.InputDataFields.groundtruth_boxes]) - self.assertAllEqual( - 2, tensor_dict[fields.InputDataFields.num_groundtruth_boxes]) expected_keypoints = ( np.vstack([keypoint_ys, keypoint_xs]).transpose().reshape((2, 3, 2))) diff --git a/research/object_detection/eval_util.py b/research/object_detection/eval_util.py index a600cc533fd..6c19c32b593 100644 --- a/research/object_detection/eval_util.py +++ b/research/object_detection/eval_util.py @@ -27,6 +27,7 @@ from object_detection.metrics import coco_evaluation from object_detection.utils import label_map_util from object_detection.utils import ops +from object_detection.utils import shape_utils from object_detection.utils import visualization_utils as vis_utils slim = tf.contrib.slim @@ -321,6 +322,7 @@ def _run_checkpoint_once(tensor_dict, # TODO(akuznetsa): result_dict contains batches of images, while # add_single_ground_truth_image_info expects a single image. Fix if (isinstance(result_dict, dict) and + fields.InputDataFields.key in result_dict and result_dict[fields.InputDataFields.key]): image_id = result_dict[fields.InputDataFields.key] else: @@ -475,6 +477,35 @@ def repeated_checkpoint_run(tensor_dict, return metrics +def _scale_box_to_absolute(args): + boxes, image_shape = args + return box_list_ops.to_absolute_coordinates( + box_list.BoxList(boxes), image_shape[0], image_shape[1]).get() + + +def _resize_detection_masks(args): + detection_boxes, detection_masks, image_shape = args + detection_masks_reframed = ops.reframe_box_masks_to_image_masks( + detection_masks, detection_boxes, image_shape[0], image_shape[1]) + return tf.cast(tf.greater(detection_masks_reframed, 0.5), tf.uint8) + + +def _resize_groundtruth_masks(args): + mask, image_shape = args + mask = tf.expand_dims(mask, 3) + mask = tf.image.resize_images( + mask, + image_shape, + method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, + align_corners=True) + return tf.cast(tf.squeeze(mask, 3), tf.uint8) + + +def _scale_keypoint_to_absolute(args): + keypoints, image_shape = args + return keypoint_ops.scale(keypoints, image_shape[0], image_shape[1]) + + def result_dict_for_single_example(image, key, detections, @@ -533,89 +564,225 @@ def result_dict_for_single_example(image, (Optional). """ + + if groundtruth: + max_gt_boxes = tf.shape( + groundtruth[fields.InputDataFields.groundtruth_boxes])[0] + for gt_key in groundtruth: + # expand groundtruth dict along the batch dimension. + groundtruth[gt_key] = tf.expand_dims(groundtruth[gt_key], 0) + + for detection_key in detections: + detections[detection_key] = tf.expand_dims( + detections[detection_key][0], axis=0) + + batched_output_dict = result_dict_for_batched_example( + image, + tf.expand_dims(key, 0), + detections, + groundtruth, + class_agnostic, + scale_to_absolute, + max_gt_boxes=max_gt_boxes) + + exclude_keys = [ + fields.InputDataFields.original_image, + fields.DetectionResultFields.num_detections, + fields.InputDataFields.num_groundtruth_boxes, + fields.InputDataFields.original_image_spatial_shape + ] + + output_dict = { + fields.InputDataFields.original_image: + batched_output_dict[fields.InputDataFields.original_image] + } + + for key in batched_output_dict: + # remove the batch dimension. + if key not in exclude_keys: + output_dict[key] = tf.squeeze(batched_output_dict[key], 0) + return output_dict + + +def result_dict_for_batched_example(images, + keys, + detections, + groundtruth=None, + class_agnostic=False, + scale_to_absolute=False, + original_image_spatial_shapes=None, + max_gt_boxes=None): + """Merges all detection and groundtruth information for a single example. + + Note that evaluation tools require classes that are 1-indexed, and so this + function performs the offset. If `class_agnostic` is True, all output classes + have label 1. + + Args: + images: A single 4D uint8 image tensor of shape [batch_size, H, W, C]. + keys: A [batch_size] string tensor with image identifier. + detections: A dictionary of detections, returned from + DetectionModel.postprocess(). + groundtruth: (Optional) Dictionary of groundtruth items, with fields: + 'groundtruth_boxes': [batch_size, max_number_of_boxes, 4] float32 tensor + of boxes, in normalized coordinates. + 'groundtruth_classes': [batch_size, max_number_of_boxes] int64 tensor of + 1-indexed classes. + 'groundtruth_area': [batch_size, max_number_of_boxes] float32 tensor of + bbox area. (Optional) + 'groundtruth_is_crowd':[batch_size, max_number_of_boxes] int64 + tensor. (Optional) + 'groundtruth_difficult': [batch_size, max_number_of_boxes] int64 + tensor. (Optional) + 'groundtruth_group_of': [batch_size, max_number_of_boxes] int64 + tensor. (Optional) + 'groundtruth_instance_masks': 4D int64 tensor of instance + masks (Optional). + class_agnostic: Boolean indicating whether the detections are class-agnostic + (i.e. binary). Default False. + scale_to_absolute: Boolean indicating whether boxes and keypoints should be + scaled to absolute coordinates. Note that for IoU based evaluations, it + does not matter whether boxes are expressed in absolute or relative + coordinates. Default False. + original_image_spatial_shapes: A 2D int32 tensor of shape [batch_size, 2] + used to resize the image. When set to None, the image size is retained. + max_gt_boxes: [batch_size] tensor representing the maximum number of + groundtruth boxes to pad. + + Returns: + A dictionary with: + 'original_image': A [batch_size, H, W, C] uint8 image tensor. + 'original_image_spatial_shape': A [batch_size, 2] tensor containing the + original image sizes. + 'key': A [batch_size] string tensor with image identifier. + 'detection_boxes': [batch_size, max_detections, 4] float32 tensor of boxes, + in normalized or absolute coordinates, depending on the value of + `scale_to_absolute`. + 'detection_scores': [batch_size, max_detections] float32 tensor of scores. + 'detection_classes': [batch_size, max_detections] int64 tensor of 1-indexed + classes. + 'detection_masks': [batch_size, max_detections, H, W] float32 tensor of + binarized masks, reframed to full image masks. + 'num_detections': [batch_size] int64 tensor containing number of valid + detections. + 'groundtruth_boxes': [batch_size, num_boxes, 4] float32 tensor of boxes, in + normalized or absolute coordinates, depending on the value of + `scale_to_absolute`. (Optional) + 'groundtruth_classes': [batch_size, num_boxes] int64 tensor of 1-indexed + classes. (Optional) + 'groundtruth_area': [batch_size, num_boxes] float32 tensor of bbox + area. (Optional) + 'groundtruth_is_crowd': [batch_size, num_boxes] int64 tensor. (Optional) + 'groundtruth_difficult': [batch_size, num_boxes] int64 tensor. (Optional) + 'groundtruth_group_of': [batch_size, num_boxes] int64 tensor. (Optional) + 'groundtruth_instance_masks': 4D int64 tensor of instance masks + (Optional). + 'num_groundtruth_boxes': [batch_size] tensor containing the maximum number + of groundtruth boxes per image. + + Raises: + ValueError: if original_image_spatial_shape is not 1D int32 tensor of shape + [2]. + """ label_id_offset = 1 # Applying label id offset (b/63711816) input_data_fields = fields.InputDataFields + if original_image_spatial_shapes is None: + original_image_spatial_shapes = tf.tile( + tf.expand_dims(tf.shape(images)[1:3], axis=0), + multiples=[tf.shape(images)[0], 1]) + else: + if (len(original_image_spatial_shapes.shape) != 2 and + original_image_spatial_shapes.shape[1] != 2): + raise ValueError( + '`original_image_spatial_shape` should be a 2D tensor of shape ' + '[batch_size, 2].') + output_dict = { - input_data_fields.original_image: image, - input_data_fields.key: key, + input_data_fields.original_image: images, + input_data_fields.key: keys, + input_data_fields.original_image_spatial_shape: ( + original_image_spatial_shapes) } detection_fields = fields.DetectionResultFields - detection_boxes = detections[detection_fields.detection_boxes][0] - image_shape = tf.shape(image) - detection_scores = detections[detection_fields.detection_scores][0] + detection_boxes = detections[detection_fields.detection_boxes] + detection_scores = detections[detection_fields.detection_scores] + num_detections = tf.to_int32(detections[detection_fields.num_detections]) if class_agnostic: detection_classes = tf.ones_like(detection_scores, dtype=tf.int64) else: detection_classes = ( - tf.to_int64(detections[detection_fields.detection_classes][0]) + + tf.to_int64(detections[detection_fields.detection_classes]) + label_id_offset) - num_detections = tf.to_int32(detections[detection_fields.num_detections][0]) - detection_boxes = tf.slice( - detection_boxes, begin=[0, 0], size=[num_detections, -1]) - detection_classes = tf.slice( - detection_classes, begin=[0], size=[num_detections]) - detection_scores = tf.slice( - detection_scores, begin=[0], size=[num_detections]) - if scale_to_absolute: - absolute_detection_boxlist = box_list_ops.to_absolute_coordinates( - box_list.BoxList(detection_boxes), image_shape[1], image_shape[2]) output_dict[detection_fields.detection_boxes] = ( - absolute_detection_boxlist.get()) + shape_utils.static_or_dynamic_map_fn( + _scale_box_to_absolute, + elems=[detection_boxes, original_image_spatial_shapes], + dtype=tf.float32)) else: output_dict[detection_fields.detection_boxes] = detection_boxes output_dict[detection_fields.detection_classes] = detection_classes output_dict[detection_fields.detection_scores] = detection_scores + output_dict[detection_fields.num_detections] = num_detections if detection_fields.detection_masks in detections: - detection_masks = detections[detection_fields.detection_masks][0] + detection_masks = detections[detection_fields.detection_masks] # TODO(rathodv): This should be done in model's postprocess # function ideally. - detection_masks = tf.slice( - detection_masks, begin=[0, 0, 0], size=[num_detections, -1, -1]) - detection_masks_reframed = ops.reframe_box_masks_to_image_masks( - detection_masks, detection_boxes, image_shape[1], image_shape[2]) - detection_masks_reframed = tf.cast( - tf.greater(detection_masks_reframed, 0.5), tf.uint8) - output_dict[detection_fields.detection_masks] = detection_masks_reframed + output_dict[detection_fields.detection_masks] = ( + shape_utils.static_or_dynamic_map_fn( + _resize_detection_masks, + elems=[detection_boxes, detection_masks, + original_image_spatial_shapes], + dtype=tf.uint8)) + if detection_fields.detection_keypoints in detections: - detection_keypoints = detections[detection_fields.detection_keypoints][0] + detection_keypoints = detections[detection_fields.detection_keypoints] output_dict[detection_fields.detection_keypoints] = detection_keypoints if scale_to_absolute: - absolute_detection_keypoints = keypoint_ops.scale( - detection_keypoints, image_shape[1], image_shape[2]) output_dict[detection_fields.detection_keypoints] = ( - absolute_detection_keypoints) + shape_utils.static_or_dynamic_map_fn( + _scale_keypoint_to_absolute, + elems=[detection_keypoints, original_image_spatial_shapes], + dtype=tf.float32)) if groundtruth: + if max_gt_boxes is None: + if input_data_fields.num_groundtruth_boxes in groundtruth: + max_gt_boxes = groundtruth[input_data_fields.num_groundtruth_boxes] + else: + raise ValueError( + 'max_gt_boxes must be provided when processing batched examples.') + if input_data_fields.groundtruth_instance_masks in groundtruth: masks = groundtruth[input_data_fields.groundtruth_instance_masks] - masks = tf.expand_dims(masks, 3) - masks = tf.image.resize_images( - masks, - image_shape[1:3], - method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, - align_corners=True) - masks = tf.squeeze(masks, 3) - groundtruth[input_data_fields.groundtruth_instance_masks] = tf.cast( - masks, tf.uint8) + groundtruth[input_data_fields.groundtruth_instance_masks] = ( + shape_utils.static_or_dynamic_map_fn( + _resize_groundtruth_masks, + elems=[masks, original_image_spatial_shapes], + dtype=tf.uint8)) + output_dict.update(groundtruth) if scale_to_absolute: groundtruth_boxes = groundtruth[input_data_fields.groundtruth_boxes] - absolute_gt_boxlist = box_list_ops.to_absolute_coordinates( - box_list.BoxList(groundtruth_boxes), image_shape[1], image_shape[2]) output_dict[input_data_fields.groundtruth_boxes] = ( - absolute_gt_boxlist.get()) + shape_utils.static_or_dynamic_map_fn( + _scale_box_to_absolute, + elems=[groundtruth_boxes, original_image_spatial_shapes], + dtype=tf.float32)) + # For class-agnostic models, groundtruth classes all become 1. if class_agnostic: groundtruth_classes = groundtruth[input_data_fields.groundtruth_classes] groundtruth_classes = tf.ones_like(groundtruth_classes, dtype=tf.int64) output_dict[input_data_fields.groundtruth_classes] = groundtruth_classes + output_dict[input_data_fields.num_groundtruth_boxes] = max_gt_boxes + return output_dict diff --git a/research/object_detection/eval_util_test.py b/research/object_detection/eval_util_test.py index 0c0ffce4822..7c99898deda 100644 --- a/research/object_detection/eval_util_test.py +++ b/research/object_detection/eval_util_test.py @@ -18,37 +18,58 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf +from absl.testing import parameterized +import tensorflow as tf from object_detection import eval_util from object_detection.core import standard_fields as fields from object_detection.protos import eval_pb2 +from object_detection.utils import test_case -class EvalUtilTest(tf.test.TestCase): +class EvalUtilTest(test_case.TestCase, parameterized.TestCase): def _get_categories_list(self): return [{'id': 0, 'name': 'person'}, {'id': 1, 'name': 'dog'}, {'id': 2, 'name': 'cat'}] - def _make_evaluation_dict(self, resized_groundtruth_masks=False): + def _make_evaluation_dict(self, + resized_groundtruth_masks=False, + batch_size=1, + max_gt_boxes=None, + scale_to_absolute=False): input_data_fields = fields.InputDataFields detection_fields = fields.DetectionResultFields - image = tf.zeros(shape=[1, 20, 20, 3], dtype=tf.uint8) - key = tf.constant('image1') - detection_boxes = tf.constant([[[0., 0., 1., 1.]]]) - detection_scores = tf.constant([[0.8]]) - detection_classes = tf.constant([[0]]) - detection_masks = tf.ones(shape=[1, 1, 20, 20], dtype=tf.float32) - num_detections = tf.constant([1]) + image = tf.zeros(shape=[batch_size, 20, 20, 3], dtype=tf.uint8) + if batch_size == 1: + key = tf.constant('image1') + else: + key = tf.constant([str(range(batch_size))]) + detection_boxes = tf.tile(tf.constant([[[0., 0., 1., 1.]]]), + multiples=[batch_size, 1, 1]) + detection_scores = tf.tile(tf.constant([[0.8]]), multiples=[batch_size, 1]) + detection_classes = tf.tile(tf.constant([[0]]), multiples=[batch_size, 1]) + detection_masks = tf.tile(tf.ones(shape=[1, 1, 20, 20], dtype=tf.float32), + multiples=[batch_size, 1, 1, 1]) + num_detections = tf.ones([batch_size]) groundtruth_boxes = tf.constant([[0., 0., 1., 1.]]) groundtruth_classes = tf.constant([1]) groundtruth_instance_masks = tf.ones(shape=[1, 20, 20], dtype=tf.uint8) if resized_groundtruth_masks: groundtruth_instance_masks = tf.ones(shape=[1, 10, 10], dtype=tf.uint8) + + if batch_size > 1: + groundtruth_boxes = tf.tile(tf.expand_dims(groundtruth_boxes, 0), + multiples=[batch_size, 1, 1]) + groundtruth_classes = tf.tile(tf.expand_dims(groundtruth_classes, 0), + multiples=[batch_size, 1]) + groundtruth_instance_masks = tf.tile( + tf.expand_dims(groundtruth_instance_masks, 0), + multiples=[batch_size, 1, 1, 1]) + detections = { detection_fields.detection_boxes: detection_boxes, detection_fields.detection_scores: detection_scores, @@ -61,14 +82,31 @@ def _make_evaluation_dict(self, resized_groundtruth_masks=False): input_data_fields.groundtruth_classes: groundtruth_classes, input_data_fields.groundtruth_instance_masks: groundtruth_instance_masks } - return eval_util.result_dict_for_single_example(image, key, detections, - groundtruth) - - def test_get_eval_metric_ops_for_coco_detections(self): + if batch_size > 1: + return eval_util.result_dict_for_batched_example( + image, key, detections, groundtruth, + scale_to_absolute=scale_to_absolute, + max_gt_boxes=max_gt_boxes) + else: + return eval_util.result_dict_for_single_example( + image, key, detections, groundtruth, + scale_to_absolute=scale_to_absolute) + + @parameterized.parameters( + {'batch_size': 1, 'max_gt_boxes': None, 'scale_to_absolute': True}, + {'batch_size': 8, 'max_gt_boxes': [1], 'scale_to_absolute': True}, + {'batch_size': 1, 'max_gt_boxes': None, 'scale_to_absolute': False}, + {'batch_size': 8, 'max_gt_boxes': [1], 'scale_to_absolute': False} + ) + def test_get_eval_metric_ops_for_coco_detections(self, batch_size=1, + max_gt_boxes=None, + scale_to_absolute=False): eval_config = eval_pb2.EvalConfig() eval_config.metrics_set.extend(['coco_detection_metrics']) categories = self._get_categories_list() - eval_dict = self._make_evaluation_dict() + eval_dict = self._make_evaluation_dict(batch_size=batch_size, + max_gt_boxes=max_gt_boxes, + scale_to_absolute=scale_to_absolute) metric_ops = eval_util.get_eval_metric_ops_for_evaluators( eval_config, categories, eval_dict) _, update_op = metric_ops['DetectionBoxes_Precision/mAP'] @@ -79,16 +117,24 @@ def test_get_eval_metric_ops_for_coco_detections(self): metrics[key] = value_op sess.run(update_op) metrics = sess.run(metrics) - print(metrics) self.assertAlmostEqual(1.0, metrics['DetectionBoxes_Precision/mAP']) self.assertNotIn('DetectionMasks_Precision/mAP', metrics) - def test_get_eval_metric_ops_for_coco_detections_and_masks(self): + @parameterized.parameters( + {'batch_size': 1, 'max_gt_boxes': None, 'scale_to_absolute': True}, + {'batch_size': 8, 'max_gt_boxes': [1], 'scale_to_absolute': True}, + {'batch_size': 1, 'max_gt_boxes': None, 'scale_to_absolute': False}, + {'batch_size': 8, 'max_gt_boxes': [1], 'scale_to_absolute': False} + ) + def test_get_eval_metric_ops_for_coco_detections_and_masks( + self, batch_size=1, max_gt_boxes=None, scale_to_absolute=False): eval_config = eval_pb2.EvalConfig() eval_config.metrics_set.extend( ['coco_detection_metrics', 'coco_mask_metrics']) categories = self._get_categories_list() - eval_dict = self._make_evaluation_dict() + eval_dict = self._make_evaluation_dict(batch_size=batch_size, + max_gt_boxes=max_gt_boxes, + scale_to_absolute=scale_to_absolute) metric_ops = eval_util.get_eval_metric_ops_for_evaluators( eval_config, categories, eval_dict) _, update_op_boxes = metric_ops['DetectionBoxes_Precision/mAP'] @@ -104,12 +150,22 @@ def test_get_eval_metric_ops_for_coco_detections_and_masks(self): self.assertAlmostEqual(1.0, metrics['DetectionBoxes_Precision/mAP']) self.assertAlmostEqual(1.0, metrics['DetectionMasks_Precision/mAP']) - def test_get_eval_metric_ops_for_coco_detections_and_resized_masks(self): + @parameterized.parameters( + {'batch_size': 1, 'max_gt_boxes': None, 'scale_to_absolute': True}, + {'batch_size': 8, 'max_gt_boxes': [1], 'scale_to_absolute': True}, + {'batch_size': 1, 'max_gt_boxes': None, 'scale_to_absolute': False}, + {'batch_size': 8, 'max_gt_boxes': [1], 'scale_to_absolute': False} + ) + def test_get_eval_metric_ops_for_coco_detections_and_resized_masks( + self, batch_size=1, max_gt_boxes=None, scale_to_absolute=False): eval_config = eval_pb2.EvalConfig() eval_config.metrics_set.extend( ['coco_detection_metrics', 'coco_mask_metrics']) categories = self._get_categories_list() - eval_dict = self._make_evaluation_dict(resized_groundtruth_masks=True) + eval_dict = self._make_evaluation_dict(batch_size=batch_size, + max_gt_boxes=max_gt_boxes, + scale_to_absolute=scale_to_absolute, + resized_groundtruth_masks=True) metric_ops = eval_util.get_eval_metric_ops_for_evaluators( eval_config, categories, eval_dict) _, update_op_boxes = metric_ops['DetectionBoxes_Precision/mAP'] diff --git a/research/object_detection/export_tflite_ssd_graph_lib.py b/research/object_detection/export_tflite_ssd_graph_lib.py index 576505d69ae..d6a5a7c54b7 100644 --- a/research/object_detection/export_tflite_ssd_graph_lib.py +++ b/research/object_detection/export_tflite_ssd_graph_lib.py @@ -234,12 +234,16 @@ def export_tflite_graph(pipeline_config, trained_checkpoint_prefix, output_dir, tf.train.get_or_create_global_step() # graph rewriter - if pipeline_config.HasField('graph_rewriter'): + is_quantized = pipeline_config.HasField('graph_rewriter') + if is_quantized: graph_rewriter_config = pipeline_config.graph_rewriter graph_rewriter_fn = graph_rewriter_builder.build( graph_rewriter_config, is_training=False) graph_rewriter_fn() + if pipeline_config.model.ssd.feature_extractor.HasField('fpn'): + exporter.rewrite_nn_resize_op(is_quantized) + # freeze the graph saver_kwargs = {} if pipeline_config.eval_config.use_moving_averages: diff --git a/research/object_detection/export_tflite_ssd_graph_lib_test.py b/research/object_detection/export_tflite_ssd_graph_lib_test.py index 3d79a5bc70e..cb93a87dda4 100644 --- a/research/object_detection/export_tflite_ssd_graph_lib_test.py +++ b/research/object_detection/export_tflite_ssd_graph_lib_test.py @@ -23,6 +23,7 @@ import tensorflow as tf from tensorflow.core.framework import types_pb2 from object_detection import export_tflite_ssd_graph_lib +from object_detection import exporter from object_detection.builders import graph_rewriter_builder from object_detection.builders import model_builder from object_detection.core import model @@ -70,6 +71,12 @@ def restore_map(self, checkpoint_path, from_detection_checkpoint): def loss(self, prediction_dict, true_image_shapes): pass + def regularization_losses(self): + pass + + def updates(self): + pass + class ExportTfliteGraphTest(tf.test.TestCase): @@ -335,6 +342,28 @@ def test_export_tflite_graph_with_postprocessing_op(self): for t in node.attr['_output_types'].list.type ])) + @mock.patch.object(exporter, 'rewrite_nn_resize_op') + def test_export_with_nn_resize_op_not_called_without_fpn(self, mock_get): + pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() + pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.height = 10 + pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.width = 10 + tflite_graph_file = self._export_graph_with_postprocessing_op( + pipeline_config) + self.assertTrue(os.path.exists(tflite_graph_file)) + mock_get.assert_not_called() + + @mock.patch.object(exporter, 'rewrite_nn_resize_op') + def test_export_with_nn_resize_op_called_with_fpn(self, mock_get): + pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() + pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.height = 10 + pipeline_config.model.ssd.image_resizer.fixed_shape_resizer.width = 10 + pipeline_config.model.ssd.feature_extractor.fpn.min_level = 3 + pipeline_config.model.ssd.feature_extractor.fpn.max_level = 7 + tflite_graph_file = self._export_graph_with_postprocessing_op( + pipeline_config) + self.assertTrue(os.path.exists(tflite_graph_file)) + mock_get.assert_called_once() + if __name__ == '__main__': tf.test.main() diff --git a/research/object_detection/exporter.py b/research/object_detection/exporter.py index ed62fac2127..a0add282711 100644 --- a/research/object_detection/exporter.py +++ b/research/object_detection/exporter.py @@ -17,6 +17,7 @@ import os import tempfile import tensorflow as tf +from tensorflow.contrib.quantize.python import graph_matcher from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.client import session from tensorflow.python.platform import gfile @@ -28,12 +29,58 @@ from object_detection.core import standard_fields as fields from object_detection.data_decoders import tf_example_decoder from object_detection.utils import config_util +from object_detection.utils import shape_utils slim = tf.contrib.slim freeze_graph_with_def_protos = freeze_graph.freeze_graph_with_def_protos +def rewrite_nn_resize_op(is_quantized=False): + """Replaces a custom nearest-neighbor resize op with the Tensorflow version. + + Some graphs use this custom version for TPU-compatibility. + + Args: + is_quantized: True if the default graph is quantized. + """ + input_pattern = graph_matcher.OpTypePattern( + 'FakeQuantWithMinMaxVars' if is_quantized else '*') + reshape_1_pattern = graph_matcher.OpTypePattern( + 'Reshape', inputs=[input_pattern, 'Const'], ordered_inputs=False) + mul_pattern = graph_matcher.OpTypePattern( + 'Mul', inputs=[reshape_1_pattern, 'Const'], ordered_inputs=False) + # The quantization script may or may not insert a fake quant op after the + # Mul. In either case, these min/max vars are not needed once replaced with + # the TF version of NN resize. + fake_quant_pattern = graph_matcher.OpTypePattern( + 'FakeQuantWithMinMaxVars', + inputs=[mul_pattern, 'Identity', 'Identity'], + ordered_inputs=False) + reshape_2_pattern = graph_matcher.OpTypePattern( + 'Reshape', + inputs=[graph_matcher.OneofPattern([fake_quant_pattern, mul_pattern]), + 'Const'], + ordered_inputs=False) + add_pattern = graph_matcher.OpTypePattern( + 'Add', inputs=[reshape_2_pattern, '*'], ordered_inputs=False) + + matcher = graph_matcher.GraphMatcher(add_pattern) + for match in matcher.match_graph(tf.get_default_graph()): + projection_op = match.get_op(input_pattern) + reshape_2_op = match.get_op(reshape_2_pattern) + add_op = match.get_op(add_pattern) + nn_resize = tf.image.resize_nearest_neighbor( + projection_op.outputs[0], + add_op.outputs[0].shape.dims[1:3], + align_corners=False) + + for index, op_input in enumerate(add_op.inputs): + if op_input == reshape_2_op.outputs[0]: + add_op._update_input(index, nn_resize) # pylint: disable=protected-access + break + + def replace_variable_values_with_moving_averages(graph, current_checkpoint_file, new_checkpoint_file): @@ -82,11 +129,12 @@ def decode(tf_example_string_tensor): image_tensor = tensor_dict[fields.InputDataFields.image] return image_tensor return (batch_tf_example_placeholder, - tf.map_fn(decode, - elems=batch_tf_example_placeholder, - dtype=tf.uint8, - parallel_iterations=32, - back_prop=False)) + shape_utils.static_or_dynamic_map_fn( + decode, + elems=batch_tf_example_placeholder, + dtype=tf.uint8, + parallel_iterations=32, + back_prop=False)) def _encoded_image_string_tensor_input_placeholder(): @@ -121,8 +169,8 @@ def decode(encoded_image_string_tensor): } -def _add_output_tensor_nodes(postprocessed_tensors, - output_collection_name='inference_op'): +def add_output_tensor_nodes(postprocessed_tensors, + output_collection_name='inference_op'): """Adds output nodes for detection boxes and scores. Adds the following nodes for output tensors - @@ -254,8 +302,8 @@ def _get_outputs_from_inputs(input_tensors, detection_model, preprocessed_inputs, true_image_shapes) postprocessed_tensors = detection_model.postprocess( output_tensors, true_image_shapes) - return _add_output_tensor_nodes(postprocessed_tensors, - output_collection_name) + return add_output_tensor_nodes(postprocessed_tensors, + output_collection_name) def _build_detection_graph(input_type, detection_model, input_shape, diff --git a/research/object_detection/exporter_test.py b/research/object_detection/exporter_test.py index d872b5611bd..5d2bd9ba7d5 100644 --- a/research/object_detection/exporter_test.py +++ b/research/object_detection/exporter_test.py @@ -19,12 +19,15 @@ import six import tensorflow as tf from google.protobuf import text_format +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from object_detection import exporter from object_detection.builders import graph_rewriter_builder from object_detection.builders import model_builder from object_detection.core import model from object_detection.protos import graph_rewriter_pb2 from object_detection.protos import pipeline_pb2 +from object_detection.utils import ops if six.PY2: import mock # pylint: disable=g-import-not-at-top @@ -74,6 +77,12 @@ def restore_map(self, checkpoint_path, fine_tune_checkpoint_type): def loss(self, prediction_dict, true_image_shapes): pass + def regularization_losses(self): + pass + + def updates(self): + pass + class ExportInferenceGraphTest(tf.test.TestCase): @@ -928,6 +937,52 @@ def test_write_graph_and_checkpoint(self): self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4])) self.assertAllClose(num_detections_np, [2, 1]) + def test_rewrite_nn_resize_op(self): + g = tf.Graph() + with g.as_default(): + x = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8)) + y = array_ops.placeholder(dtypes.float32, shape=(8, 20, 20, 8)) + s = ops.nearest_neighbor_upsampling(x, 2) + t = s + y + exporter.rewrite_nn_resize_op() + + resize_op_found = False + for op in g.get_operations(): + if op.type == 'ResizeNearestNeighbor': + resize_op_found = True + self.assertEqual(op.inputs[0], x) + self.assertEqual(op.outputs[0].consumers()[0], t.op) + break + + self.assertTrue(resize_op_found) + + def test_rewrite_nn_resize_op_quantized(self): + g = tf.Graph() + with g.as_default(): + x = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8)) + x_conv = tf.contrib.slim.conv2d(x, 8, 1) + y = array_ops.placeholder(dtypes.float32, shape=(8, 20, 20, 8)) + s = ops.nearest_neighbor_upsampling(x_conv, 2) + t = s + y + + graph_rewriter_config = graph_rewriter_pb2.GraphRewriter() + graph_rewriter_config.quantization.delay = 500000 + graph_rewriter_fn = graph_rewriter_builder.build( + graph_rewriter_config, is_training=False) + graph_rewriter_fn() + + exporter.rewrite_nn_resize_op(is_quantized=True) + + resize_op_found = False + for op in g.get_operations(): + if op.type == 'ResizeNearestNeighbor': + resize_op_found = True + self.assertEqual(op.inputs[0].op.type, 'FakeQuantWithMinMaxVars') + self.assertEqual(op.outputs[0].consumers()[0], t.op) + break + + self.assertTrue(resize_op_found) + if __name__ == '__main__': tf.test.main() diff --git a/research/object_detection/g3doc/detection_model_zoo.md b/research/object_detection/g3doc/detection_model_zoo.md index f538de89d60..538619ea221 100644 --- a/research/object_detection/g3doc/detection_model_zoo.md +++ b/research/object_detection/g3doc/detection_model_zoo.md @@ -78,6 +78,7 @@ Some remarks on frozen inference graphs: | [ssd_mobilenet_v1_fpn_coco ☆](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03.tar.gz) | 56 | 32 | Boxes | | [ssd_resnet_50_fpn_coco ☆](http://download.tensorflow.org/models/object_detection/ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03.tar.gz) | 76 | 35 | Boxes | | [ssd_mobilenet_v2_coco](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz) | 31 | 22 | Boxes | +| [ssd_mobilenet_v2_quantized_coco](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_quantized_300x300_coco_2018_09_14.tar.gz) | 29 | 22 | Boxes | | [ssdlite_mobilenet_v2_coco](http://download.tensorflow.org/models/object_detection/ssdlite_mobilenet_v2_coco_2018_05_09.tar.gz) | 27 | 22 | Boxes | | [ssd_inception_v2_coco](http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2018_01_28.tar.gz) | 42 | 24 | Boxes | | [faster_rcnn_inception_v2_coco](http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2018_01_28.tar.gz) | 58 | 28 | Boxes | @@ -111,6 +112,7 @@ Model name ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | :---: | :-------------: | :-----: [faster_rcnn_inception_resnet_v2_atrous_oid](http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_resnet_v2_atrous_oid_2018_01_28.tar.gz) | 727 | 37 | Boxes [faster_rcnn_inception_resnet_v2_atrous_lowproposals_oid](http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_resnet_v2_atrous_lowproposals_oid_2018_01_28.tar.gz) | 347 | | Boxes +[facessd_mobilenet_v2_quantized_open_image_v4](http://download.tensorflow.org/models/object_detection/facessd_mobilenet_v2_quantized_320x320_open_image_v4.tar.gz) [^3] | 20 | 73 (faces) | Boxes ## iNaturalist Species-trained models @@ -130,4 +132,5 @@ Model name [^1]: See [MSCOCO evaluation protocol](http://cocodataset.org/#detections-eval). [^2]: This is PASCAL mAP with a slightly different way of true positives computation: see [Open Images evaluation protocol](evaluation_protocols.md#open-images). +[^3]: Non-face boxes are dropped during training and non-face groundtruth boxes are ignored when evaluating. diff --git a/research/object_detection/g3doc/installation.md b/research/object_detection/g3doc/installation.md index 3fda48a1d07..206304d1484 100644 --- a/research/object_detection/g3doc/installation.md +++ b/research/object_detection/g3doc/installation.md @@ -108,7 +108,7 @@ Run the compilation process again, but use the downloaded version of protoc **If you are on MacOS:** -If you have homebrew, download and install the protobuf with +If you have homebrew, download and install the protobuf with ```brew install protobuf``` Alternately, run: @@ -118,7 +118,7 @@ sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc rm -f $PROTOC_ZIP ``` -Run the compilation process again: +Run the compilation process again: ``` bash # From tensorflow/models/research/ diff --git a/research/object_detection/inputs.py b/research/object_detection/inputs.py index 4c2155a9a69..7ab9404cc2f 100644 --- a/research/object_detection/inputs.py +++ b/research/object_detection/inputs.py @@ -124,6 +124,8 @@ def transform_input_data(tensor_dict, if fields.InputDataFields.groundtruth_instance_masks in tensor_dict: masks = tensor_dict[fields.InputDataFields.groundtruth_instance_masks] _, resized_masks, _ = image_resizer_fn(image, masks) + if use_bfloat16: + resized_masks = tf.cast(resized_masks, tf.bfloat16) tensor_dict[fields.InputDataFields. groundtruth_instance_masks] = resized_masks @@ -161,6 +163,9 @@ def transform_input_data(tensor_dict, tensor_dict[fields.InputDataFields.groundtruth_classes] = merged_classes tensor_dict[fields.InputDataFields.groundtruth_confidences] = ( merged_confidences) + if fields.InputDataFields.groundtruth_boxes in tensor_dict: + tensor_dict[fields.InputDataFields.num_groundtruth_boxes] = tf.shape( + tensor_dict[fields.InputDataFields.groundtruth_boxes])[0] return tensor_dict @@ -282,12 +287,9 @@ def augment_input_data(tensor_dict, data_augmentation_options): in tensor_dict) include_keypoints = (fields.InputDataFields.groundtruth_keypoints in tensor_dict) - include_label_scores = (fields.InputDataFields.groundtruth_confidences in - tensor_dict) tensor_dict = preprocessor.preprocess( tensor_dict, data_augmentation_options, func_arg_map=preprocessor.get_default_func_arg_map( - include_label_scores=include_label_scores, include_instance_masks=include_instance_masks, include_keypoints=include_keypoints)) tensor_dict[fields.InputDataFields.image] = tf.squeeze( diff --git a/research/object_detection/inputs_test.py b/research/object_detection/inputs_test.py index 4165b297bbe..a87b12d268f 100644 --- a/research/object_detection/inputs_test.py +++ b/research/object_detection/inputs_test.py @@ -630,6 +630,9 @@ def test_returns_correct_merged_boxes(self): self.assertAllClose( transformed_inputs[fields.InputDataFields.groundtruth_confidences], [[1, 0, 1]]) + self.assertAllClose( + transformed_inputs[fields.InputDataFields.num_groundtruth_boxes], + 1) def test_returns_resized_masks(self): tensor_dict = { diff --git a/research/object_detection/legacy/trainer_test.py b/research/object_detection/legacy/trainer_test.py index 82e77274a8a..3c06e07e81e 100644 --- a/research/object_detection/legacy/trainer_test.py +++ b/research/object_detection/legacy/trainer_test.py @@ -160,6 +160,17 @@ def loss(self, prediction_dict, true_image_shapes): } return loss_dict + def regularization_losses(self): + """Returns a list of regularization losses for this model. + + Returns a list of regularization losses for this model that the estimator + needs to use during training/optimization. + + Returns: + A list of regularization loss tensors. + """ + pass + def restore_map(self, fine_tune_checkpoint_type='detection'): """Returns a map of variables to load from a foreign checkpoint. @@ -174,6 +185,18 @@ def restore_map(self, fine_tune_checkpoint_type='detection'): """ return {var.op.name: var for var in tf.global_variables()} + def updates(self): + """Returns a list of update operators for this model. + + Returns a list of update operators for this model that must be executed at + each training step. The estimator's train op needs to have a control + dependency on these updates. + + Returns: + A list of update operators. + """ + pass + class TrainerTest(tf.test.TestCase): diff --git a/research/object_detection/meta_architectures/faster_rcnn_meta_arch.py b/research/object_detection/meta_architectures/faster_rcnn_meta_arch.py index aa4132a3ec6..bf6b7d5edff 100644 --- a/research/object_detection/meta_architectures/faster_rcnn_meta_arch.py +++ b/research/object_detection/meta_architectures/faster_rcnn_meta_arch.py @@ -662,7 +662,8 @@ def predict(self, preprocessed_inputs, true_image_shapes): anchors_boxlist, clip_window) else: anchors_boxlist = box_list_ops.clip_to_window( - anchors_boxlist, clip_window) + anchors_boxlist, clip_window, + filter_nonoverlapping=not self._use_static_shapes) self._anchors = anchors_boxlist prediction_dict = { @@ -917,12 +918,14 @@ def _predict_third_stage(self, prediction_dict, image_shapes): _, num_classes, mask_height, mask_width = ( detection_masks.get_shape().as_list()) _, max_detection = detection_classes.get_shape().as_list() + prediction_dict['mask_predictions'] = tf.reshape( + detection_masks, [-1, num_classes, mask_height, mask_width]) if num_classes > 1: detection_masks = self._gather_instance_masks( detection_masks, detection_classes) prediction_dict[fields.DetectionResultFields.detection_masks] = ( - tf.reshape(detection_masks, + tf.reshape(tf.sigmoid(detection_masks), [batch_size, max_detection, mask_height, mask_width])) return prediction_dict @@ -1159,9 +1162,9 @@ def postprocess(self, prediction_dict, true_image_shapes): } # TODO(jrru): Remove mask_predictions from _post_process_box_classifier. - with tf.name_scope('SecondStagePostprocessor'): - if (self._number_of_stages == 2 or - (self._number_of_stages == 3 and self._is_training)): + if (self._number_of_stages == 2 or + (self._number_of_stages == 3 and self._is_training)): + with tf.name_scope('SecondStagePostprocessor'): mask_predictions = prediction_dict.get(box_predictor.MASK_PREDICTIONS) detections_dict = self._postprocess_box_classifier( prediction_dict['refined_box_encodings'], @@ -1170,18 +1173,53 @@ def postprocess(self, prediction_dict, true_image_shapes): prediction_dict['num_proposals'], true_image_shapes, mask_predictions=mask_predictions) - return detections_dict + + if 'rpn_features_to_crop' in prediction_dict and self._initial_crop_size: + self._add_detection_features_output_node( + detections_dict[fields.DetectionResultFields.detection_boxes], + prediction_dict['rpn_features_to_crop']) + + return detections_dict if self._number_of_stages == 3: # Post processing is already performed in 3rd stage. We need to transfer # postprocessed tensors from `prediction_dict` to `detections_dict`. - detections_dict = {} - for key in prediction_dict: - if key == fields.DetectionResultFields.detection_masks: - detections_dict[key] = tf.sigmoid(prediction_dict[key]) - elif 'detection' in key: - detections_dict[key] = prediction_dict[key] - return detections_dict + return prediction_dict + + def _add_detection_features_output_node(self, detection_boxes, + rpn_features_to_crop): + """Add the detection features to the output node. + + The detection features are from cropping rpn_features with boxes. + Each bounding box has one feature vector of length depth, which comes from + mean_pooling of the cropped rpn_features. + + Args: + detection_boxes: a 3-D float32 tensor of shape + [batch_size, max_detection, 4] which represents the bounding boxes. + rpn_features_to_crop: A 4-D float32 tensor with shape + [batch, height, width, depth] representing image features to crop using + the proposals boxes. + """ + with tf.name_scope('SecondStageDetectionFeaturesExtract'): + flattened_detected_feature_maps = ( + self._compute_second_stage_input_feature_maps( + rpn_features_to_crop, detection_boxes)) + detection_features_unpooled = ( + self._feature_extractor.extract_box_classifier_features( + flattened_detected_feature_maps, + scope=self.second_stage_feature_extractor_scope)) + + batch_size = tf.shape(detection_boxes)[0] + max_detection = tf.shape(detection_boxes)[1] + detection_features_pool = tf.reduce_mean( + detection_features_unpooled, axis=[1, 2]) + detection_features = tf.reshape( + detection_features_pool, + [batch_size, max_detection, tf.shape(detection_features_pool)[-1]]) + + detection_features = tf.identity( + detection_features, 'detection_features') def _postprocess_rpn(self, rpn_box_encodings_batch, @@ -1454,6 +1492,7 @@ def _sample_box_classifier_minibatch_single_image( # to cls_weights. This could happen as boxes within certain IOU ranges # are ignored. If triggered, the selected boxes will still be ignored # during loss computation. + cls_weights = tf.reduce_mean(cls_weights, axis=-1) positive_indicator = tf.greater(tf.argmax(cls_targets, axis=1), 0) valid_indicator = tf.logical_and( tf.range(proposal_boxlist.num_boxes()) < num_valid_proposals, @@ -1566,6 +1605,7 @@ def _postprocess_box_classifier(self, mask_predictions_batch = tf.reshape( mask_predictions, [-1, self.max_num_proposals, self.num_classes, mask_height, mask_width]) + (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks, _, num_detections) = self._second_stage_nms_fn( refined_decoded_boxes_batch, @@ -1713,6 +1753,7 @@ class targets with the 0th index assumed to map to the background class. gt_box_batch=groundtruth_boxlists, gt_class_targets_batch=(len(groundtruth_boxlists) * [None]), gt_weights_batch=groundtruth_weights_list) + batch_cls_weights = tf.reduce_mean(batch_cls_weights, axis=2) batch_cls_targets = tf.squeeze(batch_cls_targets, axis=2) def _minibatch_subsample_fn(inputs): @@ -1743,7 +1784,8 @@ def _minibatch_subsample_fn(inputs): losses_mask=losses_mask) objectness_losses = self._first_stage_objectness_loss( rpn_objectness_predictions_with_background, - batch_one_hot_targets, weights=batch_sampled_indices, + batch_one_hot_targets, + weights=tf.expand_dims(batch_sampled_indices, axis=-1), losses_mask=losses_mask) localization_loss = tf.reduce_mean( tf.reduce_sum(localization_losses, axis=1) / normalizer) @@ -1960,25 +2002,28 @@ class targets with the 0th index assumed to map to the background class. tf.expand_dims(flat_gt_masks, -1), tf.expand_dims(flat_normalized_proposals, axis=1), [mask_height, mask_width]) + # Without stopping gradients into cropped groundtruth masks the + # performance with 100-padded groundtruth masks when batch size > 1 is + # about 4% worse. + # TODO(rathodv): Investigate this since we don't expect any variables + # upstream of flat_cropped_gt_mask. + flat_cropped_gt_mask = tf.stop_gradient(flat_cropped_gt_mask) batch_cropped_gt_mask = tf.reshape( flat_cropped_gt_mask, [batch_size, -1, mask_height * mask_width]) - second_stage_mask_losses = ops.reduce_sum_trailing_dimensions( - self._second_stage_mask_loss( - reshaped_prediction_masks, - batch_cropped_gt_mask, - weights=batch_mask_target_weights, - losses_mask=losses_mask), - ndims=2) / ( - mask_height * mask_width * tf.maximum( - tf.reduce_sum( - batch_mask_target_weights, axis=1, keep_dims=True - ), tf.ones((batch_size, 1)))) - second_stage_mask_loss = tf.reduce_sum( - tf.where(paddings_indicator, second_stage_mask_losses, - tf.zeros_like(second_stage_mask_losses))) + mask_losses_weights = ( + batch_mask_target_weights * tf.to_float(paddings_indicator)) + mask_losses = self._second_stage_mask_loss( + reshaped_prediction_masks, + batch_cropped_gt_mask, + weights=tf.expand_dims(mask_losses_weights, axis=-1), + losses_mask=losses_mask) + total_mask_loss = tf.reduce_sum(mask_losses) + normalizer = tf.maximum( + tf.reduce_sum(mask_losses_weights * mask_height * mask_width), 1.0) + second_stage_mask_loss = total_mask_loss / normalizer if second_stage_mask_loss is not None: mask_loss = tf.multiply(self._second_stage_mask_loss_weight, @@ -2073,6 +2118,17 @@ def _unpad_proposals_and_apply_hard_mining(self, cls_losses=tf.expand_dims(single_image_cls_loss, 0), decoded_boxlist_list=[proposal_boxlist]) + def regularization_losses(self): + """Returns a list of regularization losses for this model. + + Returns a list of regularization losses for this model that the estimator + needs to use during training/optimization. + + Returns: + A list of regularization loss tensors. + """ + return tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) + def restore_map(self, fine_tune_checkpoint_type='detection', load_all_detection_checkpoint_vars=False): @@ -2117,3 +2173,16 @@ def restore_map(self, feature_extractor_variables = tf.contrib.framework.filter_variables( variables_to_restore, include_patterns=include_patterns) return {var.op.name: var for var in feature_extractor_variables} + + def updates(self): + """Returns a list of update operators for this model. + + Returns a list of update operators for this model that must be executed at + each training step. The estimator's train op needs to have a control + dependency on these updates. + + Returns: + A list of update operators. + """ + return tf.get_collection(tf.GraphKeys.UPDATE_OPS) + diff --git a/research/object_detection/meta_architectures/faster_rcnn_meta_arch_test.py b/research/object_detection/meta_architectures/faster_rcnn_meta_arch_test.py index 094e312b354..2c701d28927 100644 --- a/research/object_detection/meta_architectures/faster_rcnn_meta_arch_test.py +++ b/research/object_detection/meta_architectures/faster_rcnn_meta_arch_test.py @@ -189,7 +189,7 @@ def test_predict_correct_shapes_in_inference_mode_three_stages_with_masks( set(expected_shapes.keys()).union( set([ 'detection_boxes', 'detection_scores', 'detection_classes', - 'detection_masks', 'num_detections' + 'detection_masks', 'num_detections', 'mask_predictions', ]))) for key in expected_shapes: self.assertAllEqual(tensor_dict_out[key].shape, expected_shapes[key]) @@ -199,6 +199,9 @@ def test_predict_correct_shapes_in_inference_mode_three_stages_with_masks( self.assertAllEqual(tensor_dict_out['detection_classes'].shape, [2, 5]) self.assertAllEqual(tensor_dict_out['detection_scores'].shape, [2, 5]) self.assertAllEqual(tensor_dict_out['num_detections'].shape, [2]) + num_classes = 1 if masks_are_class_agnostic else 2 + self.assertAllEqual(tensor_dict_out['mask_predictions'].shape, + [10, num_classes, 14, 14]) @parameterized.parameters( {'masks_are_class_agnostic': False}, diff --git a/research/object_detection/meta_architectures/faster_rcnn_meta_arch_test_lib.py b/research/object_detection/meta_architectures/faster_rcnn_meta_arch_test_lib.py index 397c6d0848e..655a44fecd2 100644 --- a/research/object_detection/meta_architectures/faster_rcnn_meta_arch_test_lib.py +++ b/research/object_detection/meta_architectures/faster_rcnn_meta_arch_test_lib.py @@ -250,6 +250,7 @@ def image_resizer_fn(image, masks=None): iou_threshold: 1.0 max_detections_per_class: 5 max_total_detections: 5 + use_static_shapes: """ +'{}'.format(use_static_shapes) + """ } """ post_processing_config = post_processing_pb2.PostProcessing() @@ -336,61 +337,71 @@ def image_resizer_fn(image, masks=None): masks_are_class_agnostic=masks_are_class_agnostic), **common_kwargs) def test_predict_gives_correct_shapes_in_inference_mode_first_stage_only( - self): - test_graph = tf.Graph() - with test_graph.as_default(): - model = self._build_model( - is_training=False, number_of_stages=1, second_stage_batch_size=2) - batch_size = 2 - height = 10 - width = 12 - input_image_shape = (batch_size, height, width, 3) + self, use_static_shapes=False): + batch_size = 2 + height = 10 + width = 12 + input_image_shape = (batch_size, height, width, 3) - _, true_image_shapes = model.preprocess(tf.zeros(input_image_shape)) - preprocessed_inputs = tf.placeholder( - dtype=tf.float32, shape=(batch_size, None, None, 3)) + def graph_fn(images): + """Function to construct tf graph for the test.""" + model = self._build_model( + is_training=False, + number_of_stages=1, + second_stage_batch_size=2, + clip_anchors_to_image=use_static_shapes, + use_static_shapes=use_static_shapes) + preprocessed_inputs, true_image_shapes = model.preprocess(images) prediction_dict = model.predict(preprocessed_inputs, true_image_shapes) + return (prediction_dict['rpn_box_predictor_features'], + prediction_dict['rpn_features_to_crop'], + prediction_dict['image_shape'], + prediction_dict['rpn_box_encodings'], + prediction_dict['rpn_objectness_predictions_with_background'], + prediction_dict['anchors']) + + images = np.zeros(input_image_shape, dtype=np.float32) + + # In inference mode, anchors are clipped to the image window, but not + # pruned. Since MockFasterRCNN.extract_proposal_features returns a + # tensor with the same shape as its input, the expected number of anchors + # is height * width * the number of anchors per location (i.e. 3x3). + expected_num_anchors = height * width * 3 * 3 + expected_output_shapes = { + 'rpn_box_predictor_features': (batch_size, height, width, 512), + 'rpn_features_to_crop': (batch_size, height, width, 3), + 'rpn_box_encodings': (batch_size, expected_num_anchors, 4), + 'rpn_objectness_predictions_with_background': + (batch_size, expected_num_anchors, 2), + 'anchors': (expected_num_anchors, 4) + } - # In inference mode, anchors are clipped to the image window, but not - # pruned. Since MockFasterRCNN.extract_proposal_features returns a - # tensor with the same shape as its input, the expected number of anchors - # is height * width * the number of anchors per location (i.e. 3x3). - expected_num_anchors = height * width * 3 * 3 - expected_output_keys = set([ - 'rpn_box_predictor_features', 'rpn_features_to_crop', 'image_shape', - 'rpn_box_encodings', 'rpn_objectness_predictions_with_background', - 'anchors']) - expected_output_shapes = { - 'rpn_box_predictor_features': (batch_size, height, width, 512), - 'rpn_features_to_crop': (batch_size, height, width, 3), - 'rpn_box_encodings': (batch_size, expected_num_anchors, 4), - 'rpn_objectness_predictions_with_background': - (batch_size, expected_num_anchors, 2), - 'anchors': (expected_num_anchors, 4) - } - - init_op = tf.global_variables_initializer() - with self.test_session(graph=test_graph) as sess: - sess.run(init_op) - prediction_out = sess.run(prediction_dict, - feed_dict={ - preprocessed_inputs: - np.zeros(input_image_shape) - }) - - self.assertEqual(set(prediction_out.keys()), expected_output_keys) + if use_static_shapes: + results = self.execute(graph_fn, [images]) + else: + results = self.execute_cpu(graph_fn, [images]) - self.assertAllEqual(prediction_out['image_shape'], input_image_shape) - for output_key, expected_shape in expected_output_shapes.items(): - self.assertAllEqual(prediction_out[output_key].shape, expected_shape) + self.assertAllEqual(results[0].shape, + expected_output_shapes['rpn_box_predictor_features']) + self.assertAllEqual(results[1].shape, + expected_output_shapes['rpn_features_to_crop']) + self.assertAllEqual(results[2], + input_image_shape) + self.assertAllEqual(results[3].shape, + expected_output_shapes['rpn_box_encodings']) + self.assertAllEqual( + results[4].shape, + expected_output_shapes['rpn_objectness_predictions_with_background']) + self.assertAllEqual(results[5].shape, + expected_output_shapes['anchors']) - # Check that anchors are clipped to window. - anchors = prediction_out['anchors'] - self.assertTrue(np.all(np.greater_equal(anchors, 0))) - self.assertTrue(np.all(np.less_equal(anchors[:, 0], height))) - self.assertTrue(np.all(np.less_equal(anchors[:, 1], width))) - self.assertTrue(np.all(np.less_equal(anchors[:, 2], height))) - self.assertTrue(np.all(np.less_equal(anchors[:, 3], width))) + # Check that anchors are clipped to window. + anchors = results[5] + self.assertTrue(np.all(np.greater_equal(anchors, 0))) + self.assertTrue(np.all(np.less_equal(anchors[:, 0], height))) + self.assertTrue(np.all(np.less_equal(anchors[:, 1], width))) + self.assertTrue(np.all(np.less_equal(anchors[:, 2], height))) + self.assertTrue(np.all(np.less_equal(anchors[:, 3], width))) def test_predict_gives_valid_anchors_in_training_mode_first_stage_only(self): test_graph = tf.Graph() @@ -446,7 +457,38 @@ def test_predict_gives_valid_anchors_in_training_mode_first_stage_only(self): prediction_out['rpn_objectness_predictions_with_background'].shape, (batch_size, num_anchors_out, 2)) - def test_predict_correct_shapes_in_inference_mode_two_stages(self): + def test_predict_correct_shapes_in_inference_mode_two_stages( + self, use_static_shapes=False): + + def compare_results(results, expected_output_shapes): + """Checks if the shape of the predictions are as expected.""" + self.assertAllEqual(results[0].shape, + expected_output_shapes['rpn_box_predictor_features']) + self.assertAllEqual(results[1].shape, + expected_output_shapes['rpn_features_to_crop']) + self.assertAllEqual(results[2].shape, + expected_output_shapes['image_shape']) + self.assertAllEqual(results[3].shape, + expected_output_shapes['rpn_box_encodings']) + self.assertAllEqual( + results[4].shape, + expected_output_shapes['rpn_objectness_predictions_with_background']) + self.assertAllEqual(results[5].shape, + expected_output_shapes['anchors']) + self.assertAllEqual(results[6].shape, + expected_output_shapes['refined_box_encodings']) + self.assertAllEqual( + results[7].shape, + expected_output_shapes['class_predictions_with_background']) + self.assertAllEqual(results[8].shape, + expected_output_shapes['num_proposals']) + self.assertAllEqual(results[9].shape, + expected_output_shapes['proposal_boxes']) + self.assertAllEqual(results[10].shape, + expected_output_shapes['proposal_boxes_normalized']) + self.assertAllEqual(results[11].shape, + expected_output_shapes['box_classifier_features']) + batch_size = 2 image_size = 10 max_num_proposals = 8 @@ -457,6 +499,32 @@ def test_predict_correct_shapes_in_inference_mode_two_stages(self): (None, image_size, image_size, 3), (batch_size, None, None, 3), (None, None, None, 3)] + + def graph_fn_tpu(images): + """Function to construct tf graph for the test.""" + model = self._build_model( + is_training=False, + number_of_stages=2, + second_stage_batch_size=2, + predict_masks=False, + use_matmul_crop_and_resize=use_static_shapes, + clip_anchors_to_image=use_static_shapes, + use_static_shapes=use_static_shapes) + preprocessed_inputs, true_image_shapes = model.preprocess(images) + prediction_dict = model.predict(preprocessed_inputs, true_image_shapes) + return (prediction_dict['rpn_box_predictor_features'], + prediction_dict['rpn_features_to_crop'], + prediction_dict['image_shape'], + prediction_dict['rpn_box_encodings'], + prediction_dict['rpn_objectness_predictions_with_background'], + prediction_dict['anchors'], + prediction_dict['refined_box_encodings'], + prediction_dict['class_predictions_with_background'], + prediction_dict['num_proposals'], + prediction_dict['proposal_boxes'], + prediction_dict['proposal_boxes_normalized'], + prediction_dict['box_classifier_features']) + expected_num_anchors = image_size * image_size * 3 * 3 expected_shapes = { 'rpn_box_predictor_features': @@ -481,28 +549,34 @@ def test_predict_correct_shapes_in_inference_mode_two_stages(self): 3) } - for input_shape in input_shapes: - test_graph = tf.Graph() - with test_graph.as_default(): - model = self._build_model( - is_training=False, - number_of_stages=2, - second_stage_batch_size=2, - predict_masks=False) - preprocessed_inputs = tf.placeholder(tf.float32, shape=input_shape) - _, true_image_shapes = model.preprocess(preprocessed_inputs) - result_tensor_dict = model.predict( - preprocessed_inputs, true_image_shapes) - init_op = tf.global_variables_initializer() - with self.test_session(graph=test_graph) as sess: - sess.run(init_op) - tensor_dict_out = sess.run(result_tensor_dict, feed_dict={ - preprocessed_inputs: - np.zeros((batch_size, image_size, image_size, 3))}) - self.assertEqual(set(tensor_dict_out.keys()), - set(expected_shapes.keys())) - for key in expected_shapes: - self.assertAllEqual(tensor_dict_out[key].shape, expected_shapes[key]) + if use_static_shapes: + input_shape = (batch_size, image_size, image_size, 3) + images = np.zeros(input_shape, dtype=np.float32) + results = self.execute(graph_fn_tpu, [images]) + compare_results(results, expected_shapes) + else: + for input_shape in input_shapes: + test_graph = tf.Graph() + with test_graph.as_default(): + model = self._build_model( + is_training=False, + number_of_stages=2, + second_stage_batch_size=2, + predict_masks=False) + preprocessed_inputs = tf.placeholder(tf.float32, shape=input_shape) + _, true_image_shapes = model.preprocess(preprocessed_inputs) + result_tensor_dict = model.predict( + preprocessed_inputs, true_image_shapes) + init_op = tf.global_variables_initializer() + with self.test_session(graph=test_graph) as sess: + sess.run(init_op) + tensor_dict_out = sess.run(result_tensor_dict, feed_dict={ + preprocessed_inputs: + np.zeros((batch_size, image_size, image_size, 3))}) + self.assertEqual(set(tensor_dict_out.keys()), + set(expected_shapes.keys())) + for key in expected_shapes: + self.assertAllEqual(tensor_dict_out[key].shape, expected_shapes[key]) def test_predict_gives_correct_shapes_in_train_mode_both_stages( self, @@ -596,23 +670,46 @@ def graph_fn(images, gt_boxes, gt_classes, gt_weights): self.assertAllEqual(results[8].shape, expected_shapes['rpn_box_predictor_features']) - def _test_postprocess_first_stage_only_inference_mode( - self, pad_to_max_dimension=None): - model = self._build_model( - is_training=False, number_of_stages=1, second_stage_batch_size=6, - pad_to_max_dimension=pad_to_max_dimension) + def test_postprocess_first_stage_only_inference_mode( + self, use_static_shapes=False, pad_to_max_dimension=None): batch_size = 2 - anchors = tf.constant( + first_stage_max_proposals = 4 if use_static_shapes else 8 + + def graph_fn(images, + rpn_box_encodings, + rpn_objectness_predictions_with_background, + rpn_features_to_crop, + anchors): + """Function to construct tf graph for the test.""" + model = self._build_model( + is_training=False, number_of_stages=1, second_stage_batch_size=6, + use_matmul_crop_and_resize=use_static_shapes, + clip_anchors_to_image=use_static_shapes, + use_static_shapes=use_static_shapes, + use_matmul_gather_in_matcher=use_static_shapes, + first_stage_max_proposals=first_stage_max_proposals, + pad_to_max_dimension=pad_to_max_dimension) + _, true_image_shapes = model.preprocess(images) + proposals = model.postprocess({ + 'rpn_box_encodings': rpn_box_encodings, + 'rpn_objectness_predictions_with_background': + rpn_objectness_predictions_with_background, + 'rpn_features_to_crop': rpn_features_to_crop, + 'anchors': anchors}, true_image_shapes) + return (proposals['num_detections'], + proposals['detection_boxes'], + proposals['detection_scores']) + + anchors = np.array( [[0, 0, 16, 16], [0, 16, 16, 32], [16, 0, 32, 16], - [16, 16, 32, 32]], dtype=tf.float32) - rpn_box_encodings = tf.zeros( - [batch_size, anchors.get_shape().as_list()[0], - BOX_CODE_SIZE], dtype=tf.float32) + [16, 16, 32, 32]], dtype=np.float32) + rpn_box_encodings = np.zeros( + (batch_size, anchors.shape[0], BOX_CODE_SIZE), dtype=np.float32) # use different numbers for the objectness category to break ties in # order of boxes returned by NMS - rpn_objectness_predictions_with_background = tf.constant([ + rpn_objectness_predictions_with_background = np.array([ [[-10, 13], [10, -10], [10, -11], @@ -620,16 +717,22 @@ def _test_postprocess_first_stage_only_inference_mode( [[10, -10], [-10, 13], [-10, 12], - [10, -11]]], dtype=tf.float32) - rpn_features_to_crop = tf.ones((batch_size, 8, 8, 10), dtype=tf.float32) - image_shape = tf.constant([batch_size, 32, 32, 3], dtype=tf.int32) - _, true_image_shapes = model.preprocess(tf.zeros(image_shape)) - proposals = model.postprocess({ - 'rpn_box_encodings': rpn_box_encodings, - 'rpn_objectness_predictions_with_background': - rpn_objectness_predictions_with_background, - 'rpn_features_to_crop': rpn_features_to_crop, - 'anchors': anchors}, true_image_shapes) + [10, -11]]], dtype=np.float32) + rpn_features_to_crop = np.ones((batch_size, 8, 8, 10), dtype=np.float32) + image_shape = (batch_size, 32, 32, 3) + images = np.zeros(image_shape, dtype=np.float32) + + if use_static_shapes: + results = self.execute(graph_fn, + [images, rpn_box_encodings, + rpn_objectness_predictions_with_background, + rpn_features_to_crop, anchors]) + else: + results = self.execute_cpu(graph_fn, + [images, rpn_box_encodings, + rpn_objectness_predictions_with_background, + rpn_features_to_crop, anchors]) + expected_proposal_boxes = [ [[0, 0, .5, .5], [.5, .5, 1, 1], [0, .5, .5, 1], [.5, 0, 1.0, .5]] + 4 * [4 * [0]], @@ -639,24 +742,12 @@ def _test_postprocess_first_stage_only_inference_mode( [1, 1, 0, 0, 0, 0, 0, 0]] expected_num_proposals = [4, 4] - expected_output_keys = set(['detection_boxes', 'detection_scores', - 'num_detections']) - self.assertEqual(set(proposals.keys()), expected_output_keys) - with self.test_session() as sess: - proposals_out = sess.run(proposals) - self.assertAllClose(proposals_out['detection_boxes'], - expected_proposal_boxes) - self.assertAllClose(proposals_out['detection_scores'], - expected_proposal_scores) - self.assertAllEqual(proposals_out['num_detections'], - expected_num_proposals) - - def test_postprocess_first_stage_only_inference_mode(self): - self._test_postprocess_first_stage_only_inference_mode() - - def test_postprocess_first_stage_only_inference_mode_padded_image(self): - self._test_postprocess_first_stage_only_inference_mode( - pad_to_max_dimension=56) + self.assertAllClose(results[0], expected_num_proposals) + for indx, num_proposals in enumerate(expected_num_proposals): + self.assertAllClose(results[1][indx][0:num_proposals], + expected_proposal_boxes[indx][0:num_proposals]) + self.assertAllClose(results[2][indx][0:num_proposals], + expected_proposal_scores[indx][0:num_proposals]) def _test_postprocess_first_stage_only_train_mode(self, pad_to_max_dimension=None): @@ -733,83 +824,80 @@ def test_postprocess_first_stage_only_train_mode(self): def test_postprocess_first_stage_only_train_mode_padded_image(self): self._test_postprocess_first_stage_only_train_mode(pad_to_max_dimension=56) - def _test_postprocess_second_stage_only_inference_mode( - self, pad_to_max_dimension=None): - num_proposals_shapes = [(2), (None,)] - refined_box_encodings_shapes = [(16, 2, 4), (None, 2, 4)] - class_predictions_with_background_shapes = [(16, 3), (None, 3)] - proposal_boxes_shapes = [(2, 8, 4), (None, 8, 4)] + def test_postprocess_second_stage_only_inference_mode( + self, use_static_shapes=False, pad_to_max_dimension=None): batch_size = 2 + num_classes = 2 image_shape = np.array((2, 36, 48, 3), dtype=np.int32) - for (num_proposals_shape, refined_box_encoding_shape, - class_predictions_with_background_shape, - proposal_boxes_shape) in zip(num_proposals_shapes, - refined_box_encodings_shapes, - class_predictions_with_background_shapes, - proposal_boxes_shapes): - tf_graph = tf.Graph() - with tf_graph.as_default(): - model = self._build_model( - is_training=False, number_of_stages=2, - second_stage_batch_size=6, - pad_to_max_dimension=pad_to_max_dimension) - _, true_image_shapes = model.preprocess(tf.zeros(image_shape)) - total_num_padded_proposals = batch_size * model.max_num_proposals - proposal_boxes = np.array( - [[[1, 1, 2, 3], - [0, 0, 1, 1], - [.5, .5, .6, .6], - 4*[0], 4*[0], 4*[0], 4*[0], 4*[0]], - [[2, 3, 6, 8], - [1, 2, 5, 3], - 4*[0], 4*[0], 4*[0], 4*[0], 4*[0], 4*[0]]]) - num_proposals = np.array([3, 2], dtype=np.int32) - refined_box_encodings = np.zeros( - [total_num_padded_proposals, model.num_classes, 4]) - class_predictions_with_background = np.ones( - [total_num_padded_proposals, model.num_classes+1]) - - num_proposals_placeholder = tf.placeholder(tf.int32, - shape=num_proposals_shape) - refined_box_encodings_placeholder = tf.placeholder( - tf.float32, shape=refined_box_encoding_shape) - class_predictions_with_background_placeholder = tf.placeholder( - tf.float32, shape=class_predictions_with_background_shape) - proposal_boxes_placeholder = tf.placeholder( - tf.float32, shape=proposal_boxes_shape) - image_shape_placeholder = tf.placeholder(tf.int32, shape=(4)) - - detections = model.postprocess({ - 'refined_box_encodings': refined_box_encodings_placeholder, - 'class_predictions_with_background': - class_predictions_with_background_placeholder, - 'num_proposals': num_proposals_placeholder, - 'proposal_boxes': proposal_boxes_placeholder, - }, true_image_shapes) - with self.test_session(graph=tf_graph) as sess: - detections_out = sess.run( - detections, - feed_dict={ - refined_box_encodings_placeholder: refined_box_encodings, - class_predictions_with_background_placeholder: - class_predictions_with_background, - num_proposals_placeholder: num_proposals, - proposal_boxes_placeholder: proposal_boxes, - image_shape_placeholder: image_shape - }) - self.assertAllEqual(detections_out['detection_boxes'].shape, [2, 5, 4]) - self.assertAllClose(detections_out['detection_scores'], - [[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]) - self.assertAllClose(detections_out['detection_classes'], - [[0, 0, 0, 1, 1], [0, 0, 1, 1, 0]]) - self.assertAllClose(detections_out['num_detections'], [5, 4]) - - def test_postprocess_second_stage_only_inference_mode(self): - self._test_postprocess_second_stage_only_inference_mode() - - def test_postprocess_second_stage_only_inference_mode_padded_image(self): - self._test_postprocess_second_stage_only_inference_mode( - pad_to_max_dimension=56) + first_stage_max_proposals = 8 + total_num_padded_proposals = batch_size * first_stage_max_proposals + + def graph_fn(images, + refined_box_encodings, + class_predictions_with_background, + num_proposals, + proposal_boxes): + """Function to construct tf graph for the test.""" + model = self._build_model( + is_training=False, number_of_stages=2, + second_stage_batch_size=6, + use_matmul_crop_and_resize=use_static_shapes, + clip_anchors_to_image=use_static_shapes, + use_static_shapes=use_static_shapes, + use_matmul_gather_in_matcher=use_static_shapes, + pad_to_max_dimension=pad_to_max_dimension) + _, true_image_shapes = model.preprocess(images) + detections = model.postprocess({ + 'refined_box_encodings': refined_box_encodings, + 'class_predictions_with_background': + class_predictions_with_background, + 'num_proposals': num_proposals, + 'proposal_boxes': proposal_boxes, + }, true_image_shapes) + return (detections['num_detections'], + detections['detection_boxes'], + detections['detection_scores'], + detections['detection_classes']) + + proposal_boxes = np.array( + [[[1, 1, 2, 3], + [0, 0, 1, 1], + [.5, .5, .6, .6], + 4*[0], 4*[0], 4*[0], 4*[0], 4*[0]], + [[2, 3, 6, 8], + [1, 2, 5, 3], + 4*[0], 4*[0], 4*[0], 4*[0], 4*[0], 4*[0]]], dtype=np.float32) + num_proposals = np.array([3, 2], dtype=np.int32) + refined_box_encodings = np.zeros( + [total_num_padded_proposals, num_classes, 4], dtype=np.float32) + class_predictions_with_background = np.ones( + [total_num_padded_proposals, num_classes+1], dtype=np.float32) + images = np.zeros(image_shape, dtype=np.float32) + + if use_static_shapes: + results = self.execute(graph_fn, + [images, refined_box_encodings, + class_predictions_with_background, + num_proposals, proposal_boxes]) + else: + results = self.execute_cpu(graph_fn, + [images, refined_box_encodings, + class_predictions_with_background, + num_proposals, proposal_boxes]) + expected_num_detections = [5, 4] + expected_detection_classes = [[0, 0, 0, 1, 1], [0, 0, 1, 1, 0]] + expected_detection_scores = [[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]] + + self.assertAllClose(results[0], expected_num_detections) + + for indx, num_proposals in enumerate(expected_num_detections): + self.assertAllClose(results[2][indx][0:num_proposals], + expected_detection_scores[indx][0:num_proposals]) + self.assertAllClose(results[3][indx][0:num_proposals], + expected_detection_classes[indx][0:num_proposals]) + + if not use_static_shapes: + self.assertAllEqual(results[1].shape, [2, 5, 4]) def test_preprocess_preserves_input_shapes(self): image_shapes = [(3, None, None, 3), diff --git a/research/object_detection/meta_architectures/ssd_meta_arch.py b/research/object_detection/meta_architectures/ssd_meta_arch.py index a8e24b0e8f0..ab7f3aa75f7 100644 --- a/research/object_detection/meta_architectures/ssd_meta_arch.py +++ b/research/object_detection/meta_architectures/ssd_meta_arch.py @@ -19,7 +19,6 @@ """ from abc import abstractmethod -import re import tensorflow as tf from object_detection.core import box_list @@ -116,6 +115,25 @@ def extract_features(self, preprocessed_inputs): """ raise NotImplementedError + def restore_from_classification_checkpoint_fn(self, feature_extractor_scope): + """Returns a map of variables to load from a foreign checkpoint. + + Args: + feature_extractor_scope: A scope name for the feature extractor. + + Returns: + A dict mapping variable names (to load from a checkpoint) to variables in + the model graph. + """ + variables_to_restore = {} + for variable in tf.global_variables(): + var_name = variable.op.name + if var_name.startswith(feature_extractor_scope + '/'): + var_name = var_name.replace(feature_extractor_scope + '/', '') + variables_to_restore[var_name] = variable + + return variables_to_restore + class SSDKerasFeatureExtractor(tf.keras.Model): """SSD Feature Extractor definition.""" @@ -218,6 +236,25 @@ def _extract_features(self, preprocessed_inputs): def call(self, inputs, **kwargs): return self._extract_features(inputs) + def restore_from_classification_checkpoint_fn(self, feature_extractor_scope): + """Returns a map of variables to load from a foreign checkpoint. + + Args: + feature_extractor_scope: A scope name for the feature extractor. + + Returns: + A dict mapping variable names (to load from a checkpoint) to variables in + the model graph. + """ + variables_to_restore = {} + for variable in tf.global_variables(): + var_name = variable.op.name + if var_name.startswith(feature_extractor_scope + '/'): + var_name = var_name.replace(feature_extractor_scope + '/', '') + variables_to_restore[var_name] = variable + + return variables_to_restore + class SSDMetaArch(model.DetectionModel): """SSD Meta-architecture definition.""" @@ -333,13 +370,15 @@ def __init__(self, # Slim feature extractors get an explicit naming scope self._extract_features_scope = 'FeatureExtractor' - # TODO(jonathanhuang): handle agnostic mode - # weights - self._unmatched_class_label = tf.constant([1] + self.num_classes * [0], - tf.float32) - if encode_background_as_zeros: + if self._add_background_class and encode_background_as_zeros: self._unmatched_class_label = tf.constant((self.num_classes + 1) * [0], tf.float32) + elif self._add_background_class: + self._unmatched_class_label = tf.constant([1] + self.num_classes * [0], + tf.float32) + else: + self._unmatched_class_label = tf.constant(self.num_classes * [0], + tf.float32) self._target_assigner = target_assigner_instance @@ -606,14 +645,22 @@ def postprocess(self, prediction_dict, true_image_shapes): detection_boxes = tf.identity(detection_boxes, 'raw_box_locations') detection_boxes = tf.expand_dims(detection_boxes, axis=2) - detection_scores_with_background = self._score_conversion_fn( - class_predictions) - detection_scores_with_background = tf.identity( - detection_scores_with_background, 'raw_box_scores') - detection_scores = tf.slice(detection_scores_with_background, [0, 0, 1], - [-1, -1, -1]) + detection_scores = self._score_conversion_fn(class_predictions) + detection_scores = tf.identity(detection_scores, 'raw_box_scores') + if self._add_background_class: + detection_scores = tf.slice(detection_scores, [0, 0, 1], [-1, -1, -1]) additional_fields = None + batch_size = ( + shape_utils.combined_static_and_dynamic_shape(preprocessed_images)[0]) + + if 'feature_maps' in prediction_dict: + feature_map_list = [] + for feature_map in prediction_dict['feature_maps']: + feature_map_list.append(tf.reshape(feature_map, [batch_size, -1])) + box_features = tf.concat(feature_map_list, 1) + box_features = tf.identity(box_features, 'raw_box_features') + if detection_keypoints is not None: additional_fields = { fields.BoxListFields.keypoints: detection_keypoints} @@ -683,17 +730,20 @@ def loss(self, prediction_dict, true_image_shapes, scope=None): self.groundtruth_lists(fields.BoxListFields.boxes), match_list) if self._random_example_sampler: + batch_cls_per_anchor_weights = tf.reduce_mean( + batch_cls_weights, axis=-1) batch_sampled_indicator = tf.to_float( shape_utils.static_or_dynamic_map_fn( self._minibatch_subsample_fn, - [batch_cls_targets, batch_cls_weights], + [batch_cls_targets, batch_cls_per_anchor_weights], dtype=tf.bool, parallel_iterations=self._parallel_iterations, back_prop=True)) batch_reg_weights = tf.multiply(batch_sampled_indicator, batch_reg_weights) - batch_cls_weights = tf.multiply(batch_sampled_indicator, - batch_cls_weights) + batch_cls_weights = tf.multiply( + tf.expand_dims(batch_sampled_indicator, -1), + batch_cls_weights) losses_mask = None if self.groundtruth_has_field(fields.InputDataFields.is_annotated): @@ -713,16 +763,32 @@ def loss(self, prediction_dict, true_image_shapes, scope=None): losses_mask=losses_mask) if self._expected_classification_loss_under_sampling: + # Need to compute losses for assigned targets against the + # unmatched_class_label as well as their assigned targets. + # simplest thing (but wasteful) is just to calculate all losses + # twice + batch_size, num_anchors, num_classes = batch_cls_targets.get_shape() + unmatched_targets = tf.ones([batch_size, num_anchors, 1 + ]) * self._unmatched_class_label + + unmatched_cls_losses = self._classification_loss( + prediction_dict['class_predictions_with_background'], + unmatched_targets, + weights=batch_cls_weights, + losses_mask=losses_mask) + if cls_losses.get_shape().ndims == 3: batch_size, num_anchors, num_classes = cls_losses.get_shape() cls_losses = tf.reshape(cls_losses, [batch_size, -1]) + unmatched_cls_losses = tf.reshape(unmatched_cls_losses, + [batch_size, -1]) batch_cls_targets = tf.reshape( batch_cls_targets, [batch_size, num_anchors * num_classes, -1]) batch_cls_targets = tf.concat( [1 - batch_cls_targets, batch_cls_targets], axis=-1) cls_losses = self._expected_classification_loss_under_sampling( - batch_cls_targets, cls_losses) + batch_cls_targets, cls_losses, unmatched_cls_losses) classification_loss = tf.reduce_sum(cls_losses) localization_loss = tf.reduce_sum(location_losses) @@ -971,6 +1037,26 @@ def _batch_decode(self, box_encodings): [combined_shape[0], combined_shape[1], 4])) return decoded_boxes, decoded_keypoints + def regularization_losses(self): + """Returns a list of regularization losses for this model. + + Returns a list of regularization losses for this model that the estimator + needs to use during training/optimization. + + Returns: + A list of regularization loss tensors. + """ + losses = [] + slim_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) + # Copy the slim losses to avoid modifying the collection + if slim_losses: + losses.extend(slim_losses) + if self._box_predictor.is_keras_model: + losses.extend(self._box_predictor.losses) + if self._feature_extractor.is_keras_model: + losses.extend(self._feature_extractor.losses) + return losses + def restore_map(self, fine_tune_checkpoint_type='detection', load_all_detection_checkpoint_vars=False): @@ -997,18 +1083,44 @@ def restore_map(self, if fine_tune_checkpoint_type not in ['detection', 'classification']: raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format( fine_tune_checkpoint_type)) - variables_to_restore = {} - for variable in tf.global_variables(): - var_name = variable.op.name - if (fine_tune_checkpoint_type == 'detection' and - load_all_detection_checkpoint_vars): - variables_to_restore[var_name] = variable - else: - if var_name.startswith(self._extract_features_scope): - if fine_tune_checkpoint_type == 'classification': - var_name = ( - re.split('^' + self._extract_features_scope + '/', - var_name)[-1]) + + if fine_tune_checkpoint_type == 'classification': + return self._feature_extractor.restore_from_classification_checkpoint_fn( + self._extract_features_scope) + + if fine_tune_checkpoint_type == 'detection': + variables_to_restore = {} + for variable in tf.global_variables(): + var_name = variable.op.name + if load_all_detection_checkpoint_vars: variables_to_restore[var_name] = variable + else: + if var_name.startswith(self._extract_features_scope): + variables_to_restore[var_name] = variable return variables_to_restore + + def updates(self): + """Returns a list of update operators for this model. + + Returns a list of update operators for this model that must be executed at + each training step. The estimator's train op needs to have a control + dependency on these updates. + + Returns: + A list of update operators. + """ + update_ops = [] + slim_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + # Copy the slim ops to avoid modifying the collection + if slim_update_ops: + update_ops.extend(slim_update_ops) + if self._box_predictor.is_keras_model: + update_ops.extend(self._box_predictor.get_updates_for(None)) + update_ops.extend(self._box_predictor.get_updates_for( + self._box_predictor.inputs)) + if self._feature_extractor.is_keras_model: + update_ops.extend(self._feature_extractor.get_updates_for(None)) + update_ops.extend(self._feature_extractor.get_updates_for( + self._feature_extractor.inputs)) + return update_ops diff --git a/research/object_detection/meta_architectures/ssd_meta_arch_test.py b/research/object_detection/meta_architectures/ssd_meta_arch_test.py index da1607a910f..6f8fa757f9f 100644 --- a/research/object_detection/meta_architectures/ssd_meta_arch_test.py +++ b/research/object_detection/meta_architectures/ssd_meta_arch_test.py @@ -42,7 +42,7 @@ def _create_model(self, random_example_sampling=False, weight_regression_loss_by_score=False, use_expected_classification_loss_under_sampling=False, - minimum_negative_sampling=1, + min_num_negative_samples=1, desired_negative_sampling_ratio=3, use_keras=False, predict_mask=False, @@ -57,7 +57,7 @@ def _create_model(self, weight_regression_loss_by_score=weight_regression_loss_by_score, use_expected_classification_loss_under_sampling= use_expected_classification_loss_under_sampling, - minimum_negative_sampling=minimum_negative_sampling, + min_num_negative_samples=min_num_negative_samples, desired_negative_sampling_ratio=desired_negative_sampling_ratio, use_keras=use_keras, predict_mask=predict_mask, @@ -344,11 +344,11 @@ def graph_fn(preprocessed_tensor, groundtruth_boxes1, groundtruth_boxes2, preprocessed_input = np.random.rand(batch_size, 2, 2, 3).astype(np.float32) groundtruth_boxes1 = np.array([[0, 0, .5, .5]], dtype=np.float32) groundtruth_boxes2 = np.array([[0, 0, .5, .5]], dtype=np.float32) - groundtruth_classes1 = np.array([[0, 1]], dtype=np.float32) - groundtruth_classes2 = np.array([[0, 1]], dtype=np.float32) + groundtruth_classes1 = np.array([[1]], dtype=np.float32) + groundtruth_classes2 = np.array([[1]], dtype=np.float32) expected_localization_loss = 0.0 expected_classification_loss = ( - batch_size * num_anchors * (num_classes + 1) * np.log(2.0)) + batch_size * num_anchors * num_classes * np.log(2.0)) (localization_loss, classification_loss) = self.execute( graph_fn, [ preprocessed_input, groundtruth_boxes1, groundtruth_boxes2, @@ -371,7 +371,7 @@ def graph_fn(preprocessed_tensor, groundtruth_boxes1, groundtruth_boxes2, apply_hard_mining=False, add_background_class=True, use_expected_classification_loss_under_sampling=True, - minimum_negative_sampling=1, + min_num_negative_samples=1, desired_negative_sampling_ratio=desired_negative_sampling_ratio) model.provide_groundtruth(groundtruth_boxes_list, groundtruth_classes_list) @@ -391,8 +391,7 @@ def graph_fn(preprocessed_tensor, groundtruth_boxes1, groundtruth_boxes2, expected_localization_loss = 0.0 expected_classification_loss = ( - batch_size * (desired_negative_sampling_ratio * num_anchors + - num_classes * num_anchors) * np.log(2.0)) + batch_size * (num_anchors + num_classes * num_anchors) * np.log(2.0)) (localization_loss, classification_loss) = self.execute( graph_fn, [ preprocessed_input, groundtruth_boxes1, groundtruth_boxes2, @@ -432,11 +431,11 @@ def graph_fn(preprocessed_tensor, groundtruth_boxes1, groundtruth_boxes2, preprocessed_input = np.random.rand(batch_size, 2, 2, 3).astype(np.float32) groundtruth_boxes1 = np.array([[0, 0, 1, 1]], dtype=np.float32) groundtruth_boxes2 = np.array([[0, 0, 1, 1]], dtype=np.float32) - groundtruth_classes1 = np.array([[0, 1]], dtype=np.float32) - groundtruth_classes2 = np.array([[1, 0]], dtype=np.float32) + groundtruth_classes1 = np.array([[1]], dtype=np.float32) + groundtruth_classes2 = np.array([[0]], dtype=np.float32) expected_localization_loss = 0.25 expected_classification_loss = ( - batch_size * num_anchors * (num_classes + 1) * np.log(2.0)) + batch_size * num_anchors * num_classes * np.log(2.0)) (localization_loss, classification_loss) = self.execute( graph_fn, [ preprocessed_input, groundtruth_boxes1, groundtruth_boxes2, diff --git a/research/object_detection/meta_architectures/ssd_meta_arch_test_lib.py b/research/object_detection/meta_architectures/ssd_meta_arch_test_lib.py index 5d9114aedec..c068850df8b 100644 --- a/research/object_detection/meta_architectures/ssd_meta_arch_test_lib.py +++ b/research/object_detection/meta_architectures/ssd_meta_arch_test_lib.py @@ -119,7 +119,7 @@ def _create_model(self, random_example_sampling=False, weight_regression_loss_by_score=False, use_expected_classification_loss_under_sampling=False, - minimum_negative_sampling=1, + min_num_negative_samples=1, desired_negative_sampling_ratio=3, use_keras=False, predict_mask=False, @@ -130,10 +130,12 @@ def _create_model(self, mock_anchor_generator = MockAnchorGenerator2x2() if use_keras: mock_box_predictor = test_utils.MockKerasBoxPredictor( - is_training, num_classes, predict_mask=predict_mask) + is_training, num_classes, add_background_class=add_background_class, + predict_mask=predict_mask) else: mock_box_predictor = test_utils.MockBoxPredictor( - is_training, num_classes, predict_mask=predict_mask) + is_training, num_classes, add_background_class=add_background_class, + predict_mask=predict_mask) mock_box_coder = test_utils.MockBoxCoder() if use_keras: fake_feature_extractor = FakeSSDKerasFeatureExtractor() @@ -182,7 +184,7 @@ def image_resizer_fn(image): if use_expected_classification_loss_under_sampling: expected_classification_loss_under_sampling = functools.partial( ops.expected_classification_loss_under_sampling, - minimum_negative_sampling=minimum_negative_sampling, + min_num_negative_samples=min_num_negative_samples, desired_negative_sampling_ratio=desired_negative_sampling_ratio) code_size = 4 diff --git a/research/object_detection/metrics/coco_evaluation.py b/research/object_detection/metrics/coco_evaluation.py index 9d36a7c2eb0..cefbd6abf74 100644 --- a/research/object_detection/metrics/coco_evaluation.py +++ b/research/object_detection/metrics/coco_evaluation.py @@ -248,27 +248,30 @@ def update_op( detection_boxes_batched, detection_scores_batched, detection_classes_batched, - num_det_boxes_per_image): + num_det_boxes_per_image, + is_annotated_batched): """Update operation for adding batch of images to Coco evaluator.""" for (image_id, gt_box, gt_class, gt_is_crowd, num_gt_box, det_box, - det_score, det_class, num_det_box) in zip( + det_score, det_class, num_det_box, is_annotated) in zip( image_id_batched, groundtruth_boxes_batched, groundtruth_classes_batched, groundtruth_is_crowd_batched, num_gt_boxes_per_image, detection_boxes_batched, detection_scores_batched, - detection_classes_batched, num_det_boxes_per_image): - self.add_single_ground_truth_image_info( - image_id, { - 'groundtruth_boxes': gt_box[:num_gt_box], - 'groundtruth_classes': gt_class[:num_gt_box], - 'groundtruth_is_crowd': gt_is_crowd[:num_gt_box] - }) - self.add_single_detected_image_info( - image_id, - {'detection_boxes': det_box[:num_det_box], - 'detection_scores': det_score[:num_det_box], - 'detection_classes': det_class[:num_det_box]}) + detection_classes_batched, num_det_boxes_per_image, + is_annotated_batched): + if is_annotated: + self.add_single_ground_truth_image_info( + image_id, { + 'groundtruth_boxes': gt_box[:num_gt_box], + 'groundtruth_classes': gt_class[:num_gt_box], + 'groundtruth_is_crowd': gt_is_crowd[:num_gt_box] + }) + self.add_single_detected_image_info( + image_id, + {'detection_boxes': det_box[:num_det_box], + 'detection_scores': det_score[:num_det_box], + 'detection_classes': det_class[:num_det_box]}) # Unpack items from the evaluation dictionary. input_data_fields = standard_fields.InputDataFields @@ -284,6 +287,7 @@ def update_op( num_gt_boxes_per_image = eval_dict.get( 'num_groundtruth_boxes_per_image', None) num_det_boxes_per_image = eval_dict.get('num_det_boxes_per_image', None) + is_annotated = eval_dict.get('is_annotated', None) if groundtruth_is_crowd is None: groundtruth_is_crowd = tf.zeros_like(groundtruth_classes, dtype=tf.bool) @@ -306,6 +310,11 @@ def update_op( num_det_boxes_per_image = tf.shape(detection_boxes)[1:2] else: num_det_boxes_per_image = tf.expand_dims(num_det_boxes_per_image, 0) + + if is_annotated is None: + is_annotated = tf.constant([True]) + else: + is_annotated = tf.expand_dims(is_annotated, 0) else: if num_gt_boxes_per_image is None: num_gt_boxes_per_image = tf.tile( @@ -315,6 +324,8 @@ def update_op( num_det_boxes_per_image = tf.tile( tf.shape(detection_boxes)[1:2], multiples=tf.shape(detection_boxes)[0:1]) + if is_annotated is None: + is_annotated = tf.ones_like(image_id, dtype=tf.bool) update_op = tf.py_func(update_op, [image_id, groundtruth_boxes, @@ -324,7 +335,8 @@ def update_op( detection_boxes, detection_scores, detection_classes, - num_det_boxes_per_image], []) + num_det_boxes_per_image, + is_annotated], []) metric_names = ['DetectionBoxes_Precision/mAP', 'DetectionBoxes_Precision/mAP@.50IOU', 'DetectionBoxes_Precision/mAP@.75IOU', @@ -581,8 +593,11 @@ def get_estimator_eval_metric_ops(self, eval_dict): Args: eval_dict: A dictionary that holds tensors for evaluating object detection - performance. This dictionary may be produced from - eval_util.result_dict_for_single_example(). + performance. For single-image evaluation, this dictionary may be + produced from eval_util.result_dict_for_single_example(). If multi-image + evaluation, `eval_dict` should contain the fields + 'num_groundtruth_boxes_per_image' and 'num_det_boxes_per_image' to + properly unpad the tensors from the batch. Returns: a dictionary of metric names to tuple of value_op and update_op that can @@ -590,27 +605,41 @@ def get_estimator_eval_metric_ops(self, eval_dict): update ops must be run together and similarly all value ops must be run together to guarantee correct behaviour. """ - def update_op( - image_id, - groundtruth_boxes, - groundtruth_classes, - groundtruth_instance_masks, - groundtruth_is_crowd, - detection_scores, - detection_classes, - detection_masks): + + def update_op(image_id_batched, groundtruth_boxes_batched, + groundtruth_classes_batched, + groundtruth_instance_masks_batched, + groundtruth_is_crowd_batched, num_gt_boxes_per_image, + detection_scores_batched, detection_classes_batched, + detection_masks_batched, num_det_boxes_per_image): """Update op for metrics.""" - self.add_single_ground_truth_image_info( - image_id, - {'groundtruth_boxes': groundtruth_boxes, - 'groundtruth_classes': groundtruth_classes, - 'groundtruth_instance_masks': groundtruth_instance_masks, - 'groundtruth_is_crowd': groundtruth_is_crowd}) - self.add_single_detected_image_info( - image_id, - {'detection_scores': detection_scores, - 'detection_classes': detection_classes, - 'detection_masks': detection_masks}) + + for (image_id, groundtruth_boxes, groundtruth_classes, + groundtruth_instance_masks, groundtruth_is_crowd, num_gt_box, + detection_scores, detection_classes, + detection_masks, num_det_box) in zip( + image_id_batched, groundtruth_boxes_batched, + groundtruth_classes_batched, groundtruth_instance_masks_batched, + groundtruth_is_crowd_batched, num_gt_boxes_per_image, + detection_scores_batched, detection_classes_batched, + detection_masks_batched, num_det_boxes_per_image): + self.add_single_ground_truth_image_info( + image_id, { + 'groundtruth_boxes': + groundtruth_boxes[:num_gt_box], + 'groundtruth_classes': + groundtruth_classes[:num_gt_box], + 'groundtruth_instance_masks': + groundtruth_instance_masks[:num_gt_box], + 'groundtruth_is_crowd': + groundtruth_is_crowd[:num_gt_box] + }) + self.add_single_detected_image_info( + image_id, { + 'detection_scores': detection_scores[:num_det_box], + 'detection_classes': detection_classes[:num_det_box], + 'detection_masks': detection_masks[:num_det_box] + }) # Unpack items from the evaluation dictionary. input_data_fields = standard_fields.InputDataFields @@ -622,20 +651,54 @@ def update_op( input_data_fields.groundtruth_instance_masks] groundtruth_is_crowd = eval_dict.get( input_data_fields.groundtruth_is_crowd, None) + num_gt_boxes_per_image = eval_dict.get( + input_data_fields.num_groundtruth_boxes, None) detection_scores = eval_dict[detection_fields.detection_scores] detection_classes = eval_dict[detection_fields.detection_classes] detection_masks = eval_dict[detection_fields.detection_masks] + num_det_boxes_per_image = eval_dict.get(detection_fields.num_detections, + None) if groundtruth_is_crowd is None: groundtruth_is_crowd = tf.zeros_like(groundtruth_classes, dtype=tf.bool) - update_op = tf.py_func(update_op, [image_id, - groundtruth_boxes, - groundtruth_classes, - groundtruth_instance_masks, - groundtruth_is_crowd, - detection_scores, - detection_classes, - detection_masks], []) + + if not image_id.shape.as_list(): + # Apply a batch dimension to all tensors. + image_id = tf.expand_dims(image_id, 0) + groundtruth_boxes = tf.expand_dims(groundtruth_boxes, 0) + groundtruth_classes = tf.expand_dims(groundtruth_classes, 0) + groundtruth_instance_masks = tf.expand_dims(groundtruth_instance_masks, 0) + groundtruth_is_crowd = tf.expand_dims(groundtruth_is_crowd, 0) + detection_scores = tf.expand_dims(detection_scores, 0) + detection_classes = tf.expand_dims(detection_classes, 0) + detection_masks = tf.expand_dims(detection_masks, 0) + + if num_gt_boxes_per_image is None: + num_gt_boxes_per_image = tf.shape(groundtruth_boxes)[1:2] + else: + num_gt_boxes_per_image = tf.expand_dims(num_gt_boxes_per_image, 0) + + if num_det_boxes_per_image is None: + num_det_boxes_per_image = tf.shape(detection_scores)[1:2] + else: + num_det_boxes_per_image = tf.expand_dims(num_det_boxes_per_image, 0) + else: + if num_gt_boxes_per_image is None: + num_gt_boxes_per_image = tf.tile( + tf.shape(groundtruth_boxes)[1:2], + multiples=tf.shape(groundtruth_boxes)[0:1]) + if num_det_boxes_per_image is None: + num_det_boxes_per_image = tf.tile( + tf.shape(detection_scores)[1:2], + multiples=tf.shape(detection_scores)[0:1]) + + update_op = tf.py_func(update_op, [ + image_id, groundtruth_boxes, groundtruth_classes, + groundtruth_instance_masks, groundtruth_is_crowd, + num_gt_boxes_per_image, detection_scores, detection_classes, + detection_masks, num_det_boxes_per_image + ], []) + metric_names = ['DetectionMasks_Precision/mAP', 'DetectionMasks_Precision/mAP@.50IOU', 'DetectionMasks_Precision/mAP@.75IOU', diff --git a/research/object_detection/metrics/coco_evaluation_test.py b/research/object_detection/metrics/coco_evaluation_test.py index 3aebeb58c3d..0a567c51557 100644 --- a/research/object_detection/metrics/coco_evaluation_test.py +++ b/research/object_detection/metrics/coco_evaluation_test.py @@ -308,6 +308,99 @@ def testGetOneMAPWithMatchingGroundtruthAndDetections(self): self.assertFalse(coco_evaluator._detection_boxes_list) self.assertFalse(coco_evaluator._image_ids) + def testGetOneMAPWithMatchingGroundtruthAndDetectionsIsAnnotated(self): + coco_evaluator = coco_evaluation.CocoDetectionEvaluator( + _get_categories_list()) + image_id = tf.placeholder(tf.string, shape=()) + groundtruth_boxes = tf.placeholder(tf.float32, shape=(None, 4)) + groundtruth_classes = tf.placeholder(tf.float32, shape=(None)) + is_annotated = tf.placeholder(tf.bool, shape=()) + detection_boxes = tf.placeholder(tf.float32, shape=(None, 4)) + detection_scores = tf.placeholder(tf.float32, shape=(None)) + detection_classes = tf.placeholder(tf.float32, shape=(None)) + + input_data_fields = standard_fields.InputDataFields + detection_fields = standard_fields.DetectionResultFields + eval_dict = { + input_data_fields.key: image_id, + input_data_fields.groundtruth_boxes: groundtruth_boxes, + input_data_fields.groundtruth_classes: groundtruth_classes, + 'is_annotated': is_annotated, + detection_fields.detection_boxes: detection_boxes, + detection_fields.detection_scores: detection_scores, + detection_fields.detection_classes: detection_classes + } + + eval_metric_ops = coco_evaluator.get_estimator_eval_metric_ops(eval_dict) + + _, update_op = eval_metric_ops['DetectionBoxes_Precision/mAP'] + + with self.test_session() as sess: + sess.run(update_op, + feed_dict={ + image_id: 'image1', + groundtruth_boxes: np.array([[100., 100., 200., 200.]]), + groundtruth_classes: np.array([1]), + is_annotated: True, + detection_boxes: np.array([[100., 100., 200., 200.]]), + detection_scores: np.array([.8]), + detection_classes: np.array([1]) + }) + sess.run(update_op, + feed_dict={ + image_id: 'image2', + groundtruth_boxes: np.array([[50., 50., 100., 100.]]), + groundtruth_classes: np.array([3]), + is_annotated: True, + detection_boxes: np.array([[50., 50., 100., 100.]]), + detection_scores: np.array([.7]), + detection_classes: np.array([3]) + }) + sess.run(update_op, + feed_dict={ + image_id: 'image3', + groundtruth_boxes: np.array([[25., 25., 50., 50.]]), + groundtruth_classes: np.array([2]), + is_annotated: True, + detection_boxes: np.array([[25., 25., 50., 50.]]), + detection_scores: np.array([.9]), + detection_classes: np.array([2]) + }) + sess.run(update_op, + feed_dict={ + image_id: 'image4', + groundtruth_boxes: np.zeros((0, 4)), + groundtruth_classes: np.zeros((0)), + is_annotated: False, # Note that this image isn't annotated. + detection_boxes: np.array([[25., 25., 50., 50.], + [25., 25., 70., 50.], + [25., 25., 80., 50.], + [25., 25., 90., 50.]]), + detection_scores: np.array([0.6, 0.7, 0.8, 0.9]), + detection_classes: np.array([1, 2, 2, 3]) + }) + metrics = {} + for key, (value_op, _) in eval_metric_ops.iteritems(): + metrics[key] = value_op + metrics = sess.run(metrics) + self.assertAlmostEqual(metrics['DetectionBoxes_Precision/mAP'], 1.0) + self.assertAlmostEqual(metrics['DetectionBoxes_Precision/mAP@.50IOU'], 1.0) + self.assertAlmostEqual(metrics['DetectionBoxes_Precision/mAP@.75IOU'], 1.0) + self.assertAlmostEqual(metrics['DetectionBoxes_Precision/mAP (large)'], 1.0) + self.assertAlmostEqual(metrics['DetectionBoxes_Precision/mAP (medium)'], + 1.0) + self.assertAlmostEqual(metrics['DetectionBoxes_Precision/mAP (small)'], 1.0) + self.assertAlmostEqual(metrics['DetectionBoxes_Recall/AR@1'], 1.0) + self.assertAlmostEqual(metrics['DetectionBoxes_Recall/AR@10'], 1.0) + self.assertAlmostEqual(metrics['DetectionBoxes_Recall/AR@100'], 1.0) + self.assertAlmostEqual(metrics['DetectionBoxes_Recall/AR@100 (large)'], 1.0) + self.assertAlmostEqual(metrics['DetectionBoxes_Recall/AR@100 (medium)'], + 1.0) + self.assertAlmostEqual(metrics['DetectionBoxes_Recall/AR@100 (small)'], 1.0) + self.assertFalse(coco_evaluator._groundtruth_list) + self.assertFalse(coco_evaluator._detection_boxes_list) + self.assertFalse(coco_evaluator._image_ids) + def testGetOneMAPWithMatchingGroundtruthAndDetectionsPadded(self): coco_evaluator = coco_evaluation.CocoDetectionEvaluator( _get_categories_list()) @@ -665,22 +758,40 @@ def testGetOneMAPWithMatchingGroundtruthAndDetections(self): _, update_op = eval_metric_ops['DetectionMasks_Precision/mAP'] with self.test_session() as sess: - sess.run(update_op, - feed_dict={ - image_id: 'image1', - groundtruth_boxes: np.array([[100., 100., 200., 200.]]), - groundtruth_classes: np.array([1]), - groundtruth_masks: np.pad(np.ones([1, 100, 100], - dtype=np.uint8), - ((0, 0), (10, 10), (10, 10)), - mode='constant'), - detection_scores: np.array([.8]), - detection_classes: np.array([1]), - detection_masks: np.pad(np.ones([1, 100, 100], - dtype=np.uint8), - ((0, 0), (10, 10), (10, 10)), - mode='constant') - }) + sess.run( + update_op, + feed_dict={ + image_id: + 'image1', + groundtruth_boxes: + np.array([[100., 100., 200., 200.], [50., 50., 100., 100.]]), + groundtruth_classes: + np.array([1, 2]), + groundtruth_masks: + np.stack([ + np.pad( + np.ones([100, 100], dtype=np.uint8), ((10, 10), + (10, 10)), + mode='constant'), + np.pad( + np.ones([50, 50], dtype=np.uint8), ((0, 70), (0, 70)), + mode='constant') + ]), + detection_scores: + np.array([.9, .8]), + detection_classes: + np.array([2, 1]), + detection_masks: + np.stack([ + np.pad( + np.ones([50, 50], dtype=np.uint8), ((0, 70), (0, 70)), + mode='constant'), + np.pad( + np.ones([100, 100], dtype=np.uint8), ((10, 10), + (10, 10)), + mode='constant'), + ]) + }) sess.run(update_op, feed_dict={ image_id: 'image2', @@ -735,6 +846,106 @@ def testGetOneMAPWithMatchingGroundtruthAndDetections(self): self.assertFalse(coco_evaluator._image_id_to_mask_shape_map) self.assertFalse(coco_evaluator._detection_masks_list) + def testGetOneMAPWithMatchingGroundtruthAndDetectionsBatched(self): + coco_evaluator = coco_evaluation.CocoMaskEvaluator(_get_categories_list()) + batch_size = 3 + image_id = tf.placeholder(tf.string, shape=(batch_size)) + groundtruth_boxes = tf.placeholder(tf.float32, shape=(batch_size, None, 4)) + groundtruth_classes = tf.placeholder(tf.float32, shape=(batch_size, None)) + groundtruth_masks = tf.placeholder( + tf.uint8, shape=(batch_size, None, None, None)) + detection_scores = tf.placeholder(tf.float32, shape=(batch_size, None)) + detection_classes = tf.placeholder(tf.float32, shape=(batch_size, None)) + detection_masks = tf.placeholder( + tf.uint8, shape=(batch_size, None, None, None)) + + input_data_fields = standard_fields.InputDataFields + detection_fields = standard_fields.DetectionResultFields + eval_dict = { + input_data_fields.key: image_id, + input_data_fields.groundtruth_boxes: groundtruth_boxes, + input_data_fields.groundtruth_classes: groundtruth_classes, + input_data_fields.groundtruth_instance_masks: groundtruth_masks, + detection_fields.detection_scores: detection_scores, + detection_fields.detection_classes: detection_classes, + detection_fields.detection_masks: detection_masks, + } + + eval_metric_ops = coco_evaluator.get_estimator_eval_metric_ops(eval_dict) + + _, update_op = eval_metric_ops['DetectionMasks_Precision/mAP'] + + with self.test_session() as sess: + sess.run( + update_op, + feed_dict={ + image_id: ['image1', 'image2', 'image3'], + groundtruth_boxes: + np.array([[[100., 100., 200., 200.]], + [[50., 50., 100., 100.]], + [[25., 25., 50., 50.]]]), + groundtruth_classes: + np.array([[1], [1], [1]]), + groundtruth_masks: + np.stack([ + np.pad( + np.ones([1, 100, 100], dtype=np.uint8), + ((0, 0), (0, 0), (0, 0)), + mode='constant'), + np.pad( + np.ones([1, 50, 50], dtype=np.uint8), + ((0, 0), (25, 25), (25, 25)), + mode='constant'), + np.pad( + np.ones([1, 25, 25], dtype=np.uint8), + ((0, 0), (37, 38), (37, 38)), + mode='constant') + ], + axis=0), + detection_scores: + np.array([[.8], [.8], [.8]]), + detection_classes: + np.array([[1], [1], [1]]), + detection_masks: + np.stack([ + np.pad( + np.ones([1, 100, 100], dtype=np.uint8), + ((0, 0), (0, 0), (0, 0)), + mode='constant'), + np.pad( + np.ones([1, 50, 50], dtype=np.uint8), + ((0, 0), (25, 25), (25, 25)), + mode='constant'), + np.pad( + np.ones([1, 25, 25], dtype=np.uint8), + ((0, 0), (37, 38), (37, 38)), + mode='constant') + ], + axis=0) + }) + metrics = {} + for key, (value_op, _) in eval_metric_ops.iteritems(): + metrics[key] = value_op + metrics = sess.run(metrics) + self.assertAlmostEqual(metrics['DetectionMasks_Precision/mAP'], 1.0) + self.assertAlmostEqual(metrics['DetectionMasks_Precision/mAP@.50IOU'], 1.0) + self.assertAlmostEqual(metrics['DetectionMasks_Precision/mAP@.75IOU'], 1.0) + self.assertAlmostEqual(metrics['DetectionMasks_Precision/mAP (large)'], 1.0) + self.assertAlmostEqual(metrics['DetectionMasks_Precision/mAP (medium)'], + 1.0) + self.assertAlmostEqual(metrics['DetectionMasks_Precision/mAP (small)'], 1.0) + self.assertAlmostEqual(metrics['DetectionMasks_Recall/AR@1'], 1.0) + self.assertAlmostEqual(metrics['DetectionMasks_Recall/AR@10'], 1.0) + self.assertAlmostEqual(metrics['DetectionMasks_Recall/AR@100'], 1.0) + self.assertAlmostEqual(metrics['DetectionMasks_Recall/AR@100 (large)'], 1.0) + self.assertAlmostEqual(metrics['DetectionMasks_Recall/AR@100 (medium)'], + 1.0) + self.assertAlmostEqual(metrics['DetectionMasks_Recall/AR@100 (small)'], 1.0) + self.assertFalse(coco_evaluator._groundtruth_list) + self.assertFalse(coco_evaluator._image_ids_with_detections) + self.assertFalse(coco_evaluator._image_id_to_mask_shape_map) + self.assertFalse(coco_evaluator._detection_masks_list) + if __name__ == '__main__': tf.test.main() diff --git a/research/object_detection/model_lib.py b/research/object_detection/model_lib.py index 5ba6f9b6538..b54c6ef2b3d 100644 --- a/research/object_detection/model_lib.py +++ b/research/object_detection/model_lib.py @@ -25,6 +25,7 @@ import tensorflow as tf from object_detection import eval_util +from object_detection import exporter as exporter_lib from object_detection import inputs from object_detection.builders import graph_rewriter_builder from object_detection.builders import model_builder @@ -306,8 +307,7 @@ def tpu_scaffold(): prediction_dict, features[fields.InputDataFields.true_image_shape]) losses = [loss_tensor for loss_tensor in losses_dict.values()] if train_config.add_regularization_loss: - regularization_losses = tf.get_collection( - tf.GraphKeys.REGULARIZATION_LOSSES) + regularization_losses = detection_model.regularization_losses() if regularization_losses: regularization_loss = tf.add_n( regularization_losses, name='regularization_loss') @@ -353,20 +353,24 @@ def tpu_scaffold(): for var in optimizer_summary_vars: tf.summary.scalar(var.op.name, var) summaries = [] if use_tpu else None + if train_config.summarize_gradients: + summaries = ['gradients', 'gradient_norm', 'global_gradient_norm'] train_op = tf.contrib.layers.optimize_loss( loss=total_loss, global_step=global_step, learning_rate=None, clip_gradients=clip_gradients_value, optimizer=training_optimizer, + update_ops=detection_model.updates(), variables=trainable_variables, summaries=summaries, name='') # Preventing scope prefix on all variables. if mode == tf.estimator.ModeKeys.PREDICT: + exported_output = exporter_lib.add_output_tensor_nodes(detections) export_outputs = { tf.saved_model.signature_constants.PREDICT_METHOD_NAME: - tf.estimator.export.PredictOutput(detections) + tf.estimator.export.PredictOutput(exported_output) } eval_metric_ops = None @@ -456,6 +460,7 @@ def tpu_scaffold(): def create_estimator_and_inputs(run_config, hparams, pipeline_config_path, + config_override=None, train_steps=None, sample_1_of_n_eval_examples=1, sample_1_of_n_eval_on_train_examples=1, @@ -465,6 +470,7 @@ def create_estimator_and_inputs(run_config, num_shards=1, params=None, override_eval_num_epochs=True, + save_final_config=False, **kwargs): """Creates `Estimator`, input functions, and steps. @@ -472,6 +478,8 @@ def create_estimator_and_inputs(run_config, run_config: A `RunConfig`. hparams: A `HParams`. pipeline_config_path: A path to a pipeline config file. + config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to + override the config from `pipeline_config_path`. train_steps: Number of training steps. If None, the number of training steps is set from the `TrainConfig` proto. sample_1_of_n_eval_examples: Integer representing how often an eval example @@ -499,6 +507,8 @@ def create_estimator_and_inputs(run_config, `use_tpu_estimator` is True. override_eval_num_epochs: Whether to overwrite the number of epochs to 1 for eval_input. + save_final_config: Whether to save final config (obtained after applying + overrides) to `estimator.model_dir`. **kwargs: Additional keyword arguments for configuration override. Returns: @@ -522,7 +532,8 @@ def create_estimator_and_inputs(run_config, create_eval_input_fn = MODEL_BUILD_UTIL_MAP['create_eval_input_fn'] create_predict_input_fn = MODEL_BUILD_UTIL_MAP['create_predict_input_fn'] - configs = get_configs_from_pipeline_file(pipeline_config_path) + configs = get_configs_from_pipeline_file(pipeline_config_path, + config_override=config_override) kwargs.update({ 'train_steps': train_steps, 'sample_1_of_n_eval_examples': sample_1_of_n_eval_examples @@ -595,7 +606,7 @@ def create_estimator_and_inputs(run_config, estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) # Write the as-run pipeline config to disk. - if run_config.is_chief: + if run_config.is_chief and save_final_config: pipeline_config_final = create_pipeline_proto_from_configs(configs) config_util.save_pipeline_config(pipeline_config_final, estimator.model_dir) @@ -641,11 +652,17 @@ def create_train_and_eval_specs(train_input_fn, input_fn=train_input_fn, max_steps=train_steps) if eval_spec_names is None: - eval_spec_names = [ str(i) for i in range(len(eval_input_fns)) ] + eval_spec_names = [str(i) for i in range(len(eval_input_fns))] eval_specs = [] - for eval_spec_name, eval_input_fn in zip(eval_spec_names, eval_input_fns): - exporter_name = '{}_{}'.format(final_exporter_name, eval_spec_name) + for index, (eval_spec_name, eval_input_fn) in enumerate( + zip(eval_spec_names, eval_input_fns)): + # Uses final_exporter_name as exporter_name for the first eval spec for + # backward compatibility. + if index == 0: + exporter_name = final_exporter_name + else: + exporter_name = '{}_{}'.format(final_exporter_name, eval_spec_name) exporter = tf.estimator.FinalExporter( name=exporter_name, serving_input_receiver_fn=predict_input_fn) eval_specs.append( @@ -747,6 +764,7 @@ def populate_experiment(run_config, train_steps=train_steps, eval_steps=eval_steps, model_fn_creator=model_fn_creator, + save_final_config=True, **kwargs) estimator = train_and_eval_dict['estimator'] train_input_fn = train_and_eval_dict['train_input_fn'] diff --git a/research/object_detection/model_lib_test.py b/research/object_detection/model_lib_test.py index 66e54b02fc0..c61fbb6ea61 100644 --- a/research/object_detection/model_lib_test.py +++ b/research/object_detection/model_lib_test.py @@ -310,7 +310,7 @@ def test_create_train_and_eval_specs(self): self.assertEqual(2, len(eval_specs)) self.assertEqual(None, eval_specs[0].steps) self.assertEqual('holdout', eval_specs[0].name) - self.assertEqual('exporter_holdout', eval_specs[0].exporters[0].name) + self.assertEqual('exporter', eval_specs[0].exporters[0].name) self.assertEqual(None, eval_specs[1].steps) self.assertEqual('eval_on_train', eval_specs[1].name) diff --git a/research/object_detection/model_tpu_main.py b/research/object_detection/model_tpu_main.py index 632301cb4b2..67a62fe3adc 100644 --- a/research/object_detection/model_tpu_main.py +++ b/research/object_detection/model_tpu_main.py @@ -114,6 +114,7 @@ def main(unused_argv): use_tpu_estimator=True, use_tpu=FLAGS.use_tpu, num_shards=FLAGS.num_shards, + save_final_config=FLAGS.mode == 'train', **kwargs) estimator = train_and_eval_dict['estimator'] train_input_fn = train_and_eval_dict['train_input_fn'] diff --git a/research/object_detection/models/faster_rcnn_resnet_v1_feature_extractor.py b/research/object_detection/models/faster_rcnn_resnet_v1_feature_extractor.py index 286deae3de7..8f89b7c05ed 100644 --- a/research/object_detection/models/faster_rcnn_resnet_v1_feature_extractor.py +++ b/research/object_detection/models/faster_rcnn_resnet_v1_feature_extractor.py @@ -72,6 +72,8 @@ def preprocess(self, resized_inputs): VGG style channel mean subtraction as described here: https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-readme-md + Note that if the number of channels is not equal to 3, the mean subtraction + will be skipped and the original resized_inputs will be returned. Args: resized_inputs: A [batch, height_in, width_in, channels] float32 tensor @@ -82,8 +84,11 @@ def preprocess(self, resized_inputs): tensor representing a batch of images. """ - channel_means = [123.68, 116.779, 103.939] - return resized_inputs - [[channel_means]] + if resized_inputs.shape.as_list()[3] == 3: + channel_means = [123.68, 116.779, 103.939] + return resized_inputs - [[channel_means]] + else: + return resized_inputs def _extract_proposal_features(self, preprocessed_inputs, scope): """Extracts first stage RPN features. diff --git a/research/object_detection/models/feature_map_generators.py b/research/object_detection/models/feature_map_generators.py index 84397f2c4a5..4c2fe68da3f 100644 --- a/research/object_detection/models/feature_map_generators.py +++ b/research/object_detection/models/feature_map_generators.py @@ -146,7 +146,6 @@ def __init__(self, use_depthwise = feature_map_layout['use_depthwise'] for index, from_layer in enumerate(feature_map_layout['from_layer']): net = [] - self.convolutions.append(net) layer_depth = feature_map_layout['layer_depth'][index] conv_kernel_size = 3 if 'conv_kernel_size' in feature_map_layout: @@ -231,6 +230,10 @@ def fixed_padding(features, kernel_size=conv_kernel_size): conv_hyperparams.build_activation_layer( name=layer_name)) + # Until certain bugs are fixed in checkpointable lists, + # this net must be appended only once it's been filled with layers + self.convolutions.append(net) + def call(self, image_features): """Generate the multi-resolution feature maps. @@ -263,7 +266,8 @@ def call(self, image_features): def multi_resolution_feature_maps(feature_map_layout, depth_multiplier, - min_depth, insert_1x1_conv, image_features): + min_depth, insert_1x1_conv, image_features, + pool_residual=False): """Generates multi resolution feature maps from input image features. Generates multi-scale feature maps for detection as in the SSD papers by @@ -317,6 +321,13 @@ def multi_resolution_feature_maps(feature_map_layout, depth_multiplier, should be inserted before shrinking the feature map. image_features: A dictionary of handles to activation tensors from the base feature extractor. + pool_residual: Whether to add an average pooling layer followed by a + residual connection between subsequent feature maps when the channel + depth match. For example, with option 'layer_depth': [-1, 512, 256, 256], + a pooling and residual layer is added between the third and forth feature + map. This option is better used with Weight Shared Convolution Box + Predictor when all feature maps have the same channel depth to encourage + more consistent features across multi-scale feature maps. Returns: feature_maps: an OrderedDict mapping keys (feature map names) to @@ -350,6 +361,7 @@ def multi_resolution_feature_maps(feature_map_layout, depth_multiplier, feature_map_keys.append(from_layer) else: pre_layer = feature_maps[-1] + pre_layer_depth = pre_layer.get_shape().as_list()[3] intermediate_layer = pre_layer if insert_1x1_conv: layer_name = '{}_1_Conv2d_{}_1x1_{}'.format( @@ -383,6 +395,12 @@ def multi_resolution_feature_maps(feature_map_layout, depth_multiplier, padding='SAME', stride=1, scope=layer_name) + if pool_residual and pre_layer_depth == depth_fn(layer_depth): + feature_map += slim.avg_pool2d( + pre_layer, [3, 3], + padding='SAME', + stride=2, + scope=layer_name + '_pool') else: feature_map = slim.conv2d( intermediate_layer, @@ -399,6 +417,7 @@ def multi_resolution_feature_maps(feature_map_layout, depth_multiplier, def fpn_top_down_feature_maps(image_features, depth, use_depthwise=False, + use_explicit_padding=False, scope=None): """Generates `top-down` feature maps for Feature Pyramid Networks. @@ -409,7 +428,9 @@ def fpn_top_down_feature_maps(image_features, Spatial resolutions of succesive tensors must reduce exactly by a factor of 2. depth: depth of output feature maps. - use_depthwise: use depthwise separable conv instead of regular conv. + use_depthwise: whether to use depthwise separable conv instead of regular + conv. + use_explicit_padding: whether to use explicit padding. scope: A scope name to wrap this op under. Returns: @@ -420,8 +441,10 @@ def fpn_top_down_feature_maps(image_features, num_levels = len(image_features) output_feature_maps_list = [] output_feature_map_keys = [] + padding = 'VALID' if use_explicit_padding else 'SAME' + kernel_size = 3 with slim.arg_scope( - [slim.conv2d, slim.separable_conv2d], padding='SAME', stride=1): + [slim.conv2d, slim.separable_conv2d], padding=padding, stride=1): top_down = slim.conv2d( image_features[-1][1], depth, [1, 1], activation_fn=None, normalizer_fn=None, @@ -436,14 +459,20 @@ def fpn_top_down_feature_maps(image_features, image_features[level][1], depth, [1, 1], activation_fn=None, normalizer_fn=None, scope='projection_%d' % (level + 1)) + if use_explicit_padding: + # slice top_down to the same shape as residual + residual_shape = tf.shape(residual) + top_down = top_down[:, :residual_shape[1], :residual_shape[2], :] top_down += residual if use_depthwise: conv_op = functools.partial(slim.separable_conv2d, depth_multiplier=1) else: conv_op = slim.conv2d + if use_explicit_padding: + top_down = ops.fixed_padding(top_down, kernel_size) output_feature_maps_list.append(conv_op( top_down, - depth, [3, 3], + depth, [kernel_size, kernel_size], scope='smoothing_%d' % (level + 1))) output_feature_map_keys.append('top_down_%s' % image_features[level][0]) return collections.OrderedDict(reversed( diff --git a/research/object_detection/models/feature_map_generators_test.py b/research/object_detection/models/feature_map_generators_test.py index c80c52829e2..f7ac0cc0281 100644 --- a/research/object_detection/models/feature_map_generators_test.py +++ b/research/object_detection/models/feature_map_generators_test.py @@ -45,6 +45,11 @@ 'conv_kernel_size': [-1, -1, 3, 3, 2], } +SSD_MOBILENET_V1_WEIGHT_SHARED_LAYOUT = { + 'from_layer': ['Conv2d_13_pointwise', '', '', ''], + 'layer_depth': [-1, 256, 256, 256], +} + @parameterized.parameters( {'use_keras': False}, @@ -67,7 +72,8 @@ def _build_conv_hyperparams(self): text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams) return hyperparams_builder.KerasLayerHyperparams(conv_hyperparams) - def _build_feature_map_generator(self, feature_map_layout, use_keras): + def _build_feature_map_generator(self, feature_map_layout, use_keras, + pool_residual=False): if use_keras: return feature_map_generators.KerasMultiResolutionFeatureMaps( feature_map_layout=feature_map_layout, @@ -86,7 +92,8 @@ def feature_map_generator(image_features): depth_multiplier=1, min_depth=32, insert_1x1_conv=True, - image_features=image_features) + image_features=image_features, + pool_residual=pool_residual) return feature_map_generator def test_get_expected_feature_map_shapes_with_inception_v2(self, use_keras): @@ -209,6 +216,34 @@ def test_get_expected_feature_map_shapes_with_embedded_ssd_mobilenet_v1( (key, value.shape) for key, value in out_feature_maps.items()) self.assertDictEqual(expected_feature_map_shapes, out_feature_map_shapes) + def test_feature_map_shapes_with_pool_residual_ssd_mobilenet_v1( + self, use_keras): + image_features = { + 'Conv2d_13_pointwise': tf.random_uniform([4, 8, 8, 1024], + dtype=tf.float32), + } + + feature_map_generator = self._build_feature_map_generator( + feature_map_layout=SSD_MOBILENET_V1_WEIGHT_SHARED_LAYOUT, + use_keras=use_keras, + pool_residual=True + ) + feature_maps = feature_map_generator(image_features) + + expected_feature_map_shapes = { + 'Conv2d_13_pointwise': (4, 8, 8, 1024), + 'Conv2d_13_pointwise_2_Conv2d_1_3x3_s2_256': (4, 4, 4, 256), + 'Conv2d_13_pointwise_2_Conv2d_2_3x3_s2_256': (4, 2, 2, 256), + 'Conv2d_13_pointwise_2_Conv2d_3_3x3_s2_256': (4, 1, 1, 256)} + + init_op = tf.global_variables_initializer() + with self.test_session() as sess: + sess.run(init_op) + out_feature_maps = sess.run(feature_maps) + out_feature_map_shapes = dict( + (key, value.shape) for key, value in out_feature_maps.items()) + self.assertDictEqual(expected_feature_map_shapes, out_feature_map_shapes) + def test_get_expected_variable_names_with_inception_v2(self, use_keras): image_features = { 'Mixed_3c': tf.random_uniform([4, 28, 28, 256], dtype=tf.float32), diff --git a/research/object_detection/models/keras_applications/mobilenet_v2.py b/research/object_detection/models/keras_applications/mobilenet_v2.py index 4095cc2d2d6..5969b23dd77 100644 --- a/research/object_detection/models/keras_applications/mobilenet_v2.py +++ b/research/object_detection/models/keras_applications/mobilenet_v2.py @@ -82,6 +82,8 @@ def __init__(self, self._conv_hyperparams = conv_hyperparams self._use_explicit_padding = use_explicit_padding self._min_depth = min_depth + self.regularizer = tf.keras.regularizers.l2(0.00004 * 0.5) + self.initializer = tf.truncated_normal_initializer(stddev=0.09) def _FixedPaddingLayer(self, kernel_size): return tf.keras.layers.Lambda(lambda x: ops.fixed_padding(x, kernel_size)) @@ -114,6 +116,9 @@ def Conv2D(self, filters, **kwargs): if self._conv_hyperparams: kwargs = self._conv_hyperparams.params(**kwargs) + else: + kwargs['kernel_regularizer'] = self.regularizer + kwargs['kernel_initializer'] = self.initializer kwargs['padding'] = 'same' kernel_size = kwargs.get('kernel_size') @@ -144,6 +149,8 @@ def DepthwiseConv2D(self, **kwargs): """ if self._conv_hyperparams: kwargs = self._conv_hyperparams.params(**kwargs) + else: + kwargs['depthwise_initializer'] = self.initializer kwargs['padding'] = 'same' kernel_size = kwargs.get('kernel_size') diff --git a/research/object_detection/models/ssd_mobilenet_v1_fpn_feature_extractor.py b/research/object_detection/models/ssd_mobilenet_v1_fpn_feature_extractor.py index b082678aca4..b0c149aead3 100644 --- a/research/object_detection/models/ssd_mobilenet_v1_fpn_feature_extractor.py +++ b/research/object_detection/models/ssd_mobilenet_v1_fpn_feature_extractor.py @@ -31,11 +31,10 @@ # A modified config of mobilenet v1 that makes it more detection friendly, def _create_modified_mobilenet_config(): - conv_defs = copy.copy(mobilenet_v1.MOBILENETV1_CONV_DEFS) + conv_defs = copy.deepcopy(mobilenet_v1.MOBILENETV1_CONV_DEFS) conv_defs[-2] = mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=2, depth=512) conv_defs[-1] = mobilenet_v1.DepthSepConv(kernel=[3, 3], stride=1, depth=256) return conv_defs -_CONV_DEFS = _create_modified_mobilenet_config() class SSDMobileNetV1FpnFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor): @@ -98,6 +97,9 @@ def __init__(self, self._fpn_min_level = fpn_min_level self._fpn_max_level = fpn_max_level self._additional_layer_depth = additional_layer_depth + self._conv_defs = None + if self._use_depthwise: + self._conv_defs = _create_modified_mobilenet_config() def preprocess(self, resized_inputs): """SSD preprocessing. @@ -141,7 +143,7 @@ def extract_features(self, preprocessed_inputs): final_endpoint='Conv2d_13_pointwise', min_depth=self._min_depth, depth_multiplier=self._depth_multiplier, - conv_defs=_CONV_DEFS if self._use_depthwise else None, + conv_defs=self._conv_defs, use_explicit_padding=self._use_explicit_padding, scope=scope) @@ -159,7 +161,8 @@ def extract_features(self, preprocessed_inputs): fpn_features = feature_map_generators.fpn_top_down_feature_maps( [(key, image_features[key]) for key in feature_block_list], depth=depth_fn(self._additional_layer_depth), - use_depthwise=self._use_depthwise) + use_depthwise=self._use_depthwise, + use_explicit_padding=self._use_explicit_padding) feature_maps = [] for level in range(self._fpn_min_level, base_fpn_max_level + 1): feature_maps.append(fpn_features['top_down_{}'.format( @@ -167,18 +170,23 @@ def extract_features(self, preprocessed_inputs): last_feature_map = fpn_features['top_down_{}'.format( feature_blocks[base_fpn_max_level - 2])] # Construct coarse features + padding = 'VALID' if self._use_explicit_padding else 'SAME' + kernel_size = 3 for i in range(base_fpn_max_level + 1, self._fpn_max_level + 1): if self._use_depthwise: conv_op = functools.partial( slim.separable_conv2d, depth_multiplier=1) else: conv_op = slim.conv2d + if self._use_explicit_padding: + last_feature_map = ops.fixed_padding( + last_feature_map, kernel_size) last_feature_map = conv_op( last_feature_map, num_outputs=depth_fn(self._additional_layer_depth), - kernel_size=[3, 3], + kernel_size=[kernel_size, kernel_size], stride=2, - padding='SAME', + padding=padding, scope='bottom_up_Conv2d_{}'.format(i - base_fpn_max_level + 13)) feature_maps.append(last_feature_map) return feature_maps diff --git a/research/object_detection/models/ssd_mobilenet_v2_fpn_feature_extractor.py b/research/object_detection/models/ssd_mobilenet_v2_fpn_feature_extractor.py index 1cbcc24971e..d3c1f7a29b1 100644 --- a/research/object_detection/models/ssd_mobilenet_v2_fpn_feature_extractor.py +++ b/research/object_detection/models/ssd_mobilenet_v2_fpn_feature_extractor.py @@ -30,17 +30,14 @@ slim = tf.contrib.slim -# A modified config of mobilenet v2 that makes it more detection friendly, +# A modified config of mobilenet v2 that makes it more detection friendly. def _create_modified_mobilenet_config(): - conv_defs = copy.copy(mobilenet_v2.V2_DEF) + conv_defs = copy.deepcopy(mobilenet_v2.V2_DEF) conv_defs['spec'][-1] = mobilenet.op( slim.conv2d, stride=1, kernel_size=[1, 1], num_outputs=256) return conv_defs -_CONV_DEFS = _create_modified_mobilenet_config() - - class SSDMobileNetV2FpnFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor): """SSD Feature Extractor using MobilenetV2 FPN features.""" @@ -100,6 +97,9 @@ def __init__(self, self._fpn_min_level = fpn_min_level self._fpn_max_level = fpn_max_level self._additional_layer_depth = additional_layer_depth + self._conv_defs = None + if self._use_depthwise: + self._conv_defs = _create_modified_mobilenet_config() def preprocess(self, resized_inputs): """SSD preprocessing. @@ -142,7 +142,7 @@ def extract_features(self, preprocessed_inputs): ops.pad_to_multiple(preprocessed_inputs, self._pad_to_multiple), final_endpoint='layer_19', depth_multiplier=self._depth_multiplier, - conv_defs=_CONV_DEFS if self._use_depthwise else None, + conv_defs=self._conv_defs, use_explicit_padding=self._use_explicit_padding, scope=scope) depth_fn = lambda d: max(int(d * self._depth_multiplier), self._min_depth) @@ -158,7 +158,8 @@ def extract_features(self, preprocessed_inputs): fpn_features = feature_map_generators.fpn_top_down_feature_maps( [(key, image_features[key]) for key in feature_block_list], depth=depth_fn(self._additional_layer_depth), - use_depthwise=self._use_depthwise) + use_depthwise=self._use_depthwise, + use_explicit_padding=self._use_explicit_padding) feature_maps = [] for level in range(self._fpn_min_level, base_fpn_max_level + 1): feature_maps.append(fpn_features['top_down_{}'.format( @@ -166,18 +167,23 @@ def extract_features(self, preprocessed_inputs): last_feature_map = fpn_features['top_down_{}'.format( feature_blocks[base_fpn_max_level - 2])] # Construct coarse features + padding = 'VALID' if self._use_explicit_padding else 'SAME' + kernel_size = 3 for i in range(base_fpn_max_level + 1, self._fpn_max_level + 1): if self._use_depthwise: conv_op = functools.partial( slim.separable_conv2d, depth_multiplier=1) else: conv_op = slim.conv2d + if self._use_explicit_padding: + last_feature_map = ops.fixed_padding( + last_feature_map, kernel_size) last_feature_map = conv_op( last_feature_map, num_outputs=depth_fn(self._additional_layer_depth), - kernel_size=[3, 3], + kernel_size=[kernel_size, kernel_size], stride=2, - padding='SAME', + padding=padding, scope='bottom_up_Conv2d_{}'.format(i - base_fpn_max_level + 19)) feature_maps.append(last_feature_map) return feature_maps diff --git a/research/object_detection/models/ssd_mobilenet_v2_keras_feature_extractor.py b/research/object_detection/models/ssd_mobilenet_v2_keras_feature_extractor.py index 59708039c92..9bf560eb4b4 100644 --- a/research/object_detection/models/ssd_mobilenet_v2_keras_feature_extractor.py +++ b/research/object_detection/models/ssd_mobilenet_v2_keras_feature_extractor.py @@ -85,41 +85,44 @@ def __init__(self, override_base_feature_extractor_hyperparams= override_base_feature_extractor_hyperparams, name=name) - feature_map_layout = { + self._feature_map_layout = { 'from_layer': ['layer_15/expansion_output', 'layer_19', '', '', '', ''], 'layer_depth': [-1, -1, 512, 256, 256, 128], 'use_depthwise': self._use_depthwise, 'use_explicit_padding': self._use_explicit_padding, } - with tf.name_scope('MobilenetV2'): - full_mobilenet_v2 = mobilenet_v2.mobilenet_v2( - batchnorm_training=(is_training and not freeze_batchnorm), - conv_hyperparams=(conv_hyperparams - if self._override_base_feature_extractor_hyperparams - else None), - weights=None, - use_explicit_padding=use_explicit_padding, - alpha=self._depth_multiplier, - min_depth=self._min_depth, - include_top=False) - conv2d_11_pointwise = full_mobilenet_v2.get_layer( - name='block_13_expand_relu').output - conv2d_13_pointwise = full_mobilenet_v2.get_layer(name='out_relu').output - self.mobilenet_v2 = tf.keras.Model( - inputs=full_mobilenet_v2.inputs, - outputs=[conv2d_11_pointwise, conv2d_13_pointwise]) - - self.feature_map_generator = ( - feature_map_generators.KerasMultiResolutionFeatureMaps( - feature_map_layout=feature_map_layout, - depth_multiplier=self._depth_multiplier, - min_depth=self._min_depth, - insert_1x1_conv=True, - is_training=is_training, - conv_hyperparams=conv_hyperparams, - freeze_batchnorm=freeze_batchnorm, - name='FeatureMaps')) + self.mobilenet_v2 = None + self.feature_map_generator = None + + def build(self, input_shape): + full_mobilenet_v2 = mobilenet_v2.mobilenet_v2( + batchnorm_training=(self._is_training and not self._freeze_batchnorm), + conv_hyperparams=(self._conv_hyperparams + if self._override_base_feature_extractor_hyperparams + else None), + weights=None, + use_explicit_padding=self._use_explicit_padding, + alpha=self._depth_multiplier, + min_depth=self._min_depth, + include_top=False) + conv2d_11_pointwise = full_mobilenet_v2.get_layer( + name='block_13_expand_relu').output + conv2d_13_pointwise = full_mobilenet_v2.get_layer(name='out_relu').output + self.mobilenet_v2 = tf.keras.Model( + inputs=full_mobilenet_v2.inputs, + outputs=[conv2d_11_pointwise, conv2d_13_pointwise]) + self.feature_map_generator = ( + feature_map_generators.KerasMultiResolutionFeatureMaps( + feature_map_layout=self._feature_map_layout, + depth_multiplier=self._depth_multiplier, + min_depth=self._min_depth, + insert_1x1_conv=True, + is_training=self._is_training, + conv_hyperparams=self._conv_hyperparams, + freeze_batchnorm=self._freeze_batchnorm, + name='FeatureMaps')) + self.built = True def preprocess(self, resized_inputs): """SSD preprocessing. diff --git a/research/object_detection/models/ssd_pnasnet_feature_extractor.py b/research/object_detection/models/ssd_pnasnet_feature_extractor.py new file mode 100644 index 00000000000..4f697c7058a --- /dev/null +++ b/research/object_detection/models/ssd_pnasnet_feature_extractor.py @@ -0,0 +1,175 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""SSDFeatureExtractor for PNASNet features. + +Based on PNASNet ImageNet model: https://arxiv.org/abs/1712.00559 +""" + +import tensorflow as tf + +from object_detection.meta_architectures import ssd_meta_arch +from object_detection.models import feature_map_generators +from object_detection.utils import context_manager +from object_detection.utils import ops +from nets.nasnet import pnasnet + +slim = tf.contrib.slim + + +def pnasnet_large_arg_scope_for_detection(is_batch_norm_training=False): + """Defines the default arg scope for the PNASNet Large for object detection. + + This provides a small edit to switch batch norm training on and off. + + Args: + is_batch_norm_training: Boolean indicating whether to train with batch norm. + Default is False. + + Returns: + An `arg_scope` to use for the PNASNet Large Model. + """ + imagenet_scope = pnasnet.pnasnet_large_arg_scope() + with slim.arg_scope(imagenet_scope): + with slim.arg_scope([slim.batch_norm], + is_training=is_batch_norm_training) as sc: + return sc + + +class SSDPNASNetFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor): + """SSD Feature Extractor using PNASNet features.""" + + def __init__(self, + is_training, + depth_multiplier, + min_depth, + pad_to_multiple, + conv_hyperparams_fn, + reuse_weights=None, + use_explicit_padding=False, + use_depthwise=False, + override_base_feature_extractor_hyperparams=False): + """PNASNet Feature Extractor for SSD Models. + + Args: + is_training: whether the network is in training mode. + depth_multiplier: float depth multiplier for feature extractor. + min_depth: minimum feature extractor depth. + pad_to_multiple: the nearest multiple to zero pad the input height and + width dimensions to. + conv_hyperparams_fn: A function to construct tf slim arg_scope for conv2d + and separable_conv2d ops in the layers that are added on top of the + base feature extractor. + reuse_weights: Whether to reuse variables. Default is None. + use_explicit_padding: Use 'VALID' padding for convolutions, but prepad + inputs so that the output dimensions are the same as if 'SAME' padding + were used. + use_depthwise: Whether to use depthwise convolutions. + override_base_feature_extractor_hyperparams: Whether to override + hyperparameters of the base feature extractor with the one from + `conv_hyperparams_fn`. + """ + super(SSDPNASNetFeatureExtractor, self).__init__( + is_training=is_training, + depth_multiplier=depth_multiplier, + min_depth=min_depth, + pad_to_multiple=pad_to_multiple, + conv_hyperparams_fn=conv_hyperparams_fn, + reuse_weights=reuse_weights, + use_explicit_padding=use_explicit_padding, + use_depthwise=use_depthwise, + override_base_feature_extractor_hyperparams= + override_base_feature_extractor_hyperparams) + + def preprocess(self, resized_inputs): + """SSD preprocessing. + + Maps pixel values to the range [-1, 1]. + + Args: + resized_inputs: a [batch, height, width, channels] float tensor + representing a batch of images. + + Returns: + preprocessed_inputs: a [batch, height, width, channels] float tensor + representing a batch of images. + """ + return (2.0 / 255.0) * resized_inputs - 1.0 + + def extract_features(self, preprocessed_inputs): + """Extract features from preprocessed inputs. + + Args: + preprocessed_inputs: a [batch, height, width, channels] float tensor + representing a batch of images. + + Returns: + feature_maps: a list of tensors where the ith tensor has shape + [batch, height_i, width_i, depth_i] + """ + + feature_map_layout = { + 'from_layer': ['Cell_7', 'Cell_11', '', '', '', ''], + 'layer_depth': [-1, -1, 512, 256, 256, 128], + 'use_explicit_padding': self._use_explicit_padding, + 'use_depthwise': self._use_depthwise, + } + + with slim.arg_scope( + pnasnet_large_arg_scope_for_detection( + is_batch_norm_training=self._is_training)): + with slim.arg_scope([slim.conv2d, slim.batch_norm, slim.separable_conv2d], + reuse=self._reuse_weights): + with (slim.arg_scope(self._conv_hyperparams_fn()) + if self._override_base_feature_extractor_hyperparams else + context_manager.IdentityContextManager()): + _, image_features = pnasnet.build_pnasnet_large( + ops.pad_to_multiple(preprocessed_inputs, self._pad_to_multiple), + num_classes=None, + is_training=self._is_training, + final_endpoint='Cell_11') + with tf.variable_scope('SSD_feature_maps', reuse=self._reuse_weights): + with slim.arg_scope(self._conv_hyperparams_fn()): + feature_maps = feature_map_generators.multi_resolution_feature_maps( + feature_map_layout=feature_map_layout, + depth_multiplier=self._depth_multiplier, + min_depth=self._min_depth, + insert_1x1_conv=True, + image_features=image_features) + + return feature_maps.values() + + def restore_from_classification_checkpoint_fn(self, feature_extractor_scope): + """Returns a map of variables to load from a foreign checkpoint. + + Note that this overrides the default implementation in + ssd_meta_arch.SSDFeatureExtractor which does not work for PNASNet + checkpoints. + + Args: + feature_extractor_scope: A scope name for the first stage feature + extractor. + + Returns: + A dict mapping variable names (to load from a checkpoint) to variables in + the model graph. + """ + variables_to_restore = {} + for variable in tf.global_variables(): + if variable.op.name.startswith(feature_extractor_scope): + var_name = variable.op.name.replace(feature_extractor_scope + '/', '') + var_name += '/ExponentialMovingAverage' + variables_to_restore[var_name] = variable + return variables_to_restore diff --git a/research/object_detection/models/ssd_pnasnet_feature_extractor_test.py b/research/object_detection/models/ssd_pnasnet_feature_extractor_test.py new file mode 100644 index 00000000000..6646c2f9012 --- /dev/null +++ b/research/object_detection/models/ssd_pnasnet_feature_extractor_test.py @@ -0,0 +1,87 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for ssd_pnas_feature_extractor.""" +import numpy as np +import tensorflow as tf + +from object_detection.models import ssd_feature_extractor_test +from object_detection.models import ssd_pnasnet_feature_extractor + +slim = tf.contrib.slim + + +class SsdPnasNetFeatureExtractorTest( + ssd_feature_extractor_test.SsdFeatureExtractorTestBase): + + def _create_feature_extractor(self, depth_multiplier, pad_to_multiple, + is_training=True, use_explicit_padding=False): + """Constructs a new feature extractor. + + Args: + depth_multiplier: float depth multiplier for feature extractor + pad_to_multiple: the nearest multiple to zero pad the input height and + width dimensions to. + is_training: whether the network is in training mode. + use_explicit_padding: Use 'VALID' padding for convolutions, but prepad + inputs so that the output dimensions are the same as if 'SAME' padding + were used. + Returns: + an ssd_meta_arch.SSDFeatureExtractor object. + """ + min_depth = 32 + return ssd_pnasnet_feature_extractor.SSDPNASNetFeatureExtractor( + is_training, depth_multiplier, min_depth, pad_to_multiple, + self.conv_hyperparams_fn, + use_explicit_padding=use_explicit_padding) + + def test_extract_features_returns_correct_shapes_128(self): + image_height = 128 + image_width = 128 + depth_multiplier = 1.0 + pad_to_multiple = 1 + expected_feature_map_shape = [(2, 8, 8, 2160), (2, 4, 4, 4320), + (2, 2, 2, 512), (2, 1, 1, 256), + (2, 1, 1, 256), (2, 1, 1, 128)] + self.check_extract_features_returns_correct_shape( + 2, image_height, image_width, depth_multiplier, pad_to_multiple, + expected_feature_map_shape) + + def test_extract_features_returns_correct_shapes_299(self): + image_height = 299 + image_width = 299 + depth_multiplier = 1.0 + pad_to_multiple = 1 + expected_feature_map_shape = [(2, 19, 19, 2160), (2, 10, 10, 4320), + (2, 5, 5, 512), (2, 3, 3, 256), + (2, 2, 2, 256), (2, 1, 1, 128)] + self.check_extract_features_returns_correct_shape( + 2, image_height, image_width, depth_multiplier, pad_to_multiple, + expected_feature_map_shape) + + def test_preprocess_returns_correct_value_range(self): + image_height = 128 + image_width = 128 + depth_multiplier = 1 + pad_to_multiple = 1 + test_image = np.random.rand(2, image_height, image_width, 3) + feature_extractor = self._create_feature_extractor(depth_multiplier, + pad_to_multiple) + preprocessed_image = feature_extractor.preprocess(test_image) + self.assertTrue(np.all(np.less_equal(np.abs(preprocessed_image), 1.0))) + + +if __name__ == '__main__': + tf.test.main() diff --git a/research/object_detection/models/ssd_resnet_v1_fpn_feature_extractor.py b/research/object_detection/models/ssd_resnet_v1_fpn_feature_extractor.py index 73397491e55..a7bc806a19d 100644 --- a/research/object_detection/models/ssd_resnet_v1_fpn_feature_extractor.py +++ b/research/object_detection/models/ssd_resnet_v1_fpn_feature_extractor.py @@ -113,6 +113,8 @@ def preprocess(self, resized_inputs): VGG style channel mean subtraction as described here: https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-readme-mdnge. + Note that if the number of channels is not equal to 3, the mean subtraction + will be skipped and the original resized_inputs will be returned. Args: resized_inputs: a [batch, height, width, channels] float tensor @@ -122,8 +124,11 @@ def preprocess(self, resized_inputs): preprocessed_inputs: a [batch, height, width, channels] float tensor representing a batch of images. """ - channel_means = [123.68, 116.779, 103.939] - return resized_inputs - [[channel_means]] + if resized_inputs.shape.as_list()[3] == 3: + channel_means = [123.68, 116.779, 103.939] + return resized_inputs - [[channel_means]] + else: + return resized_inputs def _filter_features(self, image_features): # TODO(rathodv): Change resnet endpoint to strip scope prefixes instead diff --git a/research/object_detection/models/ssd_resnet_v1_fpn_feature_extractor_testbase.py b/research/object_detection/models/ssd_resnet_v1_fpn_feature_extractor_testbase.py index 186f2b1748b..fd8d9f6d125 100644 --- a/research/object_detection/models/ssd_resnet_v1_fpn_feature_extractor_testbase.py +++ b/research/object_detection/models/ssd_resnet_v1_fpn_feature_extractor_testbase.py @@ -82,12 +82,15 @@ def test_preprocess_returns_correct_value_range(self): image_width = 128 depth_multiplier = 1 pad_to_multiple = 1 - test_image = np.random.rand(4, image_height, image_width, 3) + test_image = tf.constant(np.random.rand(4, image_height, image_width, 3)) feature_extractor = self._create_feature_extractor(depth_multiplier, pad_to_multiple) preprocessed_image = feature_extractor.preprocess(test_image) - self.assertAllClose(preprocessed_image, - test_image - [[123.68, 116.779, 103.939]]) + with self.test_session() as sess: + test_image_out, preprocessed_image_out = sess.run( + [test_image, preprocessed_image]) + self.assertAllClose(preprocessed_image_out, + test_image_out - [[123.68, 116.779, 103.939]]) def test_variables_only_created_in_scope(self): depth_multiplier = 1 @@ -103,5 +106,3 @@ def test_variables_only_created_in_scope(self): self.assertTrue( variable.name.startswith(self._resnet_scope_name()) or variable.name.startswith(self._fpn_scope_name())) - - diff --git a/research/object_detection/models/ssd_resnet_v1_ppn_feature_extractor.py b/research/object_detection/models/ssd_resnet_v1_ppn_feature_extractor.py index 13422503c06..d275852a74c 100644 --- a/research/object_detection/models/ssd_resnet_v1_ppn_feature_extractor.py +++ b/research/object_detection/models/ssd_resnet_v1_ppn_feature_extractor.py @@ -98,6 +98,8 @@ def preprocess(self, resized_inputs): VGG style channel mean subtraction as described here: https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-readme-mdnge. + Note that if the number of channels is not equal to 3, the mean subtraction + will be skipped and the original resized_inputs will be returned. Args: resized_inputs: a [batch, height, width, channels] float tensor @@ -107,8 +109,11 @@ def preprocess(self, resized_inputs): preprocessed_inputs: a [batch, height, width, channels] float tensor representing a batch of images. """ - channel_means = [123.68, 116.779, 103.939] - return resized_inputs - [[channel_means]] + if resized_inputs.shape.as_list()[3] == 3: + channel_means = [123.68, 116.779, 103.939] + return resized_inputs - [[channel_means]] + else: + return resized_inputs def extract_features(self, preprocessed_inputs): """Extract features from preprocessed inputs. diff --git a/research/object_detection/models/ssd_resnet_v1_ppn_feature_extractor_testbase.py b/research/object_detection/models/ssd_resnet_v1_ppn_feature_extractor_testbase.py index e8d0db5f883..3857fc708d4 100644 --- a/research/object_detection/models/ssd_resnet_v1_ppn_feature_extractor_testbase.py +++ b/research/object_detection/models/ssd_resnet_v1_ppn_feature_extractor_testbase.py @@ -15,6 +15,7 @@ """Tests for ssd resnet v1 feature extractors.""" import abc import numpy as np +import tensorflow as tf from object_detection.models import ssd_feature_extractor_test @@ -64,12 +65,15 @@ def test_preprocess_returns_correct_value_range(self): image_width = 128 depth_multiplier = 1 pad_to_multiple = 1 - test_image = np.random.rand(4, image_height, image_width, 3) + test_image = tf.constant(np.random.rand(4, image_height, image_width, 3)) feature_extractor = self._create_feature_extractor(depth_multiplier, pad_to_multiple) preprocessed_image = feature_extractor.preprocess(test_image) - self.assertAllClose(preprocessed_image, - test_image - [[123.68, 116.779, 103.939]]) + with self.test_session() as sess: + test_image_out, preprocessed_image_out = sess.run( + [test_image, preprocessed_image]) + self.assertAllClose(preprocessed_image_out, + test_image_out - [[123.68, 116.779, 103.939]]) def test_variables_only_created_in_scope(self): depth_multiplier = 1 diff --git a/research/object_detection/predictors/convolutional_keras_box_predictor.py b/research/object_detection/predictors/convolutional_keras_box_predictor.py index 994a78f0de0..fe7cba10c43 100644 --- a/research/object_detection/predictors/convolutional_keras_box_predictor.py +++ b/research/object_detection/predictors/convolutional_keras_box_predictor.py @@ -134,26 +134,32 @@ def build(self, input_shapes): (len(self._prediction_heads[BOX_ENCODINGS]), len(input_shapes))) for stack_index, input_shape in enumerate(input_shapes): - net = tf.keras.Sequential(name='PreHeadConvolutions_%d' % stack_index) - self._shared_nets.append(net) + net = [] # Add additional conv layers before the class predictor. features_depth = static_shape.get_depth(input_shape) depth = max(min(features_depth, self._max_depth), self._min_depth) tf.logging.info( 'depth of additional conv before box predictor: {}'.format(depth)) + if depth > 0 and self._num_layers_before_predictor > 0: for i in range(self._num_layers_before_predictor): - net.add(keras.Conv2D(depth, [1, 1], - name='Conv2d_%d_1x1_%d' % (i, depth), - padding='SAME', - **self._conv_hyperparams.params())) - net.add(self._conv_hyperparams.build_batch_norm( + net.append(keras.Conv2D(depth, [1, 1], + name='SharedConvolutions_%d/Conv2d_%d_1x1_%d' + % (stack_index, i, depth), + padding='SAME', + **self._conv_hyperparams.params())) + net.append(self._conv_hyperparams.build_batch_norm( training=(self._is_training and not self._freeze_batchnorm), - name='Conv2d_%d_1x1_%d_norm' % (i, depth))) - net.add(self._conv_hyperparams.build_activation_layer( - name='Conv2d_%d_1x1_%d_activation' % (i, depth), + name='SharedConvolutions_%d/Conv2d_%d_1x1_%d_norm' + % (stack_index, i, depth))) + net.append(self._conv_hyperparams.build_activation_layer( + name='SharedConvolutions_%d/Conv2d_%d_1x1_%d_activation' + % (stack_index, i, depth), )) + # Until certain bugs are fixed in checkpointable lists, + # this net must be appended only once it's been filled with layers + self._shared_nets.append(net) self.built = True def _predict(self, image_features): @@ -175,10 +181,11 @@ def _predict(self, image_features): """ predictions = collections.defaultdict(list) - for (index, image_feature) in enumerate(image_features): + for (index, net) in enumerate(image_features): # Apply shared conv layers before the head predictors. - net = self._shared_nets[index](image_feature) + for layer in self._shared_nets[index]: + net = layer(net) for head_name in self._prediction_heads: head_obj = self._prediction_heads[head_name][index] diff --git a/research/object_detection/predictors/convolutional_keras_box_predictor_test.py b/research/object_detection/predictors/convolutional_keras_box_predictor_test.py index 08daef254b0..aeb6994ffc8 100644 --- a/research/object_detection/predictors/convolutional_keras_box_predictor_test.py +++ b/research/object_detection/predictors/convolutional_keras_box_predictor_test.py @@ -181,8 +181,8 @@ def test_get_predictions_with_feature_maps_of_dynamic_shape( self.assertAllEqual(objectness_predictions_shape, [4, expected_num_anchors, 1]) expected_variable_set = set([ - 'BoxPredictor/PreHeadConvolutions_0/Conv2d_0_1x1_32/bias', - 'BoxPredictor/PreHeadConvolutions_0/Conv2d_0_1x1_32/kernel', + 'BoxPredictor/SharedConvolutions_0/Conv2d_0_1x1_32/bias', + 'BoxPredictor/SharedConvolutions_0/Conv2d_0_1x1_32/kernel', 'BoxPredictor/ConvolutionalBoxHead_0/BoxEncodingPredictor/bias', 'BoxPredictor/ConvolutionalBoxHead_0/BoxEncodingPredictor/kernel', 'BoxPredictor/ConvolutionalClassHead_0/ClassPredictor/bias', diff --git a/research/object_detection/predictors/heads/class_head.py b/research/object_detection/predictors/heads/class_head.py index ed28cbc4072..ad41203b5e6 100644 --- a/research/object_detection/predictors/heads/class_head.py +++ b/research/object_detection/predictors/heads/class_head.py @@ -34,16 +34,18 @@ class MaskRCNNClassHead(head.Head): https://arxiv.org/abs/1703.06870 """ - def __init__(self, is_training, num_classes, fc_hyperparams_fn, - use_dropout, dropout_keep_prob): + def __init__(self, + is_training, + num_class_slots, + fc_hyperparams_fn, + use_dropout, + dropout_keep_prob): """Constructor. Args: is_training: Indicates whether the BoxPredictor is in training mode. - num_classes: number of classes. Note that num_classes *does not* - include the background category, so if groundtruth labels take values - in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the - assigned classification targets can range from {0,... K}). + num_class_slots: number of class slots. Note that num_class_slots may or + may not include an implicit background category. fc_hyperparams_fn: A function to generate tf-slim arg_scope with hyperparameters for fully connected ops. use_dropout: Option to use dropout or not. Note that a single dropout @@ -54,7 +56,7 @@ def __init__(self, is_training, num_classes, fc_hyperparams_fn, """ super(MaskRCNNClassHead, self).__init__() self._is_training = is_training - self._num_classes = num_classes + self._num_class_slots = num_class_slots self._fc_hyperparams_fn = fc_hyperparams_fn self._use_dropout = use_dropout self._dropout_keep_prob = dropout_keep_prob @@ -70,7 +72,7 @@ def predict(self, features, num_predictions_per_location=1): Returns: class_predictions_with_background: A float tensor of shape - [batch_size, 1, num_classes + 1] representing the class predictions for + [batch_size, 1, num_class_slots] representing the class predictions for the proposals. Raises: @@ -91,11 +93,12 @@ def predict(self, features, num_predictions_per_location=1): with slim.arg_scope(self._fc_hyperparams_fn()): class_predictions_with_background = slim.fully_connected( flattened_roi_pooled_features, - self._num_classes + 1, + self._num_class_slots, activation_fn=None, scope='ClassPredictor') class_predictions_with_background = tf.reshape( - class_predictions_with_background, [-1, 1, self._num_classes + 1]) + class_predictions_with_background, + [-1, 1, self._num_class_slots]) return class_predictions_with_background @@ -104,7 +107,7 @@ class ConvolutionalClassHead(head.Head): def __init__(self, is_training, - num_classes, + num_class_slots, use_dropout, dropout_keep_prob, kernel_size, @@ -115,7 +118,8 @@ def __init__(self, Args: is_training: Indicates whether the BoxPredictor is in training mode. - num_classes: Number of classes. + num_class_slots: number of class slots. Note that num_class_slots may or + may not include an implicit background category. use_dropout: Option to use dropout or not. Note that a single dropout op is applied here prior to both box and class predictions, which stands in contrast to the ConvolutionalBoxPredictor below. @@ -137,7 +141,7 @@ def __init__(self, """ super(ConvolutionalClassHead, self).__init__() self._is_training = is_training - self._num_classes = num_classes + self._num_class_slots = num_class_slots self._use_dropout = use_dropout self._dropout_keep_prob = dropout_keep_prob self._kernel_size = kernel_size @@ -156,12 +160,10 @@ def predict(self, features, num_predictions_per_location): Returns: class_predictions_with_background: A float tensors of shape - [batch_size, num_anchors, num_classes + 1] representing the class + [batch_size, num_anchors, num_class_slots] representing the class predictions for the proposals. """ net = features - # Add a slot for the background class. - num_class_slots = self._num_classes + 1 if self._use_dropout: net = slim.dropout(net, keep_prob=self._dropout_keep_prob) if self._use_depthwise: @@ -171,7 +173,7 @@ def predict(self, features, num_predictions_per_location): rate=1, scope='ClassPredictor_depthwise') class_predictions_with_background = slim.conv2d( class_predictions_with_background, - num_predictions_per_location * num_class_slots, [1, 1], + num_predictions_per_location * self._num_class_slots, [1, 1], activation_fn=None, normalizer_fn=None, normalizer_params=None, @@ -179,7 +181,7 @@ def predict(self, features, num_predictions_per_location): else: class_predictions_with_background = slim.conv2d( net, - num_predictions_per_location * num_class_slots, + num_predictions_per_location * self._num_class_slots, [self._kernel_size, self._kernel_size], activation_fn=None, normalizer_fn=None, @@ -194,7 +196,8 @@ def predict(self, features, num_predictions_per_location): if batch_size is None: batch_size = tf.shape(features)[0] class_predictions_with_background = tf.reshape( - class_predictions_with_background, [batch_size, -1, num_class_slots]) + class_predictions_with_background, + [batch_size, -1, self._num_class_slots]) return class_predictions_with_background @@ -208,7 +211,7 @@ class WeightSharedConvolutionalClassHead(head.Head): """ def __init__(self, - num_classes, + num_class_slots, kernel_size=3, class_prediction_bias_init=0.0, use_dropout=False, @@ -218,10 +221,8 @@ def __init__(self, """Constructor. Args: - num_classes: number of classes. Note that num_classes *does not* - include the background category, so if groundtruth labels take values - in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the - assigned classification targets can range from {0,... K}). + num_class_slots: number of class slots. Note that num_class_slots may or + may not include an implicit background category. kernel_size: Size of final convolution kernel. class_prediction_bias_init: constant value to initialize bias of the last conv2d layer before class prediction. @@ -233,7 +234,7 @@ def __init__(self, as inputs and returns tensors). """ super(WeightSharedConvolutionalClassHead, self).__init__() - self._num_classes = num_classes + self._num_class_slots = num_class_slots self._kernel_size = kernel_size self._class_prediction_bias_init = class_prediction_bias_init self._use_dropout = use_dropout @@ -252,12 +253,10 @@ def predict(self, features, num_predictions_per_location): Returns: class_predictions_with_background: A tensor of shape - [batch_size, num_anchors, num_classes + 1] representing the class + [batch_size, num_anchors, num_class_slots] representing the class predictions for the proposals. """ class_predictions_net = features - num_class_slots = self._num_classes + 1 - # Add a slot for the background class. if self._use_dropout: class_predictions_net = slim.dropout( class_predictions_net, keep_prob=self._dropout_keep_prob) @@ -267,7 +266,7 @@ def predict(self, features, num_predictions_per_location): conv_op = slim.conv2d class_predictions_with_background = conv_op( class_predictions_net, - num_predictions_per_location * num_class_slots, + num_predictions_per_location * self._num_class_slots, [self._kernel_size, self._kernel_size], activation_fn=None, stride=1, padding='SAME', normalizer_fn=None, @@ -280,5 +279,6 @@ def predict(self, features, num_predictions_per_location): class_predictions_with_background = self._score_converter_fn( class_predictions_with_background) class_predictions_with_background = tf.reshape( - class_predictions_with_background, [batch_size, -1, num_class_slots]) + class_predictions_with_background, + [batch_size, -1, self._num_class_slots]) return class_predictions_with_background diff --git a/research/object_detection/predictors/heads/class_head_test.py b/research/object_detection/predictors/heads/class_head_test.py index 737e098425f..270dc5d1b26 100644 --- a/research/object_detection/predictors/heads/class_head_test.py +++ b/research/object_detection/predictors/heads/class_head_test.py @@ -46,7 +46,7 @@ def _build_arg_scope_with_hyperparams(self, def test_prediction_size(self): class_prediction_head = class_head.MaskRCNNClassHead( is_training=False, - num_classes=20, + num_class_slots=20, fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(), use_dropout=True, dropout_keep_prob=0.5) @@ -54,7 +54,7 @@ def test_prediction_size(self): [64, 7, 7, 1024], minval=-10.0, maxval=10.0, dtype=tf.float32) prediction = class_prediction_head.predict( features=roi_pooled_features, num_predictions_per_location=1) - self.assertAllEqual([64, 1, 21], prediction.get_shape().as_list()) + self.assertAllEqual([64, 1, 20], prediction.get_shape().as_list()) class ConvolutionalClassPredictorTest(test_case.TestCase): @@ -80,7 +80,7 @@ def _build_arg_scope_with_hyperparams( def test_prediction_size(self): class_prediction_head = class_head.ConvolutionalClassHead( is_training=True, - num_classes=20, + num_class_slots=20, use_dropout=True, dropout_keep_prob=0.5, kernel_size=3) @@ -89,7 +89,7 @@ def test_prediction_size(self): class_predictions = class_prediction_head.predict( features=image_feature, num_predictions_per_location=1) - self.assertAllEqual([64, 323, 21], + self.assertAllEqual([64, 323, 20], class_predictions.get_shape().as_list()) @@ -115,13 +115,13 @@ def _build_arg_scope_with_hyperparams( def test_prediction_size(self): class_prediction_head = ( - class_head.WeightSharedConvolutionalClassHead(num_classes=20)) + class_head.WeightSharedConvolutionalClassHead(num_class_slots=20)) image_feature = tf.random_uniform( [64, 17, 19, 1024], minval=-10.0, maxval=10.0, dtype=tf.float32) class_predictions = class_prediction_head.predict( features=image_feature, num_predictions_per_location=1) - self.assertAllEqual([64, 323, 21], class_predictions.get_shape().as_list()) + self.assertAllEqual([64, 323, 20], class_predictions.get_shape().as_list()) if __name__ == '__main__': diff --git a/research/object_detection/predictors/heads/keras_box_head.py b/research/object_detection/predictors/heads/keras_box_head.py index 3e3798fb433..da311478da6 100644 --- a/research/object_detection/predictors/heads/keras_box_head.py +++ b/research/object_detection/predictors/heads/keras_box_head.py @@ -91,7 +91,7 @@ def __init__(self, tf.keras.layers.Conv2D( num_predictions_per_location * self._box_code_size, [1, 1], name='BoxEncodingPredictor', - **conv_hyperparams.params(activation=None))) + **conv_hyperparams.params(use_bias=True))) else: self._box_encoder_layers.append( tf.keras.layers.Conv2D( @@ -99,7 +99,7 @@ def __init__(self, [self._kernel_size, self._kernel_size], padding='SAME', name='BoxEncodingPredictor', - **conv_hyperparams.params(activation=None))) + **conv_hyperparams.params(use_bias=True))) def _predict(self, features): """Predicts boxes. diff --git a/research/object_detection/predictors/heads/keras_class_head.py b/research/object_detection/predictors/heads/keras_class_head.py index edb4e12b39f..f157254e126 100644 --- a/research/object_detection/predictors/heads/keras_class_head.py +++ b/research/object_detection/predictors/heads/keras_class_head.py @@ -29,7 +29,7 @@ class ConvolutionalClassHead(head.KerasHead): def __init__(self, is_training, - num_classes, + num_class_slots, use_dropout, dropout_keep_prob, kernel_size, @@ -43,7 +43,8 @@ def __init__(self, Args: is_training: Indicates whether the BoxPredictor is in training mode. - num_classes: Number of classes. + num_class_slots: number of class slots. Note that num_class_slots may or + may not include an implicit background category. use_dropout: Option to use dropout or not. Note that a single dropout op is applied here prior to both box and class predictions, which stands in contrast to the ConvolutionalBoxPredictor below. @@ -73,13 +74,12 @@ def __init__(self, """ super(ConvolutionalClassHead, self).__init__(name=name) self._is_training = is_training - self._num_classes = num_classes self._use_dropout = use_dropout self._dropout_keep_prob = dropout_keep_prob self._kernel_size = kernel_size self._class_prediction_bias_init = class_prediction_bias_init self._use_depthwise = use_depthwise - self._num_class_slots = self._num_classes + 1 + self._num_class_slots = num_class_slots self._class_predictor_layers = [] @@ -110,7 +110,7 @@ def __init__(self, tf.keras.layers.Conv2D( num_predictions_per_location * self._num_class_slots, [1, 1], name='ClassPredictor', - **conv_hyperparams.params(activation=None))) + **conv_hyperparams.params(use_bias=True))) else: self._class_predictor_layers.append( tf.keras.layers.Conv2D( @@ -120,7 +120,7 @@ def __init__(self, name='ClassPredictor', bias_initializer=tf.constant_initializer( self._class_prediction_bias_init), - **conv_hyperparams.params(activation=None))) + **conv_hyperparams.params(use_bias=True))) def _predict(self, features): """Predicts boxes. @@ -131,7 +131,7 @@ def _predict(self, features): Returns: class_predictions_with_background: A float tensor of shape - [batch_size, num_anchors, num_classes + 1] representing the class + [batch_size, num_anchors, num_class_slots] representing the class predictions for the proposals. """ # Add a slot for the background class. diff --git a/research/object_detection/predictors/heads/keras_class_head_test.py b/research/object_detection/predictors/heads/keras_class_head_test.py index 1943c12a2c1..125fcdf8c89 100644 --- a/research/object_detection/predictors/heads/keras_class_head_test.py +++ b/research/object_detection/predictors/heads/keras_class_head_test.py @@ -45,7 +45,7 @@ def test_prediction_size_depthwise_false(self): conv_hyperparams = self._build_conv_hyperparams() class_prediction_head = keras_class_head.ConvolutionalClassHead( is_training=True, - num_classes=20, + num_class_slots=20, use_dropout=True, dropout_keep_prob=0.5, kernel_size=3, @@ -56,7 +56,7 @@ def test_prediction_size_depthwise_false(self): image_feature = tf.random_uniform( [64, 17, 19, 1024], minval=-10.0, maxval=10.0, dtype=tf.float32) class_predictions = class_prediction_head(image_feature,) - self.assertAllEqual([64, 323, 21], + self.assertAllEqual([64, 323, 20], class_predictions.get_shape().as_list()) # TODO(kaftan): Remove conditional after CMLE moves to TF 1.10 diff --git a/research/object_detection/predictors/heads/keras_mask_head.py b/research/object_detection/predictors/heads/keras_mask_head.py index 0162da0b930..fa4b1d1efb8 100644 --- a/research/object_detection/predictors/heads/keras_mask_head.py +++ b/research/object_detection/predictors/heads/keras_mask_head.py @@ -124,7 +124,7 @@ def __init__(self, tf.keras.layers.Conv2D( num_predictions_per_location * num_mask_channels, [1, 1], name='MaskPredictor', - **conv_hyperparams.params(activation=None))) + **conv_hyperparams.params(use_bias=True))) else: self._mask_predictor_layers.append( tf.keras.layers.Conv2D( @@ -132,7 +132,7 @@ def __init__(self, [self._kernel_size, self._kernel_size], padding='SAME', name='MaskPredictor', - **conv_hyperparams.params(activation=None))) + **conv_hyperparams.params(use_bias=True))) def _predict(self, features): """Predicts boxes. diff --git a/research/object_detection/predictors/heads/mask_head.py b/research/object_detection/predictors/heads/mask_head.py index 5fc40a40c61..c97d08eb2be 100644 --- a/research/object_detection/predictors/heads/mask_head.py +++ b/research/object_detection/predictors/heads/mask_head.py @@ -23,6 +23,7 @@ import tensorflow as tf from object_detection.predictors.heads import head +from object_detection.utils import ops slim = tf.contrib.slim @@ -41,7 +42,8 @@ def __init__(self, mask_width=14, mask_prediction_num_conv_layers=2, mask_prediction_conv_depth=256, - masks_are_class_agnostic=False): + masks_are_class_agnostic=False, + convolve_then_upsample=False): """Constructor. Args: @@ -62,6 +64,10 @@ def __init__(self, image features. masks_are_class_agnostic: Boolean determining if the mask-head is class-agnostic or not. + convolve_then_upsample: Whether to apply convolutions on mask features + before upsampling using nearest neighbor resizing. Otherwise, mask + features are resized to [`mask_height`, `mask_width`] using bilinear + resizing before applying convolutions. Raises: ValueError: conv_hyperparams_fn is None. @@ -74,6 +80,7 @@ def __init__(self, self._mask_prediction_num_conv_layers = mask_prediction_num_conv_layers self._mask_prediction_conv_depth = mask_prediction_conv_depth self._masks_are_class_agnostic = masks_are_class_agnostic + self._convolve_then_upsample = convolve_then_upsample if conv_hyperparams_fn is None: raise ValueError('conv_hyperparams_fn is None.') @@ -135,17 +142,30 @@ def predict(self, features, num_predictions_per_location=1): num_conv_channels = self._get_mask_predictor_conv_depth( num_feature_channels, self._num_classes) with slim.arg_scope(self._conv_hyperparams_fn()): - upsampled_features = tf.image.resize_bilinear( - features, [self._mask_height, self._mask_width], - align_corners=True) + if not self._convolve_then_upsample: + features = tf.image.resize_bilinear( + features, [self._mask_height, self._mask_width], + align_corners=True) for _ in range(self._mask_prediction_num_conv_layers - 1): - upsampled_features = slim.conv2d( - upsampled_features, + features = slim.conv2d( + features, + num_outputs=num_conv_channels, + kernel_size=[3, 3]) + if self._convolve_then_upsample: + # Replace Transposed Convolution with a Nearest Neighbor upsampling step + # followed by 3x3 convolution. + height_scale = self._mask_height / features.shape[1].value + width_scale = self._mask_width / features.shape[2].value + features = ops.nearest_neighbor_upsampling( + features, height_scale=height_scale, width_scale=width_scale) + features = slim.conv2d( + features, num_outputs=num_conv_channels, kernel_size=[3, 3]) + num_masks = 1 if self._masks_are_class_agnostic else self._num_classes mask_predictions = slim.conv2d( - upsampled_features, + features, num_outputs=num_masks, activation_fn=None, normalizer_fn=None, diff --git a/research/object_detection/predictors/heads/mask_head_test.py b/research/object_detection/predictors/heads/mask_head_test.py index c9e4c70eff9..ae46d6ad7a0 100644 --- a/research/object_detection/predictors/heads/mask_head_test.py +++ b/research/object_detection/predictors/heads/mask_head_test.py @@ -58,6 +58,22 @@ def test_prediction_size(self): features=roi_pooled_features, num_predictions_per_location=1) self.assertAllEqual([64, 1, 20, 14, 14], prediction.get_shape().as_list()) + def test_prediction_size_with_convolve_then_upsample(self): + mask_prediction_head = mask_head.MaskRCNNMaskHead( + num_classes=20, + conv_hyperparams_fn=self._build_arg_scope_with_hyperparams(), + mask_height=28, + mask_width=28, + mask_prediction_num_conv_layers=2, + mask_prediction_conv_depth=256, + masks_are_class_agnostic=True, + convolve_then_upsample=True) + roi_pooled_features = tf.random_uniform( + [64, 14, 14, 1024], minval=-10.0, maxval=10.0, dtype=tf.float32) + prediction = mask_prediction_head.predict( + features=roi_pooled_features, num_predictions_per_location=1) + self.assertAllEqual([64, 1, 1, 28, 28], prediction.get_shape().as_list()) + class ConvolutionalMaskPredictorTest(test_case.TestCase): diff --git a/research/object_detection/protos/box_predictor.proto b/research/object_detection/protos/box_predictor.proto index d5665c94049..dfa12e8b9a0 100644 --- a/research/object_detection/protos/box_predictor.proto +++ b/research/object_detection/protos/box_predictor.proto @@ -138,6 +138,7 @@ message WeightSharedConvolutionalBoxPredictor { // TODO(alirezafathi): Refactor the proto file to be able to configure mask rcnn // head easily. +// Next id: 15 message MaskRCNNBoxPredictor { // Hyperparameters for fully connected ops used in the box predictor. optional Hyperparams fc_hyperparams = 1; @@ -178,6 +179,12 @@ message MaskRCNNBoxPredictor { // Whether to use one box for all classes rather than a different box for each // class. optional bool share_box_across_classes = 13 [default = false]; + + // Whether to apply convolutions on mask features before upsampling using + // nearest neighbor resizing. + // By default, mask features are resized to [`mask_height`, `mask_width`] + // before applying convolutions and predicting masks. + optional bool convolve_then_upsample_masks = 14 [default = false]; } message RfcnBoxPredictor { diff --git a/research/object_detection/protos/faster_rcnn.proto b/research/object_detection/protos/faster_rcnn.proto index fae1fc23a1f..be5a61a1710 100644 --- a/research/object_detection/protos/faster_rcnn.proto +++ b/research/object_detection/protos/faster_rcnn.proto @@ -164,6 +164,10 @@ message FasterRcnn { // Whether the masks present in groundtruth should be resized in the model to // match the image size. optional bool resize_masks = 36 [default = true]; + + // If True, uses implementation of ops with static shape guarantees when + // running evaluation (specifically not is_training if False). + optional bool use_static_shapes_for_eval = 37 [default = false]; } diff --git a/research/object_detection/protos/preprocessor.proto b/research/object_detection/protos/preprocessor.proto index 795d65684a6..f1b6c4168ec 100644 --- a/research/object_detection/protos/preprocessor.proto +++ b/research/object_detection/protos/preprocessor.proto @@ -155,6 +155,9 @@ message RandomCropImage { // value, it is removed from the new image. optional float overlap_thresh = 6 [default=0.3]; + // Whether to clip the boxes to the cropped image. + optional bool clip_boxes = 8 [default=true]; + // Probability of keeping the original image. optional float random_coef = 7 [default=0.0]; } @@ -194,6 +197,9 @@ message RandomCropPadImage { // value, it is removed from the new image. optional float overlap_thresh = 6 [default=0.3]; + // Whether to clip the boxes to the cropped image. + optional bool clip_boxes = 11 [default=true]; + // Probability of keeping the original image during the crop operation. optional float random_coef = 7 [default=0.0]; @@ -217,6 +223,9 @@ message RandomCropToAspectRatio { // ratio between a cropped bounding box and the original is less than this // value, it is removed from the new image. optional float overlap_thresh = 2 [default=0.3]; + + // Whether to clip the boxes to the cropped image. + optional bool clip_boxes = 3 [default=true]; } // Randomly adds black square patches to an image. @@ -285,6 +294,9 @@ message SSDRandomCropOperation { // Cropped box area ratio must be above this threhold to be kept. optional float overlap_thresh = 6; + // Whether to clip the boxes to the cropped image. + optional bool clip_boxes = 8 [default=true]; + // Probability a crop operation is skipped. optional float random_coef = 7; } @@ -315,6 +327,9 @@ message SSDRandomCropPadOperation { // Cropped box area ratio must be above this threhold to be kept. optional float overlap_thresh = 6; + // Whether to clip the boxes to the cropped image. + optional bool clip_boxes = 13 [default=true]; + // Probability a crop operation is skipped. optional float random_coef = 7; @@ -353,6 +368,9 @@ message SSDRandomCropFixedAspectRatioOperation { // Cropped box area ratio must be above this threhold to be kept. optional float overlap_thresh = 6; + // Whether to clip the boxes to the cropped image. + optional bool clip_boxes = 8 [default=true]; + // Probability a crop operation is skipped. optional float random_coef = 7; } @@ -387,6 +405,9 @@ message SSDRandomCropPadFixedAspectRatioOperation { // Cropped box area ratio must be above this threhold to be kept. optional float overlap_thresh = 6; + // Whether to clip the boxes to the cropped image. + optional bool clip_boxes = 8 [default=true]; + // Probability a crop operation is skipped. optional float random_coef = 7; } diff --git a/research/object_detection/protos/ssd.proto b/research/object_detection/protos/ssd.proto index 03fe2842503..f33f0d2c7bf 100644 --- a/research/object_detection/protos/ssd.proto +++ b/research/object_detection/protos/ssd.proto @@ -12,7 +12,7 @@ import "object_detection/protos/post_processing.proto"; import "object_detection/protos/region_similarity_calculator.proto"; // Configuration for Single Shot Detection (SSD) models. -// Next id: 21 +// Next id: 22 message Ssd { // Number of classes to predict. @@ -92,11 +92,17 @@ message Ssd { // Minimum number of effective negative samples. // Only applies if use_expected_classification_loss_under_sampling is true. - optional float minimum_negative_sampling = 19 [default=0]; + optional float min_num_negative_samples = 19 [default=0]; // Desired number of effective negative samples per positive sample. // Only applies if use_expected_classification_loss_under_sampling is true. optional float desired_negative_sampling_ratio = 20 [default=3]; + + // Whether to add an implicit background class to one-hot encodings of + // groundtruth labels. Set to false if using groundtruth labels with an + // explicit background class, using multiclass scores, or if training a single + // class model. + optional bool add_background_class = 21 [default = true]; } diff --git a/research/object_detection/protos/train.proto b/research/object_detection/protos/train.proto index 73a01781295..dcd4df241e2 100644 --- a/research/object_detection/protos/train.proto +++ b/research/object_detection/protos/train.proto @@ -6,7 +6,7 @@ import "object_detection/protos/optimizer.proto"; import "object_detection/protos/preprocessor.proto"; // Message for configuring DetectionModel training jobs (train.py). -// Next id: 27 +// Next id: 28 message TrainConfig { // Effective batch size to use for training. // For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be @@ -115,4 +115,7 @@ message TrainConfig { // Whether to use bfloat16 for training. optional bool use_bfloat16 = 26 [default=false]; + + // Whether to summarize gradients. + optional bool summarize_gradients = 27 [default=false]; } diff --git a/research/object_detection/samples/configs/facessd_mobilenet_v2_quantized_320x320_open_image_v4.config b/research/object_detection/samples/configs/facessd_mobilenet_v2_quantized_320x320_open_image_v4.config new file mode 100644 index 00000000000..8cf59027b56 --- /dev/null +++ b/research/object_detection/samples/configs/facessd_mobilenet_v2_quantized_320x320_open_image_v4.config @@ -0,0 +1,211 @@ +# Quantized trained SSD with Mobilenet v2 on Open Images v4. +# Non-face boxes are dropped during training and non-face groundtruth boxes are +# ignored when evaluating. +# +# Users should configure the fine_tune_checkpoint field in the train config as +# well as the label_map_path and input_path fields in the train_input_reader and +# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that +# should be configured. + +model { + ssd { + num_classes: 1 + image_resizer { + fixed_shape_resizer { + height: 320 + width: 320 + } + } + feature_extractor { + type: "ssd_mobilenet_v2" + depth_multiplier: 1.0 + min_depth: 16 + conv_hyperparams { + regularizer { + l2_regularizer { + weight: 4.0e-05 + } + } + initializer { + truncated_normal_initializer { + mean: 0.0 + stddev: 0.03 + } + } + activation: RELU_6 + batch_norm { + decay: 0.9997 + center: true + scale: true + epsilon: 0.001 + train: true + } + } + pad_to_multiple: 32 + use_explicit_padding: true + } + box_coder { + faster_rcnn_box_coder { + y_scale: 10.0 + x_scale: 10.0 + height_scale: 5.0 + width_scale: 5.0 + } + } + matcher { + argmax_matcher { + matched_threshold: 0.5 + unmatched_threshold: 0.5 + ignore_thresholds: false + negatives_lower_than_unmatched: true + force_match_for_each_row: true + } + } + similarity_calculator { + iou_similarity { + } + } + box_predictor { + convolutional_box_predictor { + conv_hyperparams { + regularizer { + l2_regularizer { + weight: 4.0e-05 + } + } + initializer { + truncated_normal_initializer { + mean: 0.0 + stddev: 0.03 + } + } + activation: RELU_6 + batch_norm { + decay: 0.9997 + center: true + scale: true + epsilon: 0.001 + train: true + } + } + min_depth: 0 + max_depth: 0 + num_layers_before_predictor: 0 + use_dropout: false + kernel_size: 3 + box_code_size: 4 + apply_sigmoid_to_scores: false + } + } + anchor_generator { + ssd_anchor_generator { + num_layers: 6 + min_scale: 0.2 + max_scale: 0.95 + aspect_ratios: 1.0 + aspect_ratios: 2.0 + aspect_ratios: 0.5 + aspect_ratios: 3.0 + aspect_ratios: 0.3333 + height_stride: 16 + height_stride: 32 + height_stride: 64 + height_stride: 128 + height_stride: 256 + height_stride: 512 + width_stride: 16 + width_stride: 32 + width_stride: 64 + width_stride: 128 + width_stride: 256 + width_stride: 512 + } + } + post_processing { + batch_non_max_suppression { + score_threshold: 1.0e-08 + iou_threshold: 0.5 + max_detections_per_class: 100 + max_total_detections: 100 + } + score_converter: SIGMOID + } + normalize_loss_by_num_matches: true + loss { + localization_loss { + weighted_smooth_l1 { + } + } + classification_loss { + weighted_sigmoid { + } + } + hard_example_miner { + num_hard_examples: 3000 + iou_threshold: 0.99 + loss_type: CLASSIFICATION + max_negatives_per_positive: 3 + min_negatives_per_image: 10 + } + classification_weight: 1.0 + localization_weight: 1.0 + } + } +} +train_config { + batch_size: 32 + data_augmentation_options { + random_horizontal_flip { + keypoint_flip_permutation: 1 + keypoint_flip_permutation: 0 + keypoint_flip_permutation: 2 + keypoint_flip_permutation: 3 + keypoint_flip_permutation: 5 + keypoint_flip_permutation: 4 + } + } + data_augmentation_options { + ssd_random_crop_fixed_aspect_ratio { + } + } + optimizer { + rms_prop_optimizer { + learning_rate { + exponential_decay_learning_rate { + initial_learning_rate: 0.004 + decay_steps: 800720 + decay_factor: 0.95 + } + } + momentum_optimizer_value: 0.9 + decay: 0.9 + epsilon: 1.0 + } + } + fine_tune_checkpoint: "" +} +train_input_reader { + label_map_path: "PATH_TO_BE_CONFIGURED/face_label_map.pbtxt" + tf_record_input_reader { + input_path: "PATH_TO_BE_CONFIGURED/face_train.record-?????-of-00100" + } +} +eval_config { + metrics_set: "coco_detection_metrics" + use_moving_averages: true +} +eval_input_reader { + label_map_path: "PATH_TO_BE_CONFIGURED/face_label_map.pbtxt" + shuffle: false + num_readers: 1 + tf_record_input_reader { + input_path: "PATH_TO_BE_CONFIGURED/face_val.record-?????-of-00010" + } +} +graph_rewriter { + quantization { + delay: 500000 + weight_bits: 8 + activation_bits: 8 + } +} diff --git a/research/object_detection/samples/configs/ssd_mobilenet_v2_quantized_300x300_coco.config b/research/object_detection/samples/configs/ssd_mobilenet_v2_quantized_300x300_coco.config new file mode 100644 index 00000000000..e8d8f9d564f --- /dev/null +++ b/research/object_detection/samples/configs/ssd_mobilenet_v2_quantized_300x300_coco.config @@ -0,0 +1,202 @@ +# Quantized trained SSD with Mobilenet v2 on MSCOCO Dataset. +# Users should configure the fine_tune_checkpoint field in the train config as +# well as the label_map_path and input_path fields in the train_input_reader and +# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that +# should be configured. + +model { + ssd { + num_classes: 90 + box_coder { + faster_rcnn_box_coder { + y_scale: 10.0 + x_scale: 10.0 + height_scale: 5.0 + width_scale: 5.0 + } + } + matcher { + argmax_matcher { + matched_threshold: 0.5 + unmatched_threshold: 0.5 + ignore_thresholds: false + negatives_lower_than_unmatched: true + force_match_for_each_row: true + } + } + similarity_calculator { + iou_similarity { + } + } + anchor_generator { + ssd_anchor_generator { + num_layers: 6 + min_scale: 0.2 + max_scale: 0.95 + aspect_ratios: 1.0 + aspect_ratios: 2.0 + aspect_ratios: 0.5 + aspect_ratios: 3.0 + aspect_ratios: 0.3333 + } + } + image_resizer { + fixed_shape_resizer { + height: 300 + width: 300 + } + } + box_predictor { + convolutional_box_predictor { + min_depth: 0 + max_depth: 0 + num_layers_before_predictor: 0 + use_dropout: false + dropout_keep_probability: 0.8 + kernel_size: 1 + box_code_size: 4 + apply_sigmoid_to_scores: false + conv_hyperparams { + activation: RELU_6, + regularizer { + l2_regularizer { + weight: 0.00004 + } + } + initializer { + truncated_normal_initializer { + stddev: 0.03 + mean: 0.0 + } + } + batch_norm { + train: true, + scale: true, + center: true, + decay: 0.9997, + epsilon: 0.001, + } + } + } + } + feature_extractor { + type: 'ssd_mobilenet_v2' + min_depth: 16 + depth_multiplier: 1.0 + conv_hyperparams { + activation: RELU_6, + regularizer { + l2_regularizer { + weight: 0.00004 + } + } + initializer { + truncated_normal_initializer { + stddev: 0.03 + mean: 0.0 + } + } + batch_norm { + train: true, + scale: true, + center: true, + decay: 0.9997, + epsilon: 0.001, + } + } + } + loss { + classification_loss { + weighted_sigmoid { + } + } + localization_loss { + weighted_smooth_l1 { + } + } + hard_example_miner { + num_hard_examples: 3000 + iou_threshold: 0.99 + loss_type: CLASSIFICATION + max_negatives_per_positive: 3 + min_negatives_per_image: 3 + } + classification_weight: 1.0 + localization_weight: 1.0 + } + normalize_loss_by_num_matches: true + post_processing { + batch_non_max_suppression { + score_threshold: 1e-8 + iou_threshold: 0.6 + max_detections_per_class: 100 + max_total_detections: 100 + } + score_converter: SIGMOID + } + } +} + +train_config: { + batch_size: 24 + optimizer { + rms_prop_optimizer: { + learning_rate: { + exponential_decay_learning_rate { + initial_learning_rate: 0.004 + decay_steps: 800720 + decay_factor: 0.95 + } + } + momentum_optimizer_value: 0.9 + decay: 0.9 + epsilon: 1.0 + } + } + fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt" + fine_tune_checkpoint_type: "detection" + # Note: The below line limits the training process to 200K steps, which we + # empirically found to be sufficient enough to train the pets dataset. This + # effectively bypasses the learning rate schedule (the learning rate will + # never decay). Remove the below line to train indefinitely. + num_steps: 200000 + data_augmentation_options { + random_horizontal_flip { + } + } + data_augmentation_options { + ssd_random_crop { + } + } +} + +train_input_reader: { + tf_record_input_reader { + input_path: "PATH_TO_BE_CONFIGURED/mscoco_train.record-?????-of-00100" + } + label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt" +} + +eval_config: { + num_examples: 8000 + # Note: The below line limits the evaluation process to 10 evaluations. + # Remove the below line to evaluate indefinitely. + max_evals: 10 +} + +eval_input_reader: { + tf_record_input_reader { + input_path: "PATH_TO_BE_CONFIGURED/mscoco_val.record-?????-of-00010" + } + label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt" + shuffle: false + num_readers: 1 +} + +graph_rewriter { + quantization { + delay: 48000 + weight_bits: 8 + activation_bits: 8 + } +} \ No newline at end of file diff --git a/research/object_detection/utils/config_util.py b/research/object_detection/utils/config_util.py index 9a136d80ee7..f8103a33ad8 100644 --- a/research/object_detection/utils/config_util.py +++ b/research/object_detection/utils/config_util.py @@ -76,12 +76,14 @@ def get_spatial_image_size(image_resizer_config): raise ValueError("Unknown image resizer type.") -def get_configs_from_pipeline_file(pipeline_config_path): +def get_configs_from_pipeline_file(pipeline_config_path, config_override=None): """Reads config from a file containing pipeline_pb2.TrainEvalPipelineConfig. Args: pipeline_config_path: Path to pipeline_pb2.TrainEvalPipelineConfig text proto. + config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to + override pipeline_config_path. Returns: Dictionary of configuration objects. Keys are `model`, `train_config`, @@ -92,6 +94,8 @@ def get_configs_from_pipeline_file(pipeline_config_path): with tf.gfile.GFile(pipeline_config_path, "r") as f: proto_str = f.read() text_format.Merge(proto_str, pipeline_config) + if config_override: + text_format.Merge(config_override, pipeline_config) return create_configs_from_pipeline_proto(pipeline_config) @@ -430,7 +434,7 @@ def merge_external_params_with_configs(configs, hparams=None, kwargs_dict=None): final learning rates. In this case key can be one of the following formats: 1. legacy update: single string that indicates the attribute to be - updated. E.g. 'lable_map_path', 'eval_input_path', 'shuffle'. + updated. E.g. 'label_map_path', 'eval_input_path', 'shuffle'. Note that when updating fields (e.g. eval_input_path, eval_shuffle) in eval_input_configs, the override will only be applied when eval_input_configs has exactly 1 element. diff --git a/research/object_detection/utils/object_detection_evaluation.py b/research/object_detection/utils/object_detection_evaluation.py index 5826c58114d..d65c69fb291 100644 --- a/research/object_detection/utils/object_detection_evaluation.py +++ b/research/object_detection/utils/object_detection_evaluation.py @@ -633,11 +633,37 @@ def __init__(self, nms_max_output_boxes=10000, use_weighted_mean_ap=False, label_id_offset=0, - group_of_weight=0.0): + group_of_weight=0.0, + per_image_eval_class=per_image_evaluation.PerImageEvaluation): + """Constructor. + + Args: + num_groundtruth_classes: Number of ground-truth classes. + matching_iou_threshold: IOU threshold used for matching detected boxes + to ground-truth boxes. + nms_iou_threshold: IOU threshold used for non-maximum suppression. + nms_max_output_boxes: Maximum number of boxes returned by non-maximum + suppression. + use_weighted_mean_ap: (optional) boolean which determines if the mean + average precision is computed directly from the scores and tp_fp_labels + of all classes. + label_id_offset: The label id offset. + group_of_weight: Weight of group-of boxes.If set to 0, detections of the + correct class within a group-of box are ignored. If weight is > 0, then + if at least one detection falls within a group-of box with + matching_iou_threshold, weight group_of_weight is added to true + positives. Consequently, if no detection falls within a group-of box, + weight group_of_weight is added to false negatives. + per_image_eval_class: The class that contains functions for computing + per image metrics. + + Raises: + ValueError: if num_groundtruth_classes is smaller than 1. + """ if num_groundtruth_classes < 1: raise ValueError('Need at least 1 groundtruth class for evaluation.') - self.per_image_eval = per_image_evaluation.PerImageEvaluation( + self.per_image_eval = per_image_eval_class( num_groundtruth_classes=num_groundtruth_classes, matching_iou_threshold=matching_iou_threshold, nms_iou_threshold=nms_iou_threshold, @@ -659,14 +685,16 @@ def __init__(self, self._initialize_detections() def _initialize_detections(self): + """Initializes internal data structures.""" self.detection_keys = set() self.scores_per_class = [[] for _ in range(self.num_class)] self.tp_fp_labels_per_class = [[] for _ in range(self.num_class)] self.num_images_correctly_detected_per_class = np.zeros(self.num_class) self.average_precision_per_class = np.empty(self.num_class, dtype=float) self.average_precision_per_class.fill(np.nan) - self.precisions_per_class = [] - self.recalls_per_class = [] + self.precisions_per_class = [np.nan] * self.num_class + self.recalls_per_class = [np.nan] * self.num_class + self.corloc_per_class = np.ones(self.num_class, dtype=float) def clear_detections(self): @@ -867,8 +895,8 @@ def evaluate(self): logging.info(scores) precision, recall = metrics.compute_precision_recall( scores, tp_fp_labels, self.num_gt_instances_per_class[class_index]) - self.precisions_per_class.append(precision) - self.recalls_per_class.append(recall) + self.precisions_per_class[class_index] = precision + self.recalls_per_class[class_index] = recall average_precision = metrics.compute_average_precision(precision, recall) self.average_precision_per_class[class_index] = average_precision diff --git a/research/object_detection/utils/ops.py b/research/object_detection/utils/ops.py index 4d879247456..ac557c0d23e 100644 --- a/research/object_detection/utils/ops.py +++ b/research/object_detection/utils/ops.py @@ -872,7 +872,8 @@ def map_box_encodings(i): merged_box_indices) -def nearest_neighbor_upsampling(input_tensor, scale): +def nearest_neighbor_upsampling(input_tensor, scale=None, height_scale=None, + width_scale=None): """Nearest neighbor upsampling implementation. Nearest neighbor upsampling function that maps input tensor with shape @@ -883,19 +884,33 @@ def nearest_neighbor_upsampling(input_tensor, scale): Args: input_tensor: A float32 tensor of size [batch, height_in, width_in, channels]. - scale: An integer multiple to scale resolution of input data. + scale: An integer multiple to scale resolution of input data in both height + and width dimensions. + height_scale: An integer multiple to scale the height of input image. This + option when provided overrides `scale` option. + width_scale: An integer multiple to scale the width of input image. This + option when provided overrides `scale` option. Returns: data_up: A float32 tensor of size [batch, height_in*scale, width_in*scale, channels]. + + Raises: + ValueError: If both scale and height_scale or if both scale and width_scale + are None. """ + if not scale and (height_scale is None or width_scale is None): + raise ValueError('Provide either `scale` or `height_scale` and' + ' `width_scale`.') with tf.name_scope('nearest_neighbor_upsampling'): + h_scale = scale if height_scale is None else height_scale + w_scale = scale if width_scale is None else width_scale (batch_size, height, width, channels) = shape_utils.combined_static_and_dynamic_shape(input_tensor) output_tensor = tf.reshape( input_tensor, [batch_size, height, 1, width, 1, channels]) * tf.ones( - [1, 1, scale, 1, scale, 1], dtype=input_tensor.dtype) + [1, 1, h_scale, 1, w_scale, 1], dtype=input_tensor.dtype) return tf.reshape(output_tensor, - [batch_size, height * scale, width * scale, channels]) + [batch_size, height * h_scale, width * w_scale, channels]) def matmul_gather_on_zeroth_axis(params, indices, scope=None): @@ -1072,29 +1087,35 @@ def get_box_inds(proposals): return tf.reshape(cropped_regions, final_shape) -def expected_classification_loss_under_sampling(batch_cls_targets, cls_losses, - desired_negative_sampling_ratio, - minimum_negative_sampling): +def expected_classification_loss_under_sampling( + batch_cls_targets, cls_losses, unmatched_cls_losses, + desired_negative_sampling_ratio, min_num_negative_samples): """Computes classification loss by background/foreground weighting. The weighting is such that the effective background/foreground weight ratio is the desired_negative_sampling_ratio. if p_i is the foreground probability - of anchor a_i, L(a_i) is the anchors loss, N is the number of anchors, and M - is the sum of foreground probabilities across anchors, then the total loss L - is calculated as: + of anchor a_i, L(a_i) is the anchors loss, N is the number of anchors, M + is the sum of foreground probabilities across anchors, and K is the desired + ratio between the number of negative and positive samples, then the total loss + L is calculated as: beta = K*M/(N-M) - L = sum_{i=1}^N [p_i + beta * (1 - p_i)] * (L(a_i)) + L = sum_{i=1}^N [p_i * L_p(a_i) + beta * (1 - p_i) * L_n(a_i)] + where L_p(a_i) is the loss against target assuming the anchor was matched, + otherwise zero, and L_n(a_i) is the loss against the background target + assuming the anchor was unmatched, otherwise zero. Args: - batch_cls_targets: A tensor with shape [batch_size, num_anchors, - num_classes + 1], where 0'th index is the background class, containing - the class distrubution for the target assigned to a given anchor. - cls_losses: Float tensor of shape [batch_size, num_anchors] - representing anchorwise classification losses. + batch_cls_targets: A tensor with shape [batch_size, num_anchors, num_classes + + 1], where 0'th index is the background class, containing the class + distrubution for the target assigned to a given anchor. + cls_losses: Float tensor of shape [batch_size, num_anchors] representing + anchorwise classification losses. + unmatched_cls_losses: loss for each anchor against the unmatched class + target. desired_negative_sampling_ratio: The desired background/foreground weight ratio. - minimum_negative_sampling: Minimum number of effective negative samples. + min_num_negative_samples: Minimum number of effective negative samples. Used only when there are no positive examples. Returns: @@ -1103,36 +1124,44 @@ def expected_classification_loss_under_sampling(batch_cls_targets, cls_losses, num_anchors = tf.cast(tf.shape(batch_cls_targets)[1], tf.float32) # find the p_i - foreground_probabilities = ( - foreground_probabilities_from_targets(batch_cls_targets)) + foreground_probabilities = 1 - batch_cls_targets[:, :, 0] + foreground_sum = tf.reduce_sum(foreground_probabilities, axis=-1) + # for each anchor, expected_j is the expected number of positive anchors + # given that this anchor was sampled as negative. + tiled_foreground_sum = tf.tile( + tf.reshape(foreground_sum, [-1, 1]), + [1, tf.cast(num_anchors, tf.int32)]) + expected_j = tiled_foreground_sum - foreground_probabilities + k = desired_negative_sampling_ratio # compute beta - denominators = (num_anchors - foreground_sum) - beta = tf.where( - tf.equal(denominators, 0), tf.zeros_like(foreground_sum), - k * foreground_sum / denominators) + expected_negatives = tf.to_float(num_anchors) - expected_j + desired_negatives = k * expected_j + desired_negatives = tf.where( + tf.greater(desired_negatives, expected_negatives), expected_negatives, + desired_negatives) + + # probability that an anchor is sampled for the loss computation given that it + # is negative. + beta = desired_negatives / expected_negatives # where the foreground sum is zero, use a minimum negative weight. - min_negative_weight = 1.0 * minimum_negative_sampling / num_anchors + min_negative_weight = 1.0 * min_num_negative_samples / num_anchors beta = tf.where( - tf.equal(foreground_sum, 0), min_negative_weight * tf.ones_like(beta), - beta) - beta = tf.reshape(beta, [-1, 1]) + tf.equal(tiled_foreground_sum, 0), + min_negative_weight * tf.ones_like(beta), beta) - cls_loss_weights = foreground_probabilities + ( - 1 - foreground_probabilities) * beta + foreground_weights = foreground_probabilities + background_weights = (1 - foreground_weights) * beta - weighted_losses = cls_loss_weights * cls_losses + weighted_foreground_losses = foreground_weights * cls_losses + weighted_background_losses = background_weights * unmatched_cls_losses - cls_losses = tf.reduce_sum(weighted_losses, axis=-1) + cls_losses = tf.reduce_sum( + weighted_foreground_losses, axis=-1) + tf.reduce_sum( + weighted_background_losses, axis=-1) return cls_losses - - -def foreground_probabilities_from_targets(batch_cls_targets): - foreground_probabilities = 1 - batch_cls_targets[:, :, 0] - - return foreground_probabilities diff --git a/research/object_detection/utils/ops_test.py b/research/object_detection/utils/ops_test.py index 6af3f99d00c..c1b2b4e10cb 100644 --- a/research/object_detection/utils/ops_test.py +++ b/research/object_detection/utils/ops_test.py @@ -1222,7 +1222,7 @@ def testMergeBoxesWithEmptyInputs(self): class NearestNeighborUpsamplingTest(test_case.TestCase): - def test_upsampling(self): + def test_upsampling_with_single_scale(self): def graph_fn(inputs): custom_op_output = ops.nearest_neighbor_upsampling(inputs, scale=2) @@ -1236,6 +1236,22 @@ def graph_fn(inputs): [[2], [2], [3], [3]]]] self.assertAllClose(custom_op_output, expected_output) + def test_upsampling_with_separate_height_width_scales(self): + + def graph_fn(inputs): + custom_op_output = ops.nearest_neighbor_upsampling(inputs, + height_scale=2, + width_scale=3) + return custom_op_output + inputs = np.reshape(np.arange(4).astype(np.float32), [1, 2, 2, 1]) + custom_op_output = self.execute(graph_fn, [inputs]) + + expected_output = [[[[0], [0], [0], [1], [1], [1]], + [[0], [0], [0], [1], [1], [1]], + [[2], [2], [2], [3], [3], [3]], + [[2], [2], [2], [3], [3], [3]]]] + self.assertAllClose(custom_op_output, expected_output) + class MatmulGatherOnZerothAxis(test_case.TestCase): @@ -1454,78 +1470,182 @@ class OpsTestExpectedClassificationLoss(test_case.TestCase): def testExpectedClassificationLossUnderSamplingWithHardLabels(self): - def graph_fn(batch_cls_targets, cls_losses, negative_to_positive_ratio, - minimum_negative_sampling): + def graph_fn(batch_cls_targets, cls_losses, unmatched_cls_losses, + negative_to_positive_ratio, min_num_negative_samples): return ops.expected_classification_loss_under_sampling( - batch_cls_targets, cls_losses, negative_to_positive_ratio, - minimum_negative_sampling) + batch_cls_targets, cls_losses, unmatched_cls_losses, + negative_to_positive_ratio, min_num_negative_samples) batch_cls_targets = np.array( [[[1., 0, 0], [0, 1., 0]], [[1., 0, 0], [0, 1., 0]]], dtype=np.float32) cls_losses = np.array([[1, 2], [3, 4]], dtype=np.float32) + unmatched_cls_losses = np.array([[10, 20], [30, 40]], dtype=np.float32) negative_to_positive_ratio = np.array([2], dtype=np.float32) - minimum_negative_sampling = np.array([1], dtype=np.float32) + min_num_negative_samples = np.array([1], dtype=np.float32) classification_loss = self.execute(graph_fn, [ - batch_cls_targets, cls_losses, negative_to_positive_ratio, - minimum_negative_sampling + batch_cls_targets, cls_losses, unmatched_cls_losses, + negative_to_positive_ratio, min_num_negative_samples ]) - # expected_foregorund_sum = [1,1] - # expected_beta = [2,2] - # expected_cls_loss_weights = [2,1],[2,1] - # expected_classification_loss_under_sampling = [2*1+1*2, 2*3+1*4] - expected_classification_loss_under_sampling = [2 + 2, 6 + 4] + # expected_foreground_sum = [1,1] + # expected_expected_j = [[1, 0], [1, 0]] + # expected_expected_negatives = [[1, 2], [1, 2]] + # expected_desired_negatives = [[2, 0], [2, 0]] + # expected_beta = [[1, 0], [1, 0]] + # expected_foreground_weights = [[0, 1], [0, 1]] + # expected_background_weights = [[1, 0], [1, 0]] + # expected_weighted_foreground_losses = [[0, 2], [0, 4]] + # expected_weighted_background_losses = [[10, 0], [30, 0]] + # expected_classification_loss_under_sampling = [6, 40] + expected_classification_loss_under_sampling = [2 + 10, 4 + 30] + + self.assertAllClose(expected_classification_loss_under_sampling, + classification_loss) + + def testExpectedClassificationLossUnderSamplingWithHardLabelsMoreNegatives( + self): + + def graph_fn(batch_cls_targets, cls_losses, unmatched_cls_losses, + negative_to_positive_ratio, min_num_negative_samples): + return ops.expected_classification_loss_under_sampling( + batch_cls_targets, cls_losses, unmatched_cls_losses, + negative_to_positive_ratio, min_num_negative_samples) + + batch_cls_targets = np.array( + [[[1., 0, 0], [0, 1., 0], [1., 0, 0], [1., 0, 0], [1., 0, 0]]], + dtype=np.float32) + cls_losses = np.array([[1, 2, 3, 4, 5]], dtype=np.float32) + unmatched_cls_losses = np.array([[10, 20, 30, 40, 50]], dtype=np.float32) + negative_to_positive_ratio = np.array([2], dtype=np.float32) + min_num_negative_samples = np.array([1], dtype=np.float32) + + classification_loss = self.execute(graph_fn, [ + batch_cls_targets, cls_losses, unmatched_cls_losses, + negative_to_positive_ratio, min_num_negative_samples + ]) + + # expected_foreground_sum = [1] + # expected_expected_j = [[1, 0, 1, 1, 1]] + # expected_expected_negatives = [[4, 5, 4, 4, 4]] + # expected_desired_negatives = [[2, 0, 2, 2, 2]] + # expected_beta = [[.5, 0, .5, .5, .5]] + # expected_foreground_weights = [[0, 1, 0, 0, 0]] + # expected_background_weights = [[.5, 0, .5, .5, .5]] + # expected_weighted_foreground_losses = [[0, 2, 0, 0, 0]] + # expected_weighted_background_losses = [[10*.5, 0, 30*.5, 40*.5, 50*.5]] + # expected_classification_loss_under_sampling = [5+2+15+20+25] + expected_classification_loss_under_sampling = [5 + 2 + 15 + 20 + 25] self.assertAllClose(expected_classification_loss_under_sampling, classification_loss) def testExpectedClassificationLossUnderSamplingWithAllNegative(self): - def graph_fn(batch_cls_targets, cls_losses): + def graph_fn(batch_cls_targets, cls_losses, unmatched_cls_losses): return ops.expected_classification_loss_under_sampling( - batch_cls_targets, cls_losses, negative_to_positive_ratio, - minimum_negative_sampling) + batch_cls_targets, cls_losses, unmatched_cls_losses, + negative_to_positive_ratio, min_num_negative_samples) batch_cls_targets = np.array( [[[1, 0, 0], [1, 0, 0]], [[1, 0, 0], [1, 0, 0]]], dtype=np.float32) cls_losses = np.array([[1, 2], [3, 4]], dtype=np.float32) + unmatched_cls_losses = np.array([[10, 20], [30, 40]], dtype=np.float32) negative_to_positive_ratio = np.array([2], dtype=np.float32) - minimum_negative_sampling = np.array([1], dtype=np.float32) - - classification_loss = self.execute(graph_fn, - [batch_cls_targets, cls_losses]) - - # expected_foregorund_sum = [0,0] - # expected_beta = [0.5,0.5] - # expected_cls_loss_weights = [0.5,0.5],[0.5,0.5] - # expected_classification_loss_under_sampling = [.5*1+.5*2, .5*3+.5*4] - expected_classification_loss_under_sampling = [1.5, 3.5] + min_num_negative_samples = np.array([1], dtype=np.float32) + + classification_loss = self.execute( + graph_fn, [batch_cls_targets, cls_losses, unmatched_cls_losses]) + + # expected_foreground_sum = [0,0] + # expected_expected_j = [[0, 0], [0, 0]] + # expected_expected_negatives = [[2, 2], [2, 2]] + # expected_desired_negatives = [[0, 0], [0, 0]] + # expected_beta = [[0, 0],[0, 0]] + # expected_foreground_weights = [[0, 0], [0, 0]] + # expected_background_weights = [[.5, .5], [.5, .5]] + # expected_weighted_foreground_losses = [[0, 0], [0, 0]] + # expected_weighted_background_losses = [[5, 10], [15, 20]] + # expected_classification_loss_under_sampling = [15, 35] + expected_classification_loss_under_sampling = [ + 10 * .5 + 20 * .5, 30 * .5 + 40 * .5 + ] self.assertAllClose(expected_classification_loss_under_sampling, classification_loss) def testExpectedClassificationLossUnderSamplingWithAllPositive(self): - def graph_fn(batch_cls_targets, cls_losses): + def graph_fn(batch_cls_targets, cls_losses, unmatched_cls_losses): return ops.expected_classification_loss_under_sampling( - batch_cls_targets, cls_losses, negative_to_positive_ratio, - minimum_negative_sampling) + batch_cls_targets, cls_losses, unmatched_cls_losses, + negative_to_positive_ratio, min_num_negative_samples) batch_cls_targets = np.array( [[[0, 1., 0], [0, 1., 0]], [[0, 1, 0], [0, 0, 1]]], dtype=np.float32) cls_losses = np.array([[1, 2], [3, 4]], dtype=np.float32) + unmatched_cls_losses = np.array([[10, 20], [30, 40]], dtype=np.float32) negative_to_positive_ratio = np.array([2], dtype=np.float32) - minimum_negative_sampling = np.array([1], dtype=np.float32) + min_num_negative_samples = np.array([1], dtype=np.float32) + + classification_loss = self.execute( + graph_fn, [batch_cls_targets, cls_losses, unmatched_cls_losses]) + + # expected_foreground_sum = [2,2] + # expected_expected_j = [[1, 1], [1, 1]] + # expected_expected_negatives = [[1, 1], [1, 1]] + # expected_desired_negatives = [[1, 1], [1, 1]] + # expected_beta = [[1, 1],[1, 1]] + # expected_foreground_weights = [[1, 1], [1, 1]] + # expected_background_weights = [[0, 0], [0, 0]] + # expected_weighted_foreground_losses = [[1, 2], [3, 4]] + # expected_weighted_background_losses = [[0, 0], [0, 0]] + # expected_classification_loss_under_sampling = [15, 35] + expected_classification_loss_under_sampling = [1 + 2, 3 + 4] - classification_loss = self.execute(graph_fn, - [batch_cls_targets, cls_losses]) + self.assertAllClose(expected_classification_loss_under_sampling, + classification_loss) - # expected_foregorund_sum = [2,2] - # expected_beta = [0,0] - # expected_cls_loss_weights = [1,1],[1,1] - # expected_classification_loss_under_sampling = [1*1+1*2, 1*3+1*4] - expected_classification_loss_under_sampling = [1 + 2, 3 + 4] + def testExpectedClassificationLossUnderSamplingWithSoftLabels(self): + + def graph_fn(batch_cls_targets, cls_losses, unmatched_cls_losses, + negative_to_positive_ratio, min_num_negative_samples): + return ops.expected_classification_loss_under_sampling( + batch_cls_targets, cls_losses, unmatched_cls_losses, + negative_to_positive_ratio, min_num_negative_samples) + + batch_cls_targets = np.array([[[.75, .25, 0], [0.25, .75, 0], [.75, .25, 0], + [0.25, .75, 0], [1., 0, 0]]], + dtype=np.float32) + cls_losses = np.array([[1, 2, 3, 4, 5]], dtype=np.float32) + unmatched_cls_losses = np.array([[10, 20, 30, 40, 50]], dtype=np.float32) + negative_to_positive_ratio = np.array([2], dtype=np.float32) + min_num_negative_samples = np.array([1], dtype=np.float32) + + classification_loss = self.execute(graph_fn, [ + batch_cls_targets, cls_losses, unmatched_cls_losses, + negative_to_positive_ratio, min_num_negative_samples + ]) + + # expected_foreground_sum = [2] + # expected_expected_j = [[1.75, 1.25, 1.75, 1.25, 2]] + # expected_expected_negatives = [[3.25, 3.75, 3.25, 3.75, 3]] + # expected_desired_negatives = [[3.25, 2.5, 3.25, 2.5, 3]] + # expected_beta = [[1, 2/3, 1, 2/3, 1]] + # expected_foreground_weights = [[0.25, .75, .25, .75, 0]] + # expected_background_weights = [[[.75, 1/6., .75, 1/6., 1]]] + # expected_weighted_foreground_losses = [[.25*1, .75*2, .25*3, .75*4, 0*5]] + # expected_weighted_background_losses = [[ + # .75*10, 1/6.*20, .75*30, 1/6.*40, 1*50]] + # expected_classification_loss_under_sampling = sum([ + # .25*1, .75*2, .25*3, .75*4, 0, .75*10, 1/6.*20, .75*30, + # 1/6.*40, 1*50]) + expected_classification_loss_under_sampling = [ + sum([ + .25 * 1, .75 * 2, .25 * 3, .75 * 4, 0, .75 * 10, 1 / 6. * 20, + .75 * 30, 1 / 6. * 40, 1 * 50 + ]) + ] self.assertAllClose(expected_classification_loss_under_sampling, classification_loss) diff --git a/research/object_detection/utils/test_utils.py b/research/object_detection/utils/test_utils.py index 16de6176817..d165e3ad3ba 100644 --- a/research/object_detection/utils/test_utils.py +++ b/research/object_detection/utils/test_utils.py @@ -45,8 +45,10 @@ def _decode(self, rel_codes, anchors): class MockBoxPredictor(box_predictor.BoxPredictor): """Simple box predictor that ignores inputs and outputs all zeros.""" - def __init__(self, is_training, num_classes, predict_mask=False): + def __init__(self, is_training, num_classes, add_background_class=True, + predict_mask=False): super(MockBoxPredictor, self).__init__(is_training, num_classes) + self._add_background_class = add_background_class self._predict_mask = predict_mask def _predict(self, image_features, num_predictions_per_location): @@ -57,10 +59,13 @@ def _predict(self, image_features, num_predictions_per_location): num_anchors = (combined_feature_shape[1] * combined_feature_shape[2]) code_size = 4 zero = tf.reduce_sum(0 * image_feature) + num_class_slots = self.num_classes + if self._add_background_class: + num_class_slots = num_class_slots + 1 box_encodings = zero + tf.zeros( (batch_size, num_anchors, 1, code_size), dtype=tf.float32) class_predictions_with_background = zero + tf.zeros( - (batch_size, num_anchors, self.num_classes + 1), dtype=tf.float32) + (batch_size, num_anchors, num_class_slots), dtype=tf.float32) masks = zero + tf.zeros( (batch_size, num_anchors, self.num_classes, DEFAULT_MASK_SIZE, DEFAULT_MASK_SIZE), @@ -80,9 +85,11 @@ def _predict(self, image_features, num_predictions_per_location): class MockKerasBoxPredictor(box_predictor.KerasBoxPredictor): """Simple box predictor that ignores inputs and outputs all zeros.""" - def __init__(self, is_training, num_classes, predict_mask=False): + def __init__(self, is_training, num_classes, add_background_class=True, + predict_mask=False): super(MockKerasBoxPredictor, self).__init__( is_training, num_classes, False, False) + self._add_background_class = add_background_class self._predict_mask = predict_mask def _predict(self, image_features, **kwargs): @@ -93,10 +100,13 @@ def _predict(self, image_features, **kwargs): num_anchors = (combined_feature_shape[1] * combined_feature_shape[2]) code_size = 4 zero = tf.reduce_sum(0 * image_feature) + num_class_slots = self.num_classes + if self._add_background_class: + num_class_slots = num_class_slots + 1 box_encodings = zero + tf.zeros( (batch_size, num_anchors, 1, code_size), dtype=tf.float32) class_predictions_with_background = zero + tf.zeros( - (batch_size, num_anchors, self.num_classes + 1), dtype=tf.float32) + (batch_size, num_anchors, num_class_slots), dtype=tf.float32) masks = zero + tf.zeros( (batch_size, num_anchors, self.num_classes, DEFAULT_MASK_SIZE, DEFAULT_MASK_SIZE),