From 94701de1742c4b2e70fe8633a4067b6a2a43de3d Mon Sep 17 00:00:00 2001 From: Davidqian123 Date: Sat, 24 Aug 2024 04:57:22 +0000 Subject: [PATCH] bug fix --- nexa/cli/entry.py | 4 ++-- nexa/gguf/nexa_inference_image.py | 2 +- nexa/gguf/streamlit/streamlit_image_chat.py | 17 ++++++++++------- nexa/gguf/streamlit/streamlit_vlm.py | 2 +- pyproject.toml | 1 + requirements.txt | 2 +- 6 files changed, 16 insertions(+), 12 deletions(-) diff --git a/nexa/cli/entry.py b/nexa/cli/entry.py index e01a16b7..bc3aac1c 100644 --- a/nexa/cli/entry.py +++ b/nexa/cli/entry.py @@ -102,11 +102,11 @@ def main(): image_group = run_parser.add_argument_group('Image generation options') image_group.add_argument("-i2i", "--img2img", action="store_true", help="Whether to run image-to-image generation") image_group.add_argument("-ns", "--num_inference_steps", type=int, help="Number of inference steps") - image_group.add_argument("-np", "--num_images_per_prompt", type=int, help="Number of images to generate per prompt") + image_group.add_argument("-np", "--num_images_per_prompt", type=int, deafult=1, help="Number of images to generate per prompt") image_group.add_argument("-H", "--height", type=int, help="Height of the output image") image_group.add_argument("-W", "--width", type=int, help="Width of the output image") image_group.add_argument("-g", "--guidance_scale", type=float, help="Guidance scale for diffusion") - image_group.add_argument("-o", "--output", type=str, help="Output path for the generated image") + image_group.add_argument("-o", "--output", type=str, default="generated_images/image.png", help="Output path for the generated image") image_group.add_argument("-s", "--random_seed", type=int, help="Random seed for image generation") image_group.add_argument("--lora_dir", type=str, help="Path to directory containing LoRA files") image_group.add_argument("--wtype", type=str, help="Weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)") diff --git a/nexa/gguf/nexa_inference_image.py b/nexa/gguf/nexa_inference_image.py index 687888d7..a3d6114d 100644 --- a/nexa/gguf/nexa_inference_image.py +++ b/nexa/gguf/nexa_inference_image.py @@ -218,7 +218,7 @@ def run_img2img(self): sample_steps=self.params["num_inference_steps"], seed=self.params["random_seed"], control_cond=self.params.get("control_image_path", ""), - control_strength=self.params.get("control_strength", 0.9), + control_strength=self.params.get("control_strength", 0.9), ) self._save_images(images) diff --git a/nexa/gguf/streamlit/streamlit_image_chat.py b/nexa/gguf/streamlit/streamlit_image_chat.py index b9d8fb97..d2e3cc78 100644 --- a/nexa/gguf/streamlit/streamlit_image_chat.py +++ b/nexa/gguf/streamlit/streamlit_image_chat.py @@ -21,13 +21,16 @@ def generate_images(nexa_model: NexaImageInference, prompt: str, negative_prompt if not os.path.exists(output_dir): os.makedirs(output_dir) - nexa_model._txt2img(prompt, negative_prompt) - - images = [] - for file_name in os.listdir(output_dir): - if file_name.endswith(".png"): - image_path = os.path.join(output_dir, file_name) - images.append(Image.open(image_path)) + images = nexa_model.txt2img( + prompt=prompt, + negative_prompt=negative_prompt, + cfg_scale=nexa_model.params["guidance_scale"], + width=nexa_model.params["width"], + height=nexa_model.params["height"], + sample_steps=nexa_model.params["num_inference_steps"], + seed=nexa_model.params["random_seed"] + ) + return images diff --git a/nexa/gguf/streamlit/streamlit_vlm.py b/nexa/gguf/streamlit/streamlit_vlm.py index 9f334d13..63a3ba28 100644 --- a/nexa/gguf/streamlit/streamlit_vlm.py +++ b/nexa/gguf/streamlit/streamlit_vlm.py @@ -13,7 +13,7 @@ @st.cache_resource def load_model(model_path): local_path, run_type = pull_model(model_path) - nexa_model = NexaVLMInference(model_pat=model_path, local_path=local_path) + nexa_model = NexaVLMInference(model_path=model_path, local_path=local_path) return nexa_model diff --git a/pyproject.toml b/pyproject.toml index adee2ae4..4442967a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "tqdm", # Shared dependencies "tabulate", "streamlit", + "streamlit-audiorec", "python-multipart", ] classifiers = [ diff --git a/requirements.txt b/requirements.txt index 41a9b86e..df3f78ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ PyYAML requests setuptools soundfile -streamlit_audiorec +streamlit-audiorec transformers ttstokenizer