Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HF image generation that integrates with Cory's earlier script #158

Merged
merged 18 commits into from
Jul 17, 2024
177 changes: 177 additions & 0 deletions diffusion/evaluation/generate_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Generates images based on a prompt dataset and uploads them for evaluation."""

import json
import os
from typing import Dict, Optional, Union
from urllib.parse import urlparse

import torch
from composer.core import get_precision_context
from composer.utils import dist
from composer.utils.file_helpers import get_file
from composer.utils.object_store import OCIObjectStore
from datasets import DatasetDict
from diffusers import AutoPipelineForText2Image
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_pil_image
from tqdm.auto import tqdm


class ImageGenerator:
"""Image generator that generates images from a dataset and saves them.

Args:
model (torch.nn.Module): The model to evaluate.
dataset (Dataset): The dataset to use the prompts from.
load_path (str, optional): The path to load the model from. Default: ``None``.
local_checkpoint_path (str, optional): The local path to save the model checkpoint. Default: ``'/tmp/model.pt'``.
load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``.
guidance_scale (float): The guidance scale to use for evaluation. Default: ``7.0``.
height (int): The height of the generated images. Default: ``1024``.
width (int): The width of the generated images. Default: ``1024``.
caption_key (str): The key to use for captions in the dataloader. Default: ``'caption'``.
load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``.
seed (int): The seed to use for generation. Default: ``17``.
output_bucket (str, Optional): The remote to save images to. Default: ``None``.
output_prefix (str, Optional): The prefix to save images to. Default: ``None``.
extra_keys (list, Optional): Extra keys from the dataset to include in the metadata. Default: ``None``.
additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method.
"""

def __init__(self,
model: Union[torch.nn.Module, str],
dataset: Union[Dataset, DatasetDict],
load_path: Optional[str] = None,
local_checkpoint_path: str = '/tmp/model.pt',
load_strict_model_weights: bool = True,
guidance_scale: float = 7.0,
height: int = 1024,
width: int = 1024,
caption_key: str = 'caption',
seed: int = 17,
output_bucket: Optional[str] = None,
output_prefix: Optional[str] = None,
extra_keys: Optional[list] = None,
additional_generate_kwargs: Optional[Dict] = None,
hf_model: Optional[bool] = False):
rishab-partha marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(model, str) and hf_model == False:
raise ValueError('Can only use strings for model with hf models!')
self.hf_model = hf_model
if hf_model or isinstance(model, str):
print(f'LOCALRANK{dist.get_local_rank()}')
rishab-partha marked this conversation as resolved.
Show resolved Hide resolved
if dist.get_local_rank() == 0:
rishab-partha marked this conversation as resolved.
Show resolved Hide resolved
self.model = AutoPipelineForText2Image.from_pretrained(
model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}')
dist.barrier()
self.model = AutoPipelineForText2Image.from_pretrained(
model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}')
dist.barrier()
else:
self.model = model
self.dataset = dataset
self.load_path = load_path
self.local_checkpoint_path = local_checkpoint_path
self.load_strict_model_weights = load_strict_model_weights
self.guidance_scale = guidance_scale
self.height = height
self.width = width
self.caption_key = caption_key
self.seed = seed
self.generator = torch.Generator(device='cuda').manual_seed(self.seed)

self.output_bucket = output_bucket
self.output_prefix = output_prefix if output_prefix is not None else ''
self.extra_keys = extra_keys if extra_keys is not None else []
self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}

# Object store for uploading images
if self.output_bucket is not None:
parsed_remote_bucket = urlparse(self.output_bucket)
if parsed_remote_bucket.scheme != 'oci':
raise ValueError(f'Currently only OCI object stores are supported. Got {parsed_remote_bucket.scheme}.')
self.object_store = OCIObjectStore(self.output_bucket.replace('oci://', ''), self.output_prefix)

# Download the model checkpoint if needed
if self.load_path is not None and not isinstance(self.model, str):
if dist.get_local_rank() == 0:
get_file(path=self.load_path, destination=self.local_checkpoint_path, overwrite=True)
with dist.local_rank_zero_download_and_wait(self.local_checkpoint_path):
# Load the model
state_dict = torch.load(self.local_checkpoint_path)
for key in list(state_dict['state']['model'].keys()):
if 'val_metrics.' in key:
del state_dict['state']['model'][key]
self.model.load_state_dict(state_dict['state']['model'], strict=self.load_strict_model_weights)
self.model = self.model.cuda().eval()

def generate(self):
"""Core image generation function. Generates images at a given guidance scale.

