Skip to content

Commit

Permalink
[FEAT] Ability to save Statsforecast
Browse files Browse the repository at this point in the history
  • Loading branch information
akmalsoliev committed Oct 13, 2023
1 parent 9078542 commit 15ae2ad
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 1 deletion.
98 changes: 98 additions & 0 deletions nbs/src/utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
"#| export\n",
"import os\n",
"import warnings\n",
"import datetime as dt\n",
"import pickle \n",
"from glob import glob\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
Expand Down Expand Up @@ -403,6 +406,101 @@
" self.h = h\n",
" self.method = method"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| exporti\n",
"def save_statsforecast(sf, path, file_name=None, prompt=\"N\"):\n",
" # TODO: Add method information\n",
"\n",
" fitted_models = sf.fitted_\n",
" datetime_record = dt.datetime.now().strftime(\"%m_%d_%Y_%H_%M_%S\")\n",
" models = np.array(fitted_models)\n",
" models_size = models.itemsize * models.size\n",
"\n",
" ask = False\n",
" if prompt.upper != \"Y\":\n",
" ask=True\n",
"\n",
" print(\"Model(s) size:\")\n",
" if models_size < 2**10:\n",
" print(models_size, \"Bytes\")\n",
" if 2**10 < models_size < 2*20:\n",
" size = np.round(np.divide(models_size, 2**10), 2)\n",
" print(size, \"Kilobytes\")\n",
" if models_size >= 2**20:\n",
" size = np.round(np.divide(models_size, 2**20), 2)\n",
" print(size, \"Megabytes\")\n",
" if size >= 50:\n",
" print(\"!!! WARNING !!!\")\n",
" print(\"The model size is over Megabyte threshold.\")\n",
" print(\"Saving the model(s) will take long time.\")\n",
"\n",
" if ask:\n",
" prompt = input(\"Would you like to proceed? (y/n)\")\n",
" if prompt.upper() not in [\"Y\", \"N\"]:\n",
" print(\"Wrong input model(s) would not be saved\")\n",
" return\n",
" elif prompt.upper() == \"N\":\n",
" return\n",
"\n",
" print(\"Saving model(s)\")\n",
"\n",
" if not file_name:\n",
" path_file = os.path.join(path, f\"FittedModels_{datetime_record}.pickle\")\n",
" else:\n",
" path_file = os.path.join(path, file_name)\n",
"\n",
" with open(path_file, \"wb\") as m_file:\n",
" pickle.dump(sf, m_file)\n",
" print(\"Model(s) saved\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| exporti\n",
"def load_statsforecast(path, file_name=None):\n",
" \"\"\"\n",
" Automatically loads the model into ready StatsForecast.\n",
" Parameters\n",
" ----------\n",
" path: Union[str, Path]\n",
" Path to saved StatsForecast directory (folder).\n",
" file_name: Union[str, Path]\n",
" Path to saved Statsforecast (pickle file).\n",
" \"\"\"\n",
" if not file_name:\n",
" sf_p = os.path.join(path, \"*.pickle\")\n",
" sf_f_p = glob(sf_p)\n",
" if len(sf_f_p) > 1:\n",
" raise ValueError(\n",
" f\"\"\"The path contains more than one *.pickle file in it.\n",
" Please remove non-required *.pickle file from the {path}\n",
" \"\"\")\n",
" elif len(sf_f_p) == 1:\n",
" sf_f_p = sf_f_p[0]\n",
" else:\n",
" raise ValueError(\n",
" \"\"\"\n",
" Not a single model(s) file found in the specified directory.\n",
" Ensure that the directory is right and/or add your `.pickle` file in\n",
" it.\n",
" \"\"\"\n",
" )\n",
" else:\n",
" sf_f_p = os.path.join(path, file_name)\n",
" \n",
" with open(sf_f_p, \"rb\") as f:\n",
" return pickle.load(f)"
]
}
],
"metadata": {
Expand Down
6 changes: 5 additions & 1 deletion statsforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,4 +720,8 @@
'statsforecast.utils._repeat_val': ('src/utils.html#_repeat_val', 'statsforecast/utils.py'),
'statsforecast.utils._repeat_val_seas': ('src/utils.html#_repeat_val_seas', 'statsforecast/utils.py'),
'statsforecast.utils._seasonal_naive': ('src/utils.html#_seasonal_naive', 'statsforecast/utils.py'),
'statsforecast.utils.generate_series': ('src/utils.html#generate_series', 'statsforecast/utils.py')}}}
'statsforecast.utils.generate_series': ('src/utils.html#generate_series', 'statsforecast/utils.py'),
'statsforecast.utils.load_statsforecast': ( 'src/utils.html#load_statsforecast',
'statsforecast/utils.py'),
'statsforecast.utils.save_statsforecast': ( 'src/utils.html#save_statsforecast',
'statsforecast/utils.py')}}}
85 changes: 85 additions & 0 deletions statsforecast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
# %% ../nbs/src/utils.ipynb 3
import os
import warnings
import datetime as dt
import pickle
from glob import glob

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -326,3 +329,85 @@ def __init__(
self.n_windows = n_windows
self.h = h
self.method = method

# %% ../nbs/src/utils.ipynb 20
def save_statsforecast(sf, path, file_name=None, prompt="N"):
# TODO: Add method information

fitted_models = sf.fitted_
datetime_record = dt.datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
models = np.array(fitted_models)
models_size = models.itemsize * models.size

ask = False
if prompt.upper != "Y":
ask = True

print("Model(s) size:")
if models_size < 2**10:
print(models_size, "Bytes")
if 2**10 < models_size < 2 * 20:
size = np.round(np.divide(models_size, 2**10), 2)
print(size, "Kilobytes")
if models_size >= 2**20:
size = np.round(np.divide(models_size, 2**20), 2)
print(size, "Megabytes")
if size >= 50:
print("!!! WARNING !!!")
print("The model size is over Megabyte threshold.")
print("Saving the model(s) will take long time.")

if ask:
prompt = input("Would you like to proceed? (y/n)")
if prompt.upper() not in ["Y", "N"]:
print("Wrong input model(s) would not be saved")
return
elif prompt.upper() == "N":
return

print("Saving model(s)")

if not file_name:
path_file = os.path.join(path, f"FittedModels_{datetime_record}.pickle")
else:
path_file = os.path.join(path, file_name)

with open(path_file, "wb") as m_file:
pickle.dump(sf, m_file)
print("Model(s) saved")

# %% ../nbs/src/utils.ipynb 21
def load_statsforecast(path, file_name=None):
"""
Automatically loads the model into ready StatsForecast.
Parameters
----------
path: Union[str, Path]
Path to saved StatsForecast directory (folder).
file_name: Union[str, Path]
Path to saved Statsforecast (pickle file).
"""
if not file_name:
sf_p = os.path.join(path, "*.pickle")
sf_f_p = glob(sf_p)
if len(sf_f_p) > 1:
raise ValueError(
f"""The path contains more than one *.pickle file in it.
Please remove non-required *.pickle file from the {path}
"""
)
elif len(sf_f_p) == 1:
sf_f_p = sf_f_p[0]
else:
raise ValueError(
"""
Not a single model(s) file found in the specified directory.
Ensure that the directory is right and/or add your `.pickle` file in
it.
"""
)
else:
sf_f_p = os.path.join(path, file_name)

with open(sf_f_p, "rb") as f:
return pickle.load(f)

0 comments on commit 15ae2ad

Please sign in to comment.