Skip to content

Commit

Permalink
Implemented fusedSDPA for stable diffusion (#36) (#1545)
Browse files Browse the repository at this point in the history
Co-authored-by: Yixiu Chen <[email protected]>
Co-authored-by: Libin Tang <[email protected]>
  • Loading branch information
3 people authored Dec 4, 2024
1 parent 623ca0d commit c579a70
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 2 deletions.
1 change: 1 addition & 0 deletions examples/stable-diffusion/text_to_image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def main():
args.model_name_or_path,
**kwargs,
)
pipeline.unet.set_default_attn_processor(pipeline.unet)

if args.unet_adapter_name_or_path is not None:
from peft import PeftModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import os
import time
from dataclasses import dataclass
from math import ceil
Expand All @@ -30,6 +30,11 @@
from diffusers.utils import BaseOutput, deprecate
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from optimum.habana.diffusers.models.attention_processor import (
AttentionProcessor,
AttnProcessor2_0,
ScaledDotProductAttention,
)
from optimum.utils import logging

from ....transformers.gaudi_configuration import GaudiConfig
Expand Down Expand Up @@ -96,6 +101,59 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


def set_attn_processor_hpu(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
"""
Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
Added env PATCH_SDPA for HPU specific handle to use ScaledDotProductAttention.
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""

count = len(self.attn_processors.keys())

if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)

def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if os.environ.get("PATCH_SDPA") is not None:
setattr(module, "attention_module", ScaledDotProductAttention())
module.set_processor(processor(module.attention_module))
else:
if isinstance(processor, dict):
attention_processor = processor.pop(f"{name}.processor", None)
if attention_processor is not None:
module.set_processor(attention_processor)
else:
module.set_processor(processor)

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)


def set_default_attn_processor_hpu(self):
"""
Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
Disables custom attention processors and sets the default attention implementation from HPU.
"""

processor = AttnProcessor2_0()
set_attn_processor_hpu(self, processor)


class GaudiStableDiffusionPipeline(GaudiDiffusionPipeline, StableDiffusionPipeline):
"""
Adapted from: https://github.com/huggingface/diffusers/blob/v0.23.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L73
Expand Down Expand Up @@ -177,7 +235,7 @@ def __init__(
image_encoder,
requires_safety_checker,
)

self.unet.set_default_attn_processor = set_default_attn_processor_hpu
self.to(self._device)

def prepare_latents(self, num_images, num_channels_latents, height, width, dtype, device, generator, latents=None):
Expand Down

0 comments on commit c579a70

Please sign in to comment.