Skip to content

Commit

Permalink
add warning in case model is present in available models
Browse files Browse the repository at this point in the history
  • Loading branch information
raspawar committed Oct 4, 2024
1 parent 8af9d32 commit 81584b6
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,13 @@ def _validate_model(self, model_name: str) -> None:
"""
if self._is_hosted:
if model_name not in MODEL_ENDPOINT_MAP:
raise ValueError(
f"Model {model_name} is incompatible with client {self.class_name()}. "
f"Please check `{self.class_name()}.available_models()`."
)
if model_name in [model.id for model in self._client.models.list()]:
warnings.warn(f"Unable to determine validity of {model_name}")
else:
raise ValueError(
f"Model {model_name} is incompatible with client {self.class_name()}. "
f"Please check `{self.class_name()}.available_models()`."
)
else:
if model_name not in [model.id for model in self.available_models]:
raise ValueError(f"No locally hosted {model_name} was found.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,12 @@ def test_model_incompatible_client_model() -> None:
with pytest.raises(ValueError) as msg:
NVIDIAEmbedding(api_key="BOGUS", model=model_name)
assert err_msg == str(msg.value)


def test_model_incompatible_client_known_model() -> None:
model_name = "google/deplot"
warn_msg = f"Unable to determine validity"
with pytest.warns(UserWarning) as msg:
NVIDIAEmbedding(api_key="BOGUS", model=model_name)
assert len(msg) == 1
assert warn_msg in str(msg[0].message)
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,13 @@ def _validate_model(self, model_name: str) -> None:
"""
if self._is_hosted:
if model_name not in ALL_MODELS:
raise ValueError(
f"Model {model_name} is incompatible with client {self.class_name()}. "
f"Please check `{self.class_name()}.available_models()`."
)
if model_name in [model.id for model in self.available_models]:
warnings.warn(f"Unable to determine validity of {model_name}")
else:
raise ValueError(
f"Model {model_name} is incompatible with client {self.class_name()}. "
f"Please check `{self.class_name()}.available_models()`."
)
else:
if model_name not in [model.id for model in self.available_models]:
raise ValueError(f"No locally hosted {model_name} was found.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,12 @@ def test_model_incompatible_client_model() -> None:
with pytest.raises(ValueError) as msg:
NVIDIA(model=model_name)
assert err_msg == str(msg.value)


def test_model_incompatible_client_known_model() -> None:
model_name = "google/deplot"
warn_msg = f"Unable to determine validity"
with pytest.warns(UserWarning) as msg:
NVIDIA(api_key="BOGUS", model=model_name)
assert len(msg) == 1
assert warn_msg in str(msg[0].message)
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,13 @@ def _validate_model(self, model_name: str) -> None:
"""
if self._is_hosted:
if model_name not in MODEL_ENDPOINT_MAP:
raise ValueError(
f"Model {model_name} is incompatible with client {self.class_name()}. "
f"Please check `{self.class_name()}.available_models()`."
)
if model_name in [model.id for model in self._get_models()]:
warnings.warn(f"Unable to determine validity of {model_name}")
else:
raise ValueError(
f"Model {model_name} is incompatible with client {self.class_name()}. "
f"Please check `{self.class_name()}.available_models()`."
)
else:
if model_name not in [model.id for model in self.available_models]:
raise ValueError(f"No locally hosted {model_name} was found.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,12 @@ def test_model_incompatible_client() -> None:
with pytest.raises(ValueError) as msg:
NVIDIARerank(api_key="BOGUS", model=model_name)
assert err_msg == str(msg.value)


def test_model_incompatible_client_known_model() -> None:
model_name = "google/deplot"
warn_msg = f"Unable to determine validity"
with pytest.warns(UserWarning) as msg:
NVIDIARerank(api_key="BOGUS", model=model_name)
assert len(msg) == 1
assert warn_msg in str(msg[0].message)

0 comments on commit 81584b6

Please sign in to comment.