Skip to content

Commit

Permalink
Adding MultiScaleFlipAug and changing directories
Browse files Browse the repository at this point in the history
  • Loading branch information
ccanamero committed Aug 20, 2024
1 parent 09c46f5 commit 737880c
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 13 deletions.
10 changes: 2 additions & 8 deletions mmcls/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,13 @@
from .samplers import DistributedSampler, RepeatAugSampler
from .stanford_cars import StanfordCars
from .voc import VOC
from .manage_multichannel_image import (
LoadMultiChannelImgFromFile,
ResizeMultiChannel,
BrightnessTransformMultiChannel,
NormalizeMinMaxChannelwise
)


__all__ = [
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS',
'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB',
'CustomDataset', 'StanfordCars', 'LoadMultiChannelImgFromFile',
'ResizeMultiChannel', 'BrightnessTransformMultiChannel', 'NormalizeMinMaxChannelwise'
'CustomDataset', 'StanfordCars'
]
11 changes: 10 additions & 1 deletion mmcls/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
from .transforms import (CenterCrop, ColorJitter, Lighting, Normalize, Pad,
RandomCrop, RandomErasing, RandomFlip,
RandomGrayscale, RandomResizedCrop, Resize)
from .test_time_aug import MultiScaleFlipAug
from .manage_multichannel_image import (
LoadMultiChannelImgFromFile,
ResizeMultiChannel,
BrightnessTransformMultiChannel,
NormalizeMinMaxChannelwise
)

__all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
Expand All @@ -18,5 +25,7 @@
'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert',
'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', 'Pad'
'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', 'Pad',
'MultiScaleFlipAug', 'LoadMultiChannelImgFromFile',
'ResizeMultiChannel', 'BrightnessTransformMultiChannel', 'NormalizeMinMaxChannelwise'
]
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from skimage import io

from .builder import PIPELINES
from .pipelines.auto_augment import Brightness
from .pipelines.transforms import Resize, Normalize
from .pipelines.loading import LoadImageFromFile
from ..builder import PIPELINES
from .auto_augment import Brightness
from .transforms import Resize, Normalize
from .loading import LoadImageFromFile


_MAX_LEVEL = 10
Expand Down
120 changes: 120 additions & 0 deletions mmcls/datasets/pipelines/test_time_aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import warnings

import mmcv

from ..builder import PIPELINES
from .compose import Compose


@PIPELINES.register_module()
class MultiScaleFlipAug:
"""Test-time augmentation with multiple scales and flipping.
An example configuration is as followed:
.. code-block::
img_scale=[(1333, 400), (1333, 800)],
flip=True,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
]
After MultiScaleFLipAug with above configuration, the results are wrapped
into lists of the same length as followed:
.. code-block::
dict(
img=[...],
img_shape=[...],
scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)]
flip=[False, True, False, True]
...
)
Args:
transforms (list[dict]): Transforms to apply in each augmentation.
img_scale (tuple | list[tuple] | None): Images scales for resizing.
scale_factor (float | list[float] | None): Scale factors for resizing.
flip (bool): Whether apply flip augmentation. Default: False.
flip_direction (str | list[str]): Flip augmentation directions,
options are "horizontal", "vertical" and "diagonal". If
flip_direction is a list, multiple flip augmentations will be
applied. It has no effect when flip == False. Default:
"horizontal".
"""

def __init__(self,
transforms,
img_scale=None,
scale_factor=None,
flip=False,
flip_direction='horizontal'):
self.transforms = Compose(transforms)
assert (img_scale is None) ^ (scale_factor is None), (
'Must have but only one variable can be setted')
if img_scale is not None:
self.img_scale = img_scale if isinstance(img_scale,
list) else [img_scale]
self.scale_key = 'scale'
assert mmcv.is_list_of(self.img_scale, tuple)
else:
self.img_scale = scale_factor if isinstance(
scale_factor, list) else [scale_factor]
self.scale_key = 'scale_factor'

self.flip = flip
self.flip_direction = flip_direction if isinstance(
flip_direction, list) else [flip_direction]
assert mmcv.is_list_of(self.flip_direction, str)
if not self.flip and self.flip_direction != ['horizontal']:
warnings.warn(
'flip_direction has no effect when flip is set to False')
if (self.flip
and not any([t['type'] == 'RandomFlip' for t in transforms])):
warnings.warn(
'flip has no effect when RandomFlip is not in transforms')

def __call__(self, results):
"""Call function to apply test time augment transforms on results.
Args:
results (dict): Result dict contains the data to transform.
Returns:
dict[str: list]: The augmented data, where each value is wrapped
into a list.
"""

aug_data = []
flip_args = [(False, None)]
if self.flip:
flip_args += [(True, direction)
for direction in self.flip_direction]
for scale in self.img_scale:
for flip, direction in flip_args:
_results = results.copy()
_results[self.scale_key] = scale
_results['flip'] = flip
_results['flip_direction'] = direction
data = self.transforms(_results)
aug_data.append(data)
# list of dict to dict of list
aug_data_dict = {key: [] for key in aug_data[0]}
for data in aug_data:
for key, val in data.items():
aug_data_dict[key].append(val)
return aug_data_dict

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(transforms={self.transforms}, '
repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '
repr_str += f'flip_direction={self.flip_direction})'
return repr_str

0 comments on commit 737880c

Please sign in to comment.