Skip to content

Commit

Permalink
feat: add refit argument to cross_validation
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Dec 5, 2024
1 parent 463371e commit 5fda8b2
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
42 changes: 41 additions & 1 deletion nbs/src/nixtla_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1582,6 +1582,7 @@
" finetune_steps: _NonNegativeInt,\n",
" finetune_depth: _Finetune_Depth,\n",
" finetune_loss: _Loss,\n",
" refit: bool,\n",
" clean_ex_first: bool,\n",
" hist_exog_list: Optional[list[str]],\n",
" date_features: Union[bool, Sequence[Union[str, Callable]]],\n",
Expand Down Expand Up @@ -1620,6 +1621,7 @@
" finetune_steps=finetune_steps,\n",
" finetune_depth=finetune_depth,\n",
" finetune_loss=finetune_loss,\n",
" refit=refit,\n",
" clean_ex_first=clean_ex_first,\n",
" hist_exog_list=hist_exog_list,\n",
" date_features=date_features,\n",
Expand Down Expand Up @@ -1648,6 +1650,7 @@
" finetune_steps: _NonNegativeInt = 0,\n",
" finetune_depth: _Finetune_Depth = 1,\n",
" finetune_loss: _Loss = 'default',\n",
" refit: bool = True,\n",
" clean_ex_first: bool = True,\n",
" hist_exog_list: Optional[list[str]] = None,\n",
" date_features: Union[bool, list[str]] = False,\n",
Expand Down Expand Up @@ -1700,11 +1703,14 @@
" finetune_steps : int (default=0)\n",
" Number of steps used to finetune TimeGPT in the\n",
" new data.\n",
" finetune_depth: int (default=1)\n",
" finetune_depth : int (default=1)\n",
" The depth of the finetuning. Uses a scale from 1 to 5, where 1 means little finetuning,\n",
" and 5 means that the entire model is finetuned.\n",
" finetune_loss : str (default='default')\n",
" Loss function to use for finetuning. Options are: `default`, `mae`, `mse`, `rmse`, `mape`, and `smape`.\n",
" refit : bool (default=True)\n",
" Fine-tune the model in each window. If `False`, only fine-tunes on the first window.\n",
" Only used if `finetune_steps` > 0.\n",
" clean_ex_first : bool (default=True)\n",
" Clean exogenous signal before making forecasts using TimeGPT.\n",
" hist_exog_list : list of str, optional (default=None)\n",
Expand Down Expand Up @@ -1749,6 +1755,7 @@
" finetune_steps=finetune_steps,\n",
" finetune_depth=finetune_depth,\n",
" finetune_loss=finetune_loss,\n",
" refit=refit,\n",
" clean_ex_first=clean_ex_first,\n",
" hist_exog_list=hist_exog_list,\n",
" date_features=date_features,\n",
Expand Down Expand Up @@ -1856,6 +1863,7 @@
" 'finetune_steps': finetune_steps,\n",
" 'finetune_depth': finetune_depth,\n",
" 'finetune_loss': finetune_loss,\n",
" 'refit': refit,\n",
" }\n",
" with httpx.Client(**self._client_kwargs) as client:\n",
" if num_partitions is None:\n",
Expand Down Expand Up @@ -2106,6 +2114,7 @@
" finetune_steps: _NonNegativeInt,\n",
" finetune_depth: _Finetune_Depth,\n",
" finetune_loss: _Loss,\n",
" refit: bool,\n",
" clean_ex_first: bool,\n",
" hist_exog_list: Optional[list[str]],\n",
" date_features: Union[bool, list[str]],\n",
Expand All @@ -2128,6 +2137,7 @@
" finetune_steps=finetune_steps,\n",
" finetune_depth=finetune_depth,\n",
" finetune_loss=finetune_loss,\n",
" refit=refit,\n",
" clean_ex_first=clean_ex_first,\n",
" hist_exog_list=hist_exog_list,\n",
" date_features=date_features,\n",
Expand Down Expand Up @@ -2474,6 +2484,36 @@
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# cv refit\n",
"cv_kwargs = dict(\n",
" df=df,\n",
" n_windows=2,\n",
" h=12,\n",
" freq='MS',\n",
" time_col='timestamp',\n",
" target_col='value',\n",
" finetune_steps=2,\n",
")\n",
"res_refit = nixtla_client.cross_validation(refit=True, **cv_kwargs)\n",
"res_no_refit = nixtla_client.cross_validation(refit=False, **cv_kwargs)\n",
"np.testing.assert_allclose(res_refit['value'], res_no_refit['value'])\n",
"np.testing.assert_raises(\n",
" AssertionError,\n",
" np.testing.assert_allclose,\n",
" res_refit['TimeGPT'],\n",
" res_no_refit['TimeGPT'],\n",
" atol=1e-4,\n",
" rtol=1e-3,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
12 changes: 11 additions & 1 deletion nixtla/nixtla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,7 @@ def _distributed_cross_validation(
finetune_steps: _NonNegativeInt,
finetune_depth: _Finetune_Depth,
finetune_loss: _Loss,
refit: bool,
clean_ex_first: bool,
hist_exog_list: Optional[list[str]],
date_features: Union[bool, Sequence[Union[str, Callable]]],
Expand Down Expand Up @@ -1559,6 +1560,7 @@ def _distributed_cross_validation(
finetune_steps=finetune_steps,
finetune_depth=finetune_depth,
finetune_loss=finetune_loss,
refit=refit,
clean_ex_first=clean_ex_first,
hist_exog_list=hist_exog_list,
date_features=date_features,
Expand Down Expand Up @@ -1587,6 +1589,7 @@ def cross_validation(
finetune_steps: _NonNegativeInt = 0,
finetune_depth: _Finetune_Depth = 1,
finetune_loss: _Loss = "default",
refit: bool = True,
clean_ex_first: bool = True,
hist_exog_list: Optional[list[str]] = None,
date_features: Union[bool, list[str]] = False,
Expand Down Expand Up @@ -1639,11 +1642,14 @@ def cross_validation(
finetune_steps : int (default=0)
Number of steps used to finetune TimeGPT in the
new data.
finetune_depth: int (default=1)
finetune_depth : int (default=1)
The depth of the finetuning. Uses a scale from 1 to 5, where 1 means little finetuning,
and 5 means that the entire model is finetuned.
finetune_loss : str (default='default')
Loss function to use for finetuning. Options are: `default`, `mae`, `mse`, `rmse`, `mape`, and `smape`.
refit : bool (default=True)
Fine-tune the model in each window. If `False`, only fine-tunes on the first window.
Only used if `finetune_steps` > 0.
clean_ex_first : bool (default=True)
Clean exogenous signal before making forecasts using TimeGPT.
hist_exog_list : list of str, optional (default=None)
Expand Down Expand Up @@ -1688,6 +1694,7 @@ def cross_validation(
finetune_steps=finetune_steps,
finetune_depth=finetune_depth,
finetune_loss=finetune_loss,
refit=refit,
clean_ex_first=clean_ex_first,
hist_exog_list=hist_exog_list,
date_features=date_features,
Expand Down Expand Up @@ -1795,6 +1802,7 @@ def cross_validation(
"finetune_steps": finetune_steps,
"finetune_depth": finetune_depth,
"finetune_loss": finetune_loss,
"refit": refit,
}
with httpx.Client(**self._client_kwargs) as client:
if num_partitions is None:
Expand Down Expand Up @@ -2044,6 +2052,7 @@ def _cross_validation_wrapper(
finetune_steps: _NonNegativeInt,
finetune_depth: _Finetune_Depth,
finetune_loss: _Loss,
refit: bool,
clean_ex_first: bool,
hist_exog_list: Optional[list[str]],
date_features: Union[bool, list[str]],
Expand All @@ -2066,6 +2075,7 @@ def _cross_validation_wrapper(
finetune_steps=finetune_steps,
finetune_depth=finetune_depth,
finetune_loss=finetune_loss,
refit=refit,
clean_ex_first=clean_ex_first,
hist_exog_list=hist_exog_list,
date_features=date_features,
Expand Down

0 comments on commit 5fda8b2

Please sign in to comment.