From 0dc208b32e12c2871d61572caad1c52c31f79dc1 Mon Sep 17 00:00:00 2001 From: Dong Wang Date: Thu, 8 Jun 2023 11:57:38 +0800 Subject: [PATCH] Upgrade to work with latest Diffusers (#41) Co-authored-by: ivy --- daam/hook.py | 8 ++++---- daam/trace.py | 10 +++++----- requirements.txt | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/daam/hook.py b/daam/hook.py index 060cc9b..ffea6bc 100644 --- a/daam/hook.py +++ b/daam/hook.py @@ -3,7 +3,7 @@ import itertools from diffusers import UNet2DConditionModel -from diffusers.models.attention import CrossAttention +from diffusers.models.attention_processor import Attention import torch.nn as nn @@ -82,13 +82,13 @@ def register_hook(self, hook: ObjectHooker): self.module.append(hook) -class UNetCrossAttentionLocator(ModuleLocator[CrossAttention]): +class UNetCrossAttentionLocator(ModuleLocator[Attention]): def __init__(self, restrict: bool = None, locate_middle_block: bool = False): self.restrict = restrict self.layer_names = [] self.locate_middle_block = locate_middle_block - def locate(self, model: UNet2DConditionModel) -> List[CrossAttention]: + def locate(self, model: UNet2DConditionModel) -> List[Attention]: """ Locate all cross-attention modules in a UNet2DConditionModel. @@ -96,7 +96,7 @@ def locate(self, model: UNet2DConditionModel) -> List[CrossAttention]: model (`UNet2DConditionModel`): The model to locate the cross-attention modules in. Returns: - `List[CrossAttention]`: The list of cross-attention modules. + `List[Attention]`: The list of cross-attention modules. """ self.layer_names.clear() blocks_list = [] diff --git a/daam/trace.py b/daam/trace.py index faa599d..6932a4a 100644 --- a/daam/trace.py +++ b/daam/trace.py @@ -3,7 +3,7 @@ import math from diffusers import StableDiffusionPipeline -from diffusers.models.attention import CrossAttention +from diffusers.models.attention_processor import Attention import numpy as np import PIL.Image as Image import torch @@ -161,10 +161,10 @@ def _hook_impl(self): self.monkey_patch('_encode_prompt', self._hooked_encode_prompt) -class UNetCrossAttentionHooker(ObjectHooker[CrossAttention]): +class UNetCrossAttentionHooker(ObjectHooker[Attention]): def __init__( self, - module: CrossAttention, + module: Attention, parent_trace: 'trace', context_size: int = 77, layer_idx: int = 0, @@ -226,7 +226,7 @@ def _load_attn(self) -> torch.Tensor: def __call__( self, - attn: CrossAttention, + attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, @@ -238,7 +238,7 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states - elif attn.cross_attention_norm: + elif attn.norm_cross is not None: encoder_hidden_states = attn.norm_cross(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) diff --git a/requirements.txt b/requirements.txt index 17e2c96..8f15b3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ scikit-image -diffusers==0.14.0 +diffusers==0.16.1 spacy gradio ftfy