Skip to content

Commit

Permalink
Merge pull request #213 from NexaAI/david/bugfix
Browse files Browse the repository at this point in the history
fix onnx python interface loading bug
  • Loading branch information
zhiyuan8 authored Nov 5, 2024
2 parents eb7dd2c + 165cbdd commit c60d918
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 32 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -340,4 +340,5 @@ We would like to thank the following projects:

- [llama.cpp](https://github.com/ggerganov/llama.cpp)
- [stable-diffusion.cpp](https://github.com/leejet/stable-diffusion.cpp)
- [bark.cpp](https://github.com/PABannier/bark.cpp)
- [optimum](https://github.com/huggingface/optimum)
16 changes: 8 additions & 8 deletions nexa/onnx/nexa_inference_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,8 @@ def __init__(self, model_path=None, local_path=None, **kwargs):
self.params.update(kwargs)
self.pipeline = None

def run(self):

if self.download_onnx_folder is None:
self.download_onnx_folder, run_type = pull_model(self.model_path, **kwargs)
self.download_onnx_folder, _ = pull_model(self.model_path, **kwargs)

if self.download_onnx_folder is None:
logging.error(
Expand All @@ -76,17 +74,19 @@ def run(self):
)
exit(1)

self._load_model(self.download_onnx_folder)
self._load_model()

def run(self):
self._dialogue_mode()

@SpinningCursorAnimation()
def _load_model(self, model_path):
def _load_model(self):
"""
Load the model from the given model path using the appropriate pipeline.
"""
logging.debug(f"Loading model from {model_path}")
logging.debug(f"Loading model from {self.download_onnx_folder}")
try:
model_index_path = os.path.join(model_path, "model_index.json")
model_index_path = os.path.join(self.download_onnx_folder, "model_index.json")
with open(model_index_path, "r") as index_file:
model_index = json.load(index_file)

Expand All @@ -96,7 +96,7 @@ def _load_model(self, model_path):
PipelineClass = ORT_PIPELINES_MAPPING.get(
pipeline_class_name, ORTStableDiffusionPipeline
)
self.pipeline = PipelineClass.from_pretrained(model_path)
self.pipeline = PipelineClass.from_pretrained(self.download_onnx_folder)
logging.debug(f"Model loaded successfully using {pipeline_class_name}")
except Exception as e:
logging.error(f"Error loading model: {e}")
Expand Down
26 changes: 13 additions & 13 deletions nexa/onnx/nexa_inference_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,21 @@ def __init__(self, model_path=None, local_path=None, **kwargs):
self.downloaded_onnx_folder = local_path
self.timings = kwargs.get("timings", False)
self.device = "cpu"

if self.downloaded_onnx_folder is None:
self.downloaded_onnx_folder, _ = pull_model(self.model_path, **kwargs)

if self.downloaded_onnx_folder is None:
logging.error(
f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.",
exc_info=True,
)
exit(1)

self._load_model_and_tokenizer()

@SpinningCursorAnimation()
def _load_model_and_tokenizer(self) -> Tuple[Any, Any, Any, bool]:
def _load_model_and_tokenizer(self):
logging.debug(f"Loading model from {self.downloaded_onnx_folder}")
start_time = time.time()
self.tokenizer = AutoTokenizer.from_pretrained(self.downloaded_onnx_folder)
Expand Down Expand Up @@ -148,18 +160,6 @@ def run(self):
if self.params.get("streamlit"):
self.run_streamlit()
else:
if self.downloaded_onnx_folder is None:
self.downloaded_onnx_folder, run_type = pull_model(self.model_path, **kwargs)

if self.downloaded_onnx_folder is None:
logging.error(
f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.",
exc_info=True,
)
exit(1)

self._load_model_and_tokenizer()

if self.model is None or self.tokenizer is None or self.streamer is None:
logging.error(
"Failed to load model or tokenizer. Exiting.", exc_info=True
Expand Down
6 changes: 2 additions & 4 deletions nexa/onnx/nexa_inference_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def __init__(self, model_path=None, local_path=None, **kwargs):
self.downloaded_onnx_folder = local_path

if self.downloaded_onnx_folder is None:
self.downloaded_onnx_folder, run_type = pull_model(self.model_path, **kwargs)
self.downloaded_onnx_folder, _ = pull_model(self.model_path, **kwargs)

if self.downloaded_onnx_folder is None:
logging.error(
f"Model ({model_path}) is not applicable. Please refer to our docs for proper usage.",
Expand All @@ -69,12 +69,10 @@ def _load_model(self):
logging.debug(f"Loading model from {self.downloaded_onnx_folder}")
try:
self.tokenizer = TTSTokenizer(self.config["token"]["list"])
print(self.tokenizer)
self.model = onnxruntime.InferenceSession(
os.path.join(self.downloaded_onnx_folder, "model.onnx"),
providers=["CPUExecutionProvider"],
)
print(self.model)
logging.debug("Model and tokenizer loaded successfully")
except Exception as e:
logging.error(f"Error loading model or tokenizer: {e}")
Expand Down
15 changes: 8 additions & 7 deletions nexa/onnx/nexa_inference_voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ def __init__(self, model_path=None, local_path=None, **kwargs):
self.model = None
self.processor = None

def run(self):
if self.downloaded_onnx_folder is None:
self.downloaded_onnx_folder, run_type = pull_model(self.model_path, **kwargs)
self.downloaded_onnx_folder, _ = pull_model(self.model_path, **kwargs)

if self.downloaded_onnx_folder is None:
logging.error(
Expand All @@ -54,14 +53,16 @@ def run(self):
)
exit(1)

self._load_model(self.downloaded_onnx_folder)
self._load_model()

def run(self):
self._dialogue_mode()

def _load_model(self, model_path):
logging.debug(f"Loading model from {model_path}")
def _load_model(self):
logging.debug(f"Loading model from {self.downloaded_onnx_folder}")
try:
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = ORTModelForSpeechSeq2Seq.from_pretrained(model_path)
self.processor = AutoProcessor.from_pretrained(self.downloaded_onnx_folder)
self.model = ORTModelForSpeechSeq2Seq.from_pretrained(self.downloaded_onnx_folder)
logging.debug("Model and processor loaded successfully")
except Exception as e:
logging.error(f"Error loading model or processor: {e}")
Expand Down

0 comments on commit c60d918

Please sign in to comment.