Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible PyTorch implementation of WL kernel #153

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
36fc3bd
Add a PyTorch implementation of WL kernel
vladislavalerievich Oct 28, 2024
b0d3842
Fix imports
vladislavalerievich Oct 29, 2024
f87abd6
Remove redundant copy
vladislavalerievich Oct 29, 2024
358fbb7
Increase precision for allclose
vladislavalerievich Oct 29, 2024
de140b6
Fix calculation for graphs with reordered edges
vladislavalerievich Oct 29, 2024
08c7aea
Increase test coverage
vladislavalerievich Oct 29, 2024
6f07858
Improve readability of TorchWLKernel
vladislavalerievich Oct 30, 2024
896f461
Add additional comments to TorchWLKernel
vladislavalerievich Oct 30, 2024
383e924
Add MixedSingleTaskGP to process graphs
vladislavalerievich Nov 8, 2024
65666a3
Refactor WLKernelWrapper into a standalone WLKernel class.
vladislavalerievich Nov 20, 2024
7fa9432
Update tests
vladislavalerievich Nov 20, 2024
4227f22
Add a check for empty inputs
vladislavalerievich Nov 20, 2024
f194bd2
Improve and combine tests
vladislavalerievich Nov 20, 2024
a104840
Update WLKernel
vladislavalerievich Nov 21, 2024
246f9f6
Add acquisition function with graph sampling
vladislavalerievich Nov 21, 2024
770c626
Add a custom __call__ method to pass graphs during optimization
vladislavalerievich Nov 21, 2024
8bf7ea7
Update MixedSingleTaskGP
vladislavalerievich Dec 7, 2024
84d0104
Remove not used argument
vladislavalerievich Dec 7, 2024
d63239a
Update sample_graphs
vladislavalerievich Dec 7, 2024
3db3f89
Handle different batch dimensions
vladislavalerievich Dec 7, 2024
f69ddbe
Set num_restarts=10
vladislavalerievich Dec 7, 2024
1c4cc83
Add acquisition function
vladislavalerievich Dec 7, 2024
dab9a8c
Update WLKernel
vladislavalerievich Dec 7, 2024
2999582
Make train_inputs private
vladislavalerievich Dec 7, 2024
ad55030
Update tests
vladislavalerievich Dec 7, 2024
8093d31
fix: Implement graph acquisition
eddiebergman Dec 16, 2024
9f978d6
fix: Implement graph acquisition (#164)
vladislavalerievich Dec 24, 2024
a1a29a8
Delete unused MixedSingleTaskGP
vladislavalerievich Dec 24, 2024
046ad66
Add seed_all and min_max_scale
vladislavalerievich Dec 24, 2024
0a609f7
Refactor optimize.py
vladislavalerievich Dec 24, 2024
5486dcc
Speed up WL kernel computations
vladislavalerievich Dec 24, 2024
f140c56
Process wl iterations in batches
vladislavalerievich Dec 24, 2024
371b530
Use CSR
vladislavalerievich Dec 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions grakel_replace/grakel_wl_usage_example.py
vladislavalerievich marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

import matplotlib.pyplot as plt
import networkx as nx
from grakel import graph_from_networkx, WeisfeilerLehman


def visualize_graph(G):
"""Visualize the NetworkX graph."""
pos = nx.spring_layout(G)
nx.draw(G, pos, with_labels=True, node_size=700, node_color="lightblue")
plt.show()

def add_labels(G):
"""Add labels to the nodes of the graph."""
for node in G.nodes():
G.nodes[node]['label'] = str(node)

# Create graphs
G1 = nx.Graph()
G1.add_edges_from([(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)])
add_labels(G1)

G2 = nx.Graph()
G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)])
add_labels(G2)

G3 = nx.Graph()
G3.add_edges_from([(0, 1), (1, 3), (3, 2)])
add_labels(G3)

# Visualize the graphs
visualize_graph(G1)
visualize_graph(G2)
visualize_graph(G3)

# Convert NetworkX graphs to Grakel format using graph_from_networkx
graph_list = list(
graph_from_networkx([G1, G2, G3], node_labels_tag="label", as_Graph=True)
)

# Initialize the Weisfeiler-Lehman kernel
wl_kernel = WeisfeilerLehman(n_iter=5, normalize=False)

# Compute the kernel matrix
K = wl_kernel.fit_transform(graph_list)

