From 09564e048cf5ffc1523aa0605736e7f178ca505a Mon Sep 17 00:00:00 2001 From: Raphael Tang Date: Sun, 7 Jan 2024 15:46:19 -0500 Subject: [PATCH] Add Stable Diffusion XL support (#56) --- README.md | 13 ++++++------ daam/_version.py | 2 +- daam/hook.py | 10 +++++++--- daam/run/generate.py | 19 ++++++++++++++---- daam/trace.py | 47 +++++++++++++++++++++++++++++++++----------- daam/utils.py | 6 +++++- requirements.txt | 6 +++--- 7 files changed, 74 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 8eb1bba..73de368 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ ![example image](example.jpg) -### Updated to support Diffusers 0.16.1! +### Updated to support Stable Diffusion XL (SDXL) and Diffusers 0.21.1! I regularly update this codebase. Please submit an issue if you have any questions. @@ -33,6 +33,7 @@ dog.heat_map.png running.heat_map.png prompt.txt ``` Your current working directory will now contain the generated image as `output.png` and a DAAM map for every word, as well as some auxiliary data. You can see more options for `daam` by running `daam -h`. +To use Stable Diffusion XL as the backend, run `daam --model xl-base-1.0 "Dog jumping"`. ### Using DAAM as a Library @@ -40,23 +41,23 @@ Import and use DAAM as follows: ```python from daam import trace, set_seed -from diffusers import StableDiffusionPipeline +from diffusers import DiffusionPipeline from matplotlib import pyplot as plt import torch -model_id = 'stabilityai/stable-diffusion-2-base' +model_id = 'stabilityai/stable-diffusion-xl-base-1.0' device = 'cuda' -pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True) +pipe = DiffusionPipeline.from_pretrained(model_id, use_auth_token=True, torch_dtype=torch.float16, use_safetensors=True, variant='fp16') pipe = pipe.to(device) prompt = 'A dog runs across the field' gen = set_seed(0) # for reproducibility -with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad(): +with torch.no_grad(): with trace(pipe) as tc: - out = pipe(prompt, num_inference_steps=30, generator=gen) + out = pipe(prompt, num_inference_steps=50, generator=gen) heat_map = tc.compute_global_heat_map() heat_map = heat_map.compute_word_heat_map('dog') heat_map.plot_overlay(out.images[0]) diff --git a/daam/_version.py b/daam/_version.py index b794fd4..7fd229a 100644 --- a/daam/_version.py +++ b/daam/_version.py @@ -1 +1 @@ -__version__ = '0.1.0' +__version__ = '0.2.0' diff --git a/daam/hook.py b/daam/hook.py index 320ba6f..f82762c 100644 --- a/daam/hook.py +++ b/daam/hook.py @@ -55,9 +55,13 @@ def unhook(self): return self - def monkey_patch(self, fn_name, fn): - self.old_state[f'old_fn_{fn_name}'] = getattr(self.module, fn_name) - setattr(self.module, fn_name, functools.partial(fn, self.module)) + def monkey_patch(self, fn_name, fn, strict: bool = True): + try: + self.old_state[f'old_fn_{fn_name}'] = getattr(self.module, fn_name) + setattr(self.module, fn_name, functools.partial(fn, self.module)) + except AttributeError: + if strict: + raise def monkey_super(self, fn_name, *args, **kwargs): return self.old_state[f'old_fn_{fn_name}'](*args, **kwargs) diff --git a/daam/run/generate.py b/daam/run/generate.py index 0ad063d..c191b51 100644 --- a/daam/run/generate.py +++ b/daam/run/generate.py @@ -7,7 +7,7 @@ import time import pandas as pd -from diffusers import StableDiffusionPipeline +from diffusers import StableDiffusionPipeline, DiffusionPipeline from tqdm import tqdm import inflect import numpy as np @@ -25,7 +25,8 @@ def main(): 'v2-base': 'stabilityai/stable-diffusion-2-base', 'v2-large': 'stabilityai/stable-diffusion-2', 'v2-1-base': 'stabilityai/stable-diffusion-2-1-base', - 'v2-1-large': 'stabilityai/stable-diffusion-2-1' + 'v2-1-large': 'stabilityai/stable-diffusion-2-1', + 'xl-base-1.0': 'stabilityai/stable-diffusion-xl-base-1.0', } parser = argparse.ArgumentParser() @@ -192,10 +193,20 @@ def main(): prompts = new_prompts prompts = prompts[:args.gen_limit] - pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True) + + if 'xl' in model_id: + pipe = DiffusionPipeline.from_pretrained( + model_id, + use_auth_token=True, + torch_dtype=torch.float16, + use_safetensors=True, variant='fp16' + ) + else: + pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True) + pipe = auto_device(pipe) - with auto_autocast(dtype=torch.float16), torch.no_grad(): + with torch.no_grad(): for gen_idx, (prompt_id, prompt) in enumerate(tqdm(prompts)): seed = int(time.time()) if args.random_seed else args.seed prompt = prompt.replace(',', ' ,').replace('.', ' .').strip() diff --git a/daam/trace.py b/daam/trace.py index 54b9f71..bffdbbb 100644 --- a/daam/trace.py +++ b/daam/trace.py @@ -2,7 +2,8 @@ from typing import List, Type, Any, Dict, Tuple, Union import math -from diffusers import StableDiffusionPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline +from diffusers.image_processor import VaeImageProcessor from diffusers.models.attention_processor import Attention import numpy as np import PIL.Image as Image @@ -21,8 +22,7 @@ class DiffusionHeatMapHooker(AggregateHooker): def __init__( self, - pipeline: - StableDiffusionPipeline, + pipeline: Union[StableDiffusionPipeline, StableDiffusionXLPipeline], low_memory: bool = False, load_heads: bool = False, save_heads: bool = False, @@ -30,7 +30,7 @@ def __init__( ): self.all_heat_maps = RawHeatMapCollection() h = (pipeline.unet.config.sample_size * pipeline.vae_scale_factor) - self.latent_hw = 4096 if h == 512 else 9216 # 64x64 or 96x96 depending on if it's 2.0-v or 2.0 + self.latent_hw = 4096 if h == 512 or h == 1024 else 9216 # 64x64 or 96x96 depending on if it's 2.0-v or 2.0 locate_middle = load_heads or save_heads self.locator = UNetCrossAttentionLocator(restrict={0} if low_memory else None, locate_middle_block=locate_middle) self.last_prompt: str = '' @@ -52,6 +52,9 @@ def __init__( modules.append(PipelineHooker(pipeline, self)) + if type(pipeline) == StableDiffusionXLPipeline: + modules.append(ImageProcessorHooker(pipeline.image_processor, self)) + super().__init__(modules) self.pipe = pipeline @@ -129,6 +132,21 @@ def compute_global_heat_map(self, prompt=None, factors=None, head_idx=None, laye return GlobalHeatMap(self.pipe.tokenizer, prompt, maps) +class ImageProcessorHooker(ObjectHooker[VaeImageProcessor]): + def __init__(self, processor: VaeImageProcessor, parent_trace: 'trace'): + super().__init__(processor) + self.parent_trace = parent_trace + + def _hooked_postprocess(hk_self, _: VaeImageProcessor, *args, **kwargs): + images = hk_self.monkey_super('postprocess', *args, **kwargs) + hk_self.parent_trace.last_image = images[0] + + return images + + def _hook_impl(self): + self.monkey_patch('postprocess', self._hooked_postprocess) + + class PipelineHooker(ObjectHooker[StableDiffusionPipeline]): def __init__(self, pipeline: StableDiffusionPipeline, parent_trace: 'trace'): super().__init__(pipeline) @@ -137,12 +155,20 @@ def __init__(self, pipeline: StableDiffusionPipeline, parent_trace: 'trace'): def _hooked_run_safety_checker(hk_self, self: StableDiffusionPipeline, image, *args, **kwargs): image, has_nsfw = hk_self.monkey_super('run_safety_checker', image, *args, **kwargs) - pil_image = self.numpy_to_pil(image) - hk_self.parent_trace.last_image = pil_image[0] + + if self.image_processor: + if torch.is_tensor(image): + images = self.image_processor.postprocess(image, output_type='pil') + else: + images = self.image_processor.numpy_to_pil(image) + else: + images = self.numpy_to_pil(image) + + hk_self.parent_trace.last_image = images[len(images)-1] return image, has_nsfw - def _hooked_encode_prompt(hk_self, _: StableDiffusionPipeline, prompt: Union[str, List[str]], *args, **kwargs): + def _hooked_check_inputs(hk_self, _: StableDiffusionPipeline, prompt: Union[str, List[str]], *args, **kwargs): if not isinstance(prompt, str) and len(prompt) > 1: raise ValueError('Only single prompt generation is supported for heat map computation.') elif not isinstance(prompt, str): @@ -152,13 +178,12 @@ def _hooked_encode_prompt(hk_self, _: StableDiffusionPipeline, prompt: Union[str hk_self.heat_maps.clear() hk_self.parent_trace.last_prompt = last_prompt - ret = hk_self.monkey_super('_encode_prompt', prompt, *args, **kwargs) - return ret + return hk_self.monkey_super('check_inputs', prompt, *args, **kwargs) def _hook_impl(self): - self.monkey_patch('run_safety_checker', self._hooked_run_safety_checker) - self.monkey_patch('_encode_prompt', self._hooked_encode_prompt) + self.monkey_patch('run_safety_checker', self._hooked_run_safety_checker, strict=False) # not present in SDXL + self.monkey_patch('check_inputs', self._hooked_check_inputs) class UNetCrossAttentionHooker(ObjectHooker[Attention]): diff --git a/daam/utils.py b/daam/utils.py index 6b26761..8cfde13 100644 --- a/daam/utils.py +++ b/daam/utils.py @@ -73,12 +73,16 @@ def cache_dir() -> Path: def compute_token_merge_indices(tokenizer, prompt: str, word: str, word_idx: int = None, offset_idx: int = 0): merge_idxs = [] tokens = tokenizer.tokenize(prompt.lower()) + tokens = [x.replace('', '') for x in tokens] # New tokenizer uses wordpiece markers. + if word_idx is None: word = word.lower() - search_tokens = tokenizer.tokenize(word) + search_tokens = [x.replace('', '') for x in tokenizer.tokenize(word)] # New tokenizer uses wordpiece markers. start_indices = [x + offset_idx for x in range(len(tokens)) if tokens[x:x + len(search_tokens)] == search_tokens] + for indice in start_indices: merge_idxs += [i + indice for i in range(0, len(search_tokens))] + if not merge_idxs: raise ValueError(f'Search word {word} not found in prompt!') else: diff --git a/requirements.txt b/requirements.txt index 8448528..aa15461 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,12 @@ scikit-image -diffusers==0.16.1 +diffusers==0.21.2 spacy gradio ftfy -transformers==4.27.4 +transformers==4.30.2 pandas numba nltk inflect joblib -accelerate==0.18.0 +accelerate==0.23.0