Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for random walk kernel #29

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added GP/__init__.py
Empty file.
Empty file added GP/kernel_modules/__init__.py
Empty file.
14 changes: 14 additions & 0 deletions GP/kernel_modules/kernel_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Author: Henry Moss & Ryan-Rhys Griffiths
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can change the author name to yours!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha was waiting for permission

"""
Utility methods for graph-based kernels
"""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we change this module-level docstring?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably


import tensorflow as tf


def normalize(k_matrix):
k_matrix_diagonal = tf.linalg.diag_part(k_matrix)
squared_normalization_factor = tf.multiply(tf.expand_dims(k_matrix_diagonal, 1),
tf.expand_dims(k_matrix_diagonal, 0))

return tf.divide(k_matrix, tf.sqrt(squared_normalization_factor))
138 changes: 138 additions & 0 deletions GP/kernel_modules/random_walk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Author: Henry Moss & Ryan-Rhys Griffiths
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a kernel_modules directory is probably a good idea.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy for it to have another (less clumsy) name but yeah I think it was the only way for me to implement the kernel without contaminating pre-existing code

"""
Molecule kernels for Gaussian Process Regression implemented in GPflow.
"""

import gpflow
import numpy as np
import tensorflow as tf

from math import factorial

from .kernel_utils import normalize


class RandomWalk(gpflow.kernels.Kernel):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have some documentation for the class?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely. I suppose I hoped for the main logic to be looked at first but honestly I should have made a separate feature branch and done multiple PRs onto that (one for the logic, one for documentation etc)

def __init__(self, normalize=True, weight=0.1, series_type='geometric', p=None, uniform_probabilities=False):
super().__init__()
self.normalize = normalize
self.weight = weight
if series_type == 'geometric':
self.geometric = True
elif series_type == 'exponential':
self.geometric = False
self.p = p
self.uniform_probabilities = uniform_probabilities

def K(self, X, X2=None):
"""
Compute the random walk graph kernel (Gartner et al. 2003),
specifically using the spectral decomposition approach
given by https://www.jmlr.org/papers/volume11/vishwanathan10a/vishwanathan10a.pdf

:param X: array of N graph objects (represented as adjacency matrices of varying sizes)
:param X2: array of M graph objects (represented as adjacency matrices of varying sizes)
If None, compute the N x N kernel matrix for X.
:return: The kernel matrix of dimension N x M.
"""
if X2 is None:
X2 = X

X_is_X2 = X == X2

eigenvecs, eigenvals = [], []
eigenvecs_2, eigenvals_2 = [], []

for adj_mat in X:
val, vec = tf.linalg.eigh(tf.cast(adj_mat, tf.float64))
eigenvals.append(val)
eigenvecs.append(vec)

flanking_factors = self._generate_flanking_factors(eigenvecs)

if X_is_X2:
eigenvals_2, eigenvecs_2 = eigenvals, eigenvecs
flanking_factors_2 = flanking_factors
else:
for adj_mat in X2:
val, vec = tf.linalg.eigh(tf.cast(adj_mat, tf.float64))
eigenvals_2.append(val)
eigenvecs_2.append(vec)
flanking_factors_2 = self._generate_flanking_factors(eigenvecs_2)

k_matrix = np.zeros((len(X), len(X2)))

for idx_1 in range(k_matrix.shape[0]):
for idx_2 in range(k_matrix.shape[1]):

if X_is_X2 and idx_2 < idx_1:
k_matrix[idx_1, idx_2] = k_matrix[idx_2, idx_1]
continue

flanking_factor = tf.linalg.LinearOperatorKronecker(
[tf.linalg.LinearOperatorFullMatrix(flanking_factors[idx_1]),
tf.linalg.LinearOperatorFullMatrix(flanking_factors_2[idx_2])
]).to_dense()

diagonal = self.weight * tf.linalg.LinearOperatorKronecker(
[tf.linalg.LinearOperatorFullMatrix(tf.expand_dims(eigenvals[idx_1], axis=0)),
tf.linalg.LinearOperatorFullMatrix(tf.expand_dims(eigenvals_2[idx_2], axis=0))
]).to_dense()

if self.p is not None:
power_series = tf.ones_like(diagonal)
temp_diagonal = tf.ones_like(diagonal)

