Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partial SDXL model #61

Merged
merged 15 commits into from
Aug 21, 2023
129 changes: 123 additions & 6 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def stable_diffusion_2(
prompts.

Args:
model_name (str, optional): Name of the model to load. Defaults to 'stabilityai/stable-diffusion-2-base'.
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
model_name (str): Name of the model to load. Defaults to 'stabilityai/stable-diffusion-2-base'.
pretrained (bool): Whether to load pretrained weights. Defaults to True.
prediction_type (str): The type of prediction to use. Must be one of 'sample',
'epsilon', or 'v_prediction'. Default: `epsilon`.
train_metrics (list, optional): List of metrics to compute during training. If None, defaults to
Expand All @@ -54,12 +54,12 @@ def stable_diffusion_2(
[MeanSquaredError(), FrechetInceptionDistance(normalize=True)].
val_guidance_scales (list, optional): List of scales to use for validation guidance. If None, defaults to
[1.0, 3.0, 7.0].
val_seed (int, optional): Seed to use for generating evaluation images. Defaults to 1138.
val_seed (int): Seed to use for generating evaluation images. Defaults to 1138.
loss_bins (list, optional): List of tuples of (min, max) values to use for loss binning. If None, defaults to
[(0, 1)].
precomputed_latents (bool, optional): Whether to use precomputed latents. Defaults to False.
encode_latents_in_fp16 (bool, optional): Whether to encode latents in fp16. Defaults to True.
fsdp (bool, optional): Whether to use FSDP. Defaults to True.
precomputed_latents (bool): Whether to use precomputed latents. Defaults to False.
encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True.
fsdp (bool): Whether to use FSDP. Defaults to True.
"""
if train_metrics is None:
train_metrics = [MeanSquaredError()]
Expand Down Expand Up @@ -123,6 +123,123 @@ def stable_diffusion_2(
return model


def stable_diffusion_xl(
model_name: str = 'stabilityai/stable-diffusion-2-base',
jazcollins marked this conversation as resolved.
Show resolved Hide resolved
unet_model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0',
vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix',
jazcollins marked this conversation as resolved.
Show resolved Hide resolved
pretrained: bool = True,
prediction_type: str = 'epsilon',
train_metrics: Optional[List] = None,
val_metrics: Optional[List] = None,
val_guidance_scales: Optional[List] = None,
val_seed: int = 1138,
loss_bins: Optional[List] = None,
precomputed_latents: bool = False,
encode_latents_in_fp16: bool = True,
fsdp: bool = True,
):
"""Stable diffusion 2 training setup + SDXL UNet and VAE.

Requires batches of matched images and text prompts to train. Generates images from text
prompts. Currently uses UNet and VAE config from SDXL, but text encoder/tokenizer from SD2.

Args:
model_name (str): Name of the model to load. Determines the text encoder, tokenizer,
and noise scheduler. Defaults to 'stabilityai/stable-diffusion-2-base'.
unet_model_name (str): Name of the UNet model to load. Defaults to
'stabilityai/stable-diffusion-xl-base-1.0'.
vae_model_name (str): Name of the VAE model to load. Defaults to
'madebyollin/sdxl-vae-fp16-fix' as the official VAE checkpoint (from
'stabilityai/stable-diffusion-xl-base-1.0') is not compatible with fp16.
pretrained (bool): Whether to load pretrained weights. Defaults to True.
prediction_type (str): The type of prediction to use. Must be one of 'sample',
'epsilon', or 'v_prediction'. Default: `epsilon`.
train_metrics (list, optional): List of metrics to compute during training. If None, defaults to
[MeanSquaredError()].
val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to
[MeanSquaredError(), FrechetInceptionDistance(normalize=True)].
val_guidance_scales (list, optional): List of scales to use for validation guidance. If None, defaults to
[1.0, 3.0, 7.0].
val_seed (int): Seed to use for generating evaluation images. Defaults to 1138.
loss_bins (list, optional): List of tuples of (min, max) values to use for loss binning. If None, defaults to
[(0, 1)].
precomputed_latents (bool): Whether to use precomputed latents. Defaults to False.
encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True.
fsdp (bool): Whether to use FSDP. Defaults to True.
"""
if train_metrics is None:
train_metrics = [MeanSquaredError()]
if val_metrics is None:
val_metrics = [MeanSquaredError(), FrechetInceptionDistance(normalize=True)]
if val_guidance_scales is None:
val_guidance_scales = [1.0, 3.0, 7.0]
if loss_bins is None:
loss_bins = [(0, 1)]
# Fix a bug where CLIPScore requires grad
for metric in val_metrics:
if isinstance(metric, CLIPScore):
metric.requires_grad_(False)

if pretrained:
raise NotImplementedError('Full SDXL pipeline not implemented yet.')
else:
config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')
# Currently not doing micro-conditioning, so set config appropriately
config[0]['addition_embed_type'] = None
config[0]['cross_attention_dim'] = 1024
unet = UNet2DConditionModel(**config[0])

# Prevent fsdp from wrapping up_blocks and down_blocks because the forward pass calls length on these
unet.up_blocks._fsdp_wrap = False
unet.down_blocks._fsdp_wrap = False
for block in unet.up_blocks:
block._fsdp_wrap = True
for block in unet.down_blocks:
block._fsdp_wrap = True

if encode_latents_in_fp16:
vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=torch.float16)
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder', torch_dtype=torch.float16)
else:
vae = AutoencoderKL.from_pretrained(vae_model_name)
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder')

tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer')
noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder='scheduler')
inference_noise_scheduler = DDIMScheduler(num_train_timesteps=noise_scheduler.config.num_train_timesteps,
beta_start=noise_scheduler.config.beta_start,
beta_end=noise_scheduler.config.beta_end,
beta_schedule=noise_scheduler.config.beta_schedule,
trained_betas=noise_scheduler.config.trained_betas,
clip_sample=noise_scheduler.config.clip_sample,
set_alpha_to_one=noise_scheduler.config.set_alpha_to_one,
prediction_type=prediction_type)

model = StableDiffusion(
unet=unet,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
noise_scheduler=noise_scheduler,
inference_noise_scheduler=inference_noise_scheduler,
prediction_type=prediction_type,
train_metrics=train_metrics,
val_metrics=val_metrics,
val_guidance_scales=val_guidance_scales,
val_seed=val_seed,
loss_bins=loss_bins,
precomputed_latents=precomputed_latents,
encode_latents_in_fp16=encode_latents_in_fp16,
fsdp=fsdp,
)
if torch.cuda.is_available():
model = DeviceGPU().module_to_device(model)
if is_xformers_installed:
model.unet.enable_xformers_memory_efficient_attention()
model.vae.enable_xformers_memory_efficient_attention()
return model


def discrete_pixel_diffusion(clip_model_name: str = 'openai/clip-vit-large-patch14', prediction_type='epsilon'):
"""Discrete pixel diffusion training setup.

Expand Down