Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented fusedSDPA for stable diffusion (#36) #1545

Merged
merged 5 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/stable-diffusion/text_to_image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ def main():
args.model_name_or_path,
**kwargs,
)
pipeline.unet.set_default_attn_processor(pipeline.unet)

if args.unet_adapter_name_or_path is not None:
from peft import PeftModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import os
import time
from dataclasses import dataclass
from math import ceil
Expand All @@ -30,6 +30,11 @@
from diffusers.utils import BaseOutput, deprecate
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from optimum.habana.diffusers.models.attention_processor import (
AttentionProcessor,
AttnProcessor2_0,
ScaledDotProductAttention,
)
from optimum.utils import logging

from ....transformers.gaudi_configuration import GaudiConfig
Expand Down Expand Up @@ -96,6 +101,59 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


def set_attn_processor_hpu(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
"""
Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
Added env PATCH_SDPA for HPU specific handle to use ScaledDotProductAttention.
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.

If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.

"""

count = len(self.attn_processors.keys())

if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)

def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if os.environ.get("PATCH_SDPA") is not None:
setattr(module, "attention_module", ScaledDotProductAttention())
module.set_processor(processor(module.attention_module))
else:
if isinstance(processor, dict):
attention_processor = processor.pop(f"{name}.processor", None)
if attention_processor is not None:
module.set_processor(attention_processor)
else:
module.set_processor(processor)

for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)


def set_default_attn_processor_hpu(self):
"""
Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
Disables custom attention processors and sets the default attention implementation from HPU.
"""

processor = AttnProcessor2_0()
set_attn_processor_hpu(self, processor)


class GaudiStableDiffusionPipeline(GaudiDiffusionPipeline, StableDiffusionPipeline):
"""
Adapted from: https://github.com/huggingface/diffusers/blob/v0.23.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L73
Expand Down Expand Up @@ -173,7 +231,7 @@ def __init__(
image_encoder,
requires_safety_checker,
)

self.unet.set_default_attn_processor = set_default_attn_processor_hpu
self.to(self._device)

def prepare_latents(self, num_images, num_channels_latents, height, width, dtype, device, generator, latents=None):
Expand Down