diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 00000000..3bdcd138 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,37 @@ +name: Python CI + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v3 + with: + submodules: recursive # This will clone the repository with all its submodules + fetch-depth: 0 # This fetches all history so you can access any version of the submodules + + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' # Specify the Python version you want + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install build pytest + - name: Build DLL + run: | + python -m pip install -e . + - name: Run tests + run: | + python -m pytest tests \ No newline at end of file diff --git a/nexa/gguf/nexa_inference_image.py b/nexa/gguf/nexa_inference_image.py index 21714fad..494374d7 100644 --- a/nexa/gguf/nexa_inference_image.py +++ b/nexa/gguf/nexa_inference_image.py @@ -6,7 +6,6 @@ import time from pathlib import Path -from nexa.gguf.sd.stable_diffusion import StableDiffusion from nexa.general import pull_model from nexa.constants import ( DEFAULT_IMG_GEN_PARAMS, @@ -29,21 +28,22 @@ class NexaImageInference: A class used for loading image models and running image generation. Methods: - run_txt2img: Run the text-to-image generation loop. - run_img2img: Run the image-to-image generation loop. - run_streamlit: Run the Streamlit UI. + txt2img: (Used for SDK) Run the text-to-image generation loop. + img2img: (Used for SDK) Run the image-to-image generation loop. + run_streamlit: Run the Streamlit UI. Args: - model_path (str): Path or identifier for the model in Nexa Model Hub. - num_inference_steps (int): Number of inference steps. - width (int): Width of the output image. - height (int): Height of the output image. - guidance_scale (float): Guidance scale for diffusion. - output_path (str): Output path for the generated image. - random_seed (int): Random seed for image generation. - streamlit (bool): Run the inference in Streamlit UI. + model_path (str): Path or identifier for the model in Nexa Model Hub. + num_inference_steps (int): Number of inference steps. + width (int): Width of the output image. + height (int): Height of the output image. + guidance_scale (float): Guidance scale for diffusion. + output_path (str): Output path for the generated image. + random_seed (int): Random seed for image generation. + streamlit (bool): Run the inference in Streamlit UI. """ + def __init__(self, model_path, **kwargs): self.model_path = None @@ -81,9 +81,10 @@ def __init__(self, model_path, **kwargs): logging.error("Failed to load the model or pipeline.") exit(1) - @SpinningCursorAnimation() + # @SpinningCursorAnimation() def _load_model(self, model_path: str): with suppress_stdout_stderr(): + from nexa.gguf.sd.stable_diffusion import StableDiffusion self.model = StableDiffusion( model_path=self.downloaded_path, lora_model_dir=self.params.get("lora_dir", ""), @@ -107,63 +108,104 @@ def _save_images(self, images): file_path = os.path.join(output_dir, file_name) image.save(file_path) logging.info(f"\nImage {i+1} saved to: {file_path}") + + def txt2img(self, + prompt, + negative_prompt="", + cfg_scale=7.5, + width=512, + height=512, + sample_steps=20, + seed=0, + control_cond="", + control_strength=0.9): + """ + Used for SDK. Generate images from text. + + Args: + prompt (str): Prompt for the image generation. + negative_prompt (str): Negative prompt for the image generation. - def loop_txt2img(self): + Returns: + list: List of generated images. + """ + images = self.model.txt_to_img( + prompt=prompt, + negative_prompt=negative_prompt, + cfg_scale=cfg_scale, + width=width, + height=height, + sample_steps=sample_steps, + seed=seed, + control_cond=control_cond, + control_strength=control_strength, + ) + return images + def run_txt2img(self): while True: try: prompt = nexa_prompt("Enter your prompt: ") negative_prompt = nexa_prompt( "Enter your negative prompt (press Enter to skip): " ) - self._txt2img(prompt, negative_prompt) + try: + images = self.txt2img( + prompt, + negative_prompt, + cfg_scale=self.params["guidance_scale"], + width=self.params["width"], + height=self.params["height"], + sample_steps=self.params["num_inference_steps"], + seed=self.params["random_seed"], + control_cond=self.params.get("control_image_path", ""), + control_strength=self.params.get("control_strength", 0.9), + ) + self._save_images(images) + except Exception as e: + logging.error(f"Error during text to image generation: {e}") except KeyboardInterrupt: print(EXIT_REMINDER) except Exception as e: logging.error(f"Error during generation: {e}", exc_info=True) - def _txt2img(self, prompt: str, negative_prompt: str): + def img2img(self, + image_path, + prompt, + negative_prompt="", + cfg_scale=7.5, + width=512, + height=512, + sample_steps=20, + seed=0, + control_cond="", + control_strength=0.9): """ - Generate images based on the given prompt, negative prompt, and parameters. - """ - try: - images = self.model.txt_to_img( - prompt=prompt, - negative_prompt=negative_prompt if negative_prompt else "", - cfg_scale=self.params["guidance_scale"], - width=self.params["width"], - height=self.params["height"], - sample_steps=self.params["num_inference_steps"], - seed=self.params["random_seed"], - control_cond=self.params.get("control_image_path", ""), - control_strength=self.params.get("control_strength", 0.9), - ) - self._save_images(images) - except Exception as e: - logging.error(f"Error during image generation: {e}") + Used for SDK. Generate images from an image. - def loop_img2img(self): - def _generate_images(image_path, prompt, negative_prompt): - """ - Generate images based on the given prompt, negative prompt, and parameters. - """ - try: - images = self.model.img_to_img( - image=image_path, - prompt=prompt, - negative_prompt=negative_prompt if negative_prompt else "", - cfg_scale=self.params["guidance_scale"], - width=self.params["width"], - height=self.params["height"], - sample_steps=self.params["num_inference_steps"], - seed=self.params["random_seed"], - control_cond=self.params.get("control_image_path", ""), - control_strength=self.params.get("control_strength", 0.9), - ) - self._save_images(images) - except Exception as e: - logging.error(f"Error during image generation: {e}") + Args: + image_path (str): Path to the input image. + prompt (str): Prompt for the image generation. + negative_prompt (str): Negative prompt for the image generation. + Returns: + list: List of generated images. + """ + images = self.model.img_to_img( + image=image_path, + prompt=prompt, + negative_prompt=negative_prompt, + cfg_scale=cfg_scale, + width=width, + height=height, + sample_steps=sample_steps, + seed=seed, + control_cond=control_cond, + control_strength=control_strength, + ) + return images + + def run_img2img(self): while True: try: image_path = nexa_prompt("Enter the path to your image: ") @@ -171,7 +213,19 @@ def _generate_images(image_path, prompt, negative_prompt): negative_prompt = nexa_prompt( "Enter your negative prompt (press Enter to skip): " ) - _generate_images(image_path, prompt, negative_prompt) + images = self.img2img(image_path, + prompt, + negative_prompt, + cfg_scale=self.params["guidance_scale"], + width=self.params["width"], + height=self.params["height"], + sample_steps=self.params["num_inference_steps"], + seed=self.params["random_seed"], + control_cond=self.params.get("control_image_path", ""), + control_strength=self.params.get("control_strength", 0.9), + ) + + self._save_images(images) except KeyboardInterrupt: print(EXIT_REMINDER) except Exception as e: @@ -257,6 +311,6 @@ def run_streamlit(self, model_path: str): inference.run_streamlit(model_path) else: if args.img2img: - inference.loop_img2img() + inference.run_img2img() else: - inference.loop_txt2img() + inference.run_txt2img() \ No newline at end of file diff --git a/nexa/gguf/nexa_inference_text.py b/nexa/gguf/nexa_inference_text.py index 253ec79e..fa59e7ee 100644 --- a/nexa/gguf/nexa_inference_text.py +++ b/nexa/gguf/nexa_inference_text.py @@ -14,7 +14,6 @@ ) from nexa.general import pull_model from nexa.gguf.lib_utils import is_gpu_available -from nexa.gguf.llama.llama import Llama from nexa.utils import SpinningCursorAnimation, nexa_prompt, suppress_stdout_stderr logging.basicConfig( @@ -27,19 +26,21 @@ class NexaTextInference: A class used for load text models and run text generation. Methods: - run: Run the text generation loop. - run_streamlit: Run the Streamlit UI. + run: Run the text generation loop. + run_streamlit: Run the Streamlit UI. Args: - model_path (str): Path or identifier for the model in Nexa Model Hub. - stop_words (list): List of stop words for early stopping. - profiling (bool): Enable timing measurements for the generation process. - streamlit (bool): Run the inference in Streamlit UI. - temperature (float): Temperature for sampling. - max_new_tokens (int): Maximum number of new tokens to generate. - top_k (int): Top-k sampling parameter. - top_p (float): Top-p sampling parameter + model_path (str): Path or identifier for the model in Nexa Model Hub. + embedding (bool): Enable embedding generation. + stop_words (list): List of stop words for early stopping. + profiling (bool): Enable timing measurements for the generation process. + streamlit (bool): Run the inference in Streamlit UI. + temperature (float): Temperature for sampling. + max_new_tokens (int): Maximum number of new tokens to generate. + top_k (int): Top-k sampling parameter. + top_p (float): Top-p sampling parameter """ + def __init__(self, model_path, stop_words=None, **kwargs): self.params = DEFAULT_TEXT_GEN_PARAMS self.params.update(kwargs) @@ -83,35 +84,29 @@ def __init__(self, model_path, stop_words=None, **kwargs): "Failed to load model or tokenizer. Exiting.", exc_info=True ) exit(1) - def embed( + def create_embedding( self, input: Union[str, List[str]], - normalize: bool = False, - truncate: bool = True, - return_count: bool = False, ): """Embed a string. Args: input: The utf-8 encoded string or a list of string to embed. - normalize: whether to normalize embedding in embedding dimension. - trunca - truncate: whether to truncate tokens to window length before generating embedding. - return count: if true, return (embedding, count) tuple. else return embedding only. - Returns: A list of embeddings """ - return self.model.embed(input, normalize, truncate, return_count) + return self.model.create_embedding(input) - @SpinningCursorAnimation() + # @SpinningCursorAnimation() def _load_model(self): logging.debug(f"Loading model from {self.downloaded_path}") start_time = time.time() with suppress_stdout_stderr(): try: + from nexa.gguf.llama.llama import Llama self.model = Llama( + embedding=self.params.get("embedding", False), model_path=self.downloaded_path, verbose=self.profiling, chat_format=self.chat_format, @@ -144,6 +139,9 @@ def _load_model(self): self.conversation_history = [] if self.chat_format else None def run(self): + """ + CLI interactive session. Not for SDK. + """ while True: generated_text = "" try: @@ -191,6 +189,44 @@ def run(self): except Exception as e: logging.error(f"Error during generation: {e}", exc_info=True) print("\n") + + def create_chat_completion(self, messages, temperature=0.7, max_tokens=2048, top_k=50, top_p=1.0, stream=False, stop=None): + """ + Used for SDK. Generate completion for a chat conversation. + + Args: + messages (list): List of messages in the conversation. + temperature (float): Temperature for sampling. + max_tokens (int): Maximum number of new tokens to generate. + top_k (int): Top-k sampling parameter. + top_p (float): Top-p sampling parameter. + stream (bool): Stream the output. + stop (list): List of stop words for early stopping. + + Returns: + Iterator: Iterator for the completion. + """ + return self.model.create_chat_completion(messages=messages, temperature=temperature, max_tokens=max_tokens, top_k=top_k, top_p=top_p, stream=stream, stop=stop) + + def create_completion(self, prompt, temperature=0.7, max_tokens=2048, top_k=50, top_p=1.0, echo=False, stream=False, stop=None): + """ + Used for SDK. Generate completion for a given prompt. + + Args: + prompt (str): Prompt for the completion. + temperature (float): Temperature for sampling. + max_tokens (int): Maximum number of new tokens to generate. + top_k (int): Top-k sampling parameter. + top_p (float): Top-p sampling parameter. + echo (bool): Echo the prompt back in the output. + stream (bool): Stream the output. + stop (list): List of stop words for early stopping. + + Returns: + Iterator: Iterator for the completion. + """ + return self.model.create_completion(prompt=prompt, temperature=temperature, max_tokens=max_tokens, top_k=top_k, top_p=top_p, echo=echo, stream=stream, stop=stop) + def _chat(self, user_input: str) -> Iterator: current_messages = self.conversation_history + [{"role": "user", "content": user_input}] @@ -223,7 +259,7 @@ def _complete(self, user_input: str) -> Iterator: def run_streamlit(self, model_path: str): """ - Run the Streamlit UI. + Used for CLI. Run the Streamlit UI. """ logging.info("Running Streamlit UI...") diff --git a/nexa/gguf/nexa_inference_vlm.py b/nexa/gguf/nexa_inference_vlm.py index 38b419b7..27c057be 100644 --- a/nexa/gguf/nexa_inference_vlm.py +++ b/nexa/gguf/nexa_inference_vlm.py @@ -19,7 +19,6 @@ ) from nexa.general import pull_model from nexa.gguf.lib_utils import is_gpu_available -from nexa.gguf.llama.llama import Llama from nexa.gguf.llama.llama_chat_format import ( Llava15ChatHandler, Llava16ChatHandler, @@ -87,6 +86,8 @@ class NexaVLMInference: top_k (int): Top-k sampling parameter. top_p (float): Top-p sampling parameter """ + + def __init__(self, model_path, stop_words=None, **kwargs): self.params = DEFAULT_TEXT_GEN_PARAMS self.params.update(kwargs) @@ -145,7 +146,7 @@ def __init__(self, model_path, stop_words=None, **kwargs): ) exit(1) - @SpinningCursorAnimation() + # @SpinningCursorAnimation() def _load_model(self): logging.debug(f"Loading model from {self.downloaded_path}") start_time = time.time() @@ -158,6 +159,7 @@ def _load_model(self): else None ) try: + from nexa.gguf.llama.llama import Llama self.model = Llama( model_path=self.downloaded_path, chat_handler=self.projector, @@ -238,6 +240,64 @@ def run(self): except Exception as e: logging.error(f"Error during generation: {e}", exc_info=True) print("\n") + + def create_chat_completion(self, + messages, + max_tokens:int = 2048, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + stream=False, + stop=[]): + """ + Generate text completion for a given chat prompt. + + Args: + messages (list): List of messages in the chat prompt. + temperature (float): Temperature for sampling. + max_tokens (int): Maximum number of tokens to generate. + top_k (int): Top-k sampling parameter. + top_p (float): Top-p sampling parameter. + stream (bool): Stream the output. + stop (list): List of stop words for early stopping. + + Returns: + Iterator: An iterator of the generated text completion + return format: + { + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "The 2020 World Series was played in Texas at Globe Life Field in Arlington.", + "role": "assistant" + }, + "logprobs": null + } + ], + "created": 1677664795, + "id": "chatcmpl-7QyqpwdfhqwajicIEznoc6Q47XAyW", + "model": "gpt-4o-mini", + "object": "chat.completion", + "usage": { + "completion_tokens": 17, + "prompt_tokens": 57, + "total_tokens": 74 + } + } + usage: message = completion.choices[0].message.content + + """ + return self.model.create_chat_completion( + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + top_k=top_k, + top_p=top_p, + stream=stream, + stop=stop, + ) def _chat(self, user_input: str, image_path: str = None) -> Iterator: data_uri = image_to_base64_data_uri(image_path) if image_path else None diff --git a/nexa/gguf/nexa_inference_voice.py b/nexa/gguf/nexa_inference_voice.py index 7997501a..fc8034e3 100644 --- a/nexa/gguf/nexa_inference_voice.py +++ b/nexa/gguf/nexa_inference_voice.py @@ -21,18 +21,18 @@ class NexaVoiceInference: A class used for loading voice models and running voice transcription. Methods: - run: Run the voice transcription loop. - run_streamlit: Run the Streamlit UI. + run: Run the voice transcription loop. + run_streamlit: Run the Streamlit UI. Args: - model_path (str): Path or identifier for the model in Nexa Model Hub. - output_dir (str): Output directory for transcriptions. - beam_size (int): Beam size to use for transcription. - language (str): The language spoken in the audio. - task (str): Task to execute (transcribe or translate). - temperature (float): Temperature for sampling. - compute_type (str): Type to use for computation (e.g., float16, int8, int8_float16). - output_dir (str): Output directory for transcriptions. + model_path (str): Path or identifier for the model in Nexa Model Hub. + output_dir (str): Output directory for transcriptions. + beam_size (int): Beam size to use for transcription. + language (str): The language spoken in the audio. + task (str): Task to execute (transcribe or translate). + temperature (float): Temperature for sampling. + compute_type (str): Type to use for computation (e.g., float16, int8, int8_float16). + output_dir (str): Output directory for transcriptions. """ def __init__(self, model_path, **kwargs): @@ -68,7 +68,8 @@ def __init__(self, model_path, **kwargs): ) exit(1) - @SpinningCursorAnimation() + + # @SpinningCursorAnimation() def _load_model(self): from faster_whisper import WhisperModel @@ -90,6 +91,87 @@ def run(self): print(EXIT_REMINDER) except Exception as e: logging.error(f"Error during text generation: {e}", exc_info=True) + + def transcribe(self, audio, **kwargs): + """ + Transcribe the audio file. + + Arguments: + audio: Path to the input file (or a file-like object), or the audio waveform. + language: The language spoken in the audio. It should be a language code such + as "en" or "fr". If not set, the language will be detected in the first 30 seconds + of audio. + task: Task to execute (transcribe or translate). + beam_size: Beam size to use for decoding. + best_of: Number of candidates when sampling with non-zero temperature. + patience: Beam search patience factor. + length_penalty: Exponential length penalty constant. + repetition_penalty: Penalty applied to the score of previously generated tokens + (set > 1 to penalize). + no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable). + temperature: Temperature for sampling. It can be a tuple of temperatures, + which will be successively used upon failures according to either + `compression_ratio_threshold` or `log_prob_threshold`. + compression_ratio_threshold: If the gzip compression ratio is above this value, + treat as failed. + log_prob_threshold: If the average log probability over sampled tokens is + below this value, treat as failed. + no_speech_threshold: If the no_speech probability is higher than this value AND + the average log probability over sampled tokens is below `log_prob_threshold`, + consider the segment as silent. + condition_on_previous_text: If True, the previous output of the model is provided + as a prompt for the next window; disabling may make the text inconsistent across + windows, but the model becomes less prone to getting stuck in a failure loop, + such as repetition looping or timestamps going out of sync. + prompt_reset_on_temperature: Resets prompt if temperature is above this value. + Arg has effect only if condition_on_previous_text is True. + initial_prompt: Optional text string or iterable of token ids to provide as a + prompt for the first window. + prefix: Optional text to provide as a prefix for the first window. + suppress_blank: Suppress blank outputs at the beginning of the sampling. + suppress_tokens: List of token IDs to suppress. -1 will suppress a default set + of symbols as defined in the model config.json file. + without_timestamps: Only sample text tokens. + max_initial_timestamp: The initial timestamp cannot be later than this. + word_timestamps: Extract word-level timestamps using the cross-attention pattern + and dynamic time warping, and include the timestamps for each word in each segment. + prepend_punctuations: If word_timestamps is True, merge these punctuation symbols + with the next word + append_punctuations: If word_timestamps is True, merge these punctuation symbols + with the previous word + vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio + without speech. This step is using the Silero VAD model + https://github.com/snakers4/silero-vad. + vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available + parameters and default values in the class `VadOptions`). + max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set, + the maximum will be set by the default max_length. + chunk_length: The length of audio segments. If it is not None, it will overwrite the + default chunk_length of the FeatureExtractor. + clip_timestamps: + Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to + process. The last end timestamp defaults to the end of the file. + vad_filter will be ignored if clip_timestamps is used. + hallucination_silence_threshold: + When word_timestamps is True, skip silent periods longer than this threshold + (in seconds) when a possible hallucination is detected + hotwords: + Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None. + language_detection_threshold: If the maximum probability of the language tokens is higher + than this value, the language is detected. + language_detection_segments: Number of segments to consider for the language detection. + + Returns: + A tuple with: + + - a generator over transcribed segments + - an instance of TranscriptionInfo + """ + return self.model.transcribe( + audio, + **kwargs, + ) + def _transcribe_audio(self, audio_path): logging.debug(f"Transcribing audio from: {audio_path}") diff --git a/nexa/onnx/nexa_inference_image.py b/nexa/onnx/nexa_inference_image.py index b2dc3afb..87d1f959 100644 --- a/nexa/onnx/nexa_inference_image.py +++ b/nexa/onnx/nexa_inference_image.py @@ -104,15 +104,23 @@ def _dialogue_mode(self): negative_prompt = nexa_prompt( "Enter your negative prompt (press Enter to skip): " ) - self._generate_images(prompt, negative_prompt) + images = self.generate_images(prompt, negative_prompt) + self._save_images(images) except KeyboardInterrupt: print(EXIT_REMINDER) except Exception as e: logging.error(f"Error during text generation: {e}", exc_info=True) - def _generate_images(self, prompt, negative_prompt): + def generate_images(self, prompt, negative_prompt): """ - Generate images based on the given prompt, negative prompt, and parameters. + Used for SDK. Generate images based on the given prompt, negative prompt, and parameters. + + Arg: + prompt (str): Prompt for the image generation. + negative_prompt (str): Negative prompt for the image generation. + + Returns: + list: List of generated images. """ if self.pipeline is None: logging.error("Model not loaded. Exiting.") @@ -120,28 +128,26 @@ def _generate_images(self, prompt, negative_prompt): generator = np.random.RandomState(self.params["random_seed"]) - try: - is_lcm_pipeline = isinstance( - self.pipeline, ORTLatentConsistencyModelPipeline - ) + is_lcm_pipeline = isinstance( + self.pipeline, ORTLatentConsistencyModelPipeline + ) - pipeline_kwargs = { - "prompt": prompt, - "num_inference_steps": self.params["num_inference_steps"], - "num_images_per_prompt": self.params["num_images_per_prompt"], - "height": self.params["height"], - "width": self.params["width"], - "generator": generator, - "guidance_scale": self.params["guidance_scale"], - } - if not is_lcm_pipeline and negative_prompt: - pipeline_kwargs["negative_prompt"] = negative_prompt - - images = self.pipeline(**pipeline_kwargs).images - - self._save_images(images) - except Exception as e: - logging.error(f"Error during image generation: {e}") + pipeline_kwargs = { + "prompt": prompt, + "num_inference_steps": self.params["num_inference_steps"], + "num_images_per_prompt": self.params["num_images_per_prompt"], + "height": self.params["height"], + "width": self.params["width"], + "generator": generator, + "guidance_scale": self.params["guidance_scale"], + } + if not is_lcm_pipeline and negative_prompt: + pipeline_kwargs["negative_prompt"] = negative_prompt + + images = self.pipeline(**pipeline_kwargs).images + return images + + def _save_images(self, images): """ diff --git a/nexa/onnx/nexa_inference_text.py b/nexa/onnx/nexa_inference_text.py index f2f94a3c..56f5c09b 100644 --- a/nexa/onnx/nexa_inference_text.py +++ b/nexa/onnx/nexa_inference_text.py @@ -20,18 +20,18 @@ class NexaTextInference: A class used for load text models and run text generation. Methods: - run: Run the text generation loop. - run_streamlit: Run the Streamlit UI. + run: Run the text generation loop. + run_streamlit: Run the Streamlit UI. Args: - model_path (str): Path or identifier for the model in Nexa Model Hub. - profiling (bool): Enable timing measurements for the generation process. - streamlit (bool): Run the inference in Streamlit UI. - temperature (float): Temperature for sampling. - min_new_tokens (int): Minimum number of new tokens to generate. - max_new_tokens (int): Maximum number of new tokens to generate. - top_k (int): Top-k sampling parameter. - top_p (float): Top-p sampling parameter + model_path (str): Path or identifier for the model in Nexa Model Hub. + profiling (bool): Enable timing measurements for the generation process. + streamlit (bool): Run the inference in Streamlit UI. + temperature (float): Temperature for sampling. + min_new_tokens (int): Minimum number of new tokens to generate. + max_new_tokens (int): Maximum number of new tokens to generate. + top_k (int): Top-k sampling parameter. + top_p (float): Top-p sampling parameter """ def __init__(self, model_path, **kwargs): diff --git a/nexa/onnx/nexa_inference_tts.py b/nexa/onnx/nexa_inference_tts.py index e7167ee6..ff7093d6 100644 --- a/nexa/onnx/nexa_inference_tts.py +++ b/nexa/onnx/nexa_inference_tts.py @@ -23,14 +23,14 @@ class NexaTTSInference: A class used for loading text-to-speech models and running text-to-speech generation. Methods: - run: Run the text-to-speech generation loop. - run_streamlit: Run the Streamlit UI. + run: Run the text-to-speech generation loop. + run_streamlit: Run the Streamlit UI. Args: - model_path (str): Path or identifier for the model in Nexa Model Hub. - output_dir (str): Output directory for tts. - sampling_rate (int): Sampling rate for audio processing. - streamlit (bool): Run the inference in Streamlit UI. + model_path (str): Path or identifier for the model in Nexa Model Hub. + output_dir (str): Output directory for tts. + sampling_rate (int): Sampling rate for audio processing. + streamlit (bool): Run the inference in Streamlit UI. """ def __init__(self, model_path, **kwargs): @@ -71,19 +71,30 @@ def run(self): while True: try: user_input = nexa_prompt("Enter text to generate audio: ") - self._audio_generation(user_input) + outputs = self.audio_generation(user_input) + self._save_audio( + outputs[0], self.params["sampling_rate"], self.params["output_path"] + ) + logging.info(f"Audio saved to {self.params['output_path']}") except KeyboardInterrupt: print(EXIT_REMINDER) except Exception as e: logging.error(f"Error during text generation: {e}", exc_info=True) - def _audio_generation(self, user_input): + def audio_generation(self, user_input): + """ + Used for SDK. Generate audio from the user input. + + Args: + user_input (str): User input for audio generation. + + Returns: + np.array: Audio data. + """ inputs = self.tokenizer(user_input) outputs = self.model.run(None, {"text": inputs}) - self._save_audio( - outputs[0], self.params["sampling_rate"], self.params["output_path"] - ) - logging.info(f"Audio saved to {self.params['output_path']}") + return outputs + def _save_audio(self, audio_data, sampling_rate, output_path): os.makedirs(output_path, exist_ok=True) diff --git a/tests/test_image_generation.py b/tests/test_image_generation.py index 6c9d5b21..7e749dc6 100644 --- a/tests/test_image_generation.py +++ b/tests/test_image_generation.py @@ -1,47 +1,34 @@ -import os -from nexa.gguf.sd import stable_diffusion -from tests.utils import download_model +from nexa.gguf import NexaImageInference from tempfile import TemporaryDirectory +from .utils import download_model -# Constants -STABLE_DIFFUSION_URL = "https://huggingface.co/second-state/stable-diffusion-v-1-4-GGUF/resolve/main/stable-diffusion-v1-4-Q4_0.gguf" -IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" -OUTPUT_DIR = os.getcwd() -MODEL_PATH = download_model(STABLE_DIFFUSION_URL, OUTPUT_DIR) +sd = NexaImageInference( + model_path="sd1-4", + wtype="q4_0", +) -# Print the model path -print("Model downloaded to:", MODEL_PATH) - -# Helper function for Stable Diffusion initialization -def init_stable_diffusion(): - return stable_diffusion.StableDiffusion( - model_path=MODEL_PATH, - wtype="q4_0" # Weight type (options: default, f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0) - ) - # Test text-to-image generation def test_txt_to_img(): - sd = init_stable_diffusion() - output = sd.txt_to_img("a lovely cat", width=128, height=128, sample_steps=2) + global sd + output = sd.txt2img("a lovely cat", width=128, height=128, sample_steps=2) output[0].save("output_txt_to_img.png") # Test image-to-image generation def test_img_to_img(): - sd = init_stable_diffusion() + global sd img_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" with TemporaryDirectory() as temp_dir: img_path = download_model(img_url, temp_dir) - output = sd.img_to_img( - image=img_path, + output = sd.img2img( + image_path=img_path, prompt="blue sky", width=128, height=128, negative_prompt="black soil", sample_steps=2 ) - output[0].save("output_img_to_img.png") # Main execution # if __name__ == "__main__": diff --git a/tests/test_text_generation.py b/tests/test_text_generation.py index f37a4781..e3ceed30 100644 --- a/tests/test_text_generation.py +++ b/tests/test_text_generation.py @@ -1,36 +1,28 @@ -import os -from nexa.gguf.llama import llama -from tests.utils import download_model +from nexa.gguf import NexaTextInference +from nexa.gguf.lib_utils import is_gpu_available -# Constants -TINY_LLAMA_URL = "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q4_0.gguf" -OUTPUT_DIR = os.getcwd() -MODEL_PATH = download_model(TINY_LLAMA_URL, OUTPUT_DIR) - -# Initialize Llama model -def init_llama_model(verbose=False, n_gpu_layers=-1, chat_format=None, embedding=False): - return llama.Llama( - model_path=MODEL_PATH, - verbose=verbose, - n_gpu_layers=n_gpu_layers, - chat_format=chat_format, - embedding=embedding, - ) +model = NexaTextInference( + model_path="gemma", + verbose=False, + n_gpu_layers=-1 if is_gpu_available() else 0, + chat_format="llama-2", +) # Test text generation from a prompt def test_text_generation(): - model = init_llama_model() - output = model( + global model + output = model.create_completion( "Q: Name the planets in the solar system? A: ", max_tokens=512, stop=["Q:", "\n"], echo=True, ) - print(output) + # print(output) + # TODO: add assertions here # Test chat completion in streaming mode def test_streaming(): - model = init_llama_model() + global model output = model.create_completion( "Q: Name the planets in the solar system? A: ", max_tokens=512, @@ -40,10 +32,12 @@ def test_streaming(): for chunk in output: if "choices" in chunk: print(chunk["choices"][0]["text"], end="", flush=True) + # TODO: add assertions here # Test conversation mode with chat format def test_create_chat_completion(): - model = init_llama_model(chat_format="llama-2") + global model + output = model.create_chat_completion( messages=[ {"role": "user", "content": "write a long 1000 word story about a detective"} @@ -58,7 +52,13 @@ def test_create_chat_completion(): print(delta["content"], end="", flush=True) def test_create_embedding(): - model = init_llama_model(embedding=True) + model = NexaTextInference( + model_path="gemma", + verbose=False, + n_gpu_layers=-1 if is_gpu_available() else 0, + chat_format="llama-2", + embedding=True, + ) embeddings = model.create_embedding("Hello, world!") print("Embeddings:\n", embeddings) diff --git a/tests/test_vlm.py b/tests/test_vlm.py deleted file mode 100644 index b70389be..00000000 --- a/tests/test_vlm.py +++ /dev/null @@ -1,68 +0,0 @@ -import base64 -import os - -from nexa.gguf.llama import llama -from nexa.gguf.llama.llama_chat_format import NanoLlavaChatHandler -from tests.utils import download_model - -def image_to_base64_data_uri(file_path): - """ - file_path = 'file_path.png' - data_uri = image_to_base64_data_uri(file_path) - """ - with open(file_path, "rb") as img_file: - base64_data = base64.b64encode(img_file.read()).decode("utf-8") - return f"data:image/png;base64,{base64_data}" - -model_url = "https://nexa-model-hub-bucket.s3.us-west-1.amazonaws.com/public/nanoLLaVA/model-fp16.gguf" -mmproj_url = "https://nexa-model-hub-bucket.s3.us-west-1.amazonaws.com/public/nanoLLaVA/projector-fp16.gguf" - -# Download paths -output_dir = os.getcwd() -model_path = download_model(model_url, output_dir) -mmproj_path = download_model(mmproj_url, output_dir) -print("Model downloaded to:", model_path) -print("MMProj downloaded to:", mmproj_path) - -chat_handler = NanoLlavaChatHandler(clip_model_path=mmproj_path) - -def test_image_generation(): - llm = llama.Llama( - model_path=model_path, - chat_handler=chat_handler, - n_ctx=2048, # n_ctx should be increased to accommodate the image embedding - n_gpu_layers=-1, # Uncomment to use GPU acceleration - verbose=False, - ) - output = llm.create_chat_completion( - messages=[ - { - "role": "system", - "content": "You are an assistant who perfectly describes images.", - }, - { - "role": "user", - "content": [ - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": { - "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" - }, - }, - ], - }, - ], - stream=True, - ) - for chunk in output: - delta = chunk["choices"][0]["delta"] - if "role" in delta: - print(delta["role"], end=": ") - elif "content" in delta: - print(delta["content"], end="") - - -# if __name__ == "__main__": -# print("=== Testing 1 ===") -# test1()