Skip to content

Commit

Permalink
Set requirement for Diffusers 0.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
daemon committed Nov 15, 2022
1 parent f82042e commit 60e945a
Show file tree
Hide file tree
Showing 11 changed files with 419 additions and 59 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ prompt = 'A dog runs across the field'
gen = set_seed(0)

with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad():
with trace(pipe, weighted=True) as tc:
with trace(pipe) as tc:
out = pipe(prompt, num_inference_steps=30, generator=gen)
heat_map = tc.compute_global_heat_map(prompt)
heat_map = expand_image(heat_map.compute_word_heat_map('dog'))
plot_overlay_heat_map(out.images[0], heat_map)
plt.show()
```

We'll have docs soon.
We'll have docs soon.
In the meantime, checkout the `GenerationExperiment`, `HeatMap`, and `DiffusionHeatMapHooker` classes, as well as the `daam/run/*.py` example scripts.

## Running the Demo

Expand Down
2 changes: 1 addition & 1 deletion daam/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.3'
__version__ = '0.0.4'
70 changes: 65 additions & 5 deletions daam/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from dataclasses import dataclass
import json

from transformers import PreTrainedTokenizer
import PIL.Image
import numpy as np
import torch

from .evaluate import load_mask
from .utils import plot_overlay_heat_map, expand_image


__all__ = ['GenerationExperiment', 'COCO80_LABELS', 'COCOSTUFF27_LABELS']
__all__ = ['GenerationExperiment', 'COCO80_LABELS', 'COCOSTUFF27_LABELS', 'COCO80_INDICES', 'build_word_list_coco80']


COCO80_LABELS: List[str] = [
Expand All @@ -25,16 +26,37 @@
'hair drier', 'toothbrush'
]

COCO80_INDICES: Dict[str, int] = {x: i for i, x in enumerate(COCO80_LABELS)}

UNUSED_LABELS: List[str] = [f'__unused_{i}__' for i in range(1, 200)]


COCOSTUFF27_LABELS: List[str] = [
'electronic', 'appliance', 'food', 'furniture', 'indoor', 'kitchen', 'accessory', 'animal', 'outdoor', 'person',
'sports', 'vehicle', 'ceiling', 'floor', 'food', 'furniture', 'rawmaterial', 'textile', 'wall', 'window',
'building', 'ground', 'plant', 'sky', 'solid', 'structural', 'water'
]

COCO80_ONTOLOGY = {
'two-wheeled vehicle': ['bicycle', 'motorcycle'],
'vehicle': ['two-wheeled vehicle', 'four-wheeled vehicle'],
'four-wheeled vehicle': ['bus', 'truck', 'car'],
'four-legged animals': ['livestock', 'pets', 'wild animals'],
'livestock': ['cow', 'horse', 'sheep'],
'pets': ['cat', 'dog'],
'wild animals': ['elephant', 'bear', 'zebra', 'giraffe'],
'bags': ['backpack', 'handbag', 'suitcase'],
'sports boards': ['snowboard', 'surfboard', 'skateboard'],
'utensils': ['fork', 'knife', 'spoon'],
'receptacles': ['bowl', 'cup'],
'fruits': ['banana', 'apple', 'orange'],
'foods': ['fruits', 'meals', 'desserts'],
'meals': ['sandwich', 'hot dog', 'pizza'],
'desserts': ['cake', 'donut'],
'furniture': ['chair', 'couch', 'bench'],
'electronics': ['monitors', 'appliances'],
'monitors': ['tv', 'cell phone', 'laptop'],
'appliances': ['oven', 'toaster', 'refrigerator']
}

COCO80_TO_27 = {
'bicycle': 'vehicle', 'car': 'vehicle', 'motorcycle': 'vehicle', 'airplane': 'vehicle', 'bus': 'vehicle',
Expand All @@ -56,6 +78,13 @@
}


def build_word_list_coco80() -> Dict[str, List[str]]:
words_map = COCO80_ONTOLOGY.copy()
words_map = {k: v for k, v in words_map.items() if not any(item in COCO80_ONTOLOGY for item in v)}

return words_map


def _add_mask(masks: Dict[str, torch.Tensor], word: str, mask: torch.Tensor, simplify80: bool = False) -> Dict[str, torch.Tensor]:
if simplify80:
word = COCO80_TO_27.get(word, word)
Expand Down Expand Up @@ -83,6 +112,9 @@ class GenerationExperiment:
prediction_masks: Optional[Dict[str, torch.Tensor]] = None
annotations: Optional[Dict[str, Any]] = None

def nsfw(self) -> bool:
return np.sum(np.array(self.image)) == 0

def save(self, path: str = None):
if path is None:
path = self.path
Expand Down Expand Up @@ -146,9 +178,22 @@ def _load_pred_masks(self, pred_prefix, composite=False, simplify80=False, vocab

return masks

def clear_prediction_masks(self, name: str):
path = self if isinstance(self, Path) else self.path

for mask_path in path.glob(f'*.{name}.pred.png'):
mask_path.unlink()

def save_prediction_mask(self, mask: torch.Tensor, word: str, name: str):
im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).byte().numpy())
im.save(self.path / f'{word.lower()}.{name}.pred.png')
path = self if isinstance(self, Path) else self.path
im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).cpu().byte().numpy())
im.save(path / f'{word.lower()}.{name}.pred.png')

def save_heat_map(self, tokenizer: PreTrainedTokenizer, word: str):
from .trace import HeatMap # because of cyclical import
heat_map = HeatMap(tokenizer, self.prompt, self.global_heat_map)
heat_map = expand_image(heat_map.compute_word_heat_map(word))
plot_overlay_heat_map(self.image, heat_map, word, self.path / f'{word.lower()}.heat_map.png')

@staticmethod
def contains_truth_mask(path: str | Path, prompt_id: str = None) -> bool:
Expand All @@ -157,6 +202,13 @@ def contains_truth_mask(path: str | Path, prompt_id: str = None) -> bool:
else:
return any((Path(path) / prompt_id).glob('*.gt.png'))

@staticmethod
def read_seed(path: str | Path, prompt_id: str = None) -> int:
if prompt_id is None:
return int(Path(path).joinpath('seed.txt').read_text())
else:
return int(Path(path).joinpath(prompt_id).joinpath('seed.txt').read_text())

@staticmethod
def has_annotations(path: str | Path) -> bool:
return Path(path).joinpath('annotations.json').exists()
Expand All @@ -165,6 +217,14 @@ def has_annotations(path: str | Path) -> bool:
def has_experiment(path: str | Path, prompt_id: str) -> bool:
return (Path(path) / prompt_id / 'generation.pt').exists()

@staticmethod
def read_prompt(path: str | Path, prompt_id: str = None) -> str:
if prompt_id is None:
prompt_id = '.'

with (Path(path) / prompt_id / 'prompt.txt').open('r') as f:
return f.read().strip()

def _try_load_annotations(self):
if not (self.path / 'annotations.json').exists():
return None
Expand Down
50 changes: 37 additions & 13 deletions daam/run/daam_to_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,69 @@

from diffusers import StableDiffusionPipeline
from tqdm import tqdm
import joblib

from daam import HeatMap
from daam import HeatMap, MmDetectHeatMap
from daam.experiment import GenerationExperiment
from daam.utils import cached_nlp, expand_image


def main():
def run_mm_detect(path: Path):
GenerationExperiment.clear_prediction_masks(path, args.prefix_name)
heat_map = MmDetectHeatMap(path / '_masks.pred.mask2former.pt', threshold=args.threshold)

for word, mask in heat_map.word_masks.items():
GenerationExperiment.save_prediction_mask(path, mask, word, 'mmdetect')

parser = argparse.ArgumentParser()
parser.add_argument('--input-folder', '-i', type=str, required=True)
parser.add_argument('--extract-types', '-e', type=str, nargs='+', default=['noun'])
parser.add_argument('--model', '-m', type=str, default='daam', choices=['daam', 'mmdetect'])
parser.add_argument('--threshold', '-t', type=float, default=0.4)
parser.add_argument('--absolute', action='store_true')
parser.add_argument('--truth-only', action='store_true')
parser.add_argument('--prefix-name', '-p', type=str, default='daam')
parser.add_argument('--save-heat-map', action='store_true')
args = parser.parse_args()

extract_types = set(args.extract_types)
model_id = 'CompVis/stable-diffusion-v1-4'
tokenizer = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True).tokenizer
jobs = []

for path in tqdm(Path(args.input_folder).glob('*')):
for path in tqdm(list(Path(args.input_folder).glob('*'))):
if not path.is_dir() or (not GenerationExperiment.contains_truth_mask(path) and args.truth_only):
continue

if list(path.glob('**/*.heat_map.png')) and args.save_heat_map:
continue

