-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_bound.py
192 lines (151 loc) · 6.26 KB
/
get_bound.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import json
import os
import time
from PIL import Image
import argparse
import numpy as np
TIME_ST = time.time()
TIME_ED = time.time()
def prepare_context():
"""
prepare context for later use
"""
import torch
import utils
from utils import set_logger
from absl import logging
import os
import libs.autoencoder
import clip
from libs.clip import FrozenCLIPEmbedder
from libs.uvit_multi_post_ln_v1 import UViT
from configs.unidiffuserv1 import get_config
import builtins
import ml_collections
from score import Evaluator
from torch import multiprocessing as mp
config = get_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # cuda:0 default
nnet = UViT(**config.nnet)
origin_sd = torch.load("models/uvit_v1.pth", map_location='cpu')
nnet.load_state_dict(origin_sd, strict=False)
nnet.to(device)
autoencoder = libs.autoencoder.get_model(**config.autoencoder).to(device)
clip_text_model = FrozenCLIPEmbedder(version=config.clip_text_model, device=device)
clip_img_model, clip_img_model_preprocess = clip.load(config.clip_img_model, jit=False)
clip_img_model.to(device).eval().requires_grad_(False)
ev = Evaluator()
return {
"device": device,
'config': config,
"origin_sd": origin_sd,
"nnet": nnet,
"autoencoder": autoencoder,
"clip_text_model": clip_text_model,
"clip_img_model": clip_img_model,
"clip_img_model_preprocess": clip_img_model_preprocess,
"ev": ev
}
def load_json_files(path):
"""
given a directory, load all json files in that directory
return a list of json objects
"""
d_ls = []
for file in os.listdir(path):
if file.endswith(".json"):
with open(os.path.join(path, file), 'r') as f:
json_data = json.load(f)
d_ls.append(json_data)
return d_ls
def process_one_json(json_data, image_output_path, context={}):
"""
given a json object, process the task the json describes
"""
import utils
from absl import logging
import torch
from sample_fn import sample
# 初始化训练步数
config = context["config"]
device = context["device"]
nnet = context["nnet"]
autoencoder = context["autoencoder"]
clip_text_model = context["clip_text_model"]
ev = context["ev"]
config.n_samples = 4
config.n_iter = 5
origin_images = [Image.open(i["path"]).convert('RGB') for i in json_data["source_group"]]
origin_face_embs = [ev.get_face_embedding(i) for i in origin_images]
origin_face_embs = [emb for emb in origin_face_embs if emb is not None]
origin_face_embs = torch.cat(origin_face_embs)
origin_clip_embs = [ev.get_img_embedding(i) for i in origin_images]
origin_clip_embs = torch.cat(origin_clip_embs)
images = []
for caption in json_data["caption_list"]:
config.prompt = caption
paths = sample(config, nnet, clip_text_model, autoencoder, device, json_data["id"], output_path=image_output_path)
# face max sim is source group self sim
max_face_sim = (origin_face_embs @ origin_face_embs.T).mean().item()
# face min sim is randon pic gened by prompt
samples = [Image.open(sample_path).convert('RGB') for sample_path in paths]
face_embs = [ev.get_face_embedding(sample) for sample in samples]
face_embs = [emb for emb in face_embs if emb is not None]
if len(face_embs) == 0:
print(f"no face for case{json_data['id']} caption {caption}")
continue
min_face_sim = (origin_face_embs @ torch.cat(face_embs).T).mean().item()
# text max sim is image gened by prompt sim with prompt
text_emb = ev.get_text_embedding(caption)
gen_clip_embs = torch.cat([ev.get_img_embedding(i) for i in samples])
max_text_sim = (text_emb @ gen_clip_embs.T).mean().item()
# text min sim is source group with prompt
min_text_sim = (text_emb @ origin_clip_embs.T).mean().item()
# image reward max sim is gened by prompt sim with prompt
max_image_reward = np.mean([ev.image_reward.score(caption, path) for path in paths ]).item()
# image reward min sim is source group with prompt
min_image_reward = np.mean([ev.image_reward.score(caption, i["path"]) for i in json_data["source_group"] ]).item()
images.append({"prompt": caption, "paths": paths,
"max_face_sim": max_face_sim,
"min_face_sim": min_face_sim,
"max_text_sim": max_text_sim,
"min_text_sim": min_text_sim,
"max_image_reward": max_image_reward,
"min_image_reward": min_image_reward,
})
return {
"id": json_data["id"],
"images": images
}
def tik():
global TIME_ST
TIME_ST = time.time()
def tok(name):
global TIME_ED
TIME_ED = time.time()
print(f"Time {name} elapsed: {TIME_ED - TIME_ST}")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("-d","--json_data_path", type=str, default="test_json_data", help="file contains prompts")
parser.add_argument("-j","--json_output_path", type=str, default="bound_json_outputs", help="file contains scores")
parser.add_argument("-i","--image_output_path", type=str, default="bound_image_outputs", help="output dir for generated images")
return parser.parse_args()
def main():
"""
main function
"""
arg = get_args()
os.makedirs(arg.json_output_path, exist_ok=True)
os.makedirs(arg.image_output_path, exist_ok=True)
# load json files
json_data_ls = load_json_files(arg.json_data_path)
# process json files
context = prepare_context()
for json_data in json_data_ls:
tik()
out = process_one_json(json_data, arg.image_output_path, context)
tok(f"process_one_json: {json_data['id']}")
with open(os.path.join(arg.json_output_path, f"{json_data['id']}.json"), 'w') as f:
json.dump(out, f, indent=4)
if __name__ == "__main__":
main()