Skip to content

Commit

Permalink
Upgrade to work with latest Diffusers (#41)
Browse files Browse the repository at this point in the history
Co-authored-by: ivy <[email protected]>
  • Loading branch information
wangdong2023 and wangdong-ivymobile authored Jun 8, 2023
1 parent f2466b8 commit 0dc208b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
8 changes: 4 additions & 4 deletions daam/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import itertools

from diffusers import UNet2DConditionModel
from diffusers.models.attention import CrossAttention
from diffusers.models.attention_processor import Attention
import torch.nn as nn


Expand Down Expand Up @@ -82,21 +82,21 @@ def register_hook(self, hook: ObjectHooker):
self.module.append(hook)


class UNetCrossAttentionLocator(ModuleLocator[CrossAttention]):
class UNetCrossAttentionLocator(ModuleLocator[Attention]):
def __init__(self, restrict: bool = None, locate_middle_block: bool = False):
self.restrict = restrict
self.layer_names = []
self.locate_middle_block = locate_middle_block

def locate(self, model: UNet2DConditionModel) -> List[CrossAttention]:
def locate(self, model: UNet2DConditionModel) -> List[Attention]:
"""
Locate all cross-attention modules in a UNet2DConditionModel.
Args:
model (`UNet2DConditionModel`): The model to locate the cross-attention modules in.
Returns:
`List[CrossAttention]`: The list of cross-attention modules.
`List[Attention]`: The list of cross-attention modules.
"""
self.layer_names.clear()
blocks_list = []
Expand Down
10 changes: 5 additions & 5 deletions daam/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math

from diffusers import StableDiffusionPipeline
from diffusers.models.attention import CrossAttention
from diffusers.models.attention_processor import Attention
import numpy as np
import PIL.Image as Image
import torch
Expand Down Expand Up @@ -161,10 +161,10 @@ def _hook_impl(self):
self.monkey_patch('_encode_prompt', self._hooked_encode_prompt)


class UNetCrossAttentionHooker(ObjectHooker[CrossAttention]):
class UNetCrossAttentionHooker(ObjectHooker[Attention]):
def __init__(
self,
module: CrossAttention,
module: Attention,
parent_trace: 'trace',
context_size: int = 77,
layer_idx: int = 0,
Expand Down Expand Up @@ -226,7 +226,7 @@ def _load_attn(self) -> torch.Tensor:

def __call__(
self,
attn: CrossAttention,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
Expand All @@ -238,7 +238,7 @@ def __call__(

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
elif attn.norm_cross is not None:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
scikit-image
diffusers==0.14.0
diffusers==0.16.1
spacy
gradio
ftfy
Expand Down

0 comments on commit 0dc208b

Please sign in to comment.