Skip to content

Commit

Permalink
add simple side-by-side
Browse files Browse the repository at this point in the history
  • Loading branch information
BabyChouSr committed Nov 29, 2024
1 parent 79a1722 commit 59207c1
Showing 1 changed file with 68 additions and 21 deletions.
89 changes: 68 additions & 21 deletions fastchat/serve/gradio_block_arena_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,29 @@
import requests
import base64
import io
import hashlib
from PIL import Image
import datetime
import json

from fastchat.constants import LOGDIR
from fastchat.utils import upload_image_file_to_gcs

FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
API_BASE = "https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/{}/text_to_image"
API_BASE = "https://api.fireworks.ai/inference/v1/workflows/accounts/fireworks/models/{model}/text_to_image"
DUMMY_MODELS = ["stable-diffusion-3p5-medium",
"stable-diffusion-3p5-large",
"stable-diffusion-3p5-large-turbo",
"flux-1-dev-fp8",
"flux-1-schnell-fp8",
]

def generate_image(prompt, model):
def get_conv_log_filename():
t = datetime.datetime.now()
conv_log_filename = f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json"
return os.path.join(LOGDIR, f"txt2img-{conv_log_filename}")

def generate_image(model, prompt):
"""Generate image from text prompt using Fireworks API"""
headers = {
"Authorization": f"Bearer {FIREWORKS_API_KEY}",
Expand All @@ -38,29 +49,65 @@ def generate_image(prompt, model):

image_bytes = response.content
image = Image.open(io.BytesIO(image_bytes))

return image


except requests.exceptions.RequestException as e:
return f"Error generating image: {str(e)}"

log_filename = get_conv_log_filename()
image_hash = hashlib.md5(image.tobytes()).hexdigest()
image_filename = f"{image_hash}.png"
upload_image_file_to_gcs(image, image_filename)

with open(log_filename, "a") as f:
data = {
"model": model,
"prompt": prompt,
"image_filename": image_filename
}
f.write(json.dumps(data) + "\n")

return image

def generate_image_multi(model_left, model_right, prompt):
images = []
for model in [model_left, model_right]:
images.append(generate_image(model, prompt))

return images


# Create Gradio interface
with gr.Blocks(title="Text to Image Generator") as demo:
gr.Markdown("# Text to Image Generator")
gr.Markdown("Enter a text prompt to generate an image")


num_sides = 2
model_selectors = [None] * num_sides

with gr.Column():
with gr.Group():
model_selector = gr.Dropdown(
choices=DUMMY_MODELS,
interactive=True,
show_label=False,
container=False,
)
image_output = gr.Image(
label="Generated Image",
type="pil"
)
with gr.Row():
for i in range(num_sides):
model_selectors[i] = gr.Dropdown(
choices=DUMMY_MODELS,
value=DUMMY_MODELS[i] if DUMMY_MODELS else "",
interactive=True,
show_label=False,
container=False,
)


with gr.Group():
with gr.Row():
output_left = gr.Image(
type="pil",
show_label=False
)
output_right = gr.Image(
type="pil",
show_label=False
)

with gr.Row():
text_input = gr.Textbox(
Expand All @@ -72,15 +119,15 @@ def generate_image(prompt, model):

# Handle generation
send_btn.click(
fn=generate_image,
inputs=[text_input, model_selector],
outputs=image_output
fn=generate_image_multi,
inputs=model_selectors + [text_input],
outputs=[output_left, output_right]
)

text_input.submit(
fn=generate_image,
inputs=[text_input, model_selector],
outputs=image_output
fn=generate_image_multi,
inputs=model_selectors + [text_input],
outputs=[output_left, output_right]
)

if __name__ == "__main__":
Expand Down

0 comments on commit 59207c1

Please sign in to comment.