Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full SDXL Model #67

Merged
merged 30 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ec57395
random crop
jazcollins Sep 20, 2023
d34d7e9
zero init trick
jazcollins Sep 20, 2023
fb856c4
add intentionally buggy clipping
jazcollins Sep 20, 2023
4dd3c40
fix docstring and update diffusers version
jazcollins Sep 20, 2023
c1d58c9
fix attention clipping, add to sdxl, fix xformers import when not ins…
jazcollins Sep 20, 2023
f14018a
big sdxl commit, no style check
jazcollins Sep 21, 2023
45c1ac0
fix style and pyright
jazcollins Sep 21, 2023
e873717
print sdxl statement
jazcollins Sep 21, 2023
d93fbdb
add sdxl logic to generate
jazcollins Sep 22, 2023
75db76f
allow setting SDXLTextEncoder device
jazcollins Sep 22, 2023
26a133d
sdxltextencoder edits
jazcollins Sep 24, 2023
320c7a1
split conditioning
jazcollins Sep 25, 2023
c2d0321
remove prints
jazcollins Sep 25, 2023
12217fc
microconditioning and cleaning up comments
jazcollins Sep 25, 2023
06b2f2f
fix style
jazcollins Sep 25, 2023
77b0099
fix dropout dtype
jazcollins Sep 25, 2023
9b798c6
rm local streaming
jazcollins Sep 25, 2023
f798258
Update diffusion/datasets/image_caption.py
jazcollins Sep 28, 2023
505b850
use RandomCrop, fix LogDiffusionImages bug
jazcollins Oct 2, 2023
054d1ef
have tokenizers pass dict output
jazcollins Oct 2, 2023
65b1a8b
add to layers.py docs
jazcollins Oct 2, 2023
dd35c77
override prediction_type in inference_noise_scheulder
jazcollins Oct 2, 2023
45c0cd7
Update diffusion/models/stable_diffusion.py
jazcollins Oct 2, 2023
2346076
fix style
jazcollins Oct 2, 2023
8d8e2ae
log_diffusion_images.py fix
jazcollins Oct 2, 2023
8b65e7d
pass tokenized prompts as batch_size x 2 x max_length shape
jazcollins Oct 3, 2023
3ae948e
stack tokenizer output to match
jazcollins Oct 3, 2023
daf745c
fix negative prompt classifier free guidance
jazcollins Oct 3, 2023
0a1f31a
_prepare_text_embeddings fix
jazcollins Oct 3, 2023
67e0316
add negative_prompt_embeds to zero_out_negative_prompt check
jazcollins Oct 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,23 @@ def eval_batch_end(self, state: State, logger: Logger):
else:
model = state.model

