From 807c8f6276fd9035e9281bb5b557e4cfb880b2f2 Mon Sep 17 00:00:00 2001 From: Yixiu Chen Date: Wed, 27 Nov 2024 19:53:31 +0800 Subject: [PATCH 1/5] [SW-206077] implemented fusedSDPA for stable diffusion (#36) Co-authored-by: Adam Stachowicz <105052242+astachowiczhabana@users.noreply.github.com> --- .../text_to_image_generation.py | 1 + .../pipeline_stable_diffusion.py | 63 ++++++++++++++++++- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/examples/stable-diffusion/text_to_image_generation.py b/examples/stable-diffusion/text_to_image_generation.py index c191b1982a..1fa1e9737b 100755 --- a/examples/stable-diffusion/text_to_image_generation.py +++ b/examples/stable-diffusion/text_to_image_generation.py @@ -546,6 +546,7 @@ def main(): args.model_name_or_path, **kwargs, ) + pipeline.unet.set_default_attn_processor(pipeline.unet) if args.unet_adapter_name_or_path is not None: from peft import PeftModel diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f0a7febc5f..34ba88c925 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -12,8 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import inspect +import os import time from dataclasses import dataclass from math import ceil @@ -30,6 +30,11 @@ from diffusers.utils import BaseOutput, deprecate from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from optimum.habana.diffusers.models.attention_processor import ( + AttentionProcessor, + AttnProcessor2_0, + ScaledDotProductAttention, +) from optimum.utils import logging from ....transformers.gaudi_configuration import GaudiConfig @@ -96,6 +101,60 @@ def retrieve_timesteps( return timesteps, num_inference_steps +# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor +def set_attn_processor_hpu(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if os.environ.get("PATCH_SDPA") is not None: + setattr(module, "attention_module", ScaledDotProductAttention()) + module.set_processor(processor(module.attention_module)) + else: + if isinstance(processor, dict): + attention_processor = processor.pop(f"{name}.processor", None) + if attention_processor is not None: + module.set_processor(attention_processor) + else: + module.set_processor(processor) + # else: + # raise ValueError(f"Unsupported processor type: {type(processor)}") + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + +# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor +def set_default_attn_processor_hpu(self): + """ + Disables custom attention processors and sets the default attention implementation from HPU. + """ + + processor = AttnProcessor2_0() + set_attn_processor_hpu(self, processor) + + class GaudiStableDiffusionPipeline(GaudiDiffusionPipeline, StableDiffusionPipeline): """ Adapted from: https://github.com/huggingface/diffusers/blob/v0.23.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L73 @@ -173,7 +232,7 @@ def __init__( image_encoder, requires_safety_checker, ) - + self.unet.set_default_attn_processor = set_default_attn_processor_hpu self.to(self._device) def prepare_latents(self, num_images, num_channels_latents, height, width, dtype, device, generator, latents=None): From 29c03807223647d1c438f9719cc9225ae3e4926b Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Tue, 3 Dec 2024 14:16:48 -0800 Subject: [PATCH 2/5] Update pipeline_stable_diffusion.py based on review --- .../stable_diffusion/pipeline_stable_diffusion.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 34ba88c925..579c5d9481 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -101,9 +101,9 @@ def retrieve_timesteps( return timesteps, num_inference_steps -# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor_hpu(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" + """ + Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): @@ -144,10 +144,9 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - -# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor_hpu(self): """ + Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor Disables custom attention processors and sets the default attention implementation from HPU. """ From d290d7fdea282ea8adbff6d43aa973c60a234bb6 Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Tue, 3 Dec 2024 16:00:17 -0800 Subject: [PATCH 3/5] Update pipeline_stable_diffusion.py --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 579c5d9481..36cb504311 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -135,8 +135,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): module.set_processor(attention_processor) else: module.set_processor(processor) - # else: - # raise ValueError(f"Unsupported processor type: {type(processor)}") for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) From 1ce04b45ecce2b9507b1fe506c30b0eb9e691c56 Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Tue, 3 Dec 2024 16:28:44 -0800 Subject: [PATCH 4/5] Update pipeline_stable_diffusion.py --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 36cb504311..1800a49f6b 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -104,6 +104,7 @@ def retrieve_timesteps( def set_attn_processor_hpu(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): """ Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + Added env PATCH_SDPA for HPU specific handle to use ScaledDotProductAttention. Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): From 45370b8942101c62dca83f69b92fe408e5901ec1 Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Wed, 4 Dec 2024 10:44:44 -0800 Subject: [PATCH 5/5] Make style change. --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 1800a49f6b..c081dd2cac 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -143,6 +143,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) + def set_default_attn_processor_hpu(self): """ Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor