diff --git a/nbs/nixtla_client.ipynb b/nbs/nixtla_client.ipynb index 555c0e36..5d729944 100644 --- a/nbs/nixtla_client.ipynb +++ b/nbs/nixtla_client.ipynb @@ -669,6 +669,10 @@ " max_retries=max_retries, retry_interval=retry_interval, max_wait_time=max_wait_time\n", " )\n", " self._model_params: Dict[Tuple[str, str], Tuple[int, int]] = {}\n", + " if 'ai.azure' in base_url:\n", + " self.supported_models = ['azureai', 'timegpt-1-long-horizon']\n", + " else:\n", + " self.supported_models = ['timegpt-1', 'timegpt-1-long-horizon']\n", "\n", " def _make_request(self, client: httpx.Client, endpoint: str, payload: Dict[str, Any]) -> Dict[str, Any]:\n", " resp = client.request(\n", @@ -784,10 +788,9 @@ " ) -> Tuple[DFType, Optional[DFType], bool]:\n", " if validate_api_key and not self.validate_api_key(log=False):\n", " raise Exception('API Key not valid, please email ops@nixtla.io')\n", - " supported_models = ['timegpt-1', 'timegpt-1-long-horizon']\n", - " if model not in supported_models:\n", + " if model not in self.supported_models:\n", " raise ValueError(\n", - " f'unsupported model: {model}. supported models: {supported_models}'\n", + " f'unsupported model: {model}. supported models: {self.supported_models}'\n", " )\n", " drop_id = id_col not in df.columns\n", " if drop_id:\n", diff --git a/nixtla/nixtla_client.py b/nixtla/nixtla_client.py index 63d5d2c1..a3382d65 100644 --- a/nixtla/nixtla_client.py +++ b/nixtla/nixtla_client.py @@ -600,6 +600,10 @@ def __init__( max_wait_time=max_wait_time, ) self._model_params: Dict[Tuple[str, str], Tuple[int, int]] = {} + if "ai.azure" in base_url: + self.supported_models = ["azureai", "timegpt-1-long-horizon"] + else: + self.supported_models = ["timegpt-1", "timegpt-1-long-horizon"] def _make_request( self, client: httpx.Client, endpoint: str, payload: Dict[str, Any] @@ -715,10 +719,9 @@ def _run_validations( ) -> Tuple[DFType, Optional[DFType], bool]: if validate_api_key and not self.validate_api_key(log=False): raise Exception("API Key not valid, please email ops@nixtla.io") - supported_models = ["timegpt-1", "timegpt-1-long-horizon"] - if model not in supported_models: + if model not in self.supported_models: raise ValueError( - f"unsupported model: {model}. supported models: {supported_models}" + f"unsupported model: {model}. supported models: {self.supported_models}" ) drop_id = id_col not in df.columns if drop_id: