forked from sharonzhou/long_stable_diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
longsd.py
222 lines (177 loc) · 7.04 KB
/
longsd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import argparse
import json
import time
import requests
import os
import logging
import string
import glob
import torch
import torch.multiprocessing as torch_mp
from sd import load_model as load_sd, run_model as run_sd, add_prompt_modifiers
from dump_docx import dump_images_captions_docx
OPENAI_TOKEN = os.environ['OPENAI_TOKEN']
with torch.no_grad():
torch.cuda.empty_cache()
torch_mp.set_start_method("spawn", force=True)
logger = logging.getLogger('run_longsd')
logging.basicConfig(level=logging.DEBUG)
sd_model = load_sd()
sections = ["start", "middle", "end"]
def generate_image_prompts(text, filename):
suffix_template = "\n\nRecommend five different detailed, logo-free, sign-free images to accompany the previous text that illustrate the {} of this text: 1)"
image_prompts = { s: [] for s in sections }
for section in sections:
suffix = suffix_template.format(section)
prompt = text + suffix
logger.debug(f'Generating image prompts for {section}...')
response = requests.post(
"https://api.openai.com/v1/completions",
headers={
'authorization': "Bearer " + OPENAI_TOKEN,
"content-type": "application/json",
},
json={
"model": "text-davinci-002",
"prompt": prompt,
"max_tokens": 256,
"temperature": 0.8,
}
)
text = response.text
logger.debug(text)
try:
result = json.loads(text)
except:
raise Exception(f'Cannot load: {text}, {response}')
result = result['choices'][0]['text']
result_list = result.strip().split(")") # removes space and number
clean_result_list = []
for i, r in enumerate(result_list):
res = r.strip()
if not res:
continue
if i < len(result_list) - 1:
res = res[:-2]
# Remove punctuation
res = res.translate(str.maketrans('', '', string.punctuation))
clean_result_list.append(res)
image_prompts[section].extend(clean_result_list)
# Store image prompts
filepath = f'image_prompts/{filename}.json'
logger.debug(f'Writing image prompts to {filepath}...')
with open(filepath, 'w') as f:
f.write(json.dumps(image_prompts, indent=4))
filepath_all = f'image_prompts/{filename}-all.txt'
logger.debug(f'Writing image prompts to {filepath_all}...')
with open(filepath_all, 'a') as f:
f.write(json.dumps(image_prompts, indent=4))
f.write('\n')
f.write(json.dumps(image_prompts, indent=4))
f.write('\n')
logger.debug(image_prompts)
return image_prompts
def make_image_prompts(filename, text, overwrite_prompts):
filename = filename.split('.')[0]
engineered_filepath = f"engineered_image_prompts/{filename}.json"
engineered_filepath_all = f"engineered_image_prompts/{filename}-all.txt"
if os.path.exists(engineered_filepath) and not overwrite_prompts:
logger.debug(f'Reading from existing {engineered_filepath}...')
with open(engineered_filepath) as f:
engineered_prompts = json.load(f)
else:
image_prompts = generate_image_prompts(text, filename)
# TODO: extractive summarization from long-form text as additional prompts to engineer and input into Stable Diffusion
# Engineer prompts (add modifiers to image prompts)
engineered_prompts = { s: [] for s in sections }
for section, prompts in image_prompts.items():
for prompt in prompts:
engineered_prompt = add_prompt_modifiers(prompt)
engineered_prompts[section].append(engineered_prompt)
# Store engineered image prompts
logger.debug(f'Writing engineered image prompts to {engineered_filepath_all}...')
with open(engineered_filepath_all, 'a') as f:
f.write(json.dumps(engineered_prompts, indent=4))
f.write('\n')
f.write(json.dumps(engineered_prompts, indent=4))
f.write('\n')
logger.debug(f'Writing engineered image prompts to {engineered_filepath}...')
with open(engineered_filepath, 'w') as f:
f.write(json.dumps(engineered_prompts, indent=4))
logger.debug(engineered_prompts)
return engineered_prompts
def run_text_to_image(args):
prompt = args['prompt']
section = args['section']
save_folder = args['save_folder']
image = run_sd(sd_model, prompt) # PIL output
save_prompt_name = prompt[:100].replace(' ', '_')
image_name = f'{section}-{save_prompt_name}-{str(int(time.time()))}'
image_path = f'{save_folder}/{image_name}.png'
image.save(image_path)
return (prompt, image_path)
def gpu_multiprocess(sd_inputs, num_processes):
pool = torch_mp.Pool(processes=num_processes)
prompts_and_image_paths = pool.map(run_text_to_image, sd_inputs)
pool.close()
pool.join()
return prompts_and_image_paths
def setup(file):
save_folder = 'images/' + file.split('.')[0].replace(' ', '-')
os.makedirs(save_folder, exist_ok=True)
logger.debug(f'Using folder to save: {save_folder}')
filepath = f'texts/{file}' if '.' in file else f'texts/{file}.txt'
with open(filepath, 'r') as f:
text = f.read()
return text, save_folder
def prepare_sd_inputs(image_prompts, save_folder):
sd_inputs = []
section_counts = {}
for s in sections:
section_counts[s] = len(glob.glob(f'{save_folder}/{s}-*.png'))
# Generate, sorted by the section that has the least images generated
for section in sorted(section_counts, key=lambda k: section_counts[k]):
prompts = image_prompts[section]
for prompt in prompts:
sd_input = {
'prompt': prompt,
'section': section,
'save_folder': save_folder,
}
sd_inputs.append(sd_input)
logger.debug(sd_inputs)
return sd_inputs
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--files",
"-f",
type=str,
required=True,
nargs='+',
help="File for text"
)
parser.add_argument(
"--overwrite_prompts",
"-o",
action='store_true',
help="Overwrite json file image prompts"
)
parser.add_argument(
"--num_gpu_processes",
"-n",
default=3,
type=int,
help="Num processes for gpu multiprocessing"
)
args = parser.parse_args()
files = args.files
for file in files:
text, save_folder = setup(file)
image_prompts = make_image_prompts(file, text, overwrite_prompts=args.overwrite_prompts)
sd_inputs = prepare_sd_inputs(image_prompts, save_folder)
prompts_and_image_paths = gpu_multiprocess(sd_inputs, args.num_gpu_processes)
dump_images_captions_docx(file, prompts_and_image_paths)
logger.info('All complete')
if __name__ == "__main__":
main()