-
Notifications
You must be signed in to change notification settings - Fork 2
/
algorithm_factories.py
102 lines (80 loc) · 2.78 KB
/
algorithm_factories.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import warnings
from typing import Callable
from gpflow.kernels import (
Linear,
LinearCoregionalization,
Matern52,
Polynomial,
SquaredExponential,
)
from algorithms.abstract_algorithm import AbstractAlgorithm
from algorithms.gmm_regression import GMMRegression
from algorithms.gp_on_real_space import GPonRealSpace
from algorithms.KNN import KNN
from algorithms.one_hot_gp import GPOneHotSequenceSpace
from algorithms.random_forest import RandomForest
from algorithms.uncertain_rf import UncertainRandomForest
from util.mlflow.constants import ONE_HOT
def RandomForestFactory(representation, alphabet, optimize):
return RandomForest(optimize=optimize)
def UncertainRFFactory(representation, alphabet, optimize):
return UncertainRandomForest(optimize=optimize)
def KNNFactory(representation, alphabet, optimize):
return KNN(optimize=optimize)
def GPLinearFactory(representation, alphabet, optimize):
if representation is ONE_HOT:
return GPOneHotSequenceSpace(alphabet_size=len(alphabet), optimize=optimize)
else:
return GPonRealSpace(optimize=optimize)
def GPSEFactory(representation, alphabet, optimize):
if representation is ONE_HOT:
return GPOneHotSequenceSpace(
alphabet_size=len(alphabet),
kernel_factory=lambda: SquaredExponential(),
optimize=optimize,
)
else:
return GPonRealSpace(
kernel_factory=lambda: SquaredExponential(), optimize=optimize
)
def GPMaternFactory(representation, alphabet, optimize):
if representation is ONE_HOT:
return GPOneHotSequenceSpace(
alphabet_size=len(alphabet),
kernel_factory=lambda: Matern52(),
optimize=optimize,
)
else:
return GPonRealSpace(kernel_factory=lambda: Matern52(), optimize=optimize)
def GPLinearRegionFactory(representation, alphabet, optimize):
if representation is ONE_HOT:
return GPOneHotSequenceSpace(
alphabet_size=len(alphabet),
kernel_factory=lambda: LinearCoregionalization(
kernels=[Linear(0), Linear(1)]
),
optimize=optimize,
)
else:
return GPonRealSpace(
kernel_factory=lambda: LinearCoregionalization(
kernels=[Linear(0), Linear(1)]
),
optimize=optimize,
)
def GMMFactory(representation, alphabet, optimize, n_components=2):
return GMMRegression(n_components)
def get_key_for_factory(f: Callable[[], AbstractAlgorithm]):
return f.__name__
ALGORITHM_REGISTRY = {
get_key_for_factory(f): f
for f in [
RandomForestFactory,
UncertainRFFactory,
KNNFactory,
GPLinearFactory,
GPSEFactory,
GPMaternFactory,
GMMFactory,
]
}