Skip to content

Commit

Permalink
update onnx interface
Browse files Browse the repository at this point in the history
  • Loading branch information
xyyimian committed Aug 21, 2024
1 parent dd49f40 commit a0b2358
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 47 deletions.
2 changes: 1 addition & 1 deletion nexa/gguf/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class NexaTextInference:
top_k (int): Top-k sampling parameter.
top_p (float): Top-p sampling parameter
"""
from nexa.gguf.llama import Llama
from nexa.gguf.llama.llama import Llama
def __init__(self, model_path, stop_words=None, **kwargs):
self.params = DEFAULT_TEXT_GEN_PARAMS
self.params.update(kwargs)
Expand Down
54 changes: 30 additions & 24 deletions nexa/onnx/nexa_inference_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,44 +105,50 @@ def _dialogue_mode(self):
negative_prompt = nexa_prompt(
"Enter your negative prompt (press Enter to skip): "
)
self._generate_images(prompt, negative_prompt)
images = self.generate_images(prompt, negative_prompt)
self._save_images(images)
except KeyboardInterrupt:
print(EXIT_REMINDER)
except Exception as e:
logging.error(f"Error during text generation: {e}", exc_info=True)

def _generate_images(self, prompt, negative_prompt):
def generate_images(self, prompt, negative_prompt):
"""
Generate images based on the given prompt, negative prompt, and parameters.
Used for SDK. Generate images based on the given prompt, negative prompt, and parameters.
Arg:
prompt (str): Prompt for the image generation.
negative_prompt (str): Negative prompt for the image generation.
Returns:
list: List of generated images.
"""
if self.pipeline is None:
logging.error("Model not loaded. Exiting.")
return

generator = np.random.RandomState(self.params["random_seed"])

try:
is_lcm_pipeline = isinstance(
self.pipeline, ORTLatentConsistencyModelPipeline
)
is_lcm_pipeline = isinstance(
self.pipeline, ORTLatentConsistencyModelPipeline
)

pipeline_kwargs = {
"prompt": prompt,
"num_inference_steps": self.params["num_inference_steps"],
"num_images_per_prompt": self.params["num_images_per_prompt"],
"height": self.params["height"],
"width": self.params["width"],
"generator": generator,
"guidance_scale": self.params["guidance_scale"],
}
if not is_lcm_pipeline and negative_prompt:
pipeline_kwargs["negative_prompt"] = negative_prompt

images = self.pipeline(**pipeline_kwargs).images

self._save_images(images)
except Exception as e:
logging.error(f"Error during image generation: {e}")
pipeline_kwargs = {
"prompt": prompt,
"num_inference_steps": self.params["num_inference_steps"],
"num_images_per_prompt": self.params["num_images_per_prompt"],
"height": self.params["height"],
"width": self.params["width"],
"generator": generator,
"guidance_scale": self.params["guidance_scale"],
}
if not is_lcm_pipeline and negative_prompt:
pipeline_kwargs["negative_prompt"] = negative_prompt

images = self.pipeline(**pipeline_kwargs).images
return images



def _save_images(self, images):
"""
Expand Down
20 changes: 10 additions & 10 deletions nexa/onnx/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ class NexaTextInference:
A class used for load text models and run text generation.
Methods:
run: Run the text generation loop.
run_streamlit: Run the Streamlit UI.
run: Run the text generation loop.
run_streamlit: Run the Streamlit UI.
Args:
model_path (str): Path or identifier for the model in Nexa Model Hub.
profiling (bool): Enable timing measurements for the generation process.
streamlit (bool): Run the inference in Streamlit UI.
temperature (float): Temperature for sampling.
min_new_tokens (int): Minimum number of new tokens to generate.
max_new_tokens (int): Maximum number of new tokens to generate.
top_k (int): Top-k sampling parameter.
top_p (float): Top-p sampling parameter
model_path (str): Path or identifier for the model in Nexa Model Hub.
profiling (bool): Enable timing measurements for the generation process.
streamlit (bool): Run the inference in Streamlit UI.
temperature (float): Temperature for sampling.
min_new_tokens (int): Minimum number of new tokens to generate.
max_new_tokens (int): Maximum number of new tokens to generate.
top_k (int): Top-k sampling parameter.
top_p (float): Top-p sampling parameter
"""

def __init__(self, model_path, **kwargs):
Expand Down
35 changes: 23 additions & 12 deletions nexa/onnx/nexa_inference_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ class NexaTTSInference:
A class used for loading text-to-speech models and running text-to-speech generation.
Methods:
run: Run the text-to-speech generation loop.
run_streamlit: Run the Streamlit UI.
run: Run the text-to-speech generation loop.
run_streamlit: Run the Streamlit UI.
Args:
model_path (str): Path or identifier for the model in Nexa Model Hub.
output_dir (str): Output directory for tts.
sampling_rate (int): Sampling rate for audio processing.
streamlit (bool): Run the inference in Streamlit UI.
model_path (str): Path or identifier for the model in Nexa Model Hub.
output_dir (str): Output directory for tts.
sampling_rate (int): Sampling rate for audio processing.
streamlit (bool): Run the inference in Streamlit UI.
"""

def __init__(self, model_path, **kwargs):
Expand Down Expand Up @@ -71,19 +71,30 @@ def run(self):
while True:
try:
user_input = nexa_prompt("Enter text to generate audio: ")
self._audio_generation(user_input)
outputs = self.audio_generation(user_input)
self._save_audio(
outputs[0], self.params["sampling_rate"], self.params["output_path"]
)
logging.info(f"Audio saved to {self.params['output_path']}")
except KeyboardInterrupt:
print(EXIT_REMINDER)
except Exception as e:
logging.error(f"Error during text generation: {e}", exc_info=True)

def _audio_generation(self, user_input):
def audio_generation(self, user_input):
"""
Used for SDK. Generate audio from the user input.
Args:
user_input (str): User input for audio generation.
Returns:
np.array: Audio data.
"""
inputs = self.tokenizer(user_input)
outputs = self.model.run(None, {"text": inputs})
self._save_audio(
outputs[0], self.params["sampling_rate"], self.params["output_path"]
)
logging.info(f"Audio saved to {self.params['output_path']}")
return outputs


def _save_audio(self, audio_data, sampling_rate, output_path):
os.makedirs(output_path, exist_ok=True)
Expand Down

0 comments on commit a0b2358

Please sign in to comment.