diff --git a/README.md b/README.md index 60e21792..73f517e6 100644 --- a/README.md +++ b/README.md @@ -340,4 +340,5 @@ We would like to thank the following projects: - [llama.cpp](https://github.com/ggerganov/llama.cpp) - [stable-diffusion.cpp](https://github.com/leejet/stable-diffusion.cpp) +- [bark.cpp](https://github.com/PABannier/bark.cpp) - [optimum](https://github.com/huggingface/optimum) diff --git a/nexa/onnx/nexa_inference_image.py b/nexa/onnx/nexa_inference_image.py index 8392aa31..b38be335 100644 --- a/nexa/onnx/nexa_inference_image.py +++ b/nexa/onnx/nexa_inference_image.py @@ -64,10 +64,8 @@ def __init__(self, model_path=None, local_path=None, **kwargs): self.params.update(kwargs) self.pipeline = None - def run(self): - if self.download_onnx_folder is None: - self.download_onnx_folder, run_type = pull_model(self.model_path, **kwargs) + self.download_onnx_folder, _ = pull_model(self.model_path, **kwargs) if self.download_onnx_folder is None: logging.error( @@ -76,17 +74,19 @@ def run(self): ) exit(1) - self._load_model(self.download_onnx_folder) + self._load_model() + + def run(self): self._dialogue_mode() @SpinningCursorAnimation() - def _load_model(self, model_path): + def _load_model(self): """ Load the model from the given model path using the appropriate pipeline. """ - logging.debug(f"Loading model from {model_path}") + logging.debug(f"Loading model from {self.download_onnx_folder}") try: - model_index_path = os.path.join(model_path, "model_index.json") + model_index_path = os.path.join(self.download_onnx_folder, "model_index.json") with open(model_index_path, "r") as index_file: model_index = json.load(index_file) @@ -96,7 +96,7 @@ def _load_model(self, model_path): PipelineClass = ORT_PIPELINES_MAPPING.get( pipeline_class_name, ORTStableDiffusionPipeline ) - self.pipeline = PipelineClass.from_pretrained(model_path) + self.pipeline = PipelineClass.from_pretrained(self.download_onnx_folder) logging.debug(f"Model loaded successfully using {pipeline_class_name}") except Exception as e: logging.error(f"Error loading model: {e}") diff --git a/nexa/onnx/nexa_inference_text.py b/nexa/onnx/nexa_inference_text.py index fdb6db5f..f9a767e9 100644 --- a/nexa/onnx/nexa_inference_text.py +++ b/nexa/onnx/nexa_inference_text.py @@ -53,9 +53,21 @@ def __init__(self, model_path=None, local_path=None, **kwargs): self.downloaded_onnx_folder = local_path self.timings = kwargs.get("timings", False) self.device = "cpu" + + if self.downloaded_onnx_folder is None: + self.downloaded_onnx_folder, _ = pull_model(self.model_path, **kwargs) + + if self.downloaded_onnx_folder is None: + logging.error( + f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", + exc_info=True, + ) + exit(1) + + self._load_model_and_tokenizer() @SpinningCursorAnimation() - def _load_model_and_tokenizer(self) -> Tuple[Any, Any, Any, bool]: + def _load_model_and_tokenizer(self): logging.debug(f"Loading model from {self.downloaded_onnx_folder}") start_time = time.time() self.tokenizer = AutoTokenizer.from_pretrained(self.downloaded_onnx_folder) @@ -148,18 +160,6 @@ def run(self): if self.params.get("streamlit"): self.run_streamlit() else: - if self.downloaded_onnx_folder is None: - self.downloaded_onnx_folder, run_type = pull_model(self.model_path, **kwargs) - - if self.downloaded_onnx_folder is None: - logging.error( - f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", - exc_info=True, - ) - exit(1) - - self._load_model_and_tokenizer() - if self.model is None or self.tokenizer is None or self.streamer is None: logging.error( "Failed to load model or tokenizer. Exiting.", exc_info=True diff --git a/nexa/onnx/nexa_inference_tts.py b/nexa/onnx/nexa_inference_tts.py index fb1f2f9a..26c6d3e4 100644 --- a/nexa/onnx/nexa_inference_tts.py +++ b/nexa/onnx/nexa_inference_tts.py @@ -50,8 +50,8 @@ def __init__(self, model_path=None, local_path=None, **kwargs): self.downloaded_onnx_folder = local_path if self.downloaded_onnx_folder is None: - self.downloaded_onnx_folder, run_type = pull_model(self.model_path, **kwargs) - + self.downloaded_onnx_folder, _ = pull_model(self.model_path, **kwargs) + if self.downloaded_onnx_folder is None: logging.error( f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.", @@ -69,12 +69,10 @@ def _load_model(self): logging.debug(f"Loading model from {self.downloaded_onnx_folder}") try: self.tokenizer = TTSTokenizer(self.config["token"]["list"]) - print(self.tokenizer) self.model = onnxruntime.InferenceSession( os.path.join(self.downloaded_onnx_folder, "model.onnx"), providers=["CPUExecutionProvider"], ) - print(self.model) logging.debug("Model and tokenizer loaded successfully") except Exception as e: logging.error(f"Error loading model or tokenizer: {e}") diff --git a/nexa/onnx/nexa_inference_voice.py b/nexa/onnx/nexa_inference_voice.py index e6d7d696..c0f56ab4 100644 --- a/nexa/onnx/nexa_inference_voice.py +++ b/nexa/onnx/nexa_inference_voice.py @@ -43,9 +43,8 @@ def __init__(self, model_path=None, local_path=None, **kwargs): self.model = None self.processor = None - def run(self): if self.downloaded_onnx_folder is None: - self.downloaded_onnx_folder, run_type = pull_model(self.model_path, **kwargs) + self.downloaded_onnx_folder, _ = pull_model(self.model_path, **kwargs) if self.downloaded_onnx_folder is None: logging.error( @@ -54,14 +53,16 @@ def run(self): ) exit(1) - self._load_model(self.downloaded_onnx_folder) + self._load_model() + + def run(self): self._dialogue_mode() - def _load_model(self, model_path): - logging.debug(f"Loading model from {model_path}") + def _load_model(self): + logging.debug(f"Loading model from {self.downloaded_onnx_folder}") try: - self.processor = AutoProcessor.from_pretrained(model_path) - self.model = ORTModelForSpeechSeq2Seq.from_pretrained(model_path) + self.processor = AutoProcessor.from_pretrained(self.downloaded_onnx_folder) + self.model = ORTModelForSpeechSeq2Seq.from_pretrained(self.downloaded_onnx_folder) logging.debug("Model and processor loaded successfully") except Exception as e: logging.error(f"Error loading model or processor: {e}")