Skip to content

Commit

Permalink
[K Diffusion] Add k diffusion sampler natively (huggingface#1603)
Browse files Browse the repository at this point in the history
* uP

* uP
  • Loading branch information
patrickvonplaten authored Dec 8, 2022
1 parent 326de41 commit a643c63
Show file tree
Hide file tree
Showing 13 changed files with 602 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
pipe = pipe.to("cuda")

prompt = "an astronaut riding a horse on mars"
pipe.set_sampler("sample_heun")
pipe.set_scheduler("sample_heun")
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]

Expand Down Expand Up @@ -721,7 +721,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")

pipe.set_sampler("sample_euler")
pipe.set_scheduler("sample_euler")
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
```
Expand Down
5 changes: 5 additions & 0 deletions examples/community/sd_text2img_k_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import importlib
import warnings
from typing import Callable, List, Optional, Union

import torch
Expand Down Expand Up @@ -111,6 +112,10 @@ def __init__(
self.k_diffusion_model = CompVisDenoiser(model)

def set_sampler(self, scheduler_type: str):
warnings.warn("The `set_sampler` method is deprecated, please use `set_scheduler` instead.")
return self.set_scheduler(scheduler_type)

def set_scheduler(self, scheduler_type: str):
library = importlib.import_module("k_diffusion")
sampling = getattr(library, "sampling")
self.sampler = getattr(sampling, scheduler_type)
Expand Down
1 change: 1 addition & 0 deletions hi
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2",
"jaxlib>=0.1.65",
"k-diffusion",
"librosa",
"modelcards>=0.1.4",
"numpy",
Expand Down Expand Up @@ -182,6 +183,7 @@ def run(self):
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
extras["test"] = deps_list(
"datasets",
"k-diffusion",
"librosa",
"parameterized",
"pytest",
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .utils import (
is_flax_available,
is_inflect_available,
is_k_diffusion_available,
is_onnx_available,
is_scipy_available,
is_torch_available,
Expand Down Expand Up @@ -90,6 +91,11 @@
else:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403

if is_torch_available() and is_transformers_available() and is_k_diffusion_available():
from .pipelines import StableDiffusionKDiffusionPipeline
else:
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403

if is_torch_available() and is_transformers_available() and is_onnx_available():
from .pipelines import (
OnnxStableDiffusionImg2ImgPipeline,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2",
"jaxlib": "jaxlib>=0.1.65",
"k-diffusion": "k-diffusion",
"librosa": "librosa",
"modelcards": "modelcards>=0.1.4",
"numpy": "numpy",
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ..utils import (
is_flax_available,
is_k_diffusion_available,
is_librosa_available,
is_onnx_available,
is_torch_available,
Expand Down Expand Up @@ -56,5 +57,8 @@
StableDiffusionOnnxPipeline,
)

if is_torch_available() and is_transformers_available() and is_k_diffusion_available():
from .stable_diffusion import StableDiffusionKDiffusionPipeline

if is_transformers_available() and is_flax_available():
from .stable_diffusion import FlaxStableDiffusionPipeline
4 changes: 4 additions & 0 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ...utils import (
BaseOutput,
is_flax_available,
is_k_diffusion_available,
is_onnx_available,
is_torch_available,
is_transformers_available,
Expand Down Expand Up @@ -48,6 +49,9 @@ class StableDiffusionPipelineOutput(BaseOutput):
else:
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline

if is_transformers_available() and is_torch_available() and is_k_diffusion_available():
from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline

if is_transformers_available() and is_onnx_available():
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
Expand Down
Loading

0 comments on commit a643c63

Please sign in to comment.