Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add CombinatorialGapKFold #41

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ conda install -c conda-forge tscv

## Usage

This extension defines 3 cross-validator classes and 1 function:
This extension defines 4 cross-validator classes and 1 function:
- `GapLeavePOut`
- `GapKFold`
- `GapRollForward`
- `CombinatorialGapKFold`
- `gap_train_test_split`

The three classes can all be passed, as the `cv` argument, to
The four classes can all be passed, as the `cv` argument, to
scikit-learn functions such as `cross-validate`, `cross_val_score`,
and `cross_val_predict`, just like the native cross-validator classes.
and `cross_val_predict` (except `CombinatorialGapKFold`), just like the native cross-validator classes.

The one function is an alternative to the `train_test_split` function in `scikit-learn`.

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,5 @@ def get_version(rel_path):
'Programming Language :: Python :: 3.9',
],
python_requires=">=3.6",
install_requires=['numpy>=1.13.3', 'scikit-learn>=0.22']
install_requires=['numpy>=1.13.3', 'scipy>=1.3.0', 'scikit-learn>=0.22']
)
2 changes: 2 additions & 0 deletions tscv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ._split import GapKFold
from ._split import GapWalkForward
from ._split import GapRollForward
from ._split import CombinatorialGapKFold
from ._split import gap_train_test_split


Expand All @@ -13,4 +14,5 @@
'GapKFold',
'GapWalkForward',
'GapRollForward',
'CombinatorialGapKFold',
'gap_train_test_split']
119 changes: 118 additions & 1 deletion tscv/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import numbers
from math import modf
from abc import ABCMeta, abstractmethod
from itertools import chain
from itertools import chain, combinations
from inspect import signature
from scipy.special import comb
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add scipy in

TSCV/setup.py

Line 55 in 2abbc3d

install_requires=['numpy>=1.13.3', 'scikit-learn>=0.22']


import numpy as np
from sklearn.utils import indexable
Expand All @@ -24,6 +25,7 @@
'GapLeavePOut',
'GapKFold',
'GapWalkForward',
'CombinatorialGapKFold',
'gap_train_test_split']


Expand Down Expand Up @@ -371,6 +373,121 @@ def get_n_splits(self, X=None, y=None, groups=None):
return self.n_splits


class CombinatorialGapKFold(GapCrossValidator):
"""Combinatorial K-Folds cross-validator with Gaps

Provides train/test indices to split data in train/test sets. Split
dataset into N groups of k folds (without shuffling).

Parameters
----------
N : int, default=5
Number of groups. Must be at least 2.

k : int, default=2
Number of test splits. Must be at least 1.

gap_before : int, default=0
Gap before the test sets.

gap_after : int, default=0
Gap after the test sets.

Examples
--------
>>> import numpy as np
>>> from tscv import CombinatorialGapKFold
>>> cgkf = CombinatorialGapKFold(N=5, k=2, gap_before=1, gap_after=1)
>>> cgkf.get_n_splits(np.arange(10))
10
>>> print(cgkf)
CombinatorialGapKFold(N=None, gap_after=1, gap_before=1, k=None)
>>> for train_index, test_index in cgkf.split(np.arange(10)):
... print("TRAIN:", train_index, "TEST:", test_index)
TRAIN: [5 6 7 8 9] TEST: [0 1 2 3]
TRAIN: [7 8 9] TEST: [0 1 4 5]
TRAIN: [3 4 9] TEST: [0 1 6 7]
TRAIN: [3 4 5 6] TEST: [0 1 8 9]
TRAIN: [0 7 8 9] TEST: [2 3 4 5]
TRAIN: [0 9] TEST: [2 3 6 7]
TRAIN: [0 5 6] TEST: [2 3 8 9]
TRAIN: [0 1 2 9] TEST: [4 5 6 7]
TRAIN: [0 1 2] TEST: [4 5 8 9]
TRAIN: [0 1 2 3 4] TEST: [6 7 8 9]
"""

def __init__(self, N=5, k=2, gap_before=0, gap_after=0):
if not isinstance(N, numbers.Integral):
raise ValueError('The number of groups must be of Integral type. '
'%s of type %s was passed.'
% (N, type(N)))
N = int(N)

if not isinstance(k, numbers.Integral):
raise ValueError('The number of test splits must be of Integral '
'type. %s of type %s was passed.'
% (k, type(k)))
k = int(k)

if N <= 1:
raise ValueError(
"Combinatorial k-fold cross-validation requires at least two"
" groups by setting N=2 or more,"
" got N={0}.".format(N))

if k < 1:
raise ValueError(
"Combinatorial k-fold cross-validation requires at least one"
" test split by setting k=1 or more,"
" got k={0}.".format(k))

super().__init__(gap_before, gap_after)
self.n_groups = N
self.test_splits = k

def _iter_test_indices(self, X, y=None, groups=None):
n_samples = _num_samples(X)
n_splits = self.n_groups
gap_before, gap_after = self.gap_before, self.gap_after
if n_splits > n_samples:
raise ValueError(
("Cannot have number of splits n_splits={0} greater"
" than the number of samples: n_samples={1}.")
.format(self.n_groups, n_samples))
self.indexes = np.arange(n_samples)
splits = [(split[0], split[-1] + 1)
for split
in np.array_split(self.indexes, self.n_groups)]
splits_combinations = list(combinations(splits, self.test_splits))
for splits_combination in splits_combinations:
test_indexes = np.empty(0)
for start, stop in splits_combination:
test_indexes = np.union1d(
test_indexes, self.indexes[start:stop]).astype(int)
yield test_indexes

