Skip to content

Commit

Permalink
Add option for default prompts/negative prompts in eval (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
coryMosaicML authored Oct 2, 2024
1 parent 6acffcd commit ab5a2f0
Showing 1 changed file with 36 additions and 5 deletions.
41 changes: 36 additions & 5 deletions diffusion/evaluation/clean_fid_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class CleanFIDEvaluator:
precision (str): The precision to use for evaluation. Default: ``'amp_fp16'``.
prompts (List[str], optional): The prompts to use for image visualtization.
Default: ``["A shiba inu wearing a blue sweater]``.
default_prompt (Optional[str]): An optional default prompt to add before each eval prompt. Default: ``None``.
default_negative_prompt (Optional[str]): An optional default negative prompt to add before each
negative prompt. Default: ``None``.
additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method.
"""
Expand All @@ -70,6 +73,8 @@ def __init__(self,
num_samples: Optional[int] = None,
precision: str = 'amp_fp16',
prompts: Optional[List[str]] = None,
default_prompt: Optional[str] = None,
default_negative_prompt: Optional[str] = None,
additional_generate_kwargs: Optional[Dict] = None):
self.model = model
self.tokenizer: PreTrainedTokenizerBase = model.tokenizer
Expand All @@ -87,6 +92,8 @@ def __init__(self,
self.num_samples = num_samples if num_samples is not None else float('inf')
self.precision = precision
self.prompts = prompts if prompts is not None else ['A shiba inu wearing a blue sweater']
self.default_prompt = default_prompt
self.default_negative_prompt = default_negative_prompt
self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}
self.sdxl = model.sdxl

Expand Down Expand Up @@ -141,7 +148,17 @@ def _generate_images(self, guidance_scale: float):
break

real_images = batch[self.image_key]
captions = batch[self.caption_key]
tokenized_captions = batch[self.caption_key]
# Get the prompts from the tokens
text_captions = self.tokenizer.batch_decode(tokenized_captions, skip_special_tokens=True)
# Add default prompts if specified
augmented_captions = text_captions
augmented_negative_prompt = None
if self.default_prompt:
augmented_captions = [f'{self.default_prompt} {caption}' for caption in text_captions]
if self.default_negative_prompt:
augmented_negative_prompt = [f'{self.default_negative_prompt}' for _ in text_captions]

if self.sdxl:
crop_params = batch['cond_crops_coords_top_left']
input_size_params = batch['cond_original_size']
Expand All @@ -153,7 +170,8 @@ def _generate_images(self, guidance_scale: float):
seed = starting_seed + batch_id
# Generate images from the captions
with get_precision_context(self.precision):
generated_images = self.model.generate(tokenized_prompts=captions,
generated_images = self.model.generate(prompt=augmented_captions,
negative_prompt=augmented_negative_prompt,
height=self.size,
width=self.size,
guidance_scale=guidance_scale,
Expand All @@ -162,8 +180,6 @@ def _generate_images(self, guidance_scale: float):
input_size_params=input_size_params,
progress_bar=False,
**self.additional_generate_kwargs) # type: ignore
# Get the prompts from the tokens
text_captions = self.tokenizer.batch_decode(captions, skip_special_tokens=True)
self.clip_metric.update((generated_images * 255).to(torch.uint8), text_captions)
# Save the real images
# Verify that the real images are in the proper range
Expand Down Expand Up @@ -233,8 +249,23 @@ def _compute_metrics(self, guidance_scale: float):
def _generate_images_from_prompts(self, guidance_scale: float):
"""Generate images from prompts for visualization."""
if self.prompts:
# Augment the prompt
augmented_prompts = self.prompts
if self.default_prompt:
augmented_prompts = [f'{self.default_prompt} {prompt}' for prompt in self.prompts]
# Augment the negative prompt
augmented_negative_prompts = None
if 'negative prompt' in self.additional_generate_kwargs:
negative_prompts = self.additional_generate_kwargs['negative prompt']
augmented_negative_prompts = [
f'{self.default_negative_prompt} {neg_prompt}' for neg_prompt in negative_prompts
]
if self.default_negative_prompt and augmented_negative_prompts is None:
augmented_negative_prompts = [f'{self.default_negative_prompt}' for _ in self.prompts]

with get_precision_context(self.precision):
generated_images = self.model.generate(prompt=self.prompts,
generated_images = self.model.generate(prompt=augmented_prompts,
negative_prompt=augmented_negative_prompts,
height=self.size,
width=self.size,
guidance_scale=guidance_scale,
Expand Down

0 comments on commit ab5a2f0

Please sign in to comment.