forked from kyegomez/swarms
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request kyegomez#200 from elder-plinius/master
- Loading branch information
Showing
2 changed files
with
114 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import os | ||
import base64 | ||
import requests | ||
from dotenv import load_dotenv | ||
from typing import List | ||
|
||
load_dotenv() | ||
|
||
class StableDiffusion: | ||
""" | ||
A class to interact with the Stable Diffusion API for image generation. | ||
Attributes: | ||
----------- | ||
api_key : str | ||
The API key for accessing the Stable Diffusion API. | ||
api_host : str | ||
The host URL of the Stable Diffusion API. | ||
engine_id : str | ||
The ID of the Stable Diffusion engine. | ||
headers : dict | ||
The headers for the API request. | ||
output_dir : str | ||
Directory where generated images will be saved. | ||
Methods: | ||
-------- | ||
generate_image(prompt: str, cfg_scale: int, height: int, width: int, samples: int, steps: int) -> List[str]: | ||
Generates images based on a text prompt and returns a list of file paths to the generated images. | ||
""" | ||
|
||
def __init__(self, api_key: str, api_host: str = "https://api.stability.ai"): | ||
""" | ||
Initializes the StableDiffusion class with the provided API key and host. | ||
Parameters: | ||
----------- | ||
api_key : str | ||
The API key for accessing the Stable Diffusion API. | ||
api_host : str | ||
The host URL of the Stable Diffusion API. Default is "https://api.stability.ai". | ||
""" | ||
self.api_key = api_key | ||
self.api_host = api_host | ||
self.engine_id = "stable-diffusion-v1-6" | ||
self.headers = { | ||
"Authorization": f"Bearer {self.api_key}", | ||
"Content-Type": "application/json", | ||
"Accept": "application/json" | ||
} | ||
self.output_dir = "images" | ||
os.makedirs(self.output_dir, exist_ok=True) | ||
|
||
def generate_image(self, prompt: str, cfg_scale: int = 7, height: int = 1024, width: int = 1024, samples: int = 1, steps: int = 30) -> List[str]: | ||
""" | ||
Generates images based on a text prompt. | ||
Parameters: | ||
----------- | ||
prompt : str | ||
The text prompt based on which the image will be generated. | ||
cfg_scale : int | ||
CFG scale parameter for image generation. Default is 7. | ||
height : int | ||
Height of the generated image. Default is 1024. | ||
width : int | ||
Width of the generated image. Default is 1024. | ||
samples : int | ||
Number of images to generate. Default is 1. | ||
steps : int | ||
Number of steps for the generation process. Default is 30. | ||
Returns: | ||
-------- | ||
List[str]: | ||
A list of paths to the generated images. | ||
Raises: | ||
------- | ||
Exception: | ||
If the API response is not 200 (OK). | ||
""" | ||
response = requests.post( | ||
f"{self.api_host}/v1/generation/{self.engine_id}/text-to-image", | ||
headers=self.headers, | ||
json={ | ||
"text_prompts": [{"text": prompt}], | ||
"cfg_scale": cfg_scale, | ||
"height": height, | ||
"width": width, | ||
"samples": samples, | ||
"steps": steps, | ||
}, | ||
) | ||
|
||
if response.status_code != 200: | ||
raise Exception(f"Non-200 response: {response.text}") | ||
|
||
data = response.json() | ||
image_paths = [] | ||
for i, image in enumerate(data["artifacts"]): | ||
image_path = os.path.join(self.output_dir, f"v1_txt2img_{i}.png") | ||
with open(image_path, "wb") as f: | ||
f.write(base64.b64decode(image["base64"])) | ||
image_paths.append(image_path) | ||
|
||
return image_paths | ||
|
||
# Usage example: | ||
# sd = StableDiffusion("your-api-key") | ||
# images = sd.generate_image("A scenic landscape with mountains") | ||
# print(images) |