# Display the kernel matrix
print("Fit and Transform on Kernel matrix (pairwise similarities):")
print(K)
165 changes: 165 additions & 0 deletions grakel_replace/mixed_single_task_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from botorch.models import SingleTaskGP
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import AdditiveKernel
from gpytorch.module import Module
from grakel_replace.torch_wl_kernel import GraphDataset, TorchWLKernel

if TYPE_CHECKING:
import networkx as nx
from torch import Tensor


class MixedSingleTaskGP(SingleTaskGP):
"""A Gaussian Process model that handles numerical, categorical, and graph inputs.

This class extends BoTorch's SingleTaskGP to work with hybrid input spaces containing:
- Numerical features
- Categorical features
- Graph structures

It uses the Weisfeiler-Lehman (WL) kernel for graph inputs and combines it with
standard kernels for numerical/categorical features using an additive kernel structure

Attributes:
_wl_kernel (TorchWLKernel): The Weisfeiler-Lehman kernel for graph similarity
_train_graphs (List[nx.Graph]): Training set graph instances
_K_graph (Tensor): Pre-computed graph kernel matrix for training data
num_cat_kernel (Optional[Module]): Kernel for numerical/categorical features
"""

def __init__(
self,
train_X: Tensor, # Shape: (n_samples, n_numerical_categorical_features)
train_graphs: list[nx.Graph], # List of n_samples graph instances
train_Y: Tensor, # Shape: (n_samples, n_outputs)
train_Yvar: Tensor | None = None, # Shape: (n_samples, n_outputs) or None
num_cat_kernel: Module | None = None,
wl_kernel: TorchWLKernel | None = None,
**kwargs # Additional arguments passed to SingleTaskGP
) -> None:
"""Initialize the mixed input Gaussian Process model.

Args:
train_X: Training data tensor for numerical and categorical features
train_graphs: List of training graphs
train_Y: Target values
train_Yvar: Observation noise variance (optional)
num_cat_kernel: Kernel for numerical/categorical features (optional)
wl_kernel: Custom Weisfeiler-Lehman kernel instance (optional)
**kwargs: Additional arguments for SingleTaskGP initialization
"""
# Initialize parent class with initial covar_module
super().__init__(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
covar_module=num_cat_kernel or self._graph_kernel_wrapper(),
**kwargs
)

# Initialize WL kernel with default parameters if not provided
self._wl_kernel = wl_kernel or TorchWLKernel(n_iter=5, normalize=True)
self._train_graphs = train_graphs

# Convert graphs to required format and compute kernel matrix
self._train_graph_dataset = GraphDataset.from_networkx(train_graphs)
self._K_train = self._wl_kernel(self._train_graph_dataset)

if num_cat_kernel is not None:
# Create additive kernel combining numerical/categorical and graph kernels
combined_kernel = AdditiveKernel(
num_cat_kernel,
self._graph_kernel_wrapper()
)
self.covar_module = combined_kernel

self.num_cat_kernel = num_cat_kernel

def _graph_kernel_wrapper(self) -> Module:
"""Creates a GPyTorch-compatible kernel module wrapping the WL kernel.

This wrapper allows the WL kernel to be used within the GPyTorch framework
by providing a forward method that returns the pre-computed kernel matrix.

Returns:
Module: A GPyTorch kernel module wrapping the WL kernel computation
"""

class WLKernelWrapper(Module):
def __init__(self, parent: MixedSingleTaskGP):
super().__init__()
self.parent = parent

def forward(
self,
x1: Tensor,
x2: Tensor | None = None,
diag: bool = False,
last_dim_is_batch: bool = False
) -> Tensor:
"""Compute the kernel matrix for the graph inputs.

Args:
x1: First input tensor (unused, required for interface compatibility)
x2: Second input tensor (must be None)
diag: Whether to return only diagonal elements
last_dim_is_batch: Whether the last dimension is a batch dimension

Returns:
Tensor: Pre-computed graph kernel matrix

Raises:
NotImplementedError: If x2 is not None (cross-covariance not implemented)
"""
if x2 is None:
return self.parent._K_train

# Compute cross-covariance between train and test graphs
test_dataset = GraphDataset.from_networkx(self.parent._test_graphs)
return self.parent._wl_kernel(
self.parent._train_graph_dataset,
test_dataset
)

