diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index f7263974b..49d0ead04 100755 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -70,6 +70,7 @@ nav: - user_guide/evaluate/create_counterfactual.md - user_guide/evaluate/feature_importance_permutation.md - user_guide/evaluate/ftest.md + - user_guide/evaluate/GroupTimeSeriesSplit.md - user_guide/evaluate/lift_score.md - user_guide/evaluate/mcnemar_table.md - user_guide/evaluate/mcnemar_tables.md diff --git a/docs/sources/CHANGELOG.md b/docs/sources/CHANGELOG.md index 11790f8c1..0d620df85 100755 --- a/docs/sources/CHANGELOG.md +++ b/docs/sources/CHANGELOG.md @@ -24,6 +24,7 @@ The CHANGELOG for the current development version is available at - The `mlxtend.evaluate.bootstrap_point632_score` now supports `fit_params`. ([#861](https://github.com/rasbt/mlxtend/pull/861)) - The `mlxtend/plotting/decision_regions.py` function now has a `contourf_kwargs` for matplotlib to change the look of the decision boundaries if desired. ([#881](https://github.com/rasbt/mlxtend/pull/881) via [[pbloem](https://github.com/pbloem)]) - Add a `norm_colormap` parameter to `mlxtend.plotting.plot_confusion_matrix`, to allow normalizing the colormap, e.g., using `matplotlib.colors.LogNorm()` ([#895](https://github.com/rasbt/mlxtend/pull/895)) +- Add new `GroupTimeSeriesSplit` class for evaluation in time series tasks with support of custom groups and additional parameters in comparison with scikit-learn's `TimeSeriesSplit`. ([#915](https://github.com/rasbt/mlxtend/pull/915) via [Dmitry Labazkin](https://github.com/labdmitriy)) ##### Changes diff --git a/docs/sources/USER_GUIDE_INDEX.md b/docs/sources/USER_GUIDE_INDEX.md index 28b17cca6..ef264ffa9 100755 --- a/docs/sources/USER_GUIDE_INDEX.md +++ b/docs/sources/USER_GUIDE_INDEX.md @@ -36,6 +36,7 @@ - [create_counterfactual](user_guide/evaluate/create_counterfactual.md) - [feature_importance_permutation](user_guide/evaluate/feature_importance_permutation.md) - [ftest](user_guide/evaluate/ftest.md) +- [GroupTimeSeriesSplit](user_guide/evaluate/GroupTimeSeriesSplit.md) - [lift_score](user_guide/evaluate/lift_score.md) - [mcnemar_table](user_guide/evaluate/mcnemar_table.md) - [mcnemar_tables](user_guide/evaluate/mcnemar_tables.md) diff --git a/docs/sources/user_guide/evaluate/GroupTimeSeriesSplit.ipynb b/docs/sources/user_guide/evaluate/GroupTimeSeriesSplit.ipynb new file mode 100644 index 000000000..291466add --- /dev/null +++ b/docs/sources/user_guide/evaluate/GroupTimeSeriesSplit.ipynb @@ -0,0 +1,809 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GroupTimeSeriesSplit: A scikit-learn compatible version of the time series validation with groups" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A scikit-learn-compatible time series cross-validator that supports non-overlapping groups." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> `from mlxtend.evaluate import GroupTimeSeriesSplit` " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Time series tasks in machine learning require special type of validation, because the time order of the objects is important for a fairer evaluation of an ML model’s quality. \n", + "Also there can be different time units for splitting the data for different tasks - hours, days, months etc. \n", + "\n", + "Here, we use time series validation with support of the groups which can be flexibly configured along with other parameters:\n", + "\n", + "- Test size\n", + "- Train size\n", + "- Number of splits\n", + "- Gap size\n", + "- Shift size \n", + "- Window type \n", + "\n", + "This `GroupTimeSeriesSplit` implementation is inspired by scikit-learn's [TimeSeriesSplit](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.TimeSeriesSplit.html) but it has several advantages:\n", + "\n", + "- `GroupTimeSeriesSplit` lets you group data as you want before splitting, while `TimeSeriesSplit` only considers the record level.\n", + "- It can be used for both holdout validation (n_splits=1) and cross-validation (n_splits>=2), whereas `TimeSeriesSplit` can be used only for the latter case.\n", + "- `TimeSeriesSplit` uses only an expanding window, while for this implementation you can choose between both rolling and expanding window types.\n", + "- `GroupTimeSeriesSplit` offers additional control for splitting using an additional `shift size` parameter.\n", + "\n", + "**There are several features that need to be taken into account:**\n", + "\n", + "- `GroupTimeSeriesSplit` is compatible with sklearn-learn API.\n", + "- Numbers or custom non-numeric values can be used as groups\n", + "- However, groups should be consecutive\n", + "- Specifying the test size with either a) the train size or b) the number of splits is required parameters for splitting\n", + "- If full data can’t be used with specific parameters, the most recent data is considered for splitting\n", + "- If splitting is impossible (e.g., there is not enough data to split) using specified parameters, an exception will be raised " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before we illustrate the usage of `GroupTimeSeriesSplit` in the following examples below, let's set up a `DummyClassifier` that we will reuse in the following sections. Also, let's import the libraries we will be using in the following examples:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.dummy import DummyClassifier\n", + "from sklearn.model_selection import cross_val_score\n", + "\n", + "from mlxtend.evaluate.time_series import (\n", + " GroupTimeSeriesSplit,\n", + " plot_splits,\n", + " print_cv_info,\n", + " print_split_info,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "## Prepare sample data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For the following examples, we are creating sample dataset consisting of 16 training data points with corresponding targets." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Features and targets" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's assume that we have one numeric feature and target for the binary classification task." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "data = [[0], [7], [6], [4], [4], [8], [0], [6], [2], [0], [5], [9], [7], [7], [7], [7]]\n", + "target = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0]\n", + "\n", + "X = pd.DataFrame(data, columns=[\"num_feature\"])\n", + "y = pd.Series(target, name=\"target\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Group numbers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We create 6 different groups so that the first training example belongs to group 0, the next 4 to group 1, and so forth. \n", + "These groups do not have to be in ascending order (as in this dataset), **but they must be consecutive.**" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 5, 5])" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "groups = np.array([0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 5, 5])\n", + "groups" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that the following i another example of a **correct** group ordering (not sorted but consecutive):\n", + "```python\n", + "np.array([5, 5, 5, 5, 1, 1, 1, 1, 3, 3, 2, 2, 2, 4, 4, 0])\n", + "```\n", + "\n", + "However, the example below shows an **incorrect** group ordering (not consecutive), which is not compatible with `GroupTimeSeriesSplit`:\n", + "```python\n", + "np.array([0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 2, 2, 2, 2])\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Group names (months)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will add months as the index according to the specified groups for a more illustrative example." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['2021-01', '2021-02', '2021-02', '2021-02', '2021-02', '2021-03',\n", + " '2021-03', '2021-03', '2021-04', '2021-04', '2021-05', '2021-05',\n", + " '2021-06', '2021-06', '2021-06', '2021-06'], dtype='" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train indices: [0 1 2 3 4 5 6 7]\n", + "Test indices: [8 9]\n", + "Train length: 8\n", + "Test length: 2\n", + "Train groups: [0 1 1 1 1 2 2 2]\n", + "Test groups: [3 3]\n", + "Train group size: 3\n", + "Test group size: 1\n", + "Train group months: ['2021-01' '2021-02' '2021-02' '2021-02' '2021-02' '2021-03' '2021-03'\n", + " '2021-03']\n", + "Test group months: ['2021-04' '2021-04']\n", + "\n", + "Train indices: [1 2 3 4 5 6 7 8 9]\n", + "Test indices: [10 11]\n", + "Train length: 9\n", + "Test length: 2\n", + "Train groups: [1 1 1 1 2 2 2 3 3]\n", + "Test groups: [4 4]\n", + "Train group size: 3\n", + "Test group size: 1\n", + "Train group months: ['2021-02' '2021-02' '2021-02' '2021-02' '2021-03' '2021-03' '2021-03'\n", + " '2021-04' '2021-04']\n", + "Test group months: ['2021-05' '2021-05']\n", + "\n", + "Train indices: [ 5 6 7 8 9 10 11]\n", + "Test indices: [12 13 14 15]\n", + "Train length: 7\n", + "Test length: 4\n", + "Train groups: [2 2 2 3 3 4 4]\n", + "Test groups: [5 5 5 5]\n", + "Train group size: 3\n", + "Test group size: 1\n", + "Train group months: ['2021-03' '2021-03' '2021-03' '2021-04' '2021-04' '2021-05' '2021-05']\n", + "Test group months: ['2021-06' '2021-06' '2021-06' '2021-06']\n", + "\n" + ] + } + ], + "source": [ + "cv_args = {\"test_size\": 1, \"train_size\": 3}\n", + "\n", + "plot_splits(X, y, groups, **cv_args)\n", + "print_split_info(X, y, groups, **cv_args)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Please note that if we specify the number of groups for both the training and the test set, the split size is determined automatically, and the number of splits naturally changes with the groups sizes. For example, increasing the number of training groups will naturally result in a lower number of splits as shown below." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "cv_args = {\"test_size\": 1, \"train_size\": 4}\n", + "\n", + "plot_splits(X, y, groups, **cv_args)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Usage in CV" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The example below illustrates how we can use the time series splitter with scikit-learn, i.e., using `cross_val_score`:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Split number: 1\n", + "Train true target: [1 0 1 0 1 0 0 1]\n", + "Train predicted target: [0 0 0 0 0 0 0 0]\n", + "Test true target: [1 1]\n", + "Test predicted target: [0 0]\n", + "Accuracy: 0.0\n", + "\n", + "Split number: 2\n", + "Train true target: [0 1 0 1 0 0 1 1 1]\n", + "Train predicted target: [1 1 1 1 1 1 1 1 1]\n", + "Test true target: [0 1]\n", + "Test predicted target: [1 1]\n", + "Accuracy: 0.5\n", + "\n", + "Split number: 3\n", + "Train true target: [0 0 1 1 1 0 1]\n", + "Train predicted target: [1 1 1 1 1 1 1]\n", + "Test true target: [1 0 0 0]\n", + "Test predicted target: [1 1 1 1]\n", + "Accuracy: 0.25\n", + "\n" + ] + } + ], + "source": [ + "cv = GroupTimeSeriesSplit(**cv_args)\n", + "clf = DummyClassifier(strategy=\"most_frequent\")\n", + "\n", + "scores = cross_val_score(clf, X, y, groups=groups, scoring=\"accuracy\", cv=cv)\n", + "print_cv_info(cv, X, y, groups, clf, scores)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 2 -- Multiple training groups (with number of splits specified)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's take a look at an example where we don't specify the number of training groups. Here, we will split the dataset with test size (2 groups) and a specified number of splits (3 groups), which is sufficient for calculating the training size automatically." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train indices: [0 1 2 3 4]\n", + "Test indices: [5 6 7 8 9]\n", + "Train length: 5\n", + "Test length: 5\n", + "Train groups: [0 1 1 1 1]\n", + "Test groups: [2 2 2 3 3]\n", + "Train group size: 2\n", + "Test group size: 2\n", + "Train group months: ['2021-01' '2021-02' '2021-02' '2021-02' '2021-02']\n", + "Test group months: ['2021-03' '2021-03' '2021-03' '2021-04' '2021-04']\n", + "\n", + "Train indices: [1 2 3 4 5 6 7]\n", + "Test indices: [ 8 9 10 11]\n", + "Train length: 7\n", + "Test length: 4\n", + "Train groups: [1 1 1 1 2 2 2]\n", + "Test groups: [3 3 4 4]\n", + "Train group size: 2\n", + "Test group size: 2\n", + "Train group months: ['2021-02' '2021-02' '2021-02' '2021-02' '2021-03' '2021-03' '2021-03']\n", + "Test group months: ['2021-04' '2021-04' '2021-05' '2021-05']\n", + "\n", + "Train indices: [5 6 7 8 9]\n", + "Test indices: [10 11 12 13 14 15]\n", + "Train length: 5\n", + "Test length: 6\n", + "Train groups: [2 2 2 3 3]\n", + "Test groups: [4 4 5 5 5 5]\n", + "Train group size: 2\n", + "Test group size: 2\n", + "Train group months: ['2021-03' '2021-03' '2021-03' '2021-04' '2021-04']\n", + "Test group months: ['2021-05' '2021-05' '2021-06' '2021-06' '2021-06' '2021-06']\n", + "\n" + ] + } + ], + "source": [ + "cv_args = {\"test_size\": 2, \"n_splits\": 3}\n", + "\n", + "plot_splits(X, y, groups, **cv_args)\n", + "print_split_info(X, y, groups, **cv_args)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Usage in CV" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Again, let's take a look at how this looks in a scikit-learn context using `cross_val_score`:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Split number: 1\n", + "Train true target: [1 0 1 0 1]\n", + "Train predicted target: [1 1 1 1 1]\n", + "Test true target: [0 0 1 1 1]\n", + "Test predicted target: [1 1 1 1 1]\n", + "Accuracy: 0.6\n", + "\n", + "Split number: 2\n", + "Train true target: [0 1 0 1 0 0 1]\n", + "Train predicted target: [0 0 0 0 0 0 0]\n", + "Test true target: [1 1 0 1]\n", + "Test predicted target: [0 0 0 0]\n", + "Accuracy: 0.25\n", + "\n", + "Split number: 3\n", + "Train true target: [0 0 1 1 1]\n", + "Train predicted target: [1 1 1 1 1]\n", + "Test true target: [0 1 1 0 0 0]\n", + "Test predicted target: [1 1 1 1 1 1]\n", + "Accuracy: 0.33\n", + "\n" + ] + } + ], + "source": [ + "cv = GroupTimeSeriesSplit(**cv_args)\n", + "clf = DummyClassifier(strategy=\"most_frequent\")\n", + "\n", + "scores = cross_val_score(clf, X, y, groups=groups, scoring=\"accuracy\", cv=cv)\n", + "print_cv_info(cv, X, y, groups, clf, scores)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example 3 -- Defining the gap size between training and test datasets" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`GroupTimeSeriesSplit` let's you specify a gap size greater than 1 in order to skip a specified number of groups between training and test folds (the default gap size is 0). In the example below, we use a gap of 1 group to illustrate this." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train indices: [0 1 2 3 4]\n", + "Test indices: [8 9]\n", + "Train length: 5\n", + "Test length: 2\n", + "Train groups: [0 1 1 1 1]\n", + "Test groups: [3 3]\n", + "Train group size: 2\n", + "Test group size: 1\n", + "Train group months: ['2021-01' '2021-02' '2021-02' '2021-02' '2021-02']\n", + "Test group months: ['2021-04' '2021-04']\n", + "\n", + "Train indices: [1 2 3 4 5 6 7]\n", + "Test indices: [10 11]\n", + "Train length: 7\n", + "Test length: 2\n", + "Train groups: [1 1 1 1 2 2 2]\n", + "Test groups: [4 4]\n", + "Train group size: 2\n", + "Test group size: 1\n", + "Train group months: ['2021-02' '2021-02' '2021-02' '2021-02' '2021-03' '2021-03' '2021-03']\n", + "Test group months: ['2021-05' '2021-05']\n", + "\n", + "Train indices: [5 6 7 8 9]\n", + "Test indices: [12 13 14 15]\n", + "Train length: 5\n", + "Test length: 4\n", + "Train groups: [2 2 2 3 3]\n", + "Test groups: [5 5 5 5]\n", + "Train group size: 2\n", + "Test group size: 1\n", + "Train group months: ['2021-03' '2021-03' '2021-03' '2021-04' '2021-04']\n", + "Test group months: ['2021-06' '2021-06' '2021-06' '2021-06']\n", + "\n" + ] + } + ], + "source": [ + "cv_args = {\"test_size\": 1, \"n_splits\": 3, \"gap_size\": 1}\n", + "\n", + "plot_splits(X, y, groups, **cv_args)\n", + "print_split_info(X, y, groups, **cv_args)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Usage in CV" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The example below shows how this looks like in a scikit-learn context using `cross_val_score`:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Split number: 1\n", + "Train true target: [1 0 1 0 1]\n", + "Train predicted target: [1 1 1 1 1]\n", + "Test true target: [1 1]\n", + "Test predicted target: [1 1]\n", + "Accuracy: 1.0\n", + "\n", + "Split number: 2\n", + "Train true target: [0 1 0 1 0 0 1]\n", + "Train predicted target: [0 0 0 0 0 0 0]\n", + "Test true target: [0 1]\n", + "Test predicted target: [0 0]\n", + "Accuracy: 0.5\n", + "\n", + "Split number: 3\n", + "Train true target: [0 0 1 1 1]\n", + "Train predicted target: [1 1 1 1 1]\n", + "Test true target: [1 0 0 0]\n", + "Test predicted target: [1 1 1 1]\n", + "Accuracy: 0.25\n", + "\n" + ] + } + ], + "source": [ + "cv = GroupTimeSeriesSplit(**cv_args)\n", + "clf = DummyClassifier(strategy=\"most_frequent\")\n", + "\n", + "scores = cross_val_score(clf, X, y, groups=groups, scoring=\"accuracy\", cv=cv)\n", + "print_cv_info(cv, X, y, groups, clf, scores)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## API" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "## GroupTimeSeriesSplit\n", + "\n", + "*GroupTimeSeriesSplit(test_size, train_size=None, n_splits=None, gap_size=0, shift_size=1, window_type='rolling')*\n", + "\n", + "Group time series cross-validator.\n", + "\n", + "**Parameters**\n", + "\n", + "- `test_size` : int\n", + "\n", + " Size of test dataset.\n", + "\n", + "- `train_size` : int (default=None)\n", + "\n", + " Size of train dataset.\n", + "\n", + "- `n_splits` : int (default=None)\n", + "\n", + " Number of the splits.\n", + "\n", + "- `gap_size` : int (default=0)\n", + "\n", + " Gap size between train and test datasets.\n", + "\n", + "- `shift_size` : int (default=1)\n", + "\n", + " Step to shift for the next fold.\n", + "\n", + "- `window_type` : str (default=\"rolling\")\n", + "\n", + " Type of the window. Possible values: \"rolling\", \"expanding\".\n", + "\n", + "**Examples**\n", + "\n", + "For usage examples, please see\n", + " http://rasbt.github.io/mlxtend/user_guide/evaluate/GroupTimeSeriesSplit/\n", + "\n", + "### Methods\n", + "\n", + "
\n", + "\n", + "*get_n_splits(X=None, y=None, groups=None)*\n", + "\n", + "Returns the number of splitting iterations in the cross-validator.\n", + "\n", + "**Parameters**\n", + "\n", + "- `X` : object\n", + "\n", + " Always ignored, exists for compatibility.\n", + "\n", + "- `y` : object\n", + "\n", + " Always ignored, exists for compatibility.\n", + "\n", + "- `groups` : object\n", + "\n", + " Always ignored, exists for compatibility.\n", + "\n", + "**Returns**\n", + "\n", + "- `n_splits` : int\n", + "\n", + " Returns the number of splitting iterations in the cross-validator.\n", + "\n", + "
\n", + "\n", + "*split(X, y=None, groups=None)*\n", + "\n", + "Generate indices to split data into training and test set.\n", + "\n", + "**Parameters**\n", + "\n", + "- `X` : array-like\n", + "\n", + " Training data.\n", + "\n", + "- `y` : array-like (default=None)\n", + "\n", + " Always ignored, exists for compatibility.\n", + "\n", + "- `groups` : array-like (default=None)\n", + "\n", + " Array with group names or sequence numbers.\n", + "\n", + "**Yields**\n", + "\n", + "- `train` : ndarray\n", + "\n", + " The training set indices for that split.\n", + "\n", + "- `test` : ndarray\n", + "\n", + " The testing set indices for that split.\n", + "\n", + "\n" + ] + } + ], + "source": [ + "with open(\"../../api_modules/mlxtend.evaluate/GroupTimeSeriesSplit.md\", \"r\") as f:\n", + " s = f.read()\n", + "print(s)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/mlxtend/evaluate/__init__.py b/mlxtend/evaluate/__init__.py index 10d2eface..7c76fec88 100644 --- a/mlxtend/evaluate/__init__.py +++ b/mlxtend/evaluate/__init__.py @@ -21,6 +21,7 @@ from .permutation import permutation_test from .proportion_difference import proportion_difference from .scoring import scoring +from .time_series import GroupTimeSeriesSplit from .ttest import paired_ttest_5x2cv, paired_ttest_kfold_cv, paired_ttest_resampled __all__ = [ @@ -47,4 +48,5 @@ "bias_variance_decomp", "accuracy_score", "create_counterfactual", + "GroupTimeSeriesSplit", ] diff --git a/mlxtend/evaluate/tests/test_time_series.py b/mlxtend/evaluate/tests/test_time_series.py new file mode 100644 index 000000000..8233e0ba8 --- /dev/null +++ b/mlxtend/evaluate/tests/test_time_series.py @@ -0,0 +1,378 @@ +# mlxtend Machine Learning Library Extensions +# +# Time series cross validation with grouping. +# Author: Dmitry Labazkin +# +# License: BSD 3 clause + +import numpy as np +import pytest +from sklearn.dummy import DummyClassifier +from sklearn.model_selection import cross_val_score + +from mlxtend.evaluate import GroupTimeSeriesSplit + + +@pytest.fixture +def X(): + return np.array( + [[0], [7], [6], [4], [4], [8], [0], [6], [2], [0], [5], [9], [7], [7], [7], [7]] + ) + + +@pytest.fixture +def y(): + return np.array([1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0]) + + +@pytest.fixture +def group_numbers(): + return np.array([0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 5, 5]) + + +@pytest.fixture +def not_sorted_group_numbers(): + return np.array([5, 5, 5, 5, 1, 1, 1, 1, 3, 3, 2, 2, 2, 4, 4, 0]) + + +@pytest.fixture +def not_consecutive_group_numbers(): + return np.array([0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 2, 2, 2, 2]) + + +@pytest.fixture +def group_names(): + return np.array( + [ + "2021-01", + "2021-02", + "2021-02", + "2021-02", + "2021-02", + "2021-03", + "2021-03", + "2021-03", + "2021-04", + "2021-04", + "2021-05", + "2021-05", + "2021-06", + "2021-06", + "2021-06", + "2021-06", + ] + ) + + +@pytest.fixture +def not_sorted_group_names(): + return np.array( + [ + "2021-06", + "2021-06", + "2021-06", + "2021-06", + "2021-02", + "2021-02", + "2021-02", + "2021-02", + "2021-04", + "2021-04", + "2021-03", + "2021-03", + "2021-03", + "2021-05", + "2021-05", + "2021-01", + ] + ) + + +@pytest.fixture +def not_consecutive_group_names(): + return np.array( + [ + "2021-01", + "2021-02", + "2021-02", + "2021-02", + "2021-02", + "2021-03", + "2021-03", + "2021-03", + "2021-04", + "2021-04", + "2021-05", + "2021-05", + "2021-03", + "2021-03", + "2021-03", + "2021-03", + ] + ) + + +def check_splits(X, y, groups, cv_args, expected_results): + cv = GroupTimeSeriesSplit(**cv_args) + results = list(cv.split(X, y, groups)) + + assert len(results) == len(expected_results) + + for split, expected_split in zip(results, expected_results): + assert np.array_equal(split[0], expected_split[0]) + assert np.array_equal(split[1], expected_split[1]) + + return cv + + +def test_get_n_splits(X, y, group_numbers): + cv_args = {"test_size": 1, "train_size": 3} + expected_results = [ + (np.array([0, 1, 2, 3, 4, 5, 6, 7]), np.array([8, 9])), + (np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]), np.array([10, 11])), + (np.array([5, 6, 7, 8, 9, 10, 11]), np.array([12, 13, 14, 15])), + ] + cv = check_splits(X, y, group_numbers, cv_args, expected_results) + + assert cv.get_n_splits() == len(expected_results) + + +def test_train_size(X, y, group_numbers): + cv_args = {"test_size": 1, "train_size": 3} + expected_results = [ + (np.array([0, 1, 2, 3, 4, 5, 6, 7]), np.array([8, 9])), + (np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]), np.array([10, 11])), + (np.array([5, 6, 7, 8, 9, 10, 11]), np.array([12, 13, 14, 15])), + ] + check_splits(X, y, group_numbers, cv_args, expected_results) + + +def test_train_size_group_names(X, y, group_names): + cv_args = {"test_size": 1, "train_size": 3} + expected_results = [ + (np.array([0, 1, 2, 3, 4, 5, 6, 7]), np.array([8, 9])), + (np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]), np.array([10, 11])), + (np.array([5, 6, 7, 8, 9, 10, 11]), np.array([12, 13, 14, 15])), + ] + check_splits(X, y, group_names, cv_args, expected_results) + + +def test_n_splits(X, y, group_numbers): + cv_args = {"test_size": 2, "n_splits": 3} + expected_results = [ + (np.array([0, 1, 2, 3, 4]), np.array([5, 6, 7, 8, 9])), + (np.array([1, 2, 3, 4, 5, 6, 7]), np.array([8, 9, 10, 11])), + (np.array([5, 6, 7, 8, 9]), np.array([10, 11, 12, 13, 14, 15])), + ] + check_splits(X, y, group_numbers, cv_args, expected_results) + + +def test_n_splits_gap_size(X, y, group_numbers): + cv_args = {"test_size": 1, "n_splits": 3, "gap_size": 1} + expected_results = [ + (np.array([0, 1, 2, 3, 4]), np.array([8, 9])), + (np.array([1, 2, 3, 4, 5, 6, 7]), np.array([10, 11])), + (np.array([5, 6, 7, 8, 9]), np.array([12, 13, 14, 15])), + ] + check_splits(X, y, group_numbers, cv_args, expected_results) + + +def test_n_splits_shift_size(X, y, group_numbers): + cv_args = {"test_size": 1, "n_splits": 3, "gap_size": 1} + expected_results = [ + (np.array([0, 1, 2, 3, 4]), np.array([8, 9])), + (np.array([1, 2, 3, 4, 5, 6, 7]), np.array([10, 11])), + (np.array([5, 6, 7, 8, 9]), np.array([12, 13, 14, 15])), + ] + check_splits(X, y, group_numbers, cv_args, expected_results) + + +def test_n_splits_expanding_window(X, y, group_numbers): + cv_args = {"test_size": 3, "n_splits": 3, "window_type": "expanding"} + expected_results = [ + (np.array([0]), np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])), + (np.array([0, 1, 2, 3, 4]), np.array([5, 6, 7, 8, 9, 10, 11])), + ( + np.array([0, 1, 2, 3, 4, 5, 6, 7]), + np.array([8, 9, 10, 11, 12, 13, 14, 15]), + ), + ] + check_splits(X, y, group_numbers, cv_args, expected_results) + + +def test_full_usage_of_data(X, y, group_numbers): + cv_args = {"test_size": 3, "train_size": 2, "n_splits": 2} + expected_results = [ + (np.array([0, 1, 2, 3, 4]), np.array([5, 6, 7, 8, 9, 10, 11])), + ( + np.array([1, 2, 3, 4, 5, 6, 7]), + np.array([8, 9, 10, 11, 12, 13, 14, 15]), + ), + ] + check_splits(X, y, group_numbers, cv_args, expected_results) + + +def test_partial_usage_of_data(X, y, group_numbers): + cv_args = {"test_size": 2, "train_size": 2, "n_splits": 2} + expected_results = [ + (np.array([1, 2, 3, 4, 5, 6, 7]), np.array([8, 9, 10, 11])), + (np.array([5, 6, 7, 8, 9]), np.array([10, 11, 12, 13, 14, 15])), + ] + check_splits(X, y, group_numbers, cv_args, expected_results) + + +def test_not_sorted_group_numbers(X, y, not_sorted_group_numbers): + cv_args = {"test_size": 1, "train_size": 3} + expected_results = [ + (np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), np.array([10, 11, 12])), + (np.array([4, 5, 6, 7, 8, 9, 10, 11, 12]), np.array([13, 14])), + (np.array([8, 9, 10, 11, 12, 13, 14]), np.array([15])), + ] + + check_splits(X, y, not_sorted_group_numbers, cv_args, expected_results) + + +def test_not_sorted_group_names(X, y, not_sorted_group_names): + cv_args = {"test_size": 1, "train_size": 3} + expected_results = [ + (np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), np.array([10, 11, 12])), + (np.array([4, 5, 6, 7, 8, 9, 10, 11, 12]), np.array([13, 14])), + (np.array([8, 9, 10, 11, 12, 13, 14]), np.array([15])), + ] + + check_splits(X, y, not_sorted_group_names, cv_args, expected_results) + + +def test_not_specified_train_size_n_splits(X, y, group_numbers): + cv_args = {"test_size": 1} + expected_results = None + error_message = "Either train_size or n_splits should be defined" + + with pytest.raises(ValueError, match=error_message): + check_splits(X, y, group_numbers, cv_args, expected_results) + + +def test_bad_window_type(X, y, group_numbers): + cv_args = { + "test_size": 1, + "train_size": 3, + "window_type": "incorrect_window_type", + } + expected_results = None + error_message = 'Window type can be either "rolling" or "expanding"' + + with pytest.raises(ValueError, match=error_message): + check_splits(X, y, group_numbers, cv_args, expected_results) + + +def test_train_size_with_expanding_window(X, y, group_numbers): + cv_args = {"test_size": 1, "train_size": 3, "window_type": "expanding"} + expected_results = None + error_message = "Train size can be specified only with rolling window" + + with pytest.raises(ValueError, match=error_message): + check_splits(X, y, group_numbers, cv_args, expected_results) + + +def test_not_specified_groups(X, y): + cv_args = {"test_size": 1, "train_size": 3} + expected_results = None + error_message = "The groups should be specified" + + with pytest.raises(ValueError, match=error_message): + check_splits(X, y, None, cv_args, expected_results) + + +def test_not_consecutive_group_numbers(X, y, not_consecutive_group_numbers): + cv_args = {"test_size": 1, "train_size": 3} + expected_results = None + error_message = "The groups should be consecutive" + + with pytest.raises(ValueError, match=error_message): + check_splits(X, y, not_consecutive_group_numbers, cv_args, expected_results) + + +def test_not_consecutive_group_names(X, y, not_consecutive_group_names): + cv_args = {"test_size": 1, "train_size": 3} + expected_results = None + error_message = "The groups should be consecutive" + + with pytest.raises(ValueError, match=error_message): + check_splits(X, y, not_consecutive_group_names, cv_args, expected_results) + + +def test_too_large_train_size_(X, y, group_numbers): + cv_args = {"test_size": 1, "train_size": 10} + expected_results = None + error_message = ( + r"Not enough data to split number of groups \(6\)" + r" for number splits \(-4\) with train size \(10\)," + r" test size \(1\), gap size \(0\), shift size \(1\)" + ) + + with pytest.raises(ValueError, match=error_message): + check_splits(X, y, group_numbers, cv_args, expected_results) + + +def test_too_large_n_splits(X, y, group_numbers): + cv_args = {"test_size": 1, "n_splits": 10} + expected_results = None + error_message = ( + r"Not enough data to split number of groups \(6\)" + r" for number splits \(10\) with train size \(-4\)," + r" test size \(1\), gap size \(0\), shift size \(1\)" + ) + + with pytest.raises(ValueError, match=error_message): + check_splits(X, y, group_numbers, cv_args, expected_results) + + +def test_too_large_train_size_n_splits(X, y, group_numbers): + cv_args = {"test_size": 1, "train_size": 10, "n_splits": 10} + expected_results = None + error_message = ( + r"Not enough data to split number of groups \(6\)" + r" for number splits \(10\) with train size \(10\)," + r" test size \(1\), gap size \(0\), shift size \(1\)" + ) + + with pytest.raises(ValueError, match=error_message): + check_splits(X, y, group_numbers, cv_args, expected_results) + + +def test_too_large_shift_size(X, y, group_numbers): + cv_args = {"test_size": 1, "n_splits": 3, "shift_size": 10} + expected_results = None + error_message = ( + r"Not enough data to split number of groups \(6\)" + r" for number splits \(3\) with train size \(-15\)," + r" test size \(1\), gap size \(0\), shift size \(10\)" + ) + + with pytest.raises(ValueError, match=error_message): + check_splits(X, y, group_numbers, cv_args, expected_results) + + +def test_too_large_gap_size(X, y, group_numbers): + cv_args = {"test_size": 1, "n_splits": 3, "gap_size": 10} + expected_results = None + error_message = ( + r"Not enough data to split number of groups \(6\)" + r" for number splits \(3\) with train size \(-7\)," + r" test size \(1\), gap size \(10\), shift size \(1\)" + ) + + with pytest.raises(ValueError, match=error_message): + check_splits(X, y, group_numbers, cv_args, expected_results) + + +def test_cross_val_score(X, y, group_numbers): + cv_args = {"test_size": 1, "train_size": 3} + cv = GroupTimeSeriesSplit(**cv_args) + + expected_scores = np.array([0, 0.5, 0.25]) + clf = DummyClassifier(strategy="most_frequent") + scoring = "accuracy" + cv_scores = cross_val_score(clf, X, y, groups=group_numbers, scoring=scoring, cv=cv) + + assert np.array_equal(cv_scores, expected_scores) diff --git a/mlxtend/evaluate/time_series.py b/mlxtend/evaluate/time_series.py new file mode 100644 index 000000000..242c83b3c --- /dev/null +++ b/mlxtend/evaluate/time_series.py @@ -0,0 +1,349 @@ +# mlxtend Machine Learning Library Extensions +# +# Time series cross validation with grouping. +# Author: Dmitry Labazkin +# +# License: BSD 3 clause + +from itertools import accumulate, chain, groupby, islice + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.patches import Patch +from matplotlib.ticker import MaxNLocator +from sklearn.utils import indexable + + +class GroupTimeSeriesSplit: + """Group time series cross-validator. + + Parameters + ---------- + test_size : int + Size of test dataset. + train_size : int (default=None) + Size of train dataset. + n_splits : int (default=None) + Number of the splits. + gap_size : int (default=0) + Gap size between train and test datasets. + shift_size : int (default=1) + Step to shift for the next fold. + window_type : str (default="rolling") + Type of the window. Possible values: "rolling", "expanding". + + Examples + ----------- + For usage examples, please see + http://rasbt.github.io/mlxtend/user_guide/evaluate/GroupTimeSeriesSplit/ + """ + + def __init__( + self, + test_size, + train_size=None, + n_splits=None, + gap_size=0, + shift_size=1, + window_type="rolling", + ): + + if (train_size is None) and (n_splits is None): + raise ValueError("Either train_size or n_splits should be defined") + + if window_type not in ["rolling", "expanding"]: + raise ValueError('Window type can be either "rolling" or "expanding"') + + if (train_size is not None) and (window_type == "expanding"): + raise ValueError("Train size can be specified only with rolling window") + + self.test_size = test_size + self.train_size = train_size + self.n_splits = n_splits + self.gap_size = gap_size + self.shift_size = shift_size + self.window_type = window_type + + def split(self, X, y=None, groups=None): + """Generate indices to split data into training and test set. + + Parameters + ---------- + X : array-like + Training data. + y : array-like (default=None) + Always ignored, exists for compatibility. + groups : array-like (default=None) + Array with group names or sequence numbers. + + Yields + ------ + train : ndarray + The training set indices for that split. + test : ndarray + The testing set indices for that split. + """ + test_size = self.test_size + gap = self.gap_size + shift_size = self.shift_size + X, y, groups = indexable(X, y, groups) + + if groups is None: + raise ValueError("The groups should be specified") + + group_names, group_lengths = zip( + *[ + (group_name, len(list(group_seq))) + for group_name, group_seq in groupby(groups) + ] + ) + n_groups = len(group_names) + + if n_groups != len(set(group_names)): + raise ValueError("The groups should be consecutive") + + self._n_groups = n_groups + group_starts_idx = chain( + [0], + islice(accumulate(group_lengths), len(group_lengths) - 1), + ) + groups_dict = dict(zip(group_names, group_starts_idx)) + n_samples = len(X) + + self._calculate_split_params() + train_size = self.train_size + n_splits = self.n_splits + train_start_idx = self._train_start_idx + train_end_idx = train_start_idx + train_size + test_start_idx = train_end_idx + gap + test_end_idx = test_start_idx + test_size + + for _ in range(n_splits): + train_idx = np.r_[ + slice( + groups_dict[group_names[train_start_idx]], + groups_dict[group_names[train_end_idx]], + ) + ] + + if test_end_idx < n_groups: + test_idx = np.r_[ + slice( + groups_dict[group_names[test_start_idx]], + groups_dict[group_names[test_end_idx]], + ) + ] + else: + test_idx = np.r_[ + slice(groups_dict[group_names[test_start_idx]], n_samples) + ] + + yield train_idx, test_idx + + if self.window_type == "rolling": + train_start_idx = train_start_idx + shift_size + + train_end_idx = train_end_idx + shift_size + test_start_idx = test_start_idx + shift_size + test_end_idx = test_end_idx + shift_size + + def get_n_splits(self, X=None, y=None, groups=None): + """Returns the number of splitting iterations in the cross-validator. + + Parameters + ---------- + X : object + Always ignored, exists for compatibility. + y : object + Always ignored, exists for compatibility. + groups : object + Always ignored, exists for compatibility. + + Returns + ------- + n_splits : int + Returns the number of splitting iterations in the cross-validator. + """ + return self.n_splits + + def _calculate_split_params(self): + train_size = self.train_size + test_size = self.test_size + n_splits = self.n_splits + gap = self.gap_size + shift_size = self.shift_size + n_groups = self._n_groups + + not_enough_data_error = ( + "Not enough data to split number of groups ({0})" + " for number splits ({1})" + " with train size ({2}), test size ({3})," + " gap size ({4}), shift size ({5})" + ) + + if (train_size is None) and (n_splits is not None): + train_size = n_groups - gap - test_size - (n_splits - 1) * shift_size + self.train_size = train_size + + if train_size <= 0: + raise ValueError( + not_enough_data_error.format( + n_groups, + n_splits, + train_size, + test_size, + gap, + shift_size, + ) + ) + train_start_idx = 0 + elif (n_splits is None) and (train_size is not None): + n_splits = (n_groups - train_size - gap - test_size) // shift_size + 1 + self.n_splits = n_splits + + if self.n_splits <= 0: + raise ValueError( + not_enough_data_error.format( + n_groups, + n_splits, + train_size, + test_size, + gap, + shift_size, + ) + ) + train_start_idx = ( + n_groups - train_size - gap - test_size - (n_splits - 1) * shift_size + ) + else: + train_start_idx = ( + n_groups - train_size - gap - test_size - (n_splits - 1) * shift_size + ) + + if train_start_idx < 0: + raise ValueError( + not_enough_data_error.format( + n_groups, + n_splits, + train_size, + test_size, + gap, + shift_size, + ) + ) + + self._train_start_idx = train_start_idx + + +def print_split_info(X, y, groups, **cv_args): + """Print information details about splits.""" + cv = GroupTimeSeriesSplit(**cv_args) + groups = np.array(groups) + + for train_idx, test_idx in cv.split(X, groups=groups): + print("Train indices:", train_idx) + print("Test indices:", test_idx) + print("Train length:", len(train_idx)) + print("Test length:", len(test_idx)) + print("Train groups:", groups[train_idx]) + print("Test groups:", groups[test_idx]) + print("Train group size:", len(set(groups[train_idx]))) + print("Test group size:", len(set(groups[test_idx]))) + print("Train group months:", X.index[train_idx].values) + print("Test group months:", X.index[test_idx].values) + print() + + +def plot_split_indices(cv, cv_args, X, y, groups, n_splits, image_file_path=None): + """Create a sample plot for indices of a cross-validation object.""" + fig, ax = plt.subplots(figsize=(12, 4)) + cmap_data = plt.cm.tab20 + cmap_cv = plt.cm.coolwarm + lw = 10 + marker_size = 200 + + for split_idx, (train_idx, test_idx) in enumerate( + cv.split(X=X, y=y, groups=groups) + ): + indices = np.array([np.nan] * len(X)) + indices[test_idx] = 1 + indices[train_idx] = 0 + + ax.scatter( + range(len(X)), + [split_idx + 0.5] * len(X), + c=indices, + marker="_", + lw=lw, + cmap=cmap_cv, + vmin=-0.4, + vmax=1.4, + s=marker_size, + ) + + ax.scatter( + range(len(X)), + [split_idx + 1.5] * len(X), + c=groups, + marker="_", + lw=lw, + cmap=cmap_data, + s=marker_size, + ) + + yticklabels = list(range(n_splits)) + ["group"] + ax.set( + yticks=np.arange(n_splits + 1) + 0.5, + yticklabels=yticklabels, + ylabel="CV iteration", + ylim=[n_splits + 1.2, -0.2], + xlim=[-0.5, len(indices) - 0.5], + ) + + ax.legend( + [Patch(color=cmap_cv(0.2)), Patch(color=cmap_cv(0.8))], + ["Training set", "Testing set"], + loc=(1.02, 0.8), + fontsize=13, + ) + + ax.set_title("{}\n{}".format(type(cv).__name__, cv_args), fontsize=15) + ax.xaxis.set_major_locator(MaxNLocator(min_n_ticks=len(X), integer=True)) + ax.set_xlabel(xlabel="Sample index", fontsize=13) + ax.set_ylabel(ylabel="CV iteration", fontsize=13) + ax.tick_params(axis="both", which="major", labelsize=13) + ax.tick_params(axis="both", which="minor", labelsize=13) + + plt.tight_layout() + + if image_file_path: + plt.savefig(image_file_path, bbox_inches="tight") + + plt.show() + + +def plot_splits(X, y, groups, image_file_path=None, **cv_args): + """Visualize splits by group.""" + cv = GroupTimeSeriesSplit(**cv_args) + cv._n_groups = len(np.unique(groups)) + cv._calculate_split_params() + n_splits = cv.n_splits + + plot_split_indices( + cv, cv_args, X, y, groups, n_splits, image_file_path=image_file_path + ) + + +def print_cv_info(cv, X, y, groups, clf, scores): + """Print information details about cross-validation usage with classifier.""" + for split_idx, (train_idx, test_idx) in enumerate(cv.split(X, y, groups)): + clf.fit(X.iloc[train_idx], y.iloc[train_idx]) + y_train_pred = clf.predict(X.iloc[train_idx]) + y_test_pred = clf.predict(X.iloc[test_idx]) + print(f"Split number: {split_idx + 1}") + print(f"Train true target: {y.iloc[train_idx].values}") + print(f"Train predicted target: {y_train_pred}") + print(f"Test true target: {y.iloc[test_idx].values}") + print(f"Test predicted target: {y_test_pred}") + print(f"Accuracy: {scores[split_idx].round(2)}") + print()