Args:
guidance_scale (float): The guidance scale to use for image generation.
"""
os.makedirs(os.path.join('/tmp', self.output_prefix), exist_ok=True)
# Partition the dataset across the ranks
if self.hf_model:
dataset_len = self.dataset.num_rows # type: ignore
rishab-partha marked this conversation as resolved.
Show resolved Hide resolved
else:
dataset_len = self.dataset.num_samples # type: ignore
samples_per_rank, remainder = divmod(dataset_len, dist.get_world_size())
start_idx = dist.get_global_rank() * samples_per_rank + min(remainder, dist.get_global_rank())
end_idx = start_idx + samples_per_rank
if dist.get_global_rank() < remainder:
end_idx += 1
print(f'Rank {dist.get_global_rank()} processing samples {start_idx} to {end_idx} of {dataset_len} total.')
# Iterate over the dataset
for sample_id in tqdm(range(start_idx, end_idx)):
sample = self.dataset[sample_id]
caption = sample[self.caption_key]
# Generate images from the captions
if self.hf_model:
generated_image = self.model(prompt=caption,
height=self.height,
width=self.width,
guidance_scale=self.guidance_scale,
generator=self.generator,
**self.additional_generate_kwargs).images[0]
else:
with get_precision_context('amp_fp16'):
generated_image = self.model.generate(prompt=caption,
height=self.height,
width=self.width,
guidance_scale=self.guidance_scale,
seed=self.seed,
progress_bar=False,
**self.additional_generate_kwargs) # type: ignore
# Save the images
image_name = f'{sample_id}.png'
data_name = f'{sample_id}.json'
img_local_path = os.path.join('/tmp', self.output_prefix, image_name)
data_local_path = os.path.join('/tmp', self.output_prefix, data_name)
# Save the image
if self.hf_model:
img = generated_image
else:
img = to_pil_image(generated_image[0])
img.save(img_local_path)
# Save the metadata
metadata = {
'image_name': image_name,
'prompt': caption,
'guidance_scale': self.guidance_scale,
'seed': self.seed
}
for key in self.extra_keys:
metadata[key] = sample[key]
json.dump(metadata, open(f'{data_local_path}', 'w'))
# Upload the image
if self.output_bucket is not None:
self.object_store.upload_object(object_name=os.path.join(self.output_prefix, image_name),
filename=img_local_path)
# Upload the metadata
self.object_store.upload_object(object_name=os.path.join(self.output_prefix, data_name),
filename=data_local_path)
86 changes: 86 additions & 0 deletions diffusion/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Generate images from a model."""

import operator
from typing import List

import hydra
from composer import Algorithm, ComposerModel
from composer.algorithms.low_precision_groupnorm import apply_low_precision_groupnorm
from composer.algorithms.low_precision_layernorm import apply_low_precision_layernorm
from composer.core import Precision
from composer.utils import reproducibility
from datasets import load_dataset
from omegaconf import DictConfig
from torch.utils.data import Dataset

from diffusion.evaluation.generate_images import ImageGenerator


def generate(config: DictConfig) -> None:
"""Evaluate a model.

Args:
config (DictConfig): Configuration composed by Hydra
"""
reproducibility.seed_all(config.seed)

# The model to evaluate
if not config.hf_model:
model: ComposerModel = hydra.utils.instantiate(config.model)
else:
model = config.model.name

tokenizer = model.tokenizer if hasattr(model, 'tokenizer') else None

# The dataset to use for evaluation

if config.hf_model:
rishab-partha marked this conversation as resolved.
Show resolved Hide resolved
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
elif tokenizer:
dataset = hydra.utils.instantiate(config.dataset)

else:
dataset: Dataset = hydra.utils.instantiate(config.dataset)

# Build list of algorithms.
algorithms: List[Algorithm] = []

# Some algorithms should also be applied at inference time
if 'algorithms' in config:
for ag_name, ag_conf in config.algorithms.items():
if '_target_' in ag_conf:
print(f'Instantiating algorithm <{ag_conf._target_}>')
algorithms.append(hydra.utils.instantiate(ag_conf))
elif ag_name == 'low_precision_groupnorm':
surgery_target = model
if 'attribute' in ag_conf:
surgery_target = operator.attrgetter(ag_conf.attribute)(model)
apply_low_precision_groupnorm(
model=surgery_target,
precision=Precision(ag_conf['precision']),
optimizers=None,
)
elif ag_name == 'low_precision_layernorm':
surgery_target = model
if 'attribute' in ag_conf:
surgery_target = operator.attrgetter(ag_conf.attribute)(model)
apply_low_precision_layernorm(
model=surgery_target,
precision=Precision(ag_conf['precision']),
optimizers=None,
)

image_generator: ImageGenerator = hydra.utils.instantiate(
config.generator,
model=model,
dataset=dataset,
hf_model=config.hf_model,
)

def generate_from_model():
image_generator.generate()

return generate_from_model()
26 changes: 26 additions & 0 deletions run_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Run image generation."""

import textwrap

import hydra
from omegaconf import DictConfig

from diffusion.generate import generate


@hydra.main(version_base=None)
def main(config: DictConfig) -> None:
"""Hydra wrapper for evaluation."""
if not config:
raise ValueError(
textwrap.dedent("""\
Config path and name not specified!
Please specify these by using --config-path and --config-name, respectively."""))
return generate(config)


if __name__ == '__main__':
main()
20 changes: 5 additions & 15 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,11 @@
from setuptools import find_packages, setup

install_requires = [
'mosaicml==0.20.1',
'mosaicml-streaming==0.7.4',
'hydra-core>=1.2',
'hydra-colorlog>=1.1.0',
'diffusers[torch]==0.26.3',
'transformers[torch]==4.38.2',
'huggingface_hub==0.21.2',
'wandb==0.16.3',
'xformers==0.0.23.post1',
'triton==2.1.0',
'torchmetrics[image]==1.3.1',
'lpips==0.1.4',
'clean-fid==0.1.35',
'clip@git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33',
'gradio==4.19.2',
'mosaicml==0.20.1', 'mosaicml-streaming==0.7.4', 'hydra-core>=1.2', 'hydra-colorlog>=1.1.0',
'diffusers[torch]==0.26.3', 'transformers[torch]==4.38.2', 'huggingface_hub==0.21.2', 'wandb==0.16.3',
'xformers==0.0.23.post1', 'triton==2.1.0', 'torchmetrics[image]==1.3.1', 'lpips==0.1.4', 'clean-fid==0.1.35',
'clip@git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33', 'gradio==4.19.2',
'datasets==2.19.2'
]

extras_require = {}
Expand Down
Loading