exp = GenerationExperiment.load(path)
heat_map = HeatMap(tokenizer, exp.prompt, exp.global_heat_map)
doc = cached_nlp(exp.prompt)

for token in doc:
if token.pos_.lower() in extract_types:
try:
word_heat_map = heat_map.compute_word_heat_map(token.text)
except:
continue
if args.model == 'daam':
heat_map = HeatMap(tokenizer, exp.prompt, exp.global_heat_map)
doc = cached_nlp(exp.prompt)

for token in doc:
if token.pos_.lower() in extract_types or 'all' in extract_types:
try:
word_heat_map = heat_map.compute_word_heat_map(token.text)
except:
continue

im = expand_image(word_heat_map, absolute=args.absolute, threshold=args.threshold)
exp.save_prediction_mask(im, token.text, args.prefix_name)

if args.save_heat_map:
exp.save_heat_map(tokenizer, token.text)

im = expand_image(word_heat_map, absolute=args.absolute, threshold=args.threshold)
exp.save_prediction_mask(im, token.text, args.prefix_name)
tqdm.write(f'Saved mask for {token.text} in {path}')
else:
jobs.append(joblib.delayed(run_mm_detect)(path))

tqdm.write(f'Saved mask for {token.text} in {path}')
if jobs:
joblib.Parallel(n_jobs=16)(tqdm(jobs))


