From f1486a69dc0361b06536fdf428aab493b29e57a1 Mon Sep 17 00:00:00 2001 From: Davidqian123 Date: Thu, 3 Oct 2024 19:06:22 +0000 Subject: [PATCH] add lora support for text generation --- nexa/cli/entry.py | 3 ++- nexa/gguf/nexa_inference_image.py | 45 ++++++++++++++++++++++++++++--- nexa/gguf/nexa_inference_text.py | 14 +++++----- 3 files changed, 50 insertions(+), 12 deletions(-) diff --git a/nexa/cli/entry.py b/nexa/cli/entry.py index 56ea1fcc..d137f944 100644 --- a/nexa/cli/entry.py +++ b/nexa/cli/entry.py @@ -108,6 +108,7 @@ def main(): text_group.add_argument("-k", "--top_k", type=int, help="Top-k sampling parameter") text_group.add_argument("-p", "--top_p", type=float, help="Top-p sampling parameter") text_group.add_argument("-sw", "--stop_words", nargs="*", help="List of stop words for early stopping") + text_group.add_argument("--lora_path", type=str, help="Path to a LoRA file to apply to the model.") text_group.add_argument("-hf", "--huggingface", action="store_true", help="Load model from Hugging Face Hub") # Image generation arguments @@ -123,7 +124,7 @@ def main(): image_group.add_argument("--wtype", type=str, help="Weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)") image_group.add_argument("--control_net_path", type=str, help="Path to control net model") image_group.add_argument("--control_image_path", type=str, help="Path to image condition for Control Net") - image_group.add_argument("--control_strength", type=str, help="Strength to apply Control Net") + image_group.add_argument("--control_strength", type=float, help="Strength to apply Control Net") # ASR arguments asr_group = run_parser.add_argument_group('Automatic Speech Recognition options') diff --git a/nexa/gguf/nexa_inference_image.py b/nexa/gguf/nexa_inference_image.py index 216a9be6..120b3251 100644 --- a/nexa/gguf/nexa_inference_image.py +++ b/nexa/gguf/nexa_inference_image.py @@ -53,8 +53,14 @@ class NexaImageInference: guidance_scale (float): Guidance scale for diffusion. output_path (str): Output path for the generated image. random_seed (int): Random seed for image generation. + lora_dir (str): Path to directory containing LoRA files. + lora_path (str): Path to a LoRA file to apply to the model. + wtype (str): Weight type (options: default, f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0). + control_net_path (str): Path to control net model. + control_image_path (str): Path to image condition for Control Net. + control_strength (float): Strength to apply Control Net. streamlit (bool): Run the inference in Streamlit UI. - + profiling (bool): Enable profiling logs for the inference process. """ def __init__(self, model_path, local_path=None, **kwargs): @@ -98,6 +104,7 @@ def __init__(self, model_path, local_path=None, **kwargs): else: self.params = DEFAULT_IMG_GEN_PARAMS.copy() + self.profiling = kwargs.get("profiling", False) self.params.update({k: v for k, v in kwargs.items() if v is not None}) if not kwargs.get("streamlit", False): self._load_model(model_path) @@ -119,7 +126,7 @@ def _load_model(self, model_path: str): wtype=self.params.get( "wtype", NEXA_RUN_MODEL_PRECISION_MAP.get(model_path, "default") ), # Weight type (options: default, f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0) - verbose=False, + verbose=self.profiling, ) else: self.model = StableDiffusion( @@ -130,7 +137,7 @@ def _load_model(self, model_path: str): "wtype", NEXA_RUN_MODEL_PRECISION_MAP.get(model_path, "default") ), # Weight type (options: default, f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0) control_net_path=self.params.get("control_net_path", ""), - verbose=False, + verbose=self.profiling, ) def _save_images(self, images): @@ -352,13 +359,43 @@ def run_streamlit(self, model_path: str): default=0, help="Random seed for image generation", ) - # parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help="Device to run the model on (default: cuda if available, else cpu)") + parser.add_argument( + "--lora_dir", + type=str, + help="Path to directory containing LoRA files.", + ) + parser.add_argument( + "--wtype", + type=str, + help="Weight type (options: default, f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)", + ) + parser.add_argument( + "--control_net_path", + type=str, + help="Path to control net model.", + ) + parser.add_argument( + "--control_image_path", + type=str, + help="Path to image condition for Control Net.", + ) + parser.add_argument( + "--control_strength", + type=float, + help="Strength to apply Control Net.", + ) parser.add_argument( "-st", "--streamlit", action="store_true", help="Run the inference in Streamlit UI", ) + parser.add_argument( + "-pf", + "--profiling", + action="store_true", + help="Enable profiling logs for the inference process", + ) args = parser.parse_args() kwargs = {k: v for k, v in vars(args).items() if v is not None} model_path = kwargs.pop("model_path") diff --git a/nexa/gguf/nexa_inference_text.py b/nexa/gguf/nexa_inference_text.py index 8106765d..e168e1ef 100644 --- a/nexa/gguf/nexa_inference_text.py +++ b/nexa/gguf/nexa_inference_text.py @@ -107,6 +107,7 @@ def _load_model(self): chat_format=self.chat_format, n_ctx=2048, n_gpu_layers=-1 if is_gpu_available() else 0, + lora_path=self.params.get("lora_path", ""), ) except Exception as e: logging.error(f"Failed to load model: {e}. Falling back to CPU.", exc_info=True) @@ -116,6 +117,7 @@ def _load_model(self): chat_format=self.chat_format, n_ctx=2048, n_gpu_layers=0, # hardcode to use CPU + lora_path=self.params.get("lora_path", ""), ) load_time = time.time() - start_time @@ -331,13 +333,11 @@ def run_streamlit(self, model_path: str): action="store_true", help="Run the inference in Streamlit UI", ) - # parser.add_argument( - # "-tlps", - # "--top_logprobs", - # type=int, - # default=None, # -tlps 5 - # help="Number of most likely tokens to return at each token position", - # ) + parser.add_argument( + "--lora_path", + type=str, + help="Path to a LoRA file to apply to the model.", + ) args = parser.parse_args() kwargs = {k: v for k, v in vars(args).items() if v is not None} model_path = kwargs.pop("model_path")