Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full SDXL Model #67

Merged
merged 30 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ec57395
random crop
jazcollins Sep 20, 2023
d34d7e9
zero init trick
jazcollins Sep 20, 2023
fb856c4
add intentionally buggy clipping
jazcollins Sep 20, 2023
4dd3c40
fix docstring and update diffusers version
jazcollins Sep 20, 2023
c1d58c9
fix attention clipping, add to sdxl, fix xformers import when not ins…
jazcollins Sep 20, 2023
f14018a
big sdxl commit, no style check
jazcollins Sep 21, 2023
45c1ac0
fix style and pyright
jazcollins Sep 21, 2023
e873717
print sdxl statement
jazcollins Sep 21, 2023
d93fbdb
add sdxl logic to generate
jazcollins Sep 22, 2023
75db76f
allow setting SDXLTextEncoder device
jazcollins Sep 22, 2023
26a133d
sdxltextencoder edits
jazcollins Sep 24, 2023
320c7a1
split conditioning
jazcollins Sep 25, 2023
c2d0321
remove prints
jazcollins Sep 25, 2023
12217fc
microconditioning and cleaning up comments
jazcollins Sep 25, 2023
06b2f2f
fix style
jazcollins Sep 25, 2023
77b0099
fix dropout dtype
jazcollins Sep 25, 2023
9b798c6
rm local streaming
jazcollins Sep 25, 2023
f798258
Update diffusion/datasets/image_caption.py
jazcollins Sep 28, 2023
505b850
use RandomCrop, fix LogDiffusionImages bug
jazcollins Oct 2, 2023
054d1ef
have tokenizers pass dict output
jazcollins Oct 2, 2023
65b1a8b
add to layers.py docs
jazcollins Oct 2, 2023
dd35c77
override prediction_type in inference_noise_scheulder
jazcollins Oct 2, 2023
45c0cd7
Update diffusion/models/stable_diffusion.py
jazcollins Oct 2, 2023
2346076
fix style
jazcollins Oct 2, 2023
8d8e2ae
log_diffusion_images.py fix
jazcollins Oct 2, 2023
8b65e7d
pass tokenized prompts as batch_size x 2 x max_length shape
jazcollins Oct 3, 2023
3ae948e
stack tokenizer output to match
jazcollins Oct 3, 2023
daf745c
fix negative prompt classifier free guidance
jazcollins Oct 3, 2023
0a1f31a
_prepare_text_embeddings fix
jazcollins Oct 3, 2023
67e0316
add negative_prompt_embeds to zero_out_negative_prompt check
jazcollins Oct 3, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class LogDiffusionImages(Callback):
the text prompt, usually at the expense of lower image quality.
Default: ``0.0``.
text_key (str, optional): Key in the batch to use for text prompts. Default: ``'captions'``.
tokenized_prompts (torch.LongTensor, optional): Batch of pre-tokenized prompts
to use for evaluation. Default: ``None``.
tokenized_prompts (torch.LongTensor or List[torch.LongTensor], optional): Batch of pre-tokenized prompts
to use for evaluation. If SDXL, this will be a list of two pre-tokenized prompts Default: ``None``.
seed (int, optional): Random seed to use for generation. Set a seed for reproducible generation.
Default: ``1138``.
use_table (bool): Whether to make a table of the images or not. Default: ``False``.
Expand Down Expand Up @@ -63,13 +63,17 @@ def eval_batch_end(self, state: State, logger: Logger):
model = state.model

if self.tokenized_prompts is None:
tokenized_prompts = [
self.tokenized_prompts = [
model.tokenizer(p, padding='max_length', truncation=True,
return_tensors='pt')['input_ids'] # type: ignore
for p in self.prompts
]
self.tokenized_prompts = torch.cat(tokenized_prompts)
self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device)
if model.sdxl:
self.tokenized_prompts = torch.stack([torch.cat(tp) for tp in self.tokenized_prompts
]) # [B, 2, max_length]
else:
self.tokenized_prompts = torch.cat(self.tokenized_prompts) # type: ignore
self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # type: ignore