for k in range(self.p):
temp_diagonal = tf.multiply(temp_diagonal, diagonal)
if not self.geometric:
temp_diagonal = tf.divide(temp_diagonal, factorial(k+1))
power_series = tf.add(power_series, temp_diagonal)

power_series = tf.linalg.diag(power_series)
else:
if self.geometric:
power_series = tf.linalg.diag(1 / (1 - diagonal))
else:
power_series = tf.linalg.diag(tf.exp(diagonal))

k_matrix[idx_1, idx_2] = tf.linalg.matmul(
flanking_factor,
tf.linalg.matmul(
power_series,
tf.transpose(flanking_factor, perm=[1, 0])
)
)

if self.normalize:
return tf.convert_to_tensor(normalize(k_matrix))

return tf.convert_to_tensor(k_matrix)

def K_diag(self, X):
"""
Compute the diagonal of the N x N kernel matrix of X.
:param X: array of N graph objects (represented as adjacency matrices of varying sizes).
:return: N x 1 array.
"""
return tf.linalg.tensor_diag_part(self.K(X))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't think of anything smarter than implementing this function like this


def _generate_flanking_factors(self, eigenvecs):
"""
Helper method to calculate intermediate terms in the expression for random
walk kernel evaluated for two graphs.
:param eigenvecs: array of N matrices of varying sizes
:return: array of N matrices of varying sizes
"""
flanking_factors = []

for eigenvec in eigenvecs:
start_stop_probs = tf.ones((1, eigenvec.shape[0]), tf.float64)
if self.uniform_probabilities:
start_stop_probs = tf.divide(start_stop_probs, eigenvec.shape(0))

flanking_factors.append(
tf.linalg.matmul(start_stop_probs, eigenvec)
)

return flanking_factors
Empty file added tests/__init__.py
Empty file.
Empty file added tests/kernels/__init__.py
Empty file.
59 changes: 59 additions & 0 deletions tests/kernels/test_random_walk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Author: Henry Moss & Ryan-Rhys Griffiths
"""
Verifies the FlowMO implementation of the Random Walk graph kernel
against GraKel
"""

import os

import grakel
import numpy.testing as npt
import pandas as pd
import pytest
import tensorflow as tf
from rdkit.Chem import MolFromSmiles
from rdkit.Chem.rdmolops import GetAdjacencyMatrix

from GP.kernel_modules.random_walk import RandomWalk

@pytest.fixture
def load_data():
benchmark_path = os.path.abspath(
os.path.join(
os.getcwd(), '..', '..', 'datasets', 'ESOL.csv'
)
)
df = pd.read_csv(benchmark_path)
smiles = df["smiles"].to_list()

adj_mats = [GetAdjacencyMatrix(MolFromSmiles(smiles)) for smiles in smiles[:50]]
tensor_adj_mats = [tf.convert_to_tensor(adj_mat) for adj_mat in adj_mats]
grakel_graphs = [grakel.Graph(adj_mat) for adj_mat in adj_mats]

return tensor_adj_mats, grakel_graphs


@pytest.mark.parametrize(
'weight, series_type, p',
[
(0.1, 'geometric', None),
(0.1, 'exponential', None),
#(0.3, 'geometric', None), #Requires `method_type="baseline" in GraKel kernel constructor
(0.3, 'exponential', None),
(0.3, 'geometric', 3), #Doesn't pass due to suspected GraKel bug, see https://github.com/ysig/GraKeL/issues/71
(0.8, 'exponential', 3), #Same issue as above test
]
)
def test_random_walk_unlabelled(weight, series_type, p, load_data):
tensor_adj_mats, grakel_graphs = load_data

random_walk_grakel = grakel.kernels.RandomWalk(normalize=True, lamda=weight, kernel_type=series_type, p=p)
grakel_results = random_walk_grakel.fit_transform(grakel_graphs)

random_walk_FlowMo = RandomWalk(normalize=True, weight=weight, series_type=series_type, p=p)
FlowMo_results = random_walk_FlowMo.K(tensor_adj_mats, tensor_adj_mats)

npt.assert_almost_equal(
grakel_results, FlowMo_results.numpy(),
decimal=2
)