Skip to content

Commit

Permalink
feat: set supported_models using base_url (#450)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Aug 19, 2024
1 parent 8817e3b commit bbb0810
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
9 changes: 6 additions & 3 deletions nbs/nixtla_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,10 @@
" max_retries=max_retries, retry_interval=retry_interval, max_wait_time=max_wait_time\n",
" )\n",
" self._model_params: Dict[Tuple[str, str], Tuple[int, int]] = {}\n",
" if 'ai.azure' in base_url:\n",
" self.supported_models = ['azureai', 'timegpt-1-long-horizon']\n",
" else:\n",
" self.supported_models = ['timegpt-1', 'timegpt-1-long-horizon']\n",
"\n",
" def _make_request(self, client: httpx.Client, endpoint: str, payload: Dict[str, Any]) -> Dict[str, Any]:\n",
" resp = client.request(\n",
Expand Down Expand Up @@ -784,10 +788,9 @@
" ) -> Tuple[DFType, Optional[DFType], bool]:\n",
" if validate_api_key and not self.validate_api_key(log=False):\n",
" raise Exception('API Key not valid, please email [email protected]')\n",
" supported_models = ['timegpt-1', 'timegpt-1-long-horizon']\n",
" if model not in supported_models:\n",
" if model not in self.supported_models:\n",
" raise ValueError(\n",
" f'unsupported model: {model}. supported models: {supported_models}'\n",
" f'unsupported model: {model}. supported models: {self.supported_models}'\n",
" )\n",
" drop_id = id_col not in df.columns\n",
" if drop_id:\n",
Expand Down
9 changes: 6 additions & 3 deletions nixtla/nixtla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,10 @@ def __init__(
max_wait_time=max_wait_time,
)
self._model_params: Dict[Tuple[str, str], Tuple[int, int]] = {}
if "ai.azure" in base_url:
self.supported_models = ["azureai", "timegpt-1-long-horizon"]
else:
self.supported_models = ["timegpt-1", "timegpt-1-long-horizon"]

def _make_request(
self, client: httpx.Client, endpoint: str, payload: Dict[str, Any]
Expand Down Expand Up @@ -715,10 +719,9 @@ def _run_validations(
) -> Tuple[DFType, Optional[DFType], bool]:
if validate_api_key and not self.validate_api_key(log=False):
raise Exception("API Key not valid, please email [email protected]")
supported_models = ["timegpt-1", "timegpt-1-long-horizon"]
if model not in supported_models:
if model not in self.supported_models:
raise ValueError(
f"unsupported model: {model}. supported models: {supported_models}"
f"unsupported model: {model}. supported models: {self.supported_models}"
)
drop_id = id_col not in df.columns
if drop_id:
Expand Down

0 comments on commit bbb0810

Please sign in to comment.