forked from open-mmlab/mmpretrain
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding MultiScaleFlipAug and changing directories
- Loading branch information
Showing
4 changed files
with
136 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import warnings | ||
|
||
import mmcv | ||
|
||
from ..builder import PIPELINES | ||
from .compose import Compose | ||
|
||
|
||
@PIPELINES.register_module() | ||
class MultiScaleFlipAug: | ||
"""Test-time augmentation with multiple scales and flipping. | ||
An example configuration is as followed: | ||
.. code-block:: | ||
img_scale=[(1333, 400), (1333, 800)], | ||
flip=True, | ||
transforms=[ | ||
dict(type='Resize', keep_ratio=True), | ||
dict(type='RandomFlip'), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='Pad', size_divisor=32), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='Collect', keys=['img']), | ||
] | ||
After MultiScaleFLipAug with above configuration, the results are wrapped | ||
into lists of the same length as followed: | ||
.. code-block:: | ||
dict( | ||
img=[...], | ||
img_shape=[...], | ||
scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)] | ||
flip=[False, True, False, True] | ||
... | ||
) | ||
Args: | ||
transforms (list[dict]): Transforms to apply in each augmentation. | ||
img_scale (tuple | list[tuple] | None): Images scales for resizing. | ||
scale_factor (float | list[float] | None): Scale factors for resizing. | ||
flip (bool): Whether apply flip augmentation. Default: False. | ||
flip_direction (str | list[str]): Flip augmentation directions, | ||
options are "horizontal", "vertical" and "diagonal". If | ||
flip_direction is a list, multiple flip augmentations will be | ||
applied. It has no effect when flip == False. Default: | ||
"horizontal". | ||
""" | ||
|
||
def __init__(self, | ||
transforms, | ||
img_scale=None, | ||
scale_factor=None, | ||
flip=False, | ||
flip_direction='horizontal'): | ||
self.transforms = Compose(transforms) | ||
assert (img_scale is None) ^ (scale_factor is None), ( | ||
'Must have but only one variable can be setted') | ||
if img_scale is not None: | ||
self.img_scale = img_scale if isinstance(img_scale, | ||
list) else [img_scale] | ||
self.scale_key = 'scale' | ||
assert mmcv.is_list_of(self.img_scale, tuple) | ||
else: | ||
self.img_scale = scale_factor if isinstance( | ||
scale_factor, list) else [scale_factor] | ||
self.scale_key = 'scale_factor' | ||
|
||
self.flip = flip | ||
self.flip_direction = flip_direction if isinstance( | ||
flip_direction, list) else [flip_direction] | ||
assert mmcv.is_list_of(self.flip_direction, str) | ||
if not self.flip and self.flip_direction != ['horizontal']: | ||
warnings.warn( | ||
'flip_direction has no effect when flip is set to False') | ||
if (self.flip | ||
and not any([t['type'] == 'RandomFlip' for t in transforms])): | ||
warnings.warn( | ||
'flip has no effect when RandomFlip is not in transforms') | ||
|
||
def __call__(self, results): | ||
"""Call function to apply test time augment transforms on results. | ||
Args: | ||
results (dict): Result dict contains the data to transform. | ||
Returns: | ||
dict[str: list]: The augmented data, where each value is wrapped | ||
into a list. | ||
""" | ||
|
||
aug_data = [] | ||
flip_args = [(False, None)] | ||
if self.flip: | ||
flip_args += [(True, direction) | ||
for direction in self.flip_direction] | ||
for scale in self.img_scale: | ||
for flip, direction in flip_args: | ||
_results = results.copy() | ||
_results[self.scale_key] = scale | ||
_results['flip'] = flip | ||
_results['flip_direction'] = direction | ||
data = self.transforms(_results) | ||
aug_data.append(data) | ||
# list of dict to dict of list | ||
aug_data_dict = {key: [] for key in aug_data[0]} | ||
for data in aug_data: | ||
for key, val in data.items(): | ||
aug_data_dict[key].append(val) | ||
return aug_data_dict | ||
|
||
def __repr__(self): | ||
repr_str = self.__class__.__name__ | ||
repr_str += f'(transforms={self.transforms}, ' | ||
repr_str += f'img_scale={self.img_scale}, flip={self.flip}, ' | ||
repr_str += f'flip_direction={self.flip_direction})' | ||
return repr_str |