-
Notifications
You must be signed in to change notification settings - Fork 3
/
kernels.py
54 lines (45 loc) · 1.66 KB
/
kernels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from operator import mul, add
from functools import reduce
import gpflowSlim.kernels as gfsk
import gpflowSlim as gfs
import tensorflow as tf
def SpectralMixture(params, name):
"""
Build the SpectralMixture kernel.
:params: list of dict. With each item corresponding to one mixture.
The dict is formatted as {'w': float, 'rbf': dict, 'cos': dict}.
That each sub-dict is used to the init corresponding kernel.
"""
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
sm = 0.
for i in range(len(params)):
w = gfs.Param(params[i]['w'], transform=gfs.transforms.positive, name='w' + str(i))
sm = gfsk.RBF(**params[i]['rbf']) * gfsk.Cosine(**params[i]['cos']) * w.value + sm
return sm
_KERNEL_DICT=dict(
White=gfsk.White,
Constant=gfsk.Constant,
ExpQuad=gfsk.RBF,
RBF=gfsk.RBF,
Matern12=gfsk.Matern12,
Matern32=gfsk.Matern32,
Matern52=gfsk.Matern52,
Cosine=gfsk.Cosine,
ArcCosine=gfsk.ArcCosine,
Linear=gfsk.Linear,
Periodic=gfsk.Periodic,
RatQuad=gfsk.RatQuad,
SM=SpectralMixture,
)
def KernelWrapper(hparams):
"""
Wrapper for Kernels.
:param hyparams: list of dict. Each item corresponds to one primitive kernel.
The dict is formatted as {'name': XXX, 'params': XXX}.
e.g.
[{'name': 'Linear', params={'c': 0.1, 'input_dim': 100}},
{'name': 'Periodic', params={'period': 2, 'input_dim': 100, 'ls': 2}}]
"""
assert len(hparams) > 0, 'At least one kernel should be provided.'
with tf.variable_scope('KernelWrapper'):
return [_KERNEL_DICT[k['name']](**k['params']) for k in hparams]