diff --git a/nexa/gguf/nexa_inference_vlm_omni.py b/nexa/gguf/nexa_inference_vlm_omni.py index 83876569..021f79e7 100644 --- a/nexa/gguf/nexa_inference_vlm_omni.py +++ b/nexa/gguf/nexa_inference_vlm_omni.py @@ -32,9 +32,6 @@ def __init__( self.model = None self.projector = None - self.projector_path = NEXA_RUN_OMNI_VLM_PROJECTOR_MAP.get(model_path, None) - self.downloaded_path = local_path - self.projector_downloaded_path = projector_local_path self.device = device self.context = None self.omni_vlm_version = "vlm-81-instruct" @@ -43,39 +40,57 @@ def __init__( else: self.n_gpu_layers = 0 - if self.downloaded_path is not None and self.projector_downloaded_path is not None: - # when running from local, both path should be provided - self.omni_vlm_version = self._determine_vlm_version(str(self.downloaded_path)) - elif self.downloaded_path is not None: - if model_path in NEXA_RUN_OMNI_VLM_MAP: + # Handle direct model file paths (e.g., omnivision:model-fp16) + if model_path and ':model-' in model_path: + base_name = model_path.split(':')[0] + model_type = model_path.split('model-')[1] + if base_name in NEXA_RUN_OMNI_VLM_PROJECTOR_MAP: + self.model_path = model_path + # Construct corresponding projector path + self.projector_path = f"{base_name}:projector-{model_type}" + self.downloaded_path, _ = pull_model(self.model_path, **kwargs) + self.projector_downloaded_path, _ = pull_model(self.projector_path, **kwargs) + self.omni_vlm_version = self._determine_vlm_version(model_path) + + else: + # Handle other path formats and model loading scenarios + self.projector_path = NEXA_RUN_OMNI_VLM_PROJECTOR_MAP.get(model_path, None) + self.downloaded_path = local_path + self.projector_downloaded_path = projector_local_path + + if self.downloaded_path is not None and self.projector_downloaded_path is not None: + # when running from local, both path should be provided + self.omni_vlm_version = self._determine_vlm_version(str(self.downloaded_path)) + elif self.downloaded_path is not None: + if model_path in NEXA_RUN_OMNI_VLM_MAP: + self.projector_path = NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[model_path] + self.projector_downloaded_path, _ = pull_model(self.projector_path, **kwargs) + self.omni_vlm_version = self._determine_vlm_version(model_path) + elif model_path in NEXA_RUN_OMNI_VLM_MAP: + self.model_path = NEXA_RUN_OMNI_VLM_MAP[model_path] self.projector_path = NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[model_path] + self.downloaded_path, _ = pull_model(self.model_path, **kwargs) self.projector_downloaded_path, _ = pull_model(self.projector_path, **kwargs) self.omni_vlm_version = self._determine_vlm_version(model_path) - elif model_path in NEXA_RUN_OMNI_VLM_MAP: - self.model_path = NEXA_RUN_OMNI_VLM_MAP[model_path] - self.projector_path = NEXA_RUN_OMNI_VLM_PROJECTOR_MAP[model_path] - self.downloaded_path, _ = pull_model(self.model_path, **kwargs) - self.projector_downloaded_path, _ = pull_model(self.projector_path, **kwargs) - self.omni_vlm_version = self._determine_vlm_version(model_path) - elif Path(model_path).parent.exists(): - local_dir = Path(model_path).parent - model_name = Path(model_path).name - tag_and_ext = model_name.split(":")[-1] - self.downloaded_path = local_dir / f"model-{tag_and_ext}" - self.projector_downloaded_path = local_dir / f"projector-{tag_and_ext}" - if not ( - self.downloaded_path.exists() - and self.projector_downloaded_path.exists() - ): - logging.error( - f"Model or projector not found in {local_dir}. " - "Make sure to name them as 'model-.gguf' and 'projector-.gguf'." - ) + elif Path(model_path).parent.exists(): + local_dir = Path(model_path).parent + model_name = Path(model_path).name + tag_and_ext = model_name.split(":")[-1] + self.downloaded_path = local_dir / f"model-{tag_and_ext}" + self.projector_downloaded_path = local_dir / f"projector-{tag_and_ext}" + if not ( + self.downloaded_path.exists() + and self.projector_downloaded_path.exists() + ): + logging.error( + f"Model or projector not found in {local_dir}. " + "Make sure to name them as 'model-.gguf' and 'projector-.gguf'." + ) + exit(1) + self.omni_vlm_version = self._determine_vlm_version(model_path) + else: + logging.error("VLM user model from hub is not supported yet.") exit(1) - self.omni_vlm_version = self._determine_vlm_version(model_path) - else: - logging.error("VLM user model from hub is not supported yet.") - exit(1) # Override version if specified in kwargs if 'omni_vlm_version' in kwargs: