Skip to content

Commit

Permalink
Merge pull request #139 from NexaAI/david/newfeature
Browse files Browse the repository at this point in the history
add lora support for text generation
  • Loading branch information
zhiyuan8 authored Oct 3, 2024
2 parents 8c1ea24 + f1486a6 commit 6f30ff0
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 12 deletions.
3 changes: 2 additions & 1 deletion nexa/cli/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def main():
text_group.add_argument("-k", "--top_k", type=int, help="Top-k sampling parameter")
text_group.add_argument("-p", "--top_p", type=float, help="Top-p sampling parameter")
text_group.add_argument("-sw", "--stop_words", nargs="*", help="List of stop words for early stopping")
text_group.add_argument("--lora_path", type=str, help="Path to a LoRA file to apply to the model.")
text_group.add_argument("-hf", "--huggingface", action="store_true", help="Load model from Hugging Face Hub")

# Image generation arguments
Expand All @@ -123,7 +124,7 @@ def main():
image_group.add_argument("--wtype", type=str, help="Weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)")
image_group.add_argument("--control_net_path", type=str, help="Path to control net model")
image_group.add_argument("--control_image_path", type=str, help="Path to image condition for Control Net")
image_group.add_argument("--control_strength", type=str, help="Strength to apply Control Net")
image_group.add_argument("--control_strength", type=float, help="Strength to apply Control Net")

# ASR arguments
asr_group = run_parser.add_argument_group('Automatic Speech Recognition options')
Expand Down
45 changes: 41 additions & 4 deletions nexa/gguf/nexa_inference_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,14 @@ class NexaImageInference:
guidance_scale (float): Guidance scale for diffusion.
output_path (str): Output path for the generated image.
random_seed (int): Random seed for image generation.
lora_dir (str): Path to directory containing LoRA files.
lora_path (str): Path to a LoRA file to apply to the model.
wtype (str): Weight type (options: default, f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0).
control_net_path (str): Path to control net model.
control_image_path (str): Path to image condition for Control Net.
control_strength (float): Strength to apply Control Net.
streamlit (bool): Run the inference in Streamlit UI.
profiling (bool): Enable profiling logs for the inference process.
"""

def __init__(self, model_path, local_path=None, **kwargs):
Expand Down Expand Up @@ -98,6 +104,7 @@ def __init__(self, model_path, local_path=None, **kwargs):
else:
self.params = DEFAULT_IMG_GEN_PARAMS.copy()

self.profiling = kwargs.get("profiling", False)
self.params.update({k: v for k, v in kwargs.items() if v is not None})
if not kwargs.get("streamlit", False):
self._load_model(model_path)
Expand All @@ -119,7 +126,7 @@ def _load_model(self, model_path: str):
wtype=self.params.get(
"wtype", NEXA_RUN_MODEL_PRECISION_MAP.get(model_path, "default")
), # Weight type (options: default, f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)
verbose=False,
verbose=self.profiling,
)
else:
self.model = StableDiffusion(
Expand All @@ -130,7 +137,7 @@ def _load_model(self, model_path: str):
"wtype", NEXA_RUN_MODEL_PRECISION_MAP.get(model_path, "default")
), # Weight type (options: default, f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)
control_net_path=self.params.get("control_net_path", ""),
verbose=False,
verbose=self.profiling,
)

def _save_images(self, images):
Expand Down Expand Up @@ -352,13 +359,43 @@ def run_streamlit(self, model_path: str):
default=0,
help="Random seed for image generation",
)
# parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help="Device to run the model on (default: cuda if available, else cpu)")
parser.add_argument(
"--lora_dir",
type=str,
help="Path to directory containing LoRA files.",
)
parser.add_argument(
"--wtype",
type=str,
help="Weight type (options: default, f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)",
)
parser.add_argument(
"--control_net_path",
type=str,
help="Path to control net model.",
)
parser.add_argument(
"--control_image_path",
type=str,
help="Path to image condition for Control Net.",
)
parser.add_argument(
"--control_strength",
type=float,
help="Strength to apply Control Net.",
)
parser.add_argument(
"-st",
"--streamlit",
action="store_true",
help="Run the inference in Streamlit UI",
)
parser.add_argument(
"-pf",
"--profiling",
action="store_true",
help="Enable profiling logs for the inference process",
)
args = parser.parse_args()
kwargs = {k: v for k, v in vars(args).items() if v is not None}
model_path = kwargs.pop("model_path")
Expand Down
14 changes: 7 additions & 7 deletions nexa/gguf/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _load_model(self):
chat_format=self.chat_format,
n_ctx=2048,
n_gpu_layers=-1 if is_gpu_available() else 0,
lora_path=self.params.get("lora_path", ""),
)
except Exception as e:
logging.error(f"Failed to load model: {e}. Falling back to CPU.", exc_info=True)
Expand All @@ -116,6 +117,7 @@ def _load_model(self):
chat_format=self.chat_format,
n_ctx=2048,
n_gpu_layers=0, # hardcode to use CPU
lora_path=self.params.get("lora_path", ""),
)

load_time = time.time() - start_time
Expand Down Expand Up @@ -331,13 +333,11 @@ def run_streamlit(self, model_path: str):
action="store_true",
help="Run the inference in Streamlit UI",
)
# parser.add_argument(
# "-tlps",
# "--top_logprobs",
# type=int,
# default=None, # -tlps 5
# help="Number of most likely tokens to return at each token position",
# )
parser.add_argument(
"--lora_path",
type=str,
help="Path to a LoRA file to apply to the model.",
)
args = parser.parse_args()
kwargs = {k: v for k, v in vars(args).items() if v is not None}
model_path = kwargs.pop("model_path")
Expand Down

0 comments on commit 6f30ff0

Please sign in to comment.