# Generate images
with get_precision_context(state.precision):
Expand Down
86 changes: 75 additions & 11 deletions diffusion/datasets/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Streaming Image-Caption dataset."""

import logging
import random
from io import BytesIO
from typing import Callable, Dict, List, Optional, Sequence, Union
Expand All @@ -14,7 +15,10 @@
from torchvision import transforms
from transformers import AutoTokenizer

from diffusion.datasets.laion.transforms import LargestCenterSquare
from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropSquare, RandomCropSquareReturnTransform
from diffusion.models.models import SDXLTokenizer

log = logging.getLogger(__name__)

# Disable PIL max image size limit
Image.MAX_IMAGE_PIXELS = None
Expand All @@ -29,13 +33,16 @@ class StreamingImageCaptionDataset(StreamingDataset):
remote (str, optional): Remote directory (S3 or local filesystem) where dataset is stored. Default: ``None``.
local (str, optional): Local filesystem directory where dataset is cached during operation. Default: ``None``.
tokenizer_name_or_path (str): The name or path of the tokenizer to use. Default: ``'stabilityai/stable-diffusion-2-base'``.
caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``.
microcond_drop_prob (float): The probability of dropping microconditioning. Only relevant for SDXL. Default: ``0.0``.
caption_selection (str): If there are multiple captions, specifies how to select a single caption.
'first' selects the first caption in the list and 'random' selects a random caption in the list.
If there is only one caption, this argument is ignored. Default: ``'first'``.
transform (Optional[Callable]): The transforms to apply to the image. Default: ``None``.
image_size (Optional[int]): The size to resize the image to. Default: ``None``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``.
sdxl (bool): Whether or not we're training SDXL. Default: `False`.
**streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader
"""

Expand All @@ -46,11 +53,13 @@ def __init__(
local: Optional[str] = None,
tokenizer_name_or_path: str = 'stabilityai/stable-diffusion-2-base',
caption_drop_prob: float = 0.0,
microcond_drop_prob: float = 0.0,
caption_selection: str = 'first',
transform: Optional[Callable] = None,
image_size: Optional[int] = None,
image_key: str = 'image',
caption_key: str = 'caption',
sdxl: bool = False,
**streaming_kwargs,
) -> None:

Expand All @@ -65,8 +74,15 @@ def __init__(
raise ValueError(f'Invalid caption selection: {caption_selection}. Must be one of [random, first]')

self.transform = transform
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, subfolder='tokenizer')
self.sdxl = sdxl
if self.sdxl:
self.tokenizer = SDXLTokenizer(tokenizer_name_or_path)
self.sdxl_crop = RandomCropSquareReturnTransform(image_size)
else:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, subfolder='tokenizer')
self.sdxl_crop = None
self.caption_drop_prob = caption_drop_prob
self.microcond_drop_prob = microcond_drop_prob
self.caption_selection = caption_selection
self.image_size = image_size
self.image_key = image_key
Expand All @@ -81,6 +97,25 @@ def __getitem__(self, index):
img = Image.open(BytesIO(sample[self.image_key]))
if img.mode != 'RGB':
img = img.convert('RGB')

out = {}
# Image transforms
if self.sdxl and self.sdxl_crop:
img, crop_top, crop_left, image_height, image_width = self.sdxl_crop(img)
out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left])
out['cond_original_size'] = torch.tensor([image_width, image_height])
out['cond_target_size'] = torch.tensor([self.image_size, self.image_size])

# Microconditioning dropout as in Stability repo
# https://github.com/Stability-AI/generative-models/blob/477d8b9a7730d9b2e92b326a770c0420d00308c9/sgm/modules/encoders/modules.py#L151-L160
if torch.rand(1) < self.microcond_drop_prob:
out['cond_crops_coords_top_left'] = out['cond_crops_coords_top_left'] * 0
if torch.rand(1) < self.microcond_drop_prob:
out['cond_original_size'] = out['cond_original_size'] * 0
if torch.rand(1) < self.microcond_drop_prob:
out['cond_target_size'] = out['cond_target_size'] * 0
else:
crop_top, crop_left, image_height, image_width = None, None, None, None
if self.transform is not None:
img = self.transform(img)

Expand All @@ -93,13 +128,21 @@ def __getitem__(self, index):
caption = caption[0]
if isinstance(caption, List) and self.caption_selection == 'random':
caption = random.sample(caption, k=1)[0]

max_length = None if self.sdxl else self.tokenizer.model_max_length # type: ignore
tokenized_caption = self.tokenizer(caption,
padding='max_length',
max_length=self.tokenizer.model_max_length,
max_length=max_length,
truncation=True,
return_tensors='pt')['input_ids'][0]

return {'image': img, 'captions': tokenized_caption}
return_tensors='pt')['input_ids']
if self.sdxl:
tokenized_caption = [tokenized_cap.squeeze() for tokenized_cap in tokenized_caption]
tokenized_caption = torch.stack(tokenized_caption)
else:
tokenized_caption = tokenized_caption.squeeze()
out['image'] = img
out['captions'] = tokenized_caption
return out


