From ec5739578803d6f138efaba7296ec47de1b42352 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Wed, 20 Sep 2023 13:42:44 -0700 Subject: [PATCH 01/30] random crop --- diffusion/datasets/image_caption.py | 7 +++- diffusion/datasets/laion/transforms.py | 57 ++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index a8405b4f..9982d402 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -14,7 +14,7 @@ from torchvision import transforms from transformers import AutoTokenizer -from diffusion.datasets.laion.transforms import LargestCenterSquare +from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropSquare # Disable PIL max image size limit Image.MAX_IMAGE_PIXELS = None @@ -113,6 +113,7 @@ def build_streaming_image_caption_dataloader( transform: Optional[List[Callable]] = None, image_key: str = 'image', caption_key: str = 'caption', + rand_crop: bool = False, streaming_kwargs: Optional[Dict] = None, dataloader_kwargs: Optional[Dict] = None, ): @@ -131,6 +132,7 @@ def build_streaming_image_caption_dataloader( transform (Optional[Callable]): The transforms to apply to the image. Default: ``None``. image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``. + rand_crop (bool): If True, randomly crop images. Otherwise, center crop. ``False``. streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``. dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``. """ @@ -157,9 +159,10 @@ def build_streaming_image_caption_dataloader( streams.append(Stream(remote=r, local=l)) # Setup the transforms to apply + crop_transform = LargestCenterSquare(resize_size) if rand_crop else RandomCropSquare(resize_size) if transform is None: transform = [ - LargestCenterSquare(resize_size), + crop_transform, transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # # Normalize from 0 to 1 to -1 to 1 ] diff --git a/diffusion/datasets/laion/transforms.py b/diffusion/datasets/laion/transforms.py index a0a142d8..32792885 100644 --- a/diffusion/datasets/laion/transforms.py +++ b/diffusion/datasets/laion/transforms.py @@ -3,7 +3,35 @@ """Transforms for the laion dataset.""" +import numpy as np import torchvision.transforms as transforms +from torchvision.transforms.functional import crop, get_dimensions + + +def random_crop_params(img, output_size): + """Helper function to return the parameters for a random crop. + + Args: + img (PIL Image or Tensor): Input image. + output_size (int): Size of output image. + + Returns: + cropped_im (PIL Image or Tensor): Cropped square image of output_size. + c_top (int): Top crop coordinate. + c_left (int): Left crop coordinate. + """ + _, image_height, image_width = get_dimensions(img) + if image_height == image_width: + c_left = 0 + c_top = 0 + elif image_height < image_width: + c_left = np.random.randint(0, image_width - output_size) + c_top = 0 + else: + c_left = 0 + c_top = np.random.randint(0, image_height - output_size) + cropped_im = crop(img, c_top, c_left, output_size, output_size) + return cropped_im, c_top, c_left class LargestCenterSquare: @@ -19,3 +47,32 @@ def __call__(self, img): # Then take a center crop to a square. img = self.center_crop(img) return img + + +class RandomCropSquare: + """Randomly crop square of a PIL image.""" + + def __init__(self, size): + self.size = size + + def __call__(self, img): + # First, resize the image such that the smallest side is self.size while preserving aspect ratio. + img = transforms.functional.resize(img, self.size, antialias=True) + # Then take a center crop to a square & return crop params. + img, _, _ = random_crop_params(img, self.size) + return img + + +class RandomCropSquareReturnTransform: + """Randomly crop square of a PIL image and return the crop parameters.""" + + def __init__(self, size): + self.size = size + + def __call__(self, img): + # First, resize the image such that the smallest side is self.size while preserving aspect ratio. + orig_w, orig_h = img.size + img = transforms.functional.resize(img, self.size, antialias=True) + # Then take a center crop to a square & return crop params. + img, c_top, c_left = random_crop_params(img, self.size) + return img, c_top, c_left, orig_h, orig_w From d34d7e9eb40d3d4464f094f4ff5c181ba8cbd336 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Wed, 20 Sep 2023 15:50:46 -0700 Subject: [PATCH 02/30] zero init trick --- diffusion/models/models.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index e2eed1a8..278e2045 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -13,6 +13,7 @@ from torchmetrics.multimodal.clip_score import CLIPScore from transformers import CLIPTextModel, CLIPTokenizer, PretrainedConfig +from diffusion.models.layers import zero_module from diffusion.models.pixel_diffusion import PixelDiffusion from diffusion.models.stable_diffusion import StableDiffusion from diffusion.schedulers.schedulers import ContinuousTimeScheduler @@ -189,13 +190,16 @@ def stable_diffusion_xl( config[0]['cross_attention_dim'] = 1024 unet = UNet2DConditionModel(**config[0]) - # Prevent fsdp from wrapping up_blocks and down_blocks because the forward pass calls length on these - unet.up_blocks._fsdp_wrap = False - unet.down_blocks._fsdp_wrap = False - for block in unet.up_blocks: - block._fsdp_wrap = True - for block in unet.down_blocks: - block._fsdp_wrap = True + # Zero initialization trick for more stable training + for name, layer in unet.named_modules(): + # Final conv in ResNet blocks + if name.endswith('conv2'): + layer = zero_module(layer) + # proj_out in attention blocks + if name.endswith('to_out.0'): + layer = zero_module(layer) + # Last conv block out projection + unet.conv_out = zero_module(unet.conv_out) if encode_latents_in_fp16: vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=torch.float16) From fb856c42396b3cc5324f7b27c77319aa89bd7ad0 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Wed, 20 Sep 2023 16:16:26 -0700 Subject: [PATCH 03/30] add intentionally buggy clipping --- diffusion/models/layers.py | 212 +++++++++++++++++++++++++++++++++++++ diffusion/models/models.py | 19 +++- 2 files changed, 230 insertions(+), 1 deletion(-) create mode 100644 diffusion/models/layers.py diff --git a/diffusion/models/layers.py b/diffusion/models/layers.py new file mode 100644 index 00000000..4cf29503 --- /dev/null +++ b/diffusion/models/layers.py @@ -0,0 +1,212 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Helpful layers and functions for UNet construction.""" + +from typing import Optional + +import torch +import torch.nn.functional as F +import xformers # type: ignore + + +def zero_module(module): + """Zero out the parameters of a module and return it.""" + for p in module.parameters(): + p.detach().zero_() + return module + + +class ClampedAttnProcessor2_0: + """Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).""" + + def __init__(self, clip_val=6.0): + if not hasattr(F, 'scaled_dot_product_attention'): + raise ImportError('AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.') + self.clip_val = clip_val + print('initializing ClampedAttnProcessor2_0 with value %f!' % clip_val) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + scale: float = 1.0, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + else: + channel, height, width = None, None, None + + batch_size, sequence_length, _ = (hidden_states.shape + if encoder_hidden_states is None else encoder_hidden_states.shape) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states, scale=scale) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, scale=scale) + value = attn.to_v(encoder_hidden_states, scale=scale) + + query = query.clamp(min=-self.clip_val, max=self.clip_val) + key = key.clamp(min=-self.clip_val, max=self.clip_val) + value = value.clamp(min=-self.clip_val, max=self.clip_val) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention(query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, scale=scale) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class ClampedXFormersAttnProcessor: + """Processor for implementing memory efficient attention using xFormers. + + Args: + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + """ + + def __init__(self, clip_val=6.0, attention_op=None): + self.attention_op = attention_op + self.clip_val = clip_val + + print('initializing ClampedXFormersAttnProcessor with value %f - intentionally buggy version!' % clip_val) + + def __call__( + self, + attn, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + scale: float = 1.0, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + else: + channel, height, width = None, None, None + + batch_size, key_tokens, _ = (hidden_states.shape + if encoder_hidden_states is None else encoder_hidden_states.shape) + + attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states, scale=scale) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states, scale=scale) + value = attn.to_v(encoder_hidden_states, scale=scale) + + query = query.clamp(min=-self.clip_val, max=self.clip_val) + key = query.clamp(min=-self.clip_val, max=self.clip_val) # key.clamp(min=-self.clip_val, max=self.clip_val) + value = query.clamp(min=-self.clip_val, max=self.clip_val) # value.clamp(min=-self.clip_val, max=self.clip_val) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention(query, + key, + value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, scale=scale) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + assert channel + assert height + assert width + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 278e2045..e1c81913 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -13,7 +13,7 @@ from torchmetrics.multimodal.clip_score import CLIPScore from transformers import CLIPTextModel, CLIPTokenizer, PretrainedConfig -from diffusion.models.layers import zero_module +from diffusion.models.layers import ClampedAttnProcessor2_0, ClampedXFormersAttnProcessor, zero_module from diffusion.models.pixel_diffusion import PixelDiffusion from diffusion.models.stable_diffusion import StableDiffusion from diffusion.schedulers.schedulers import ContinuousTimeScheduler @@ -38,6 +38,7 @@ def stable_diffusion_2( precomputed_latents: bool = False, encode_latents_in_fp16: bool = True, fsdp: bool = True, + clip_qkv: Optional[float] = None, ): """Stable diffusion v2 training setup. @@ -61,6 +62,7 @@ def stable_diffusion_2( precomputed_latents (bool): Whether to use precomputed latents. Defaults to False. encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. fsdp (bool): Whether to use FSDP. Defaults to True. + clip_qkv (float, optional): If not None, clip the qkv values to this value. Defaults to None. """ if train_metrics is None: train_metrics = [MeanSquaredError()] @@ -121,6 +123,15 @@ def stable_diffusion_2( if is_xformers_installed: model.unet.enable_xformers_memory_efficient_attention() model.vae.enable_xformers_memory_efficient_attention() + + if clip_qkv is not None: + if is_xformers_installed: + attn_processor = ClampedXFormersAttnProcessor(clip_val=clip_qkv) + else: + attn_processor = ClampedAttnProcessor2_0(clip_val=clip_qkv) + + model.unet.set_attn_processor(attn_processor) + return model @@ -138,6 +149,7 @@ def stable_diffusion_xl( precomputed_latents: bool = False, encode_latents_in_fp16: bool = True, fsdp: bool = True, + clip_qkv: Optional[float] = 6.0, ): """Stable diffusion 2 training setup + SDXL UNet and VAE. @@ -167,6 +179,7 @@ def stable_diffusion_xl( precomputed_latents (bool): Whether to use precomputed latents. Defaults to False. encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. fsdp (bool): Whether to use FSDP. Defaults to True. + clip_qkv (float, optional): If not None, clip the qkv values to this value. Defaults to 6.0. """ if train_metrics is None: train_metrics = [MeanSquaredError()] @@ -241,6 +254,10 @@ def stable_diffusion_xl( if is_xformers_installed: model.unet.enable_xformers_memory_efficient_attention() model.vae.enable_xformers_memory_efficient_attention() + + if clip_qkv is not None: + raise NotImplementedError('Clipping not implemented for SDXL yet.') + return model From 4dd3c40723f0245eccbf9f578b8c7b79f0ce300a Mon Sep 17 00:00:00 2001 From: jazcollins Date: Wed, 20 Sep 2023 16:27:29 -0700 Subject: [PATCH 04/30] fix docstring and update diffusers version --- diffusion/models/layers.py | 7 ++++++- diffusion/models/models.py | 3 ++- setup.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/diffusion/models/layers.py b/diffusion/models/layers.py index 4cf29503..734c3bd8 100644 --- a/diffusion/models/layers.py +++ b/diffusion/models/layers.py @@ -18,7 +18,10 @@ def zero_module(module): class ClampedAttnProcessor2_0: - """Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).""" + """Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + + Modified from https://github.com/huggingface/diffusers/blob/v0.21.0-release/src/diffusers/models/attention_processor.py. + """ def __init__(self, clip_val=6.0): if not hasattr(F, 'scaled_dot_product_attention'): @@ -113,6 +116,8 @@ def __call__( class ClampedXFormersAttnProcessor: """Processor for implementing memory efficient attention using xFormers. + Modified from https://github.com/huggingface/diffusers/blob/v0.21.0-release/src/diffusers/models/attention_processor.py. + Args: attention_op (`Callable`, *optional*, defaults to `None`): The base diff --git a/diffusion/models/models.py b/diffusion/models/models.py index e1c81913..e9f7f8cd 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -179,7 +179,8 @@ def stable_diffusion_xl( precomputed_latents (bool): Whether to use precomputed latents. Defaults to False. encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. fsdp (bool): Whether to use FSDP. Defaults to True. - clip_qkv (float, optional): If not None, clip the qkv values to this value. Defaults to 6.0. + clip_qkv (float, optional): If not None, clip the qkv values to this value. Defaults to 6.0. Improves stability + of training. """ if train_metrics is None: train_metrics = [MeanSquaredError()] diff --git a/setup.py b/setup.py index 6c163e2d..4cb3034c 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ 'mosaicml-streaming>=0.4.0,<1.0', 'hydra-core>=1.2', 'hydra-colorlog>=1.1.0', - 'diffusers[torch]==0.19.3', + 'diffusers[torch]==0.21.0', 'transformers[torch]==4.31.0', 'wandb==0.15.4', 'xformers==0.0.21', From c1d58c9fe280a8be328f71d3025516e23e003c92 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Wed, 20 Sep 2023 16:47:30 -0700 Subject: [PATCH 05/30] fix attention clipping, add to sdxl, fix xformers import when not installed --- diffusion/models/layers.py | 23 +++++++++++++---------- diffusion/models/models.py | 13 ++++++++----- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/diffusion/models/layers.py b/diffusion/models/layers.py index 734c3bd8..19847830 100644 --- a/diffusion/models/layers.py +++ b/diffusion/models/layers.py @@ -7,7 +7,11 @@ import torch import torch.nn.functional as F -import xformers # type: ignore + +try: + import xformers # type: ignore +except: + pass def zero_module(module): @@ -17,17 +21,17 @@ def zero_module(module): return module -class ClampedAttnProcessor2_0: +class ClippedAttnProcessor2_0: """Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - Modified from https://github.com/huggingface/diffusers/blob/v0.21.0-release/src/diffusers/models/attention_processor.py. + Modified from https://github.com/huggingface/diffusers/blob/v0.21.0-release/src/diffusers/models/attention_processor.py to + allow clipping QKV values. """ def __init__(self, clip_val=6.0): if not hasattr(F, 'scaled_dot_product_attention'): raise ImportError('AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.') self.clip_val = clip_val - print('initializing ClampedAttnProcessor2_0 with value %f!' % clip_val) def __call__( self, @@ -113,10 +117,11 @@ def __call__( return hidden_states -class ClampedXFormersAttnProcessor: +class ClippedXFormersAttnProcessor: """Processor for implementing memory efficient attention using xFormers. - Modified from https://github.com/huggingface/diffusers/blob/v0.21.0-release/src/diffusers/models/attention_processor.py. + Modified from https://github.com/huggingface/diffusers/blob/v0.21.0-release/src/diffusers/models/attention_processor.py to + allow clipping QKV values. Args: attention_op (`Callable`, *optional*, defaults to `None`): @@ -130,8 +135,6 @@ def __init__(self, clip_val=6.0, attention_op=None): self.attention_op = attention_op self.clip_val = clip_val - print('initializing ClampedXFormersAttnProcessor with value %f - intentionally buggy version!' % clip_val) - def __call__( self, attn, @@ -182,8 +185,8 @@ def __call__( value = attn.to_v(encoder_hidden_states, scale=scale) query = query.clamp(min=-self.clip_val, max=self.clip_val) - key = query.clamp(min=-self.clip_val, max=self.clip_val) # key.clamp(min=-self.clip_val, max=self.clip_val) - value = query.clamp(min=-self.clip_val, max=self.clip_val) # value.clamp(min=-self.clip_val, max=self.clip_val) + key = key.clamp(min=-self.clip_val, max=self.clip_val) + value = value.clamp(min=-self.clip_val, max=self.clip_val) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() diff --git a/diffusion/models/models.py b/diffusion/models/models.py index e9f7f8cd..580cd7d9 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -13,7 +13,7 @@ from torchmetrics.multimodal.clip_score import CLIPScore from transformers import CLIPTextModel, CLIPTokenizer, PretrainedConfig -from diffusion.models.layers import ClampedAttnProcessor2_0, ClampedXFormersAttnProcessor, zero_module +from diffusion.models.layers import ClippedAttnProcessor2_0, ClippedXFormersAttnProcessor, zero_module from diffusion.models.pixel_diffusion import PixelDiffusion from diffusion.models.stable_diffusion import StableDiffusion from diffusion.schedulers.schedulers import ContinuousTimeScheduler @@ -126,10 +126,9 @@ def stable_diffusion_2( if clip_qkv is not None: if is_xformers_installed: - attn_processor = ClampedXFormersAttnProcessor(clip_val=clip_qkv) + attn_processor = ClippedXFormersAttnProcessor(clip_val=clip_qkv) else: - attn_processor = ClampedAttnProcessor2_0(clip_val=clip_qkv) - + attn_processor = ClippedAttnProcessor2_0(clip_val=clip_qkv) model.unet.set_attn_processor(attn_processor) return model @@ -257,7 +256,11 @@ def stable_diffusion_xl( model.vae.enable_xformers_memory_efficient_attention() if clip_qkv is not None: - raise NotImplementedError('Clipping not implemented for SDXL yet.') + if is_xformers_installed: + attn_processor = ClippedXFormersAttnProcessor(clip_val=clip_qkv) + else: + attn_processor = ClippedAttnProcessor2_0(clip_val=clip_qkv) + model.unet.set_attn_processor(attn_processor) return model From f14018a8d6319d880d00c07c072ed60f55e074b3 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Thu, 21 Sep 2023 14:25:20 -0700 Subject: [PATCH 06/30] big sdxl commit, no style check --- diffusion/callbacks/log_diffusion_images.py | 32 ++++-- diffusion/datasets/image_caption.py | 72 +++++++++--- diffusion/models/models.py | 119 ++++++++++++++++---- diffusion/models/stable_diffusion.py | 105 ++++++++++++++--- 4 files changed, 267 insertions(+), 61 deletions(-) diff --git a/diffusion/callbacks/log_diffusion_images.py b/diffusion/callbacks/log_diffusion_images.py index 54c36a72..38aa4ae6 100644 --- a/diffusion/callbacks/log_diffusion_images.py +++ b/diffusion/callbacks/log_diffusion_images.py @@ -27,8 +27,8 @@ class LogDiffusionImages(Callback): the text prompt, usually at the expense of lower image quality. Default: ``0.0``. text_key (str, optional): Key in the batch to use for text prompts. Default: ``'captions'``. - tokenized_prompts (torch.LongTensor, optional): Batch of pre-tokenized prompts - to use for evaluation. Default: ``None``. + tokenized_prompts (torch.LongTensor or List[torch.LongTensor], optional): Batch of pre-tokenized prompts + to use for evaluation. If SDXL, this will be a list of two pre-tokenized prompts Default: ``None``. seed (int, optional): Random seed to use for generation. Set a seed for reproducible generation. Default: ``1138``. use_table (bool): Whether to make a table of the images or not. Default: ``False``. @@ -62,14 +62,26 @@ def eval_batch_end(self, state: State, logger: Logger): else: model = state.model - if self.tokenized_prompts is None: - tokenized_prompts = [ - model.tokenizer(p, padding='max_length', truncation=True, - return_tensors='pt')['input_ids'] # type: ignore - for p in self.prompts - ] - self.tokenized_prompts = torch.cat(tokenized_prompts) - self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) + if model.sdxl: + if self.tokenized_prompts is None: + tokenized_prompts = [ + model.tokenizer(p, padding='max_length', truncation=True, + return_tensors='pt', input_ids=True) # type: ignore + for p in self.prompts + ] + self.tokenized_prompts = [torch.cat(tokenized_prompts[0]), + torch.cat(tokenized_prompts[1])] + self.tokenized_prompts[0] = self.tokenized_prompts[0].to(state.batch[self.text_key].device) + self.tokenized_prompts[1] = self.tokenized_prompts[1].to(state.batch[self.text_key].device) + else: + if self.tokenized_prompts is None: + tokenized_prompts = [ + model.tokenizer(p, padding='max_length', truncation=True, + return_tensors='pt')['input_ids'] # type: ignore + for p in self.prompts + ] + self.tokenized_prompts = torch.cat(tokenized_prompts) + self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # Generate images with get_precision_context(state.precision): diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 9982d402..2aef4f00 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -14,7 +14,8 @@ from torchvision import transforms from transformers import AutoTokenizer -from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropSquare +from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropSquare, RandomCropSquareReturnTransform +from diffusion.models.models import SDXLTokenizer # Disable PIL max image size limit Image.MAX_IMAGE_PIXELS = None @@ -36,6 +37,7 @@ class StreamingImageCaptionDataset(StreamingDataset): image_size (Optional[int]): The size to resize the image to. Default: ``None``. image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``. + sdxl (bool): Whether or not we're training SDXL. Default: `False`. **streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader """ @@ -51,6 +53,7 @@ def __init__( image_size: Optional[int] = None, image_key: str = 'image', caption_key: str = 'caption', + sdxl: bool = False, **streaming_kwargs, ) -> None: @@ -65,7 +68,12 @@ def __init__( raise ValueError(f'Invalid caption selection: {caption_selection}. Must be one of [random, first]') self.transform = transform - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, subfolder='tokenizer') + if self.sdxl: + self.tokenizer = SDXLTokenizer(tokenizer_name_or_path) + self.sdxl_crop = RandomCropSquareReturnTransform(image_size) + else: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, subfolder='tokenizer') + self.sdxl_crop = None self.caption_drop_prob = caption_drop_prob self.caption_selection = caption_selection self.image_size = image_size @@ -81,9 +89,22 @@ def __getitem__(self, index): img = Image.open(BytesIO(sample[self.image_key])) if img.mode != 'RGB': img = img.convert('RGB') + + out = {} + # Image transforms + if self.sdxl: + # sdxl crop to return params + img, crop_top, crop_left, image_height, image_width = self.sdxl_crop(img) + out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left]) + out['cond_original_size'] = torch.tensor([image_width, image_height]) + out['cond_target_size'] = torch.tensor([self.image_size, self.image_size]) + else: + crop_top, crop_left, image_height, image_width = None, None, None, None if self.transform is not None: img = self.transform(img) + # TODO implement dropped caption masking! + # Caption if torch.rand(1) < self.caption_drop_prob: caption = '' @@ -93,13 +114,24 @@ def __getitem__(self, index): caption = caption[0] if isinstance(caption, List) and self.caption_selection == 'random': caption = random.sample(caption, k=1)[0] - tokenized_caption = self.tokenizer(caption, - padding='max_length', - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors='pt')['input_ids'][0] - return {'image': img, 'captions': tokenized_caption} + if self.sdxl: + tokenized_captions = self.tokenizer(caption, + padding='max_length', + truncation=True, + return_tensors='pt', + input_ids=True) + tokenized_captions = [cap[0] for cap in tokenized_captions] + tokenized_caption = torch.stack(tokenized_captions) + else: + tokenized_caption = self.tokenizer(caption, + padding='max_length', + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors='pt')['input_ids'][0] + out['image'] = img + out['captions'] = tokenized_caption + return out def build_streaming_image_caption_dataloader( @@ -158,14 +190,27 @@ def build_streaming_image_caption_dataloader( for r, l in zip(remote, local): streams.append(Stream(remote=r, local=l)) + # Infer SDXL from tokenizer path + if tokenizer_name_or_path == 'stabilityai/stable-diffusion-xl-base-1.0': + sdxl = True + else: + sdxl = False + # Setup the transforms to apply crop_transform = LargestCenterSquare(resize_size) if rand_crop else RandomCropSquare(resize_size) if transform is None: - transform = [ - crop_transform, - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # # Normalize from 0 to 1 to -1 to 1 - ] + if sdxl: + # Crop will return parameters so do separately + transform = [ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + else: + transform = [ + crop_transform, + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # # Normalize from 0 to 1 to -1 to 1 + ] transform = transforms.Compose(transform) assert isinstance(transform, Callable) @@ -179,6 +224,7 @@ def build_streaming_image_caption_dataloader( image_key=image_key, caption_key=caption_key, batch_size=batch_size, + sdxl=sdxl, **streaming_kwargs, ) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 580cd7d9..c0fe865d 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -7,11 +7,11 @@ import torch from composer.devices import DeviceGPU -from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, UNet2DConditionModel +from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel from torchmetrics import MeanSquaredError from torchmetrics.image.fid import FrechetInceptionDistance from torchmetrics.multimodal.clip_score import CLIPScore -from transformers import CLIPTextModel, CLIPTokenizer, PretrainedConfig +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, PretrainedConfig from diffusion.models.layers import ClippedAttnProcessor2_0, ClippedXFormersAttnProcessor, zero_module from diffusion.models.pixel_diffusion import PixelDiffusion @@ -134,8 +134,88 @@ def stable_diffusion_2( return model +class SDXLTextEncoder: + """Wrapper around HuggingFace text encoders for SDXL. + + Creates two text encoders (a CLIPTextModel and CLIPTextModelWithProjection) that behave like one. + + Args: + model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. + encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. + """ + def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0', encode_latents_in_fp16=True): + if encode_latents_in_fp16: + self.text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder', torch_dtype=torch.float16) + self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_name, + subfolder='text_encoder_2', + torch_dtype=torch.float16) + else: + self.text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder') + self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_name, subfolder='text_encoder_2') + self.device = self.text_encoder.device + + self._fsdp_wrap = False + self.text_encoder._fsdp_wrap = False + self.text_encoder_2._fsdp_wrap = False + + def requires_grad_(self, requires_grad): + self.text_encoder.requires_grad_(requires_grad) + self.text_encoder_2.requires_grad_(requires_grad) + + def half(self): + self.text_encoder.half() + self.text_encoder_2.half() + + def __call__(self, tokenized_output, tokenized_output_2): # TODO need to pass second tokenized outputs and handle pooled output + # first text encoder + conditioning = self.text_encoder(tokenized_output, output_hidden_states=True).hidden_states[-2] + # second text encoder + text_encoder_2_out = self.text_encoder_2(tokenized_output_2, output_hidden_states=True) + pooled_conditioning = text_encoder_2_out[0] # (batch_size, 1280) + conditioning_2 = text_encoder_2_out.hidden_states[-2] # (batch_size, 77, 1280) + + # # zero out the appropriate things + # if batch[self.text_key].sum() == 0: + # conditioning = torch.zeros_like(conditioning) + # if batch[self.text_key_2].sum() == 0: + # conditioning_2 = torch.zeros_like(conditioning_2) + # pooled_conditioning = torch.zeros_like(pooled_conditioning) + + conditioning = torch.concat([conditioning, conditioning_2], dim=-1) + return conditioning, pooled_conditioning + + +class SDXLTokenizer: + """Wrapper around HuggingFace tokenizers for SDXL. + + Tokenizes prompt with two tokenizers and returns the outputs as a list. + + Args: + model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. + """ + def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0'): + self.tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer') + self.tokenizer_2 = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer_2') + + def __call__(self, prompt, padding, truncation, return_tensors, input_ids=False): + tokenized_output = self.tokenizer(prompt, + padding=padding, + max_length=self.tokenizer.model_max_length, + truncation=truncation, + return_tensors=return_tensors) + tokenized_output_2 = self.tokenizer_2(prompt, + padding=padding, + max_length=self.tokenizer_2.model_max_length, + truncation=truncation, + return_tensors=return_tensors) + if input_ids: + tokenized_output = tokenized_output.input_ids + tokenized_output_2 = tokenized_output_2.input_ids + return [tokenized_output, tokenized_output_2] + + def stable_diffusion_xl( - model_name: str = 'stabilityai/stable-diffusion-2-base', + model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', unet_model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', pretrained: bool = True, @@ -156,8 +236,8 @@ def stable_diffusion_xl( prompts. Currently uses UNet and VAE config from SDXL, but text encoder/tokenizer from SD2. Args: - model_name (str): Name of the model to load. Determines the text encoder, tokenizer, - and noise scheduler. Defaults to 'stabilityai/stable-diffusion-2-base'. + model_name (str): Name of the model to load. Determines the text encoders, tokenizers, + and noise scheduler. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. unet_model_name (str): Name of the UNet model to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. vae_model_name (str): Name of the VAE model to load. Defaults to @@ -198,9 +278,6 @@ def stable_diffusion_xl( raise NotImplementedError('Full SDXL pipeline not implemented yet.') else: config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet') - # Currently not doing micro-conditioning, so set config appropriately - config[0]['addition_embed_type'] = None - config[0]['cross_attention_dim'] = 1024 unet = UNet2DConditionModel(**config[0]) # Zero initialization trick for more stable training @@ -215,22 +292,21 @@ def stable_diffusion_xl( unet.conv_out = zero_module(unet.conv_out) if encode_latents_in_fp16: - vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=torch.float16) - text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder', torch_dtype=torch.float16) + try: + vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=torch.float16) + except: # for handling SDXL vae fp16 fixed checkpoint + vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=torch.float16) else: - vae = AutoencoderKL.from_pretrained(vae_model_name) - text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder') + try: + vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae') + except: # for handling SDXL vae fp16 fixed checkpoint + vae = AutoencoderKL.from_pretrained(vae_model_name) + + tokenizer = SDXLTokenizer(model_name) + text_encoder = SDXLTextEncoder(model_name, encode_latents_in_fp16) - tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer') noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder='scheduler') - inference_noise_scheduler = DDIMScheduler(num_train_timesteps=noise_scheduler.config.num_train_timesteps, - beta_start=noise_scheduler.config.beta_start, - beta_end=noise_scheduler.config.beta_end, - beta_schedule=noise_scheduler.config.beta_schedule, - trained_betas=noise_scheduler.config.trained_betas, - clip_sample=noise_scheduler.config.clip_sample, - set_alpha_to_one=noise_scheduler.config.set_alpha_to_one, - prediction_type=prediction_type) + inference_noise_scheduler = EulerDiscreteScheduler.from_pretrained(model_name, subfolder='scheduler') model = StableDiffusion( unet=unet, @@ -248,6 +324,7 @@ def stable_diffusion_xl( precomputed_latents=precomputed_latents, encode_latents_in_fp16=encode_latents_in_fp16, fsdp=fsdp, + sdxl=True, ) if torch.cuda.is_available(): model = DeviceGPU().module_to_device(model) diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 01688c04..4236460f 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -62,6 +62,7 @@ class StableDiffusion(ComposerModel): Default: `False`. encode_latents_in_fp16 (bool): whether to encode latents in fp16. Default: `False`. + sdxl (bool): Whether or not we're training SDXL. Default: `False`. """ def __init__(self, @@ -84,7 +85,8 @@ def __init__(self, text_latents_key: str = 'caption_latents', precomputed_latents: bool = False, encode_latents_in_fp16: bool = False, - fsdp: bool = False): + fsdp: bool = False, + sdxl: bool = False): super().__init__() self.unet = unet self.vae = vae @@ -97,6 +99,11 @@ def __init__(self, self.image_key = image_key self.image_latents_key = image_latents_key self.precomputed_latents = precomputed_latents + self.sdxl = sdxl + if self.sdxl: + self.latent_scale = 0.13025 + else: + self.latent_scale = 0.18215 # setup metrics if train_metrics is None: @@ -157,9 +164,16 @@ def forward(self, batch): latents, conditioning = None, None # Use latents if specified and available. When specified, they might not exist during eval if self.precomputed_latents and self.image_latents_key in batch and self.text_latents_key in batch: + if self.sdxl: + raise NotImplementedError('SDXL not yet supported with precomputed latents') latents, conditioning = batch[self.image_latents_key], batch[self.text_latents_key] else: inputs, conditioning = batch[self.image_key], batch[self.text_key] + if self.sdxl: + # TODO check this + conditioning, conditioning_2 = conditioning[0], conditioning[1] + else: + conditioning_2 = None conditioning = conditioning.view(-1, conditioning.shape[-1]) if self.encode_latents_in_fp16: # Disable autocast context as models are in fp16 @@ -167,13 +181,23 @@ def forward(self, batch): # Encode the images to the latent space. # Encode prompt into conditioning vector latents = self.vae.encode(inputs.half())['latent_dist'].sample().data - conditioning = self.text_encoder(conditioning)[0] # Should be (batch_size, 77, 768) - + # if self.sdxl: + # conditioning_2 = batch[self.text_key_2].view(-1, conditioning_2.shape[-1]) + # conditioning, pooled_conditioning = self.text_encoder(conditioning, conditioning_2) + # else: + # conditioning = self.text_encoder(conditioning)[0] # Should be (batch_size, 77, 768) + # pooled_conditioning = None else: latents = self.vae.encode(inputs)['latent_dist'].sample().data + + if self.sdxl: + conditioning_2 = conditioning_2.view(-1, conditioning_2.shape[-1]) + conditioning, pooled_conditioning = self.text_encoder(conditioning, conditioning_2) + else: conditioning = self.text_encoder(conditioning)[0] + pooled_conditioning = None # Magical scaling number (See https://github.com/huggingface/diffusers/issues/437#issuecomment-1241827515) - latents *= 0.18215 + latents *= self.latent_scale # Sample the diffusion timesteps timesteps = torch.randint(0, len(self.noise_scheduler), (latents.shape[0],), device=latents.device) @@ -190,8 +214,19 @@ def forward(self, batch): else: raise ValueError( f'prediction type must be one of sample, epsilon, or v_prediction. Got {self.prediction_type}') + + added_cond_kwargs = {} + # if using SDXL, prepare added time ids & embeddings + if self.sdxl: + # TODO double check cond_crops_coords_top_left calc in transforms.py + add_time_ids = torch.cat( + [batch['cond_original_size'], batch['cond_crops_coords_top_left'], batch['cond_target_size']], dim=1) + add_text_embeds = pooled_conditioning + added_cond_kwargs = {'text_embeds': add_text_embeds, 'time_ids': add_time_ids} + # Forward through the model - return self.unet(noised_latents, timesteps, conditioning)['sample'], targets, timesteps + return self.unet(noised_latents, timesteps, conditioning, + added_cond_kwargs=added_cond_kwargs)['sample'], targets, timesteps def loss(self, outputs, batch): """Loss between unet output and added noise, typically mse.""" @@ -207,6 +242,18 @@ def eval_forward(self, batch, outputs=None): # Sample images from the prompts in the batch prompts = batch[self.text_key] height, width = batch[self.image_key].shape[-2], batch[self.image_key].shape[-1] + + # If SDXL, add eval-time micro-conditioning to batch + if self.sdxl: + device = self.unet.device + bsz = batch[self.image_key].shape[0] + # Set to resolution we are trying to generate + batch['cond_original_size'] = torch.tensor([[width, height]]).repeat(bsz, 1).to(device) + # No cropping + batch['cond_crops_coords_top_left'] = torch.tensor([[0., 0.]]).repeat(bsz, 1).to(device) + # Set to resolution we are trying to generate + batch['cond_target_size'] = torch.tensor([[width, height]]).repeat(bsz, 1).to(device) + generated_images = {} for guidance_scale in self.val_guidance_scales: gen_images = self.generate(tokenized_prompts=prompts, @@ -261,7 +308,11 @@ def update_metric(self, batch, outputs, metric): # CLIP metrics should be updated with the generated images at the desired guidance scale elif metric.__class__.__name__ == 'CLIPScore': # Convert the captions to a list of strings - captions = [self.tokenizer.decode(caption, skip_special_tokens=True) for caption in batch[self.text_key]] + if self.sdxl: + # Decode captions with first tokenizer + captions = [self.tokenizer.tokenizer.decode(caption, skip_special_tokens=True) for caption in batch[self.text_key]] + else: + captions = [self.tokenizer.decode(caption, skip_special_tokens=True) for caption in batch[self.text_key]] generated_images = (outputs[3][metric.guidance_scale] * 255).to(torch.uint8) metric.update(generated_images, captions) else: @@ -295,8 +346,9 @@ def generate( image generation away from. Ignored when not using guidance (i.e., ignored if guidance_scale is less than 1). Must be the same length as list of prompts. Default: `None`. - tokenized_prompts (torch.LongTensor): Optionally pass pre-tokenized prompts instead - of string prompts. Default: `None`. + tokenized_prompts (torch.LongTensor) or List[torch.LongTensor]: Optionally pass + pre-tokenized prompts instead of string prompts. If SDXL, this will be a list + of two pre-tokenized prompts. Default: `None`. tokenized_negative_prompts (torch.LongTensor): Optionally pass pre-tokenized negative prompts instead of string prompts. Default: `None`. prompt_embeds (torch.FloatTensor): Optionally pass pre-tokenized prompts instead @@ -387,7 +439,7 @@ def generate( # We now use the vae to decode the generated latents back into the image. # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents + latents = 1 / self.latent_scale * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) return image.detach() # (batch*num_images_per_prompt, channel, h, w) @@ -396,21 +448,40 @@ def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num """Tokenizes and embeds prompts if needed, then duplicates embeddings to support multiple generations per prompt.""" device = self.text_encoder.device if prompt_embeds is None: - if tokenized_prompts is None: - tokenized_prompts = self.tokenizer(prompt, - padding='max_length', - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors='pt').input_ids - text_embeddings = self.text_encoder(tokenized_prompts.to(device))[0] # type: ignore + if self.sdxl: + if tokenized_prompts is None: + tokenized_prompts = self.tokenizer(prompt, + padding='max_length', + truncation=True, + return_tensors='pt', + input_ids=True) + # TODO implement zero-ing out empty prompts! + text_embeddings, pooled_text_embeddings = self.text_encoder( + tokenized_prompts[0].to(device), + tokenized_prompts[1].to(device)) # type: ignore + else: + if tokenized_prompts is None: + tokenized_prompts = self.tokenizer(prompt, + padding='max_length', + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors='pt').input_ids + text_embeddings = self.text_encoder(tokenized_prompts.to(device))[0] # type: ignore + pooled_text_embeddings = None else: + if self.sdxl: + raise NotImplementedError('SDXL not yet supported with precomputed embeddings') text_embeddings = prompt_embeds # duplicate text embeddings for each generation per prompt bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # type: ignore text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - return text_embeddings + + if self.sdxl and pooled_text_embeddings is not None: + pooled_text_embeddings = pooled_text_embeddings.repeat(1, num_images_per_prompt) + pooled_text_embeddings = pooled_text_embeddings.view(bs_embed * num_images_per_prompt, -1) + return text_embeddings, pooled_text_embeddings def _check_prompt_lenths(prompt, negative_prompt): From 45c1ac0d287624e6fbe50be2a35fe05f7442d4cd Mon Sep 17 00:00:00 2001 From: jazcollins Date: Thu, 21 Sep 2023 14:35:43 -0700 Subject: [PATCH 07/30] fix style and pyright --- diffusion/callbacks/log_diffusion_images.py | 9 +++--- diffusion/datasets/image_caption.py | 14 +++++---- diffusion/models/models.py | 15 ++++++--- diffusion/models/stable_diffusion.py | 35 ++++++++++++--------- 4 files changed, 42 insertions(+), 31 deletions(-) diff --git a/diffusion/callbacks/log_diffusion_images.py b/diffusion/callbacks/log_diffusion_images.py index 38aa4ae6..088e436c 100644 --- a/diffusion/callbacks/log_diffusion_images.py +++ b/diffusion/callbacks/log_diffusion_images.py @@ -65,12 +65,11 @@ def eval_batch_end(self, state: State, logger: Logger): if model.sdxl: if self.tokenized_prompts is None: tokenized_prompts = [ - model.tokenizer(p, padding='max_length', truncation=True, - return_tensors='pt', input_ids=True) # type: ignore + model.tokenizer(p, padding='max_length', truncation=True, return_tensors='pt', + input_ids=True) # type: ignore for p in self.prompts ] - self.tokenized_prompts = [torch.cat(tokenized_prompts[0]), - torch.cat(tokenized_prompts[1])] + self.tokenized_prompts = [torch.cat(tokenized_prompts[0]), torch.cat(tokenized_prompts[1])] self.tokenized_prompts[0] = self.tokenized_prompts[0].to(state.batch[self.text_key].device) self.tokenized_prompts[1] = self.tokenized_prompts[1].to(state.batch[self.text_key].device) else: @@ -81,7 +80,7 @@ def eval_batch_end(self, state: State, logger: Logger): for p in self.prompts ] self.tokenized_prompts = torch.cat(tokenized_prompts) - self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) + self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # type: ignore # Generate images with get_precision_context(state.precision): diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 2aef4f00..50df8bed 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -68,6 +68,7 @@ def __init__( raise ValueError(f'Invalid caption selection: {caption_selection}. Must be one of [random, first]') self.transform = transform + self.sdxl = sdxl if self.sdxl: self.tokenizer = SDXLTokenizer(tokenizer_name_or_path) self.sdxl_crop = RandomCropSquareReturnTransform(image_size) @@ -92,7 +93,7 @@ def __getitem__(self, index): out = {} # Image transforms - if self.sdxl: + if self.sdxl and self.sdxl_crop: # sdxl crop to return params img, crop_top, crop_left, image_height, image_width = self.sdxl_crop(img) out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left]) @@ -124,11 +125,12 @@ def __getitem__(self, index): tokenized_captions = [cap[0] for cap in tokenized_captions] tokenized_caption = torch.stack(tokenized_captions) else: - tokenized_caption = self.tokenizer(caption, - padding='max_length', - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors='pt')['input_ids'][0] + tokenized_caption = self.tokenizer( + caption, + padding='max_length', + max_length=self.tokenizer.model_max_length, # type: ignore + truncation=True, + return_tensors='pt')['input_ids'][0] out['image'] = img out['captions'] = tokenized_caption return out diff --git a/diffusion/models/models.py b/diffusion/models/models.py index c0fe865d..2a9e3431 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -143,11 +143,14 @@ class SDXLTextEncoder: model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. """ + def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0', encode_latents_in_fp16=True): if encode_latents_in_fp16: - self.text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder', torch_dtype=torch.float16) - self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_name, - subfolder='text_encoder_2', + self.text_encoder = CLIPTextModel.from_pretrained(model_name, + subfolder='text_encoder', + torch_dtype=torch.float16) + self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_name, + subfolder='text_encoder_2', torch_dtype=torch.float16) else: self.text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder') @@ -166,7 +169,8 @@ def half(self): self.text_encoder.half() self.text_encoder_2.half() - def __call__(self, tokenized_output, tokenized_output_2): # TODO need to pass second tokenized outputs and handle pooled output + def __call__(self, tokenized_output, + tokenized_output_2): # TODO need to pass second tokenized outputs and handle pooled output # first text encoder conditioning = self.text_encoder(tokenized_output, output_hidden_states=True).hidden_states[-2] # second text encoder @@ -193,6 +197,7 @@ class SDXLTokenizer: Args: model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. """ + def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0'): self.tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer') self.tokenizer_2 = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer_2') @@ -299,7 +304,7 @@ def stable_diffusion_xl( else: try: vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae') - except: # for handling SDXL vae fp16 fixed checkpoint + except: # for handling SDXL vae fp16 fixed checkpoint vae = AutoencoderKL.from_pretrained(vae_model_name) tokenizer = SDXLTokenizer(model_name) diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 4236460f..9db6fd62 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -161,7 +161,7 @@ def __init__(self, self.unet._fsdp_wrap = True def forward(self, batch): - latents, conditioning = None, None + latents, conditioning, pooled_conditioning = None, None, None # Use latents if specified and available. When specified, they might not exist during eval if self.precomputed_latents and self.image_latents_key in batch and self.text_latents_key in batch: if self.sdxl: @@ -191,11 +191,12 @@ def forward(self, batch): latents = self.vae.encode(inputs)['latent_dist'].sample().data if self.sdxl: - conditioning_2 = conditioning_2.view(-1, conditioning_2.shape[-1]) + assert conditioning_2 is not None + conditioning_2 = conditioning_2.view(-1, conditioning_2.shape[-1]) conditioning, pooled_conditioning = self.text_encoder(conditioning, conditioning_2) else: conditioning = self.text_encoder(conditioning)[0] - pooled_conditioning = None + # Magical scaling number (See https://github.com/huggingface/diffusers/issues/437#issuecomment-1241827515) latents *= self.latent_scale @@ -214,7 +215,7 @@ def forward(self, batch): else: raise ValueError( f'prediction type must be one of sample, epsilon, or v_prediction. Got {self.prediction_type}') - + added_cond_kwargs = {} # if using SDXL, prepare added time ids & embeddings if self.sdxl: @@ -253,7 +254,7 @@ def eval_forward(self, batch, outputs=None): batch['cond_crops_coords_top_left'] = torch.tensor([[0., 0.]]).repeat(bsz, 1).to(device) # Set to resolution we are trying to generate batch['cond_target_size'] = torch.tensor([[width, height]]).repeat(bsz, 1).to(device) - + generated_images = {} for guidance_scale in self.val_guidance_scales: gen_images = self.generate(tokenized_prompts=prompts, @@ -310,9 +311,14 @@ def update_metric(self, batch, outputs, metric): # Convert the captions to a list of strings if self.sdxl: # Decode captions with first tokenizer - captions = [self.tokenizer.tokenizer.decode(caption, skip_special_tokens=True) for caption in batch[self.text_key]] + captions = [ + self.tokenizer.tokenizer.decode(caption, skip_special_tokens=True) + for caption in batch[self.text_key] + ] else: - captions = [self.tokenizer.decode(caption, skip_special_tokens=True) for caption in batch[self.text_key]] + captions = [ + self.tokenizer.decode(caption, skip_special_tokens=True) for caption in batch[self.text_key] + ] generated_images = (outputs[3][metric.guidance_scale] * 255).to(torch.uint8) metric.update(generated_images, captions) else: @@ -346,7 +352,7 @@ def generate( image generation away from. Ignored when not using guidance (i.e., ignored if guidance_scale is less than 1). Must be the same length as list of prompts. Default: `None`. - tokenized_prompts (torch.LongTensor) or List[torch.LongTensor]: Optionally pass + tokenized_prompts (torch.LongTensor or List[torch.LongTensor]): Optionally pass pre-tokenized prompts instead of string prompts. If SDXL, this will be a list of two pre-tokenized prompts. Default: `None`. tokenized_negative_prompts (torch.LongTensor): Optionally pass pre-tokenized negative @@ -447,18 +453,18 @@ def generate( def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num_images_per_prompt): """Tokenizes and embeds prompts if needed, then duplicates embeddings to support multiple generations per prompt.""" device = self.text_encoder.device + pooled_text_embeddings = None if prompt_embeds is None: if self.sdxl: if tokenized_prompts is None: tokenized_prompts = self.tokenizer(prompt, - padding='max_length', - truncation=True, - return_tensors='pt', - input_ids=True) + padding='max_length', + truncation=True, + return_tensors='pt', + input_ids=True) # TODO implement zero-ing out empty prompts! text_embeddings, pooled_text_embeddings = self.text_encoder( - tokenized_prompts[0].to(device), - tokenized_prompts[1].to(device)) # type: ignore + tokenized_prompts[0].to(device), tokenized_prompts[1].to(device)) # type: ignore else: if tokenized_prompts is None: tokenized_prompts = self.tokenizer(prompt, @@ -467,7 +473,6 @@ def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num truncation=True, return_tensors='pt').input_ids text_embeddings = self.text_encoder(tokenized_prompts.to(device))[0] # type: ignore - pooled_text_embeddings = None else: if self.sdxl: raise NotImplementedError('SDXL not yet supported with precomputed embeddings') From e873717be862507260930aa60ff939e87687e7c0 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Thu, 21 Sep 2023 14:53:09 -0700 Subject: [PATCH 08/30] print sdxl statement --- diffusion/datasets/image_caption.py | 1 + 1 file changed, 1 insertion(+) diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 50df8bed..87e7fa0d 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -194,6 +194,7 @@ def build_streaming_image_caption_dataloader( # Infer SDXL from tokenizer path if tokenizer_name_or_path == 'stabilityai/stable-diffusion-xl-base-1.0': + print('Detected SDXL tokenizer, using SDXL crop transform and tokenizers.') sdxl = True else: sdxl = False From d93fbdb9bb6720e53d98ce090658981de82b43e7 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Thu, 21 Sep 2023 17:35:25 -0700 Subject: [PATCH 09/30] add sdxl logic to generate --- diffusion/models/stable_diffusion.py | 53 ++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 6 deletions(-) diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 9db6fd62..303d4a78 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -340,6 +340,9 @@ def generate( num_images_per_prompt: Optional[int] = 1, seed: Optional[int] = None, progress_bar: Optional[bool] = True, + zero_out_negative_prompt: bool = True, + crop_params: Optional[list] = None, + size_params: Optional[list] = None, ): """Generates image from noise. @@ -378,10 +381,16 @@ def generate( Default: `3.0`. num_images_per_prompt (int): The number of images to generate per prompt. Default: `1`. - progress_bar (bool): Wether to use the tqdm progress bar during generation. + progress_bar (bool): Whether to use the tqdm progress bar during generation. Default: `True`. seed (int): Random seed to use for generation. Set a seed for reproducible generation. Default: `None`. + zero_out_negative_prompt (bool): Whether or not to zero out negative prompt if it is + an empty string. Default: `True`. + crop_params (list, optional): Crop parameters to use when generating images with SDXL. + Default: `None`. + size_params (list, optional): Size parameters to use when generating images with SDXL. + Default: `None`. """ _check_prompt_given(prompt, tokenized_prompts, prompt_embeds) _check_prompt_lenths(prompt, negative_prompt) @@ -402,16 +411,28 @@ def generate( do_classifier_free_guidance = guidance_scale > 1.0 # type: ignore - text_embeddings = self._prepare_text_embeddings(prompt, tokenized_prompts, prompt_embeds, num_images_per_prompt) + text_embeddings, pooled_text_embeddings = self._prepare_text_embeddings(prompt, tokenized_prompts, + prompt_embeds, num_images_per_prompt) batch_size = len(text_embeddings) # len prompts * num_images_per_prompt # classifier free guidance + negative prompts # negative prompt is given in place of the unconditional input in classifier free guidance + pooled_embeddings = None if do_classifier_free_guidance: - negative_prompt = negative_prompt or ([''] * (batch_size // num_images_per_prompt)) # type: ignore - unconditional_embeddings = self._prepare_text_embeddings(negative_prompt, tokenized_negative_prompts, - negative_prompt_embeds, num_images_per_prompt) + if negative_prompt_embeds is None and zero_out_negative_prompt: + unconditional_embeddings = torch.zeros_like(text_embeddings) + if pooled_text_embeddings is not None: + pooled_unconditional_embeddings = torch.zeros_like(pooled_text_embeddings) + else: + pooled_unconditional_embeddings = None + else: + negative_prompt = negative_prompt or ([''] * (batch_size // num_images_per_prompt)) # type: ignore + unconditional_embeddings, pooled_unconditional_embeddings = self._prepare_text_embeddings( + negative_prompt, tokenized_negative_prompts, negative_prompt_embeds, num_images_per_prompt) + # concat uncond + prompt text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]) + if self.sdxl: + pooled_embeddings = torch.cat([pooled_unconditional_embeddings, pooled_text_embeddings]) # type: ignore # prepare for diffusion generation process latents = torch.randn( @@ -424,6 +445,23 @@ def generate( # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.inference_scheduler.init_noise_sigma + added_cond_kwargs = {} + # if using SDXL, prepare added time ids & embeddings + if self.sdxl and pooled_embeddings is not None: + if not crop_params: + crop_params = [0., 0.] + if not size_params: + size_params = [width, height] + cond_original_size = torch.tensor([[width, height]]).repeat(pooled_embeddings.shape[0], + 1).to(device).float() + cond_crops_coords_top_left = torch.tensor([crop_params]).repeat(pooled_embeddings.shape[0], + 1).to(device).float() + cond_target_size = torch.tensor([size_params]).repeat(pooled_embeddings.shape[0], 1).to(device).float() + add_time_ids = torch.cat([cond_original_size, cond_crops_coords_top_left, cond_target_size], dim=1).float() + add_text_embeds = pooled_embeddings + + added_cond_kwargs = {'text_embeds': add_text_embeds, 'time_ids': add_time_ids} + # backward diffusion process for t in tqdm(self.inference_scheduler.timesteps, disable=not progress_bar): if do_classifier_free_guidance: @@ -433,7 +471,10 @@ def generate( latent_model_input = self.inference_scheduler.scale_model_input(latent_model_input, t) # Model prediction - pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + pred = self.unet(latent_model_input, + t, + encoder_hidden_states=text_embeddings, + added_cond_kwargs=added_cond_kwargs).sample if do_classifier_free_guidance: # perform guidance. Note this is only techincally correct for prediction_type 'epsilon' From 75db76f8223eff84ff1734c2ed79a8320c98f4ef Mon Sep 17 00:00:00 2001 From: jazcollins Date: Thu, 21 Sep 2023 21:23:18 -0700 Subject: [PATCH 10/30] allow setting SDXLTextEncoder device --- diffusion/models/models.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 2a9e3431..9bddc701 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -169,8 +169,12 @@ def half(self): self.text_encoder.half() self.text_encoder_2.half() - def __call__(self, tokenized_output, - tokenized_output_2): # TODO need to pass second tokenized outputs and handle pooled output + def to_device(self, composer_device): + self.text_encoder = composer_device.module_to_device(self.text_encoder) + self.text_encoder_2 = composer_device.module_to_device(self.text_encoder_2) + self.device = self.text_encoder.device + + def __call__(self, tokenized_output, tokenized_output_2): # first text encoder conditioning = self.text_encoder(tokenized_output, output_hidden_states=True).hidden_states[-2] # second text encoder @@ -337,6 +341,9 @@ def stable_diffusion_xl( model.unet.enable_xformers_memory_efficient_attention() model.vae.enable_xformers_memory_efficient_attention() + # Manually set text encoders to device + text_encoder.to_device(DeviceGPU()) + if clip_qkv is not None: if is_xformers_installed: attn_processor = ClippedXFormersAttnProcessor(clip_val=clip_qkv) From 26a133d3105e0816812062103915b76911aee785 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Sun, 24 Sep 2023 18:12:30 +0000 Subject: [PATCH 11/30] sdxltextencoder edits --- diffusion/models/models.py | 163 ++++++++++++--------------- diffusion/models/stable_diffusion.py | 7 +- 2 files changed, 74 insertions(+), 96 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 9bddc701..b2939650 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -134,95 +134,6 @@ def stable_diffusion_2( return model -class SDXLTextEncoder: - """Wrapper around HuggingFace text encoders for SDXL. - - Creates two text encoders (a CLIPTextModel and CLIPTextModelWithProjection) that behave like one. - - Args: - model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. - encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. - """ - - def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0', encode_latents_in_fp16=True): - if encode_latents_in_fp16: - self.text_encoder = CLIPTextModel.from_pretrained(model_name, - subfolder='text_encoder', - torch_dtype=torch.float16) - self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_name, - subfolder='text_encoder_2', - torch_dtype=torch.float16) - else: - self.text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder') - self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_name, subfolder='text_encoder_2') - self.device = self.text_encoder.device - - self._fsdp_wrap = False - self.text_encoder._fsdp_wrap = False - self.text_encoder_2._fsdp_wrap = False - - def requires_grad_(self, requires_grad): - self.text_encoder.requires_grad_(requires_grad) - self.text_encoder_2.requires_grad_(requires_grad) - - def half(self): - self.text_encoder.half() - self.text_encoder_2.half() - - def to_device(self, composer_device): - self.text_encoder = composer_device.module_to_device(self.text_encoder) - self.text_encoder_2 = composer_device.module_to_device(self.text_encoder_2) - self.device = self.text_encoder.device - - def __call__(self, tokenized_output, tokenized_output_2): - # first text encoder - conditioning = self.text_encoder(tokenized_output, output_hidden_states=True).hidden_states[-2] - # second text encoder - text_encoder_2_out = self.text_encoder_2(tokenized_output_2, output_hidden_states=True) - pooled_conditioning = text_encoder_2_out[0] # (batch_size, 1280) - conditioning_2 = text_encoder_2_out.hidden_states[-2] # (batch_size, 77, 1280) - - # # zero out the appropriate things - # if batch[self.text_key].sum() == 0: - # conditioning = torch.zeros_like(conditioning) - # if batch[self.text_key_2].sum() == 0: - # conditioning_2 = torch.zeros_like(conditioning_2) - # pooled_conditioning = torch.zeros_like(pooled_conditioning) - - conditioning = torch.concat([conditioning, conditioning_2], dim=-1) - return conditioning, pooled_conditioning - - -class SDXLTokenizer: - """Wrapper around HuggingFace tokenizers for SDXL. - - Tokenizes prompt with two tokenizers and returns the outputs as a list. - - Args: - model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. - """ - - def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0'): - self.tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer') - self.tokenizer_2 = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer_2') - - def __call__(self, prompt, padding, truncation, return_tensors, input_ids=False): - tokenized_output = self.tokenizer(prompt, - padding=padding, - max_length=self.tokenizer.model_max_length, - truncation=truncation, - return_tensors=return_tensors) - tokenized_output_2 = self.tokenizer_2(prompt, - padding=padding, - max_length=self.tokenizer_2.model_max_length, - truncation=truncation, - return_tensors=return_tensors) - if input_ids: - tokenized_output = tokenized_output.input_ids - tokenized_output_2 = tokenized_output_2.input_ids - return [tokenized_output, tokenized_output_2] - - def stable_diffusion_xl( model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', unet_model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', @@ -341,9 +252,6 @@ def stable_diffusion_xl( model.unet.enable_xformers_memory_efficient_attention() model.vae.enable_xformers_memory_efficient_attention() - # Manually set text encoders to device - text_encoder.to_device(DeviceGPU()) - if clip_qkv is not None: if is_xformers_installed: attn_processor = ClippedXFormersAttnProcessor(clip_val=clip_qkv) @@ -468,3 +376,74 @@ def continuous_pixel_diffusion(clip_model_name: str = 'openai/clip-vit-large-pat if is_xformers_installed: model.model.enable_xformers_memory_efficient_attention() return model + + +class SDXLTextEncoder(torch.nn.Module): + """Wrapper around HuggingFace text encoders for SDXL. + + Creates two text encoders (a CLIPTextModel and CLIPTextModelWithProjection) that behave like one. + + Args: + model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. + encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True. + """ + + def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0', encode_latents_in_fp16=True): + super().__init__() + torch_dtype = torch.float16 if encode_latents_in_fp16 else None + self.text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder', torch_dtype=torch_dtype) + self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_name, + subfolder='text_encoder_2', + torch_dtype=torch_dtype) + + @property + def device(self): + return self.text_encoder.device + + def forward(self, tokenized_text): + # first text encoder + conditioning = self.text_encoder(tokenized_text[0], output_hidden_states=True).hidden_states[-2] + # second text encoder + text_encoder_2_out = self.text_encoder_2(tokenized_text[1], output_hidden_states=True) + pooled_conditioning = text_encoder_2_out[0] # (batch_size, 1280) + conditioning_2 = text_encoder_2_out.hidden_states[-2] # (batch_size, 77, 1280) + + # # zero out the appropriate things + # if batch[self.text_key].sum() == 0: + # conditioning = torch.zeros_like(conditioning) + # if batch[self.text_key_2].sum() == 0: + # conditioning_2 = torch.zeros_like(conditioning_2) + # pooled_conditioning = torch.zeros_like(pooled_conditioning) + + conditioning = torch.concat([conditioning, conditioning_2], dim=-1) + return conditioning, pooled_conditioning + + +class SDXLTokenizer: + """Wrapper around HuggingFace tokenizers for SDXL. + + Tokenizes prompt with two tokenizers and returns the outputs as a list. + + Args: + model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. + """ + + def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0'): + self.tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer') + self.tokenizer_2 = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer_2') + + def __call__(self, prompt, padding, truncation, return_tensors, input_ids=False): + tokenized_output = self.tokenizer(prompt, + padding=padding, + max_length=self.tokenizer.model_max_length, + truncation=truncation, + return_tensors=return_tensors) + tokenized_output_2 = self.tokenizer_2(prompt, + padding=padding, + max_length=self.tokenizer_2.model_max_length, + truncation=truncation, + return_tensors=return_tensors) + if input_ids: + tokenized_output = tokenized_output.input_ids + tokenized_output_2 = tokenized_output_2.input_ids + return [tokenized_output, tokenized_output_2] diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 303d4a78..c3c3681f 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -193,7 +193,7 @@ def forward(self, batch): if self.sdxl: assert conditioning_2 is not None conditioning_2 = conditioning_2.view(-1, conditioning_2.shape[-1]) - conditioning, pooled_conditioning = self.text_encoder(conditioning, conditioning_2) + conditioning, pooled_conditioning = self.text_encoder([conditioning, conditioning_2]) else: conditioning = self.text_encoder(conditioning)[0] @@ -219,7 +219,6 @@ def forward(self, batch): added_cond_kwargs = {} # if using SDXL, prepare added time ids & embeddings if self.sdxl: - # TODO double check cond_crops_coords_top_left calc in transforms.py add_time_ids = torch.cat( [batch['cond_original_size'], batch['cond_crops_coords_top_left'], batch['cond_target_size']], dim=1) add_text_embeds = pooled_conditioning @@ -312,7 +311,7 @@ def update_metric(self, batch, outputs, metric): if self.sdxl: # Decode captions with first tokenizer captions = [ - self.tokenizer.tokenizer.decode(caption, skip_special_tokens=True) + self.tokenizer.tokenizer.decode(caption[0], skip_special_tokens=True) for caption in batch[self.text_key] ] else: @@ -505,7 +504,7 @@ def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num input_ids=True) # TODO implement zero-ing out empty prompts! text_embeddings, pooled_text_embeddings = self.text_encoder( - tokenized_prompts[0].to(device), tokenized_prompts[1].to(device)) # type: ignore + [tokenized_prompts[0].to(device), tokenized_prompts[1].to(device)]) # type: ignore else: if tokenized_prompts is None: tokenized_prompts = self.tokenizer(prompt, From 320c7a128c9068164cd5cf089cb5cca567d8b13a Mon Sep 17 00:00:00 2001 From: jazcollins Date: Mon, 25 Sep 2023 17:05:56 +0000 Subject: [PATCH 12/30] split conditioning --- diffusion/datasets/image_caption.py | 3 ++- diffusion/models/models.py | 2 +- diffusion/models/stable_diffusion.py | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 87e7fa0d..35511c2b 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -123,7 +123,8 @@ def __getitem__(self, index): return_tensors='pt', input_ids=True) tokenized_captions = [cap[0] for cap in tokenized_captions] - tokenized_caption = torch.stack(tokenized_captions) + tokenized_caption = torch.stack(tokenized_captions) + print('stacked shape:', tokenized_caption.shape) else: tokenized_caption = self.tokenizer( caption, diff --git a/diffusion/models/models.py b/diffusion/models/models.py index b2939650..5d11ee38 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -395,7 +395,7 @@ def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0', encode self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_name, subfolder='text_encoder_2', torch_dtype=torch_dtype) - + @property def device(self): return self.text_encoder.device diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index c3c3681f..a073131e 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -171,7 +171,8 @@ def forward(self, batch): inputs, conditioning = batch[self.image_key], batch[self.text_key] if self.sdxl: # TODO check this - conditioning, conditioning_2 = conditioning[0], conditioning[1] + print('batch of tokens shape:', conditioning.shape) + conditioning, conditioning_2 = conditioning[:,0,:], conditioning[:,1,:] # [B, 2, 77] else: conditioning_2 = None conditioning = conditioning.view(-1, conditioning.shape[-1]) From c2d03211726a084f1eab7f9aaf5a7ee99fb899b0 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Mon, 25 Sep 2023 17:25:39 +0000 Subject: [PATCH 13/30] remove prints --- diffusion/datasets/image_caption.py | 8 +++++--- diffusion/models/stable_diffusion.py | 2 -- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 35511c2b..dabad032 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -4,6 +4,7 @@ """Streaming Image-Caption dataset.""" import random +import logging from io import BytesIO from typing import Callable, Dict, List, Optional, Sequence, Union @@ -17,6 +18,8 @@ from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropSquare, RandomCropSquareReturnTransform from diffusion.models.models import SDXLTokenizer +log = logging.getLogger(__name__) + # Disable PIL max image size limit Image.MAX_IMAGE_PIXELS = None @@ -123,8 +126,7 @@ def __getitem__(self, index): return_tensors='pt', input_ids=True) tokenized_captions = [cap[0] for cap in tokenized_captions] - tokenized_caption = torch.stack(tokenized_captions) - print('stacked shape:', tokenized_caption.shape) + tokenized_caption = torch.stack(tokenized_captions) else: tokenized_caption = self.tokenizer( caption, @@ -195,7 +197,7 @@ def build_streaming_image_caption_dataloader( # Infer SDXL from tokenizer path if tokenizer_name_or_path == 'stabilityai/stable-diffusion-xl-base-1.0': - print('Detected SDXL tokenizer, using SDXL crop transform and tokenizers.') + log.info('Detected SDXL tokenizer, using SDXL crop transform and tokenizers.') sdxl = True else: sdxl = False diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index a073131e..0557aa0c 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -170,8 +170,6 @@ def forward(self, batch): else: inputs, conditioning = batch[self.image_key], batch[self.text_key] if self.sdxl: - # TODO check this - print('batch of tokens shape:', conditioning.shape) conditioning, conditioning_2 = conditioning[:,0,:], conditioning[:,1,:] # [B, 2, 77] else: conditioning_2 = None From 12217fc9686573c1a103f48fa78887729e18135e Mon Sep 17 00:00:00 2001 From: jazcollins Date: Mon, 25 Sep 2023 20:59:16 +0000 Subject: [PATCH 14/30] microconditioning and cleaning up comments --- diffusion/datasets/image_caption.py | 21 +++++++++++++++--- diffusion/models/models.py | 31 +++++++++++---------------- diffusion/models/stable_diffusion.py | 32 ++++++++++++---------------- 3 files changed, 44 insertions(+), 40 deletions(-) diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index dabad032..d52c5cae 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -33,6 +33,8 @@ class StreamingImageCaptionDataset(StreamingDataset): remote (str, optional): Remote directory (S3 or local filesystem) where dataset is stored. Default: ``None``. local (str, optional): Local filesystem directory where dataset is cached during operation. Default: ``None``. tokenizer_name_or_path (str): The name or path of the tokenizer to use. Default: ``'stabilityai/stable-diffusion-2-base'``. + caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``. + microcond_drop_prob (float): The probability of dropping microconditioning. Only relevant for SDXL. Default: ``0.0``. caption_selection (str): If there are multiple captions, specifies how to select a single caption. 'first' selects the first caption in the list and 'random' selects a random caption in the list. If there is only one caption, this argument is ignored. Default: ``'first'``. @@ -51,6 +53,7 @@ def __init__( local: Optional[str] = None, tokenizer_name_or_path: str = 'stabilityai/stable-diffusion-2-base', caption_drop_prob: float = 0.0, + microcond_drop_prob: float = 0.0, caption_selection: str = 'first', transform: Optional[Callable] = None, image_size: Optional[int] = None, @@ -79,6 +82,7 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, subfolder='tokenizer') self.sdxl_crop = None self.caption_drop_prob = caption_drop_prob + self.microcond_drop_prob = microcond_drop_prob self.caption_selection = caption_selection self.image_size = image_size self.image_key = image_key @@ -97,18 +101,24 @@ def __getitem__(self, index): out = {} # Image transforms if self.sdxl and self.sdxl_crop: - # sdxl crop to return params img, crop_top, crop_left, image_height, image_width = self.sdxl_crop(img) out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left]) out['cond_original_size'] = torch.tensor([image_width, image_height]) out['cond_target_size'] = torch.tensor([self.image_size, self.image_size]) + + # Microconditioning dropout as in Stability repo + # https://github.com/Stability-AI/generative-models/blob/477d8b9a7730d9b2e92b326a770c0420d00308c9/sgm/modules/encoders/modules.py#L151-L160 + if torch.rand(1) < self.microcond_drop_prob: + out['cond_crops_coords_top_left'] = out['cond_crops_coords_top_left'] * 0.0 + if torch.rand(1) < self.microcond_drop_prob: + out['cond_original_size'] = out['cond_original_size'] * 0.0 + if torch.rand(1) < self.microcond_drop_prob: + out['cond_target_size'] = out['cond_target_size'] * 0.0 else: crop_top, crop_left, image_height, image_width = None, None, None, None if self.transform is not None: img = self.transform(img) - # TODO implement dropped caption masking! - # Caption if torch.rand(1) < self.caption_drop_prob: caption = '' @@ -145,6 +155,7 @@ def build_streaming_image_caption_dataloader( batch_size: int, tokenizer_name_or_path: str = 'stabilityai/stable-diffusion-2-base', caption_drop_prob: float = 0.0, + microcond_drop_prob: float = 0.0, resize_size: int = 256, caption_selection: str = 'first', transform: Optional[List[Callable]] = None, @@ -153,6 +164,7 @@ def build_streaming_image_caption_dataloader( rand_crop: bool = False, streaming_kwargs: Optional[Dict] = None, dataloader_kwargs: Optional[Dict] = None, + ): """Builds a streaming LAION dataloader. @@ -162,6 +174,7 @@ def build_streaming_image_caption_dataloader( batch_size (int): The batch size to use for both the ``StreamingDataset`` and ``DataLoader``. tokenizer_name_or_path (str): The name or path of the tokenizer to use. Default: ``'stabilityai/stable-diffusion-2-base'``. caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``. + microcond_drop_prob (float): The probability of dropping microconditioning. Only relevant for SDXL. Default: ``0.0``. resize_size (int): The size to resize the image to. Default: ``256``. caption_selection (str): If there are multiple captions, specifies how to select a single caption. 'first' selects the first caption in the list and 'random' selects a random caption in the list. @@ -224,6 +237,7 @@ def build_streaming_image_caption_dataloader( streams=streams, tokenizer_name_or_path=tokenizer_name_or_path, caption_drop_prob=caption_drop_prob, + microcond_drop_prob=microcond_drop_prob, caption_selection=caption_selection, transform=transform, image_size=resize_size, @@ -234,6 +248,7 @@ def build_streaming_image_caption_dataloader( **streaming_kwargs, ) + dataloader = DataLoader( dataset=dataset, batch_size=batch_size, diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 5d11ee38..31fa6807 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -6,6 +6,7 @@ from typing import List, Optional import torch +import logging from composer.devices import DeviceGPU from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel from torchmetrics import MeanSquaredError @@ -25,6 +26,8 @@ except: is_xformers_installed = False +log = logging.getLogger(__name__) + def stable_diffusion_2( model_name: str = 'stabilityai/stable-diffusion-2-base', @@ -129,6 +132,7 @@ def stable_diffusion_2( attn_processor = ClippedXFormersAttnProcessor(clip_val=clip_qkv) else: attn_processor = ClippedAttnProcessor2_0(clip_val=clip_qkv) + log.info('Using %s with clip_val %.1f'%(attn_processor.__class__, clip_qkv)) model.unet.set_attn_processor(attn_processor) return model @@ -195,12 +199,12 @@ def stable_diffusion_xl( metric.requires_grad_(False) if pretrained: - raise NotImplementedError('Full SDXL pipeline not implemented yet.') + unet = UNet2DConditionModel.from_pretrained(model_name, subfolder='unet') else: config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet') unet = UNet2DConditionModel(**config[0]) - # Zero initialization trick for more stable training + # Zero initialization trick for name, layer in unet.named_modules(): # Final conv in ResNet blocks if name.endswith('conv2'): @@ -211,16 +215,11 @@ def stable_diffusion_xl( # Last conv block out projection unet.conv_out = zero_module(unet.conv_out) - if encode_latents_in_fp16: - try: - vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=torch.float16) - except: # for handling SDXL vae fp16 fixed checkpoint - vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=torch.float16) - else: - try: - vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae') - except: # for handling SDXL vae fp16 fixed checkpoint - vae = AutoencoderKL.from_pretrained(vae_model_name) + torch_dtype = torch.float16 if encode_latents_in_fp16 else None + try: + vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=torch_dtype) + except: # for handling SDXL vae fp16 fixed checkpoint + vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=torch_dtype) tokenizer = SDXLTokenizer(model_name) text_encoder = SDXLTextEncoder(model_name, encode_latents_in_fp16) @@ -257,6 +256,7 @@ def stable_diffusion_xl( attn_processor = ClippedXFormersAttnProcessor(clip_val=clip_qkv) else: attn_processor = ClippedAttnProcessor2_0(clip_val=clip_qkv) + log.info('Using %s with clip_val %.1f'%(attn_processor.__class__, clip_qkv)) model.unet.set_attn_processor(attn_processor) return model @@ -408,13 +408,6 @@ def forward(self, tokenized_text): pooled_conditioning = text_encoder_2_out[0] # (batch_size, 1280) conditioning_2 = text_encoder_2_out.hidden_states[-2] # (batch_size, 77, 1280) - # # zero out the appropriate things - # if batch[self.text_key].sum() == 0: - # conditioning = torch.zeros_like(conditioning) - # if batch[self.text_key_2].sum() == 0: - # conditioning_2 = torch.zeros_like(conditioning_2) - # pooled_conditioning = torch.zeros_like(pooled_conditioning) - conditioning = torch.concat([conditioning, conditioning_2], dim=-1) return conditioning, pooled_conditioning diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 0557aa0c..8098e537 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -161,7 +161,7 @@ def __init__(self, self.unet._fsdp_wrap = True def forward(self, batch): - latents, conditioning, pooled_conditioning = None, None, None + latents, conditioning, conditioning_2, pooled_conditioning = None, None, None, None # Use latents if specified and available. When specified, they might not exist during eval if self.precomputed_latents and self.image_latents_key in batch and self.text_latents_key in batch: if self.sdxl: @@ -170,9 +170,10 @@ def forward(self, batch): else: inputs, conditioning = batch[self.image_key], batch[self.text_key] if self.sdxl: - conditioning, conditioning_2 = conditioning[:,0,:], conditioning[:,1,:] # [B, 2, 77] - else: - conditioning_2 = None + # If SDXL, separate the conditioning ([B, 2, 77]) from each tokenizer + conditioning, conditioning_2 = conditioning[:,0,:], conditioning[:,1,:] + conditioning_2 = conditioning_2.view(-1, conditioning_2.shape[-1]) + conditioning = conditioning.view(-1, conditioning.shape[-1]) if self.encode_latents_in_fp16: # Disable autocast context as models are in fp16 @@ -180,21 +181,17 @@ def forward(self, batch): # Encode the images to the latent space. # Encode prompt into conditioning vector latents = self.vae.encode(inputs.half())['latent_dist'].sample().data - # if self.sdxl: - # conditioning_2 = batch[self.text_key_2].view(-1, conditioning_2.shape[-1]) - # conditioning, pooled_conditioning = self.text_encoder(conditioning, conditioning_2) - # else: - # conditioning = self.text_encoder(conditioning)[0] # Should be (batch_size, 77, 768) - # pooled_conditioning = None - else: - latents = self.vae.encode(inputs)['latent_dist'].sample().data + if self.sdxl: + conditioning, pooled_conditioning = self.text_encoder([conditioning, conditioning_2]) + else: + conditioning = self.text_encoder(conditioning)[0] # Should be (batch_size, 77, 768) - if self.sdxl: - assert conditioning_2 is not None - conditioning_2 = conditioning_2.view(-1, conditioning_2.shape[-1]) - conditioning, pooled_conditioning = self.text_encoder([conditioning, conditioning_2]) else: - conditioning = self.text_encoder(conditioning)[0] + latents = self.vae.encode(inputs)['latent_dist'].sample().data + if self.sdxl: + conditioning, pooled_conditioning = self.text_encoder([conditioning, conditioning_2]) + else: + conditioning = self.text_encoder(conditioning)[0] # Magical scaling number (See https://github.com/huggingface/diffusers/issues/437#issuecomment-1241827515) latents *= self.latent_scale @@ -501,7 +498,6 @@ def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num truncation=True, return_tensors='pt', input_ids=True) - # TODO implement zero-ing out empty prompts! text_embeddings, pooled_text_embeddings = self.text_encoder( [tokenized_prompts[0].to(device), tokenized_prompts[1].to(device)]) # type: ignore else: From 06b2f2fd17f659a80c79f3699df1ea464bbfeccb Mon Sep 17 00:00:00 2001 From: jazcollins Date: Mon, 25 Sep 2023 14:02:17 -0700 Subject: [PATCH 15/30] fix style --- diffusion/datasets/image_caption.py | 4 +--- diffusion/models/models.py | 8 ++++---- diffusion/models/stable_diffusion.py | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index d52c5cae..59a17224 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -3,8 +3,8 @@ """Streaming Image-Caption dataset.""" -import random import logging +import random from io import BytesIO from typing import Callable, Dict, List, Optional, Sequence, Union @@ -164,7 +164,6 @@ def build_streaming_image_caption_dataloader( rand_crop: bool = False, streaming_kwargs: Optional[Dict] = None, dataloader_kwargs: Optional[Dict] = None, - ): """Builds a streaming LAION dataloader. @@ -248,7 +247,6 @@ def build_streaming_image_caption_dataloader( **streaming_kwargs, ) - dataloader = DataLoader( dataset=dataset, batch_size=batch_size, diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 31fa6807..75c496fe 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -3,10 +3,10 @@ """Constructors for diffusion models.""" +import logging from typing import List, Optional import torch -import logging from composer.devices import DeviceGPU from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel from torchmetrics import MeanSquaredError @@ -132,7 +132,7 @@ def stable_diffusion_2( attn_processor = ClippedXFormersAttnProcessor(clip_val=clip_qkv) else: attn_processor = ClippedAttnProcessor2_0(clip_val=clip_qkv) - log.info('Using %s with clip_val %.1f'%(attn_processor.__class__, clip_qkv)) + log.info('Using %s with clip_val %.1f' % (attn_processor.__class__, clip_qkv)) model.unet.set_attn_processor(attn_processor) return model @@ -256,7 +256,7 @@ def stable_diffusion_xl( attn_processor = ClippedXFormersAttnProcessor(clip_val=clip_qkv) else: attn_processor = ClippedAttnProcessor2_0(clip_val=clip_qkv) - log.info('Using %s with clip_val %.1f'%(attn_processor.__class__, clip_qkv)) + log.info('Using %s with clip_val %.1f' % (attn_processor.__class__, clip_qkv)) model.unet.set_attn_processor(attn_processor) return model @@ -399,7 +399,7 @@ def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0', encode @property def device(self): return self.text_encoder.device - + def forward(self, tokenized_text): # first text encoder conditioning = self.text_encoder(tokenized_text[0], output_hidden_states=True).hidden_states[-2] diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 8098e537..e97df520 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -171,7 +171,7 @@ def forward(self, batch): inputs, conditioning = batch[self.image_key], batch[self.text_key] if self.sdxl: # If SDXL, separate the conditioning ([B, 2, 77]) from each tokenizer - conditioning, conditioning_2 = conditioning[:,0,:], conditioning[:,1,:] + conditioning, conditioning_2 = conditioning[:, 0, :], conditioning[:, 1, :] conditioning_2 = conditioning_2.view(-1, conditioning_2.shape[-1]) conditioning = conditioning.view(-1, conditioning.shape[-1]) From 77b009901d14124ee6d6bed1fb3a5ec8a1517750 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Mon, 25 Sep 2023 21:23:55 +0000 Subject: [PATCH 16/30] fix dropout dtype --- diffusion/datasets/image_caption.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 59a17224..4a41425b 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -109,11 +109,11 @@ def __getitem__(self, index): # Microconditioning dropout as in Stability repo # https://github.com/Stability-AI/generative-models/blob/477d8b9a7730d9b2e92b326a770c0420d00308c9/sgm/modules/encoders/modules.py#L151-L160 if torch.rand(1) < self.microcond_drop_prob: - out['cond_crops_coords_top_left'] = out['cond_crops_coords_top_left'] * 0.0 + out['cond_crops_coords_top_left'] = out['cond_crops_coords_top_left'] * 0 if torch.rand(1) < self.microcond_drop_prob: - out['cond_original_size'] = out['cond_original_size'] * 0.0 + out['cond_original_size'] = out['cond_original_size'] * 0 if torch.rand(1) < self.microcond_drop_prob: - out['cond_target_size'] = out['cond_target_size'] * 0.0 + out['cond_target_size'] = out['cond_target_size'] * 0 else: crop_top, crop_left, image_height, image_width = None, None, None, None if self.transform is not None: @@ -232,6 +232,9 @@ def build_streaming_image_caption_dataloader( transform = transforms.Compose(transform) assert isinstance(transform, Callable) + import streaming + streaming.base.util.clean_stale_shared_memory() + dataset = StreamingImageCaptionDataset( streams=streams, tokenizer_name_or_path=tokenizer_name_or_path, From 9b798c6808f88c1fc7f8c630058ec0b0b98cb551 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Mon, 25 Sep 2023 14:25:53 -0700 Subject: [PATCH 17/30] rm local streaming --- diffusion/datasets/image_caption.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 4a41425b..27344a6e 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -232,9 +232,6 @@ def build_streaming_image_caption_dataloader( transform = transforms.Compose(transform) assert isinstance(transform, Callable) - import streaming - streaming.base.util.clean_stale_shared_memory() - dataset = StreamingImageCaptionDataset( streams=streams, tokenizer_name_or_path=tokenizer_name_or_path, From f79825892847ce9bb12bdf4eaaf683d3dfa1cfaa Mon Sep 17 00:00:00 2001 From: Jasmine Collins Date: Wed, 27 Sep 2023 17:07:16 -0700 Subject: [PATCH 18/30] Update diffusion/datasets/image_caption.py Co-authored-by: Landan Seguin --- diffusion/datasets/image_caption.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 27344a6e..4419f816 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -181,7 +181,7 @@ def build_streaming_image_caption_dataloader( transform (Optional[Callable]): The transforms to apply to the image. Default: ``None``. image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``. - rand_crop (bool): If True, randomly crop images. Otherwise, center crop. ``False``. + rand_crop (bool): If True, randomly crop images. Otherwise, center crop. Default: ``False``. streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``. dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``. """ From 505b850711d754840e82f1ec68393f0a626e7f69 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Mon, 2 Oct 2023 19:48:39 +0000 Subject: [PATCH 19/30] use RandomCrop, fix LogDiffusionImages bug --- diffusion/callbacks/log_diffusion_images.py | 6 ++-- diffusion/datasets/laion/transforms.py | 40 +++++---------------- diffusion/models/models.py | 2 +- 3 files changed, 13 insertions(+), 35 deletions(-) diff --git a/diffusion/callbacks/log_diffusion_images.py b/diffusion/callbacks/log_diffusion_images.py index 088e436c..66108553 100644 --- a/diffusion/callbacks/log_diffusion_images.py +++ b/diffusion/callbacks/log_diffusion_images.py @@ -69,9 +69,9 @@ def eval_batch_end(self, state: State, logger: Logger): input_ids=True) # type: ignore for p in self.prompts ] - self.tokenized_prompts = [torch.cat(tokenized_prompts[0]), torch.cat(tokenized_prompts[1])] - self.tokenized_prompts[0] = self.tokenized_prompts[0].to(state.batch[self.text_key].device) - self.tokenized_prompts[1] = self.tokenized_prompts[1].to(state.batch[self.text_key].device) + tokenized_prompts_1 = torch.cat([tp[0] for tp in tokenized_prompts]).to(state.batch[self.text_key].device) + tokenized_prompts_2 = torch.cat([tp[1] for tp in tokenized_prompts]).to(state.batch[self.text_key].device) + self.tokenized_prompts = [tokenized_prompts_1, tokenized_prompts_2] else: if self.tokenized_prompts is None: tokenized_prompts = [ diff --git a/diffusion/datasets/laion/transforms.py b/diffusion/datasets/laion/transforms.py index 32792885..c0ce01b8 100644 --- a/diffusion/datasets/laion/transforms.py +++ b/diffusion/datasets/laion/transforms.py @@ -1,37 +1,11 @@ # Copyright 2022 MosaicML Diffusion authors # SPDX-License-Identifier: Apache-2.0 -"""Transforms for the laion dataset.""" +"""Transforms for the training and eval dataset.""" -import numpy as np import torchvision.transforms as transforms -from torchvision.transforms.functional import crop, get_dimensions - - -def random_crop_params(img, output_size): - """Helper function to return the parameters for a random crop. - - Args: - img (PIL Image or Tensor): Input image. - output_size (int): Size of output image. - - Returns: - cropped_im (PIL Image or Tensor): Cropped square image of output_size. - c_top (int): Top crop coordinate. - c_left (int): Left crop coordinate. - """ - _, image_height, image_width = get_dimensions(img) - if image_height == image_width: - c_left = 0 - c_top = 0 - elif image_height < image_width: - c_left = np.random.randint(0, image_width - output_size) - c_top = 0 - else: - c_left = 0 - c_top = np.random.randint(0, image_height - output_size) - cropped_im = crop(img, c_top, c_left, output_size, output_size) - return cropped_im, c_top, c_left +from torchvision.transforms.functional import crop +from torchvision.transforms import RandomCrop class LargestCenterSquare: @@ -54,12 +28,14 @@ class RandomCropSquare: def __init__(self, size): self.size = size + self.random_crop = RandomCrop(size) def __call__(self, img): # First, resize the image such that the smallest side is self.size while preserving aspect ratio. img = transforms.functional.resize(img, self.size, antialias=True) # Then take a center crop to a square & return crop params. - img, _, _ = random_crop_params(img, self.size) + c_top, c_left, h, w = self.random_crop.get_params(img, (self.size, self.size)) + img = crop(img, c_top, c_left, h, w) return img @@ -68,11 +44,13 @@ class RandomCropSquareReturnTransform: def __init__(self, size): self.size = size + self.random_crop = RandomCrop(size) def __call__(self, img): # First, resize the image such that the smallest side is self.size while preserving aspect ratio. orig_w, orig_h = img.size img = transforms.functional.resize(img, self.size, antialias=True) # Then take a center crop to a square & return crop params. - img, c_top, c_left = random_crop_params(img, self.size) + c_top, c_left, h, w = self.random_crop.get_params(img, (self.size, self.size)) + img = crop(img, c_top, c_left, h, w) return img, c_top, c_left, orig_h, orig_w diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 75c496fe..4b19526f 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -199,7 +199,7 @@ def stable_diffusion_xl( metric.requires_grad_(False) if pretrained: - unet = UNet2DConditionModel.from_pretrained(model_name, subfolder='unet') + unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder='unet') else: config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet') unet = UNet2DConditionModel(**config[0]) From 054d1ef454d62220b23b348b7d4296fed52b4dbb Mon Sep 17 00:00:00 2001 From: jazcollins Date: Mon, 2 Oct 2023 20:35:55 +0000 Subject: [PATCH 20/30] have tokenizers pass dict output --- diffusion/callbacks/log_diffusion_images.py | 20 +++++++----------- diffusion/datasets/image_caption.py | 23 +++++++++------------ diffusion/models/models.py | 17 ++++++++------- diffusion/models/stable_diffusion.py | 19 +++++++---------- 4 files changed, 33 insertions(+), 46 deletions(-) diff --git a/diffusion/callbacks/log_diffusion_images.py b/diffusion/callbacks/log_diffusion_images.py index 66108553..f114a3bd 100644 --- a/diffusion/callbacks/log_diffusion_images.py +++ b/diffusion/callbacks/log_diffusion_images.py @@ -62,24 +62,18 @@ def eval_batch_end(self, state: State, logger: Logger): else: model = state.model + if self.tokenized_prompts is None: + tokenized_prompts = [ + model.tokenizer(p, padding='max_length', truncation=True, + return_tensors='pt')['input_ids'] # type: ignore + for p in self.prompts + ] if model.sdxl: - if self.tokenized_prompts is None: - tokenized_prompts = [ - model.tokenizer(p, padding='max_length', truncation=True, return_tensors='pt', - input_ids=True) # type: ignore - for p in self.prompts - ] tokenized_prompts_1 = torch.cat([tp[0] for tp in tokenized_prompts]).to(state.batch[self.text_key].device) tokenized_prompts_2 = torch.cat([tp[1] for tp in tokenized_prompts]).to(state.batch[self.text_key].device) self.tokenized_prompts = [tokenized_prompts_1, tokenized_prompts_2] else: - if self.tokenized_prompts is None: - tokenized_prompts = [ - model.tokenizer(p, padding='max_length', truncation=True, - return_tensors='pt')['input_ids'] # type: ignore - for p in self.prompts - ] - self.tokenized_prompts = torch.cat(tokenized_prompts) + self.tokenized_prompts = torch.cat(tokenized_prompts) self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # type: ignore # Generate images diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 4419f816..25136818 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -129,21 +129,18 @@ def __getitem__(self, index): if isinstance(caption, List) and self.caption_selection == 'random': caption = random.sample(caption, k=1)[0] + max_length = None if self.sdxl else self.tokenizer.model_max_length + tokenized_caption = self.tokenizer( + caption, + padding='max_length', + max_length=max_length, + truncation=True, + return_tensors='pt')['input_ids'] if self.sdxl: - tokenized_captions = self.tokenizer(caption, - padding='max_length', - truncation=True, - return_tensors='pt', - input_ids=True) - tokenized_captions = [cap[0] for cap in tokenized_captions] - tokenized_caption = torch.stack(tokenized_captions) + tokenized_caption = [tokenized_cap.squeeze() for tokenized_cap in tokenized_caption] + tokenized_caption = torch.stack(tokenized_caption) else: - tokenized_caption = self.tokenizer( - caption, - padding='max_length', - max_length=self.tokenizer.model_max_length, # type: ignore - truncation=True, - return_tensors='pt')['input_ids'][0] + tokenized_caption = tokenized_caption.squeeze() out['image'] = img out['captions'] = tokenized_caption return out diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 4b19526f..cfa857ef 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -415,7 +415,7 @@ def forward(self, tokenized_text): class SDXLTokenizer: """Wrapper around HuggingFace tokenizers for SDXL. - Tokenizes prompt with two tokenizers and returns the outputs as a list. + Tokenizes prompt with two tokenizers and returns the joined output. Args: model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'. @@ -425,18 +425,19 @@ def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0'): self.tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer') self.tokenizer_2 = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer_2') - def __call__(self, prompt, padding, truncation, return_tensors, input_ids=False): + def __call__(self, prompt, padding, truncation, return_tensors, max_length=None): tokenized_output = self.tokenizer(prompt, padding=padding, - max_length=self.tokenizer.model_max_length, + max_length=self.tokenizer.model_max_length if max_length is None else max_length, truncation=truncation, return_tensors=return_tensors) tokenized_output_2 = self.tokenizer_2(prompt, padding=padding, - max_length=self.tokenizer_2.model_max_length, + max_length=self.tokenizer_2.model_max_length if max_length is None else max_length, truncation=truncation, return_tensors=return_tensors) - if input_ids: - tokenized_output = tokenized_output.input_ids - tokenized_output_2 = tokenized_output_2.input_ids - return [tokenized_output, tokenized_output_2] + + # Add second tokenizer output to first tokenizer + for key in tokenized_output.keys(): + tokenized_output[key] = [tokenized_output[key], tokenized_output_2[key]] + return tokenized_output diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index e97df520..8ff647aa 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -491,22 +491,17 @@ def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num device = self.text_encoder.device pooled_text_embeddings = None if prompt_embeds is None: + max_length = None if self.sdxl else self.tokenizer.model_max_length + if tokenized_prompts is None: + tokenized_prompts = self.tokenizer(prompt, + padding='max_length', + max_length=max_length, + truncation=True, + return_tensors='pt').input_ids if self.sdxl: - if tokenized_prompts is None: - tokenized_prompts = self.tokenizer(prompt, - padding='max_length', - truncation=True, - return_tensors='pt', - input_ids=True) text_embeddings, pooled_text_embeddings = self.text_encoder( [tokenized_prompts[0].to(device), tokenized_prompts[1].to(device)]) # type: ignore else: - if tokenized_prompts is None: - tokenized_prompts = self.tokenizer(prompt, - padding='max_length', - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors='pt').input_ids text_embeddings = self.text_encoder(tokenized_prompts.to(device))[0] # type: ignore else: if self.sdxl: From 65b1a8b3a43808a90d35cbc7871db3e360fecf4c Mon Sep 17 00:00:00 2001 From: jazcollins Date: Mon, 2 Oct 2023 20:39:58 +0000 Subject: [PATCH 21/30] add to layers.py docs --- diffusion/models/layers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/diffusion/models/layers.py b/diffusion/models/layers.py index 19847830..dbd33221 100644 --- a/diffusion/models/layers.py +++ b/diffusion/models/layers.py @@ -24,8 +24,11 @@ def zero_module(module): class ClippedAttnProcessor2_0: """Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). - Modified from https://github.com/huggingface/diffusers/blob/v0.21.0-release/src/diffusers/models/attention_processor.py to + Modified from https://github.com/huggingface/diffusers/blob/v0.21.0-release/src/diffusers/models/attention_processor.py#L977 to allow clipping QKV values. + + Args: + clip_val (float, defaults to 6.0): Amount to clip query, key, and value by. """ def __init__(self, clip_val=6.0): @@ -120,7 +123,7 @@ def __call__( class ClippedXFormersAttnProcessor: """Processor for implementing memory efficient attention using xFormers. - Modified from https://github.com/huggingface/diffusers/blob/v0.21.0-release/src/diffusers/models/attention_processor.py to + Modified from https://github.com/huggingface/diffusers/blob/v0.21.0-release/src/diffusers/models/attention_processor.py#L888 to allow clipping QKV values. Args: @@ -129,6 +132,7 @@ class ClippedXFormersAttnProcessor: [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. + clip_val (float, defaults to 6.0): Amount to clip query, key, and value by. """ def __init__(self, clip_val=6.0, attention_op=None): From dd35c77fb97aaf3919b02acc7f373797ee758de7 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Mon, 2 Oct 2023 20:54:10 +0000 Subject: [PATCH 22/30] override prediction_type in inference_noise_scheulder --- diffusion/models/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index cfa857ef..0b0a03e9 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -226,6 +226,7 @@ def stable_diffusion_xl( noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder='scheduler') inference_noise_scheduler = EulerDiscreteScheduler.from_pretrained(model_name, subfolder='scheduler') + inference_noise_scheduler.prediction_type = prediction_type model = StableDiffusion( unet=unet, From 45c0cd7b4ace68100b1f16b79054311baec599ef Mon Sep 17 00:00:00 2001 From: Jasmine Collins Date: Mon, 2 Oct 2023 13:54:37 -0700 Subject: [PATCH 23/30] Update diffusion/models/stable_diffusion.py Co-authored-by: Landan Seguin --- diffusion/models/stable_diffusion.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 8ff647aa..7ee0c764 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -447,12 +447,8 @@ def generate( crop_params = [0., 0.] if not size_params: size_params = [width, height] - cond_original_size = torch.tensor([[width, height]]).repeat(pooled_embeddings.shape[0], - 1).to(device).float() - cond_crops_coords_top_left = torch.tensor([crop_params]).repeat(pooled_embeddings.shape[0], - 1).to(device).float() - cond_target_size = torch.tensor([size_params]).repeat(pooled_embeddings.shape[0], 1).to(device).float() - add_time_ids = torch.cat([cond_original_size, cond_crops_coords_top_left, cond_target_size], dim=1).float() + add_time_ids = torch.tensor([[width, height, *crop_params, *size_params]], dtype=torch.float, device=device) + add_time_ids = add_time_ids.repeat(pooled_embeddings.shape[0], 1) add_text_embeds = pooled_embeddings added_cond_kwargs = {'text_embeds': add_text_embeds, 'time_ids': add_time_ids} From 2346076ddb6db640b0e8a066722b8f698198c4dd Mon Sep 17 00:00:00 2001 From: jazcollins Date: Mon, 2 Oct 2023 14:01:26 -0700 Subject: [PATCH 24/30] fix style --- diffusion/callbacks/log_diffusion_images.py | 13 ++++++----- diffusion/datasets/image_caption.py | 13 ++++++----- diffusion/datasets/laion/transforms.py | 2 +- diffusion/models/models.py | 24 +++++++++++---------- 4 files changed, 28 insertions(+), 24 deletions(-) diff --git a/diffusion/callbacks/log_diffusion_images.py b/diffusion/callbacks/log_diffusion_images.py index f114a3bd..2b5a7c50 100644 --- a/diffusion/callbacks/log_diffusion_images.py +++ b/diffusion/callbacks/log_diffusion_images.py @@ -63,18 +63,21 @@ def eval_batch_end(self, state: State, logger: Logger): model = state.model if self.tokenized_prompts is None: - tokenized_prompts = [ + self.tokenized_prompts = [ model.tokenizer(p, padding='max_length', truncation=True, return_tensors='pt')['input_ids'] # type: ignore for p in self.prompts ] + if model.sdxl: - tokenized_prompts_1 = torch.cat([tp[0] for tp in tokenized_prompts]).to(state.batch[self.text_key].device) - tokenized_prompts_2 = torch.cat([tp[1] for tp in tokenized_prompts]).to(state.batch[self.text_key].device) + tokenized_prompts_1 = torch.cat([tp[0] for tp in self.tokenized_prompts + ]).to(state.batch[self.text_key].device) + tokenized_prompts_2 = torch.cat([tp[1] for tp in self.tokenized_prompts + ]).to(state.batch[self.text_key].device) self.tokenized_prompts = [tokenized_prompts_1, tokenized_prompts_2] else: - self.tokenized_prompts = torch.cat(tokenized_prompts) - self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # type: ignore + self.tokenized_prompts = torch.cat(self.tokenized_prompts) # type: ignore + self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # Generate images with get_precision_context(state.precision): diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index 25136818..24cec6dd 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -129,13 +129,12 @@ def __getitem__(self, index): if isinstance(caption, List) and self.caption_selection == 'random': caption = random.sample(caption, k=1)[0] - max_length = None if self.sdxl else self.tokenizer.model_max_length - tokenized_caption = self.tokenizer( - caption, - padding='max_length', - max_length=max_length, - truncation=True, - return_tensors='pt')['input_ids'] + max_length = None if self.sdxl else self.tokenizer.model_max_length # type: ignore + tokenized_caption = self.tokenizer(caption, + padding='max_length', + max_length=max_length, + truncation=True, + return_tensors='pt')['input_ids'] if self.sdxl: tokenized_caption = [tokenized_cap.squeeze() for tokenized_cap in tokenized_caption] tokenized_caption = torch.stack(tokenized_caption) diff --git a/diffusion/datasets/laion/transforms.py b/diffusion/datasets/laion/transforms.py index c0ce01b8..86ed78ad 100644 --- a/diffusion/datasets/laion/transforms.py +++ b/diffusion/datasets/laion/transforms.py @@ -4,8 +4,8 @@ """Transforms for the training and eval dataset.""" import torchvision.transforms as transforms -from torchvision.transforms.functional import crop from torchvision.transforms import RandomCrop +from torchvision.transforms.functional import crop class LargestCenterSquare: diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 0b0a03e9..38e07c4c 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -427,17 +427,19 @@ def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0'): self.tokenizer_2 = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer_2') def __call__(self, prompt, padding, truncation, return_tensors, max_length=None): - tokenized_output = self.tokenizer(prompt, - padding=padding, - max_length=self.tokenizer.model_max_length if max_length is None else max_length, - truncation=truncation, - return_tensors=return_tensors) - tokenized_output_2 = self.tokenizer_2(prompt, - padding=padding, - max_length=self.tokenizer_2.model_max_length if max_length is None else max_length, - truncation=truncation, - return_tensors=return_tensors) - + tokenized_output = self.tokenizer( + prompt, + padding=padding, + max_length=self.tokenizer.model_max_length if max_length is None else max_length, + truncation=truncation, + return_tensors=return_tensors) + tokenized_output_2 = self.tokenizer_2( + prompt, + padding=padding, + max_length=self.tokenizer_2.model_max_length if max_length is None else max_length, + truncation=truncation, + return_tensors=return_tensors) + # Add second tokenizer output to first tokenizer for key in tokenized_output.keys(): tokenized_output[key] = [tokenized_output[key], tokenized_output_2[key]] From 8d8e2ae16ad62193544b2d46b0730b32ac27d1d3 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Mon, 2 Oct 2023 15:38:28 -0700 Subject: [PATCH 25/30] log_diffusion_images.py fix --- diffusion/callbacks/log_diffusion_images.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/diffusion/callbacks/log_diffusion_images.py b/diffusion/callbacks/log_diffusion_images.py index 2b5a7c50..3b2f6878 100644 --- a/diffusion/callbacks/log_diffusion_images.py +++ b/diffusion/callbacks/log_diffusion_images.py @@ -68,16 +68,18 @@ def eval_batch_end(self, state: State, logger: Logger): return_tensors='pt')['input_ids'] # type: ignore for p in self.prompts ] - + if model.sdxl: + self.tokenized_prompts = [ + torch.cat([tp[0] for tp in self.tokenized_prompts]), + torch.cat([tp[1] for tp in self.tokenized_prompts]) + ] + else: + self.tokenized_prompts = torch.cat(self.tokenized_prompts) # type: ignore if model.sdxl: - tokenized_prompts_1 = torch.cat([tp[0] for tp in self.tokenized_prompts - ]).to(state.batch[self.text_key].device) - tokenized_prompts_2 = torch.cat([tp[1] for tp in self.tokenized_prompts - ]).to(state.batch[self.text_key].device) - self.tokenized_prompts = [tokenized_prompts_1, tokenized_prompts_2] + self.tokenized_prompts[0] = self.tokenized_prompts[0].to(state.batch[self.text_key].device) + self.tokenized_prompts[1] = self.tokenized_prompts[1].to(state.batch[self.text_key].device) else: - self.tokenized_prompts = torch.cat(self.tokenized_prompts) # type: ignore - self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) + self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # type: ignore # Generate images with get_precision_context(state.precision): From 8b65e7d8966ff4a79486c9878d1e4f09ab849659 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Tue, 3 Oct 2023 15:03:45 -0700 Subject: [PATCH 26/30] pass tokenized prompts as batch_size x 2 x max_length shape --- diffusion/callbacks/log_diffusion_images.py | 12 +++--------- diffusion/models/stable_diffusion.py | 8 ++++---- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/diffusion/callbacks/log_diffusion_images.py b/diffusion/callbacks/log_diffusion_images.py index 3b2f6878..04f2fb2f 100644 --- a/diffusion/callbacks/log_diffusion_images.py +++ b/diffusion/callbacks/log_diffusion_images.py @@ -69,17 +69,11 @@ def eval_batch_end(self, state: State, logger: Logger): for p in self.prompts ] if model.sdxl: - self.tokenized_prompts = [ - torch.cat([tp[0] for tp in self.tokenized_prompts]), - torch.cat([tp[1] for tp in self.tokenized_prompts]) - ] + self.tokenized_prompts = torch.stack([torch.cat(tp) for tp in self.tokenized_prompts + ]) # [B, 2, max_length] else: self.tokenized_prompts = torch.cat(self.tokenized_prompts) # type: ignore - if model.sdxl: - self.tokenized_prompts[0] = self.tokenized_prompts[0].to(state.batch[self.text_key].device) - self.tokenized_prompts[1] = self.tokenized_prompts[1].to(state.batch[self.text_key].device) - else: - self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # type: ignore + self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # type: ignore # Generate images with get_precision_context(state.precision): diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 7ee0c764..82e1a288 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -350,9 +350,9 @@ def generate( image generation away from. Ignored when not using guidance (i.e., ignored if guidance_scale is less than 1). Must be the same length as list of prompts. Default: `None`. - tokenized_prompts (torch.LongTensor or List[torch.LongTensor]): Optionally pass - pre-tokenized prompts instead of string prompts. If SDXL, this will be a list - of two pre-tokenized prompts. Default: `None`. + tokenized_prompts (torch.LongTensor): Optionally pass pre-tokenized prompts instead + of string prompts. If SDXL, this will be a tensor of size [B, 2, max_length], + otherwise will be of shape [B, max_length]. Default: `None`. tokenized_negative_prompts (torch.LongTensor): Optionally pass pre-tokenized negative prompts instead of string prompts. Default: `None`. prompt_embeds (torch.FloatTensor): Optionally pass pre-tokenized prompts instead @@ -496,7 +496,7 @@ def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num return_tensors='pt').input_ids if self.sdxl: text_embeddings, pooled_text_embeddings = self.text_encoder( - [tokenized_prompts[0].to(device), tokenized_prompts[1].to(device)]) # type: ignore + [tokenized_prompts[:, 0, :].to(device), tokenized_prompts[:, 1, :].to(device)]) # type: ignore else: text_embeddings = self.text_encoder(tokenized_prompts.to(device))[0] # type: ignore else: From 3ae948edc78d63217643f0e0e7e384ec7519313c Mon Sep 17 00:00:00 2001 From: jazcollins Date: Tue, 3 Oct 2023 15:30:44 -0700 Subject: [PATCH 27/30] stack tokenizer output to match --- diffusion/models/stable_diffusion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 82e1a288..1f8bc080 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -494,6 +494,7 @@ def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num max_length=max_length, truncation=True, return_tensors='pt').input_ids + tokenized_prompts = torch.stack([tokenized_prompts[0], tokenized_prompts[1]], dim=1) if self.sdxl: text_embeddings, pooled_text_embeddings = self.text_encoder( [tokenized_prompts[:, 0, :].to(device), tokenized_prompts[:, 1, :].to(device)]) # type: ignore From daf745c2364c7d8676eefaee3ee66cf8c76c84a1 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Tue, 3 Oct 2023 15:51:30 -0700 Subject: [PATCH 28/30] fix negative prompt classifier free guidance --- diffusion/models/stable_diffusion.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 1f8bc080..b3c157fb 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -332,7 +332,7 @@ def generate( width: Optional[int] = None, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 3.0, - num_images_per_prompt: Optional[int] = 1, + num_images_per_prompt: int = 1, seed: Optional[int] = None, progress_bar: Optional[bool] = True, zero_out_negative_prompt: bool = True, @@ -413,14 +413,13 @@ def generate( # negative prompt is given in place of the unconditional input in classifier free guidance pooled_embeddings = None if do_classifier_free_guidance: - if negative_prompt_embeds is None and zero_out_negative_prompt: + if not negative_prompt and not tokenized_negative_prompts and zero_out_negative_prompt: + # Negative prompt is empty and we want to zero it out unconditional_embeddings = torch.zeros_like(text_embeddings) - if pooled_text_embeddings is not None: - pooled_unconditional_embeddings = torch.zeros_like(pooled_text_embeddings) - else: - pooled_unconditional_embeddings = None + pooled_unconditional_embeddings = torch.zeros_like(pooled_text_embeddings) if self.sdxl else None else: - negative_prompt = negative_prompt or ([''] * (batch_size // num_images_per_prompt)) # type: ignore + if not negative_prompt: + negative_prompt = [''] * (batch_size // num_images_per_prompt) # type: ignore unconditional_embeddings, pooled_unconditional_embeddings = self._prepare_text_embeddings( negative_prompt, tokenized_negative_prompts, negative_prompt_embeds, num_images_per_prompt) From 0a1f31a611804dcc4bcfffff87c44eada1e651ba Mon Sep 17 00:00:00 2001 From: jazcollins Date: Tue, 3 Oct 2023 16:14:44 -0700 Subject: [PATCH 29/30] _prepare_text_embeddings fix --- diffusion/models/stable_diffusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index b3c157fb..16f7be47 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -493,7 +493,8 @@ def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num max_length=max_length, truncation=True, return_tensors='pt').input_ids - tokenized_prompts = torch.stack([tokenized_prompts[0], tokenized_prompts[1]], dim=1) + if self.sdxl: + tokenized_prompts = torch.stack([tokenized_prompts[0], tokenized_prompts[1]], dim=1) if self.sdxl: text_embeddings, pooled_text_embeddings = self.text_encoder( [tokenized_prompts[:, 0, :].to(device), tokenized_prompts[:, 1, :].to(device)]) # type: ignore From 67e03165786d3ba023d08d57468f3d7e8936f117 Mon Sep 17 00:00:00 2001 From: jazcollins Date: Tue, 3 Oct 2023 16:45:21 -0700 Subject: [PATCH 30/30] add negative_prompt_embeds to zero_out_negative_prompt check --- diffusion/models/stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffusion/models/stable_diffusion.py b/diffusion/models/stable_diffusion.py index 16f7be47..4435d0bc 100644 --- a/diffusion/models/stable_diffusion.py +++ b/diffusion/models/stable_diffusion.py @@ -413,7 +413,7 @@ def generate( # negative prompt is given in place of the unconditional input in classifier free guidance pooled_embeddings = None if do_classifier_free_guidance: - if not negative_prompt and not tokenized_negative_prompts and zero_out_negative_prompt: + if not negative_prompt and not tokenized_negative_prompts and not negative_prompt_embeds and zero_out_negative_prompt: # Negative prompt is empty and we want to zero it out unconditional_embeddings = torch.zeros_like(text_embeddings) pooled_unconditional_embeddings = torch.zeros_like(pooled_text_embeddings) if self.sdxl else None