Skip to content

Commit

Permalink
Add improved handling of pil (huggingface#1309)
Browse files Browse the repository at this point in the history
* Better error message for transformers dummy

* [PIL] Better deprecation functionality

* up
  • Loading branch information
patrickvonplaten authored Nov 16, 2022
1 parent 46893ad commit 65d136e
Show file tree
Hide file tree
Showing 20 changed files with 64 additions and 39 deletions.
4 changes: 2 additions & 2 deletions docs/source/api/pipelines/latent_diffusion.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff


## LDMTextToImagePipeline
[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion.LDMTextToImagePipeline
[[autodoc]] LDMTextToImagePipeline
- __call__

## LDMSuperResolutionPipeline
[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion_superresolution.LDMSuperResolutionPipeline
[[autodoc]] LDMSuperResolutionPipeline
- __call__
1 change: 1 addition & 0 deletions docs/source/api/pipelines/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ available a colab notebook to directly try them out.
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ available a colab notebook to directly try them out.
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
Expand Down
4 changes: 2 additions & 2 deletions examples/community/imagic_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import logging
from diffusers.utils import PIL_INTERPOLATION, logging
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

Expand All @@ -28,7 +28,7 @@
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
Expand Down
6 changes: 3 additions & 3 deletions examples/community/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, is_accelerate_available, logging
from diffusers.utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer


Expand Down Expand Up @@ -358,7 +358,7 @@ def get_weighted_text_embeddings(
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
Expand All @@ -369,7 +369,7 @@ def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
Expand Down
6 changes: 3 additions & 3 deletions examples/community/lpw_stable_diffusion_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import logging
from diffusers.utils import PIL_INTERPOLATION, logging
from transformers import CLIPFeatureExtractor, CLIPTokenizer


Expand Down Expand Up @@ -365,7 +365,7 @@ def get_weighted_text_embeddings(
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
return 2.0 * image - 1.0
Expand All @@ -375,7 +375,7 @@ def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
Expand Down
10 changes: 5 additions & 5 deletions examples/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
import torch.utils.checkpoint
from torch.utils.data import Dataset

import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.utils import PIL_INTERPOLATION
from huggingface_hub import HfFolder, Repository, whoami
from PIL import Image
from torchvision import transforms
Expand Down Expand Up @@ -260,10 +260,10 @@ def __init__(
self._length = self.num_images * repeats

self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"linear": PIL_INTERPOLATION["linear"],
"bilinear": PIL_INTERPOLATION["bilinear"],
"bicubic": PIL_INTERPOLATION["bicubic"],
"lanczos": PIL_INTERPOLATION["lanczos"],
}[interpolation]

self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
Expand Down
10 changes: 5 additions & 5 deletions examples/textual_inversion/textual_inversion_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import jax
import jax.numpy as jnp
import optax
import PIL
import transformers
from diffusers import (
FlaxAutoencoderKL,
Expand All @@ -24,6 +23,7 @@
FlaxUNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
from diffusers.utils import PIL_INTERPOLATION
from flax import jax_utils
from flax.training import train_state
from flax.training.common_utils import shard
Expand Down Expand Up @@ -246,10 +246,10 @@ def __init__(
self._length = self.num_images * repeats

self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"linear": PIL_INTERPOLATION["linear"],
"bilinear": PIL_INTERPOLATION["bilinear"],
"bicubic": PIL_INTERPOLATION["bicubic"],
"lanczos": PIL_INTERPOLATION["lanczos"],
}[interpolation]

self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
# 1. all dependencies should be listed here with their version requirements if any
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
_deps = [
"Pillow<10.0", # keep the PIL.Image.Resampling deprecation away
"Pillow", # keep the PIL.Image.Resampling deprecation away
"accelerate>=0.11.0",
"black==22.8",
"datasets",
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# 1. modify the `_deps` dict in setup.py
# 2. run `make deps_table_update``
deps = {
"Pillow": "Pillow<10.0",
"Pillow": "Pillow",
"accelerate": "accelerate>=0.11.0",
"black": "black==22.8",
"datasets": "datasets",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,19 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, logging
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import PIL_INTERPOLATION


def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler
from ...utils import deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker

Expand All @@ -37,7 +37,7 @@
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput


Expand All @@ -35,7 +35,7 @@
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
return 2.0 * image - 1.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput


Expand All @@ -44,7 +44,7 @@ def prepare_mask_and_masked_image(image, mask, latents_shape):
image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
masked_image = image * (image_mask < 127.5)

mask = mask.resize((latents_shape[1], latents_shape[0]), PIL.Image.NEAREST)
mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"])
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker

Expand All @@ -44,7 +44,7 @@
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker

Expand All @@ -44,7 +44,7 @@
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
Expand All @@ -55,7 +55,7 @@ def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from .logging import get_logger
from .outputs import BaseOutput
from .pil_utils import PIL_INTERPOLATION


if is_torch_available():
Expand Down
21 changes: 21 additions & 0 deletions src/diffusers/utils/pil_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import PIL.Image
import PIL.ImageOps
from packaging import version


if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
"linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
"nearest": PIL.Image.Resampling.NEAREST,
}
else:
PIL_INTERPOLATION = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"nearest": PIL.Image.NEAREST,
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
import numpy as np
import torch

import PIL
from diffusers import DDIMScheduler, LDMSuperResolutionPipeline, UNet2DModel, VQModel
from diffusers.utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils import PIL_INTERPOLATION, floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import require_torch

from ...test_pipelines_common import PipelineTesterMixin
Expand Down Expand Up @@ -97,7 +96,7 @@ def test_inference_superresolution(self):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/vq_diffusion/teddy_bear_pool.png"
)
init_image = init_image.resize((64, 64), resample=PIL.Image.LANCZOS)
init_image = init_image.resize((64, 64), resample=PIL_INTERPOLATION["lanczos"])

ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto")
ldm.to(torch_device)
Expand Down

0 comments on commit 65d136e

Please sign in to comment.