diff --git a/src/backend/sd.rs b/src/backend/sd.rs
index 6c06d3d..da30fcf 100644
--- a/src/backend/sd.rs
+++ b/src/backend/sd.rs
@@ -3,7 +3,7 @@ use endpoints::{
files::FileObject,
images::{ImageCreateRequest, ImageEditRequest, ImageVariationRequest, ResponseFormat},
};
-use hyper::{body::to_bytes, Body, Method, Request, Response};
+use hyper::{body::to_bytes, header::CONTENT_TYPE, Body, Method, Request, Response};
use multipart::server::{Multipart, ReadEntry, ReadEntryResult};
use multipart_2021 as multipart;
use std::{
@@ -38,94 +38,707 @@ pub(crate) async fn image_generation_handler(mut req: Request
) -> Response
}
}
- let res = if req.method() == Method::POST {
- info!(target: "stdout", "Prepare the image generation request.");
+ let content_type = req
+ .headers()
+ .get(CONTENT_TYPE)
+ .and_then(|ct| ct.to_str().ok());
+
+ if let Some(content_type) = content_type {
+ if content_type.starts_with("multipart/") {
+ // Handle multipart request
+ info!(target: "stdout", "Handling multipart request");
+ // Your multipart handling code here
+ } else {
+ // Handle command request
+ info!(target: "stdout", "Handling command request");
+ // Your command handling code here
+ }
+ } else {
+ // Handle request with no Content-Type header
+ info!(target: "stdout", "Handling request with no Content-Type header");
+ // Your handling code here
+ }
+
+ let mut image_request = match content_type {
+ Some(content_type) if content_type.starts_with("multipart/") => {
+ let boundary = "boundary=";
+
+ let boundary = req.headers().get("content-type").and_then(|ct| {
+ let ct = ct.to_str().ok()?;
+ let idx = ct.find(boundary)?;
+ Some(ct[idx + boundary.len()..].to_string())
+ });
+
+ let req_body = req.into_body();
+ let body_bytes = match to_bytes(req_body).await {
+ Ok(body_bytes) => body_bytes,
+ Err(e) => {
+ let err_msg = format!("Fail to read buffer from request body. {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ };
+
+ let cursor = Cursor::new(body_bytes.to_vec());
+
+ let mut multipart = Multipart::with_body(cursor, boundary.unwrap());
+
+ let mut image_request = ImageCreateRequest::default();
+ while let ReadEntryResult::Entry(mut field) = multipart.read_entry_mut() {
+ match &*field.headers.name {
+ "control_image" => {
+ let filename = match field.headers.filename {
+ Some(filename) => filename,
+ None => {
+ let err_msg =
+ "Failed to upload the image file. The filename is not provided.";
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ };
+
+ // get the image data
+ let mut buffer = Vec::new();
+ let size_in_bytes = match field.data.read_to_end(&mut buffer) {
+ Ok(size_in_bytes) => size_in_bytes,
+ Err(e) => {
+ let err_msg = format!("Failed to read the image file. {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ };
+
+ // create a file id for the image file
+ let id = format!("file_{}", uuid::Uuid::new_v4());
+
+ // save the file
+ let path = Path::new("archives");
+ if !path.exists() {
+ fs::create_dir(path).unwrap();
+ }
+ let file_path = path.join(&id);
+ if !file_path.exists() {
+ fs::create_dir(&file_path).unwrap();
+ }
+ let mut file = match File::create(file_path.join(&filename)) {
+ Ok(file) => file,
+ Err(e) => {
+ let err_msg = format!(
+ "Failed to create archive document {}. {}",
+ &filename, e
+ );
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ };
+ file.write_all(&buffer[..]).unwrap();
+
+ // log
+ info!(target: "stdout", "file_id: {}, file_name: {}, size in bytes: {}", &id, &filename, size_in_bytes);
+
+ let created_at =
+ match SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
+ Ok(n) => n.as_secs(),
+ Err(_) => {
+ let err_msg = "Failed to get the current time.";
+
+ // log
+ error!(target: "stdout", "{}", err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ };
+
+ // create a file object
+ image_request.control_image = Some(FileObject {
+ id,
+ bytes: size_in_bytes as u64,
+ created_at,
+ filename,
+ object: "file".to_string(),
+ purpose: "assistants".to_string(),
+ });
+ }
+ "prompt" => match field.is_text() {
+ true => {
+ let mut prompt = String::new();
+
+ if let Err(e) = field.data.read_to_string(&mut prompt) {
+ let err_msg = format!("Failed to read the prompt. {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+
+ image_request.prompt = prompt;
+ }
+ false => {
+ let err_msg =
+ "Failed to get the prompt. The prompt field in the request should be a text field.";
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ },
+ "negative_prompt" => match field.is_text() {
+ true => {
+ let mut negative_prompt = String::new();
+
+ if let Err(e) = field.data.read_to_string(&mut negative_prompt) {
+ let err_msg = format!("Failed to read the prompt. {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+
+ image_request.prompt = negative_prompt;
+ }
+ false => {
+ let err_msg =
+ "Failed to get the negative prompt. The negative prompt field in the request should be a text field.";
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ },
+ "model" => match field.is_text() {
+ true => {
+ let mut model = String::new();
+
+ if let Err(e) = field.data.read_to_string(&mut model) {
+ let err_msg = format!("Failed to read the model. {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+
+ image_request.model = model;
+ }
+ false => {
+ let err_msg =
+ "Failed to get the model name. The model field in the request should be a text field.";
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ },
+ "n" => match field.is_text() {
+ true => {
+ let mut n = String::new();
+
+ if let Err(e) = field.data.read_to_string(&mut n) {
+ let err_msg = format!("Failed to read the number of images. {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+
+ match n.parse::() {
+ Ok(n) => image_request.n = Some(n),
+ Err(e) => {
+ let err_msg = format!(
+ "Failed to parse the number of images. Reason: {}",
+ e
+ );
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::bad_request(err_msg);
+ }
+ }
+ }
+ false => {
+ let err_msg =
+ "Failed to get the number of images. The n field in the request should be a text field.";
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ },
+ "size" => {
+ match field.is_text() {
+ true => {
+ let mut size = String::new();
+
+ if let Err(e) = field.data.read_to_string(&mut size) {
+ let err_msg = format!("Failed to read the size. {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+
+ // image_request.size = Some(size);
+
+ let parts: Vec<&str> = size.split('x').collect();
+ if parts.len() != 2 {
+ let err_msg = "Invalid size format. The correct format is `HeightxWidth`. Example: 256x256";
+
+ // log
+ error!(target: "stdout", "{}", err_msg);
+
+ return error::bad_request(err_msg);
+ }
+ image_request.height = Some(parts[0].parse().unwrap());
+ image_request.width = Some(parts[1].parse().unwrap());
+ }
+ false => {
+ let err_msg =
+ "Failed to get the size. The size field in the request should be a text field.";
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ }
+ }
+ "response_format" => match field.is_text() {
+ true => {
+ let mut response_format = String::new();
+
+ if let Err(e) = field.data.read_to_string(&mut response_format) {
+ let err_msg = format!("Failed to read the response format. {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+
+ match response_format.parse::() {
+ Ok(format) => image_request.response_format = Some(format),
+ Err(e) => {
+ let err_msg = format!(
+ "Failed to parse the response format. Reason: {}",
+ e
+ );
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::bad_request(err_msg);
+ }
+ }
+ }
+ false => {
+ let err_msg =
+ "Failed to get the response format. The response format field in the request should be a text field.";
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ },
+ "user" => match field.is_text() {
+ true => {
+ let mut user = String::new();
+
+ if let Err(e) = field.data.read_to_string(&mut user) {
+ let err_msg = format!("Failed to read the user. {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+
+ image_request.user = Some(user);
+ }
+ false => {
+ let err_msg =
+ "Failed to get the user. The user field in the request should be a text field.";
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ },
+ "cfg_scale" => match field.is_text() {
+ true => {
+ let mut cfg_scale = String::new();
+
+ if let Err(e) = field.data.read_to_string(&mut cfg_scale) {
+ let err_msg = format!("Failed to read the cfg_config. {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+
+ match cfg_scale.parse::() {
+ Ok(scale) => image_request.cfg_scale = Some(scale),
+ Err(e) => {
+ let err_msg = format!(
+ "Failed to parse the number of images. Reason: {}",
+ e
+ );
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::bad_request(err_msg);
+ }
+ }
+ }
+ false => {
+ let err_msg =
+ "Failed to get the cfg_config. The cfg_config field in the request should be a text field.";
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ },
+ "sample_method" => match field.is_text() {
+ true => {
+ let mut sample_method = String::new();
- // parse request
- let body_bytes = match to_bytes(req.body_mut()).await {
- Ok(body_bytes) => body_bytes,
- Err(e) => {
- let err_msg = format!("Fail to read buffer from request body. {}", e);
+ if let Err(e) = field.data.read_to_string(&mut sample_method) {
+ let err_msg = format!("Failed to read the sample_method. {}", e);
- // log
- error!(target: "stdout", "{}", &err_msg);
+ // log
+ error!(target: "stdout", "{}", &err_msg);
- return error::internal_server_error(err_msg);
- }
- };
- let mut image_request: ImageCreateRequest = match serde_json::from_slice(&body_bytes) {
- Ok(image_request) => image_request,
- Err(e) => {
- let err_msg = format!("Fail to deserialize image create request: {msg}", msg = e);
+ return error::internal_server_error(err_msg);
+ }
- // log
- error!(target: "stdout", "{}", &err_msg);
+ image_request.sample_method = Some(sample_method.as_str().into());
+ }
+ false => {
+ let err_msg =
+ "Failed to get the sample_method. The sample_method field in the request should be a text field.";
- return error::bad_request(err_msg);
- }
- };
-
- // check if the user id is provided
- if image_request.user.is_none() {
- image_request.user = Some(gen_image_id())
- };
- let id = image_request.user.clone().unwrap();
-
- // log user id
- info!(target: "stdout", "user: {}", image_request.user.clone().unwrap());
-
- match llama_core::images::image_generation(&mut image_request).await {
- Ok(images_response) => {
- // serialize embedding object
- match serde_json::to_string(&images_response) {
- Ok(s) => {
- // return response
- let result = Response::builder()
- .header("Access-Control-Allow-Origin", "*")
- .header("Access-Control-Allow-Methods", "*")
- .header("Access-Control-Allow-Headers", "*")
- .header("Content-Type", "application/json")
- .header("user", id)
- .body(Body::from(s));
- match result {
- Ok(response) => response,
- Err(e) => {
- let err_msg = e.to_string();
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ },
+ "steps" => match field.is_text() {
+ true => {
+ let mut steps = String::new();
+
+ if let Err(e) = field.data.read_to_string(&mut steps) {
+ let err_msg = format!("Failed to read the steps. {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+
+ match steps.parse::() {
+ Ok(steps) => image_request.steps = Some(steps),
+ Err(e) => {
+ let err_msg =
+ format!("Failed to parse the steps. Reason: {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::bad_request(err_msg);
+ }
+ }
+ }
+ false => {
+ let err_msg =
+ "Failed to get the steps. The steps field in the request should be a text field.";
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ },
+ "height" => match field.is_text() {
+ true => {
+ let mut height = String::new();
+
+ if let Err(e) = field.data.read_to_string(&mut height) {
+ let err_msg = format!("Failed to read the height. {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+
+ match height.parse::() {
+ Ok(height) => image_request.height = Some(height),
+ Err(e) => {
+ let err_msg =
+ format!("Failed to parse the height. Reason: {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::bad_request(err_msg);
+ }
+ }
+ }
+ false => {
+ let err_msg =
+ "Failed to get the height. The height field in the request should be a text field.";
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ },
+ "width" => match field.is_text() {
+ true => {
+ let mut width = String::new();
+
+ if let Err(e) = field.data.read_to_string(&mut width) {
+ let err_msg = format!("Failed to read the width. {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+
+ match width.parse::() {
+ Ok(width) => image_request.width = Some(width),
+ Err(e) => {
+ let err_msg =
+ format!("Failed to parse the width. Reason: {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::bad_request(err_msg);
+ }
+ }
+ }
+ false => {
+ let err_msg =
+ "Failed to get the width. The width field in the request should be a text field.";
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ },
+ "control_strength" => match field.is_text() {
+ true => {
+ let mut control_strength = String::new();
+
+ if let Err(e) = field.data.read_to_string(&mut control_strength) {
+ let err_msg = format!("Failed to read the control_strength. {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+
+ match control_strength.parse::() {
+ Ok(control_strength) => {
+ image_request.control_strength = Some(control_strength)
+ }
+ Err(e) => {
+ let err_msg = format!(
+ "Failed to parse the control_strength. Reason: {}",
+ e
+ );
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::bad_request(err_msg);
+ }
+ }
+ }
+ false => {
+ let err_msg =
+ "Failed to get the control_strength. The control_strength field in the request should be a text field.";
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ },
+ "seed" => match field.is_text() {
+ true => {
+ let mut seed = String::new();
+
+ if let Err(e) = field.data.read_to_string(&mut seed) {
+ let err_msg = format!("Failed to read the seed. {}", e);
// log
error!(target: "stdout", "{}", &err_msg);
- error::internal_server_error(err_msg)
+ return error::internal_server_error(err_msg);
+ }
+
+ match seed.parse::() {
+ Ok(seed) => image_request.seed = Some(seed),
+ Err(e) => {
+ let err_msg =
+ format!("Failed to parse the seed. Reason: {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::bad_request(err_msg);
+ }
}
}
+ false => {
+ let err_msg =
+ "Failed to get the seed. The seed field in the request should be a text field.";
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ },
+ unsupported_field => {
+ let err_msg = format!("Unsupported field: {}", unsupported_field);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::bad_request(err_msg);
+ }
+ }
+ }
+
+ image_request
+ }
+ _ => {
+ if req.method() == Method::POST {
+ info!(target: "stdout", "Prepare the image generation request.");
+
+ // parse request
+ let body_bytes = match to_bytes(req.body_mut()).await {
+ Ok(body_bytes) => body_bytes,
+ Err(e) => {
+ let err_msg = format!("Fail to read buffer from request body. {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
}
+ };
+ let image_request: ImageCreateRequest = match serde_json::from_slice(&body_bytes) {
+ Ok(image_request) => image_request,
Err(e) => {
let err_msg =
- format!("Fail to serialize the `ListImagesResponse` instance. {}", e);
+ format!("Fail to deserialize image create request: {msg}", msg = e);
// log
error!(target: "stdout", "{}", &err_msg);
- error::internal_server_error(err_msg)
+ return error::bad_request(err_msg);
}
- }
- }
- Err(e) => {
- let err_msg = format!("Failed to get image generations. Reason: {}", e);
+ };
+
+ image_request
+ } else {
+ let err_msg = "Invalid HTTP Method.";
// log
error!(target: "stdout", "{}", &err_msg);
- error::internal_server_error(err_msg)
+ return error::internal_server_error(err_msg);
}
}
- } else {
- let err_msg = "Invalid HTTP Method.";
+ };
+
+ if image_request.user.is_none() {
+ image_request.user = Some(gen_image_id())
+ };
+ let id = image_request.user.clone().unwrap();
+
+ // log user id
+ info!(target: "stdout", "user: {}", image_request.user.clone().unwrap());
+
+ let res = match llama_core::images::image_generation(&mut image_request).await {
+ Ok(images_response) => {
+ // serialize embedding object
+ match serde_json::to_string(&images_response) {
+ Ok(s) => {
+ // return response
+ let result = Response::builder()
+ .header("Access-Control-Allow-Origin", "*")
+ .header("Access-Control-Allow-Methods", "*")
+ .header("Access-Control-Allow-Headers", "*")
+ .header("Content-Type", "application/json")
+ .header("user", id)
+ .body(Body::from(s));
+ match result {
+ Ok(response) => response,
+ Err(e) => {
+ let err_msg = e.to_string();
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ error::internal_server_error(err_msg)
+ }
+ }
+ }
+ Err(e) => {
+ let err_msg =
+ format!("Fail to serialize the `ListImagesResponse` instance. {}", e);
+
+ // log
+ error!(target: "stdout", "{}", &err_msg);
- // log
- error!(target: "stdout", "{}", &err_msg);
+ error::internal_server_error(err_msg)
+ }
+ }
+ }
+ Err(e) => {
+ let err_msg = format!("Failed to get image generations. Reason: {}", e);
- error::internal_server_error(err_msg)
+ // log
+ error!(target: "stdout", "{}", &err_msg);
+
+ error::internal_server_error(err_msg)
+ }
};
// log
@@ -313,7 +926,7 @@ pub(crate) async fn image_edit_handler(req: Request) -> Response {
}
false => {
let err_msg =
- "Failed to get the prompt. The prompt field in the request should be a text field.";
+ "Failed to get the negative prompt. The negative prompt field in the request should be a text field.";
// log
error!(target: "stdout", "{}", &err_msg);