Skip to content

Commit

Permalink
Merge pull request kyegomez#200 from elder-plinius/master
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez authored Nov 27, 2023
2 parents 88e9ad0 + 7bcfcbc commit 0a66093
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 2 deletions.
4 changes: 2 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ AI21_API_KEY="your_api_key_here"
COHERE_API_KEY="your_api_key_here"
ALEPHALPHA_API_KEY="your_api_key_here"
HUGGINFACEHUB_API_KEY="your_api_key_here"

STABILITY_API_KEY="your_api_key_here"

WOLFRAM_ALPHA_APPID="your_wolfram_alpha_appid_here"
ZAPIER_NLA_API_KEY="your_zapier_nla_api_key_here"
Expand Down Expand Up @@ -41,4 +41,4 @@ REDIS_PORT=
PINECONE_API_KEY=""
BING_COOKIE=""

PSG_CONNECTION_STRING=""
PSG_CONNECTION_STRING=""
112 changes: 112 additions & 0 deletions playground/models/stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os
import base64
import requests
from dotenv import load_dotenv
from typing import List

load_dotenv()

class StableDiffusion:
"""
A class to interact with the Stable Diffusion API for image generation.
Attributes:
-----------
api_key : str
The API key for accessing the Stable Diffusion API.
api_host : str
The host URL of the Stable Diffusion API.
engine_id : str
The ID of the Stable Diffusion engine.
headers : dict
The headers for the API request.
output_dir : str
Directory where generated images will be saved.
Methods:
--------
generate_image(prompt: str, cfg_scale: int, height: int, width: int, samples: int, steps: int) -> List[str]:
Generates images based on a text prompt and returns a list of file paths to the generated images.
"""

def __init__(self, api_key: str, api_host: str = "https://api.stability.ai"):
"""
Initializes the StableDiffusion class with the provided API key and host.
Parameters:
-----------
api_key : str
The API key for accessing the Stable Diffusion API.
api_host : str
The host URL of the Stable Diffusion API. Default is "https://api.stability.ai".
"""
self.api_key = api_key
self.api_host = api_host
self.engine_id = "stable-diffusion-v1-6"
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json"
}
self.output_dir = "images"
os.makedirs(self.output_dir, exist_ok=True)

def generate_image(self, prompt: str, cfg_scale: int = 7, height: int = 1024, width: int = 1024, samples: int = 1, steps: int = 30) -> List[str]:
"""
Generates images based on a text prompt.
Parameters:
-----------
prompt : str
The text prompt based on which the image will be generated.
cfg_scale : int
CFG scale parameter for image generation. Default is 7.
height : int
Height of the generated image. Default is 1024.
width : int
Width of the generated image. Default is 1024.
samples : int
Number of images to generate. Default is 1.
steps : int
Number of steps for the generation process. Default is 30.
Returns:
--------
List[str]:
A list of paths to the generated images.
Raises:
-------
Exception:
If the API response is not 200 (OK).
"""
response = requests.post(
f"{self.api_host}/v1/generation/{self.engine_id}/text-to-image",
headers=self.headers,
json={
"text_prompts": [{"text": prompt}],
"cfg_scale": cfg_scale,
"height": height,
"width": width,
"samples": samples,
"steps": steps,
},
)

if response.status_code != 200:
raise Exception(f"Non-200 response: {response.text}")

data = response.json()
image_paths = []
for i, image in enumerate(data["artifacts"]):
image_path = os.path.join(self.output_dir, f"v1_txt2img_{i}.png")
with open(image_path, "wb") as f:
f.write(base64.b64decode(image["base64"]))
image_paths.append(image_path)

return image_paths

# Usage example:
# sd = StableDiffusion("your-api-key")
# images = sd.generate_image("A scenic landscape with mountains")
# print(images)

0 comments on commit 0a66093

Please sign in to comment.