From 03a0b6a03978b3db7716f86750178d6a3d784064 Mon Sep 17 00:00:00 2001 From: Rishab Parthasarathy <56666587+rishab-partha@users.noreply.github.com> Date: Wed, 17 Jul 2024 16:49:31 -0700 Subject: [PATCH] Image generation script for models from here and HF (#158) --- diffusion/evaluation/generate_images.py | 180 ++++++++++++++++++++++++ diffusion/generate.py | 91 ++++++++++++ run_generation.py | 26 ++++ setup.py | 20 +-- 4 files changed, 302 insertions(+), 15 deletions(-) create mode 100644 diffusion/evaluation/generate_images.py create mode 100644 diffusion/generate.py create mode 100644 run_generation.py diff --git a/diffusion/evaluation/generate_images.py b/diffusion/evaluation/generate_images.py new file mode 100644 index 00000000..3ac40c9c --- /dev/null +++ b/diffusion/evaluation/generate_images.py @@ -0,0 +1,180 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Generates images based on a prompt dataset and uploads them for evaluation.""" + +import json +import os +from typing import Dict, Optional, Union +from urllib.parse import urlparse + +import torch +from composer.core import get_precision_context +from composer.utils import dist +from composer.utils.file_helpers import get_file +from composer.utils.object_store import OCIObjectStore +from datasets import DatasetDict +from diffusers import AutoPipelineForText2Image +from torch.utils.data import Dataset +from torchvision.transforms.functional import to_pil_image +from tqdm.auto import tqdm + + +class ImageGenerator: + """Image generator that generates images from a dataset and saves them. + + Args: + model (torch.nn.Module): The model to evaluate. + dataset (Dataset): The dataset to use the prompts from. + load_path (str, optional): The path to load the model from. Default: ``None``. + local_checkpoint_path (str, optional): The local path to save the model checkpoint. Default: ``'/tmp/model.pt'``. + load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``. + guidance_scale (float): The guidance scale to use for evaluation. Default: ``7.0``. + height (int): The height of the generated images. Default: ``1024``. + width (int): The width of the generated images. Default: ``1024``. + caption_key (str): The key to use for captions in the dataloader. Default: ``'caption'``. + load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``. + seed (int): The seed to use for generation. Default: ``17``. + output_bucket (str, Optional): The remote to save images to. Default: ``None``. + output_prefix (str, Optional): The prefix to save images to. Default: ``None``. + extra_keys (list, Optional): Extra keys from the dataset to include in the metadata. Default: ``None``. + additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method. + hf_model: (bool, Optional): whether the model is HF or not. Default: ``False``. + hf_dataset: (bool, Optional): whether the dataset is HF formatted or not. Default: ``False``. + """ + + def __init__(self, + model: Union[torch.nn.Module, str], + dataset: Union[Dataset, DatasetDict], + load_path: Optional[str] = None, + local_checkpoint_path: str = '/tmp/model.pt', + load_strict_model_weights: bool = True, + guidance_scale: float = 7.0, + height: int = 1024, + width: int = 1024, + caption_key: str = 'caption', + seed: int = 17, + output_bucket: Optional[str] = None, + output_prefix: Optional[str] = None, + extra_keys: Optional[list] = None, + additional_generate_kwargs: Optional[Dict] = None, + hf_model: Optional[bool] = False, + hf_dataset: Optional[bool] = False): + + if isinstance(model, str) and hf_model == False: + raise ValueError('Can only use strings for model with hf models!') + self.hf_model = hf_model + self.hf_dataset = hf_dataset + if hf_model or isinstance(model, str): + if dist.get_local_rank() == 0: + self.model = AutoPipelineForText2Image.from_pretrained( + model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}') + dist.barrier() + self.model = AutoPipelineForText2Image.from_pretrained( + model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}') + dist.barrier() + else: + self.model = model + self.dataset = dataset + self.load_path = load_path + self.local_checkpoint_path = local_checkpoint_path + self.load_strict_model_weights = load_strict_model_weights + self.guidance_scale = guidance_scale + self.height = height + self.width = width + self.caption_key = caption_key + self.seed = seed + self.generator = torch.Generator(device='cuda').manual_seed(self.seed) + + self.output_bucket = output_bucket + self.output_prefix = output_prefix if output_prefix is not None else '' + self.extra_keys = extra_keys if extra_keys is not None else [] + self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {} + + # Object store for uploading images + if self.output_bucket is not None: + parsed_remote_bucket = urlparse(self.output_bucket) + if parsed_remote_bucket.scheme != 'oci': + raise ValueError(f'Currently only OCI object stores are supported. Got {parsed_remote_bucket.scheme}.') + self.object_store = OCIObjectStore(self.output_bucket.replace('oci://', ''), self.output_prefix) + + # Download the model checkpoint if needed + if self.load_path is not None and not isinstance(self.model, str): + if dist.get_local_rank() == 0: + get_file(path=self.load_path, destination=self.local_checkpoint_path, overwrite=True) + with dist.local_rank_zero_download_and_wait(self.local_checkpoint_path): + # Load the model + state_dict = torch.load(self.local_checkpoint_path) + for key in list(state_dict['state']['model'].keys()): + if 'val_metrics.' in key: + del state_dict['state']['model'][key] + self.model.load_state_dict(state_dict['state']['model'], strict=self.load_strict_model_weights) + self.model = self.model.cuda().eval() + + def generate(self): + """Core image generation function. Generates images at a given guidance scale. + + Args: + guidance_scale (float): The guidance scale to use for image generation. + """ + os.makedirs(os.path.join('/tmp', self.output_prefix), exist_ok=True) + # Partition the dataset across the ranks + if self.hf_dataset: + dataset_len = self.dataset.num_rows # type: ignore + else: + dataset_len = self.dataset.num_samples # type: ignore + samples_per_rank, remainder = divmod(dataset_len, dist.get_world_size()) + start_idx = dist.get_global_rank() * samples_per_rank + min(remainder, dist.get_global_rank()) + end_idx = start_idx + samples_per_rank + if dist.get_global_rank() < remainder: + end_idx += 1 + print(f'Rank {dist.get_global_rank()} processing samples {start_idx} to {end_idx} of {dataset_len} total.') + # Iterate over the dataset + for sample_id in tqdm(range(start_idx, end_idx)): + sample = self.dataset[sample_id] + caption = sample[self.caption_key] + # Generate images from the captions + if self.hf_model: + generated_image = self.model(prompt=caption, + height=self.height, + width=self.width, + guidance_scale=self.guidance_scale, + generator=self.generator, + **self.additional_generate_kwargs).images[0] + else: + with get_precision_context('amp_fp16'): + generated_image = self.model.generate(prompt=caption, + height=self.height, + width=self.width, + guidance_scale=self.guidance_scale, + seed=self.seed, + progress_bar=False, + **self.additional_generate_kwargs) # type: ignore + # Save the images + image_name = f'{sample_id}.png' + data_name = f'{sample_id}.json' + img_local_path = os.path.join('/tmp', self.output_prefix, image_name) + data_local_path = os.path.join('/tmp', self.output_prefix, data_name) + # Save the image + if self.hf_model: + img = generated_image + else: + img = to_pil_image(generated_image[0]) + img.save(img_local_path) + # Save the metadata + metadata = { + 'image_name': image_name, + 'prompt': caption, + 'guidance_scale': self.guidance_scale, + 'seed': self.seed + } + for key in self.extra_keys: + metadata[key] = sample[key] + json.dump(metadata, open(f'{data_local_path}', 'w')) + # Upload the image + if self.output_bucket is not None: + self.object_store.upload_object(object_name=os.path.join(self.output_prefix, image_name), + filename=img_local_path) + # Upload the metadata + self.object_store.upload_object(object_name=os.path.join(self.output_prefix, data_name), + filename=data_local_path) diff --git a/diffusion/generate.py b/diffusion/generate.py new file mode 100644 index 00000000..b6b766ad --- /dev/null +++ b/diffusion/generate.py @@ -0,0 +1,91 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Generate images from a model.""" + +import operator +from typing import List + +import hydra +from composer import Algorithm, ComposerModel +from composer.algorithms.low_precision_groupnorm import apply_low_precision_groupnorm +from composer.algorithms.low_precision_layernorm import apply_low_precision_layernorm +from composer.core import Precision +from composer.utils import dist, get_device, reproducibility +from datasets import load_dataset +from omegaconf import DictConfig +from torch.utils.data import Dataset + +from diffusion.evaluation.generate_images import ImageGenerator + + +def generate(config: DictConfig) -> None: + """Evaluate a model. + + Args: + config (DictConfig): Configuration composed by Hydra + """ + reproducibility.seed_all(config.seed) + device = get_device() # type: ignore + dist.initialize_dist(device, config.dist_timeout) + + # The model to evaluate + if not config.hf_model: + model: ComposerModel = hydra.utils.instantiate(config.model) + else: + model = config.model.name + + tokenizer = model.tokenizer if hasattr(model, 'tokenizer') else None + + # The dataset to use for evaluation + + if config.hf_dataset: + if dist.get_local_rank() == 0: + dataset = load_dataset(config.dataset.name, split=config.dataset.split) + dist.barrier() + dataset = load_dataset(config.dataset.name, split=config.dataset.split) + dist.barrier() + elif tokenizer: + dataset = hydra.utils.instantiate(config.dataset) + + else: + dataset: Dataset = hydra.utils.instantiate(config.dataset) + + # Build list of algorithms. + algorithms: List[Algorithm] = [] + + # Some algorithms should also be applied at inference time + if 'algorithms' in config: + for ag_name, ag_conf in config.algorithms.items(): + if '_target_' in ag_conf: + print(f'Instantiating algorithm <{ag_conf._target_}>') + algorithms.append(hydra.utils.instantiate(ag_conf)) + elif ag_name == 'low_precision_groupnorm': + surgery_target = model + if 'attribute' in ag_conf: + surgery_target = operator.attrgetter(ag_conf.attribute)(model) + apply_low_precision_groupnorm( + model=surgery_target, + precision=Precision(ag_conf['precision']), + optimizers=None, + ) + elif ag_name == 'low_precision_layernorm': + surgery_target = model + if 'attribute' in ag_conf: + surgery_target = operator.attrgetter(ag_conf.attribute)(model) + apply_low_precision_layernorm( + model=surgery_target, + precision=Precision(ag_conf['precision']), + optimizers=None, + ) + + image_generator: ImageGenerator = hydra.utils.instantiate(config.generator, + model=model, + dataset=dataset, + hf_model=config.hf_model, + hf_dataset=config.hf_dataset) + + def generate_from_model(): + image_generator.generate() + + return generate_from_model() diff --git a/run_generation.py b/run_generation.py new file mode 100644 index 00000000..c1f7fd79 --- /dev/null +++ b/run_generation.py @@ -0,0 +1,26 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Run image generation.""" + +import textwrap + +import hydra +from omegaconf import DictConfig + +from diffusion.generate import generate + + +@hydra.main(version_base=None) +def main(config: DictConfig) -> None: + """Hydra wrapper for evaluation.""" + if not config: + raise ValueError( + textwrap.dedent("""\ + Config path and name not specified! + Please specify these by using --config-path and --config-name, respectively.""")) + return generate(config) + + +if __name__ == '__main__': + main() diff --git a/setup.py b/setup.py index a7a7b9c0..469d1268 100644 --- a/setup.py +++ b/setup.py @@ -6,21 +6,11 @@ from setuptools import find_packages, setup install_requires = [ - 'mosaicml==0.20.1', - 'mosaicml-streaming==0.7.4', - 'hydra-core>=1.2', - 'hydra-colorlog>=1.1.0', - 'diffusers[torch]==0.26.3', - 'transformers[torch]==4.38.2', - 'huggingface_hub==0.21.2', - 'wandb==0.16.3', - 'xformers==0.0.23.post1', - 'triton==2.1.0', - 'torchmetrics[image]==1.3.1', - 'lpips==0.1.4', - 'clean-fid==0.1.35', - 'clip@git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33', - 'gradio==4.19.2', + 'mosaicml==0.20.1', 'mosaicml-streaming==0.7.4', 'hydra-core>=1.2', 'hydra-colorlog>=1.1.0', + 'diffusers[torch]==0.26.3', 'transformers[torch]==4.38.2', 'huggingface_hub==0.21.2', 'wandb==0.16.3', + 'xformers==0.0.23.post1', 'triton==2.1.0', 'torchmetrics[image]==1.3.1', 'lpips==0.1.4', 'clean-fid==0.1.35', + 'clip@git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33', 'gradio==4.19.2', + 'datasets==2.19.2' ] extras_require = {}