From 09c46f53b8e5f03944318923b6cd1c784f5cf419 Mon Sep 17 00:00:00 2001 From: ccanamero Date: Fri, 9 Aug 2024 19:24:46 +0200 Subject: [PATCH] Nuevo repo mmpretrain --- mmcls/datasets/__init__.py | 9 +- mmcls/datasets/manage_multichannel_image.py | 218 ++++++++++++++++++ tests/test_downstream/test_mmdet_inference.py | 8 +- 3 files changed, 230 insertions(+), 5 deletions(-) create mode 100644 mmcls/datasets/manage_multichannel_image.py diff --git a/mmcls/datasets/__init__.py b/mmcls/datasets/__init__.py index 095077e2321..d178621edfd 100644 --- a/mmcls/datasets/__init__.py +++ b/mmcls/datasets/__init__.py @@ -14,6 +14,12 @@ from .samplers import DistributedSampler, RepeatAugSampler from .stanford_cars import StanfordCars from .voc import VOC +from .manage_multichannel_image import ( + LoadMultiChannelImgFromFile, + ResizeMultiChannel, + BrightnessTransformMultiChannel, + NormalizeMinMaxChannelwise +) __all__ = [ 'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST', @@ -21,5 +27,6 @@ 'DistributedSampler', 'ConcatDataset', 'RepeatDataset', 'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS', 'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB', - 'CustomDataset', 'StanfordCars' + 'CustomDataset', 'StanfordCars', 'LoadMultiChannelImgFromFile', + 'ResizeMultiChannel', 'BrightnessTransformMultiChannel', 'NormalizeMinMaxChannelwise' ] diff --git a/mmcls/datasets/manage_multichannel_image.py b/mmcls/datasets/manage_multichannel_image.py new file mode 100644 index 00000000000..7f6ae380952 --- /dev/null +++ b/mmcls/datasets/manage_multichannel_image.py @@ -0,0 +1,218 @@ +# Copyright (c) Gradiant. All rights reserved. +import os.path as osp + +import mmcv +import numpy as np + +from skimage import io + +from .builder import PIPELINES +from .pipelines.auto_augment import Brightness +from .pipelines.transforms import Resize, Normalize +from .pipelines.loading import LoadImageFromFile + + +_MAX_LEVEL = 10 + +def enhance_level_to_value(level, a=1.8, b=0.1): + """Map from level to values.""" + return (level / _MAX_LEVEL) * a + b + + +@PIPELINES.register_module() +class LoadMultiChannelImgFromFile(LoadImageFromFile): + """Load an image from file. + Required keys are "img_prefix" and "img_info" (a dict that must contain the + key "filename"). Added or updated keys are "filename", "img", "img_shape", + "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), + "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + color_type (str): The flag argument for :func:`mmcv.imfrombytes`. + Defaults to 'color'. + file_client_args (dict): Arguments to instantiate a FileClient. + See :class:`mmcv.fileio.FileClient` for details. + Defaults to ``dict(backend='disk')``. + """ + + def __init__( + self, + to_float32=False, + color_type="color", + file_client_args=dict(backend="disk"), + ): + self.to_float32 = to_float32 + self.color_type = color_type + self.file_client_args = file_client_args.copy() + self.file_client = None + + def __call__(self, results): + """Call functions to load image and get image meta information. + Args: + results (dict): Result dict from :obj:`mmcls.CustomDataset`. + Returns: + dict: The dict contains loaded image and meta information. + """ + + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + + if results["img_prefix"] is not None: + filename = osp.join(results["img_prefix"], results["img_info"]["filename"]) + else: + filename = results["img_info"]["filename"] + + img = io.imread(filename) + if self.to_float32: + img = img.astype(np.float32) + + img = np.moveaxis(img, 0, -1) + + results["filename"] = filename + results["ori_filename"] = results["img_info"]["filename"] + results["img"] = img + results["img_shape"] = img.shape + results["ori_shape"] = img.shape + results["img_fields"] = ["img"] + + return results + + def __repr__(self): + repr_str = ( + f"{self.__class__.__name__}(" + f"to_float32={self.to_float32}, " + f"color_type='{self.color_type}', " + f"file_client_args={self.file_client_args})" + ) + return repr_str + + + + + +@PIPELINES.register_module() +class ResizeMultiChannel(Resize): + def _resize_img(self, results): + + img = results["img"].shape + + w_scale = img.shape[1] + h_scale = img.shape[2] + + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], dtype=np.float32) + results["img"] = img + results["img_shape"] = img.shape + results["pad_shape"] = img.shape # in case that there is no padding + results["scale_factor"] = scale_factor + results["keep_ratio"] = self.keep_ratio + +@PIPELINES.register_module() +class BrightnessTransformMultiChannel(Brightness): + + def __init__(self, level, prob=0.5, dims=[]): + + assert isinstance(level, (int, float, list)), \ + 'The level must be type list, int or float.' + assert isinstance(dims, (list)), \ + 'dims must be list of channels' + assert 0 <= prob <= 1.0, \ + 'The probability should be in range [0,1].' + + if isinstance(level, (list)): + if isinstance(dims, list) and len(dims) != 0: + assert len(level)==len(dims), \ + 'Level list length should match dimension list length' + for l in level: + assert 0 <= l <= _MAX_LEVEL, \ + 'The level should be in range [0,_MAX_LEVEL].' + else: + assert 0 <= level <= _MAX_LEVEL, \ + 'The level should be in range [0,_MAX_LEVEL].' + + self.level = level + self.prob = prob + self.dims = dims + + def __call__(self, results): + """Call function for Brightness transformation. + + Args: + results (dict): Results dict from loading pipeline. + + Returns: + dict: Results after the transformation. + """ + if np.random.rand() > self.prob: + return results + + original_img = results['img'] + + assert len(self.dims) <= original_img.shape[-1], \ + 'Selected channels can\'t be greater than numer of channels' + + for d in self.dims: + assert d <= (original_img.shape[-1]-1) , \ + f'Channel must be one of {range(0, original_img.shape[-1]-1)} but found {d}' + + if len(self.dims) != 0: + + if isinstance(self.level, list): + + for l, d in zip(self.level, self.dims): + results['img'] = original_img[:,:,d] + self._adjust_brightness_img(results, enhance_level_to_value(l)) + original_img[:,:,d] = results['img'] + + else: + results['img'] = original_img[:,:,self.dims] + self._adjust_brightness_img(results, enhance_level_to_value(self.level)) + original_img[:,:,self.dims] = results['img'] + + else: + if isinstance(self.level, list): + assert len(self.level) == original_img.shape[-1], \ + 'When type(level)==list, len(level) should match total number of channels' + + for d, l in enumerate(self.level): + results['img'] = original_img[:,:,d] + self._adjust_brightness_img(results, enhance_level_to_value(l)) + original_img[:,:,d] = results['img'] + else: + self._adjust_brightness_img(results, enhance_level_to_value(self.level)) + + results['img'] = original_img + + return results + + + +@PIPELINES.register_module() +class NormalizeMinMaxChannelwise(Normalize): + """Normalize the image channelwise. + """ + + def __init__(self): + pass + + def __call__(self, results): + """Call function to normalize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Normalized results, 'img_norm_cfg' key is added into + result dict. + """ + for key in results.get('img_fields', ['img']): + + for c in range(0, results[key].shape[-1]): + channel = results[key][:,:,c] + channel -= np.min(channel) + channel /= np.max(channel) + + results[key][:,:,c]=channel + + return results diff --git a/tests/test_downstream/test_mmdet_inference.py b/tests/test_downstream/test_mmdet_inference.py index 096c5db7d3f..cfafd03e83f 100644 --- a/tests/test_downstream/test_mmdet_inference.py +++ b/tests/test_downstream/test_mmdet_inference.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np from mmcv import Config -from mmdet.apis import inference_detector -from mmdet.models import build_detector +from mmcls.apis import inference_model +from mmcls.models import build_classifier from mmcls.models import (MobileNetV2, MobileNetV3, RegNet, ResNeSt, ResNet, ResNeXt, SEResNet, SEResNeXt, SwinTransformer, @@ -108,11 +108,11 @@ def test_mmdet_inference(): else: config.model.neck.in_channels = out_channels - model = build_detector(config.model) + model = build_classifier(config.model) assert isinstance(model.backbone, module) model.cfg = config model.eval() - result = inference_detector(model, img1) + result = inference_model(model, img1) assert len(result) == config.num_classes