-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathslm2_demo.py
211 lines (171 loc) · 7.74 KB
/
slm2_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
"""
This is a demo to demonstrate the synthetic data generation
"""
def box_to_list(box):
return [box['xmin'], box['ymin'], box['xmax'], box['ymax']]
def detr_postprocess(output):
labels = [res['label'] for res in output]
boxes = [box_to_list(res['box']) for res in output]
return labels, boxes
import torch.nn.functional as F
from PIL import ImageDraw, Image
from copy import deepcopy
from transformers import pipeline
class SynPipe:
"""
Part of synthetic data pipeline that extracts object labels, bounding boxes and depth information from an image.
Returns this info in a pure textual format.
"""
def __init__(self):
self.detr_pipe = pipeline("object-detection", model="facebook/detr-resnet-50", device = 'cuda')
self.depth_pipe = pipeline("depth-estimation", model="Intel/dpt-large", device = 'cuda')
def visualize_detr(self, img : Image.Image) -> Image:
_, boxes = detr_postprocess(self.detr_pipe(img))
img = deepcopy(img)
draw = ImageDraw.Draw(img)
for box in boxes:
draw.rectangle(box, outline="red")
return img
def visualize_depth(self, img : Image.Image) -> Image:
return self.depth_pipe(img)['depth']
def __call__(self, img):
w, h = img.size
labels, boxes = detr_postprocess(self.detr_pipe(img))
depth_map = self.depth_pipe(img)['predicted_depth'].unsqueeze(0)
depth_map = F.interpolate(depth_map, (h, w)).squeeze()
depths = []
for box in boxes:
x,y,xw,yh = box
center_x = int((x + xw)/2)
center_y = int((y+yh)/2)
depth_at_xy = depth_map[center_y, center_x].squeeze().item()
depths.append(depth_at_xy)
res_str = ""
for i in range(len(labels)):
label = labels[i]
x,y,x2,y2 = boxes[i]
depth = depths[i]
res_str+=f"\n[Label: {label}, Bounding Box: Top-Left({x},{y}) Bottom-Right({x2,y2}), Depth Score: {depth}"
# Calls models again but this is only for demo
return res_str, self.visualize_detr(img), self.visualize_depth(img)
# System prompt for the data task
sys_2 = """
You are helping create a synthetic image-text dataset for the purpose of training a vision-language chat model.
While you are a pure textual language model and cannot see images, you will be given some data on the image in a textual
format that should let you infer its contents. The first thing you'll be given is the size of the image and a basic
caption on its contents. Then, you will be given a list of objects in the image. Each list item will have an associated
caption (what the item is), a bounding box (where the item is, defined in terms of Top-Left x,y and Bottom-Left x,y) and
a depth number (how far something is, a higher number => closer , lower number => farther away). The scope of your task is
combining all this information meaningfully. The bounding boxes and depth numbers should tell you where objects are in the scene.
You will have 3 distinct tasks:
1. Generate a detailed description of the image that expands on the given caption with information about the relative positions of things as far as you can tell from the additional data.
2. Generate a set (3-5) of simple questions and answers that test understanding of image content (where things are relative to each other, what side of the view they're on, how far away they are, etc.) (you don't need exact numbers: statements like "to the left of", "far away" "close" are sufficient)
3. Generate a set (1-3) of complex questions and answers that test conceptual understanding of the context of the image (what is going on). These should require reasoning about the scene to determine potential nuances.
Please generate in the following strict json format:
{
"caption" : "{DETAILED DESCRIPTION HERE}",
"simple_qas" : [
{
"question" : "{SIMPLE QUESTION 1}",
"answer" : "{ANSWER 1}"
},
...
{
"question" : "{SIMPLE QUESTION N}",
"answer" : "{ANSWER N}"
}
],
"complex_qas" : [
{
"question" : "{COMPLEX QUESTION 1}",
"answer" : "{ANSWER 1}"
},
...
{
"question" : "{COMPLEX QUESTION N}",
"answer" : "{ANSWER N}"
}
]
}
"""
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
class StableLM2Wrapper:
def __init__(self):
model_id = 'stabilityai/stablelm-2-12b-chat'
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.float16, device_map = 'auto', trust_remote_code = True)
self.syn_pipe = SynPipe()
def __call__(self, image, caption, system = None):
"""
Generate synethtic data from image and caption
"""
user_input = f"Image Dimensions: {image.size}\nCaption:{caption}\nObjects:"
pipe_result, detr_vis, depth_vis = self.syn_pipe(image)
user_input += pipe_result
prompt = [
{'role' : 'system', 'content' : sys_2 if system is None else system},
{'role' : 'user', 'content' : user_input}
]
inputs = self.tokenizer.apply_chat_template(
prompt,
add_generation_prompt = True,
return_tensors = "pt"
)
tokens = self.model.generate(
inputs.to(self.model.device),
max_new_tokens = 1000,
temperature=0.7,
do_sample = True,
eos_token_id = 100278
)
output = self.tokenizer.decode(tokens[:,inputs.shape[-1]:][0], skip_special_tokens = False)
return output, detr_vis, depth_vis
class SLM2ChatWrapper:
def __init__(self):
model_id = 'stabilityai/stablelm-2-12b-chat'
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.float16, device_map = 'auto', trust_remote_code = True)
def __call__(self, system, user_input):
"""
Generate synethtic data from image and caption
"""
prompt = [
{'role' : 'system', 'content' : system},
{'role' : 'user', 'content' : user_input}
]
inputs = self.tokenizer.apply_chat_template(
prompt,
add_generation_prompt = True,
return_tensors = "pt"
)
tokens = self.model.generate(
inputs.to(self.model.device),
max_new_tokens = 1000,
temperature=0.7,
do_sample = True,
eos_token_id = 100278
)
output = self.tokenizer.decode(tokens[:,inputs.shape[-1]:][0], skip_special_tokens = False)
return output
from diffusers import AutoPipelineForText2Image
class SDXLWrapper:
def __init__(self):
model_id = "stabilityai/sdxl-turbo"
self.pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
self.pipe.to("cuda")
def __call__(self, prompt):
img = self.pipe(prompt, guidance_scale = 0.0, num_inference_steps = 1).images[0]
return img, prompt
if __name__ == "__main__":
chat = StableLM2Wrapper()
img_gen = SDXLWrapper()
import gradio as gr
def generate_synthetic_data(prompt):
img, caption = img_gen(prompt)
syn_data, detr_vis, depth_vis = chat(img, caption)
return syn_data, img, detr_vis, depth_vis
iface = gr.Interface(fn=generate_synthetic_data,
inputs="text",
outputs=["text", "image", "image", "image"])
iface.launch(share = True)