return WLKernelWrapper(self)
vladislavalerievich marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal:
"""Forward pass computing the GP distribution for given inputs.

Computes the kernel matrix for both numerical/categorical features and graphs,
combines them if both are present, and returns the resulting GP distribution.

Args:
X: Input tensor for numerical and categorical features
graphs: List of input graphs

Returns:
MultivariateNormal: GP distribution for the given inputs
"""
if len(X) != len(graphs):
raise ValueError(
f"Number of feature vectors ({len(X)}) must match "
f"number of graphs ({len(graphs)})"
)

# Process new graphs and compute kernel matrix
proc_graphs = GraphDataset.from_networkx(graphs)
K_new = self._wl_kernel(proc_graphs) # Shape: (n_samples, n_samples)

# If we have both numerical/categorical and graph features
if self.num_cat_kernel is not None:
# Compute kernel for numerical/categorical features
K_num_cat = self.num_cat_kernel(X)
# Add the kernels (element-wise addition)
K_combined = K_num_cat + K_new
else:
K_combined = K_new

# Compute mean using the mean module
mean_x = self.mean_module(X)

return MultivariateNormal(mean_x, K_combined)
102 changes: 102 additions & 0 deletions grakel_replace/mixed_single_task_gp_usage_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import networkx as nx
import torch
from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import AdditiveKernel, MaternKernel
from grakel_replace.mixed_single_task_gp import MixedSingleTaskGP
from grakel_replace.torch_wl_kernel import TorchWLKernel

TRAIN_CONFIGS = 10
TEST_CONFIGS = 10
TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS

N_NUMERICAL = 2
N_CATEGORICAL = 2
N_CATEGORICAL_VALUES_PER_CATEGORY = 3
N_GRAPH = 2

kernels = []

# Create numerical and categorical features
X = torch.empty(size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL), dtype=torch.float64)
if N_NUMERICAL > 0:
X[:, :N_NUMERICAL] = torch.rand(
size=(TOTAL_CONFIGS, N_NUMERICAL),
dtype=torch.float64,
)

if N_CATEGORICAL > 0:
X[:, N_NUMERICAL:] = torch.randint(
0,
N_CATEGORICAL_VALUES_PER_CATEGORY,
size=(TOTAL_CONFIGS, N_CATEGORICAL),
dtype=torch.float64,
)

# Create random graph architectures
graphs = []
for _ in range(TOTAL_CONFIGS):
G = nx.erdos_renyi_graph(n=5, p=0.5) # Random graph with 5 nodes
graphs.append(G)

# Create random target values
y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64)

# Setup kernels for numerical and categorical features
if N_NUMERICAL > 0:
matern = ScaleKernel(
MaternKernel(
nu=2.5,
ard_num_dims=N_NUMERICAL,
active_dims=tuple(range(N_NUMERICAL)),
),
)
kernels.append(matern)

if N_CATEGORICAL > 0:
hamming = ScaleKernel(
CategoricalKernel(
ard_num_dims=N_CATEGORICAL,
active_dims=tuple(range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)),
),
)
kernels.append(hamming)

# Combine numerical and categorical kernels
combined_num_cat_kernel = AdditiveKernel(*kernels) if kernels else None

# Create WL kernel for graphs
wl_kernel = TorchWLKernel(n_iter=5, normalize=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing there's no way to really make it such that we could pass the TorchWLKernel to the AdditiveKernel, i.e. you would use it just like any other kernel type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could define a WLKernel class that extends gpytorch.kernels.Kernel and use that class instead of TorchWLKernel.


# Split into train and test sets
train_x = X[:TRAIN_CONFIGS]
train_graphs = graphs[:TRAIN_CONFIGS]
train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch

test_x = X[TRAIN_CONFIGS:]
test_graphs = graphs[TRAIN_CONFIGS:]
test_y = y[TRAIN_CONFIGS:].unsqueeze(-1)

# Initialize the mixed GP
gp = MixedSingleTaskGP(
train_X=train_x,
train_graphs=train_graphs,
train_Y=train_y,
num_cat_kernel=combined_num_cat_kernel,
wl_kernel=wl_kernel,
)

# Compute the posterior distribution
multivariate_normal: MultivariateNormal = gp.forward(train_x, train_graphs)
print("Posterior distribution:", multivariate_normal)

# Making predictions on test data
with torch.no_grad():
posterior = gp.forward(test_x, test_graphs)
predictions = posterior.mean
uncertainties = posterior.variance.sqrt()
covar = posterior.covariance_matrix

print("\nMean:", predictions)
print("Variance:", uncertainties)
print("Covariance matrix:", covar)
Loading
Loading