if model.sdxl:
if self.tokenized_prompts is None:
tokenized_prompts = [
model.tokenizer(p, padding='max_length', truncation=True, return_tensors='pt',
input_ids=True) # type: ignore
for p in self.prompts
if self.tokenized_prompts is None:
self.tokenized_prompts = [
model.tokenizer(p, padding='max_length', truncation=True,
return_tensors='pt')['input_ids'] # type: ignore
for p in self.prompts
]
if model.sdxl:
self.tokenized_prompts = [
torch.cat([tp[0] for tp in self.tokenized_prompts]),
torch.cat([tp[1] for tp in self.tokenized_prompts])
]
self.tokenized_prompts = [torch.cat(tokenized_prompts[0]), torch.cat(tokenized_prompts[1])]
else:
self.tokenized_prompts = torch.cat(self.tokenized_prompts) # type: ignore
if model.sdxl:
self.tokenized_prompts[0] = self.tokenized_prompts[0].to(state.batch[self.text_key].device)
self.tokenized_prompts[1] = self.tokenized_prompts[1].to(state.batch[self.text_key].device)
else:
if self.tokenized_prompts is None:
tokenized_prompts = [
model.tokenizer(p, padding='max_length', truncation=True,
return_tensors='pt')['input_ids'] # type: ignore
for p in self.prompts
]
self.tokenized_prompts = torch.cat(tokenized_prompts)
self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # type: ignore

# Generate images
Expand Down
24 changes: 10 additions & 14 deletions diffusion/datasets/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,17 @@ def __getitem__(self, index):
if isinstance(caption, List) and self.caption_selection == 'random':
caption = random.sample(caption, k=1)[0]

max_length = None if self.sdxl else self.tokenizer.model_max_length # type: ignore
tokenized_caption = self.tokenizer(caption,
padding='max_length',
max_length=max_length,
truncation=True,
return_tensors='pt')['input_ids']
if self.sdxl:
tokenized_captions = self.tokenizer(caption,
padding='max_length',
truncation=True,
return_tensors='pt',
input_ids=True)
tokenized_captions = [cap[0] for cap in tokenized_captions]
tokenized_caption = torch.stack(tokenized_captions)
tokenized_caption = [tokenized_cap.squeeze() for tokenized_cap in tokenized_caption]
tokenized_caption = torch.stack(tokenized_caption)
else:
tokenized_caption = self.tokenizer(
caption,
padding='max_length',
max_length=self.tokenizer.model_max_length, # type: ignore
truncation=True,
return_tensors='pt')['input_ids'][0]
tokenized_caption = tokenized_caption.squeeze()
out['image'] = img
out['captions'] = tokenized_caption
return out
Expand Down Expand Up @@ -181,7 +177,7 @@ def build_streaming_image_caption_dataloader(
transform (Optional[Callable]): The transforms to apply to the image. Default: ``None``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``.
rand_crop (bool): If True, randomly crop images. Otherwise, center crop. ``False``.
rand_crop (bool): If True, randomly crop images. Otherwise, center crop. Default: ``False``.
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
"""
Expand Down
40 changes: 9 additions & 31 deletions diffusion/datasets/laion/transforms.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,11 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Transforms for the laion dataset."""
"""Transforms for the training and eval dataset."""

import numpy as np
import torchvision.transforms as transforms
from torchvision.transforms.functional import crop, get_dimensions


def random_crop_params(img, output_size):
"""Helper function to return the parameters for a random crop.

Args:
img (PIL Image or Tensor): Input image.
output_size (int): Size of output image.

Returns:
cropped_im (PIL Image or Tensor): Cropped square image of output_size.
c_top (int): Top crop coordinate.
c_left (int): Left crop coordinate.
"""
_, image_height, image_width = get_dimensions(img)
if image_height == image_width:
c_left = 0
c_top = 0
elif image_height < image_width:
c_left = np.random.randint(0, image_width - output_size)
c_top = 0
else:
c_left = 0
c_top = np.random.randint(0, image_height - output_size)
cropped_im = crop(img, c_top, c_left, output_size, output_size)
return cropped_im, c_top, c_left
from torchvision.transforms import RandomCrop
from torchvision.transforms.functional import crop


class LargestCenterSquare:
Expand All @@ -54,12 +28,14 @@ class RandomCropSquare:

def __init__(self, size):
self.size = size
self.random_crop = RandomCrop(size)

def __call__(self, img):
# First, resize the image such that the smallest side is self.size while preserving aspect ratio.
img = transforms.functional.resize(img, self.size, antialias=True)
# Then take a center crop to a square & return crop params.
img, _, _ = random_crop_params(img, self.size)
c_top, c_left, h, w = self.random_crop.get_params(img, (self.size, self.size))
img = crop(img, c_top, c_left, h, w)
return img


Expand All @@ -68,11 +44,13 @@ class RandomCropSquareReturnTransform:

def __init__(self, size):
self.size = size
self.random_crop = RandomCrop(size)

def __call__(self, img):
# First, resize the image such that the smallest side is self.size while preserving aspect ratio.
orig_w, orig_h = img.size
img = transforms.functional.resize(img, self.size, antialias=True)
# Then take a center crop to a square & return crop params.
img, c_top, c_left = random_crop_params(img, self.size)
c_top, c_left, h, w = self.random_crop.get_params(img, (self.size, self.size))
img = crop(img, c_top, c_left, h, w)
return img, c_top, c_left, orig_h, orig_w
8 changes: 6 additions & 2 deletions diffusion/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ def zero_module(module):
class ClippedAttnProcessor2_0:
"""Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).

Modified from https://github.com/huggingface/diffusers/blob/v0.21.0-release/src/diffusers/models/attention_processor.py to
Modified from https://github.com/huggingface/diffusers/blob/v0.21.0-release/src/diffusers/models/attention_processor.py#L977 to
allow clipping QKV values.
jazcollins marked this conversation as resolved.
Show resolved Hide resolved

Args:
clip_val (float, defaults to 6.0): Amount to clip query, key, and value by.
"""

def __init__(self, clip_val=6.0):
Expand Down Expand Up @@ -120,7 +123,7 @@ def __call__(
class ClippedXFormersAttnProcessor:
"""Processor for implementing memory efficient attention using xFormers.

Modified from https://github.com/huggingface/diffusers/blob/v0.21.0-release/src/diffusers/models/attention_processor.py to
Modified from https://github.com/huggingface/diffusers/blob/v0.21.0-release/src/diffusers/models/attention_processor.py#L888 to
allow clipping QKV values.

Args:
Expand All @@ -129,6 +132,7 @@ class ClippedXFormersAttnProcessor:
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
jazcollins marked this conversation as resolved.
Show resolved Hide resolved
clip_val (float, defaults to 6.0): Amount to clip query, key, and value by.
"""

def __init__(self, clip_val=6.0, attention_op=None):
Expand Down
38 changes: 21 additions & 17 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def stable_diffusion_xl(
metric.requires_grad_(False)

if pretrained:
unet = UNet2DConditionModel.from_pretrained(model_name, subfolder='unet')
unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder='unet')
else:
config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')
unet = UNet2DConditionModel(**config[0])
Expand All @@ -226,6 +226,7 @@ def stable_diffusion_xl(

noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder='scheduler')
inference_noise_scheduler = EulerDiscreteScheduler.from_pretrained(model_name, subfolder='scheduler')
jazcollins marked this conversation as resolved.
Show resolved Hide resolved
inference_noise_scheduler.prediction_type = prediction_type

model = StableDiffusion(
unet=unet,
Expand Down Expand Up @@ -415,7 +416,7 @@ def forward(self, tokenized_text):
class SDXLTokenizer:
"""Wrapper around HuggingFace tokenizers for SDXL.

Tokenizes prompt with two tokenizers and returns the outputs as a list.
Tokenizes prompt with two tokenizers and returns the joined output.

Args:
model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'.
Expand All @@ -425,18 +426,21 @@ def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0'):
self.tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer')
self.tokenizer_2 = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer_2')

def __call__(self, prompt, padding, truncation, return_tensors, input_ids=False):
tokenized_output = self.tokenizer(prompt,
padding=padding,
max_length=self.tokenizer.model_max_length,
truncation=truncation,
return_tensors=return_tensors)
tokenized_output_2 = self.tokenizer_2(prompt,
padding=padding,
max_length=self.tokenizer_2.model_max_length,
truncation=truncation,
return_tensors=return_tensors)
if input_ids:
tokenized_output = tokenized_output.input_ids
tokenized_output_2 = tokenized_output_2.input_ids
return [tokenized_output, tokenized_output_2]
def __call__(self, prompt, padding, truncation, return_tensors, max_length=None):
tokenized_output = self.tokenizer(
prompt,
padding=padding,
max_length=self.tokenizer.model_max_length if max_length is None else max_length,
truncation=truncation,
return_tensors=return_tensors)
tokenized_output_2 = self.tokenizer_2(
prompt,
padding=padding,
max_length=self.tokenizer_2.model_max_length if max_length is None else max_length,
truncation=truncation,
return_tensors=return_tensors)

# Add second tokenizer output to first tokenizer
for key in tokenized_output.keys():
tokenized_output[key] = [tokenized_output[key], tokenized_output_2[key]]
Landanjs marked this conversation as resolved.
Show resolved Hide resolved
return tokenized_output
27 changes: 9 additions & 18 deletions diffusion/models/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,12 +447,8 @@ def generate(
crop_params = [0., 0.]
if not size_params:
size_params = [width, height]
cond_original_size = torch.tensor([[width, height]]).repeat(pooled_embeddings.shape[0],
1).to(device).float()
cond_crops_coords_top_left = torch.tensor([crop_params]).repeat(pooled_embeddings.shape[0],
1).to(device).float()
cond_target_size = torch.tensor([size_params]).repeat(pooled_embeddings.shape[0], 1).to(device).float()
add_time_ids = torch.cat([cond_original_size, cond_crops_coords_top_left, cond_target_size], dim=1).float()
add_time_ids = torch.tensor([[width, height, *crop_params, *size_params]], dtype=torch.float, device=device)
add_time_ids = add_time_ids.repeat(pooled_embeddings.shape[0], 1)
add_text_embeds = pooled_embeddings

added_cond_kwargs = {'text_embeds': add_text_embeds, 'time_ids': add_time_ids}
Expand Down Expand Up @@ -491,22 +487,17 @@ def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num
device = self.text_encoder.device
pooled_text_embeddings = None
if prompt_embeds is None:
max_length = None if self.sdxl else self.tokenizer.model_max_length
if tokenized_prompts is None:
tokenized_prompts = self.tokenizer(prompt,
padding='max_length',
max_length=max_length,
truncation=True,
return_tensors='pt').input_ids
if self.sdxl:
if tokenized_prompts is None:
tokenized_prompts = self.tokenizer(prompt,
padding='max_length',
truncation=True,
return_tensors='pt',
input_ids=True)
text_embeddings, pooled_text_embeddings = self.text_encoder(
[tokenized_prompts[0].to(device), tokenized_prompts[1].to(device)]) # type: ignore
else:
if tokenized_prompts is None:
tokenized_prompts = self.tokenizer(prompt,
padding='max_length',
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors='pt').input_ids
text_embeddings = self.text_encoder(tokenized_prompts.to(device))[0] # type: ignore
else:
if self.sdxl:
Expand Down