diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f88af5972e..ecb6fc47e0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: - id: no-commit-to-branch args: ['--branch', 'main'] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.1 + rev: v0.8.3 hooks: - id: ruff args: [ --fix ] diff --git a/doctr/transforms/functional/pytorch.py b/doctr/transforms/functional/pytorch.py index a91ba7754b..f66279795b 100644 --- a/doctr/transforms/functional/pytorch.py +++ b/doctr/transforms/functional/pytorch.py @@ -7,6 +7,7 @@ import numpy as np import torch +from scipy.ndimage import gaussian_filter from torchvision.transforms import functional as F from doctr.utils.geometry import rotate_abs_geoms @@ -113,24 +114,24 @@ def crop_detection( def random_shadow(img: torch.Tensor, opacity_range: tuple[float, float], **kwargs) -> torch.Tensor: - """Crop and image and associated bboxes + """Apply a random shadow effect to an image using NumPy for blurring. Args: - img: image to modify - opacity_range: the minimum and maximum desired opacity of the shadow - **kwargs: additional arguments to pass to `create_shadow_mask` + img: Image to modify (C, H, W) as a PyTorch tensor. + opacity_range: The minimum and maximum desired opacity of the shadow. + **kwargs: Additional arguments to pass to `create_shadow_mask`. Returns: - shaded image + Shadowed image as a PyTorch tensor (same shape as input). """ shadow_mask = create_shadow_mask(img.shape[1:], **kwargs) - opacity = np.random.uniform(*opacity_range) - shadow_tensor = 1 - torch.from_numpy(shadow_mask[None, ...]) - # Add some blur to make it believable - k = 7 + 2 * int(4 * np.random.rand(1)) + # Apply Gaussian blur to the shadow mask sigma = np.random.uniform(0.5, 5.0) - shadow_tensor = F.gaussian_blur(shadow_tensor, k, sigma=[sigma, sigma]) + blurred_mask = gaussian_filter(shadow_mask, sigma=sigma) + + shadow_tensor = 1 - torch.from_numpy(blurred_mask).float() + shadow_tensor = shadow_tensor.to(img.device).unsqueeze(0) # Add channel dimension return opacity * shadow_tensor * img + (1 - opacity) * img diff --git a/doctr/transforms/modules/pytorch.py b/doctr/transforms/modules/pytorch.py index 027998412d..ee989a5949 100644 --- a/doctr/transforms/modules/pytorch.py +++ b/doctr/transforms/modules/pytorch.py @@ -8,13 +8,22 @@ import numpy as np import torch from PIL.Image import Image +from scipy.ndimage import gaussian_filter from torch.nn.functional import pad from torchvision.transforms import functional as F from torchvision.transforms import transforms as T from ..functional.pytorch import random_shadow -__all__ = ["Resize", "GaussianNoise", "ChannelShuffle", "RandomHorizontalFlip", "RandomShadow", "RandomResize"] +__all__ = [ + "Resize", + "GaussianNoise", + "ChannelShuffle", + "RandomHorizontalFlip", + "RandomShadow", + "RandomResize", + "GaussianBlur", +] class Resize(T.Resize): @@ -142,6 +151,39 @@ def extra_repr(self) -> str: return f"mean={self.mean}, std={self.std}" +class GaussianBlur(torch.nn.Module): + """Apply Gaussian Blur to the input tensor + + >>> import torch + >>> from doctr.transforms import GaussianBlur + >>> transfo = GaussianBlur(sigma=(0.0, 1.0)) + + Args: + sigma : standard deviation range for the gaussian kernel + """ + + def __init__(self, sigma: tuple[float, float]) -> None: + super().__init__() + self.sigma_range = sigma + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Sample a random sigma value within the specified range + sigma = torch.empty(1).uniform_(*self.sigma_range).item() + + # Apply Gaussian blur along spatial dimensions only + blurred = torch.tensor( + gaussian_filter( + x.numpy(), + sigma=sigma, + mode="reflect", + truncate=4.0, + ), + dtype=x.dtype, + device=x.device, + ) + return blurred + + class ChannelShuffle(torch.nn.Module): """Randomly shuffle channel order of a given image""" diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index 8d1bfa4499..0207ccaff8 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -18,7 +18,7 @@ import wandb from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR from torch.utils.data import DataLoader, RandomSampler, SequentialSampler -from torchvision.transforms.v2 import Compose, GaussianBlur, Normalize, RandomGrayscale, RandomPhotometricDistort +from torchvision.transforms.v2 import Compose, Normalize, RandomGrayscale, RandomPhotometricDistort from tqdm.auto import tqdm from doctr import transforms as T @@ -261,12 +261,12 @@ def main(args): img_transforms = T.OneOf([ Compose([ T.RandomApply(T.ColorInversion(), 0.3), - T.RandomApply(GaussianBlur(kernel_size=5, sigma=(0.1, 4)), 0.2), + T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.2), ]), Compose([ T.RandomApply(T.RandomShadow(), 0.3), T.RandomApply(T.GaussianNoise(), 0.1), - T.RandomApply(GaussianBlur(kernel_size=5, sigma=(0.1, 4)), 0.3), + T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3), RandomGrayscale(p=0.15), ]), RandomPhotometricDistort(p=0.3), diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py index 561447c5f7..bb6ebc9cfe 100644 --- a/references/detection/train_tensorflow.py +++ b/references/detection/train_tensorflow.py @@ -212,13 +212,13 @@ def main(args): img_transforms = T.OneOf([ T.Compose([ T.RandomApply(T.ColorInversion(), 0.3), - T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.1, 4)), 0.2), + T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.5, 1.5)), 0.2), ]), T.Compose([ T.RandomApply(T.RandomJpegQuality(60), 0.15), # T.RandomApply(T.RandomShadow(), 0.2), # Broken atm on GPU T.RandomApply(T.GaussianNoise(), 0.1), - T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.1, 4)), 0.3), + T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.5, 1.5)), 0.3), T.RandomApply(T.ToGray(num_output_channels=3), 0.15), ]), T.Compose([ diff --git a/tests/pytorch/test_transforms_pt.py b/tests/pytorch/test_transforms_pt.py index 15e60c6f1c..6f273de13c 100644 --- a/tests/pytorch/test_transforms_pt.py +++ b/tests/pytorch/test_transforms_pt.py @@ -7,6 +7,7 @@ from doctr.transforms import ( ChannelShuffle, ColorInversion, + GaussianBlur, GaussianNoise, RandomCrop, RandomHorizontalFlip, @@ -278,6 +279,38 @@ def test_gaussian_noise(input_dtype, input_shape): assert torch.all(transformed <= 1.0) +@pytest.mark.parametrize( + "input_dtype, input_shape", + [ + [torch.float32, (3, 32, 32)], + [torch.uint8, (3, 32, 32)], + ], +) +def test_gaussian_blur(input_dtype, input_shape): + sigma_range = (0.0, 1.0) + transform = GaussianBlur(sigma=sigma_range) + + input_t = torch.rand(input_shape, dtype=torch.float32) + + if input_dtype == torch.uint8: + input_t = (255 * input_t).round().to(dtype=torch.uint8) + + blurred = transform(input_t) + + assert isinstance(blurred, torch.Tensor) + assert blurred.shape == input_shape + assert blurred.dtype == input_dtype + + if input_dtype == torch.uint8: + assert torch.any(blurred != input_t) + assert torch.all(blurred <= 255) + assert torch.all(blurred >= 0) + else: + assert torch.any(blurred != input_t) + assert torch.all(blurred <= 1.0) + assert torch.all(blurred >= 0.0) + + @pytest.mark.parametrize( "p,target", [