Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Replace mem leaking torch gaussian_blur in augmentations #1822

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.1
rev: v0.8.3
hooks:
- id: ruff
args: [ --fix ]
Expand Down
21 changes: 11 additions & 10 deletions doctr/transforms/functional/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import torch
from scipy.ndimage import gaussian_filter
from torchvision.transforms import functional as F

from doctr.utils.geometry import rotate_abs_geoms
Expand Down Expand Up @@ -113,24 +114,24 @@ def crop_detection(


def random_shadow(img: torch.Tensor, opacity_range: tuple[float, float], **kwargs) -> torch.Tensor:
"""Crop and image and associated bboxes
"""Apply a random shadow effect to an image using NumPy for blurring.

Args:
img: image to modify
opacity_range: the minimum and maximum desired opacity of the shadow
**kwargs: additional arguments to pass to `create_shadow_mask`
img: Image to modify (C, H, W) as a PyTorch tensor.
opacity_range: The minimum and maximum desired opacity of the shadow.
**kwargs: Additional arguments to pass to `create_shadow_mask`.

Returns:
shaded image
Shadowed image as a PyTorch tensor (same shape as input).
"""
shadow_mask = create_shadow_mask(img.shape[1:], **kwargs)

opacity = np.random.uniform(*opacity_range)
shadow_tensor = 1 - torch.from_numpy(shadow_mask[None, ...])

# Add some blur to make it believable
k = 7 + 2 * int(4 * np.random.rand(1))
# Apply Gaussian blur to the shadow mask
sigma = np.random.uniform(0.5, 5.0)
shadow_tensor = F.gaussian_blur(shadow_tensor, k, sigma=[sigma, sigma])
blurred_mask = gaussian_filter(shadow_mask, sigma=sigma)

shadow_tensor = 1 - torch.from_numpy(blurred_mask).float()
shadow_tensor = shadow_tensor.to(img.device).unsqueeze(0) # Add channel dimension

return opacity * shadow_tensor * img + (1 - opacity) * img
44 changes: 43 additions & 1 deletion doctr/transforms/modules/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,22 @@
import numpy as np
import torch
from PIL.Image import Image
from scipy.ndimage import gaussian_filter
from torch.nn.functional import pad
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T

from ..functional.pytorch import random_shadow

__all__ = ["Resize", "GaussianNoise", "ChannelShuffle", "RandomHorizontalFlip", "RandomShadow", "RandomResize"]
__all__ = [
"Resize",
"GaussianNoise",
"ChannelShuffle",
"RandomHorizontalFlip",
"RandomShadow",
"RandomResize",
"GaussianBlur",
]


class Resize(T.Resize):
Expand Down Expand Up @@ -142,6 +151,39 @@ def extra_repr(self) -> str:
return f"mean={self.mean}, std={self.std}"


class GaussianBlur(torch.nn.Module):
"""Apply Gaussian Blur to the input tensor

>>> import torch
>>> from doctr.transforms import GaussianBlur
>>> transfo = GaussianBlur(sigma=(0.0, 1.0))

Args:
sigma : standard deviation range for the gaussian kernel
"""

def __init__(self, sigma: tuple[float, float]) -> None:
super().__init__()
self.sigma_range = sigma

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Sample a random sigma value within the specified range
sigma = torch.empty(1).uniform_(*self.sigma_range).item()

# Apply Gaussian blur along spatial dimensions only
blurred = torch.tensor(
gaussian_filter(
x.numpy(),
sigma=sigma,
mode="reflect",
truncate=4.0,
),
dtype=x.dtype,
device=x.device,
)
return blurred


class ChannelShuffle(torch.nn.Module):
"""Randomly shuffle channel order of a given image"""

Expand Down
6 changes: 3 additions & 3 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import wandb
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR, PolynomialLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torchvision.transforms.v2 import Compose, GaussianBlur, Normalize, RandomGrayscale, RandomPhotometricDistort
from torchvision.transforms.v2 import Compose, Normalize, RandomGrayscale, RandomPhotometricDistort
from tqdm.auto import tqdm

from doctr import transforms as T
Expand Down Expand Up @@ -261,12 +261,12 @@ def main(args):
img_transforms = T.OneOf([
Compose([
T.RandomApply(T.ColorInversion(), 0.3),
T.RandomApply(GaussianBlur(kernel_size=5, sigma=(0.1, 4)), 0.2),
T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.2),
]),
Compose([
T.RandomApply(T.RandomShadow(), 0.3),
T.RandomApply(T.GaussianNoise(), 0.1),
T.RandomApply(GaussianBlur(kernel_size=5, sigma=(0.1, 4)), 0.3),
T.RandomApply(T.GaussianBlur(sigma=(0.5, 1.5)), 0.3),
RandomGrayscale(p=0.15),
]),
RandomPhotometricDistort(p=0.3),
Expand Down
4 changes: 2 additions & 2 deletions references/detection/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,13 @@ def main(args):
img_transforms = T.OneOf([
T.Compose([
T.RandomApply(T.ColorInversion(), 0.3),
T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.1, 4)), 0.2),
T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.5, 1.5)), 0.2),
]),
T.Compose([
T.RandomApply(T.RandomJpegQuality(60), 0.15),
# T.RandomApply(T.RandomShadow(), 0.2), # Broken atm on GPU
T.RandomApply(T.GaussianNoise(), 0.1),
T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.1, 4)), 0.3),
T.RandomApply(T.GaussianBlur(kernel_shape=5, std=(0.5, 1.5)), 0.3),
T.RandomApply(T.ToGray(num_output_channels=3), 0.15),
]),
T.Compose([
Expand Down
33 changes: 33 additions & 0 deletions tests/pytorch/test_transforms_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from doctr.transforms import (
ChannelShuffle,
ColorInversion,
GaussianBlur,
GaussianNoise,
RandomCrop,
RandomHorizontalFlip,
Expand Down Expand Up @@ -278,6 +279,38 @@ def test_gaussian_noise(input_dtype, input_shape):
assert torch.all(transformed <= 1.0)


@pytest.mark.parametrize(
"input_dtype, input_shape",
[
[torch.float32, (3, 32, 32)],
[torch.uint8, (3, 32, 32)],
],
)
def test_gaussian_blur(input_dtype, input_shape):
sigma_range = (0.0, 1.0)
transform = GaussianBlur(sigma=sigma_range)

input_t = torch.rand(input_shape, dtype=torch.float32)

if input_dtype == torch.uint8:
input_t = (255 * input_t).round().to(dtype=torch.uint8)

blurred = transform(input_t)

assert isinstance(blurred, torch.Tensor)
assert blurred.shape == input_shape
assert blurred.dtype == input_dtype

if input_dtype == torch.uint8:
assert torch.any(blurred != input_t)
assert torch.all(blurred <= 255)
assert torch.all(blurred >= 0)
else:
assert torch.any(blurred != input_t)
assert torch.all(blurred <= 1.0)
assert torch.all(blurred >= 0.0)


@pytest.mark.parametrize(
"p,target",
[
Expand Down
Loading