if __name__ == '__main__':
Expand Down
70 changes: 70 additions & 0 deletions daam/run/filter_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from collections import Counter, defaultdict
from pathlib import Path
import argparse
import json
import re
import sys

from nltk.stem import PorterStemmer
from tqdm import tqdm

from daam.experiment import build_word_list_coco80


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--input-folder', '-i', type=str, default='input')
parser.add_argument('--limit', '-lim', type=int, default=500)
args = parser.parse_args()

with (Path(args.input_folder) / 'captions_val2014.json').open() as f:
captions = json.load(f)['annotations']

vocab = build_word_list_coco80()
stemmer = PorterStemmer()
words = set(stemmer.stem(w) for items in vocab.values() for w in items)
word_patt = '(' + '|'.join(words) + ')'
patt = re.compile(rf'^.*(?P<word1>{word_patt}) and (a )?(?P<word2>{word_patt}).*$')

c = Counter()
data = defaultdict(list)

for caption in tqdm(captions):
sentence = caption['caption'].split()
sentence = ' '.join(stemmer.stem(w) for w in sentence)
match = patt.match(sentence)

if match:
word1 = match.groupdict()['word1']
word2 = match.groupdict()['word2']
print(f'{word1} and {word2} found', file=sys.stderr)

words = tuple(sorted([word1, word2]))
c[words] += 1
data[words].append(caption)

all_captions = []
final_captions = []

for words, count in c.most_common():
all_captions.append(data[words])

while all_captions:
for captions in all_captions:
if captions:
final_captions.append(captions.pop(-1))

idx = 0

while idx < len(all_captions):
if not all_captions[idx]:
all_captions.pop(idx)
else:
idx += 1

for captions in final_captions:
print(json.dumps(captions))


if __name__ == '__main__':
main()
Loading

0 comments on commit 60e945a

Please sign in to comment.