From 60e945af4e4027e0dd3ec6384062305dc9fe7be8 Mon Sep 17 00:00:00 2001 From: Raphael Tang Date: Tue, 15 Nov 2022 00:41:15 -0500 Subject: [PATCH] Set requirement for Diffusers 0.3.0 --- README.md | 5 +- daam/_version.py | 2 +- daam/experiment.py | 70 ++++++++++++++++++++-- daam/run/daam_to_mask.py | 50 +++++++++++----- daam/run/filter_coco.py | 70 ++++++++++++++++++++++ daam/run/generate.py | 119 ++++++++++++++++++++++++++++---------- daam/run/test_literacy.py | 91 +++++++++++++++++++++++++++++ daam/trace.py | 47 +++++++++++++-- daam/utils.py | 12 +++- requirements.txt | 4 +- setup.py | 8 ++- 11 files changed, 419 insertions(+), 59 deletions(-) create mode 100644 daam/run/filter_coco.py create mode 100644 daam/run/test_literacy.py diff --git a/README.md b/README.md index 8bf7cf5..274d163 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ prompt = 'A dog runs across the field' gen = set_seed(0) with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad(): - with trace(pipe, weighted=True) as tc: + with trace(pipe) as tc: out = pipe(prompt, num_inference_steps=30, generator=gen) heat_map = tc.compute_global_heat_map(prompt) heat_map = expand_image(heat_map.compute_word_heat_map('dog')) @@ -43,7 +43,8 @@ with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad(): plt.show() ``` -We'll have docs soon. +We'll have docs soon. +In the meantime, checkout the `GenerationExperiment`, `HeatMap`, and `DiffusionHeatMapHooker` classes, as well as the `daam/run/*.py` example scripts. ## Running the Demo diff --git a/daam/_version.py b/daam/_version.py index ffcc925..156d6f9 100644 --- a/daam/_version.py +++ b/daam/_version.py @@ -1 +1 @@ -__version__ = '0.0.3' +__version__ = '0.0.4' diff --git a/daam/experiment.py b/daam/experiment.py index a456781..c728e9e 100644 --- a/daam/experiment.py +++ b/daam/experiment.py @@ -3,14 +3,15 @@ from dataclasses import dataclass import json +from transformers import PreTrainedTokenizer import PIL.Image import numpy as np import torch from .evaluate import load_mask +from .utils import plot_overlay_heat_map, expand_image - -__all__ = ['GenerationExperiment', 'COCO80_LABELS', 'COCOSTUFF27_LABELS'] +__all__ = ['GenerationExperiment', 'COCO80_LABELS', 'COCOSTUFF27_LABELS', 'COCO80_INDICES', 'build_word_list_coco80'] COCO80_LABELS: List[str] = [ @@ -25,16 +26,37 @@ 'hair drier', 'toothbrush' ] +COCO80_INDICES: Dict[str, int] = {x: i for i, x in enumerate(COCO80_LABELS)} UNUSED_LABELS: List[str] = [f'__unused_{i}__' for i in range(1, 200)] - COCOSTUFF27_LABELS: List[str] = [ 'electronic', 'appliance', 'food', 'furniture', 'indoor', 'kitchen', 'accessory', 'animal', 'outdoor', 'person', 'sports', 'vehicle', 'ceiling', 'floor', 'food', 'furniture', 'rawmaterial', 'textile', 'wall', 'window', 'building', 'ground', 'plant', 'sky', 'solid', 'structural', 'water' ] +COCO80_ONTOLOGY = { + 'two-wheeled vehicle': ['bicycle', 'motorcycle'], + 'vehicle': ['two-wheeled vehicle', 'four-wheeled vehicle'], + 'four-wheeled vehicle': ['bus', 'truck', 'car'], + 'four-legged animals': ['livestock', 'pets', 'wild animals'], + 'livestock': ['cow', 'horse', 'sheep'], + 'pets': ['cat', 'dog'], + 'wild animals': ['elephant', 'bear', 'zebra', 'giraffe'], + 'bags': ['backpack', 'handbag', 'suitcase'], + 'sports boards': ['snowboard', 'surfboard', 'skateboard'], + 'utensils': ['fork', 'knife', 'spoon'], + 'receptacles': ['bowl', 'cup'], + 'fruits': ['banana', 'apple', 'orange'], + 'foods': ['fruits', 'meals', 'desserts'], + 'meals': ['sandwich', 'hot dog', 'pizza'], + 'desserts': ['cake', 'donut'], + 'furniture': ['chair', 'couch', 'bench'], + 'electronics': ['monitors', 'appliances'], + 'monitors': ['tv', 'cell phone', 'laptop'], + 'appliances': ['oven', 'toaster', 'refrigerator'] +} COCO80_TO_27 = { 'bicycle': 'vehicle', 'car': 'vehicle', 'motorcycle': 'vehicle', 'airplane': 'vehicle', 'bus': 'vehicle', @@ -56,6 +78,13 @@ } +def build_word_list_coco80() -> Dict[str, List[str]]: + words_map = COCO80_ONTOLOGY.copy() + words_map = {k: v for k, v in words_map.items() if not any(item in COCO80_ONTOLOGY for item in v)} + + return words_map + + def _add_mask(masks: Dict[str, torch.Tensor], word: str, mask: torch.Tensor, simplify80: bool = False) -> Dict[str, torch.Tensor]: if simplify80: word = COCO80_TO_27.get(word, word) @@ -83,6 +112,9 @@ class GenerationExperiment: prediction_masks: Optional[Dict[str, torch.Tensor]] = None annotations: Optional[Dict[str, Any]] = None + def nsfw(self) -> bool: + return np.sum(np.array(self.image)) == 0 + def save(self, path: str = None): if path is None: path = self.path @@ -146,9 +178,22 @@ def _load_pred_masks(self, pred_prefix, composite=False, simplify80=False, vocab return masks + def clear_prediction_masks(self, name: str): + path = self if isinstance(self, Path) else self.path + + for mask_path in path.glob(f'*.{name}.pred.png'): + mask_path.unlink() + def save_prediction_mask(self, mask: torch.Tensor, word: str, name: str): - im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).byte().numpy()) - im.save(self.path / f'{word.lower()}.{name}.pred.png') + path = self if isinstance(self, Path) else self.path + im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).cpu().byte().numpy()) + im.save(path / f'{word.lower()}.{name}.pred.png') + + def save_heat_map(self, tokenizer: PreTrainedTokenizer, word: str): + from .trace import HeatMap # because of cyclical import + heat_map = HeatMap(tokenizer, self.prompt, self.global_heat_map) + heat_map = expand_image(heat_map.compute_word_heat_map(word)) + plot_overlay_heat_map(self.image, heat_map, word, self.path / f'{word.lower()}.heat_map.png') @staticmethod def contains_truth_mask(path: str | Path, prompt_id: str = None) -> bool: @@ -157,6 +202,13 @@ def contains_truth_mask(path: str | Path, prompt_id: str = None) -> bool: else: return any((Path(path) / prompt_id).glob('*.gt.png')) + @staticmethod + def read_seed(path: str | Path, prompt_id: str = None) -> int: + if prompt_id is None: + return int(Path(path).joinpath('seed.txt').read_text()) + else: + return int(Path(path).joinpath(prompt_id).joinpath('seed.txt').read_text()) + @staticmethod def has_annotations(path: str | Path) -> bool: return Path(path).joinpath('annotations.json').exists() @@ -165,6 +217,14 @@ def has_annotations(path: str | Path) -> bool: def has_experiment(path: str | Path, prompt_id: str) -> bool: return (Path(path) / prompt_id / 'generation.pt').exists() + @staticmethod + def read_prompt(path: str | Path, prompt_id: str = None) -> str: + if prompt_id is None: + prompt_id = '.' + + with (Path(path) / prompt_id / 'prompt.txt').open('r') as f: + return f.read().strip() + def _try_load_annotations(self): if not (self.path / 'annotations.json').exists(): return None diff --git a/daam/run/daam_to_mask.py b/daam/run/daam_to_mask.py index 506508e..d8d6bca 100644 --- a/daam/run/daam_to_mask.py +++ b/daam/run/daam_to_mask.py @@ -3,45 +3,69 @@ from diffusers import StableDiffusionPipeline from tqdm import tqdm +import joblib -from daam import HeatMap +from daam import HeatMap, MmDetectHeatMap from daam.experiment import GenerationExperiment from daam.utils import cached_nlp, expand_image def main(): + def run_mm_detect(path: Path): + GenerationExperiment.clear_prediction_masks(path, args.prefix_name) + heat_map = MmDetectHeatMap(path / '_masks.pred.mask2former.pt', threshold=args.threshold) + + for word, mask in heat_map.word_masks.items(): + GenerationExperiment.save_prediction_mask(path, mask, word, 'mmdetect') + parser = argparse.ArgumentParser() parser.add_argument('--input-folder', '-i', type=str, required=True) parser.add_argument('--extract-types', '-e', type=str, nargs='+', default=['noun']) + parser.add_argument('--model', '-m', type=str, default='daam', choices=['daam', 'mmdetect']) parser.add_argument('--threshold', '-t', type=float, default=0.4) parser.add_argument('--absolute', action='store_true') parser.add_argument('--truth-only', action='store_true') parser.add_argument('--prefix-name', '-p', type=str, default='daam') + parser.add_argument('--save-heat-map', action='store_true') args = parser.parse_args() extract_types = set(args.extract_types) model_id = 'CompVis/stable-diffusion-v1-4' tokenizer = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True).tokenizer + jobs = [] - for path in tqdm(Path(args.input_folder).glob('*')): + for path in tqdm(list(Path(args.input_folder).glob('*'))): if not path.is_dir() or (not GenerationExperiment.contains_truth_mask(path) and args.truth_only): continue + if list(path.glob('**/*.heat_map.png')) and args.save_heat_map: + continue + exp = GenerationExperiment.load(path) - heat_map = HeatMap(tokenizer, exp.prompt, exp.global_heat_map) - doc = cached_nlp(exp.prompt) - for token in doc: - if token.pos_.lower() in extract_types: - try: - word_heat_map = heat_map.compute_word_heat_map(token.text) - except: - continue + if args.model == 'daam': + heat_map = HeatMap(tokenizer, exp.prompt, exp.global_heat_map) + doc = cached_nlp(exp.prompt) + + for token in doc: + if token.pos_.lower() in extract_types or 'all' in extract_types: + try: + word_heat_map = heat_map.compute_word_heat_map(token.text) + except: + continue + + im = expand_image(word_heat_map, absolute=args.absolute, threshold=args.threshold) + exp.save_prediction_mask(im, token.text, args.prefix_name) + + if args.save_heat_map: + exp.save_heat_map(tokenizer, token.text) - im = expand_image(word_heat_map, absolute=args.absolute, threshold=args.threshold) - exp.save_prediction_mask(im, token.text, args.prefix_name) + tqdm.write(f'Saved mask for {token.text} in {path}') + else: + jobs.append(joblib.delayed(run_mm_detect)(path)) - tqdm.write(f'Saved mask for {token.text} in {path}') + if jobs: + joblib.Parallel(n_jobs=16)(tqdm(jobs)) if __name__ == '__main__': diff --git a/daam/run/filter_coco.py b/daam/run/filter_coco.py new file mode 100644 index 0000000..1c6563f --- /dev/null +++ b/daam/run/filter_coco.py @@ -0,0 +1,70 @@ +from collections import Counter, defaultdict +from pathlib import Path +import argparse +import json +import re +import sys + +from nltk.stem import PorterStemmer +from tqdm import tqdm + +from daam.experiment import build_word_list_coco80 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--input-folder', '-i', type=str, default='input') + parser.add_argument('--limit', '-lim', type=int, default=500) + args = parser.parse_args() + + with (Path(args.input_folder) / 'captions_val2014.json').open() as f: + captions = json.load(f)['annotations'] + + vocab = build_word_list_coco80() + stemmer = PorterStemmer() + words = set(stemmer.stem(w) for items in vocab.values() for w in items) + word_patt = '(' + '|'.join(words) + ')' + patt = re.compile(rf'^.*(?P{word_patt}) and (a )?(?P{word_patt}).*$') + + c = Counter() + data = defaultdict(list) + + for caption in tqdm(captions): + sentence = caption['caption'].split() + sentence = ' '.join(stemmer.stem(w) for w in sentence) + match = patt.match(sentence) + + if match: + word1 = match.groupdict()['word1'] + word2 = match.groupdict()['word2'] + print(f'{word1} and {word2} found', file=sys.stderr) + + words = tuple(sorted([word1, word2])) + c[words] += 1 + data[words].append(caption) + + all_captions = [] + final_captions = [] + + for words, count in c.most_common(): + all_captions.append(data[words]) + + while all_captions: + for captions in all_captions: + if captions: + final_captions.append(captions.pop(-1)) + + idx = 0 + + while idx < len(all_captions): + if not all_captions[idx]: + all_captions.pop(idx) + else: + idx += 1 + + for captions in final_captions: + print(json.dumps(captions)) + + +if __name__ == '__main__': + main() diff --git a/daam/run/generate.py b/daam/run/generate.py index 46f135d..ebb4eca 100644 --- a/daam/run/generate.py +++ b/daam/run/generate.py @@ -1,23 +1,52 @@ from collections import defaultdict from pathlib import Path +from typing import Dict, List import argparse import json + import pandas as pd import random from diffusers import StableDiffusionPipeline +from nltk.corpus import wordnet as wn from tqdm import tqdm import inflect +import numpy as np import torch from daam import trace -from daam.experiment import GenerationExperiment -from daam.utils import set_seed +from daam.experiment import GenerationExperiment, build_word_list_coco80 +from daam.utils import set_seed, cached_nlp + + +def build_word_list_large() -> Dict[str, List[str]]: + cat5 = ['vegetable', 'fruit', 'car', 'mammal', 'reptile'] + topk = open('data/top100k').readlines()[:30000] + topk = set(w.strip() for w in topk) + words_map = {} + + for cat in cat5: + words = set() + x = wn.synsets(cat, 'n')[0] + hyponyms = list(x.closure(lambda s: s.hyponyms())) + + for synset in hyponyms: + if any('_' in w for w in synset.lemma_names()): + continue + + word = synset.lemma_names()[0].lower() + + if '_' not in word and word in topk: + words.add(word) + + words_map[cat] = list(words) + + return words_map def main(): parser = argparse.ArgumentParser() - parser.add_argument('--action', type=str, default='prompt', choices=['prompt', 'coco', 'template', 'cconj']) + parser.add_argument('--action', type=str, default='prompt', choices=['prompt', 'coco', 'template', 'cconj', 'coco-unreal']) parser.add_argument('--output-folder', '-o', type=str, default='output') parser.add_argument('--input-folder', '-i', type=str, default='input') parser.add_argument('--seed', '-s', type=int, default=0) @@ -26,17 +55,41 @@ def main(): parser.add_argument('--scramble-unreal', action='store_true') parser.add_argument('--template-data-file', '-tdf', type=str, default='template.tsv') parser.add_argument('--regenerate', action='store_true') + parser.add_argument('--seed-offset', type=int, default=0) args = parser.parse_args() gen = set_seed(args.seed) eng = inflect.engine() - if args.action == 'coco': + if args.action.startswith('coco'): with (Path(args.input_folder) / 'captions_val2014.json').open() as f: captions = json.load(f)['annotations'] random.shuffle(captions) captions = captions[:args.gen_limit] + if args.action == 'coco-unreal': + pos_map = defaultdict(list) + + for caption in tqdm(captions): + doc = cached_nlp(caption['caption']) + + for tok in doc: + if tok.pos_ == 'ADJ' or tok.pos_ == 'NOUN': + pos_map[tok.pos_].append(tok) + + for caption in tqdm(captions): + doc = cached_nlp(caption['caption']) + new_tokens = [] + + for tok in doc: + if tok.pos_ == 'ADJ' or tok.pos_ == 'NOUN': + new_tokens.append(random.choice(pos_map[tok.pos_])) + + new_prompt = ''.join([tok.text_with_ws for tok in new_tokens]) + caption['caption'] = new_prompt + + print(new_prompt) + prompts = [(caption['id'], caption['caption']) for caption in captions] elif args.action == 'template': template_df = pd.read_csv(args.template_data_file, sep='\t', quoting=3) @@ -68,27 +121,25 @@ def main(): prompts.append((prompt_id, ' '.join(words))) tqdm.write(str(prompts[-1])) elif args.action == 'cconj': - template_df = pd.read_csv(args.template_data_file, sep='\t', quoting=3) - sample_dict = defaultdict(list) - - for name, df in template_df.groupby('pos'): - sample_dict[name].extend(df['word'].tolist()) - + words_map = build_word_list_coco80() prompts = [] - prompt_id = 0 - - for _ in range(args.gen_limit): - - for word1 in tqdm(sample_dict['noun']): - for word2 in sample_dict['noun']: - if word1 == word2: - continue - - prompt = f'a {word1} and a {word2}' - print(prompt) - prompts.append((str(prompt_id), prompt)) - prompt_id += 1 + for idx in range(args.gen_limit): + use_cohyponym = random.random() < 0.5 + + if use_cohyponym: + c = random.choice(list(words_map.keys())) + w1, w2 = np.random.choice(words_map[c], 2, replace=False) + else: + c1, c2 = np.random.choice(list(words_map.keys()), 2, replace=False) + w1 = random.choice(words_map[c1]) + w2 = random.choice(words_map[c2]) + + prompt_id = f'{"cohypo" if use_cohyponym else "diff"}-{idx}' + a1 = 'an' if w1[0] in 'aeiou' else 'a' + a2 = 'an' if w2[0] in 'aeiou' else 'a' + prompt = f'{a1} {w1} and {a2} {w2}' + prompts.append((prompt_id, prompt)) else: prompts = [('prompt', input('> '))] @@ -101,18 +152,26 @@ def main(): with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad(): for prompt_id, prompt in tqdm(prompts): + # gen = set_seed(seed) # Uncomment this for seed fix + if args.action == 'template' or args.action == 'cconj': - gen = set_seed(int(prompt_id)) - seed = int(prompt_id) + seed = int(prompt_id.split('-')[1]) + args.seed_offset + gen = set_seed(seed) prompt_id = str(prompt_id) + do_skip = False + num_steps = 30 if args.regenerate and not GenerationExperiment.contains_truth_mask(args.output_folder, prompt_id): + # I screwed up with the seed generation so this is a hacky workaround to reproduce the paper. + num_steps = 1 + do_skip = True print(f'Skipping {prompt_id}') - continue + elif args.regenerate: + print(f'Regenerating {prompt_id}') - with trace(pipe, weighted=True) as tc: - out = pipe(prompt, num_inference_steps=20, generator=gen) + with trace(pipe, weighted=False) as tc: + out = pipe(prompt, num_inference_steps=num_steps, generator=gen) exp = GenerationExperiment( id=prompt_id, global_heat_map=tc.compute_global_heat_map(prompt).heat_maps, @@ -120,7 +179,9 @@ def main(): prompt=prompt, image=out.images[0] ) - exp.save(args.output_folder) + + if not do_skip: + exp.save(args.output_folder) if __name__ == '__main__': diff --git a/daam/run/test_literacy.py b/daam/run/test_literacy.py new file mode 100644 index 0000000..998e7dd --- /dev/null +++ b/daam/run/test_literacy.py @@ -0,0 +1,91 @@ +from functools import cache +from pathlib import Path +import argparse + +from diffusers import StableDiffusionPipeline +from matplotlib import pyplot as plt +from tqdm import tqdm +import numpy as np +import torch + +from daam import GenerationExperiment, COCO80_LABELS, COCO80_INDICES + + +def main(): + @torch.no_grad() + @cache + def compute_cosine_dist(word1, word2): + text_encoder.eval() + text_input = tokenizer( + f'a {word1} and a {word2}', + padding='max_length', + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors='pt' + ) + + text_embeddings = text_encoder(text_input.input_ids.cuda())[0][0] + a = text_embeddings[text_input.input_ids.squeeze().tolist().index(537) - 1] # 537 is ID for 'and' + b = text_embeddings[text_input.input_ids.squeeze().tolist().index(49407) - 1] # 49407 is ID for padding + + return torch.cosine_similarity(a, b, dim=0).item() + + parser = argparse.ArgumentParser() + parser.add_argument('--input-folder', '-i', type=str, required=True) + parser.add_argument('--pred-prefix', '-p', type=str, default='mmdetect') + parser.add_argument('--visualize', action='store_true') + args = parser.parse_args() + + input_folder = Path(args.input_folder) + n = len(COCO80_LABELS) + + word1_present_matrix = np.zeros((n, n), dtype=np.int32) + word2_present_matrix = np.zeros((n, n), dtype=np.int32) + corr_matrix = np.zeros((n, n), dtype=np.int32) + tot_matrix = np.zeros((n, n), dtype=np.int32) + tot_subset_matrix = np.zeros((n, n), dtype=np.int32) + cos_matrix = np.zeros((n, n), dtype=np.float32) + + model_id = 'CompVis/stable-diffusion-v1-4' + device = 'cuda' + + pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True) + pipe = pipe.to(device) + tokenizer = pipe.tokenizer + text_encoder = pipe.text_encoder + del pipe + + for path in tqdm(list(input_folder.iterdir())): + if not path.is_dir(): + continue + + exp = GenerationExperiment.load(str(path), args.pred_prefix) + word1, word2 = exp.prompt.split(' and ') + word1 = word1.strip()[2:] + word2 = word2.strip()[2:] + lbl1 = COCO80_INDICES[word1] + lbl2 = COCO80_INDICES[word2] + + if exp.nsfw(): + continue + + compute_cosine_dist(word1, word2) + + if args.visualize: + print(exp.prompt, word1 in exp.prediction_masks, word2 in exp.prediction_masks) + plt.clf() + exp.image.show() + plt.show() + + word1_present_matrix[lbl1, lbl2] += word1 in exp.prediction_masks + word2_present_matrix[lbl1, lbl2] += word2 in exp.prediction_masks + corr_matrix[lbl1, lbl2] += word1 in exp.prediction_masks and word2 in exp.prediction_masks + tot_subset_matrix[lbl1, lbl2] += (word1 in exp.prediction_masks) or (word2 in exp.prediction_masks) + tot_matrix[lbl1, lbl2] += 1 + cos_matrix[lbl1, lbl2] = compute_cosine_dist(word1, word2) + + torch.save((word1_present_matrix, word2_present_matrix, corr_matrix, tot_matrix, tot_subset_matrix, cos_matrix), 'results2.pt') + + +if __name__ == '__main__': + main() diff --git a/daam/trace.py b/daam/trace.py index 6f7bba3..67b26e3 100644 --- a/daam/trace.py +++ b/daam/trace.py @@ -1,19 +1,22 @@ from collections import defaultdict from copy import deepcopy -from typing import List, Type, Dict, Any, Literal +from pathlib import Path +from typing import List, Type, Any, Literal, Dict import math from diffusers import UNet2DConditionModel, StableDiffusionPipeline from diffusers.models.attention import CrossAttention +import numba import numpy as np import torch import torch.nn.functional as F +from .experiment import COCO80_LABELS from .hook import ObjectHooker, AggregateHooker, UNetCrossAttentionLocator from .utils import compute_token_merge_indices -__all__ = ['trace', 'DiffusionHeatMapHooker', 'HeatMap'] +__all__ = ['trace', 'DiffusionHeatMapHooker', 'HeatMap', 'MmDetectHeatMap'] class UNetForwardHooker(ObjectHooker[UNet2DConditionModel]): @@ -47,8 +50,44 @@ def compute_word_heat_map(self, word: str, word_idx: int = None) -> torch.Tensor return self.heat_maps[merge_idxs].mean(0) +class MmDetectHeatMap: + def __init__(self, pred_file: str | Path, threshold: float = 0.95): + @numba.njit + def _compute_mask(masks: np.ndarray, bboxes: np.ndarray): + x_any = np.any(masks, axis=1) + y_any = np.any(masks, axis=2) + num_masks = len(bboxes) + + for idx in range(num_masks): + x = np.where(x_any[idx, :])[0] + y = np.where(y_any[idx, :])[0] + bboxes[idx, :4] = np.array([x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=np.float32) + + pred_file = Path(pred_file) + self.word_masks: Dict[str, torch.Tensor] = defaultdict(lambda: 0) + bbox_result, masks = torch.load(pred_file) + labels = [np.full(bbox.shape[0], i, dtype=np.int32) for i, bbox in enumerate(bbox_result)] + labels = np.concatenate(labels) + bboxes = np.vstack(bbox_result) + + if masks is not None and bboxes[:, :4].sum() == 0: + _compute_mask(masks, bboxes) + scores = bboxes[:, -1] + inds = scores > threshold + labels = labels[inds] + masks = masks[inds, ...] + + for lbl, mask in zip(labels, masks): + self.word_masks[COCO80_LABELS[lbl]] |= torch.from_numpy(mask) + + self.word_masks = {k: v.float() for k, v in self.word_masks.items()} + + def compute_word_heat_map(self, word: str) -> torch.Tensor: + return self.word_masks[word] + + class DiffusionHeatMapHooker(AggregateHooker): - def __init__(self, pipeline: StableDiffusionPipeline, weighted: bool = True): + def __init__(self, pipeline: StableDiffusionPipeline, weighted: bool = False): heat_maps = defaultdict(list) modules = [UNetCrossAttentionHooker(x, heat_maps, weighted=weighted) for x in UNetCrossAttentionLocator().locate(pipeline.unet)] self.forward_hook = UNetForwardHooker(pipeline.unet, heat_maps) @@ -98,7 +137,7 @@ def compute_global_heat_map(self, prompt, time_weights=None, time_idx=None, last class UNetCrossAttentionHooker(ObjectHooker[CrossAttention]): - def __init__(self, module: CrossAttention, heat_maps: defaultdict, context_size: int = 77, weighted: bool = True): + def __init__(self, module: CrossAttention, heat_maps: defaultdict, context_size: int = 77, weighted: bool = False): super().__init__(module) self.heat_maps = heat_maps self.context_size = context_size diff --git a/daam/utils.py b/daam/utils.py index 124fb7f..2064cdf 100644 --- a/daam/utils.py +++ b/daam/utils.py @@ -1,4 +1,5 @@ from functools import lru_cache +from pathlib import Path import random import PIL.Image @@ -27,13 +28,21 @@ def expand_image(im: torch.Tensor, out: int = 512, absolute: bool = False, thres return im.squeeze() -def plot_overlay_heat_map(im: PIL.Image.Image | np.ndarray, heat_map: torch.Tensor): +def plot_overlay_heat_map(im: PIL.Image.Image | np.ndarray, heat_map: torch.Tensor, word: str = None, out_file: Path = None): + plt.clf() + plt.rcParams.update({'font.size': 24}) plt.imshow(heat_map.squeeze().cpu().numpy(), cmap='jet') im = np.array(im) im = torch.from_numpy(im).float() / 255 im = torch.cat((im, (1 - heat_map.unsqueeze(-1))), dim=-1) plt.imshow(im) + if word is not None: + plt.title(word) + + if out_file is not None: + plt.savefig(out_file) + def plot_mask_heat_map(im: PIL.Image.Image, heat_map: torch.Tensor, threshold: float = 0.4): im = torch.from_numpy(np.array(im)).float() / 255 @@ -76,7 +85,6 @@ def compute_token_merge_indices(tokenizer, prompt: str, word: str, word_idx: int if word_idx is None: raise ValueError(f'Couldn\'t find "{word}" in "{prompt}"') - for idx, token in enumerate(tokens): merge_idxs.append(idx) diff --git a/requirements.txt b/requirements.txt index c468d55..bea800c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ -diffusers +diffusers==0.3.0 spacy gradio ftfy transformers pandas +numba +joblib diff --git a/setup.py b/setup.py index 81a97e5..426f607 100644 --- a/setup.py +++ b/setup.py @@ -10,12 +10,16 @@ description='What the DAAM: Interpreting Stable Diffusion Using Cross Attention.', install_requires=[ 'transformers', - 'diffusers', + 'diffusers==0.3.0', 'spacy', 'gradio', 'ftfy', 'transformers', - 'pandas' + 'pandas', + 'numba', + 'nltk', + 'inflect', + 'joblib' ], packages=setuptools.find_packages(), python_requires='>=3.10'