def get_n_splits(self, X=None, y=None, groups=None):
"""Returns the number of splitting iterations in the cross-validator

Parameters
----------
X : object
Always ignored, exists for compatibility.

y : object
Always ignored, exists for compatibility.

groups : object
Always ignored, exists for compatibility.

Returns
-------
n_splits : int
Returns the number of splitting iterations in the cross-validator.
"""
return int(comb(self.n_groups, self.test_splits))


def gap_train_test_split(*arrays, **options):
"""Split arrays or matrices into random train and test subsets (with a gap)

Expand Down
130 changes: 130 additions & 0 deletions tscv/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tscv import GapKFold
from tscv import GapWalkForward
from tscv import GapRollForward
from tscv import CombinatorialGapKFold
from tscv import gap_train_test_split


Expand Down Expand Up @@ -568,3 +569,132 @@ def test_roll_size(self):
train, test = next(splits)
assert_array_equal(train, [0, 1, 2, 3])
assert_array_equal(test, [6])


def test_combinatorial_gap_k_fold():
splits = CombinatorialGapKFold().split(np.arange(10))

train, test = next(splits)
assert_array_equal(train, [4, 5, 6, 7, 8, 9])
assert_array_equal(test, [0, 1, 2, 3])

train, test = next(splits)
assert_array_equal(train, [2, 3, 6, 7, 8, 9])
assert_array_equal(test, [0, 1, 4, 5])

train, test = next(splits)
assert_array_equal(train, [2, 3, 4, 5, 8, 9])
assert_array_equal(test, [0, 1, 6, 7])

train, test = next(splits)
assert_array_equal(train, [2, 3, 4, 5, 6, 7])
assert_array_equal(test, [0, 1, 8, 9])

train, test = next(splits)
assert_array_equal(train, [0, 1, 6, 7, 8, 9])
assert_array_equal(test, [2, 3, 4, 5])

train, test = next(splits)
assert_array_equal(train, [0, 1, 4, 5, 8, 9])
assert_array_equal(test, [2, 3, 6, 7])

train, test = next(splits)
assert_array_equal(train, [0, 1, 4, 5, 6, 7])
assert_array_equal(test, [2, 3, 8, 9])

train, test = next(splits)
assert_array_equal(train, [0, 1, 2, 3, 8, 9])
assert_array_equal(test, [4, 5, 6, 7])

train, test = next(splits)
assert_array_equal(train, [0, 1, 2, 3, 6, 7])
assert_array_equal(test, [4, 5, 8, 9])

train, test = next(splits)
assert_array_equal(train, [0, 1, 2, 3, 4, 5])
assert_array_equal(test, [6, 7, 8, 9])

splits = CombinatorialGapKFold(
N=5, k=2, gap_before=1, gap_after=2).split(np.arange(10))

train, test = next(splits)
assert_array_equal(train, [6, 7, 8, 9])
assert_array_equal(test, [0, 1, 2, 3])

train, test = next(splits)
assert_array_equal(train, [8, 9])
assert_array_equal(test, [0, 1, 4, 5])

train, test = next(splits)
assert_array_equal(train, [4])
assert_array_equal(test, [0, 1, 6, 7])

train, test = next(splits)
assert_array_equal(train, [4, 5, 6])
assert_array_equal(test, [0, 1, 8, 9])

train, test = next(splits)
assert_array_equal(train, [0, 8, 9])
assert_array_equal(test, [2, 3, 4, 5])

train, test = next(splits)
assert_array_equal(train, [0])
assert_array_equal(test, [2, 3, 6, 7])

train, test = next(splits)
assert_array_equal(train, [0, 6])
assert_array_equal(test, [2, 3, 8, 9])

train, test = next(splits)
assert_array_equal(train, [0, 1, 2])
assert_array_equal(test, [4, 5, 6, 7])

train, test = next(splits)
assert_array_equal(train, [0, 1, 2])
assert_array_equal(test, [4, 5, 8, 9])

train, test = next(splits)
assert_array_equal(train, [0, 1, 2, 3, 4])
assert_array_equal(test, [6, 7, 8, 9])

assert_equal(CombinatorialGapKFold(
N=10, k=3).get_n_splits(np.arange(100)), 120)

splits = CombinatorialGapKFold(
N=5, k=3, gap_before=1, gap_after=2).split(np.arange(10))

train, test = next(splits)
assert_array_equal(train, [8, 9])
assert_array_equal(test, [0, 1, 2, 3, 4, 5])

with pytest.raises(
ValueError,
match="The number of groups must be of Integral type. 5.0 of type <class 'float'> was passed."
):
CombinatorialGapKFold(N=5.0, k=2)

with pytest.raises(
ValueError,
match="The number of test splits must be of Integral type. 2 of type <class 'str'> was passed."
):
CombinatorialGapKFold(N=5, k="2")

with pytest.raises(
ValueError,
match="Combinatorial k-fold cross-validation requires at least two groups by setting N=2 or more, got N=1."
):
CombinatorialGapKFold(N=1, k=2)

with pytest.raises(
ValueError,
match="Combinatorial k-fold cross-validation requires at least one test split by setting k=1 or more, got k=0."
):
CombinatorialGapKFold(N=5, k=0)

splits = CombinatorialGapKFold(N=15, k=2).split(np.arange(10))

with pytest.raises(
ValueError,
match="Cannot have number of splits n_splits=15 greater than the number of samples: n_samples=10."
):
next(splits)