Skip to content

Commit

Permalink
img_scale argument for Resize function
Browse files Browse the repository at this point in the history
  • Loading branch information
ccanamero committed Aug 21, 2024
1 parent 9e55795 commit 00ee948
Showing 1 changed file with 118 additions and 1 deletion.
119 changes: 118 additions & 1 deletion mmcls/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,18 @@ def __init__(self,
size,
interpolation='bilinear',
adaptive_side='short',
backend='cv2'):
backend='cv2',
img_scale=None):

if img_scale is None:
self.img_scale = None
else:
if isinstance(img_scale, list):
self.img_scale = img_scale
else:
self.img_scale = [img_scale]
assert mmcv.is_list_of(self.img_scale, tuple)

assert isinstance(size, int) or (isinstance(size, tuple)
and len(size) == 2)
assert adaptive_side in {'short', 'long', 'height', 'width'}
Expand Down Expand Up @@ -765,9 +776,115 @@ def __call__(self, results):

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(img_scale={self.img_scale}, '
repr_str += f'(size={self.size}, '
repr_str += f'interpolation={self.interpolation})'
return repr_str

@staticmethod
def random_select(img_scales):
"""Randomly select an img_scale from given candidates.
Args:
img_scales (list[tuple]): Images scales for selection.
Returns:
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \
where ``img_scale`` is the selected image scale and \
``scale_idx`` is the selected index in the given candidates.
"""

assert mmcv.is_list_of(img_scales, tuple)
scale_idx = np.random.randint(len(img_scales))
img_scale = img_scales[scale_idx]
return img_scale, scale_idx

@staticmethod
def random_sample(img_scales):
"""Randomly sample an img_scale when ``multiscale_mode=='range'``.
Args:
img_scales (list[tuple]): Images scale range for sampling.
There must be two tuples in img_scales, which specify the lower
and upper bound of image scales.
Returns:
(tuple, None): Returns a tuple ``(img_scale, None)``, where \
``img_scale`` is sampled scale and None is just a placeholder \
to be consistent with :func:`random_select`.
"""

assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
img_scale_long = [max(s) for s in img_scales]
img_scale_short = [min(s) for s in img_scales]
long_edge = np.random.randint(
min(img_scale_long),
max(img_scale_long) + 1)
short_edge = np.random.randint(
min(img_scale_short),
max(img_scale_short) + 1)
img_scale = (long_edge, short_edge)
return img_scale, None

@staticmethod
def random_sample_ratio(img_scale, ratio_range):
"""Randomly sample an img_scale when ``ratio_range`` is specified.
A ratio will be randomly sampled from the range specified by
``ratio_range``. Then it would be multiplied with ``img_scale`` to
generate sampled scale.
Args:
img_scale (tuple): Images scale base to multiply with ratio.
ratio_range (tuple[float]): The minimum and maximum ratio to scale
the ``img_scale``.
Returns:
(tuple, None): Returns a tuple ``(scale, None)``, where \
``scale`` is sampled ratio multiplied with ``img_scale`` and \
None is just a placeholder to be consistent with \
:func:`random_select`.
"""

assert isinstance(img_scale, tuple) and len(img_scale) == 2
min_ratio, max_ratio = ratio_range
assert min_ratio <= max_ratio
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
return scale, None

def _random_scale(self, results):
"""Randomly sample an img_scale according to ``ratio_range`` and
``multiscale_mode``.
If ``ratio_range`` is specified, a ratio will be sampled and be
multiplied with ``img_scale``.
If multiple scales are specified by ``img_scale``, a scale will be
sampled according to ``multiscale_mode``.
Otherwise, single scale will be used.
Args:
results (dict): Result dict from :obj:`dataset`.
Returns:
dict: Two new keys 'scale` and 'scale_idx` are added into \
``results``, which would be used by subsequent pipelines.
"""

if self.ratio_range is not None:
scale, scale_idx = self.random_sample_ratio(
self.img_scale[0], self.ratio_range)
elif len(self.img_scale) == 1:
scale, scale_idx = self.img_scale[0], 0
elif self.multiscale_mode == 'range':
scale, scale_idx = self.random_sample(self.img_scale)
elif self.multiscale_mode == 'value':
scale, scale_idx = self.random_select(self.img_scale)
else:
raise NotImplementedError

results['scale'] = scale
results['scale_idx'] = scale_idx


@PIPELINES.register_module()
Expand Down

0 comments on commit 00ee948

Please sign in to comment.