-
Notifications
You must be signed in to change notification settings - Fork 2
/
protocol_factories.py
89 lines (69 loc) · 2.28 KB
/
protocol_factories.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""
Reference CV splitter instances, for usage across all experimental runs and plotting
"""
from typing import List
import numpy as np
from algorithm_factories import get_key_for_factory
from data.train_test_split import (
AbstractTrainTestSplitter,
BioSplitter,
BlockPostionSplitter,
FractionalRandomSplitter,
PositionSplitter,
RandomSplitter,
WeightedTaskSplitter,
)
# SPLITTING: random, block, positional, fractional-random, mutation-lvl
def BioSplitterFactory(
dataset: str, n_mutations_train: int, n_mutations_test: int
) -> List[AbstractTrainTestSplitter]:
return [
BioSplitter(
dataset,
n_mutations_train=n_mutations_train,
n_mutations_test=n_mutations_test,
)
]
def FractionalSplitterFactory(
dataset: str, fractions: np.ndarray = None
) -> List[AbstractTrainTestSplitter]:
if not fractions:
fractions = np.concatenate(
[
np.arange(0.001, 0.3, 0.01),
np.arange(0.3, 0.6, 0.03),
np.arange(0.6, 1.05, 0.05),
]
)
return [FractionalRandomSplitter(dataset, fraction) for fraction in fractions]
def PositionalSplitterFactory(
dataset: str, positions: int = 15
) -> List[AbstractTrainTestSplitter]:
return [PositionSplitter(dataset, positions=positions)]
def BlockSplitterFactory(dataset) -> List[AbstractTrainTestSplitter]:
return [BlockPostionSplitter(dataset)]
def RandomSplitterFactory(dataset, k: int = 10) -> List[AbstractTrainTestSplitter]:
return [RandomSplitter(dataset, k=k)]
def WeightedTaskSplitterFactory(
dataset, threshold=3
) -> List[AbstractTrainTestSplitter]:
return [WeightedTaskSplitter(dataset, threshold=threshold)]
def WeightedTaskRegressSplitterFactory(
dataset, threshold=0.5
) -> List[AbstractTrainTestSplitter]:
# use X_p fraction of functional observations for training, rest for testing
return [
WeightedTaskSplitter(
dataset, threshold=threshold, X_p_fraction=0.15, split_type="threshold"
)
]
PROTOCOL_REGISTRY = {
get_key_for_factory(f): f
for f in [
RandomSplitterFactory,
BlockSplitterFactory,
PositionalSplitterFactory,
FractionalSplitterFactory,
BioSplitterFactory,
]
}