Skip to content

Commit

Permalink
StableDiffusionDepth2ImgPipeline (huggingface#1531)
Browse files Browse the repository at this point in the history
* begin depth pipeline

* add depth estimation model

* fix prepare_depth_mask

* add a comment about autocast

* copied from, quality, cleanup

* begin tests

* handle tensors

* norm image tensor

* fix batch size

* fix tests

* fix enable_sequential_cpu_offload

* fix save load

* fix test_save_load_float16

* fix test_save_load_optional_components

* fix test_float16_inference

* fix test_cpu_offload_forward_pass

* fix test_dict_tuple_outputs_equivalent

* up

* fix fast tests

* fix test_stable_diffusion_img2img_multiple_init_images

* fix few more fast tests

* don't use device map for DPT

* fix test_stable_diffusion_pipeline_with_sequential_cpu_offloading

* accept external depth maps

* prepare_depth_mask -> prepare_depth_map

* fix file name

* fix file name

* quality

* check transformers version

* fix test names

* use skipif

* fix import

* add docs

* skip tests on mps

* correct version

* uP

* Update docs/source/api/pipelines/stable_diffusion_2.mdx

* fix fix-copies

* fix fix-copies

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: anton- <[email protected]>
  • Loading branch information
3 people authored Dec 8, 2022
1 parent dbe0719 commit 5383188
Show file tree
Hide file tree
Showing 15 changed files with 1,234 additions and 23 deletions.
8 changes: 8 additions & 0 deletions docs/source/api/pipelines/stable_diffusion.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention

## StableDiffusionDepth2ImgPipeline
[[autodoc]] StableDiffusionDepth2ImgPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention

## StableDiffusionImageVariationPipeline
[[autodoc]] StableDiffusionImageVariationPipeline
- __call__
Expand Down
32 changes: 32 additions & 0 deletions docs/source/api/pipelines/stable_diffusion_2.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Note that the architecture is more or less identical to [Stable Diffusion 1](./a
- *Text-to-Image (768x768 resolution)*: [stabilityai/stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) with [`StableDiffusionPipeline`]
- *Image Inpainting (512x512 resolution)*: [stabilityai/stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) with [`StableDiffusionInpaintPipeline`]
- *Image Upscaling (x4 resolution resolution)*: [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) [`StableDiffusionUpscalePipeline`]
- *Depth-to-Image (512x512 resolution)*: [stabilityai/stable-diffusion-2-depth](https://huggingface.co/stabilityai/stable-diffusion-2-depth) with [`StableDiffusionDepth2ImagePipeline`]

We recommend using the [`DPMSolverMultistepScheduler`] as it's currently the fastest scheduler there is.

Expand Down Expand Up @@ -125,6 +126,37 @@ upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
upscaled_image.save("upsampled_cat.png")
```

- *Depth-Guided Text-to-Image*: [stabilityai/stable-diffusion-2-depth](https://huggingface.co/stabilityai/stable-diffusion-2-depth) [`StableDiffusionDepth2ImagePipeline`]

**Installation**

```bash
!pip install -U git+https://github.com/huggingface/transformers.git
!pip install diffusers[torch]
```

**Example**

```python
import torch
import requests
from PIL import Image

from diffusers import StableDiffusionDepth2ImgPipeline

pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-depth",
torch_dtype=torch.float16,
).to("cuda")


url = "http://images.cocodataset.org/val2017/000000039769.jpg"
init_image = Image.open(requests.get(url, stream=True).raw)
prompt = "two tigers"
n_propmt = "bad, deformed, ugly, bad anotomy"
image = pipe(prompt=prompt, image=init_image, negative_prompt=n_propmt, strength=0.7).images[0]
```

### How to load and use different schedulers.

The stable diffusion pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
"tensorboard",
"torch>=1.4",
"torchvision",
"transformers>=4.21.0",
"transformers>=4.25.1",
]

# this is a lookup table with items like:
Expand Down
14 changes: 14 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,24 @@
is_scipy_available,
is_torch_available,
is_transformers_available,
is_transformers_version,
is_unidecode_available,
logging,
)


# Make sure `transformers` is up to date
if is_transformers_available():
import transformers

if is_transformers_version("<", "4.25.1"):
raise ImportError(
f"`diffusers` requires transformers >= 4.25.1 to function correctly, but {transformers.__version__} was"
" found in your environment. You can upgrade it with pip: `pip install transformers --upgrade`"
)
else:
pass

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -87,6 +100,7 @@
CycleDiffusionPipeline,
LDMTextToImagePipeline,
PaintByExamplePipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@
"tensorboard": "tensorboard",
"torch": "torch>=1.4",
"torchvision": "torchvision",
"transformers": "transformers>=4.21.0",
"transformers": "transformers>=4.25.1",
}
1 change: 1 addition & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .paint_by_example import PaintByExamplePipeline
from .stable_diffusion import (
CycleDiffusionPipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
Expand Down
12 changes: 11 additions & 1 deletion src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,23 @@ class StableDiffusionPipelineOutput(BaseOutput):
from .safety_checker import StableDiffusionSafetyChecker

try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0")):
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
else:
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline


try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0.dev0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionDepth2ImgPipeline
else:
from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline


try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
Expand Down
Loading

0 comments on commit 5383188

Please sign in to comment.