diff --git a/grakel_replace/grakel_wl_usage_example.py b/grakel_replace/grakel_wl_usage_example.py new file mode 100644 index 00000000..33f9c386 --- /dev/null +++ b/grakel_replace/grakel_wl_usage_example.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import matplotlib.pyplot as plt +import networkx as nx +from grakel import graph_from_networkx, WeisfeilerLehman + + +def visualize_graph(G): + """Visualize the NetworkX graph.""" + pos = nx.spring_layout(G) + nx.draw(G, pos, with_labels=True, node_size=700, node_color="lightblue") + plt.show() + +def add_labels(G): + """Add labels to the nodes of the graph.""" + for node in G.nodes(): + G.nodes[node]['label'] = str(node) + +# Create graphs +G1 = nx.Graph() +G1.add_edges_from([(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)]) +add_labels(G1) + +G2 = nx.Graph() +G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)]) +add_labels(G2) + +G3 = nx.Graph() +G3.add_edges_from([(0, 1), (1, 3), (3, 2)]) +add_labels(G3) + +# Visualize the graphs +visualize_graph(G1) +visualize_graph(G2) +visualize_graph(G3) + +# Convert NetworkX graphs to Grakel format using graph_from_networkx +graph_list = list( + graph_from_networkx([G1, G2, G3], node_labels_tag="label", as_Graph=True) +) + +# Initialize the Weisfeiler-Lehman kernel +wl_kernel = WeisfeilerLehman(n_iter=5, normalize=False) + +# Compute the kernel matrix +K = wl_kernel.fit_transform(graph_list) + +# Display the kernel matrix +print("Fit and Transform on Kernel matrix (pairwise similarities):") +print(K) diff --git a/grakel_replace/mixed_single_task_gp_usage_example.py b/grakel_replace/mixed_single_task_gp_usage_example.py new file mode 100644 index 00000000..335bc1b4 --- /dev/null +++ b/grakel_replace/mixed_single_task_gp_usage_example.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import time +from collections.abc import Iterator +from contextlib import contextmanager +from itertools import product +from typing import TYPE_CHECKING + +import networkx as nx +import torch +from botorch import fit_gpytorch_mll, settings +from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement +from botorch.models import SingleTaskGP +from botorch.models.gp_regression_mixed import CategoricalKernel, Kernel, ScaleKernel +from gpytorch import ExactMarginalLogLikelihood +from gpytorch.kernels import AdditiveKernel, MaternKernel +from grakel_replace.optimize import optimize_acqf_graph +from grakel_replace.torch_wl_kernel import TorchWLKernel +from grakel_replace.utils import min_max_scale, seed_all + +if TYPE_CHECKING: + from gpytorch.distributions.multivariate_normal import MultivariateNormal + +start_time = time.time() +settings.debug._set_state(True) +seed_all() + +TRAIN_CONFIGS = 50 +TEST_CONFIGS = 10 +TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS + +N_NUMERICAL = 2 +N_CATEGORICAL = 1 +N_CATEGORICAL_VALUES_PER_CATEGORY = 2 +N_GRAPH = 1 + +assert N_GRAPH == 1, "This example only supports a single graph feature" + +# Generate random data +X = torch.cat([ + torch.rand((TOTAL_CONFIGS, N_NUMERICAL), dtype=torch.float64), + torch.randint(0, N_CATEGORICAL_VALUES_PER_CATEGORY, (TOTAL_CONFIGS, N_CATEGORICAL), + dtype=torch.float64), + torch.arange(TOTAL_CONFIGS, dtype=torch.float64).unsqueeze(1) +], dim=1) + +# Generate random graphs +graphs = [nx.erdos_renyi_graph(5, 0.5) for _ in range(TOTAL_CONFIGS)] + +# Generate random target values +y = torch.rand(TOTAL_CONFIGS, dtype=torch.float64) + 0.5 + +# Split into train and test sets +train_x, test_x = X[:TRAIN_CONFIGS], X[TRAIN_CONFIGS:] +train_graphs, test_graphs = graphs[:TRAIN_CONFIGS], graphs[TRAIN_CONFIGS:] +train_y, test_y = y[:TRAIN_CONFIGS].unsqueeze(-1), y[TRAIN_CONFIGS:].unsqueeze(-1) + +train_x, test_x = min_max_scale(train_x), min_max_scale(test_x) + +kernels = [ + ScaleKernel( + MaternKernel(nu=2.5, ard_num_dims=N_NUMERICAL, active_dims=range(N_NUMERICAL))), + ScaleKernel(CategoricalKernel( + ard_num_dims=N_CATEGORICAL, + active_dims=range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL))), + ScaleKernel(TorchWLKernel( + graph_lookup=train_graphs, n_iter=5, normalize=True, + active_dims=(X.shape[1] - 1,))) +] + +gp = SingleTaskGP(train_X=train_x, train_Y=train_y, covar_module=AdditiveKernel(*kernels)) + +# Compute the posterior distribution +# The wl_kernel will use the indices to index into the training graphs it is holding +# on to... +multivariate_normal: MultivariateNormal = gp.forward(train_x) + + +# Making predictions on test data +# No the wl_kernel needs to be aware of the test graphs +@contextmanager +def set_graph_lookup(_gp: SingleTaskGP, new_graphs: list[nx.Graph]) -> Iterator[None]: + kernel_prev_graphs: list[tuple[Kernel, list[nx.Graph]]] = [] + for kern in _gp.covar_module.sub_kernels(): + if isinstance(kern, TorchWLKernel): + kernel_prev_graphs.append((kern, kern.graph_lookup)) + kern.set_graph_lookup(new_graphs) + + yield + + for _kern, _prev_graphs in kernel_prev_graphs: + _kern.set_graph_lookup(_prev_graphs) + + +with torch.no_grad(), set_graph_lookup(gp, train_graphs + test_graphs): + posterior = gp.forward(test_x) + predictions = posterior.mean + uncertainties = posterior.variance.sqrt() + covar = posterior.covariance_matrix + +mll = ExactMarginalLogLikelihood(gp.likelihood, gp) +fit_gpytorch_mll(mll) + +acq_function = qLogNoisyExpectedImprovement( + model=gp, + X_baseline=train_x, + objective=LinearMCObjective(weights=torch.tensor([-1.0])), + prune_baseline=True, +) + +bounds = torch.tensor([ + [0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL + [-1.0] * N_GRAPH, + [1.0] * N_NUMERICAL + [ + float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL + [ + len(X) - 1] * N_GRAPH, +]) + +cats_per_column = {i: list(range(N_CATEGORICAL_VALUES_PER_CATEGORY)) for i in + range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)} +fixed_cats = [dict(zip(cats_per_column.keys(), combo, strict=False)) for combo in + product(*cats_per_column.values())] + +best_candidate, best_score = optimize_acqf_graph( + acq_function=acq_function, + bounds=bounds, + fixed_features_list=fixed_cats, + train_graphs=train_graphs, + num_graph_samples=2, + num_restarts=2, + raw_samples=16, + q=1, +) + +print(f"Best candidate: {best_candidate}") +print(f"Best score: {best_score}") +print(f"Elapsed time: {time.time() - start_time} seconds") diff --git a/grakel_replace/optimize.py b/grakel_replace/optimize.py new file mode 100644 index 00000000..b54b671c --- /dev/null +++ b/grakel_replace/optimize.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import random +from collections.abc import Iterator +from contextlib import contextmanager +from typing import TYPE_CHECKING + +import networkx as nx +import torch +from botorch.optim import optimize_acqf_mixed +from grakel_replace.torch_wl_kernel import TorchWLKernel + +if TYPE_CHECKING: + from botorch.acquisition import AcquisitionFunction + from botorch.models.gp_regression_mixed import Kernel + + +@contextmanager +def set_graph_lookup( + kernel: Kernel, + new_graphs: list[nx.Graph], + *, + append: bool = True, +) -> Iterator[None]: + """Context manager to temporarily set the graph lookup for a kernel. + + Args: + kernel (Kernel): The kernel whose graph lookup is to be set. + new_graphs (list[nx.Graph]): The new graphs to set in the graph lookup. + append (bool, optional): Whether to append the new graphs to the existing graph + lookup. Defaults to True. + """ + kernel_prev_graphs: list[tuple[Kernel, list[nx.Graph]]] = [] + + # Determine the modules to update based on the kernel type + if isinstance(kernel, TorchWLKernel): + modules = [kernel] + else: + assert hasattr( + kernel, "sub_kernels" + ), "Kernel module must have sub_kernels method." + modules = [k for k in kernel.sub_kernels() if isinstance(k, TorchWLKernel)] + + # Save the current graph lookup and set the new graph lookup + for kern in modules: + kernel_prev_graphs.append((kern, kern.graph_lookup)) + if append: + kern.set_graph_lookup([*kern.graph_lookup, *new_graphs]) + else: + kern.set_graph_lookup(new_graphs) + + yield + + # Restore the original graph lookup after the context manager exits + for kern, prev_graphs in kernel_prev_graphs: + kern.set_graph_lookup(prev_graphs) + + +def sample_graphs(graphs: list[nx.Graph], num_samples: int) -> list[nx.Graph]: + """Sample graphs using random walks or edge modifications. + + Args: + graphs (list[nx.Graph]): Existing training graphs. + num_samples (int): Number of graph samples to generate. + + Returns: + list[nx.Graph]: Sampled graphs. + """ + sampled_graphs = [] + for _ in range(num_samples): + base_graph = random.choice(graphs) + sampled_graph = base_graph.copy() + + # More aggressive modifications + num_modifications = random.randint(2, 5) # Increase minimum modifications + for _ in range(num_modifications): + if random.random() > 0.3: # 70% chance to add edge + nodes = list(sampled_graph.nodes) + if len(nodes) >= 2: + u, v = random.sample(nodes, 2) + if not sampled_graph.has_edge(u, v): + sampled_graph.add_edge(u, v) + elif sampled_graph.edges: # 30% chance to remove edge + u, v = random.choice(list(sampled_graph.edges)) + sampled_graph.remove_edge(u, v) + + # Ensure the graph stays connected + if not nx.is_connected(sampled_graph): + components = list(nx.connected_components(sampled_graph)) + for i in range(len(components) - 1): + u = random.choice(list(components[i])) + v = random.choice(list(components[i + 1])) + sampled_graph.add_edge(u, v) + + sampled_graphs.append(sampled_graph) + + return sampled_graphs + + +def optimize_acqf_graph( + acq_function: AcquisitionFunction, + bounds: torch.Tensor, + fixed_features_list: list[dict[int, float]] | None = None, + num_graph_samples: int = 10, + train_graphs: list[nx.Graph] | None = None, + num_restarts: int = 10, + raw_samples: int = 1024, + q: int = 1, +) -> tuple[torch.Tensor, float]: + """Optimize acquisition function with graph sampling. + + Args: + acq_function: Acquisition function to optimize + bounds: Bounds for numerical/categorical features + fixed_features_list: Fixed categorical feature configurations + num_graph_samples: Number of graphs to sample + train_graphs: Original training graphs + num_restarts: Number of optimization restarts + raw_samples: Number of raw samples to generate + q: Number of candidates to generate + + Returns: + tuple: Best candidate and acquisition score. + """ + if train_graphs is None: + raise ValueError("train_graphs cannot be None.") + + sampled_graphs = sample_graphs(train_graphs, num_samples=num_graph_samples) + best_candidates, best_scores = [], [] + graph_idx = bounds.shape[1] - 1 + # Iterate through all the kernels and include the sampled graph. + for _graph in sampled_graphs: + with set_graph_lookup(acq_function.model.covar_module, [_graph], append=True): + for fixed_features in fixed_features_list or [{}]: + # We then consider this graph as a fixed feature, i.e. in the X's + # generated during acquisition, the graph column will just be full + # of `-1` indicating to select the very last graph in the lookup + # they used. + candidates, scores = optimize_acqf_mixed( + acq_function=acq_function, + bounds=bounds, + fixed_features_list=[{**fixed_features, graph_idx: -1.0}], + num_restarts=num_restarts, + raw_samples=raw_samples, q=q) + best_candidates.append(candidates) + best_scores.append(scores) + best_idx = torch.argmax(torch.tensor(best_scores)) + return best_candidates[best_idx], best_scores[best_idx].item() diff --git a/grakel_replace/single_task_gp_usage_example.py b/grakel_replace/single_task_gp_usage_example.py new file mode 100644 index 00000000..9e295852 --- /dev/null +++ b/grakel_replace/single_task_gp_usage_example.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from itertools import product +from typing import TYPE_CHECKING + +import torch +from botorch import fit_gpytorch_mll +from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement +from botorch.models import SingleTaskGP +from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel +from botorch.optim import optimize_acqf_mixed +from gpytorch import ExactMarginalLogLikelihood +from gpytorch.kernels import AdditiveKernel, MaternKernel + +if TYPE_CHECKING: + from gpytorch.distributions.multivariate_normal import MultivariateNormal + +TRAIN_CONFIGS = 10 +TEST_CONFIGS = 10 +TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS + +N_NUMERICAL = 2 +N_CATEGORICAL = 2 +N_CATEGORICAL_VALUES_PER_CATEGORY = 3 + +kernels = [] + +# Create some random encoded hyperparameter configurations +X = torch.empty(size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL), dtype=torch.float64) +if N_NUMERICAL > 0: + X[:, :N_NUMERICAL] = torch.rand( + size=(TOTAL_CONFIGS, N_NUMERICAL), + dtype=torch.float64, + ) + +if N_CATEGORICAL > 0: + X[:, N_NUMERICAL:] = torch.randint( + 0, + N_CATEGORICAL_VALUES_PER_CATEGORY, + size=(TOTAL_CONFIGS, N_CATEGORICAL), + dtype=torch.float64, + ) + +y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64) + +if N_NUMERICAL > 0: + matern = ScaleKernel( + MaternKernel( + nu=2.5, + ard_num_dims=N_NUMERICAL, + active_dims=tuple(range(N_NUMERICAL)), + ), + ) + kernels.append(matern) + +if N_CATEGORICAL > 0: + hamming = ScaleKernel( + CategoricalKernel( + ard_num_dims=N_CATEGORICAL, + active_dims=tuple(range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)), + ), + ) + kernels.append(hamming) + +combined_num_cat_kernel = AdditiveKernel(*kernels) + +train_x = X[:TRAIN_CONFIGS] +train_y = y[:TRAIN_CONFIGS] + +test_x = X[TRAIN_CONFIGS:] +test_y = y[TRAIN_CONFIGS:] + +K_matrix = combined_num_cat_kernel.forward(train_x, train_x) + +train_y = train_y.unsqueeze(-1) +test_y = test_y.unsqueeze(-1) + +gp = SingleTaskGP( + train_X=train_x, + train_Y=train_y, + covar_module=combined_num_cat_kernel, +) + +multivariate_normal: MultivariateNormal = gp.forward(train_x) + +# =============== Fitting the GP using botorch =============== + +print("\nFitting the GP model using botorch...") + +mll = ExactMarginalLogLikelihood(gp.likelihood, gp) +fit_gpytorch_mll(mll) + +acq_function = qLogNoisyExpectedImprovement( + model=gp, + X_baseline=train_x, + objective=LinearMCObjective(weights=torch.tensor([-1.0])), + prune_baseline=True, +) + +# Define bounds +bounds = torch.tensor( + [ + [0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL, + [1.0] * N_NUMERICAL + [ + float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL + ] +) + +# Setup categorical feature optimization +cats_per_column: dict[int, list[float]] = { + column_ix: [float(i) for i in range(N_CATEGORICAL_VALUES_PER_CATEGORY)] + for column_ix in range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL) +} + +# Generate fixed categorical features +fixed_cats: list[dict[int, float]] +if len(cats_per_column) == 1: + col, choice_indices = next(iter(cats_per_column.items())) + fixed_cats = [{col: i} for i in choice_indices] +else: + fixed_cats = [ + dict(zip(cats_per_column.keys(), combo)) + for combo in product(*cats_per_column.values()) + ] + +best_candidate, best_score = optimize_acqf_mixed( + acq_function=acq_function, + bounds=bounds, + fixed_features_list=fixed_cats, + num_restarts=10, + raw_samples=10, + q=1, +) + +print("Best candidate:", best_candidate) +print("Acquisition score:", best_score) diff --git a/grakel_replace/torch_wl_kernel.py b/grakel_replace/torch_wl_kernel.py new file mode 100644 index 00000000..e4f6fd2b --- /dev/null +++ b/grakel_replace/torch_wl_kernel.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +from typing import Any + +import networkx as nx +import torch +from botorch.models.gp_regression_mixed import Kernel +from torch import Tensor +from torch.nn import Module + + +class TorchWLKernel(Kernel): + has_lengthscale = False + + def __init__( + self, + graph_lookup: list[nx.Graph], + n_iter: int = 5, + *, + normalize: bool = True, + active_dims: tuple[int, ...], + **kwargs: Any, + ) -> None: + super().__init__(active_dims=active_dims, **kwargs) + self.graph_lookup = graph_lookup + self.n_iter = n_iter + self.normalize = normalize + + # Cache adjacency matrices and initial node labels + self.adjacency_cache = {} + self.label_cache = {} + + self._precompute_graph_data() + + def _precompute_graph_data(self) -> None: + """Precompute adjacency matrices and initial node labels for all graphs.""" + self.adjacency_cache = {} + self.label_cache = {} + + for idx, graph in enumerate(self.graph_lookup): + self.adjacency_cache[idx] = self._get_sparse_adj(graph) + self.label_cache[idx] = self._init_node_labels(graph) + + def set_graph_lookup(self, graph_lookup: list[nx.Graph]) -> None: + """Update the graph lookup and refresh the cached data.""" + self.graph_lookup = graph_lookup + self._precompute_graph_data() + + def forward( + self, + x1: Tensor, + x2: Tensor, + *, + diag: bool = False, + last_dim_is_batch: bool = False, + **params: Any, + ) -> Tensor: + if last_dim_is_batch: + raise NotImplementedError("TODO: Figure this out") + + assert x1.shape[-1] == 1, "Last dimension must be the graph index" + assert x2.shape[-1] == 1, "Last dimension must be the graph index" + + x1_is_x2 = torch.equal(x1, x2) + + if x1.ndim == 3: + q_dim_size = x1.shape[0] + assert x2.shape[0] == q_dim_size + + out = torch.empty((q_dim_size, x1.shape[1], x2.shape[1]), device=x1.device) + for q in range(q_dim_size): + out[q] = self.forward(x1[q], x2[q], diag=diag) + return out + + if x1_is_x2: + indices = x1.flatten().to(torch.int64).tolist() + all_graphs = indices + select = None + else: + indices1 = x1.flatten().to(torch.int64).tolist() + indices2 = x2.flatten().to(torch.int64).tolist() + all_graphs = indices1 + indices2 + select = lambda K: K[:len(indices1), len(indices1):] + + # Handle the special case for -1 + all_graphs = [ + len(self.graph_lookup) - 1 if i == -1 else i for i in all_graphs + ] + + # Use cached adjacency matrices and labels + adj_matrices = [self.adjacency_cache[i] for i in all_graphs] + label_tensors = [self.label_cache[i] for i in all_graphs] + + _kernel = _TorchWLKernel(n_iter=self.n_iter, normalize=self.normalize) + K = _kernel(adj_matrices, label_tensors) + K_selected = K if select is None else select(K) + if diag: + return torch.diag(K_selected) + return K_selected + + def _get_sparse_adj(self, graph: nx.Graph) -> Tensor: + """Convert a NetworkX graph to a sparse adjacency tensor.""" + edges = list(graph.edges()) + num_nodes = graph.number_of_nodes() + + if not edges: + return torch.sparse_coo_tensor( + indices=torch.empty((2, 0), dtype=torch.long), + values=torch.empty(0), + size=(num_nodes, num_nodes), + device=self.device, + ) + + edge_indices: list[tuple[int, int]] = edges + [(v, u) for u, v in edges] + rows, cols = zip(*edge_indices, strict=False) + + indices = torch.tensor([rows, cols], dtype=torch.long) + values = torch.ones(len(edge_indices), dtype=torch.float) + + return torch.sparse_coo_tensor( + indices, values, (num_nodes, num_nodes), device=self.device + ).to_sparse_csr() # Convert to CSR for efficient operations + + def _init_node_labels(self, graph: nx.Graph) -> Tensor: + """Initialize node label tensor from graph attributes.""" + labels: list[int] = [] + label_dict: dict[str, int] = {} + label_counter = 0 + + for node in range(graph.number_of_nodes()): + if "label" in graph.nodes[node]: + label = graph.nodes[node]["label"] + else: + label = str(node) + if label not in label_dict: + label_dict[label] = label_counter + label_counter += 1 + labels.append(label_dict[label]) + + return torch.tensor(labels, dtype=torch.long, device=self.device) + + +class _TorchWLKernel(Module): + """A custom implementation of Weisfeiler-Lehman (WL) Kernel in PyTorch. + + The WL Kernel is a graph kernel that measures similarity between graphs based on + their structural properties. It works by iteratively updating node labels based on + their neighborhoods and computing feature vectors from label distributions. + + Args: + n_iter: Number of WL iterations to perform + normalize: bool, optional. Whether to normalize the kernel matrix + + Attributes: + device: torch.device for computation (CPU/GPU) + label_dict: Mapping from node labels to numerical indices + label_counter: Counter for generating new label indices + """ + + def __init__(self, n_iter: int = 5, *, normalize: bool = True) -> None: + super().__init__() + self.n_iter = n_iter + self.normalize = normalize + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.label_dict: dict[tuple, int] = {} + self.label_counter: int = 0 + self.hash_module = torch.nn.Linear(2, 1, bias=False) + torch.nn.init.normal_(self.hash_module.weight) + + def _wl_iteration(self, adj: Tensor, labels: Tensor) -> Tensor: + """Perform one iteration of the WL algorithm to update node labels.""" + # Ensure the adjacency matrix is in COO format before coalescing + if adj.layout == torch.sparse_csr: + adj = adj.to_sparse_coo() + + adj = adj.coalesce() + indices = adj.indices() + rows, cols = indices + num_nodes = labels.size(0) + + # Create a mask for each node's neighbors + neighbor_mask = torch.zeros((num_nodes, num_nodes), dtype=torch.bool, + device=self.device) + neighbor_mask[rows, cols] = True + + # Get neighbor labels for each node + neighbor_labels = labels.unsqueeze(0).expand(num_nodes, -1) + neighbor_labels = neighbor_labels.masked_fill(~neighbor_mask, -1) + + # Sort neighbor labels for each node + sorted_neighbor_labels, _ = torch.sort(neighbor_labels, dim=1, descending=True) + + # Remove padding (-1 values) from sorted labels + valid_neighbors_mask = sorted_neighbor_labels != -1 + max_neighbors = valid_neighbors_mask.sum(1).max().item() + sorted_neighbor_labels = sorted_neighbor_labels[:, :max_neighbors] + + # Combine node labels with neighbor labels + node_labels_expanded = labels.unsqueeze(1).expand(-1, max_neighbors) + + # Create feature vectors + features = torch.cat([ + node_labels_expanded.unsqueeze(-1).float(), + sorted_neighbor_labels.unsqueeze(-1).float() + ], dim=-1) + + # Hash the combined features + hashed_features = self.hash_module(features).squeeze(-1) + hashed_labels = hashed_features.sum(dim=1) + + # Convert to discrete labels + _, new_labels = torch.unique(hashed_labels, sorted=True, return_inverse=True) + return new_labels + + def _compute_feature_vector(self, all_labels: list[list[Tensor]]) -> Tensor: + """Compute feature vectors for all graphs in a batch.""" + max_label = 0 + for iteration_labels in all_labels: + for labels in iteration_labels: + max_label = max(max_label, labels.max().item()) + + batch_size = len(all_labels[0]) + features = torch.zeros((batch_size, max_label + 1), + dtype=torch.float32, device=self.device) + + # Accumulate label counts for each graph + for graph_idx in range(batch_size): + # Sum contributions from all WL iterations + for iteration_labels in all_labels: + graph_labels = iteration_labels[graph_idx] + label_counts = torch.bincount( + graph_labels, + minlength=max_label + 1 + ).float() + features[graph_idx] += label_counts + + return features + + def forward( + self, + adj_matrices: list[Tensor], + label_tensors: list[Tensor], + ) -> Tensor: + """Compute WL kernel matrix for a list of graphs. + + Args: + adj_matrices: Precomputed sparse adjacency matrices for graphs. + label_tensors: Precomputed node label tensors for graphs. + + Returns: + Kernel matrix containing pairwise graph similarities. + """ + if len(adj_matrices) != len(label_tensors): + raise ValueError("Mismatch between adjacency matrices and label tensors.") + + # Perform WL iterations to update the node labels + all_labels = [label_tensors] + for _ in range(self.n_iter): + new_labels = [ + self._wl_iteration(adj, labels) + for adj, labels in zip(adj_matrices, all_labels[-1], strict=False) + ] + all_labels.append(new_labels) + + # Compute feature vectors for each graph in the batch + final_features = self._compute_feature_vector(all_labels) + + # Compute kernel matrix (similarity matrix) + kernel_matrix = torch.mm(final_features, final_features.t()) + + # Apply normalization if requested + if self.normalize: + diag = torch.sqrt(torch.diag(kernel_matrix)) + kernel_matrix /= (diag.unsqueeze(0) * diag.unsqueeze(1)) + + return kernel_matrix + + +class GraphDataset: + """Utility class to convert NetworkX graphs for WL kernel.""" + + @staticmethod + def from_networkx( + graphs: list[nx.Graph], node_labels_tag: str = "label" + ) -> list[nx.Graph]: + if not all(isinstance(g, nx.Graph) for g in graphs): + raise TypeError("Expected input type is a list of NetworkX graphs.") + + """Convert NetworkX graphs ensuring proper node labeling.""" + processed_graphs = [] + for g in graphs: + g = g.copy() + # Add default labels if not present + for node in g.nodes(): + if node_labels_tag not in g.nodes[node]: + g.nodes[node][node_labels_tag] = str(node) + processed_graphs.append(g) + return processed_graphs diff --git a/grakel_replace/torch_wl_usage_example.py b/grakel_replace/torch_wl_usage_example.py new file mode 100644 index 00000000..f9958045 --- /dev/null +++ b/grakel_replace/torch_wl_usage_example.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import networkx as nx +import torch +from torch_wl_kernel import GraphDataset, TorchWLKernel + +# Create the same graphs as for the Grakel example +G1 = nx.Graph() +G1.add_edges_from([(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)]) +G2 = nx.Graph() +G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)]) +G3 = nx.Graph() +G3.add_edges_from([(0, 1), (1, 3), (3, 2)]) + +# Process graphs +graphs: list[nx.Graph] = GraphDataset.from_networkx([G1, G2, G3]) + +# Initialize and run WL kernel +wl_kernel = TorchWLKernel( + training_graph_list=graphs, + n_iter=2, + normalize=True, + active_dims=(1,), +) +X1 = torch.tensor([[42.4, 43.4, 44.5], [0, 1, 2]]).T +X2 = torch.tensor([[42.4, 43.4, 44.5], [0, 1, 2]]).T + +K = wl_kernel(X1, X2) +print(K.to_dense()) # noqa: T201 diff --git a/grakel_replace/utils.py b/grakel_replace/utils.py new file mode 100644 index 00000000..245550eb --- /dev/null +++ b/grakel_replace/utils.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import random + +import numpy as np +import torch + + +def seed_all(seed: int = 100): + """Seed all random generators for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # Ensure reproducibility with CuDNN (may reduce performance) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def min_max_scale(tensor: torch.Tensor) -> torch.Tensor: + """Scale the input tensor to the range [0, 1].""" + min_vals = tensor.min(dim=0, keepdim=True).values + max_vals = tensor.max(dim=0, keepdim=True).values + return (tensor - min_vals) / (max_vals - min_vals) diff --git a/tests/test_torch_wl_kernel.py b/tests/test_torch_wl_kernel.py new file mode 100644 index 00000000..49d7b1ed --- /dev/null +++ b/tests/test_torch_wl_kernel.py @@ -0,0 +1,182 @@ +import networkx as nx +import pytest +import torch +from grakel import WeisfeilerLehman, graph_from_networkx +from grakel_replace.torch_wl_kernel import TorchWLKernel, GraphDataset + + +class TestTorchWLKernel: + @pytest.fixture + def example_graphs(self): + # Create example graphs for testing + G1 = nx.Graph() + G1.add_edges_from([(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)]) + for node in G1.nodes(): + G1.nodes[node]["label"] = str(node) + + G2 = nx.Graph() + G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)]) + for node in G2.nodes(): + G2.nodes[node]["label"] = str(node) + + G3 = nx.Graph() + G3.add_edges_from([(0, 1), (1, 3), (3, 2)]) + for node in G3.nodes(): + G3.nodes[node]["label"] = str(node) + + return [G1, G2, G3] + + @pytest.mark.parametrize("n_iter", [1, 2, 3, 4, 5, 6, 7]) + @pytest.mark.parametrize("normalize", [True, False]) + def test_wl_kernel_against_grakel(self, n_iter, normalize, example_graphs): + """Test the custom WL kernel against Grakel's implementation.""" + graphs = GraphDataset.from_networkx(example_graphs) + + # Initialize and compute kernel matrix using custom WLKernel + wl_kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize) + torch_kernel_matrix = wl_kernel(graphs).cpu().detach().numpy() + + # Convert to Grakel-compatible format + grakel_graphs = graph_from_networkx(example_graphs, node_labels_tag="label") + + # Initialize and compute kernel matrix using Grakel's Weisfeiler-Lehman kernel + grakel_wl = WeisfeilerLehman(n_iter=n_iter, normalize=normalize) + grakel_kernel_matrix = grakel_wl.fit_transform(grakel_graphs) + + # Assert that the kernel matrices are similar within a reasonable tolerance + assert torch.allclose( + torch.tensor(torch_kernel_matrix, dtype=torch.float64), + torch.tensor(grakel_kernel_matrix, dtype=torch.float64), + atol=1e-100 + ), (f"Mismatch found in kernel matrices with n_iter={n_iter} and " + f"normalize={normalize}") + + def test_kernel_symmetry(self, example_graphs): + """Test if the kernel matrix is symmetric.""" + graphs = GraphDataset.from_networkx([example_graphs[0], example_graphs[0]]) + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + K = wl_kernel(graphs) + + # Check if the kernel matrix is symmetric + assert torch.allclose(K, K.T, atol=1e-100), "Kernel matrix is not symmetric" + + def test_empty_graph(self): + """Test the kernel computation for an empty graph.""" + # Test with an empty graph + G_empty = nx.Graph() + graphs = GraphDataset.from_networkx([G_empty]) + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + + # Check if kernel returns a valid matrix (1x1 zero matrix expected) + K = wl_kernel(graphs) + assert K.shape == (1, 1), "Kernel matrix shape for empty graph is incorrect" + assert K.item() == 0.0, "Kernel matrix value for empty graph should be zero" + + def test_invalid_input(self): + """Test that invalid inputs raise the appropriate TypeError.""" + # Test with invalid input types + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + + with pytest.raises(TypeError, match="Expected input type is a list of NetworkX graphs"): + wl_kernel("invalid_input") # Passing a string instead of a list of graphs + + with pytest.raises(TypeError, match="Expected input type is a list of NetworkX graphs"): + wl_kernel([1, 2, 3]) # Passing a list of integers instead of graphs + + def test_kernel_on_single_node_graph(self, example_graphs): + """Test the kernel computation for single-node graphs.""" + # Test with a single-node graph + G_single = nx.Graph() + G_single.add_node(0) + G_single.nodes[0]["label"] = "0" + + graphs = GraphDataset.from_networkx([G_single, G_single]) + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + K = wl_kernel(graphs) + + # Check if a kernel matrix for identical single-node graphs is valid and symmetric + assert K.shape == (2, 2), "Kernel matrix shape for single-node graphs is incorrect" + assert K[0, 0] == K[1, 1], "Self-similarity for single-node graph should be the same" + assert torch.allclose(K, K.T, atol=1e-100), "Kernel matrix is not symmetric for single-node graph" + + def test_wl_kernel_with_empty_graph_and_reordered_edges(self, example_graphs): + """Test the TorchWLKernel with an empty graph and a graph with reordered edges.""" + # Create example graphs for testing + G_empty = nx.Graph() + G = example_graphs[0] + G_reordered = nx.Graph() + G_reordered.add_edges_from([(1, 4), (2, 3), (1, 2), (0, 1), (1, 3)]) + for node in G_reordered.nodes(): + G_reordered.nodes[node]["label"] = str(node) + + graphs = GraphDataset.from_networkx([G_empty, G, G_reordered]) + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + K = wl_kernel(graphs) + + # Check if the kernel matrix is valid and the values + # are the same for the original and reordered graphs + assert K.shape == (3, 3), "Kernel matrix shape is incorrect" + assert K[1, 1] == K[2, 2], "Kernel value for original and reordered graphs should be the same" + + @pytest.mark.parametrize("n_iter", [1, 2, 3, 4, 5, 6, 7]) + @pytest.mark.parametrize("normalize", [True, False]) + def test_wl_kernel_with_different_node_labels(self, n_iter, normalize, example_graphs): + """Test the TorchWLKernel with graphs having different node labels.""" + # Create example graphs with different node labels + G1 = example_graphs[0] + for node in G1.nodes(): + G1.nodes[node]["label"] = f"node_{node}" + + G2 = example_graphs[1] + for node in G2.nodes(): + G2.nodes[node]["label"] = f"vertex_{node}" + + G3 = example_graphs[2] + for node in G3.nodes(): + G3.nodes[node]["label"] = f"n{node}" + + graphs = GraphDataset.from_networkx([G1, G2, G3]) + + # Initialize and compute kernel matrix using custom TorchWLKernel + wl_kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize) + torch_kernel_matrix = wl_kernel(graphs).cpu().detach().numpy() + + # Convert to Grakel-compatible format + grakel_graphs = graph_from_networkx([G1, G2, G3], node_labels_tag="label") + + # Initialize and compute kernel matrix using Grakel's Weisfeiler-Lehman kernel + grakel_wl = WeisfeilerLehman(n_iter=n_iter, normalize=normalize) + grakel_kernel_matrix = grakel_wl.fit_transform(grakel_graphs) + + # Assert that the kernel matrices are similar within a reasonable tolerance + assert torch.allclose( + torch.tensor(torch_kernel_matrix, dtype=torch.float64), + torch.tensor(grakel_kernel_matrix, dtype=torch.float64), + atol=1e-100 + ), (f"Mismatch found in kernel matrices with n_iter={n_iter} and " + f"normalize={normalize} for graphs with different node labels") + + def test_wl_kernel_with_same_node_labels(self, example_graphs): + """Test the TorchWLKernel with graphs having the same node labels.""" + # Create example graphs with the same node labels + G1 = example_graphs[0] + for node in G1.nodes(): + G1.nodes[node]["label"] = "A" + + G2 = example_graphs[1] + for node in G2.nodes(): + G2.nodes[node]["label"] = "A" + + G3 = example_graphs[2] + for node in G3.nodes(): + G3.nodes[node]["label"] = "A" + + graphs = GraphDataset.from_networkx([G1, G2, G3]) + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + K = wl_kernel(graphs) + + # Check if the kernel matrix is valid and the values are the same for the graphs with the same node labels + assert K.shape == (3, 3), "Kernel matrix shape is incorrect" + assert torch.allclose(K, K.T, atol=1e-100), "Kernel matrix is not symmetric" + assert torch.all(K == K[0, 0]), ("Kernel values should be the same for " + "graphs with the same node labels")