def build_streaming_image_caption_dataloader(
Expand All @@ -108,11 +151,13 @@ def build_streaming_image_caption_dataloader(
batch_size: int,
tokenizer_name_or_path: str = 'stabilityai/stable-diffusion-2-base',
caption_drop_prob: float = 0.0,
microcond_drop_prob: float = 0.0,
resize_size: int = 256,
caption_selection: str = 'first',
transform: Optional[List[Callable]] = None,
image_key: str = 'image',
caption_key: str = 'caption',
rand_crop: bool = False,
streaming_kwargs: Optional[Dict] = None,
dataloader_kwargs: Optional[Dict] = None,
):
Expand All @@ -124,13 +169,15 @@ def build_streaming_image_caption_dataloader(
batch_size (int): The batch size to use for both the ``StreamingDataset`` and ``DataLoader``.
tokenizer_name_or_path (str): The name or path of the tokenizer to use. Default: ``'stabilityai/stable-diffusion-2-base'``.
caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``.
microcond_drop_prob (float): The probability of dropping microconditioning. Only relevant for SDXL. Default: ``0.0``.
resize_size (int): The size to resize the image to. Default: ``256``.
caption_selection (str): If there are multiple captions, specifies how to select a single caption.
'first' selects the first caption in the list and 'random' selects a random caption in the list.
If there is only one caption, this argument is ignored. Default: ``'first'``.
transform (Optional[Callable]): The transforms to apply to the image. Default: ``None``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``.
rand_crop (bool): If True, randomly crop images. Otherwise, center crop. Default: ``False``.
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
"""
Expand All @@ -156,26 +203,43 @@ def build_streaming_image_caption_dataloader(
for r, l in zip(remote, local):
streams.append(Stream(remote=r, local=l))

# Infer SDXL from tokenizer path
if tokenizer_name_or_path == 'stabilityai/stable-diffusion-xl-base-1.0':
log.info('Detected SDXL tokenizer, using SDXL crop transform and tokenizers.')
sdxl = True
else:
sdxl = False

# Setup the transforms to apply
crop_transform = LargestCenterSquare(resize_size) if rand_crop else RandomCropSquare(resize_size)
if transform is None:
transform = [
LargestCenterSquare(resize_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # # Normalize from 0 to 1 to -1 to 1
]
if sdxl:
# Crop will return parameters so do separately
transform = [
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
else:
transform = [
crop_transform,
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # # Normalize from 0 to 1 to -1 to 1
]
transform = transforms.Compose(transform)
assert isinstance(transform, Callable)

dataset = StreamingImageCaptionDataset(
streams=streams,
tokenizer_name_or_path=tokenizer_name_or_path,
caption_drop_prob=caption_drop_prob,
microcond_drop_prob=microcond_drop_prob,
caption_selection=caption_selection,
transform=transform,
image_size=resize_size,
image_key=image_key,
caption_key=caption_key,
batch_size=batch_size,
sdxl=sdxl,
**streaming_kwargs,
)

Expand Down
37 changes: 36 additions & 1 deletion diffusion/datasets/laion/transforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Transforms for the laion dataset."""
"""Transforms for the training and eval dataset."""

import torchvision.transforms as transforms
from torchvision.transforms import RandomCrop
from torchvision.transforms.functional import crop


class LargestCenterSquare:
Expand All @@ -19,3 +21,36 @@ def __call__(self, img):
# Then take a center crop to a square.
img = self.center_crop(img)
return img


class RandomCropSquare:
"""Randomly crop square of a PIL image."""

def __init__(self, size):
self.size = size
self.random_crop = RandomCrop(size)

def __call__(self, img):
# First, resize the image such that the smallest side is self.size while preserving aspect ratio.
img = transforms.functional.resize(img, self.size, antialias=True)
# Then take a center crop to a square & return crop params.
c_top, c_left, h, w = self.random_crop.get_params(img, (self.size, self.size))
img = crop(img, c_top, c_left, h, w)
return img


class RandomCropSquareReturnTransform:
"""Randomly crop square of a PIL image and return the crop parameters."""

def __init__(self, size):
self.size = size
self.random_crop = RandomCrop(size)

def __call__(self, img):
# First, resize the image such that the smallest side is self.size while preserving aspect ratio.
orig_w, orig_h = img.size
img = transforms.functional.resize(img, self.size, antialias=True)
# Then take a center crop to a square & return crop params.
c_top, c_left, h, w = self.random_crop.get_params(img, (self.size, self.size))
img = crop(img, c_top, c_left, h, w)
return img, c_top, c_left, orig_h, orig_w
Loading
Loading