Skip to content

Commit

Permalink
Image generation script for models from here and HF (#158)
Browse files Browse the repository at this point in the history
  • Loading branch information
rishab-partha authored Jul 17, 2024
1 parent c64e18c commit 03a0b6a
Show file tree
Hide file tree
Showing 4 changed files with 302 additions and 15 deletions.
180 changes: 180 additions & 0 deletions diffusion/evaluation/generate_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Generates images based on a prompt dataset and uploads them for evaluation."""

import json
import os
from typing import Dict, Optional, Union
from urllib.parse import urlparse

import torch
from composer.core import get_precision_context
from composer.utils import dist
from composer.utils.file_helpers import get_file
from composer.utils.object_store import OCIObjectStore
from datasets import DatasetDict
from diffusers import AutoPipelineForText2Image
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_pil_image
from tqdm.auto import tqdm


class ImageGenerator:
"""Image generator that generates images from a dataset and saves them.
Args:
model (torch.nn.Module): The model to evaluate.
dataset (Dataset): The dataset to use the prompts from.
load_path (str, optional): The path to load the model from. Default: ``None``.
local_checkpoint_path (str, optional): The local path to save the model checkpoint. Default: ``'/tmp/model.pt'``.
load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``.
guidance_scale (float): The guidance scale to use for evaluation. Default: ``7.0``.
height (int): The height of the generated images. Default: ``1024``.
width (int): The width of the generated images. Default: ``1024``.
caption_key (str): The key to use for captions in the dataloader. Default: ``'caption'``.
load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``.
seed (int): The seed to use for generation. Default: ``17``.
output_bucket (str, Optional): The remote to save images to. Default: ``None``.
output_prefix (str, Optional): The prefix to save images to. Default: ``None``.
extra_keys (list, Optional): Extra keys from the dataset to include in the metadata. Default: ``None``.
additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method.
hf_model: (bool, Optional): whether the model is HF or not. Default: ``False``.
hf_dataset: (bool, Optional): whether the dataset is HF formatted or not. Default: ``False``.
"""

def __init__(self,
model: Union[torch.nn.Module, str],
dataset: Union[Dataset, DatasetDict],
load_path: Optional[str] = None,
local_checkpoint_path: str = '/tmp/model.pt',
load_strict_model_weights: bool = True,
guidance_scale: float = 7.0,
height: int = 1024,
width: int = 1024,
caption_key: str = 'caption',
seed: int = 17,
output_bucket: Optional[str] = None,
output_prefix: Optional[str] = None,
extra_keys: Optional[list] = None,
additional_generate_kwargs: Optional[Dict] = None,
hf_model: Optional[bool] = False,
hf_dataset: Optional[bool] = False):

if isinstance(model, str) and hf_model == False:
raise ValueError('Can only use strings for model with hf models!')
self.hf_model = hf_model
self.hf_dataset = hf_dataset
if hf_model or isinstance(model, str):
if dist.get_local_rank() == 0:
self.model = AutoPipelineForText2Image.from_pretrained(
model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}')
dist.barrier()
self.model = AutoPipelineForText2Image.from_pretrained(
model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}')
dist.barrier()
else:
self.model = model
self.dataset = dataset
self.load_path = load_path
self.local_checkpoint_path = local_checkpoint_path
self.load_strict_model_weights = load_strict_model_weights
self.guidance_scale = guidance_scale
self.height = height
self.width = width
self.caption_key = caption_key
self.seed = seed
self.generator = torch.Generator(device='cuda').manual_seed(self.seed)

self.output_bucket = output_bucket
self.output_prefix = output_prefix if output_prefix is not None else ''
self.extra_keys = extra_keys if extra_keys is not None else []
self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}

# Object store for uploading images
if self.output_bucket is not None:
parsed_remote_bucket = urlparse(self.output_bucket)
if parsed_remote_bucket.scheme != 'oci':
raise ValueError(f'Currently only OCI object stores are supported. Got {parsed_remote_bucket.scheme}.')
self.object_store = OCIObjectStore(self.output_bucket.replace('oci://', ''), self.output_prefix)

# Download the model checkpoint if needed
if self.load_path is not None and not isinstance(self.model, str):
if dist.get_local_rank() == 0:
get_file(path=self.load_path, destination=self.local_checkpoint_path, overwrite=True)
with dist.local_rank_zero_download_and_wait(self.local_checkpoint_path):
# Load the model
state_dict = torch.load(self.local_checkpoint_path)
for key in list(state_dict['state']['model'].keys()):
if 'val_metrics.' in key:
del state_dict['state']['model'][key]
self.model.load_state_dict(state_dict['state']['model'], strict=self.load_strict_model_weights)
self.model = self.model.cuda().eval()

def generate(self):
"""Core image generation function. Generates images at a given guidance scale.
Args:
guidance_scale (float): The guidance scale to use for image generation.
"""
os.makedirs(os.path.join('/tmp', self.output_prefix), exist_ok=True)
# Partition the dataset across the ranks
if self.hf_dataset:
dataset_len = self.dataset.num_rows # type: ignore
else:
dataset_len = self.dataset.num_samples # type: ignore
samples_per_rank, remainder = divmod(dataset_len, dist.get_world_size())
start_idx = dist.get_global_rank() * samples_per_rank + min(remainder, dist.get_global_rank())
end_idx = start_idx + samples_per_rank
if dist.get_global_rank() < remainder:
end_idx += 1
print(f'Rank {dist.get_global_rank()} processing samples {start_idx} to {end_idx} of {dataset_len} total.')
# Iterate over the dataset
for sample_id in tqdm(range(start_idx, end_idx)):
sample = self.dataset[sample_id]
caption = sample[self.caption_key]
# Generate images from the captions
if self.hf_model:
generated_image = self.model(prompt=caption,
height=self.height,
width=self.width,
guidance_scale=self.guidance_scale,
generator=self.generator,
**self.additional_generate_kwargs).images[0]
else:
with get_precision_context('amp_fp16'):
generated_image = self.model.generate(prompt=caption,
height=self.height,
width=self.width,
guidance_scale=self.guidance_scale,
seed=self.seed,
progress_bar=False,
**self.additional_generate_kwargs) # type: ignore
# Save the images
image_name = f'{sample_id}.png'
data_name = f'{sample_id}.json'
img_local_path = os.path.join('/tmp', self.output_prefix, image_name)
data_local_path = os.path.join('/tmp', self.output_prefix, data_name)
# Save the image
if self.hf_model:
img = generated_image
else:
img = to_pil_image(generated_image[0])
img.save(img_local_path)
# Save the metadata
metadata = {
'image_name': image_name,
'prompt': caption,
'guidance_scale': self.guidance_scale,
'seed': self.seed
}
for key in self.extra_keys:
metadata[key] = sample[key]
json.dump(metadata, open(f'{data_local_path}', 'w'))
# Upload the image
if self.output_bucket is not None:
self.object_store.upload_object(object_name=os.path.join(self.output_prefix, image_name),
filename=img_local_path)
# Upload the metadata
self.object_store.upload_object(object_name=os.path.join(self.output_prefix, data_name),
filename=data_local_path)
91 changes: 91 additions & 0 deletions diffusion/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Generate images from a model."""

import operator
from typing import List

import hydra
from composer import Algorithm, ComposerModel
from composer.algorithms.low_precision_groupnorm import apply_low_precision_groupnorm
from composer.algorithms.low_precision_layernorm import apply_low_precision_layernorm
from composer.core import Precision
from composer.utils import dist, get_device, reproducibility
from datasets import load_dataset
from omegaconf import DictConfig
from torch.utils.data import Dataset

from diffusion.evaluation.generate_images import ImageGenerator


def generate(config: DictConfig) -> None:
"""Evaluate a model.
Args:
config (DictConfig): Configuration composed by Hydra
"""
reproducibility.seed_all(config.seed)
device = get_device() # type: ignore
dist.initialize_dist(device, config.dist_timeout)

# The model to evaluate
if not config.hf_model:
model: ComposerModel = hydra.utils.instantiate(config.model)
else:
model = config.model.name

tokenizer = model.tokenizer if hasattr(model, 'tokenizer') else None

# The dataset to use for evaluation

if config.hf_dataset:
if dist.get_local_rank() == 0:
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
dist.barrier()
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
dist.barrier()
elif tokenizer:
dataset = hydra.utils.instantiate(config.dataset)

else:
dataset: Dataset = hydra.utils.instantiate(config.dataset)

# Build list of algorithms.
algorithms: List[Algorithm] = []

# Some algorithms should also be applied at inference time
if 'algorithms' in config:
for ag_name, ag_conf in config.algorithms.items():
if '_target_' in ag_conf:
print(f'Instantiating algorithm <{ag_conf._target_}>')
algorithms.append(hydra.utils.instantiate(ag_conf))
elif ag_name == 'low_precision_groupnorm':
surgery_target = model
if 'attribute' in ag_conf:
surgery_target = operator.attrgetter(ag_conf.attribute)(model)
apply_low_precision_groupnorm(
model=surgery_target,
precision=Precision(ag_conf['precision']),
optimizers=None,
)
elif ag_name == 'low_precision_layernorm':
surgery_target = model
if 'attribute' in ag_conf:
surgery_target = operator.attrgetter(ag_conf.attribute)(model)
apply_low_precision_layernorm(
model=surgery_target,
precision=Precision(ag_conf['precision']),
optimizers=None,
)

image_generator: ImageGenerator = hydra.utils.instantiate(config.generator,
model=model,
dataset=dataset,
hf_model=config.hf_model,
hf_dataset=config.hf_dataset)

def generate_from_model():
image_generator.generate()

return generate_from_model()
26 changes: 26 additions & 0 deletions run_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Run image generation."""

import textwrap

import hydra
from omegaconf import DictConfig

from diffusion.generate import generate


@hydra.main(version_base=None)
def main(config: DictConfig) -> None:
"""Hydra wrapper for evaluation."""
if not config:
raise ValueError(
textwrap.dedent("""\
Config path and name not specified!
Please specify these by using --config-path and --config-name, respectively."""))
return generate(config)


if __name__ == '__main__':
main()
20 changes: 5 additions & 15 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,11 @@
from setuptools import find_packages, setup

install_requires = [
'mosaicml==0.20.1',
'mosaicml-streaming==0.7.4',
'hydra-core>=1.2',
'hydra-colorlog>=1.1.0',
'diffusers[torch]==0.26.3',
'transformers[torch]==4.38.2',
'huggingface_hub==0.21.2',
'wandb==0.16.3',
'xformers==0.0.23.post1',
'triton==2.1.0',
'torchmetrics[image]==1.3.1',
'lpips==0.1.4',
'clean-fid==0.1.35',
'clip@git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33',
'gradio==4.19.2',
'mosaicml==0.20.1', 'mosaicml-streaming==0.7.4', 'hydra-core>=1.2', 'hydra-colorlog>=1.1.0',
'diffusers[torch]==0.26.3', 'transformers[torch]==4.38.2', 'huggingface_hub==0.21.2', 'wandb==0.16.3',
'xformers==0.0.23.post1', 'triton==2.1.0', 'torchmetrics[image]==1.3.1', 'lpips==0.1.4', 'clean-fid==0.1.35',
'clip@git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33', 'gradio==4.19.2',
'datasets==2.19.2'
]

extras_require = {}
Expand Down

0 comments on commit 03a0b6a

Please sign in to comment.