From 36fc3bdc284a85eb951b7c712e2191b4e7253ee5 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Mon, 28 Oct 2024 17:15:46 +0100 Subject: [PATCH 01/32] Add a PyTorch implementation of WL kernel --- grakel_replace/grakel_wl_usage_example.py | 52 +++++++ grakel_replace/torch_wl_kernel.py | 172 ++++++++++++++++++++++ grakel_replace/torch_wl_usage_example.py | 21 +++ tests/test_torch_wl_kernel.py | 111 ++++++++++++++ 4 files changed, 356 insertions(+) create mode 100644 grakel_replace/grakel_wl_usage_example.py create mode 100644 grakel_replace/torch_wl_kernel.py create mode 100644 grakel_replace/torch_wl_usage_example.py create mode 100644 tests/test_torch_wl_kernel.py diff --git a/grakel_replace/grakel_wl_usage_example.py b/grakel_replace/grakel_wl_usage_example.py new file mode 100644 index 00000000..f4db568a --- /dev/null +++ b/grakel_replace/grakel_wl_usage_example.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import matplotlib.pyplot as plt +import networkx as nx + +from weisfeiler_lehman import WeisfeilerLehman +from utils import graph_from_networkx + + +def visualize_graph(G): + """Visualize the NetworkX graph.""" + pos = nx.spring_layout(G) + nx.draw(G, pos, with_labels=True, node_size=700, node_color="lightblue") + plt.show() + +def add_labels(G): + """Add labels to the nodes of the graph.""" + for node in G.nodes(): + G.nodes[node]['label'] = str(node) + +# Create graphs +G1 = nx.Graph() +G1.add_edges_from([(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)]) +add_labels(G1) + +G2 = nx.Graph() +G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)]) +add_labels(G2) + +G3 = nx.Graph() +G3.add_edges_from([(0, 1), (1, 3), (3, 2)]) +add_labels(G3) + +# Visualize the graphs +visualize_graph(G1) +visualize_graph(G2) +visualize_graph(G3) + +# Convert NetworkX graphs to Grakel format using graph_from_networkx +graph_list = list( + graph_from_networkx([G1, G2, G3], node_labels_tag="label", as_Graph=True) +) + +# Initialize the Weisfeiler-Lehman kernel +wl_kernel = WeisfeilerLehman(n_iter=5, normalize=False) + +# Compute the kernel matrix +K = wl_kernel.fit_transform(graph_list) + +# Display the kernel matrix +print("Fit and Transform on Kernel matrix (pairwise similarities):") +print(K) diff --git a/grakel_replace/torch_wl_kernel.py b/grakel_replace/torch_wl_kernel.py new file mode 100644 index 00000000..15f47203 --- /dev/null +++ b/grakel_replace/torch_wl_kernel.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from collections import Counter + +import networkx as nx +import torch +from torch import nn + + +class TorchWLKernel(nn.Module): + """Custom PyTorch implementation of Weisfeiler-Lehman Kernel. + + Args: + n_iter: Number of WL iterations + normalize: Whether to normalize the kernel matrix + """ + + def __init__(self, n_iter: int = 5, normalize: bool = True): + super().__init__() + self.n_iter = n_iter + self.normalize = normalize + self.label_dict = {} + self.label_counter = 0 + + def _get_sparse_adj(self, graph: nx.Graph) -> torch.sparse.Tensor: + """Convert NetworkX graph to sparse adjacency tensor.""" + edges = list(graph.edges()) + if not edges: + num_nodes = graph.number_of_nodes() + return torch.sparse_coo_tensor( + indices=torch.empty((2, 0), dtype=torch.long), + values=torch.empty(0), + size=(num_nodes, num_nodes), + device=self.device + ) + + # Create COO format indices + row = torch.tensor([e[0] for e in edges], dtype=torch.long) + col = torch.tensor([e[1] for e in edges], dtype=torch.long) + edges = torch.stack([ + torch.cat([row, col]), # Add both directions for undirected graph + torch.cat([col, row]) + ]) + + values = torch.ones(edges.size(1), dtype=torch.float) + N = graph.number_of_nodes() + + return torch.sparse_coo_tensor( + edges, values, (N, N), + device=self.device + ) + + def _init_node_labels(self, graph: nx.Graph) -> torch.Tensor: + """Initialize node label tensor from graph.""" + # Get node labels and convert to indices + labels = [] + for node in range(graph.number_of_nodes()): + label = graph.nodes[node].get("label", str(node)) + if label not in self.label_dict: + self.label_dict[label] = self.label_counter + self.label_counter += 1 + labels.append(self.label_dict[label]) + + return torch.tensor(labels, dtype=torch.long, device=self.device) + + def _wl_iteration(self, adj: torch.sparse.Tensor, + labels: torch.Tensor) -> torch.Tensor: + """Perform one WL iteration.""" + # Concatenate own label with sorted neighbor labels + new_labels = [] + for node in range(adj.size(0)): + node_label = labels[node].item() + neighbors = adj.coalesce().indices()[1][adj.coalesce().indices()[0] == node] + neighbor_label_list = sorted([labels[n].item() for n in neighbors]) + combined = f"{node_label}_{neighbor_label_list}" + + if combined not in self.label_dict: + self.label_dict[combined] = self.label_counter + self.label_counter += 1 + new_labels.append(self.label_dict[combined]) + + return torch.tensor(new_labels, dtype=torch.long, device=self.device) + + def _compute_feature_vector(self, labels: torch.Tensor, size: int) -> torch.Tensor: + """Compute histogram feature vector from labels with fixed size.""" + counts = Counter(labels.cpu().numpy()) + feature = torch.zeros(size, device=self.device) + for label, count in counts.items(): + if label < size: # Safety check + feature[label] = count + return feature + + def forward(self, graphs: list[nx.Graph]) -> torch.Tensor: + """Compute WL kernel matrix for a list of graphs. + + Args: + graphs: List of NetworkX graphs + + Returns: + Kernel matrix as a torch.Tensor + """ + # Validate input + if (not isinstance(graphs, list) or + not all(isinstance(g, nx.Graph) for g in graphs)): + raise TypeError("Expected input type is a list of NetworkX graphs.") + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.label_dict = {} + self.label_counter = 0 + + # Handle case of empty graphs list or empty individual graphs + if not graphs or all(g.number_of_nodes() == 0 for g in graphs): + return torch.zeros((len(graphs), len(graphs)), device=self.device) + + # Convert graphs to sparse adjacency matrices and initialize labels + adj_matrices = [self._get_sparse_adj(g) for g in graphs] + label_tensors = [self._init_node_labels(g) for g in graphs] + + # Pre-allocate feature matrices list + feature_matrices = [] + + # First, run all iterations to compute maximum label count + all_label_tensors = [label_tensors] + for _ in range(self.n_iter): + new_label_tensors = [ + self._wl_iteration(adj, labels) + for adj, labels in zip(adj_matrices, all_label_tensors[-1]) + ] + all_label_tensors.append(new_label_tensors) + + max_label_count = self.label_counter + + # Now compute feature vectors for all iterations with fixed size + for labels_list in all_label_tensors: + features = torch.stack([ + self._compute_feature_vector(labels, max_label_count) + for labels in labels_list + ]) + feature_matrices.append(features) + + # Sum up feature matrices from all iterations + final_features = torch.stack(feature_matrices).sum(dim=0) + + # Compute kernel matrix + K = torch.mm(final_features, final_features.t()) + + # Normalize if requested + if self.normalize: + diag = torch.sqrt(torch.diag(K)) + K = K / (diag.unsqueeze(0) * diag.unsqueeze(1)) + + return K + + +class GraphDataset: + """Utility class to convert NetworkX graphs for WL kernel.""" + + @staticmethod + def from_networkx(graphs: list[nx.Graph], node_labels_tag: str = "label") -> list[ + nx.Graph]: + """Convert NetworkX graphs ensuring proper node labeling.""" + processed_graphs = [] + for g in graphs: + g = g.copy() + # Ensure nodes are numbered from 0 to n-1 + g = nx.convert_node_labels_to_integers(g) + # Add default labels if not present + for node in g.nodes(): + if node_labels_tag not in g.nodes[node]: + g.nodes[node][node_labels_tag] = str(node) + processed_graphs.append(g) + return processed_graphs diff --git a/grakel_replace/torch_wl_usage_example.py b/grakel_replace/torch_wl_usage_example.py new file mode 100644 index 00000000..a24d2536 --- /dev/null +++ b/grakel_replace/torch_wl_usage_example.py @@ -0,0 +1,21 @@ +import networkx as nx +from torch_wl_kernel import GraphDataset, TorchWLKernel + +# Create the same graphs as for the Grakel example +G1 = nx.Graph() +G1.add_edges_from([(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)]) +G2 = nx.Graph() +G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)]) +G3 = nx.Graph() +G3.add_edges_from([(0, 1), (1, 3), (3, 2)]) + +# Process graphs +graphs = GraphDataset.from_networkx([G1, G2, G3]) + +# Initialize and run WL kernel +wl_kernel = TorchWLKernel(n_iter=2, normalize=False) + +K = wl_kernel(graphs) + +print("Kernel matrix (pairwise similarities):") +print(K) diff --git a/tests/test_torch_wl_kernel.py b/tests/test_torch_wl_kernel.py new file mode 100644 index 00000000..1e53e929 --- /dev/null +++ b/tests/test_torch_wl_kernel.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import networkx as nx +import pytest +import torch +from grakel import WeisfeilerLehman, graph_from_networkx +from grakel_replace.torch_wl_kernel import TorchWLKernel, GraphDataset + + +@pytest.mark.parametrize("n_iter", [1, 2, 3, 4, 5, 6, 7]) +@pytest.mark.parametrize("normalize", [True, False]) +def test_wl_kernel_against_grakel(n_iter, normalize): + """Test the custom WL kernel against Grakel's implementation. + + Args: + n_iter: Number of iterations for the WL kernel. + normalize: Whether to normalize the kernel matrix. + """ + # Create example graphs for testing + G1 = nx.Graph() + G1.add_edges_from([(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)]) + for node in G1.nodes(): + G1.nodes[node]["label"] = str(node) + + G2 = nx.Graph() + G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)]) + for node in G2.nodes(): + G2.nodes[node]["label"] = str(node) + + G3 = nx.Graph() + G3.add_edges_from([(0, 1), (1, 3), (3, 2)]) + for node in G3.nodes(): + G3.nodes[node]["label"] = str(node) + + graphs = GraphDataset.from_networkx([G1, G2, G3]) + + # Initialize and compute kernel matrix using custom WLKernel + wl_kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize) + torch_kernel_matrix = wl_kernel(graphs).cpu().detach().numpy() + + # Convert to Grakel-compatible format + grakel_graphs = graph_from_networkx([G1, G2, G3], node_labels_tag="label") + + # Initialize and compute kernel matrix using Grakel's Weisfeiler-Lehman kernel + grakel_wl = WeisfeilerLehman(n_iter=n_iter, normalize=normalize) + grakel_kernel_matrix = grakel_wl.fit_transform(grakel_graphs) + + # Assert that the kernel matrices are similar within a reasonable tolerance + assert torch.allclose( + torch.tensor(torch_kernel_matrix, dtype=torch.float64), + torch.tensor(grakel_kernel_matrix, dtype=torch.float64), + atol=1e-5 + ), f"Mismatch found in kernel matrices with n_iter={n_iter} and normalize={normalize}" + + +def test_kernel_symmetry(): + """Test if the kernel matrix is symmetric.""" + # Create example graphs for testing + G = nx.Graph() + G.add_edges_from([(0, 1), (1, 2), (2, 3)]) + for node in G.nodes(): + G.nodes[node]["label"] = str(node) + + graphs = GraphDataset.from_networkx([G, G]) + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + K = wl_kernel(graphs) + + # Check if the kernel matrix is symmetric + assert torch.allclose(K, K.T, atol=1e-5), "Kernel matrix is not symmetric" + + +def test_empty_graph(): + """Test the kernel computation for an empty graph.""" + # Test with an empty graph + G_empty = nx.Graph() + graphs = GraphDataset.from_networkx([G_empty]) + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + + # Check if kernel returns a valid matrix (1x1 zero matrix expected) + K = wl_kernel(graphs) + assert K.shape == (1, 1), "Kernel matrix shape for empty graph is incorrect" + assert K.item() == 0.0, "Kernel matrix value for empty graph should be zero" + + +def test_invalid_input(): + """Test that invalid inputs raise the appropriate TypeError.""" + # Test with invalid input types + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + + with pytest.raises(TypeError, match="Expected input type is a list of NetworkX graphs"): + wl_kernel("invalid_input") # Passing a string instead of a list of graphs + + with pytest.raises(TypeError, match="Expected input type is a list of NetworkX graphs"): + wl_kernel([1, 2, 3]) # Passing a list of integers instead of graphs + + +def test_kernel_on_single_node_graph(): + """Test the kernel computation for single-node graphs.""" + # Test with a single-node graph + G_single = nx.Graph() + G_single.add_node(0) + G_single.nodes[0]["label"] = "0" + + graphs = GraphDataset.from_networkx([G_single, G_single]) + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + K = wl_kernel(graphs) + + # Check if a kernel matrix for identical single-node graphs is valid and symmetric + assert K.shape == (2, 2), "Kernel matrix shape for single-node graphs is incorrect" + assert K[0, 0] == K[1, 1], "Self-similarity for single-node graph should be the same" + assert torch.allclose(K, K.T, atol=1e-5), "Kernel matrix is not symmetric for single-node graph" From b0d38426e1d3a79970f7162f270baa304f649396 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Tue, 29 Oct 2024 15:02:59 +0100 Subject: [PATCH 02/32] Fix imports --- grakel_replace/grakel_wl_usage_example.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/grakel_replace/grakel_wl_usage_example.py b/grakel_replace/grakel_wl_usage_example.py index f4db568a..33f9c386 100644 --- a/grakel_replace/grakel_wl_usage_example.py +++ b/grakel_replace/grakel_wl_usage_example.py @@ -2,9 +2,7 @@ import matplotlib.pyplot as plt import networkx as nx - -from weisfeiler_lehman import WeisfeilerLehman -from utils import graph_from_networkx +from grakel import graph_from_networkx, WeisfeilerLehman def visualize_graph(G): From f87abd6aef46fe0d76221b502e24295755ce22d4 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Tue, 29 Oct 2024 15:12:02 +0100 Subject: [PATCH 03/32] Remove redundant copy --- grakel_replace/torch_wl_kernel.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grakel_replace/torch_wl_kernel.py b/grakel_replace/torch_wl_kernel.py index 15f47203..8fa781c9 100644 --- a/grakel_replace/torch_wl_kernel.py +++ b/grakel_replace/torch_wl_kernel.py @@ -161,7 +161,6 @@ def from_networkx(graphs: list[nx.Graph], node_labels_tag: str = "label") -> lis """Convert NetworkX graphs ensuring proper node labeling.""" processed_graphs = [] for g in graphs: - g = g.copy() # Ensure nodes are numbered from 0 to n-1 g = nx.convert_node_labels_to_integers(g) # Add default labels if not present From 358fbb7c0f805b6ddc9c588947579bcc16584e63 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Tue, 29 Oct 2024 15:16:42 +0100 Subject: [PATCH 04/32] Increase precision for allclose --- tests/test_torch_wl_kernel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_torch_wl_kernel.py b/tests/test_torch_wl_kernel.py index 1e53e929..d039de5b 100644 --- a/tests/test_torch_wl_kernel.py +++ b/tests/test_torch_wl_kernel.py @@ -49,7 +49,7 @@ def test_wl_kernel_against_grakel(n_iter, normalize): assert torch.allclose( torch.tensor(torch_kernel_matrix, dtype=torch.float64), torch.tensor(grakel_kernel_matrix, dtype=torch.float64), - atol=1e-5 + atol=1e-100 ), f"Mismatch found in kernel matrices with n_iter={n_iter} and normalize={normalize}" @@ -66,7 +66,7 @@ def test_kernel_symmetry(): K = wl_kernel(graphs) # Check if the kernel matrix is symmetric - assert torch.allclose(K, K.T, atol=1e-5), "Kernel matrix is not symmetric" + assert torch.allclose(K, K.T, atol=1e-100), "Kernel matrix is not symmetric" def test_empty_graph(): @@ -108,4 +108,4 @@ def test_kernel_on_single_node_graph(): # Check if a kernel matrix for identical single-node graphs is valid and symmetric assert K.shape == (2, 2), "Kernel matrix shape for single-node graphs is incorrect" assert K[0, 0] == K[1, 1], "Self-similarity for single-node graph should be the same" - assert torch.allclose(K, K.T, atol=1e-5), "Kernel matrix is not symmetric for single-node graph" + assert torch.allclose(K, K.T, atol=1e-100), "Kernel matrix is not symmetric for single-node graph" From de140b6642931dd3af5dae99fd1300cd1c9d0e2d Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Tue, 29 Oct 2024 16:26:11 +0100 Subject: [PATCH 05/32] Fix calculation for graphs with reordered edges --- grakel_replace/torch_wl_kernel.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/grakel_replace/torch_wl_kernel.py b/grakel_replace/torch_wl_kernel.py index 8fa781c9..dab04a1f 100644 --- a/grakel_replace/torch_wl_kernel.py +++ b/grakel_replace/torch_wl_kernel.py @@ -55,7 +55,10 @@ def _init_node_labels(self, graph: nx.Graph) -> torch.Tensor: # Get node labels and convert to indices labels = [] for node in range(graph.number_of_nodes()): - label = graph.nodes[node].get("label", str(node)) + if "label" in graph.nodes[node]: + label = graph.nodes[node]["label"] + else: + label = str(node) if label not in self.label_dict: self.label_dict[label] = self.label_counter self.label_counter += 1 @@ -72,17 +75,27 @@ def _wl_iteration(self, adj: torch.sparse.Tensor, node_label = labels[node].item() neighbors = adj.coalesce().indices()[1][adj.coalesce().indices()[0] == node] neighbor_label_list = sorted([labels[n].item() for n in neighbors]) - combined = f"{node_label}_{neighbor_label_list}" - if combined not in self.label_dict: - self.label_dict[combined] = self.label_counter - self.label_counter += 1 - new_labels.append(self.label_dict[combined]) + # Check if all neighbors have the same label as the current node + if all(labels[n] == labels[node] for n in neighbors): + new_labels.append(node_label) + else: + combined = f"{node_label}_{neighbor_label_list}" + if combined not in self.label_dict: + self.label_dict[combined] = self.label_counter + self.label_counter += 1 + new_labels.append(self.label_dict[combined]) return torch.tensor(new_labels, dtype=torch.long, device=self.device) def _compute_feature_vector(self, labels: torch.Tensor, size: int) -> torch.Tensor: """Compute histogram feature vector from labels with fixed size.""" + # Handle the case where all node labels are the same + if len(set(labels.cpu().numpy())) == 1: + feature = torch.zeros(size, device=self.device) + feature[labels[0].item()] = len(labels) + return feature + counts = Counter(labels.cpu().numpy()) feature = torch.zeros(size, device=self.device) for label, count in counts.items(): @@ -124,7 +137,7 @@ def forward(self, graphs: list[nx.Graph]) -> torch.Tensor: for _ in range(self.n_iter): new_label_tensors = [ self._wl_iteration(adj, labels) - for adj, labels in zip(adj_matrices, all_label_tensors[-1]) + for adj, labels in zip(adj_matrices, all_label_tensors[-1], strict=False) ] all_label_tensors.append(new_label_tensors) @@ -161,8 +174,7 @@ def from_networkx(graphs: list[nx.Graph], node_labels_tag: str = "label") -> lis """Convert NetworkX graphs ensuring proper node labeling.""" processed_graphs = [] for g in graphs: - # Ensure nodes are numbered from 0 to n-1 - g = nx.convert_node_labels_to_integers(g) + g = g.copy() # Add default labels if not present for node in g.nodes(): if node_labels_tag not in g.nodes[node]: From 08c7aea91b1265ed77537297df3030e14ff783a6 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Tue, 29 Oct 2024 16:29:52 +0100 Subject: [PATCH 06/32] Increase test coverage --- tests/test_torch_wl_kernel.py | 276 +++++++++++++++++++++------------- 1 file changed, 172 insertions(+), 104 deletions(-) diff --git a/tests/test_torch_wl_kernel.py b/tests/test_torch_wl_kernel.py index d039de5b..4b89f4d9 100644 --- a/tests/test_torch_wl_kernel.py +++ b/tests/test_torch_wl_kernel.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import networkx as nx import pytest import torch @@ -7,105 +5,175 @@ from grakel_replace.torch_wl_kernel import TorchWLKernel, GraphDataset -@pytest.mark.parametrize("n_iter", [1, 2, 3, 4, 5, 6, 7]) -@pytest.mark.parametrize("normalize", [True, False]) -def test_wl_kernel_against_grakel(n_iter, normalize): - """Test the custom WL kernel against Grakel's implementation. - - Args: - n_iter: Number of iterations for the WL kernel. - normalize: Whether to normalize the kernel matrix. - """ - # Create example graphs for testing - G1 = nx.Graph() - G1.add_edges_from([(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)]) - for node in G1.nodes(): - G1.nodes[node]["label"] = str(node) - - G2 = nx.Graph() - G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)]) - for node in G2.nodes(): - G2.nodes[node]["label"] = str(node) - - G3 = nx.Graph() - G3.add_edges_from([(0, 1), (1, 3), (3, 2)]) - for node in G3.nodes(): - G3.nodes[node]["label"] = str(node) - - graphs = GraphDataset.from_networkx([G1, G2, G3]) - - # Initialize and compute kernel matrix using custom WLKernel - wl_kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize) - torch_kernel_matrix = wl_kernel(graphs).cpu().detach().numpy() - - # Convert to Grakel-compatible format - grakel_graphs = graph_from_networkx([G1, G2, G3], node_labels_tag="label") - - # Initialize and compute kernel matrix using Grakel's Weisfeiler-Lehman kernel - grakel_wl = WeisfeilerLehman(n_iter=n_iter, normalize=normalize) - grakel_kernel_matrix = grakel_wl.fit_transform(grakel_graphs) - - # Assert that the kernel matrices are similar within a reasonable tolerance - assert torch.allclose( - torch.tensor(torch_kernel_matrix, dtype=torch.float64), - torch.tensor(grakel_kernel_matrix, dtype=torch.float64), - atol=1e-100 - ), f"Mismatch found in kernel matrices with n_iter={n_iter} and normalize={normalize}" - - -def test_kernel_symmetry(): - """Test if the kernel matrix is symmetric.""" - # Create example graphs for testing - G = nx.Graph() - G.add_edges_from([(0, 1), (1, 2), (2, 3)]) - for node in G.nodes(): - G.nodes[node]["label"] = str(node) - - graphs = GraphDataset.from_networkx([G, G]) - wl_kernel = TorchWLKernel(n_iter=3, normalize=True) - K = wl_kernel(graphs) - - # Check if the kernel matrix is symmetric - assert torch.allclose(K, K.T, atol=1e-100), "Kernel matrix is not symmetric" - - -def test_empty_graph(): - """Test the kernel computation for an empty graph.""" - # Test with an empty graph - G_empty = nx.Graph() - graphs = GraphDataset.from_networkx([G_empty]) - wl_kernel = TorchWLKernel(n_iter=3, normalize=True) - - # Check if kernel returns a valid matrix (1x1 zero matrix expected) - K = wl_kernel(graphs) - assert K.shape == (1, 1), "Kernel matrix shape for empty graph is incorrect" - assert K.item() == 0.0, "Kernel matrix value for empty graph should be zero" - - -def test_invalid_input(): - """Test that invalid inputs raise the appropriate TypeError.""" - # Test with invalid input types - wl_kernel = TorchWLKernel(n_iter=3, normalize=True) - - with pytest.raises(TypeError, match="Expected input type is a list of NetworkX graphs"): - wl_kernel("invalid_input") # Passing a string instead of a list of graphs - - with pytest.raises(TypeError, match="Expected input type is a list of NetworkX graphs"): - wl_kernel([1, 2, 3]) # Passing a list of integers instead of graphs - - -def test_kernel_on_single_node_graph(): - """Test the kernel computation for single-node graphs.""" - # Test with a single-node graph - G_single = nx.Graph() - G_single.add_node(0) - G_single.nodes[0]["label"] = "0" - - graphs = GraphDataset.from_networkx([G_single, G_single]) - wl_kernel = TorchWLKernel(n_iter=3, normalize=True) - K = wl_kernel(graphs) - - # Check if a kernel matrix for identical single-node graphs is valid and symmetric - assert K.shape == (2, 2), "Kernel matrix shape for single-node graphs is incorrect" - assert K[0, 0] == K[1, 1], "Self-similarity for single-node graph should be the same" - assert torch.allclose(K, K.T, atol=1e-100), "Kernel matrix is not symmetric for single-node graph" +class TestTorchWLKernel: + @pytest.fixture + def example_graphs(self): + # Create example graphs for testing + G1 = nx.Graph() + G1.add_edges_from([(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)]) + for node in G1.nodes(): + G1.nodes[node]["label"] = str(node) + + G2 = nx.Graph() + G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)]) + for node in G2.nodes(): + G2.nodes[node]["label"] = str(node) + + G3 = nx.Graph() + G3.add_edges_from([(0, 1), (1, 3), (3, 2)]) + for node in G3.nodes(): + G3.nodes[node]["label"] = str(node) + + return [G1, G2, G3] + + @pytest.mark.parametrize("n_iter", [1, 2, 3, 4, 5, 6, 7]) + @pytest.mark.parametrize("normalize", [True, False]) + def test_wl_kernel_against_grakel(self, n_iter, normalize, example_graphs): + """Test the custom WL kernel against Grakel's implementation.""" + graphs = GraphDataset.from_networkx(example_graphs) + + # Initialize and compute kernel matrix using custom WLKernel + wl_kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize) + torch_kernel_matrix = wl_kernel(graphs).cpu().detach().numpy() + + # Convert to Grakel-compatible format + grakel_graphs = graph_from_networkx(example_graphs, node_labels_tag="label") + + # Initialize and compute kernel matrix using Grakel's Weisfeiler-Lehman kernel + grakel_wl = WeisfeilerLehman(n_iter=n_iter, normalize=normalize) + grakel_kernel_matrix = grakel_wl.fit_transform(grakel_graphs) + + # Assert that the kernel matrices are similar within a reasonable tolerance + assert torch.allclose( + torch.tensor(torch_kernel_matrix, dtype=torch.float64), + torch.tensor(grakel_kernel_matrix, dtype=torch.float64), + atol=1e-100 + ), f"Mismatch found in kernel matrices with n_iter={n_iter} and normalize={normalize}" + + def test_kernel_symmetry(self, example_graphs): + """Test if the kernel matrix is symmetric.""" + graphs = GraphDataset.from_networkx([example_graphs[0], example_graphs[0]]) + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + K = wl_kernel(graphs) + + # Check if the kernel matrix is symmetric + assert torch.allclose(K, K.T, atol=1e-100), "Kernel matrix is not symmetric" + + def test_empty_graph(self): + """Test the kernel computation for an empty graph.""" + # Test with an empty graph + G_empty = nx.Graph() + graphs = GraphDataset.from_networkx([G_empty]) + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + + # Check if kernel returns a valid matrix (1x1 zero matrix expected) + K = wl_kernel(graphs) + assert K.shape == (1, 1), "Kernel matrix shape for empty graph is incorrect" + assert K.item() == 0.0, "Kernel matrix value for empty graph should be zero" + + def test_invalid_input(self): + """Test that invalid inputs raise the appropriate TypeError.""" + # Test with invalid input types + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + + with pytest.raises(TypeError, match="Expected input type is a list of NetworkX graphs"): + wl_kernel("invalid_input") # Passing a string instead of a list of graphs + + with pytest.raises(TypeError, match="Expected input type is a list of NetworkX graphs"): + wl_kernel([1, 2, 3]) # Passing a list of integers instead of graphs + + def test_kernel_on_single_node_graph(self, example_graphs): + """Test the kernel computation for single-node graphs.""" + # Test with a single-node graph + G_single = nx.Graph() + G_single.add_node(0) + G_single.nodes[0]["label"] = "0" + + graphs = GraphDataset.from_networkx([G_single, G_single]) + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + K = wl_kernel(graphs) + + # Check if a kernel matrix for identical single-node graphs is valid and symmetric + assert K.shape == (2, 2), "Kernel matrix shape for single-node graphs is incorrect" + assert K[0, 0] == K[1, 1], "Self-similarity for single-node graph should be the same" + assert torch.allclose(K, K.T, atol=1e-100), "Kernel matrix is not symmetric for single-node graph" + + def test_wl_kernel_with_empty_graph_and_reordered_edges(self, example_graphs): + """Test the TorchWLKernel with an empty graph and a graph with reordered edges.""" + # Create example graphs for testing + G_empty = nx.Graph() + G = example_graphs[0] + G_reordered = nx.Graph() + G_reordered.add_edges_from([(1, 4), (2, 3), (1, 2), (0, 1), (1, 3)]) + for node in G_reordered.nodes(): + G_reordered.nodes[node]["label"] = str(node) + + graphs = GraphDataset.from_networkx([G_empty, G, G_reordered]) + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + K = wl_kernel(graphs) + + # Check if the kernel matrix is valid and the values + # are the same for the original and reordered graphs + assert K.shape == (3, 3), "Kernel matrix shape is incorrect" + assert K[1, 1] == K[2, 2], "Kernel value for original and reordered graphs should be the same" + + @pytest.mark.parametrize("n_iter", [1, 2, 3, 4, 5, 6, 7]) + @pytest.mark.parametrize("normalize", [True, False]) + def test_wl_kernel_with_different_node_labels(self, n_iter, normalize, example_graphs): + """Test the TorchWLKernel with graphs having different node labels.""" + # Create example graphs with different node labels + G1 = example_graphs[0] + for node in G1.nodes(): + G1.nodes[node]["label"] = f"node_{node}" + + G2 = example_graphs[1] + for node in G2.nodes(): + G2.nodes[node]["label"] = f"vertex_{node}" + + G3 = example_graphs[2] + for node in G3.nodes(): + G3.nodes[node]["label"] = f"n{node}" + + graphs = GraphDataset.from_networkx([G1, G2, G3]) + + # Initialize and compute kernel matrix using custom TorchWLKernel + wl_kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize) + torch_kernel_matrix = wl_kernel(graphs).cpu().detach().numpy() + + # Convert to Grakel-compatible format + grakel_graphs = graph_from_networkx([G1, G2, G3], node_labels_tag="label") + + # Initialize and compute kernel matrix using Grakel's Weisfeiler-Lehman kernel + grakel_wl = WeisfeilerLehman(n_iter=n_iter, normalize=normalize) + grakel_kernel_matrix = grakel_wl.fit_transform(grakel_graphs) + + # Assert that the kernel matrices are similar within a reasonable tolerance + assert torch.allclose( + torch.tensor(torch_kernel_matrix, dtype=torch.float64), + torch.tensor(grakel_kernel_matrix, dtype=torch.float64), + atol=1e-100 + ), f"Mismatch found in kernel matrices with n_iter={n_iter} and normalize={normalize} for graphs with different node labels" + + def test_wl_kernel_with_same_node_labels(self, example_graphs): + """Test the TorchWLKernel with graphs having the same node labels.""" + # Create example graphs with the same node labels + G1 = example_graphs[0] + for node in G1.nodes(): + G1.nodes[node]["label"] = "A" + + G2 = example_graphs[1] + for node in G2.nodes(): + G2.nodes[node]["label"] = "A" + + G3 = example_graphs[2] + for node in G3.nodes(): + G3.nodes[node]["label"] = "A" + + graphs = GraphDataset.from_networkx([G1, G2, G3]) + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + K = wl_kernel(graphs) + + # Check if the kernel matrix is valid and the values are the same for the graphs with the same node labels + assert K.shape == (3, 3), "Kernel matrix shape is incorrect" + assert torch.allclose(K, K.T, atol=1e-100), "Kernel matrix is not symmetric" + assert torch.all(K == K[0, 0]), "Kernel values should be the same for graphs with the same node labels" From 6f078581332230988e0b048794264d951eb42010 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Wed, 30 Oct 2024 09:52:25 +0100 Subject: [PATCH 07/32] Improve readability of TorchWLKernel --- grakel_replace/torch_wl_kernel.py | 173 +++++++++++++++++++----------- 1 file changed, 112 insertions(+), 61 deletions(-) diff --git a/grakel_replace/torch_wl_kernel.py b/grakel_replace/torch_wl_kernel.py index dab04a1f..d4957eb8 100644 --- a/grakel_replace/torch_wl_kernel.py +++ b/grakel_replace/torch_wl_kernel.py @@ -8,25 +8,43 @@ class TorchWLKernel(nn.Module): - """Custom PyTorch implementation of Weisfeiler-Lehman Kernel. + """A custom implementation of Weisfeiler-Lehman (WL) Kernel in PyTorch. + + The WL Kernel is a graph kernel that measures similarity between graphs based on + their structural properties. It works by iteratively updating node labels based on + their neighborhoods and computing feature vectors from label distributions. Args: - n_iter: Number of WL iterations - normalize: Whether to normalize the kernel matrix + n_iter: Number of WL iterations to perform + normalize: bool, optional. Whether to normalize the kernel matrix + + Attributes: + device: torch.device for computation (CPU/GPU) + label_dict: Mapping from node labels to numerical indices + label_counter: Counter for generating new label indices """ - def __init__(self, n_iter: int = 5, normalize: bool = True): + def __init__(self, n_iter: int = 5, normalize: bool = True) -> None: super().__init__() self.n_iter = n_iter self.normalize = normalize - self.label_dict = {} - self.label_counter = 0 + self.device: torch.device = torch.device("cpu") + self.label_dict: dict[str, int] = {} + self.label_counter: int = 0 def _get_sparse_adj(self, graph: nx.Graph) -> torch.sparse.Tensor: - """Convert NetworkX graph to sparse adjacency tensor.""" + """Convert a NetworkX graph to a sparse adjacency tensor. + + Args: + graph: Input NetworkX graph + + Returns: + Sparse tensor representation of the graph's adjacency matrix + """ edges = list(graph.edges()) + num_nodes = graph.number_of_nodes() + if not edges: - num_nodes = graph.number_of_nodes() return torch.sparse_coo_tensor( indices=torch.empty((2, 0), dtype=torch.long), values=torch.empty(0), @@ -34,26 +52,29 @@ def _get_sparse_adj(self, graph: nx.Graph) -> torch.sparse.Tensor: device=self.device ) - # Create COO format indices - row = torch.tensor([e[0] for e in edges], dtype=torch.long) - col = torch.tensor([e[1] for e in edges], dtype=torch.long) - edges = torch.stack([ - torch.cat([row, col]), # Add both directions for undirected graph - torch.cat([col, row]) - ]) + # Create bidirectional edge indices for undirected graph + edge_indices: list[tuple[int, int]] = edges + [(v, u) for u, v in edges] + rows, cols = zip(*edge_indices, strict=False) - values = torch.ones(edges.size(1), dtype=torch.float) - N = graph.number_of_nodes() + indices = torch.tensor([rows, cols], dtype=torch.long, device=self.device) + values = torch.ones(len(edge_indices), dtype=torch.float, device=self.device) return torch.sparse_coo_tensor( - edges, values, (N, N), + indices, values, (num_nodes, num_nodes), device=self.device ) def _init_node_labels(self, graph: nx.Graph) -> torch.Tensor: - """Initialize node label tensor from graph.""" - # Get node labels and convert to indices - labels = [] + """Initialize node label tensor from graph attributes. + + Args: + graph: Input NetworkX graph + + Returns: + Tensor of numerical node label indices + """ + labels: list[int] = [] + for node in range(graph.number_of_nodes()): if "label" in graph.nodes[node]: label = graph.nodes[node]["label"] @@ -66,62 +87,96 @@ def _init_node_labels(self, graph: nx.Graph) -> torch.Tensor: return torch.tensor(labels, dtype=torch.long, device=self.device) - def _wl_iteration(self, adj: torch.sparse.Tensor, - labels: torch.Tensor) -> torch.Tensor: - """Perform one WL iteration.""" - # Concatenate own label with sorted neighbor labels - new_labels = [] + def _wl_iteration( + self, + adj: torch.sparse.Tensor, + labels: torch.Tensor + ) -> torch.Tensor: + """Perform one WL iteration to update node labels. + Concatenate own label with sorted neighbor labels. + + Args: + adj: Sparse adjacency matrix + labels: Current node label tensor + + Returns: + Updated node label tensor + """ + new_labels: list[int] = [] + indices = adj.coalesce().indices() + for node in range(adj.size(0)): node_label = labels[node].item() - neighbors = adj.coalesce().indices()[1][adj.coalesce().indices()[0] == node] - neighbor_label_list = sorted([labels[n].item() for n in neighbors]) + # Get indices of neighbors for current node + neighbors = indices[1][indices[0] == node] + neighbor_labels = sorted([labels[n].item() for n in neighbors]) # Check if all neighbors have the same label as the current node if all(labels[n] == labels[node] for n in neighbors): new_labels.append(node_label) else: - combined = f"{node_label}_{neighbor_label_list}" - if combined not in self.label_dict: - self.label_dict[combined] = self.label_counter + # Create new label from node and neighbor information + combined_label = f"{node_label}_{neighbor_labels}" + if combined_label not in self.label_dict: + self.label_dict[combined_label] = self.label_counter self.label_counter += 1 - new_labels.append(self.label_dict[combined]) + new_labels.append(self.label_dict[combined_label]) return torch.tensor(new_labels, dtype=torch.long, device=self.device) - def _compute_feature_vector(self, labels: torch.Tensor, size: int) -> torch.Tensor: - """Compute histogram feature vector from labels with fixed size.""" + def _compute_feature_vector( + self, + labels: torch.Tensor, + size: int + ) -> torch.Tensor: + """Compute histogram feature vector from node labels. + + Args: + labels: Node label tensor + size: Size of the feature vector + + Returns: + Feature vector representing label distribution + """ # Handle the case where all node labels are the same - if len(set(labels.cpu().numpy())) == 1: + unique_labels = set(labels.cpu().numpy()) + if len(unique_labels) == 1: feature = torch.zeros(size, device=self.device) feature[labels[0].item()] = len(labels) return feature - counts = Counter(labels.cpu().numpy()) + label_counts = Counter(labels.cpu().numpy()) feature = torch.zeros(size, device=self.device) - for label, count in counts.items(): + + for label, count in label_counts.items(): if label < size: # Safety check feature[label] = count + return feature def forward(self, graphs: list[nx.Graph]) -> torch.Tensor: """Compute WL kernel matrix for a list of graphs. Args: - graphs: List of NetworkX graphs + graphs: List of NetworkX graphs to compare Returns: - Kernel matrix as a torch.Tensor + Kernel matrix containing pairwise graph similarities + + Raises: + TypeError: If input is not a list of NetworkX graphs """ # Validate input if (not isinstance(graphs, list) or not all(isinstance(g, nx.Graph) for g in graphs)): raise TypeError("Expected input type is a list of NetworkX graphs.") + # Setup computation self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.label_dict = {} self.label_counter = 0 - # Handle case of empty graphs list or empty individual graphs + # Handle a case of empty graphs list or empty individual graphs if not graphs or all(g.number_of_nodes() == 0 for g in graphs): return torch.zeros((len(graphs), len(graphs)), device=self.device) @@ -129,40 +184,36 @@ def forward(self, graphs: list[nx.Graph]) -> torch.Tensor: adj_matrices = [self._get_sparse_adj(g) for g in graphs] label_tensors = [self._init_node_labels(g) for g in graphs] - # Pre-allocate feature matrices list - feature_matrices = [] - - # First, run all iterations to compute maximum label count - all_label_tensors = [label_tensors] + # Collect label tensors from all iterations + all_label_tensors: list[list[torch.Tensor]] = [label_tensors] for _ in range(self.n_iter): - new_label_tensors = [ + new_labels = [ self._wl_iteration(adj, labels) for adj, labels in zip(adj_matrices, all_label_tensors[-1], strict=False) ] - all_label_tensors.append(new_label_tensors) - - max_label_count = self.label_counter + all_label_tensors.append(new_labels) - # Now compute feature vectors for all iterations with fixed size - for labels_list in all_label_tensors: - features = torch.stack([ - self._compute_feature_vector(labels, max_label_count) - for labels in labels_list + # Compute feature matrices using final label count + feature_matrices = [ + torch.stack([ + self._compute_feature_vector(labels, self.label_counter) + for labels in iteration_labels ]) - feature_matrices.append(features) + for iteration_labels in all_label_tensors + ] - # Sum up feature matrices from all iterations + # Combine features from all iterations final_features = torch.stack(feature_matrices).sum(dim=0) # Compute kernel matrix - K = torch.mm(final_features, final_features.t()) + kernel_matrix = torch.mm(final_features, final_features.t()) - # Normalize if requested + # Apply normalization if requested if self.normalize: - diag = torch.sqrt(torch.diag(K)) - K = K / (diag.unsqueeze(0) * diag.unsqueeze(1)) + diag = torch.sqrt(torch.diag(kernel_matrix)) + kernel_matrix = kernel_matrix / (diag.unsqueeze(0) * diag.unsqueeze(1)) - return K + return kernel_matrix class GraphDataset: From 896f461b64ab9c827f2298e29390bdc316d05327 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Wed, 30 Oct 2024 10:10:41 +0100 Subject: [PATCH 08/32] Add additional comments to TorchWLKernel --- grakel_replace/torch_wl_kernel.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/grakel_replace/torch_wl_kernel.py b/grakel_replace/torch_wl_kernel.py index d4957eb8..929365a1 100644 --- a/grakel_replace/torch_wl_kernel.py +++ b/grakel_replace/torch_wl_kernel.py @@ -106,8 +106,9 @@ def _wl_iteration( indices = adj.coalesce().indices() for node in range(adj.size(0)): + # Step 1. Get current node's label node_label = labels[node].item() - # Get indices of neighbors for current node + # Step 2. Get neighbor labels for current node neighbors = indices[1][indices[0] == node] neighbor_labels = sorted([labels[n].item() for n in neighbors]) @@ -115,8 +116,9 @@ def _wl_iteration( if all(labels[n] == labels[node] for n in neighbors): new_labels.append(node_label) else: - # Create new label from node and neighbor information + # Step 3. Create a new label combining node and neighbor information combined_label = f"{node_label}_{neighbor_labels}" + # Step 4. Assign a numerical index to this new label if combined_label not in self.label_dict: self.label_dict[combined_label] = self.label_counter self.label_counter += 1 @@ -145,11 +147,14 @@ def _compute_feature_vector( feature[labels[0].item()] = len(labels) return feature + # Count the frequency of each label label_counts = Counter(labels.cpu().numpy()) + # In the feature vector, each position represents a label feature = torch.zeros(size, device=self.device) for label, count in label_counts.items(): if label < size: # Safety check + # The value represents how many times that label appears in the graph feature[label] = count return feature @@ -182,6 +187,7 @@ def forward(self, graphs: list[nx.Graph]) -> torch.Tensor: # Convert graphs to sparse adjacency matrices and initialize labels adj_matrices = [self._get_sparse_adj(g) for g in graphs] + # Initialize node labels - either use provided labels or default to node indices label_tensors = [self._init_node_labels(g) for g in graphs] # Collect label tensors from all iterations @@ -205,7 +211,7 @@ def forward(self, graphs: list[nx.Graph]) -> torch.Tensor: # Combine features from all iterations final_features = torch.stack(feature_matrices).sum(dim=0) - # Compute kernel matrix + # Compute kernel matrix (similarity matrix) kernel_matrix = torch.mm(final_features, final_features.t()) # Apply normalization if requested From 383e92431f0b9d45ec816bea9147db131be55005 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Fri, 8 Nov 2024 09:13:11 +0100 Subject: [PATCH 09/32] Add MixedSingleTaskGP to process graphs --- grakel_replace/mixed_single_task_gp.py | 165 ++++++++++++ .../mixed_single_task_gp_usage_example.py | 102 ++++++++ .../single_task_gp_usage_example.py | 81 ++++++ tests/test_mixed_single_task_gp.py | 236 ++++++++++++++++++ 4 files changed, 584 insertions(+) create mode 100644 grakel_replace/mixed_single_task_gp.py create mode 100644 grakel_replace/mixed_single_task_gp_usage_example.py create mode 100644 grakel_replace/single_task_gp_usage_example.py create mode 100644 tests/test_mixed_single_task_gp.py diff --git a/grakel_replace/mixed_single_task_gp.py b/grakel_replace/mixed_single_task_gp.py new file mode 100644 index 00000000..91b8b9ba --- /dev/null +++ b/grakel_replace/mixed_single_task_gp.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from botorch.models import SingleTaskGP +from gpytorch.distributions.multivariate_normal import MultivariateNormal +from gpytorch.kernels import AdditiveKernel +from gpytorch.module import Module +from grakel_replace.torch_wl_kernel import GraphDataset, TorchWLKernel + +if TYPE_CHECKING: + import networkx as nx + from torch import Tensor + + +class MixedSingleTaskGP(SingleTaskGP): + """A Gaussian Process model that handles numerical, categorical, and graph inputs. + + This class extends BoTorch's SingleTaskGP to work with hybrid input spaces containing: + - Numerical features + - Categorical features + - Graph structures + + It uses the Weisfeiler-Lehman (WL) kernel for graph inputs and combines it with + standard kernels for numerical/categorical features using an additive kernel structure + + Attributes: + _wl_kernel (TorchWLKernel): The Weisfeiler-Lehman kernel for graph similarity + _train_graphs (List[nx.Graph]): Training set graph instances + _K_graph (Tensor): Pre-computed graph kernel matrix for training data + num_cat_kernel (Optional[Module]): Kernel for numerical/categorical features + """ + + def __init__( + self, + train_X: Tensor, # Shape: (n_samples, n_numerical_categorical_features) + train_graphs: list[nx.Graph], # List of n_samples graph instances + train_Y: Tensor, # Shape: (n_samples, n_outputs) + train_Yvar: Tensor | None = None, # Shape: (n_samples, n_outputs) or None + num_cat_kernel: Module | None = None, + wl_kernel: TorchWLKernel | None = None, + **kwargs # Additional arguments passed to SingleTaskGP + ) -> None: + """Initialize the mixed input Gaussian Process model. + + Args: + train_X: Training data tensor for numerical and categorical features + train_graphs: List of training graphs + train_Y: Target values + train_Yvar: Observation noise variance (optional) + num_cat_kernel: Kernel for numerical/categorical features (optional) + wl_kernel: Custom Weisfeiler-Lehman kernel instance (optional) + **kwargs: Additional arguments for SingleTaskGP initialization + """ + # Initialize parent class with initial covar_module + super().__init__( + train_X=train_X, + train_Y=train_Y, + train_Yvar=train_Yvar, + covar_module=num_cat_kernel or self._graph_kernel_wrapper(), + **kwargs + ) + + # Initialize WL kernel with default parameters if not provided + self._wl_kernel = wl_kernel or TorchWLKernel(n_iter=5, normalize=True) + self._train_graphs = train_graphs + + # Convert graphs to required format and compute kernel matrix + self._train_graph_dataset = GraphDataset.from_networkx(train_graphs) + self._K_train = self._wl_kernel(self._train_graph_dataset) + + if num_cat_kernel is not None: + # Create additive kernel combining numerical/categorical and graph kernels + combined_kernel = AdditiveKernel( + num_cat_kernel, + self._graph_kernel_wrapper() + ) + self.covar_module = combined_kernel + + self.num_cat_kernel = num_cat_kernel + + def _graph_kernel_wrapper(self) -> Module: + """Creates a GPyTorch-compatible kernel module wrapping the WL kernel. + + This wrapper allows the WL kernel to be used within the GPyTorch framework + by providing a forward method that returns the pre-computed kernel matrix. + + Returns: + Module: A GPyTorch kernel module wrapping the WL kernel computation + """ + + class WLKernelWrapper(Module): + def __init__(self, parent: MixedSingleTaskGP): + super().__init__() + self.parent = parent + + def forward( + self, + x1: Tensor, + x2: Tensor | None = None, + diag: bool = False, + last_dim_is_batch: bool = False + ) -> Tensor: + """Compute the kernel matrix for the graph inputs. + + Args: + x1: First input tensor (unused, required for interface compatibility) + x2: Second input tensor (must be None) + diag: Whether to return only diagonal elements + last_dim_is_batch: Whether the last dimension is a batch dimension + + Returns: + Tensor: Pre-computed graph kernel matrix + + Raises: + NotImplementedError: If x2 is not None (cross-covariance not implemented) + """ + if x2 is None: + return self.parent._K_train + + # Compute cross-covariance between train and test graphs + test_dataset = GraphDataset.from_networkx(self.parent._test_graphs) + return self.parent._wl_kernel( + self.parent._train_graph_dataset, + test_dataset + ) + + return WLKernelWrapper(self) + + def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal: + """Forward pass computing the GP distribution for given inputs. + + Computes the kernel matrix for both numerical/categorical features and graphs, + combines them if both are present, and returns the resulting GP distribution. + + Args: + X: Input tensor for numerical and categorical features + graphs: List of input graphs + + Returns: + MultivariateNormal: GP distribution for the given inputs + """ + if len(X) != len(graphs): + raise ValueError( + f"Number of feature vectors ({len(X)}) must match " + f"number of graphs ({len(graphs)})" + ) + + # Process new graphs and compute kernel matrix + proc_graphs = GraphDataset.from_networkx(graphs) + K_new = self._wl_kernel(proc_graphs) # Shape: (n_samples, n_samples) + + # If we have both numerical/categorical and graph features + if self.num_cat_kernel is not None: + # Compute kernel for numerical/categorical features + K_num_cat = self.num_cat_kernel(X) + # Add the kernels (element-wise addition) + K_combined = K_num_cat + K_new + else: + K_combined = K_new + + # Compute mean using the mean module + mean_x = self.mean_module(X) + + return MultivariateNormal(mean_x, K_combined) diff --git a/grakel_replace/mixed_single_task_gp_usage_example.py b/grakel_replace/mixed_single_task_gp_usage_example.py new file mode 100644 index 00000000..67854b7c --- /dev/null +++ b/grakel_replace/mixed_single_task_gp_usage_example.py @@ -0,0 +1,102 @@ +import networkx as nx +import torch +from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel +from gpytorch.distributions.multivariate_normal import MultivariateNormal +from gpytorch.kernels import AdditiveKernel, MaternKernel +from grakel_replace.mixed_single_task_gp import MixedSingleTaskGP +from grakel_replace.torch_wl_kernel import TorchWLKernel + +TRAIN_CONFIGS = 10 +TEST_CONFIGS = 10 +TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS + +N_NUMERICAL = 2 +N_CATEGORICAL = 2 +N_CATEGORICAL_VALUES_PER_CATEGORY = 3 +N_GRAPH = 2 + +kernels = [] + +# Create numerical and categorical features +X = torch.empty(size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL), dtype=torch.float64) +if N_NUMERICAL > 0: + X[:, :N_NUMERICAL] = torch.rand( + size=(TOTAL_CONFIGS, N_NUMERICAL), + dtype=torch.float64, + ) + +if N_CATEGORICAL > 0: + X[:, N_NUMERICAL:] = torch.randint( + 0, + N_CATEGORICAL_VALUES_PER_CATEGORY, + size=(TOTAL_CONFIGS, N_CATEGORICAL), + dtype=torch.float64, + ) + +# Create random graph architectures +graphs = [] +for _ in range(TOTAL_CONFIGS): + G = nx.erdos_renyi_graph(n=5, p=0.5) # Random graph with 5 nodes + graphs.append(G) + +# Create random target values +y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64) + +# Setup kernels for numerical and categorical features +if N_NUMERICAL > 0: + matern = ScaleKernel( + MaternKernel( + nu=2.5, + ard_num_dims=N_NUMERICAL, + active_dims=tuple(range(N_NUMERICAL)), + ), + ) + kernels.append(matern) + +if N_CATEGORICAL > 0: + hamming = ScaleKernel( + CategoricalKernel( + ard_num_dims=N_CATEGORICAL, + active_dims=tuple(range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)), + ), + ) + kernels.append(hamming) + +# Combine numerical and categorical kernels +combined_num_cat_kernel = AdditiveKernel(*kernels) if kernels else None + +# Create WL kernel for graphs +wl_kernel = TorchWLKernel(n_iter=5, normalize=True) + +# Split into train and test sets +train_x = X[:TRAIN_CONFIGS] +train_graphs = graphs[:TRAIN_CONFIGS] +train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch + +test_x = X[TRAIN_CONFIGS:] +test_graphs = graphs[TRAIN_CONFIGS:] +test_y = y[TRAIN_CONFIGS:].unsqueeze(-1) + +# Initialize the mixed GP +gp = MixedSingleTaskGP( + train_X=train_x, + train_graphs=train_graphs, + train_Y=train_y, + num_cat_kernel=combined_num_cat_kernel, + wl_kernel=wl_kernel, +) + +# Compute the posterior distribution +multivariate_normal: MultivariateNormal = gp.forward(train_x, train_graphs) +print("Posterior distribution:", multivariate_normal) + +# Making predictions on test data +with torch.no_grad(): + posterior = gp.forward(test_x, test_graphs) + predictions = posterior.mean + uncertainties = posterior.variance.sqrt() + covar = posterior.covariance_matrix + +print("\nMean:", predictions) +print("Variance:", uncertainties) +print("Covariance matrix:", covar) diff --git a/grakel_replace/single_task_gp_usage_example.py b/grakel_replace/single_task_gp_usage_example.py new file mode 100644 index 00000000..bdc9a0ed --- /dev/null +++ b/grakel_replace/single_task_gp_usage_example.py @@ -0,0 +1,81 @@ +import torch +from botorch.models import SingleTaskGP +from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel +from gpytorch.distributions.multivariate_normal import MultivariateNormal +from gpytorch.kernels import AdditiveKernel, MaternKernel + +TRAIN_CONFIGS = 10 +TEST_CONFIGS = 10 +TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS + +N_NUMERICAL = 2 +N_CATEGORICAL = 2 +N_CATEGORICAL_VALUES_PER_CATEGORY = 3 + +kernels = [] + +# Create some random encoded hyperparameter configurations +X = torch.empty(size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL), dtype=torch.float64) +if N_NUMERICAL > 0: + X[:, :N_NUMERICAL] = torch.rand( + size=(TOTAL_CONFIGS, N_NUMERICAL), + dtype=torch.float64, + ) + +if N_CATEGORICAL > 0: + X[:, N_NUMERICAL:] = torch.randint( + 0, + N_CATEGORICAL_VALUES_PER_CATEGORY, + size=(TOTAL_CONFIGS, N_CATEGORICAL), + dtype=torch.float64, + ) + +y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64) + +if N_NUMERICAL > 0: + matern = ScaleKernel( + MaternKernel( + nu=2.5, + ard_num_dims=N_NUMERICAL, + active_dims=tuple(range(N_NUMERICAL)), + ), + ) + kernels.append(matern) + +if N_CATEGORICAL > 0: + hamming = ScaleKernel( + CategoricalKernel( + ard_num_dims=N_CATEGORICAL, + active_dims=tuple(range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)), + ), + ) + kernels.append(hamming) + + +combined_num_cat_kernel = AdditiveKernel(*kernels) + +train_x = X[:TRAIN_CONFIGS] +train_y = y[:TRAIN_CONFIGS] + +test_x = X[TRAIN_CONFIGS:] +test_y = y[TRAIN_CONFIGS:] + +K_matrix = combined_num_cat_kernel.forward(train_x, train_x) +print( + "K_matrix: ", K_matrix.to_dense() +) + +train_y = train_y.unsqueeze(-1) +test_y = test_y.unsqueeze(-1) + +gp = SingleTaskGP( + train_X=train_x, + train_Y=train_y, + mean_module=None, # We can leave it as the default it uses which is `ConstantMean` + covar_module=combined_num_cat_kernel, +) + +multivariate_normal: MultivariateNormal = gp.forward(train_x) +print("Mean:", multivariate_normal.mean) +print("Variance:", multivariate_normal.variance) +print("Covariance matrix:", multivariate_normal.covariance_matrix) diff --git a/tests/test_mixed_single_task_gp.py b/tests/test_mixed_single_task_gp.py new file mode 100644 index 00000000..68bad6fc --- /dev/null +++ b/tests/test_mixed_single_task_gp.py @@ -0,0 +1,236 @@ +import pytest +import torch +import networkx as nx +from gpytorch.kernels import MaternKernel, AdditiveKernel +from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel +from gpytorch.distributions.multivariate_normal import MultivariateNormal + +from grakel_replace.mixed_single_task_gp import MixedSingleTaskGP +from grakel_replace.torch_wl_kernel import TorchWLKernel + + +@pytest.fixture +def sample_data(): + """Create sample data for testing.""" + n_samples = 5 + n_numerical = 2 + n_categorical = 2 + + # Create numerical and categorical features + X = torch.empty(size=(n_samples, n_numerical + n_categorical), dtype=torch.float64) + X[:, :n_numerical] = torch.rand(size=(n_samples, n_numerical), dtype=torch.float64) + X[:, n_numerical:] = torch.randint(0, 3, size=(n_samples, n_categorical), + dtype=torch.float64) + + # Create sample graphs + graphs = [nx.erdos_renyi_graph(n=4, p=0.5) for _ in range(n_samples)] + + # Create target values + y = torch.rand(size=(n_samples, 1), dtype=torch.float64) + + return X, graphs, y + + +@pytest.fixture +def sample_kernels(): + """Create sample kernels for testing.""" + n_numerical = 2 + n_categorical = 2 + + matern = ScaleKernel( + MaternKernel( + nu=2.5, + ard_num_dims=n_numerical, + active_dims=tuple(range(n_numerical)), + ), + ) + + hamming = ScaleKernel( + CategoricalKernel( + ard_num_dims=n_categorical, + active_dims=tuple(range(n_numerical, n_numerical + n_categorical)), + ), + ) + + combined_kernel = AdditiveKernel(matern, hamming) + wl_kernel = TorchWLKernel(n_iter=3, normalize=True) + + return combined_kernel, wl_kernel + + +def test_initialization(sample_data, sample_kernels): + """Test that MixedSingleTaskGP initializes correctly.""" + X, graphs, y = sample_data + combined_kernel, wl_kernel = sample_kernels + + # Test initialization with all parameters + gp = MixedSingleTaskGP( + train_X=X, + train_graphs=graphs, + train_Y=y, + num_cat_kernel=combined_kernel, + wl_kernel=wl_kernel, + ) + + assert gp._train_graphs == graphs + assert isinstance(gp._wl_kernel, TorchWLKernel) + assert gp.num_cat_kernel == combined_kernel + + +def test_forward_shape(sample_data, sample_kernels): + """Test that forward pass returns correct shapes.""" + X, graphs, y = sample_data + combined_kernel, wl_kernel = sample_kernels + + gp = MixedSingleTaskGP( + train_X=X, + train_graphs=graphs, + train_Y=y, + num_cat_kernel=combined_kernel, + wl_kernel=wl_kernel, + ) + + # Test forward pass with training data + output = gp.forward(X, graphs) + assert isinstance(output, MultivariateNormal) + assert output.mean.shape == (len(X),) + assert output.covariance_matrix.shape == (len(X), len(X)) + + # Test forward pass with different sized test data + n_test = 3 + test_X = torch.rand(size=(n_test, X.shape[1]), dtype=torch.float64) + test_graphs = [nx.erdos_renyi_graph(n=4, p=0.5) for _ in range(n_test)] + + output = gp.forward(test_X, test_graphs) + assert output.mean.shape == (n_test,) + assert output.covariance_matrix.shape == (n_test, n_test) + + +def test_input_validation(sample_data, sample_kernels): + """Test that appropriate errors are raised for invalid inputs.""" + X, graphs, y = sample_data + combined_kernel, wl_kernel = sample_kernels + + gp = MixedSingleTaskGP( + train_X=X, + train_graphs=graphs, + train_Y=y, + num_cat_kernel=combined_kernel, + wl_kernel=wl_kernel, + ) + + # Test mismatched number of features and graphs + test_X = torch.rand(size=(3, X.shape[1]), dtype=torch.float64) + test_graphs = [nx.erdos_renyi_graph(n=4, p=0.5) for _ in range(4)] # Different length + + with pytest.raises(ValueError, + match="Number of feature vectors.*must match.*number of graphs"): + gp.forward(test_X, test_graphs) + + +def test_kernel_combination(sample_data): + """Test that numerical/categorical and graph kernels are properly combined.""" + X, graphs, y = sample_data + + # Create GP with only graph kernel + gp_graph_only = MixedSingleTaskGP( + train_X=X, + train_graphs=graphs, + train_Y=y, + ) + + output_graph = gp_graph_only.forward(X, graphs) + graph_var = output_graph.variance + + # Create GP with both kernels + n_numerical = 2 + matern = ScaleKernel( + MaternKernel( + nu=2.5, + ard_num_dims=n_numerical, + active_dims=tuple(range(n_numerical)), + ), + ) + + gp_combined = MixedSingleTaskGP( + train_X=X, + train_graphs=graphs, + train_Y=y, + num_cat_kernel=matern, + ) + + output_combined = gp_combined.forward(X, graphs) + combined_var = output_combined.variance + + # Combined kernel should have larger variance due to addition + assert torch.all(combined_var > graph_var) + + +def test_prediction_consistency(sample_data, sample_kernels): + """Test that predictions are consistent between multiple forward passes.""" + X, graphs, y = sample_data + combined_kernel, wl_kernel = sample_kernels + + gp = MixedSingleTaskGP( + train_X=X, + train_graphs=graphs, + train_Y=y, + num_cat_kernel=combined_kernel, + wl_kernel=wl_kernel, + ) + + # Multiple forward passes should give same result + output1 = gp.forward(X, graphs) + output2 = gp.forward(X, graphs) + + assert torch.allclose(output1.mean, output2.mean) + assert torch.allclose(output1.variance, output2.variance) + + +def test_graph_kernel_caching(sample_data, sample_kernels): + """Test that graph kernel matrices are properly cached.""" + X, graphs, y = sample_data + combined_kernel, wl_kernel = sample_kernels + + gp = MixedSingleTaskGP( + train_X=X, + train_graphs=graphs, + train_Y=y, + num_cat_kernel=combined_kernel, + wl_kernel=wl_kernel, + ) + + # First forward pass + _ = gp.forward(X, graphs) + K_train_1 = gp._K_train.clone() + + # Second forward pass + _ = gp.forward(X, graphs) + K_train_2 = gp._K_train.clone() + + # Cached kernel matrices should be identical + assert torch.allclose(K_train_1, K_train_2) + + +def test_mean_predictions(sample_data, sample_kernels): + """Test that mean predictions are reasonable.""" + X, graphs, y = sample_data + combined_kernel, wl_kernel = sample_kernels + + gp = MixedSingleTaskGP( + train_X=X, + train_graphs=graphs, + train_Y=y, + num_cat_kernel=combined_kernel, + wl_kernel=wl_kernel, + ) + + # Test predictions + with torch.no_grad(): + output = gp.forward(X, graphs) + predictions = output.mean + uncertainties = output.variance.sqrt() + + # Mean predictions should be within reasonable bounds + assert torch.all(predictions >= y.min() - 2 * uncertainties) + assert torch.all(predictions <= y.max() + 2 * uncertainties) From 65666a315f75d3365c97f044b8be1b8e117115ad Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Wed, 20 Nov 2024 22:38:16 +0100 Subject: [PATCH 10/32] Refactor WLKernelWrapper into a standalone WLKernel class. --- grakel_replace/mixed_single_task_gp.py | 194 +++++++++++++------------ 1 file changed, 99 insertions(+), 95 deletions(-) diff --git a/grakel_replace/mixed_single_task_gp.py b/grakel_replace/mixed_single_task_gp.py index 91b8b9ba..e09e288c 100644 --- a/grakel_replace/mixed_single_task_gp.py +++ b/grakel_replace/mixed_single_task_gp.py @@ -4,141 +4,146 @@ from botorch.models import SingleTaskGP from gpytorch.distributions.multivariate_normal import MultivariateNormal -from gpytorch.kernels import AdditiveKernel -from gpytorch.module import Module +from gpytorch.kernels import AdditiveKernel, Kernel from grakel_replace.torch_wl_kernel import GraphDataset, TorchWLKernel if TYPE_CHECKING: import networkx as nx + from gpytorch.module import Module from torch import Tensor +class WLKernel(Kernel): + """Weisfeiler-Lehman Kernel for graph similarity + integrated into the GPyTorch framework. + + This kernel encapsulates the precomputed Weisfeiler-Lehman graph kernel matrix + and provides it in a GPyTorch-compatible format. + It computes either the training kernel + or the cross-kernel between training and test graphs as needed. + + Args: + parent (MixedSingleTaskGP): + The parent MixedSingleTaskGP instance that holds + the training data and precomputed kernel matrix. + """ + + def __init__(self, parent: MixedSingleTaskGP) -> None: + super().__init__() + self.parent = parent + + def forward( + self, + x1: Tensor, + x2: Tensor | None = None, + diag: bool = False, + last_dim_is_batch: bool = False, + ) -> Tensor: + """Forward method to compute the kernel matrix for the graph inputs. + + Args: + x1 (Tensor): First input tensor + (unused, required for interface compatibility). + x2 (Tensor | None): Second input tensor. + If None, computes the training kernel matrix. + diag (bool): Whether to return only the diagonal of the kernel matrix. + last_dim_is_batch (bool): Whether the last dimension is a batch dimension. + + Returns: + Tensor: The computed kernel matrix. + """ + if x2 is None: + # Return the precomputed training kernel matrix + return self.parent._K_train + + # Compute cross-kernel between training graphs and new test graphs + test_dataset = GraphDataset.from_networkx(self.parent._test_graphs) + return self.parent._wl_kernel( + self.parent._train_graph_dataset, test_dataset + ) + + class MixedSingleTaskGP(SingleTaskGP): - """A Gaussian Process model that handles numerical, categorical, and graph inputs. + """A Gaussian Process model for mixed input spaces containing numerical, categorical, + and graph features. - This class extends BoTorch's SingleTaskGP to work with hybrid input spaces containing: - - Numerical features - - Categorical features - - Graph structures + This class extends BoTorch's SingleTaskGP to support hybrid inputs by combining: + - Standard kernels for numerical and categorical features. + - Weisfeiler-Lehman kernel for graph structures. - It uses the Weisfeiler-Lehman (WL) kernel for graph inputs and combines it with - standard kernels for numerical/categorical features using an additive kernel structure + The kernels are combined using an additive kernel structure. Attributes: - _wl_kernel (TorchWLKernel): The Weisfeiler-Lehman kernel for graph similarity - _train_graphs (List[nx.Graph]): Training set graph instances - _K_graph (Tensor): Pre-computed graph kernel matrix for training data - num_cat_kernel (Optional[Module]): Kernel for numerical/categorical features + _wl_kernel (TorchWLKernel): Instance of the Weisfeiler-Lehman kernel. + _train_graphs (list[nx.Graph]): Training graph instances. + _K_train (Tensor): Precomputed graph kernel matrix for training graphs. + num_cat_kernel (Module | None): Kernel for numerical/categorical features. """ def __init__( self, - train_X: Tensor, # Shape: (n_samples, n_numerical_categorical_features) - train_graphs: list[nx.Graph], # List of n_samples graph instances - train_Y: Tensor, # Shape: (n_samples, n_outputs) - train_Yvar: Tensor | None = None, # Shape: (n_samples, n_outputs) or None + train_X: Tensor, + train_graphs: list[nx.Graph], + train_Y: Tensor, + train_Yvar: Tensor | None = None, num_cat_kernel: Module | None = None, wl_kernel: TorchWLKernel | None = None, - **kwargs # Additional arguments passed to SingleTaskGP + **kwargs, ) -> None: - """Initialize the mixed input Gaussian Process model. + """Initialize the mixed-input Gaussian Process model. Args: - train_X: Training data tensor for numerical and categorical features - train_graphs: List of training graphs - train_Y: Target values - train_Yvar: Observation noise variance (optional) - num_cat_kernel: Kernel for numerical/categorical features (optional) - wl_kernel: Custom Weisfeiler-Lehman kernel instance (optional) - **kwargs: Additional arguments for SingleTaskGP initialization + train_X (Tensor): Training tensor for numerical and categorical features. + train_graphs (list[nx.Graph]): List of training graph instances. + train_Y (Tensor): Target values for training data. + train_Yvar (Tensor | None): Observation noise variance (optional). + num_cat_kernel (Module | None): Kernel for numerical/categorical features + (optional). + wl_kernel (TorchWLKernel | None): Weisfeiler-Lehman kernel instance + (optional). + **kwargs: Additional arguments for SingleTaskGP initialization. """ - # Initialize parent class with initial covar_module + # Initialize the base SingleTaskGP with a num/cat kernel (if provided) super().__init__( train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar, - covar_module=num_cat_kernel or self._graph_kernel_wrapper(), - **kwargs + covar_module=num_cat_kernel, + **kwargs, ) - # Initialize WL kernel with default parameters if not provided + # Initialize the Weisfeiler-Lehman kernel or use a default one self._wl_kernel = wl_kernel or TorchWLKernel(n_iter=5, normalize=True) self._train_graphs = train_graphs - # Convert graphs to required format and compute kernel matrix + # Preprocess the training graphs into a compatible format and compute the graph + # kernel matrix self._train_graph_dataset = GraphDataset.from_networkx(train_graphs) self._K_train = self._wl_kernel(self._train_graph_dataset) + # If a kernel for numerical/categorical features is provided, combine it with + # the WL kernel if num_cat_kernel is not None: - # Create additive kernel combining numerical/categorical and graph kernels combined_kernel = AdditiveKernel( num_cat_kernel, - self._graph_kernel_wrapper() + WLKernel(self), ) self.covar_module = combined_kernel self.num_cat_kernel = num_cat_kernel - def _graph_kernel_wrapper(self) -> Module: - """Creates a GPyTorch-compatible kernel module wrapping the WL kernel. - - This wrapper allows the WL kernel to be used within the GPyTorch framework - by providing a forward method that returns the pre-computed kernel matrix. - - Returns: - Module: A GPyTorch kernel module wrapping the WL kernel computation - """ - - class WLKernelWrapper(Module): - def __init__(self, parent: MixedSingleTaskGP): - super().__init__() - self.parent = parent - - def forward( - self, - x1: Tensor, - x2: Tensor | None = None, - diag: bool = False, - last_dim_is_batch: bool = False - ) -> Tensor: - """Compute the kernel matrix for the graph inputs. - - Args: - x1: First input tensor (unused, required for interface compatibility) - x2: Second input tensor (must be None) - diag: Whether to return only diagonal elements - last_dim_is_batch: Whether the last dimension is a batch dimension - - Returns: - Tensor: Pre-computed graph kernel matrix - - Raises: - NotImplementedError: If x2 is not None (cross-covariance not implemented) - """ - if x2 is None: - return self.parent._K_train - - # Compute cross-covariance between train and test graphs - test_dataset = GraphDataset.from_networkx(self.parent._test_graphs) - return self.parent._wl_kernel( - self.parent._train_graph_dataset, - test_dataset - ) - - return WLKernelWrapper(self) - def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal: - """Forward pass computing the GP distribution for given inputs. + """Forward pass to compute the Gaussian Process distribution for given inputs. - Computes the kernel matrix for both numerical/categorical features and graphs, - combines them if both are present, and returns the resulting GP distribution. + This combines the numerical/categorical kernel with the graph kernel + to compute the joint covariance matrix. Args: - X: Input tensor for numerical and categorical features - graphs: List of input graphs + X (Tensor): Input tensor for numerical and categorical features. + graphs (list[nx.Graph]): List of input graphs. Returns: - MultivariateNormal: GP distribution for the given inputs + MultivariateNormal: The Gaussian Process distribution for the inputs. """ if len(X) != len(graphs): raise ValueError( @@ -146,20 +151,19 @@ def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal: f"number of graphs ({len(graphs)})" ) - # Process new graphs and compute kernel matrix + # Process the new graph inputs into a compatible dataset proc_graphs = GraphDataset.from_networkx(graphs) + + # Compute the kernel matrix for the new graphs K_new = self._wl_kernel(proc_graphs) # Shape: (n_samples, n_samples) - # If we have both numerical/categorical and graph features + # Combine the graph kernel with the numerical/categorical kernel (if present) if self.num_cat_kernel is not None: - # Compute kernel for numerical/categorical features - K_num_cat = self.num_cat_kernel(X) - # Add the kernels (element-wise addition) - K_combined = K_num_cat + K_new + K_num_cat = self.num_cat_kernel(X) # Compute kernel for num/cat features + K_combined = K_num_cat + K_new # Combine the two kernels else: K_combined = K_new - # Compute mean using the mean module + # Compute the mean using the mean module and construct the GP distribution mean_x = self.mean_module(X) - return MultivariateNormal(mean_x, K_combined) From 7fa9432d95f8ff136e18dccb241111d0792f65a7 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Wed, 20 Nov 2024 22:38:24 +0100 Subject: [PATCH 11/32] Update tests --- tests/test_torch_wl_kernel.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_torch_wl_kernel.py b/tests/test_torch_wl_kernel.py index 4b89f4d9..49d7b1ed 100644 --- a/tests/test_torch_wl_kernel.py +++ b/tests/test_torch_wl_kernel.py @@ -48,7 +48,8 @@ def test_wl_kernel_against_grakel(self, n_iter, normalize, example_graphs): torch.tensor(torch_kernel_matrix, dtype=torch.float64), torch.tensor(grakel_kernel_matrix, dtype=torch.float64), atol=1e-100 - ), f"Mismatch found in kernel matrices with n_iter={n_iter} and normalize={normalize}" + ), (f"Mismatch found in kernel matrices with n_iter={n_iter} and " + f"normalize={normalize}") def test_kernel_symmetry(self, example_graphs): """Test if the kernel matrix is symmetric.""" @@ -152,7 +153,8 @@ def test_wl_kernel_with_different_node_labels(self, n_iter, normalize, example_g torch.tensor(torch_kernel_matrix, dtype=torch.float64), torch.tensor(grakel_kernel_matrix, dtype=torch.float64), atol=1e-100 - ), f"Mismatch found in kernel matrices with n_iter={n_iter} and normalize={normalize} for graphs with different node labels" + ), (f"Mismatch found in kernel matrices with n_iter={n_iter} and " + f"normalize={normalize} for graphs with different node labels") def test_wl_kernel_with_same_node_labels(self, example_graphs): """Test the TorchWLKernel with graphs having the same node labels.""" @@ -176,4 +178,5 @@ def test_wl_kernel_with_same_node_labels(self, example_graphs): # Check if the kernel matrix is valid and the values are the same for the graphs with the same node labels assert K.shape == (3, 3), "Kernel matrix shape is incorrect" assert torch.allclose(K, K.T, atol=1e-100), "Kernel matrix is not symmetric" - assert torch.all(K == K[0, 0]), "Kernel values should be the same for graphs with the same node labels" + assert torch.all(K == K[0, 0]), ("Kernel values should be the same for " + "graphs with the same node labels") From 4227f221b7cf0294b6eb45ef4c90ca557693585a Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Wed, 20 Nov 2024 23:10:53 +0100 Subject: [PATCH 12/32] Add a check for empty inputs --- grakel_replace/mixed_single_task_gp.py | 7 ++++++- grakel_replace/torch_wl_kernel.py | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/grakel_replace/mixed_single_task_gp.py b/grakel_replace/mixed_single_task_gp.py index e09e288c..f73c0f3f 100644 --- a/grakel_replace/mixed_single_task_gp.py +++ b/grakel_replace/mixed_single_task_gp.py @@ -2,13 +2,13 @@ from typing import TYPE_CHECKING +import networkx as nx from botorch.models import SingleTaskGP from gpytorch.distributions.multivariate_normal import MultivariateNormal from gpytorch.kernels import AdditiveKernel, Kernel from grakel_replace.torch_wl_kernel import GraphDataset, TorchWLKernel if TYPE_CHECKING: - import networkx as nx from gpytorch.module import Module from torch import Tensor @@ -103,6 +103,9 @@ def __init__( (optional). **kwargs: Additional arguments for SingleTaskGP initialization. """ + if train_X.size(0) == 0 or len(train_graphs) == 0: + raise ValueError("Training inputs (features and graphs) cannot be empty.") + # Initialize the base SingleTaskGP with a num/cat kernel (if provided) super().__init__( train_X=train_X, @@ -150,6 +153,8 @@ def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal: f"Number of feature vectors ({len(X)}) must match " f"number of graphs ({len(graphs)})" ) + if not all(isinstance(g, nx.Graph) for g in graphs): + raise TypeError("Expected input type is a list of NetworkX graphs.") # Process the new graph inputs into a compatible dataset proc_graphs = GraphDataset.from_networkx(graphs) diff --git a/grakel_replace/torch_wl_kernel.py b/grakel_replace/torch_wl_kernel.py index 929365a1..ba2ae071 100644 --- a/grakel_replace/torch_wl_kernel.py +++ b/grakel_replace/torch_wl_kernel.py @@ -228,6 +228,10 @@ class GraphDataset: @staticmethod def from_networkx(graphs: list[nx.Graph], node_labels_tag: str = "label") -> list[ nx.Graph]: + + if not all(isinstance(g, nx.Graph) for g in graphs): + raise TypeError("Expected input type is a list of NetworkX graphs.") + """Convert NetworkX graphs ensuring proper node labeling.""" processed_graphs = [] for g in graphs: From f194bd21eff8ff3eb7bbd44d11c304b9868dde63 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Wed, 20 Nov 2024 23:14:30 +0100 Subject: [PATCH 13/32] Improve and combine tests --- tests/test_mixed_single_task_gp.py | 141 +++++++++++++++++++---------- 1 file changed, 93 insertions(+), 48 deletions(-) diff --git a/tests/test_mixed_single_task_gp.py b/tests/test_mixed_single_task_gp.py index 68bad6fc..3119d66d 100644 --- a/tests/test_mixed_single_task_gp.py +++ b/tests/test_mixed_single_task_gp.py @@ -58,12 +58,12 @@ def sample_kernels(): return combined_kernel, wl_kernel -def test_initialization(sample_data, sample_kernels): - """Test that MixedSingleTaskGP initializes correctly.""" +def test_model_initialization_and_validation(sample_data, sample_kernels): + """Test GP initialization, inputs validation, and basic properties.""" X, graphs, y = sample_data combined_kernel, wl_kernel = sample_kernels - # Test initialization with all parameters + # Test successful initialization with all parameters gp = MixedSingleTaskGP( train_X=X, train_graphs=graphs, @@ -76,9 +76,17 @@ def test_initialization(sample_data, sample_kernels): assert isinstance(gp._wl_kernel, TorchWLKernel) assert gp.num_cat_kernel == combined_kernel + # Test empty input validation + with pytest.raises(ValueError, match="Training inputs.*cannot be empty"): + MixedSingleTaskGP( + train_X=torch.empty((0, 4), dtype=torch.float64), + train_graphs=[], + train_Y=torch.empty((0, 1), dtype=torch.float64), + ) -def test_forward_shape(sample_data, sample_kernels): - """Test that forward pass returns correct shapes.""" + +def test_forward_pass_and_predictions(sample_data, sample_kernels): + """Test forward pass, shape consistency, and prediction characteristics.""" X, graphs, y = sample_data combined_kernel, wl_kernel = sample_kernels @@ -96,42 +104,36 @@ def test_forward_shape(sample_data, sample_kernels): assert output.mean.shape == (len(X),) assert output.covariance_matrix.shape == (len(X), len(X)) - # Test forward pass with different sized test data + # Test forward pass with test data n_test = 3 test_X = torch.rand(size=(n_test, X.shape[1]), dtype=torch.float64) test_graphs = [nx.erdos_renyi_graph(n=4, p=0.5) for _ in range(n_test)] - output = gp.forward(test_X, test_graphs) - assert output.mean.shape == (n_test,) - assert output.covariance_matrix.shape == (n_test, n_test) - - -def test_input_validation(sample_data, sample_kernels): - """Test that appropriate errors are raised for invalid inputs.""" - X, graphs, y = sample_data - combined_kernel, wl_kernel = sample_kernels - - gp = MixedSingleTaskGP( - train_X=X, - train_graphs=graphs, - train_Y=y, - num_cat_kernel=combined_kernel, - wl_kernel=wl_kernel, - ) + output_test = gp.forward(test_X, test_graphs) + assert output_test.mean.shape == (n_test,) + assert output_test.covariance_matrix.shape == (n_test, n_test) - # Test mismatched number of features and graphs - test_X = torch.rand(size=(3, X.shape[1]), dtype=torch.float64) - test_graphs = [nx.erdos_renyi_graph(n=4, p=0.5) for _ in range(4)] # Different length + # Test input validation + # Mismatched number of features and graphs + mismatched_test_X = torch.rand(size=(3, X.shape[1]), dtype=torch.float64) + mismatched_test_graphs = [nx.erdos_renyi_graph(n=4, p=0.5) for _ in range(4)] with pytest.raises(ValueError, match="Number of feature vectors.*must match.*number of graphs"): - gp.forward(test_X, test_graphs) + gp.forward(mismatched_test_X, mismatched_test_graphs) + # Invalid graph input + invalid_graphs = ["not_a_graph", 123, None, False, True] + with pytest.raises(TypeError, + match="Expected input type is a list of NetworkX graphs."): + gp.forward(X, invalid_graphs) -def test_kernel_combination(sample_data): - """Test that numerical/categorical and graph kernels are properly combined.""" + +def test_kernel_combinations_and_properties(sample_data): + """Test kernel combination, invariance, and consistency properties.""" X, graphs, y = sample_data + # Test kernel combination and variance changes # Create GP with only graph kernel gp_graph_only = MixedSingleTaskGP( train_X=X, @@ -142,7 +144,7 @@ def test_kernel_combination(sample_data): output_graph = gp_graph_only.forward(X, graphs) graph_var = output_graph.variance - # Create GP with both kernels + # Create GP with combined kernels n_numerical = 2 matern = ScaleKernel( MaternKernel( @@ -165,9 +167,42 @@ def test_kernel_combination(sample_data): # Combined kernel should have larger variance due to addition assert torch.all(combined_var > graph_var) + # Use graphs with slight variations to avoid singular matrix + similar_graphs = [ + nx.complete_graph(5) for _ in range(len(graphs)) + ] + + # Add small random perturbations to make graphs slightly different + for i in range(1, len(similar_graphs)): + G = similar_graphs[i] + # Add or remove edges with a small probability + edges_to_add = [(u, v) for u in range(5) for v in range(u + 1, 5) + if not G.has_edge(u, v) and torch.rand(1) < 0.1] + edges_to_remove = [(u, v) for (u, v) in G.edges() + if torch.rand(1) < 0.1] + + G.add_edges_from(edges_to_add) + G.remove_edges_from(edges_to_remove) + + gp_similar = MixedSingleTaskGP( + train_X=X, + train_graphs=similar_graphs, + train_Y=y, + ) + + # Compute kernel matrix and check diagonal consistency + kernel_matrix = gp_similar._K_train + diag = kernel_matrix.diag() + + # Allow for slight variations due to graph perturbations + assert torch.allclose(diag, diag[0], atol=1e-1) -def test_prediction_consistency(sample_data, sample_kernels): - """Test that predictions are consistent between multiple forward passes.""" + # Check that the matrix is not completely uniform + assert not torch.allclose(kernel_matrix, torch.ones_like(kernel_matrix), rtol=1e-5) + + +def test_model_prediction_consistency(sample_data, sample_kernels): + """Test prediction consistency and mean prediction bounds.""" X, graphs, y = sample_data combined_kernel, wl_kernel = sample_kernels @@ -179,13 +214,22 @@ def test_prediction_consistency(sample_data, sample_kernels): wl_kernel=wl_kernel, ) - # Multiple forward passes should give same result + # Multiple forward passes should give consistent results output1 = gp.forward(X, graphs) output2 = gp.forward(X, graphs) assert torch.allclose(output1.mean, output2.mean) assert torch.allclose(output1.variance, output2.variance) + # Mean predictions should be within reasonable bounds + with torch.no_grad(): + output = gp.forward(X, graphs) + predictions = output.mean + uncertainties = output.variance.sqrt() + + assert torch.all(predictions >= y.min() - 2 * uncertainties) + assert torch.all(predictions <= y.max() + 2 * uncertainties) + def test_graph_kernel_caching(sample_data, sample_kernels): """Test that graph kernel matrices are properly cached.""" @@ -212,25 +256,26 @@ def test_graph_kernel_caching(sample_data, sample_kernels): assert torch.allclose(K_train_1, K_train_2) -def test_mean_predictions(sample_data, sample_kernels): - """Test that mean predictions are reasonable.""" - X, graphs, y = sample_data - combined_kernel, wl_kernel = sample_kernels +def test_large_dataset_handling(): + """Test the model's behavior with large datasets.""" + n_samples = 100 + n_features = 4 + + # Create large numerical and categorical features + X = torch.rand(size=(n_samples, n_features), dtype=torch.float64) + + # Create a large set of random graphs + graphs = [nx.erdos_renyi_graph(n=10, p=0.2) for _ in range(n_samples)] + + # Create target values + y = torch.rand(size=(n_samples, 1), dtype=torch.float64) gp = MixedSingleTaskGP( train_X=X, train_graphs=graphs, train_Y=y, - num_cat_kernel=combined_kernel, - wl_kernel=wl_kernel, ) - # Test predictions - with torch.no_grad(): - output = gp.forward(X, graphs) - predictions = output.mean - uncertainties = output.variance.sqrt() - - # Mean predictions should be within reasonable bounds - assert torch.all(predictions >= y.min() - 2 * uncertainties) - assert torch.all(predictions <= y.max() + 2 * uncertainties) + output = gp.forward(X, graphs) + assert output.mean.shape == (n_samples,) + assert output.covariance_matrix.shape == (n_samples, n_samples) From a10484096a7002b1e2831a5c7e01f8297baf3b09 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Thu, 21 Nov 2024 11:30:23 +0100 Subject: [PATCH 14/32] Update WLKernel --- grakel_replace/mixed_single_task_gp.py | 34 ++++++++++++-------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/grakel_replace/mixed_single_task_gp.py b/grakel_replace/mixed_single_task_gp.py index f73c0f3f..3f436a64 100644 --- a/grakel_replace/mixed_single_task_gp.py +++ b/grakel_replace/mixed_single_task_gp.py @@ -21,23 +21,24 @@ class WLKernel(Kernel): and provides it in a GPyTorch-compatible format. It computes either the training kernel or the cross-kernel between training and test graphs as needed. - - Args: - parent (MixedSingleTaskGP): - The parent MixedSingleTaskGP instance that holds - the training data and precomputed kernel matrix. """ - def __init__(self, parent: MixedSingleTaskGP) -> None: + def __init__( + self, + K_train: Tensor, + wl_kernel: TorchWLKernel, + train_graph_dataset: GraphDataset + ) -> None: super().__init__() - self.parent = parent + self._K_train = K_train + self._wl_kernel = wl_kernel + self._train_graph_dataset = train_graph_dataset def forward( - self, - x1: Tensor, + self, x1: Tensor, x2: Tensor | None = None, diag: bool = False, - last_dim_is_batch: bool = False, + last_dim_is_batch: bool = False ) -> Tensor: """Forward method to compute the kernel matrix for the graph inputs. @@ -54,13 +55,11 @@ def forward( """ if x2 is None: # Return the precomputed training kernel matrix - return self.parent._K_train + return self._K_train # Compute cross-kernel between training graphs and new test graphs - test_dataset = GraphDataset.from_networkx(self.parent._test_graphs) - return self.parent._wl_kernel( - self.parent._train_graph_dataset, test_dataset - ) + test_dataset = GraphDataset.from_networkx(x2) # x2 should be test graphs + return self._wl_kernel(self._train_graph_dataset, test_dataset) class MixedSingleTaskGP(SingleTaskGP): @@ -127,11 +126,10 @@ def __init__( # If a kernel for numerical/categorical features is provided, combine it with # the WL kernel if num_cat_kernel is not None: - combined_kernel = AdditiveKernel( + self.covar_module = AdditiveKernel( num_cat_kernel, - WLKernel(self), + WLKernel(self._K_train, self._wl_kernel, self._train_graph_dataset), ) - self.covar_module = combined_kernel self.num_cat_kernel = num_cat_kernel From 246f9f658ed49e873862693c48bec70337cf621c Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Thu, 21 Nov 2024 11:31:37 +0100 Subject: [PATCH 15/32] Add acquisition function with graph sampling --- .../mixed_single_task_gp_usage_example.py | 59 ++++++++++++ grakel_replace/optimize.py | 90 +++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 grakel_replace/optimize.py diff --git a/grakel_replace/mixed_single_task_gp_usage_example.py b/grakel_replace/mixed_single_task_gp_usage_example.py index 67854b7c..98dec20a 100644 --- a/grakel_replace/mixed_single_task_gp_usage_example.py +++ b/grakel_replace/mixed_single_task_gp_usage_example.py @@ -1,9 +1,16 @@ +from itertools import product + import networkx as nx import torch +from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement +from botorch.fit import fit_gpytorch_mll from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel +from botorch.optim import optimize_acqf, optimize_acqf_mixed +from gpytorch import ExactMarginalLogLikelihood from gpytorch.distributions.multivariate_normal import MultivariateNormal from gpytorch.kernels import AdditiveKernel, MaternKernel from grakel_replace.mixed_single_task_gp import MixedSingleTaskGP +from grakel_replace.optimize import optimize_acqf_graph from grakel_replace.torch_wl_kernel import TorchWLKernel TRAIN_CONFIGS = 10 @@ -100,3 +107,55 @@ print("\nMean:", predictions) print("Variance:", uncertainties) print("Covariance matrix:", covar) + +# =============== Fitting the GP using botorch =============== + +mll = ExactMarginalLogLikelihood(gp.likelihood, gp) +fit_gpytorch_mll(mll) + +# Define the acquisition function +acq_function = qLogNoisyExpectedImprovement( + model=gp, + X_baseline=X, + objective=LinearMCObjective(weights=torch.tensor([-1.0])), + prune_baseline=True, +) + +# Define bounds +bounds = torch.tensor( + [ + [0.0, 1.0] * N_NUMERICAL + + [0.0, N_CATEGORICAL_VALUES_PER_CATEGORY - 1] * N_CATEGORICAL + ] +).view(2, -1) + +cats_per_column: dict[int, list[float]] = { + column_ix: [float(i) for i in range(N_CATEGORICAL_VALUES_PER_CATEGORY)] + for column_ix in range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL) +} + +# Generate fixed categorical features +fixed_cats: list[dict[int, float]] +if len(cats_per_column) == 1: + col, choice_indices = next(iter(cats_per_column.items())) + fixed_cats = [{col: i} for i in choice_indices] +else: + fixed_cats = [ + dict(zip(cats_per_column.keys(), combo)) + for combo in product(*cats_per_column.values()) + ] + +# Use the graph-optimized acquisition function +best_candidate, best_score = optimize_acqf_graph( + acq_function=acq_function, + bounds=bounds, + fixed_features_list=fixed_cats, + train_graphs=graphs, + num_graph_samples=10, # Number of graphs to sample + num_restarts=3, + raw_samples=250, + q=1, +) + +print("Best candidate:", best_candidate) +print("Acquisition score:", best_score) diff --git a/grakel_replace/optimize.py b/grakel_replace/optimize.py new file mode 100644 index 00000000..1f0fe72a --- /dev/null +++ b/grakel_replace/optimize.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import random +from typing import TYPE_CHECKING + +import torch +from botorch.optim import optimize_acqf_mixed + +if TYPE_CHECKING: + import networkx as nx + from botorch.acquisition import AcquisitionFunction + + +def sample_graphs(graphs: list[nx.Graph], num_samples: int) -> list[nx.Graph]: + """Sample graphs using random walks or edge modifications. + + Args: + graphs (list[nx.Graph]): Existing training graphs. + num_samples (int): Number of graph samples to generate. + + Returns: + list[nx.Graph]: Sampled graphs. + """ + sampled_graphs = [] + for _ in range(num_samples): + base_graph = random.choice(graphs) + sampled_graph = base_graph.copy() # Copy base graph + # Modify the graph with edge additions or removals + for _ in range(random.randint(1, 3)): + if len(sampled_graph.edges) > 0: + # Randomly remove or add edges + if random.random() > 0.5: + u, v = random.choice(list(sampled_graph.edges)) + sampled_graph.remove_edge(u, v) + else: + u = random.choice(list(sampled_graph.nodes)) + v = random.choice(list(sampled_graph.nodes)) + sampled_graph.add_edge(u, v) + sampled_graphs.append(sampled_graph) + return sampled_graphs + + +def optimize_acqf_graph( + acq_function: AcquisitionFunction, + bounds: torch.Tensor, + fixed_features_list: list[dict[int, float]] | None = None, + num_graph_samples: int = 10, + train_graphs: list[nx.Graph] | None = None, + num_restarts: int = 10, + raw_samples: int = 1024, + q: int = 1, +) -> tuple[torch.Tensor, float]: + """Optimize acquisition function with graph sampling. + + Args: + acq_function: Acquisition function to optimize + bounds: Bounds for numerical/categorical features + fixed_features_list: Fixed categorical feature configurations + num_graph_samples: Number of graphs to sample + train_graphs: Original training graphs + num_restarts: Number of optimization restarts + raw_samples: Number of raw samples to generate + q: Number of candidates to generate + + Returns: + tuple: Best candidate and acquisition score. + """ + if train_graphs is None: + raise ValueError("train_graphs cannot be None.") + + sampled_graphs = sample_graphs(train_graphs, num_samples=num_graph_samples) + + best_candidates, best_scores = [], [] + + for _graph in sampled_graphs: + for fixed_features in fixed_features_list or [{}]: + candidates, scores = optimize_acqf_mixed( + acq_function=acq_function, + bounds=bounds, + fixed_features_list=[fixed_features], + num_restarts=num_restarts, + raw_samples=raw_samples, + q=q, + ) + best_candidates.append(candidates) + best_scores.append(scores) + + best_scores_tensor = torch.tensor(best_scores) + best_idx = torch.argmax(best_scores_tensor) + return best_candidates[best_idx], best_scores_tensor[best_idx].item() From 770c62682b1bda38b958e8a953f9961d3b5c65bf Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Thu, 21 Nov 2024 22:39:41 +0100 Subject: [PATCH 16/32] Add a custom __call__ method to pass graphs during optimization --- grakel_replace/mixed_single_task_gp.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/grakel_replace/mixed_single_task_gp.py b/grakel_replace/mixed_single_task_gp.py index 3f436a64..b6d20678 100644 --- a/grakel_replace/mixed_single_task_gp.py +++ b/grakel_replace/mixed_single_task_gp.py @@ -76,6 +76,7 @@ class MixedSingleTaskGP(SingleTaskGP): _wl_kernel (TorchWLKernel): Instance of the Weisfeiler-Lehman kernel. _train_graphs (list[nx.Graph]): Training graph instances. _K_train (Tensor): Precomputed graph kernel matrix for training graphs. + train_inputs (tuple[Tensor, list[nx.Graph]]): Tuple of training inputs. num_cat_kernel (Module | None): Kernel for numerical/categorical features. """ @@ -118,6 +119,9 @@ def __init__( self._wl_kernel = wl_kernel or TorchWLKernel(n_iter=5, normalize=True) self._train_graphs = train_graphs + # Store graphs as part of train_inputs for using them in the __call__ method + self.train_inputs = (train_X, train_graphs) + # Preprocess the training graphs into a compatible format and compute the graph # kernel matrix self._train_graph_dataset = GraphDataset.from_networkx(train_graphs) @@ -133,6 +137,12 @@ def __init__( self.num_cat_kernel = num_cat_kernel + def __call__(self, X: Tensor, graphs: list[nx.Graph] = None, **kwargs): + """Custom __call__ method that retrieves graphs if not explicitly passed.""" + if graphs is None: # Use stored graphs from train_inputs if not provided + graphs = self.train_inputs[1] + return super().__call__(X, graphs=graphs, **kwargs) + def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal: """Forward pass to compute the Gaussian Process distribution for given inputs. From 8bf7ea7dd733a618b5a9eb517d1137ecb1ac8922 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Sat, 7 Dec 2024 16:14:44 +0100 Subject: [PATCH 17/32] Update MixedSingleTaskGP --- grakel_replace/mixed_single_task_gp.py | 13 +++++-------- .../mixed_single_task_gp_usage_example.py | 17 +++++++++-------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/grakel_replace/mixed_single_task_gp.py b/grakel_replace/mixed_single_task_gp.py index b6d20678..d1dd7983 100644 --- a/grakel_replace/mixed_single_task_gp.py +++ b/grakel_replace/mixed_single_task_gp.py @@ -74,7 +74,6 @@ class MixedSingleTaskGP(SingleTaskGP): Attributes: _wl_kernel (TorchWLKernel): Instance of the Weisfeiler-Lehman kernel. - _train_graphs (list[nx.Graph]): Training graph instances. _K_train (Tensor): Precomputed graph kernel matrix for training graphs. train_inputs (tuple[Tensor, list[nx.Graph]]): Tuple of training inputs. num_cat_kernel (Module | None): Kernel for numerical/categorical features. @@ -114,13 +113,11 @@ def __init__( covar_module=num_cat_kernel, **kwargs, ) + # Store graphs as part of train_inputs for using them in the __call__ method + self.train_inputs = (train_X, train_graphs) # Initialize the Weisfeiler-Lehman kernel or use a default one self._wl_kernel = wl_kernel or TorchWLKernel(n_iter=5, normalize=True) - self._train_graphs = train_graphs - - # Store graphs as part of train_inputs for using them in the __call__ method - self.train_inputs = (train_X, train_graphs) # Preprocess the training graphs into a compatible format and compute the graph # kernel matrix @@ -137,11 +134,11 @@ def __init__( self.num_cat_kernel = num_cat_kernel - def __call__(self, X: Tensor, graphs: list[nx.Graph] = None, **kwargs): - """Custom __call__ method that retrieves graphs if not explicitly passed.""" + def __call__(self, X: Tensor, graphs: list[nx.Graph] | None = None, **kwargs): + """Custom __call__ method that retrieves train graphs if not explicitly passed.""" if graphs is None: # Use stored graphs from train_inputs if not provided graphs = self.train_inputs[1] - return super().__call__(X, graphs=graphs, **kwargs) + return self.forward(X, graphs) def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal: """Forward pass to compute the Gaussian Process distribution for given inputs. diff --git a/grakel_replace/mixed_single_task_gp_usage_example.py b/grakel_replace/mixed_single_task_gp_usage_example.py index 98dec20a..19208d99 100644 --- a/grakel_replace/mixed_single_task_gp_usage_example.py +++ b/grakel_replace/mixed_single_task_gp_usage_example.py @@ -5,7 +5,6 @@ from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement from botorch.fit import fit_gpytorch_mll from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel -from botorch.optim import optimize_acqf, optimize_acqf_mixed from gpytorch import ExactMarginalLogLikelihood from gpytorch.distributions.multivariate_normal import MultivariateNormal from gpytorch.kernels import AdditiveKernel, MaternKernel @@ -109,6 +108,7 @@ print("Covariance matrix:", covar) # =============== Fitting the GP using botorch =============== +print("\nFitting the GP model using botorch...") mll = ExactMarginalLogLikelihood(gp.likelihood, gp) fit_gpytorch_mll(mll) @@ -116,7 +116,7 @@ # Define the acquisition function acq_function = qLogNoisyExpectedImprovement( model=gp, - X_baseline=X, + X_baseline=train_x, objective=LinearMCObjective(weights=torch.tensor([-1.0])), prune_baseline=True, ) @@ -124,11 +124,12 @@ # Define bounds bounds = torch.tensor( [ - [0.0, 1.0] * N_NUMERICAL - + [0.0, N_CATEGORICAL_VALUES_PER_CATEGORY - 1] * N_CATEGORICAL + [0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL, + [1.0] * N_NUMERICAL + [float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL ] -).view(2, -1) +) +# Setup categorical feature optimization cats_per_column: dict[int, list[float]] = { column_ix: [float(i) for i in range(N_CATEGORICAL_VALUES_PER_CATEGORY)] for column_ix in range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL) @@ -150,10 +151,10 @@ acq_function=acq_function, bounds=bounds, fixed_features_list=fixed_cats, - train_graphs=graphs, - num_graph_samples=10, # Number of graphs to sample + train_graphs=train_graphs, + num_graph_samples=6, num_restarts=3, - raw_samples=250, + raw_samples=10, q=1, ) From 84d010442942df097ea139077aa57f1850838b2e Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Sat, 7 Dec 2024 16:15:32 +0100 Subject: [PATCH 18/32] Remove not used argument --- grakel_replace/single_task_gp_usage_example.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grakel_replace/single_task_gp_usage_example.py b/grakel_replace/single_task_gp_usage_example.py index bdc9a0ed..f6b24e12 100644 --- a/grakel_replace/single_task_gp_usage_example.py +++ b/grakel_replace/single_task_gp_usage_example.py @@ -71,7 +71,6 @@ gp = SingleTaskGP( train_X=train_x, train_Y=train_y, - mean_module=None, # We can leave it as the default it uses which is `ConstantMean` covar_module=combined_num_cat_kernel, ) From d63239ada8d012e1870029f5b9254b25fcde17f7 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Sat, 7 Dec 2024 16:52:32 +0100 Subject: [PATCH 19/32] Update sample_graphs --- grakel_replace/optimize.py | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/grakel_replace/optimize.py b/grakel_replace/optimize.py index 1f0fe72a..eff7adac 100644 --- a/grakel_replace/optimize.py +++ b/grakel_replace/optimize.py @@ -3,11 +3,11 @@ import random from typing import TYPE_CHECKING +import networkx as nx import torch from botorch.optim import optimize_acqf_mixed if TYPE_CHECKING: - import networkx as nx from botorch.acquisition import AcquisitionFunction @@ -24,19 +24,31 @@ def sample_graphs(graphs: list[nx.Graph], num_samples: int) -> list[nx.Graph]: sampled_graphs = [] for _ in range(num_samples): base_graph = random.choice(graphs) - sampled_graph = base_graph.copy() # Copy base graph - # Modify the graph with edge additions or removals - for _ in range(random.randint(1, 3)): - if len(sampled_graph.edges) > 0: - # Randomly remove or add edges - if random.random() > 0.5: - u, v = random.choice(list(sampled_graph.edges)) - sampled_graph.remove_edge(u, v) - else: - u = random.choice(list(sampled_graph.nodes)) - v = random.choice(list(sampled_graph.nodes)) - sampled_graph.add_edge(u, v) + sampled_graph = base_graph.copy() + + # More aggressive modifications + num_modifications = random.randint(2, 5) # Increase minimum modifications + for _ in range(num_modifications): + if random.random() > 0.3: # 70% chance to add edge + nodes = list(sampled_graph.nodes) + if len(nodes) >= 2: + u, v = random.sample(nodes, 2) + if not sampled_graph.has_edge(u, v): + sampled_graph.add_edge(u, v) + elif sampled_graph.edges: # 30% chance to remove edge + u, v = random.choice(list(sampled_graph.edges)) + sampled_graph.remove_edge(u, v) + + # Ensure the graph stays connected + if not nx.is_connected(sampled_graph): + components = list(nx.connected_components(sampled_graph)) + for i in range(len(components) - 1): + u = random.choice(list(components[i])) + v = random.choice(list(components[i + 1])) + sampled_graph.add_edge(u, v) + sampled_graphs.append(sampled_graph) + return sampled_graphs From 3db3f8935692bb503d0c1caeb3c6d34173b7388b Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Sat, 7 Dec 2024 16:53:57 +0100 Subject: [PATCH 20/32] Handle different batch dimensions --- grakel_replace/mixed_single_task_gp.py | 69 ++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 10 deletions(-) diff --git a/grakel_replace/mixed_single_task_gp.py b/grakel_replace/mixed_single_task_gp.py index d1dd7983..de04a14a 100644 --- a/grakel_replace/mixed_single_task_gp.py +++ b/grakel_replace/mixed_single_task_gp.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING import networkx as nx +import torch from botorch.models import SingleTaskGP from gpytorch.distributions.multivariate_normal import MultivariateNormal from gpytorch.kernels import AdditiveKernel, Kernel @@ -54,12 +55,35 @@ def forward( Tensor: The computed kernel matrix. """ if x2 is None: - # Return the precomputed training kernel matrix - return self._K_train - - # Compute cross-kernel between training graphs and new test graphs - test_dataset = GraphDataset.from_networkx(x2) # x2 should be test graphs - return self._wl_kernel(self._train_graph_dataset, test_dataset) + K = self._K_train + # Handle batch dimension if present in x1 + if x1.dim() > 2: # We have a batch dimension + batch_size = x1.size(0) + target_size = x1.size(1) # This should be 11 in our case + # Resize K to match the expected dimensions + K = K.unsqueeze(0) # Add batch dimension + # Pad or interpolate K to match target size + if K.size(1) != target_size: + K_resized = torch.zeros(1, target_size, target_size, dtype=K.dtype, + device=K.device) + K_resized[:, :K.size(1), :K.size(2)] = K + K = K_resized + K = K.expand(batch_size, target_size, target_size) + return K.to(dtype=x1.dtype) + + # Similar logic for cross-kernel case + test_dataset = GraphDataset.from_networkx(x2) + K = self._wl_kernel(self._train_graph_dataset, test_dataset) + if x1.dim() > 2: + batch_size = x1.size(0) + target_size = x1.size(1) + if K.size(0) != target_size: + K_resized = torch.zeros(target_size, target_size, dtype=K.dtype, + device=K.device) + K_resized[:K.size(0), :K.size(1)] = K + K = K_resized + K = K.unsqueeze(0).expand(batch_size, target_size, target_size) + return K.to(dtype=x1.dtype) class MixedSingleTaskGP(SingleTaskGP): @@ -161,16 +185,41 @@ def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal: if not all(isinstance(g, nx.Graph) for g in graphs): raise TypeError("Expected input type is a list of NetworkX graphs.") - # Process the new graph inputs into a compatible dataset + # Process the new graph inputs into a compatible dataset proc_graphs = GraphDataset.from_networkx(graphs) # Compute the kernel matrix for the new graphs - K_new = self._wl_kernel(proc_graphs) # Shape: (n_samples, n_samples) + K_new = self._wl_kernel(proc_graphs) + K_new = K_new.to(dtype=X.dtype) # Combine the graph kernel with the numerical/categorical kernel (if present) if self.num_cat_kernel is not None: - K_num_cat = self.num_cat_kernel(X) # Compute kernel for num/cat features - K_combined = K_num_cat + K_new # Combine the two kernels + K_num_cat = self.num_cat_kernel(X) + + # Ensure K_new matches K_num_cat dimensions + if K_num_cat.dim() > 2: + batch_size = K_num_cat.size(0) + target_size = K_num_cat.size(1) + + # Resize K_new if needed + if K_new.size(-1) != target_size: + K_new_resized = torch.zeros( + *K_new.shape[:-2], target_size, target_size, + dtype=K_new.dtype, + device=K_new.device + ) + K_new_resized[..., :K_new.size(-2), :K_new.size(-1)] = K_new + K_new = K_new_resized + + if K_new.dim() < K_num_cat.dim(): + K_new = K_new.unsqueeze(0).expand(batch_size, target_size, + target_size) + + # Convert to dense tensor if needed + if hasattr(K_num_cat, "to_dense"): + K_num_cat = K_num_cat.to_dense() + + K_combined = K_num_cat + K_new else: K_combined = K_new From f69ddbe27bb28d87384885c00ac76be3a2879274 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Sat, 7 Dec 2024 16:54:17 +0100 Subject: [PATCH 21/32] Set num_restarts=10 --- grakel_replace/mixed_single_task_gp_usage_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grakel_replace/mixed_single_task_gp_usage_example.py b/grakel_replace/mixed_single_task_gp_usage_example.py index 19208d99..19d1662b 100644 --- a/grakel_replace/mixed_single_task_gp_usage_example.py +++ b/grakel_replace/mixed_single_task_gp_usage_example.py @@ -153,7 +153,7 @@ fixed_features_list=fixed_cats, train_graphs=train_graphs, num_graph_samples=6, - num_restarts=3, + num_restarts=10, raw_samples=10, q=1, ) From 1c4cc833488e917d394cce465ddceefd0398cbc6 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Sat, 7 Dec 2024 22:45:58 +0100 Subject: [PATCH 22/32] Add acquisition function --- .../single_task_gp_usage_example.py | 72 ++++++++++++++++--- 1 file changed, 64 insertions(+), 8 deletions(-) diff --git a/grakel_replace/single_task_gp_usage_example.py b/grakel_replace/single_task_gp_usage_example.py index f6b24e12..9e295852 100644 --- a/grakel_replace/single_task_gp_usage_example.py +++ b/grakel_replace/single_task_gp_usage_example.py @@ -1,9 +1,20 @@ +from __future__ import annotations + +from itertools import product +from typing import TYPE_CHECKING + import torch +from botorch import fit_gpytorch_mll +from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement from botorch.models import SingleTaskGP from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel -from gpytorch.distributions.multivariate_normal import MultivariateNormal +from botorch.optim import optimize_acqf_mixed +from gpytorch import ExactMarginalLogLikelihood from gpytorch.kernels import AdditiveKernel, MaternKernel +if TYPE_CHECKING: + from gpytorch.distributions.multivariate_normal import MultivariateNormal + TRAIN_CONFIGS = 10 TEST_CONFIGS = 10 TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS @@ -51,7 +62,6 @@ ) kernels.append(hamming) - combined_num_cat_kernel = AdditiveKernel(*kernels) train_x = X[:TRAIN_CONFIGS] @@ -61,9 +71,6 @@ test_y = y[TRAIN_CONFIGS:] K_matrix = combined_num_cat_kernel.forward(train_x, train_x) -print( - "K_matrix: ", K_matrix.to_dense() -) train_y = train_y.unsqueeze(-1) test_y = test_y.unsqueeze(-1) @@ -75,6 +82,55 @@ ) multivariate_normal: MultivariateNormal = gp.forward(train_x) -print("Mean:", multivariate_normal.mean) -print("Variance:", multivariate_normal.variance) -print("Covariance matrix:", multivariate_normal.covariance_matrix) + +# =============== Fitting the GP using botorch =============== + +print("\nFitting the GP model using botorch...") + +mll = ExactMarginalLogLikelihood(gp.likelihood, gp) +fit_gpytorch_mll(mll) + +acq_function = qLogNoisyExpectedImprovement( + model=gp, + X_baseline=train_x, + objective=LinearMCObjective(weights=torch.tensor([-1.0])), + prune_baseline=True, +) + +# Define bounds +bounds = torch.tensor( + [ + [0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL, + [1.0] * N_NUMERICAL + [ + float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL + ] +) + +# Setup categorical feature optimization +cats_per_column: dict[int, list[float]] = { + column_ix: [float(i) for i in range(N_CATEGORICAL_VALUES_PER_CATEGORY)] + for column_ix in range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL) +} + +# Generate fixed categorical features +fixed_cats: list[dict[int, float]] +if len(cats_per_column) == 1: + col, choice_indices = next(iter(cats_per_column.items())) + fixed_cats = [{col: i} for i in choice_indices] +else: + fixed_cats = [ + dict(zip(cats_per_column.keys(), combo)) + for combo in product(*cats_per_column.values()) + ] + +best_candidate, best_score = optimize_acqf_mixed( + acq_function=acq_function, + bounds=bounds, + fixed_features_list=fixed_cats, + num_restarts=10, + raw_samples=10, + q=1, +) + +print("Best candidate:", best_candidate) +print("Acquisition score:", best_score) From dab9a8c81c2d54a3fd07380c5852f561f4182753 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Sat, 7 Dec 2024 22:46:45 +0100 Subject: [PATCH 23/32] Update WLKernel --- grakel_replace/mixed_single_task_gp.py | 37 ++++--------------- .../mixed_single_task_gp_usage_example.py | 4 +- 2 files changed, 9 insertions(+), 32 deletions(-) diff --git a/grakel_replace/mixed_single_task_gp.py b/grakel_replace/mixed_single_task_gp.py index de04a14a..9f322d05 100644 --- a/grakel_replace/mixed_single_task_gp.py +++ b/grakel_replace/mixed_single_task_gp.py @@ -55,35 +55,12 @@ def forward( Tensor: The computed kernel matrix. """ if x2 is None: - K = self._K_train - # Handle batch dimension if present in x1 - if x1.dim() > 2: # We have a batch dimension - batch_size = x1.size(0) - target_size = x1.size(1) # This should be 11 in our case - # Resize K to match the expected dimensions - K = K.unsqueeze(0) # Add batch dimension - # Pad or interpolate K to match target size - if K.size(1) != target_size: - K_resized = torch.zeros(1, target_size, target_size, dtype=K.dtype, - device=K.device) - K_resized[:, :K.size(1), :K.size(2)] = K - K = K_resized - K = K.expand(batch_size, target_size, target_size) - return K.to(dtype=x1.dtype) - - # Similar logic for cross-kernel case - test_dataset = GraphDataset.from_networkx(x2) - K = self._wl_kernel(self._train_graph_dataset, test_dataset) - if x1.dim() > 2: - batch_size = x1.size(0) - target_size = x1.size(1) - if K.size(0) != target_size: - K_resized = torch.zeros(target_size, target_size, dtype=K.dtype, - device=K.device) - K_resized[:K.size(0), :K.size(1)] = K - K = K_resized - K = K.unsqueeze(0).expand(batch_size, target_size, target_size) - return K.to(dtype=x1.dtype) + # Return the precomputed training kernel matrix + return self._K_train + + # Compute cross-kernel between training graphs and new test graphs + test_dataset = GraphDataset.from_networkx(x2) # x2 should be test graphs + return self._wl_kernel(self._train_graph_dataset, test_dataset) class MixedSingleTaskGP(SingleTaskGP): @@ -185,7 +162,7 @@ def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal: if not all(isinstance(g, nx.Graph) for g in graphs): raise TypeError("Expected input type is a list of NetworkX graphs.") - # Process the new graph inputs into a compatible dataset + # Process the new graph inputs into a compatible dataset proc_graphs = GraphDataset.from_networkx(graphs) # Compute the kernel matrix for the new graphs diff --git a/grakel_replace/mixed_single_task_gp_usage_example.py b/grakel_replace/mixed_single_task_gp_usage_example.py index 19d1662b..cd8528e4 100644 --- a/grakel_replace/mixed_single_task_gp_usage_example.py +++ b/grakel_replace/mixed_single_task_gp_usage_example.py @@ -105,9 +105,9 @@ print("\nMean:", predictions) print("Variance:", uncertainties) -print("Covariance matrix:", covar) # =============== Fitting the GP using botorch =============== + print("\nFitting the GP model using botorch...") mll = ExactMarginalLogLikelihood(gp.likelihood, gp) @@ -152,7 +152,7 @@ bounds=bounds, fixed_features_list=fixed_cats, train_graphs=train_graphs, - num_graph_samples=6, + num_graph_samples=20, num_restarts=10, raw_samples=10, q=1, From 2999582cad468b4af90c2341b6e4f8d0a833086b Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Sat, 7 Dec 2024 23:06:56 +0100 Subject: [PATCH 24/32] Make train_inputs private --- grakel_replace/mixed_single_task_gp.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/grakel_replace/mixed_single_task_gp.py b/grakel_replace/mixed_single_task_gp.py index 9f322d05..5b502b19 100644 --- a/grakel_replace/mixed_single_task_gp.py +++ b/grakel_replace/mixed_single_task_gp.py @@ -76,7 +76,7 @@ class MixedSingleTaskGP(SingleTaskGP): Attributes: _wl_kernel (TorchWLKernel): Instance of the Weisfeiler-Lehman kernel. _K_train (Tensor): Precomputed graph kernel matrix for training graphs. - train_inputs (tuple[Tensor, list[nx.Graph]]): Tuple of training inputs. + _train_inputs (tuple[Tensor, list[nx.Graph]]): Tuple of training inputs. num_cat_kernel (Module | None): Kernel for numerical/categorical features. """ @@ -114,9 +114,6 @@ def __init__( covar_module=num_cat_kernel, **kwargs, ) - # Store graphs as part of train_inputs for using them in the __call__ method - self.train_inputs = (train_X, train_graphs) - # Initialize the Weisfeiler-Lehman kernel or use a default one self._wl_kernel = wl_kernel or TorchWLKernel(n_iter=5, normalize=True) @@ -125,6 +122,9 @@ def __init__( self._train_graph_dataset = GraphDataset.from_networkx(train_graphs) self._K_train = self._wl_kernel(self._train_graph_dataset) + # Store graphs as part of train_inputs for using them in the __call__ method + self._train_inputs = (train_X, train_graphs) + # If a kernel for numerical/categorical features is provided, combine it with # the WL kernel if num_cat_kernel is not None: @@ -138,7 +138,7 @@ def __init__( def __call__(self, X: Tensor, graphs: list[nx.Graph] | None = None, **kwargs): """Custom __call__ method that retrieves train graphs if not explicitly passed.""" if graphs is None: # Use stored graphs from train_inputs if not provided - graphs = self.train_inputs[1] + graphs = self._train_inputs[1] return self.forward(X, graphs) def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal: From ad5503003b329fa70b9c8ff3e2e536caee83a1bc Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Sat, 7 Dec 2024 23:10:21 +0100 Subject: [PATCH 25/32] Update tests --- tests/test_mixed_single_task_gp.py | 49 ++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/tests/test_mixed_single_task_gp.py b/tests/test_mixed_single_task_gp.py index 3119d66d..d9c9fade 100644 --- a/tests/test_mixed_single_task_gp.py +++ b/tests/test_mixed_single_task_gp.py @@ -72,7 +72,7 @@ def test_model_initialization_and_validation(sample_data, sample_kernels): wl_kernel=wl_kernel, ) - assert gp._train_graphs == graphs + assert gp._train_inputs[1] == graphs assert isinstance(gp._wl_kernel, TorchWLKernel) assert gp.num_cat_kernel == combined_kernel @@ -133,6 +133,10 @@ def test_kernel_combinations_and_properties(sample_data): """Test kernel combination, invariance, and consistency properties.""" X, graphs, y = sample_data + # Ensure inputs are properly formatted + X = X.float() + y = y.float() + # Test kernel combination and variance changes # Create GP with only graph kernel gp_graph_only = MixedSingleTaskGP( @@ -142,16 +146,16 @@ def test_kernel_combinations_and_properties(sample_data): ) output_graph = gp_graph_only.forward(X, graphs) - graph_var = output_graph.variance + graph_var = output_graph.variance.detach() # Detach to avoid gradient computation # Create GP with combined kernels - n_numerical = 2 + n_numerical = X.shape[1] # Use actual number of features from X matern = ScaleKernel( MaternKernel( nu=2.5, ard_num_dims=n_numerical, active_dims=tuple(range(n_numerical)), - ), + ) ) gp_combined = MixedSingleTaskGP( @@ -162,24 +166,35 @@ def test_kernel_combinations_and_properties(sample_data): ) output_combined = gp_combined.forward(X, graphs) - combined_var = output_combined.variance + combined_var = output_combined.variance.detach() # Combined kernel should have larger variance due to addition - assert torch.all(combined_var > graph_var) + assert torch.all(combined_var >= graph_var - 1e-6) # Allow for numerical precision - # Use graphs with slight variations to avoid singular matrix - similar_graphs = [ - nx.complete_graph(5) for _ in range(len(graphs)) - ] + # Create similar but slightly different graphs + similar_graphs = [] + for _ in range(len(graphs)): + G = nx.Graph() + G.add_nodes_from(range(5)) + G.add_edges_from([(i, j) for i in range(5) for j in range(i + 1, 5)]) + similar_graphs.append(G) # Add small random perturbations to make graphs slightly different for i in range(1, len(similar_graphs)): G = similar_graphs[i] # Add or remove edges with a small probability - edges_to_add = [(u, v) for u in range(5) for v in range(u + 1, 5) - if not G.has_edge(u, v) and torch.rand(1) < 0.1] - edges_to_remove = [(u, v) for (u, v) in G.edges() - if torch.rand(1) < 0.1] + edges_to_add = [] + edges_to_remove = [] + + # Use fixed random seed for reproducibility + torch.manual_seed(i) + + for u in range(5): + for v in range(u + 1, 5): + if not G.has_edge(u, v) and torch.rand(1) < 0.1: + edges_to_add.append((u, v)) + elif G.has_edge(u, v) and torch.rand(1) < 0.1: + edges_to_remove.append((u, v)) G.add_edges_from(edges_to_add) G.remove_edges_from(edges_to_remove) @@ -195,10 +210,12 @@ def test_kernel_combinations_and_properties(sample_data): diag = kernel_matrix.diag() # Allow for slight variations due to graph perturbations - assert torch.allclose(diag, diag[0], atol=1e-1) + assert torch.allclose(diag, torch.ones_like(diag), atol=1e-6) # Check that the matrix is not completely uniform - assert not torch.allclose(kernel_matrix, torch.ones_like(kernel_matrix), rtol=1e-5) + off_diag = kernel_matrix - torch.eye(kernel_matrix.size(0), + device=kernel_matrix.device) + assert not torch.allclose(off_diag, torch.zeros_like(off_diag), atol=1e-3) def test_model_prediction_consistency(sample_data, sample_kernels): From 8093d3118157ff65ea0128507fddba6c6f2ace51 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 16 Dec 2024 15:13:38 +0100 Subject: [PATCH 26/32] fix: Implement graph acquisition --- grakel_replace/mixed_single_task_gp.py | 126 +-------------- .../mixed_single_task_gp_usage_example.py | 125 +++++++++------ grakel_replace/optimize.py | 72 +++++++-- grakel_replace/torch_wl_kernel.py | 149 +++++++++++++++--- grakel_replace/torch_wl_usage_example.py | 20 ++- 5 files changed, 286 insertions(+), 206 deletions(-) diff --git a/grakel_replace/mixed_single_task_gp.py b/grakel_replace/mixed_single_task_gp.py index 5b502b19..9a381c7f 100644 --- a/grakel_replace/mixed_single_task_gp.py +++ b/grakel_replace/mixed_single_task_gp.py @@ -2,67 +2,15 @@ from typing import TYPE_CHECKING -import networkx as nx -import torch from botorch.models import SingleTaskGP -from gpytorch.distributions.multivariate_normal import MultivariateNormal from gpytorch.kernels import AdditiveKernel, Kernel from grakel_replace.torch_wl_kernel import GraphDataset, TorchWLKernel if TYPE_CHECKING: - from gpytorch.module import Module + import networkx as nx from torch import Tensor -class WLKernel(Kernel): - """Weisfeiler-Lehman Kernel for graph similarity - integrated into the GPyTorch framework. - - This kernel encapsulates the precomputed Weisfeiler-Lehman graph kernel matrix - and provides it in a GPyTorch-compatible format. - It computes either the training kernel - or the cross-kernel between training and test graphs as needed. - """ - - def __init__( - self, - K_train: Tensor, - wl_kernel: TorchWLKernel, - train_graph_dataset: GraphDataset - ) -> None: - super().__init__() - self._K_train = K_train - self._wl_kernel = wl_kernel - self._train_graph_dataset = train_graph_dataset - - def forward( - self, x1: Tensor, - x2: Tensor | None = None, - diag: bool = False, - last_dim_is_batch: bool = False - ) -> Tensor: - """Forward method to compute the kernel matrix for the graph inputs. - - Args: - x1 (Tensor): First input tensor - (unused, required for interface compatibility). - x2 (Tensor | None): Second input tensor. - If None, computes the training kernel matrix. - diag (bool): Whether to return only the diagonal of the kernel matrix. - last_dim_is_batch (bool): Whether the last dimension is a batch dimension. - - Returns: - Tensor: The computed kernel matrix. - """ - if x2 is None: - # Return the precomputed training kernel matrix - return self._K_train - - # Compute cross-kernel between training graphs and new test graphs - test_dataset = GraphDataset.from_networkx(x2) # x2 should be test graphs - return self._wl_kernel(self._train_graph_dataset, test_dataset) - - class MixedSingleTaskGP(SingleTaskGP): """A Gaussian Process model for mixed input spaces containing numerical, categorical, and graph features. @@ -85,9 +33,9 @@ def __init__( train_X: Tensor, train_graphs: list[nx.Graph], train_Y: Tensor, + num_cat_kernel: Kernel, + wl_kernel: TorchWLKernel, train_Yvar: Tensor | None = None, - num_cat_kernel: Module | None = None, - wl_kernel: TorchWLKernel | None = None, **kwargs, ) -> None: """Initialize the mixed-input Gaussian Process model. @@ -115,7 +63,7 @@ def __init__( **kwargs, ) # Initialize the Weisfeiler-Lehman kernel or use a default one - self._wl_kernel = wl_kernel or TorchWLKernel(n_iter=5, normalize=True) + self._wl_kernel = wl_kernel # Preprocess the training graphs into a compatible format and compute the graph # kernel matrix @@ -137,69 +85,7 @@ def __init__( def __call__(self, X: Tensor, graphs: list[nx.Graph] | None = None, **kwargs): """Custom __call__ method that retrieves train graphs if not explicitly passed.""" + print("__call__", X.shape, len(graphs) if graphs is not None else None) # noqa: T201 if graphs is None: # Use stored graphs from train_inputs if not provided graphs = self._train_inputs[1] - return self.forward(X, graphs) - - def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal: - """Forward pass to compute the Gaussian Process distribution for given inputs. - - This combines the numerical/categorical kernel with the graph kernel - to compute the joint covariance matrix. - - Args: - X (Tensor): Input tensor for numerical and categorical features. - graphs (list[nx.Graph]): List of input graphs. - - Returns: - MultivariateNormal: The Gaussian Process distribution for the inputs. - """ - if len(X) != len(graphs): - raise ValueError( - f"Number of feature vectors ({len(X)}) must match " - f"number of graphs ({len(graphs)})" - ) - if not all(isinstance(g, nx.Graph) for g in graphs): - raise TypeError("Expected input type is a list of NetworkX graphs.") - - # Process the new graph inputs into a compatible dataset - proc_graphs = GraphDataset.from_networkx(graphs) - - # Compute the kernel matrix for the new graphs - K_new = self._wl_kernel(proc_graphs) - K_new = K_new.to(dtype=X.dtype) - - # Combine the graph kernel with the numerical/categorical kernel (if present) - if self.num_cat_kernel is not None: - K_num_cat = self.num_cat_kernel(X) - - # Ensure K_new matches K_num_cat dimensions - if K_num_cat.dim() > 2: - batch_size = K_num_cat.size(0) - target_size = K_num_cat.size(1) - - # Resize K_new if needed - if K_new.size(-1) != target_size: - K_new_resized = torch.zeros( - *K_new.shape[:-2], target_size, target_size, - dtype=K_new.dtype, - device=K_new.device - ) - K_new_resized[..., :K_new.size(-2), :K_new.size(-1)] = K_new - K_new = K_new_resized - - if K_new.dim() < K_num_cat.dim(): - K_new = K_new.unsqueeze(0).expand(batch_size, target_size, - target_size) - - # Convert to dense tensor if needed - if hasattr(K_num_cat, "to_dense"): - K_num_cat = K_num_cat.to_dense() - - K_combined = K_num_cat + K_new - else: - K_combined = K_new - - # Compute the mean using the mean module and construct the GP distribution - mean_x = self.mean_module(X) - return MultivariateNormal(mean_x, K_combined) + return self.forward(X) diff --git a/grakel_replace/mixed_single_task_gp_usage_example.py b/grakel_replace/mixed_single_task_gp_usage_example.py index cd8528e4..67306456 100644 --- a/grakel_replace/mixed_single_task_gp_usage_example.py +++ b/grakel_replace/mixed_single_task_gp_usage_example.py @@ -1,30 +1,40 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager from itertools import product +from typing import TYPE_CHECKING import networkx as nx import torch +from botorch import fit_gpytorch_mll from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement -from botorch.fit import fit_gpytorch_mll -from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel +from botorch.models.gp_regression_mixed import CategoricalKernel, Kernel, ScaleKernel from gpytorch import ExactMarginalLogLikelihood -from gpytorch.distributions.multivariate_normal import MultivariateNormal from gpytorch.kernels import AdditiveKernel, MaternKernel -from grakel_replace.mixed_single_task_gp import MixedSingleTaskGP from grakel_replace.optimize import optimize_acqf_graph from grakel_replace.torch_wl_kernel import TorchWLKernel -TRAIN_CONFIGS = 10 +if TYPE_CHECKING: + from gpytorch.distributions.multivariate_normal import MultivariateNormal + +TRAIN_CONFIGS = 50 TEST_CONFIGS = 10 TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS N_NUMERICAL = 2 -N_CATEGORICAL = 2 -N_CATEGORICAL_VALUES_PER_CATEGORY = 3 -N_GRAPH = 2 +N_CATEGORICAL = 1 +N_CATEGORICAL_VALUES_PER_CATEGORY = 2 +N_GRAPH = 1 +assert N_GRAPH == 1, "This example only supports a single graph feature" kernels = [] # Create numerical and categorical features -X = torch.empty(size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL), dtype=torch.float64) +X = torch.empty( + size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL + N_GRAPH), + dtype=torch.float64, +) if N_NUMERICAL > 0: X[:, :N_NUMERICAL] = torch.rand( size=(TOTAL_CONFIGS, N_NUMERICAL), @@ -32,7 +42,7 @@ ) if N_CATEGORICAL > 0: - X[:, N_NUMERICAL:] = torch.randint( + X[:, N_NUMERICAL : N_NUMERICAL + N_CATEGORICAL] = torch.randint( 0, N_CATEGORICAL_VALUES_PER_CATEGORY, size=(TOTAL_CONFIGS, N_CATEGORICAL), @@ -45,8 +55,21 @@ G = nx.erdos_renyi_graph(n=5, p=0.5) # Random graph with 5 nodes graphs.append(G) +# Assign a new index column to the graphs +X[:, -1] = torch.arange(TOTAL_CONFIGS, dtype=torch.float64) + # Create random target values -y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64) +y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64) + 0.5 + +# Split into train and test sets +train_x = X[:TRAIN_CONFIGS] +train_graphs = graphs[:TRAIN_CONFIGS] +train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch + +test_x = X[TRAIN_CONFIGS:] +test_graphs = graphs[TRAIN_CONFIGS:] +test_y = y[TRAIN_CONFIGS:].unsqueeze(-1) + # Setup kernels for numerical and categorical features if N_NUMERICAL > 0: @@ -68,47 +91,56 @@ ) kernels.append(hamming) -# Combine numerical and categorical kernels -combined_num_cat_kernel = AdditiveKernel(*kernels) if kernels else None +if N_GRAPH > 0: + wl_kernel = ScaleKernel( + TorchWLKernel( + graph_lookup=train_graphs, + n_iter=5, + normalize=True, + active_dims=(X.shape[1] - 1,), # Last column + ) + ) + kernels.append(wl_kernel) -# Create WL kernel for graphs -wl_kernel = TorchWLKernel(n_iter=5, normalize=True) -# Split into train and test sets -train_x = X[:TRAIN_CONFIGS] -train_graphs = graphs[:TRAIN_CONFIGS] -train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch +# Combine numerical and categorical kernels +kernel = AdditiveKernel(*kernels) -test_x = X[TRAIN_CONFIGS:] -test_graphs = graphs[TRAIN_CONFIGS:] -test_y = y[TRAIN_CONFIGS:].unsqueeze(-1) +from botorch.models import SingleTaskGP # Initialize the mixed GP -gp = MixedSingleTaskGP( - train_X=train_x, - train_graphs=train_graphs, - train_Y=train_y, - num_cat_kernel=combined_num_cat_kernel, - wl_kernel=wl_kernel, -) +gp = SingleTaskGP(train_X=train_x, train_Y=train_y, covar_module=kernel) # Compute the posterior distribution -multivariate_normal: MultivariateNormal = gp.forward(train_x, train_graphs) -print("Posterior distribution:", multivariate_normal) +# The wl_kernel will use the indices to index into the training graphs it is holding +# on to... +multivariate_normal: MultivariateNormal = gp.forward(train_x) + # Making predictions on test data -with torch.no_grad(): - posterior = gp.forward(test_x, test_graphs) +# No the wl_kernel needs to be aware of the test graphs +@contextmanager +def set_graph_lookup(_gp: SingleTaskGP, new_graphs: list[nx.Graph]) -> Iterator[None]: + kernel_prev_graphs: list[tuple[Kernel, list[nx.Graph]]] = [] + for kern in _gp.covar_module.sub_kernels(): + if isinstance(kern, TorchWLKernel): + kernel_prev_graphs.append((kern, kern.graph_lookup)) + kern.set_graph_lookup(new_graphs) + + yield + + for _kern, _prev_graphs in kernel_prev_graphs: + _kern.set_graph_lookup(_prev_graphs) + + +with torch.no_grad(), set_graph_lookup(gp, train_graphs + test_graphs): + posterior = gp.forward(test_x) predictions = posterior.mean uncertainties = posterior.variance.sqrt() covar = posterior.covariance_matrix -print("\nMean:", predictions) -print("Variance:", uncertainties) - # =============== Fitting the GP using botorch =============== -print("\nFitting the GP model using botorch...") mll = ExactMarginalLogLikelihood(gp.likelihood, gp) fit_gpytorch_mll(mll) @@ -124,8 +156,10 @@ # Define bounds bounds = torch.tensor( [ - [0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL, - [1.0] * N_NUMERICAL + [float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL + [0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL + [-1.0] * N_GRAPH, + [1.0] * N_NUMERICAL + + [float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL + + [len(X) - 1] * N_GRAPH, ] ) @@ -142,21 +176,20 @@ fixed_cats = [{col: i} for i in choice_indices] else: fixed_cats = [ - dict(zip(cats_per_column.keys(), combo)) + dict(zip(cats_per_column.keys(), combo, strict=False)) for combo in product(*cats_per_column.values()) ] + +print("------------------") # noqa: T201 # Use the graph-optimized acquisition function best_candidate, best_score = optimize_acqf_graph( acq_function=acq_function, bounds=bounds, fixed_features_list=fixed_cats, train_graphs=train_graphs, - num_graph_samples=20, - num_restarts=10, - raw_samples=10, + num_graph_samples=2, + num_restarts=2, + raw_samples=16, q=1, ) - -print("Best candidate:", best_candidate) -print("Acquisition score:", best_score) diff --git a/grakel_replace/optimize.py b/grakel_replace/optimize.py index eff7adac..4f21886e 100644 --- a/grakel_replace/optimize.py +++ b/grakel_replace/optimize.py @@ -1,14 +1,49 @@ from __future__ import annotations import random +from collections.abc import Iterator +from contextlib import contextmanager from typing import TYPE_CHECKING import networkx as nx import torch from botorch.optim import optimize_acqf_mixed +from grakel_replace.torch_wl_kernel import TorchWLKernel if TYPE_CHECKING: from botorch.acquisition import AcquisitionFunction + from botorch.models.gp_regression_mixed import Kernel + + +# Making predictions on test data +# No the wl_kernel needs to be aware of the test graphs +@contextmanager +def set_graph_lookup( + kernel: Kernel, + new_graphs: list[nx.Graph], + *, + append: bool = True, +) -> Iterator[None]: + kernel_prev_graphs: list[tuple[Kernel, list[nx.Graph]]] = [] + if isinstance(kernel, TorchWLKernel): + modules = [kernel] + else: + assert hasattr( + kernel, "sub_kernels" + ), "Kernel module must have sub_kernels method." + modules = [k for k in kernel.sub_kernels() if isinstance(k, TorchWLKernel)] + + for kern in modules: + kernel_prev_graphs.append((kern, kern.graph_lookup)) + if append: + kern.set_graph_lookup([*kern.graph_lookup, *new_graphs]) + else: + kern.set_graph_lookup(new_graphs) + + yield + + for _kern, _prev_graphs in kernel_prev_graphs: + _kern.set_graph_lookup(_prev_graphs) def sample_graphs(graphs: list[nx.Graph], num_samples: int) -> list[nx.Graph]: @@ -35,7 +70,7 @@ def sample_graphs(graphs: list[nx.Graph], num_samples: int) -> list[nx.Graph]: u, v = random.sample(nodes, 2) if not sampled_graph.has_edge(u, v): sampled_graph.add_edge(u, v) - elif sampled_graph.edges: # 30% chance to remove edge + elif sampled_graph.edges: # 30% chance to remove edge u, v = random.choice(list(sampled_graph.edges)) sampled_graph.remove_edge(u, v) @@ -81,21 +116,34 @@ def optimize_acqf_graph( raise ValueError("train_graphs cannot be None.") sampled_graphs = sample_graphs(train_graphs, num_samples=num_graph_samples) + gp = acq_function.model + covar_module = gp.covar_module best_candidates, best_scores = [], [] + TODO_GRAPH_COLUMN_INDEX = bounds.shape[1] - 1 + for _graph in sampled_graphs: - for fixed_features in fixed_features_list or [{}]: - candidates, scores = optimize_acqf_mixed( - acq_function=acq_function, - bounds=bounds, - fixed_features_list=[fixed_features], - num_restarts=num_restarts, - raw_samples=raw_samples, - q=q, - ) - best_candidates.append(candidates) - best_scores.append(scores) + # This is new, we essentially iterate through all the kernels and + # include the sampled graph. + with set_graph_lookup(covar_module, [_graph], append=True): + for fixed_features in fixed_features_list or [{}]: + # We then consider this graph as a fixed feature, i.e. in the X's + # generated during acquisition, the graph column will just be full + # of `-1` indicating to select the very last graph in the lookup + # they used. + _fixed_features = {**fixed_features, TODO_GRAPH_COLUMN_INDEX: -1.0} + + candidates, scores = optimize_acqf_mixed( + acq_function=acq_function, + bounds=bounds, + fixed_features_list=[_fixed_features], + num_restarts=num_restarts, + raw_samples=raw_samples, + q=q, + ) + best_candidates.append(candidates) + best_scores.append(scores) best_scores_tensor = torch.tensor(best_scores) best_idx = torch.argmax(best_scores_tensor) diff --git a/grakel_replace/torch_wl_kernel.py b/grakel_replace/torch_wl_kernel.py index ba2ae071..5c946555 100644 --- a/grakel_replace/torch_wl_kernel.py +++ b/grakel_replace/torch_wl_kernel.py @@ -1,13 +1,122 @@ from __future__ import annotations from collections import Counter +from typing import Any import networkx as nx import torch +from botorch.models.gp_regression_mixed import Kernel from torch import nn -class TorchWLKernel(nn.Module): +class TorchWLKernel(Kernel): + has_lengthscale = False + + def __init__( + self, + graph_lookup: list[nx.Graph], + n_iter: int = 5, + *, + normalize: bool = True, + active_dims: tuple[int, ...], + **kwargs: Any, + ) -> None: + super().__init__(active_dims=active_dims, **kwargs) + self.graph_lookup = graph_lookup + self.n_iter = n_iter + self.normalize = normalize + + # NOTE: set in the `super().__init__()` + self.active_dims: torch.Tensor + + def set_graph_lookup(self, graph_lookup: list[nx.Graph]) -> None: + self.graph_lookup = graph_lookup + + def forward( + self, + x1: torch.Tensor, + x2: torch.Tensor, + *, + diag: bool = False, + last_dim_is_batch: bool = False, + **params: Any, + ): + if last_dim_is_batch: + raise NotImplementedError("TODO: Figure this out") + + assert x1.shape[-1] == 1, "Last dimension must be the graph index" + assert x2.shape[-1] == 1, "Last dimension must be the graph index" + + # TODO: Optimizations + # + # 1. We're computing the whole K Matrix, but we only need the K_x1_x2 + # + # K + # -------------------- + # | K_x1_x1 K_x1_x2 | + # | K_x2_x1 K_x2_x2 | + # -------------------- + # + # However in the case where x1 == x2, we can shortcut this slightly as in + # the above, K_x1_x2 == K_x2_x1 == K_x1_x1 == K_x2_x2 + # This shortcut is implemented below based on this flag. + # + # 2. The _TorchWLKernel used below has the following properties, which + # get set on forward. In the case where x1.ndim == 3 then the first dim is + # the `q` dim. Doesn't matter what it is other than we end up repeating the + # processing the graphs `q` times. Given that it's likely that the indices + # in last dimension are likely to be constant (i.e. all `4`, indicating the + # `4th` graph, we are effectively doing a lot of extra calculation. We could + # shortcut this by pre-computing these for each index. Could be nice to somehow + # have the inned `_TorchWLKernel` be aware of this extra dimension but it's + # fine if not as long as we can reduce the extraneuous computations. We could + # change the interface of `_TorchWLKernel` to take in the raw processed + # tensors instead of `nx.Graph` objects, which we would instead preprocess here. + # If that's the case, we could move the `_TorchWLKernel` to essentially just + # be functions we call instead with the correct pre-processed data. + # + # .self.label_dict + # .self.label_counter + # + x1_is_x2 = torch.equal(x1, x2) + + # NOTE: The active dim is already selected out for us and is the last dimension + # (not including whatever happens when last_dim_is_batch) is True. + if x1.ndim == 3: + # - x1: torch.Size([32, 5, 1]) + # - x2: torch.Size([32, 55, 1]) + # - output: torch.Size([32, 5, 55]) + q_dim_size = x1.shape[0] + assert x2.shape[0] == q_dim_size + + out = torch.empty((q_dim_size, x1.shape[1], x2.shape[1]), device=x1.device) + for q in range(q_dim_size): + out[q] = self.forward(x1[q], x2[q], diag=diag) + return out + + if x1_is_x2: + _ixs = x1.flatten().to(torch.int64).tolist() + all_graphs = [self.graph_lookup[i] for i in _ixs] + + # No selection requires + select = None + else: + _ixs1 = x1.flatten().to(torch.int64).tolist() + _ixs2 = x2.flatten().to(torch.int64).tolist() + all_graphs = [self.graph_lookup[i] for i in _ixs1 + _ixs2] + + # Select out K_x1_x2 + select = lambda _K: _K[: len(_ixs1), len(_ixs1) :] + + _kernel = _TorchWLKernel(n_iter=self.n_iter, normalize=self.normalize) + K = _kernel(all_graphs) + K_selected = K if select is None else select(K) + if diag: + return torch.diag(K_selected) + return K_selected + + +class _TorchWLKernel(nn.Module): """A custom implementation of Weisfeiler-Lehman (WL) Kernel in PyTorch. The WL Kernel is a graph kernel that measures similarity between graphs based on @@ -24,7 +133,7 @@ class TorchWLKernel(nn.Module): label_counter: Counter for generating new label indices """ - def __init__(self, n_iter: int = 5, normalize: bool = True) -> None: + def __init__(self, n_iter: int = 5, *, normalize: bool = True) -> None: super().__init__() self.n_iter = n_iter self.normalize = normalize @@ -49,7 +158,7 @@ def _get_sparse_adj(self, graph: nx.Graph) -> torch.sparse.Tensor: indices=torch.empty((2, 0), dtype=torch.long), values=torch.empty(0), size=(num_nodes, num_nodes), - device=self.device + device=self.device, ) # Create bidirectional edge indices for undirected graph @@ -60,8 +169,7 @@ def _get_sparse_adj(self, graph: nx.Graph) -> torch.sparse.Tensor: values = torch.ones(len(edge_indices), dtype=torch.float, device=self.device) return torch.sparse_coo_tensor( - indices, values, (num_nodes, num_nodes), - device=self.device + indices, values, (num_nodes, num_nodes), device=self.device ) def _init_node_labels(self, graph: nx.Graph) -> torch.Tensor: @@ -88,9 +196,7 @@ def _init_node_labels(self, graph: nx.Graph) -> torch.Tensor: return torch.tensor(labels, dtype=torch.long, device=self.device) def _wl_iteration( - self, - adj: torch.sparse.Tensor, - labels: torch.Tensor + self, adj: torch.sparse.Tensor, labels: torch.Tensor ) -> torch.Tensor: """Perform one WL iteration to update node labels. Concatenate own label with sorted neighbor labels. @@ -126,11 +232,7 @@ def _wl_iteration( return torch.tensor(new_labels, dtype=torch.long, device=self.device) - def _compute_feature_vector( - self, - labels: torch.Tensor, - size: int - ) -> torch.Tensor: + def _compute_feature_vector(self, labels: torch.Tensor, size: int) -> torch.Tensor: """Compute histogram feature vector from node labels. Args: @@ -172,8 +274,9 @@ def forward(self, graphs: list[nx.Graph]) -> torch.Tensor: TypeError: If input is not a list of NetworkX graphs """ # Validate input - if (not isinstance(graphs, list) or - not all(isinstance(g, nx.Graph) for g in graphs)): + if not isinstance(graphs, list) or not all( + isinstance(g, nx.Graph) for g in graphs + ): raise TypeError("Expected input type is a list of NetworkX graphs.") # Setup computation @@ -201,10 +304,12 @@ def forward(self, graphs: list[nx.Graph]) -> torch.Tensor: # Compute feature matrices using final label count feature_matrices = [ - torch.stack([ - self._compute_feature_vector(labels, self.label_counter) - for labels in iteration_labels - ]) + torch.stack( + [ + self._compute_feature_vector(labels, self.label_counter) + for labels in iteration_labels + ] + ) for iteration_labels in all_label_tensors ] @@ -226,9 +331,9 @@ class GraphDataset: """Utility class to convert NetworkX graphs for WL kernel.""" @staticmethod - def from_networkx(graphs: list[nx.Graph], node_labels_tag: str = "label") -> list[ - nx.Graph]: - + def from_networkx( + graphs: list[nx.Graph], node_labels_tag: str = "label" + ) -> list[nx.Graph]: if not all(isinstance(g, nx.Graph) for g in graphs): raise TypeError("Expected input type is a list of NetworkX graphs.") diff --git a/grakel_replace/torch_wl_usage_example.py b/grakel_replace/torch_wl_usage_example.py index a24d2536..f9958045 100644 --- a/grakel_replace/torch_wl_usage_example.py +++ b/grakel_replace/torch_wl_usage_example.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import networkx as nx +import torch from torch_wl_kernel import GraphDataset, TorchWLKernel # Create the same graphs as for the Grakel example @@ -10,12 +13,17 @@ G3.add_edges_from([(0, 1), (1, 3), (3, 2)]) # Process graphs -graphs = GraphDataset.from_networkx([G1, G2, G3]) +graphs: list[nx.Graph] = GraphDataset.from_networkx([G1, G2, G3]) # Initialize and run WL kernel -wl_kernel = TorchWLKernel(n_iter=2, normalize=False) - -K = wl_kernel(graphs) +wl_kernel = TorchWLKernel( + training_graph_list=graphs, + n_iter=2, + normalize=True, + active_dims=(1,), +) +X1 = torch.tensor([[42.4, 43.4, 44.5], [0, 1, 2]]).T +X2 = torch.tensor([[42.4, 43.4, 44.5], [0, 1, 2]]).T -print("Kernel matrix (pairwise similarities):") -print(K) +K = wl_kernel(X1, X2) +print(K.to_dense()) # noqa: T201 From a1a29a8aae9ed4c3b32cfa5066822c582f0ad469 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Tue, 24 Dec 2024 09:20:17 +0100 Subject: [PATCH 27/32] Delete unused MixedSingleTaskGP --- grakel_replace/mixed_single_task_gp.py | 91 -------- tests/test_mixed_single_task_gp.py | 298 ------------------------- 2 files changed, 389 deletions(-) delete mode 100644 grakel_replace/mixed_single_task_gp.py delete mode 100644 tests/test_mixed_single_task_gp.py diff --git a/grakel_replace/mixed_single_task_gp.py b/grakel_replace/mixed_single_task_gp.py deleted file mode 100644 index 9a381c7f..00000000 --- a/grakel_replace/mixed_single_task_gp.py +++ /dev/null @@ -1,91 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from botorch.models import SingleTaskGP -from gpytorch.kernels import AdditiveKernel, Kernel -from grakel_replace.torch_wl_kernel import GraphDataset, TorchWLKernel - -if TYPE_CHECKING: - import networkx as nx - from torch import Tensor - - -class MixedSingleTaskGP(SingleTaskGP): - """A Gaussian Process model for mixed input spaces containing numerical, categorical, - and graph features. - - This class extends BoTorch's SingleTaskGP to support hybrid inputs by combining: - - Standard kernels for numerical and categorical features. - - Weisfeiler-Lehman kernel for graph structures. - - The kernels are combined using an additive kernel structure. - - Attributes: - _wl_kernel (TorchWLKernel): Instance of the Weisfeiler-Lehman kernel. - _K_train (Tensor): Precomputed graph kernel matrix for training graphs. - _train_inputs (tuple[Tensor, list[nx.Graph]]): Tuple of training inputs. - num_cat_kernel (Module | None): Kernel for numerical/categorical features. - """ - - def __init__( - self, - train_X: Tensor, - train_graphs: list[nx.Graph], - train_Y: Tensor, - num_cat_kernel: Kernel, - wl_kernel: TorchWLKernel, - train_Yvar: Tensor | None = None, - **kwargs, - ) -> None: - """Initialize the mixed-input Gaussian Process model. - - Args: - train_X (Tensor): Training tensor for numerical and categorical features. - train_graphs (list[nx.Graph]): List of training graph instances. - train_Y (Tensor): Target values for training data. - train_Yvar (Tensor | None): Observation noise variance (optional). - num_cat_kernel (Module | None): Kernel for numerical/categorical features - (optional). - wl_kernel (TorchWLKernel | None): Weisfeiler-Lehman kernel instance - (optional). - **kwargs: Additional arguments for SingleTaskGP initialization. - """ - if train_X.size(0) == 0 or len(train_graphs) == 0: - raise ValueError("Training inputs (features and graphs) cannot be empty.") - - # Initialize the base SingleTaskGP with a num/cat kernel (if provided) - super().__init__( - train_X=train_X, - train_Y=train_Y, - train_Yvar=train_Yvar, - covar_module=num_cat_kernel, - **kwargs, - ) - # Initialize the Weisfeiler-Lehman kernel or use a default one - self._wl_kernel = wl_kernel - - # Preprocess the training graphs into a compatible format and compute the graph - # kernel matrix - self._train_graph_dataset = GraphDataset.from_networkx(train_graphs) - self._K_train = self._wl_kernel(self._train_graph_dataset) - - # Store graphs as part of train_inputs for using them in the __call__ method - self._train_inputs = (train_X, train_graphs) - - # If a kernel for numerical/categorical features is provided, combine it with - # the WL kernel - if num_cat_kernel is not None: - self.covar_module = AdditiveKernel( - num_cat_kernel, - WLKernel(self._K_train, self._wl_kernel, self._train_graph_dataset), - ) - - self.num_cat_kernel = num_cat_kernel - - def __call__(self, X: Tensor, graphs: list[nx.Graph] | None = None, **kwargs): - """Custom __call__ method that retrieves train graphs if not explicitly passed.""" - print("__call__", X.shape, len(graphs) if graphs is not None else None) # noqa: T201 - if graphs is None: # Use stored graphs from train_inputs if not provided - graphs = self._train_inputs[1] - return self.forward(X) diff --git a/tests/test_mixed_single_task_gp.py b/tests/test_mixed_single_task_gp.py deleted file mode 100644 index d9c9fade..00000000 --- a/tests/test_mixed_single_task_gp.py +++ /dev/null @@ -1,298 +0,0 @@ -import pytest -import torch -import networkx as nx -from gpytorch.kernels import MaternKernel, AdditiveKernel -from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel -from gpytorch.distributions.multivariate_normal import MultivariateNormal - -from grakel_replace.mixed_single_task_gp import MixedSingleTaskGP -from grakel_replace.torch_wl_kernel import TorchWLKernel - - -@pytest.fixture -def sample_data(): - """Create sample data for testing.""" - n_samples = 5 - n_numerical = 2 - n_categorical = 2 - - # Create numerical and categorical features - X = torch.empty(size=(n_samples, n_numerical + n_categorical), dtype=torch.float64) - X[:, :n_numerical] = torch.rand(size=(n_samples, n_numerical), dtype=torch.float64) - X[:, n_numerical:] = torch.randint(0, 3, size=(n_samples, n_categorical), - dtype=torch.float64) - - # Create sample graphs - graphs = [nx.erdos_renyi_graph(n=4, p=0.5) for _ in range(n_samples)] - - # Create target values - y = torch.rand(size=(n_samples, 1), dtype=torch.float64) - - return X, graphs, y - - -@pytest.fixture -def sample_kernels(): - """Create sample kernels for testing.""" - n_numerical = 2 - n_categorical = 2 - - matern = ScaleKernel( - MaternKernel( - nu=2.5, - ard_num_dims=n_numerical, - active_dims=tuple(range(n_numerical)), - ), - ) - - hamming = ScaleKernel( - CategoricalKernel( - ard_num_dims=n_categorical, - active_dims=tuple(range(n_numerical, n_numerical + n_categorical)), - ), - ) - - combined_kernel = AdditiveKernel(matern, hamming) - wl_kernel = TorchWLKernel(n_iter=3, normalize=True) - - return combined_kernel, wl_kernel - - -def test_model_initialization_and_validation(sample_data, sample_kernels): - """Test GP initialization, inputs validation, and basic properties.""" - X, graphs, y = sample_data - combined_kernel, wl_kernel = sample_kernels - - # Test successful initialization with all parameters - gp = MixedSingleTaskGP( - train_X=X, - train_graphs=graphs, - train_Y=y, - num_cat_kernel=combined_kernel, - wl_kernel=wl_kernel, - ) - - assert gp._train_inputs[1] == graphs - assert isinstance(gp._wl_kernel, TorchWLKernel) - assert gp.num_cat_kernel == combined_kernel - - # Test empty input validation - with pytest.raises(ValueError, match="Training inputs.*cannot be empty"): - MixedSingleTaskGP( - train_X=torch.empty((0, 4), dtype=torch.float64), - train_graphs=[], - train_Y=torch.empty((0, 1), dtype=torch.float64), - ) - - -def test_forward_pass_and_predictions(sample_data, sample_kernels): - """Test forward pass, shape consistency, and prediction characteristics.""" - X, graphs, y = sample_data - combined_kernel, wl_kernel = sample_kernels - - gp = MixedSingleTaskGP( - train_X=X, - train_graphs=graphs, - train_Y=y, - num_cat_kernel=combined_kernel, - wl_kernel=wl_kernel, - ) - - # Test forward pass with training data - output = gp.forward(X, graphs) - assert isinstance(output, MultivariateNormal) - assert output.mean.shape == (len(X),) - assert output.covariance_matrix.shape == (len(X), len(X)) - - # Test forward pass with test data - n_test = 3 - test_X = torch.rand(size=(n_test, X.shape[1]), dtype=torch.float64) - test_graphs = [nx.erdos_renyi_graph(n=4, p=0.5) for _ in range(n_test)] - - output_test = gp.forward(test_X, test_graphs) - assert output_test.mean.shape == (n_test,) - assert output_test.covariance_matrix.shape == (n_test, n_test) - - # Test input validation - # Mismatched number of features and graphs - mismatched_test_X = torch.rand(size=(3, X.shape[1]), dtype=torch.float64) - mismatched_test_graphs = [nx.erdos_renyi_graph(n=4, p=0.5) for _ in range(4)] - - with pytest.raises(ValueError, - match="Number of feature vectors.*must match.*number of graphs"): - gp.forward(mismatched_test_X, mismatched_test_graphs) - - # Invalid graph input - invalid_graphs = ["not_a_graph", 123, None, False, True] - with pytest.raises(TypeError, - match="Expected input type is a list of NetworkX graphs."): - gp.forward(X, invalid_graphs) - - -def test_kernel_combinations_and_properties(sample_data): - """Test kernel combination, invariance, and consistency properties.""" - X, graphs, y = sample_data - - # Ensure inputs are properly formatted - X = X.float() - y = y.float() - - # Test kernel combination and variance changes - # Create GP with only graph kernel - gp_graph_only = MixedSingleTaskGP( - train_X=X, - train_graphs=graphs, - train_Y=y, - ) - - output_graph = gp_graph_only.forward(X, graphs) - graph_var = output_graph.variance.detach() # Detach to avoid gradient computation - - # Create GP with combined kernels - n_numerical = X.shape[1] # Use actual number of features from X - matern = ScaleKernel( - MaternKernel( - nu=2.5, - ard_num_dims=n_numerical, - active_dims=tuple(range(n_numerical)), - ) - ) - - gp_combined = MixedSingleTaskGP( - train_X=X, - train_graphs=graphs, - train_Y=y, - num_cat_kernel=matern, - ) - - output_combined = gp_combined.forward(X, graphs) - combined_var = output_combined.variance.detach() - - # Combined kernel should have larger variance due to addition - assert torch.all(combined_var >= graph_var - 1e-6) # Allow for numerical precision - - # Create similar but slightly different graphs - similar_graphs = [] - for _ in range(len(graphs)): - G = nx.Graph() - G.add_nodes_from(range(5)) - G.add_edges_from([(i, j) for i in range(5) for j in range(i + 1, 5)]) - similar_graphs.append(G) - - # Add small random perturbations to make graphs slightly different - for i in range(1, len(similar_graphs)): - G = similar_graphs[i] - # Add or remove edges with a small probability - edges_to_add = [] - edges_to_remove = [] - - # Use fixed random seed for reproducibility - torch.manual_seed(i) - - for u in range(5): - for v in range(u + 1, 5): - if not G.has_edge(u, v) and torch.rand(1) < 0.1: - edges_to_add.append((u, v)) - elif G.has_edge(u, v) and torch.rand(1) < 0.1: - edges_to_remove.append((u, v)) - - G.add_edges_from(edges_to_add) - G.remove_edges_from(edges_to_remove) - - gp_similar = MixedSingleTaskGP( - train_X=X, - train_graphs=similar_graphs, - train_Y=y, - ) - - # Compute kernel matrix and check diagonal consistency - kernel_matrix = gp_similar._K_train - diag = kernel_matrix.diag() - - # Allow for slight variations due to graph perturbations - assert torch.allclose(diag, torch.ones_like(diag), atol=1e-6) - - # Check that the matrix is not completely uniform - off_diag = kernel_matrix - torch.eye(kernel_matrix.size(0), - device=kernel_matrix.device) - assert not torch.allclose(off_diag, torch.zeros_like(off_diag), atol=1e-3) - - -def test_model_prediction_consistency(sample_data, sample_kernels): - """Test prediction consistency and mean prediction bounds.""" - X, graphs, y = sample_data - combined_kernel, wl_kernel = sample_kernels - - gp = MixedSingleTaskGP( - train_X=X, - train_graphs=graphs, - train_Y=y, - num_cat_kernel=combined_kernel, - wl_kernel=wl_kernel, - ) - - # Multiple forward passes should give consistent results - output1 = gp.forward(X, graphs) - output2 = gp.forward(X, graphs) - - assert torch.allclose(output1.mean, output2.mean) - assert torch.allclose(output1.variance, output2.variance) - - # Mean predictions should be within reasonable bounds - with torch.no_grad(): - output = gp.forward(X, graphs) - predictions = output.mean - uncertainties = output.variance.sqrt() - - assert torch.all(predictions >= y.min() - 2 * uncertainties) - assert torch.all(predictions <= y.max() + 2 * uncertainties) - - -def test_graph_kernel_caching(sample_data, sample_kernels): - """Test that graph kernel matrices are properly cached.""" - X, graphs, y = sample_data - combined_kernel, wl_kernel = sample_kernels - - gp = MixedSingleTaskGP( - train_X=X, - train_graphs=graphs, - train_Y=y, - num_cat_kernel=combined_kernel, - wl_kernel=wl_kernel, - ) - - # First forward pass - _ = gp.forward(X, graphs) - K_train_1 = gp._K_train.clone() - - # Second forward pass - _ = gp.forward(X, graphs) - K_train_2 = gp._K_train.clone() - - # Cached kernel matrices should be identical - assert torch.allclose(K_train_1, K_train_2) - - -def test_large_dataset_handling(): - """Test the model's behavior with large datasets.""" - n_samples = 100 - n_features = 4 - - # Create large numerical and categorical features - X = torch.rand(size=(n_samples, n_features), dtype=torch.float64) - - # Create a large set of random graphs - graphs = [nx.erdos_renyi_graph(n=10, p=0.2) for _ in range(n_samples)] - - # Create target values - y = torch.rand(size=(n_samples, 1), dtype=torch.float64) - - gp = MixedSingleTaskGP( - train_X=X, - train_graphs=graphs, - train_Y=y, - ) - - output = gp.forward(X, graphs) - assert output.mean.shape == (n_samples,) - assert output.covariance_matrix.shape == (n_samples, n_samples) From 046ad66dc136e294171b00e1cf9f5fb8d81f229d Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Tue, 24 Dec 2024 09:46:19 +0100 Subject: [PATCH 28/32] Add seed_all and min_max_scale --- .../mixed_single_task_gp_usage_example.py | 163 ++++++------------ grakel_replace/utils.py | 24 +++ 2 files changed, 76 insertions(+), 111 deletions(-) create mode 100644 grakel_replace/utils.py diff --git a/grakel_replace/mixed_single_task_gp_usage_example.py b/grakel_replace/mixed_single_task_gp_usage_example.py index 67306456..335bc1b4 100644 --- a/grakel_replace/mixed_single_task_gp_usage_example.py +++ b/grakel_replace/mixed_single_task_gp_usage_example.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time from collections.abc import Iterator from contextlib import contextmanager from itertools import product @@ -7,17 +8,23 @@ import networkx as nx import torch -from botorch import fit_gpytorch_mll +from botorch import fit_gpytorch_mll, settings from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement +from botorch.models import SingleTaskGP from botorch.models.gp_regression_mixed import CategoricalKernel, Kernel, ScaleKernel from gpytorch import ExactMarginalLogLikelihood from gpytorch.kernels import AdditiveKernel, MaternKernel from grakel_replace.optimize import optimize_acqf_graph from grakel_replace.torch_wl_kernel import TorchWLKernel +from grakel_replace.utils import min_max_scale, seed_all if TYPE_CHECKING: from gpytorch.distributions.multivariate_normal import MultivariateNormal +start_time = time.time() +settings.debug._set_state(True) +seed_all() + TRAIN_CONFIGS = 50 TEST_CONFIGS = 10 TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS @@ -26,90 +33,42 @@ N_CATEGORICAL = 1 N_CATEGORICAL_VALUES_PER_CATEGORY = 2 N_GRAPH = 1 + assert N_GRAPH == 1, "This example only supports a single graph feature" -kernels = [] +# Generate random data +X = torch.cat([ + torch.rand((TOTAL_CONFIGS, N_NUMERICAL), dtype=torch.float64), + torch.randint(0, N_CATEGORICAL_VALUES_PER_CATEGORY, (TOTAL_CONFIGS, N_CATEGORICAL), + dtype=torch.float64), + torch.arange(TOTAL_CONFIGS, dtype=torch.float64).unsqueeze(1) +], dim=1) -# Create numerical and categorical features -X = torch.empty( - size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL + N_GRAPH), - dtype=torch.float64, -) -if N_NUMERICAL > 0: - X[:, :N_NUMERICAL] = torch.rand( - size=(TOTAL_CONFIGS, N_NUMERICAL), - dtype=torch.float64, - ) - -if N_CATEGORICAL > 0: - X[:, N_NUMERICAL : N_NUMERICAL + N_CATEGORICAL] = torch.randint( - 0, - N_CATEGORICAL_VALUES_PER_CATEGORY, - size=(TOTAL_CONFIGS, N_CATEGORICAL), - dtype=torch.float64, - ) - -# Create random graph architectures -graphs = [] -for _ in range(TOTAL_CONFIGS): - G = nx.erdos_renyi_graph(n=5, p=0.5) # Random graph with 5 nodes - graphs.append(G) - -# Assign a new index column to the graphs -X[:, -1] = torch.arange(TOTAL_CONFIGS, dtype=torch.float64) - -# Create random target values -y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64) + 0.5 +# Generate random graphs +graphs = [nx.erdos_renyi_graph(5, 0.5) for _ in range(TOTAL_CONFIGS)] + +# Generate random target values +y = torch.rand(TOTAL_CONFIGS, dtype=torch.float64) + 0.5 # Split into train and test sets -train_x = X[:TRAIN_CONFIGS] -train_graphs = graphs[:TRAIN_CONFIGS] -train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch - -test_x = X[TRAIN_CONFIGS:] -test_graphs = graphs[TRAIN_CONFIGS:] -test_y = y[TRAIN_CONFIGS:].unsqueeze(-1) - - -# Setup kernels for numerical and categorical features -if N_NUMERICAL > 0: - matern = ScaleKernel( - MaternKernel( - nu=2.5, - ard_num_dims=N_NUMERICAL, - active_dims=tuple(range(N_NUMERICAL)), - ), - ) - kernels.append(matern) - -if N_CATEGORICAL > 0: - hamming = ScaleKernel( - CategoricalKernel( - ard_num_dims=N_CATEGORICAL, - active_dims=tuple(range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)), - ), - ) - kernels.append(hamming) - -if N_GRAPH > 0: - wl_kernel = ScaleKernel( - TorchWLKernel( - graph_lookup=train_graphs, - n_iter=5, - normalize=True, - active_dims=(X.shape[1] - 1,), # Last column - ) - ) - kernels.append(wl_kernel) - - -# Combine numerical and categorical kernels -kernel = AdditiveKernel(*kernels) +train_x, test_x = X[:TRAIN_CONFIGS], X[TRAIN_CONFIGS:] +train_graphs, test_graphs = graphs[:TRAIN_CONFIGS], graphs[TRAIN_CONFIGS:] +train_y, test_y = y[:TRAIN_CONFIGS].unsqueeze(-1), y[TRAIN_CONFIGS:].unsqueeze(-1) -from botorch.models import SingleTaskGP +train_x, test_x = min_max_scale(train_x), min_max_scale(test_x) -# Initialize the mixed GP -gp = SingleTaskGP(train_X=train_x, train_Y=train_y, covar_module=kernel) +kernels = [ + ScaleKernel( + MaternKernel(nu=2.5, ard_num_dims=N_NUMERICAL, active_dims=range(N_NUMERICAL))), + ScaleKernel(CategoricalKernel( + ard_num_dims=N_CATEGORICAL, + active_dims=range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL))), + ScaleKernel(TorchWLKernel( + graph_lookup=train_graphs, n_iter=5, normalize=True, + active_dims=(X.shape[1] - 1,))) +] + +gp = SingleTaskGP(train_X=train_x, train_Y=train_y, covar_module=AdditiveKernel(*kernels)) # Compute the posterior distribution # The wl_kernel will use the indices to index into the training graphs it is holding @@ -139,13 +98,9 @@ def set_graph_lookup(_gp: SingleTaskGP, new_graphs: list[nx.Graph]) -> Iterator[ uncertainties = posterior.variance.sqrt() covar = posterior.covariance_matrix -# =============== Fitting the GP using botorch =============== - - mll = ExactMarginalLogLikelihood(gp.likelihood, gp) fit_gpytorch_mll(mll) -# Define the acquisition function acq_function = qLogNoisyExpectedImprovement( model=gp, X_baseline=train_x, @@ -153,36 +108,18 @@ def set_graph_lookup(_gp: SingleTaskGP, new_graphs: list[nx.Graph]) -> Iterator[ prune_baseline=True, ) -# Define bounds -bounds = torch.tensor( - [ - [0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL + [-1.0] * N_GRAPH, - [1.0] * N_NUMERICAL - + [float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL - + [len(X) - 1] * N_GRAPH, - ] -) +bounds = torch.tensor([ + [0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL + [-1.0] * N_GRAPH, + [1.0] * N_NUMERICAL + [ + float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL + [ + len(X) - 1] * N_GRAPH, +]) + +cats_per_column = {i: list(range(N_CATEGORICAL_VALUES_PER_CATEGORY)) for i in + range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL)} +fixed_cats = [dict(zip(cats_per_column.keys(), combo, strict=False)) for combo in + product(*cats_per_column.values())] -# Setup categorical feature optimization -cats_per_column: dict[int, list[float]] = { - column_ix: [float(i) for i in range(N_CATEGORICAL_VALUES_PER_CATEGORY)] - for column_ix in range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL) -} - -# Generate fixed categorical features -fixed_cats: list[dict[int, float]] -if len(cats_per_column) == 1: - col, choice_indices = next(iter(cats_per_column.items())) - fixed_cats = [{col: i} for i in choice_indices] -else: - fixed_cats = [ - dict(zip(cats_per_column.keys(), combo, strict=False)) - for combo in product(*cats_per_column.values()) - ] - - -print("------------------") # noqa: T201 -# Use the graph-optimized acquisition function best_candidate, best_score = optimize_acqf_graph( acq_function=acq_function, bounds=bounds, @@ -193,3 +130,7 @@ def set_graph_lookup(_gp: SingleTaskGP, new_graphs: list[nx.Graph]) -> Iterator[ raw_samples=16, q=1, ) + +print(f"Best candidate: {best_candidate}") +print(f"Best score: {best_score}") +print(f"Elapsed time: {time.time() - start_time} seconds") diff --git a/grakel_replace/utils.py b/grakel_replace/utils.py new file mode 100644 index 00000000..245550eb --- /dev/null +++ b/grakel_replace/utils.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import random + +import numpy as np +import torch + + +def seed_all(seed: int = 100): + """Seed all random generators for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # Ensure reproducibility with CuDNN (may reduce performance) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def min_max_scale(tensor: torch.Tensor) -> torch.Tensor: + """Scale the input tensor to the range [0, 1].""" + min_vals = tensor.min(dim=0, keepdim=True).values + max_vals = tensor.max(dim=0, keepdim=True).values + return (tensor - min_vals) / (max_vals - min_vals) From 0a609f710e01eddc6fb46b49fdfd440bd36e492e Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Tue, 24 Dec 2024 09:48:32 +0100 Subject: [PATCH 29/32] Refactor optimize.py --- grakel_replace/optimize.py | 44 +++-- grakel_replace/torch_wl_kernel.py | 287 ++++++++++++------------------ 2 files changed, 133 insertions(+), 198 deletions(-) diff --git a/grakel_replace/optimize.py b/grakel_replace/optimize.py index 4f21886e..b54b671c 100644 --- a/grakel_replace/optimize.py +++ b/grakel_replace/optimize.py @@ -15,8 +15,6 @@ from botorch.models.gp_regression_mixed import Kernel -# Making predictions on test data -# No the wl_kernel needs to be aware of the test graphs @contextmanager def set_graph_lookup( kernel: Kernel, @@ -24,7 +22,17 @@ def set_graph_lookup( *, append: bool = True, ) -> Iterator[None]: + """Context manager to temporarily set the graph lookup for a kernel. + + Args: + kernel (Kernel): The kernel whose graph lookup is to be set. + new_graphs (list[nx.Graph]): The new graphs to set in the graph lookup. + append (bool, optional): Whether to append the new graphs to the existing graph + lookup. Defaults to True. + """ kernel_prev_graphs: list[tuple[Kernel, list[nx.Graph]]] = [] + + # Determine the modules to update based on the kernel type if isinstance(kernel, TorchWLKernel): modules = [kernel] else: @@ -33,6 +41,7 @@ def set_graph_lookup( ), "Kernel module must have sub_kernels method." modules = [k for k in kernel.sub_kernels() if isinstance(k, TorchWLKernel)] + # Save the current graph lookup and set the new graph lookup for kern in modules: kernel_prev_graphs.append((kern, kern.graph_lookup)) if append: @@ -42,8 +51,9 @@ def set_graph_lookup( yield - for _kern, _prev_graphs in kernel_prev_graphs: - _kern.set_graph_lookup(_prev_graphs) + # Restore the original graph lookup after the context manager exits + for kern, prev_graphs in kernel_prev_graphs: + kern.set_graph_lookup(prev_graphs) def sample_graphs(graphs: list[nx.Graph], num_samples: int) -> list[nx.Graph]: @@ -116,35 +126,23 @@ def optimize_acqf_graph( raise ValueError("train_graphs cannot be None.") sampled_graphs = sample_graphs(train_graphs, num_samples=num_graph_samples) - gp = acq_function.model - covar_module = gp.covar_module - best_candidates, best_scores = [], [] - - TODO_GRAPH_COLUMN_INDEX = bounds.shape[1] - 1 - + graph_idx = bounds.shape[1] - 1 + # Iterate through all the kernels and include the sampled graph. for _graph in sampled_graphs: - # This is new, we essentially iterate through all the kernels and - # include the sampled graph. - with set_graph_lookup(covar_module, [_graph], append=True): + with set_graph_lookup(acq_function.model.covar_module, [_graph], append=True): for fixed_features in fixed_features_list or [{}]: # We then consider this graph as a fixed feature, i.e. in the X's # generated during acquisition, the graph column will just be full # of `-1` indicating to select the very last graph in the lookup # they used. - _fixed_features = {**fixed_features, TODO_GRAPH_COLUMN_INDEX: -1.0} - candidates, scores = optimize_acqf_mixed( acq_function=acq_function, bounds=bounds, - fixed_features_list=[_fixed_features], + fixed_features_list=[{**fixed_features, graph_idx: -1.0}], num_restarts=num_restarts, - raw_samples=raw_samples, - q=q, - ) + raw_samples=raw_samples, q=q) best_candidates.append(candidates) best_scores.append(scores) - - best_scores_tensor = torch.tensor(best_scores) - best_idx = torch.argmax(best_scores_tensor) - return best_candidates[best_idx], best_scores_tensor[best_idx].item() + best_idx = torch.argmax(torch.tensor(best_scores)) + return best_candidates[best_idx], best_scores[best_idx].item() diff --git a/grakel_replace/torch_wl_kernel.py b/grakel_replace/torch_wl_kernel.py index 5c946555..656d4688 100644 --- a/grakel_replace/torch_wl_kernel.py +++ b/grakel_replace/torch_wl_kernel.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections import Counter from typing import Any import networkx as nx @@ -26,11 +25,25 @@ def __init__( self.n_iter = n_iter self.normalize = normalize - # NOTE: set in the `super().__init__()` - self.active_dims: torch.Tensor + # Cache adjacency matrices and initial node labels + self.adjacency_cache = {} + self.label_cache = {} + + self._precompute_graph_data() + + def _precompute_graph_data(self) -> None: + """Precompute adjacency matrices and initial node labels for all graphs.""" + self.adjacency_cache = {} + self.label_cache = {} + + for idx, graph in enumerate(self.graph_lookup): + self.adjacency_cache[idx] = self._get_sparse_adj(graph) + self.label_cache[idx] = self._init_node_labels(graph) def set_graph_lookup(self, graph_lookup: list[nx.Graph]) -> None: + """Update the graph lookup and refresh the cached data.""" self.graph_lookup = graph_lookup + self._precompute_graph_data() def forward( self, @@ -47,45 +60,9 @@ def forward( assert x1.shape[-1] == 1, "Last dimension must be the graph index" assert x2.shape[-1] == 1, "Last dimension must be the graph index" - # TODO: Optimizations - # - # 1. We're computing the whole K Matrix, but we only need the K_x1_x2 - # - # K - # -------------------- - # | K_x1_x1 K_x1_x2 | - # | K_x2_x1 K_x2_x2 | - # -------------------- - # - # However in the case where x1 == x2, we can shortcut this slightly as in - # the above, K_x1_x2 == K_x2_x1 == K_x1_x1 == K_x2_x2 - # This shortcut is implemented below based on this flag. - # - # 2. The _TorchWLKernel used below has the following properties, which - # get set on forward. In the case where x1.ndim == 3 then the first dim is - # the `q` dim. Doesn't matter what it is other than we end up repeating the - # processing the graphs `q` times. Given that it's likely that the indices - # in last dimension are likely to be constant (i.e. all `4`, indicating the - # `4th` graph, we are effectively doing a lot of extra calculation. We could - # shortcut this by pre-computing these for each index. Could be nice to somehow - # have the inned `_TorchWLKernel` be aware of this extra dimension but it's - # fine if not as long as we can reduce the extraneuous computations. We could - # change the interface of `_TorchWLKernel` to take in the raw processed - # tensors instead of `nx.Graph` objects, which we would instead preprocess here. - # If that's the case, we could move the `_TorchWLKernel` to essentially just - # be functions we call instead with the correct pre-processed data. - # - # .self.label_dict - # .self.label_counter - # x1_is_x2 = torch.equal(x1, x2) - # NOTE: The active dim is already selected out for us and is the last dimension - # (not including whatever happens when last_dim_is_batch) is True. if x1.ndim == 3: - # - x1: torch.Size([32, 5, 1]) - # - x2: torch.Size([32, 55, 1]) - # - output: torch.Size([32, 5, 55]) q_dim_size = x1.shape[0] assert x2.shape[0] == q_dim_size @@ -95,61 +72,33 @@ def forward( return out if x1_is_x2: - _ixs = x1.flatten().to(torch.int64).tolist() - all_graphs = [self.graph_lookup[i] for i in _ixs] - - # No selection requires + indices = x1.flatten().to(torch.int64).tolist() + all_graphs = indices select = None else: - _ixs1 = x1.flatten().to(torch.int64).tolist() - _ixs2 = x2.flatten().to(torch.int64).tolist() - all_graphs = [self.graph_lookup[i] for i in _ixs1 + _ixs2] + indices1 = x1.flatten().to(torch.int64).tolist() + indices2 = x2.flatten().to(torch.int64).tolist() + all_graphs = indices1 + indices2 + select = lambda K: K[: len(indices1), len(indices1):] + + # Handle the special case for -1 + all_graphs = [ + len(self.graph_lookup) - 1 if i == -1 else i for i in all_graphs + ] - # Select out K_x1_x2 - select = lambda _K: _K[: len(_ixs1), len(_ixs1) :] + # Use cached adjacency matrices and labels + adj_matrices = [self.adjacency_cache[i] for i in all_graphs] + label_tensors = [self.label_cache[i] for i in all_graphs] _kernel = _TorchWLKernel(n_iter=self.n_iter, normalize=self.normalize) - K = _kernel(all_graphs) + K = _kernel(adj_matrices, label_tensors) K_selected = K if select is None else select(K) if diag: return torch.diag(K_selected) return K_selected - -class _TorchWLKernel(nn.Module): - """A custom implementation of Weisfeiler-Lehman (WL) Kernel in PyTorch. - - The WL Kernel is a graph kernel that measures similarity between graphs based on - their structural properties. It works by iteratively updating node labels based on - their neighborhoods and computing feature vectors from label distributions. - - Args: - n_iter: Number of WL iterations to perform - normalize: bool, optional. Whether to normalize the kernel matrix - - Attributes: - device: torch.device for computation (CPU/GPU) - label_dict: Mapping from node labels to numerical indices - label_counter: Counter for generating new label indices - """ - - def __init__(self, n_iter: int = 5, *, normalize: bool = True) -> None: - super().__init__() - self.n_iter = n_iter - self.normalize = normalize - self.device: torch.device = torch.device("cpu") - self.label_dict: dict[str, int] = {} - self.label_counter: int = 0 - def _get_sparse_adj(self, graph: nx.Graph) -> torch.sparse.Tensor: - """Convert a NetworkX graph to a sparse adjacency tensor. - - Args: - graph: Input NetworkX graph - - Returns: - Sparse tensor representation of the graph's adjacency matrix - """ + """Convert a NetworkX graph to a sparse adjacency tensor.""" edges = list(graph.edges()) num_nodes = graph.number_of_nodes() @@ -158,42 +107,62 @@ def _get_sparse_adj(self, graph: nx.Graph) -> torch.sparse.Tensor: indices=torch.empty((2, 0), dtype=torch.long), values=torch.empty(0), size=(num_nodes, num_nodes), - device=self.device, + device=torch.device("cpu"), ) - # Create bidirectional edge indices for undirected graph edge_indices: list[tuple[int, int]] = edges + [(v, u) for u, v in edges] rows, cols = zip(*edge_indices, strict=False) - indices = torch.tensor([rows, cols], dtype=torch.long, device=self.device) - values = torch.ones(len(edge_indices), dtype=torch.float, device=self.device) + indices = torch.tensor([rows, cols], dtype=torch.long) + values = torch.ones(len(edge_indices), dtype=torch.float) return torch.sparse_coo_tensor( - indices, values, (num_nodes, num_nodes), device=self.device + indices, values, (num_nodes, num_nodes), device=torch.device("cpu") ) def _init_node_labels(self, graph: nx.Graph) -> torch.Tensor: - """Initialize node label tensor from graph attributes. - - Args: - graph: Input NetworkX graph - - Returns: - Tensor of numerical node label indices - """ + """Initialize node label tensor from graph attributes.""" labels: list[int] = [] + label_dict: dict[str, int] = {} + label_counter = 0 for node in range(graph.number_of_nodes()): if "label" in graph.nodes[node]: label = graph.nodes[node]["label"] else: label = str(node) - if label not in self.label_dict: - self.label_dict[label] = self.label_counter - self.label_counter += 1 - labels.append(self.label_dict[label]) + if label not in label_dict: + label_dict[label] = label_counter + label_counter += 1 + labels.append(label_dict[label]) + + return torch.tensor(labels, dtype=torch.long) + + +class _TorchWLKernel(nn.Module): + """A custom implementation of Weisfeiler-Lehman (WL) Kernel in PyTorch. + + The WL Kernel is a graph kernel that measures similarity between graphs based on + their structural properties. It works by iteratively updating node labels based on + their neighborhoods and computing feature vectors from label distributions. + + Args: + n_iter: Number of WL iterations to perform + normalize: bool, optional. Whether to normalize the kernel matrix + + Attributes: + device: torch.device for computation (CPU/GPU) + label_dict: Mapping from node labels to numerical indices + label_counter: Counter for generating new label indices + """ - return torch.tensor(labels, dtype=torch.long, device=self.device) + def __init__(self, n_iter: int = 5, *, normalize: bool = True) -> None: + super().__init__() + self.n_iter = n_iter + self.normalize = normalize + self.device: torch.device = torch.device("cpu") + self.label_dict: dict[str, int] = {} + self.label_counter: int = 0 def _wl_iteration( self, adj: torch.sparse.Tensor, labels: torch.Tensor @@ -208,29 +177,20 @@ def _wl_iteration( Returns: Updated node label tensor """ - new_labels: list[int] = [] indices = adj.coalesce().indices() + new_labels = torch.empty_like(labels) for node in range(adj.size(0)): - # Step 1. Get current node's label - node_label = labels[node].item() - # Step 2. Get neighbor labels for current node neighbors = indices[1][indices[0] == node] neighbor_labels = sorted([labels[n].item() for n in neighbors]) - # Check if all neighbors have the same label as the current node - if all(labels[n] == labels[node] for n in neighbors): - new_labels.append(node_label) - else: - # Step 3. Create a new label combining node and neighbor information - combined_label = f"{node_label}_{neighbor_labels}" - # Step 4. Assign a numerical index to this new label - if combined_label not in self.label_dict: - self.label_dict[combined_label] = self.label_counter - self.label_counter += 1 - new_labels.append(self.label_dict[combined_label]) + combined_label = (labels[node].item(), tuple(neighbor_labels)) + if combined_label not in self.label_dict: + self.label_dict[combined_label] = self.label_counter + self.label_counter += 1 + new_labels[node] = self.label_dict[combined_label] - return torch.tensor(new_labels, dtype=torch.long, device=self.device) + return new_labels def _compute_feature_vector(self, labels: torch.Tensor, size: int) -> torch.Tensor: """Compute histogram feature vector from node labels. @@ -242,79 +202,56 @@ def _compute_feature_vector(self, labels: torch.Tensor, size: int) -> torch.Tens Returns: Feature vector representing label distribution """ - # Handle the case where all node labels are the same - unique_labels = set(labels.cpu().numpy()) - if len(unique_labels) == 1: - feature = torch.zeros(size, device=self.device) - feature[labels[0].item()] = len(labels) - return feature - - # Count the frequency of each label - label_counts = Counter(labels.cpu().numpy()) - # In the feature vector, each position represents a label - feature = torch.zeros(size, device=self.device) - - for label, count in label_counts.items(): - if label < size: # Safety check - # The value represents how many times that label appears in the graph - feature[label] = count - + feature = torch.zeros(size, device=self.device, dtype=torch.float32) + unique, counts = torch.unique(labels, return_counts=True) + feature[unique] = counts.to(dtype=feature.dtype) return feature - def forward(self, graphs: list[nx.Graph]) -> torch.Tensor: + def forward( + self, + adj_matrices: list[torch.sparse.Tensor], + label_tensors: list[torch.Tensor], + ) -> torch.Tensor: """Compute WL kernel matrix for a list of graphs. Args: - graphs: List of NetworkX graphs to compare + adj_matrices: Precomputed sparse adjacency matrices for graphs. + label_tensors: Precomputed node label tensors for graphs. Returns: - Kernel matrix containing pairwise graph similarities - - Raises: - TypeError: If input is not a list of NetworkX graphs + Kernel matrix containing pairwise graph similarities. """ - # Validate input - if not isinstance(graphs, list) or not all( - isinstance(g, nx.Graph) for g in graphs - ): - raise TypeError("Expected input type is a list of NetworkX graphs.") - - # Setup computation - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.label_dict = {} - self.label_counter = 0 - - # Handle a case of empty graphs list or empty individual graphs - if not graphs or all(g.number_of_nodes() == 0 for g in graphs): - return torch.zeros((len(graphs), len(graphs)), device=self.device) + # Validate inputs + if len(adj_matrices) != len(label_tensors): + raise ValueError("Mismatch between adjacency matrices and label tensors.") - # Convert graphs to sparse adjacency matrices and initialize labels - adj_matrices = [self._get_sparse_adj(g) for g in graphs] - # Initialize node labels - either use provided labels or default to node indices - label_tensors = [self._init_node_labels(g) for g in graphs] - - # Collect label tensors from all iterations - all_label_tensors: list[list[torch.Tensor]] = [label_tensors] + # Perform WL iterations to update node labels + all_labels = [label_tensors] for _ in range(self.n_iter): new_labels = [ self._wl_iteration(adj, labels) - for adj, labels in zip(adj_matrices, all_label_tensors[-1], strict=False) + for adj, labels in zip(adj_matrices, all_labels[-1], strict=False) ] - all_label_tensors.append(new_labels) - - # Compute feature matrices using final label count - feature_matrices = [ - torch.stack( - [ - self._compute_feature_vector(labels, self.label_counter) - for labels in iteration_labels - ] + all_labels.append(new_labels) + + # Compute feature vectors for each graph at each iteration + feature_vectors = [] + label_counter = max( + max(labels.max().item() for labels in label_set) + for label_set in all_labels + ) + 1 + for iteration_labels in all_labels: + feature_vectors.append( + torch.stack( + [ + self._compute_feature_vector(labels, label_counter) + for labels in iteration_labels + ] + ) ) - for iteration_labels in all_label_tensors - ] # Combine features from all iterations - final_features = torch.stack(feature_matrices).sum(dim=0) + final_features = torch.stack(feature_vectors).sum(dim=0) # Compute kernel matrix (similarity matrix) kernel_matrix = torch.mm(final_features, final_features.t()) @@ -322,7 +259,7 @@ def forward(self, graphs: list[nx.Graph]) -> torch.Tensor: # Apply normalization if requested if self.normalize: diag = torch.sqrt(torch.diag(kernel_matrix)) - kernel_matrix = kernel_matrix / (diag.unsqueeze(0) * diag.unsqueeze(1)) + kernel_matrix /= diag.unsqueeze(0) * diag.unsqueeze(1) return kernel_matrix From 5486dcc3704ad0cc85730c9ffb9ecad868368e43 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Tue, 24 Dec 2024 11:35:11 +0100 Subject: [PATCH 30/32] Speed up WL kernel computations --- grakel_replace/torch_wl_kernel.py | 42 +++++++++++++++++++------------ 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/grakel_replace/torch_wl_kernel.py b/grakel_replace/torch_wl_kernel.py index 656d4688..8615e4a3 100644 --- a/grakel_replace/torch_wl_kernel.py +++ b/grakel_replace/torch_wl_kernel.py @@ -5,7 +5,6 @@ import networkx as nx import torch from botorch.models.gp_regression_mixed import Kernel -from torch import nn class TorchWLKernel(Kernel): @@ -139,7 +138,7 @@ def _init_node_labels(self, graph: nx.Graph) -> torch.Tensor: return torch.tensor(labels, dtype=torch.long) -class _TorchWLKernel(nn.Module): +class _TorchWLKernel(torch.nn.Module): """A custom implementation of Weisfeiler-Lehman (WL) Kernel in PyTorch. The WL Kernel is a graph kernel that measures similarity between graphs based on @@ -161,12 +160,11 @@ def __init__(self, n_iter: int = 5, *, normalize: bool = True) -> None: self.n_iter = n_iter self.normalize = normalize self.device: torch.device = torch.device("cpu") - self.label_dict: dict[str, int] = {} + self.label_dict: dict[tuple, int] = {} self.label_counter: int = 0 - def _wl_iteration( - self, adj: torch.sparse.Tensor, labels: torch.Tensor - ) -> torch.Tensor: + def _wl_iteration(self, adj: torch.sparse.Tensor, + labels: torch.Tensor) -> torch.Tensor: """Perform one WL iteration to update node labels. Concatenate own label with sorted neighbor labels. @@ -177,18 +175,30 @@ def _wl_iteration( Returns: Updated node label tensor """ - indices = adj.coalesce().indices() + # Ensure the adjacency matrix is in COO format + adj = adj.coalesce() + indices = adj.indices() + adj.values() + + # Get the neighbors for each node + rows, cols = indices + neighbors = cols[rows] + + # Create a list of combined labels for each node + combined_labels = [] + for node in range(labels.size(0)): + node_neighbors = neighbors[rows == node] + node_neighbor_labels = labels[node_neighbors] + combined_label = (labels[node].item(), tuple(node_neighbor_labels.tolist())) + combined_labels.append(combined_label) + + # Update the label dictionary and counter new_labels = torch.empty_like(labels) - - for node in range(adj.size(0)): - neighbors = indices[1][indices[0] == node] - neighbor_labels = sorted([labels[n].item() for n in neighbors]) - - combined_label = (labels[node].item(), tuple(neighbor_labels)) - if combined_label not in self.label_dict: - self.label_dict[combined_label] = self.label_counter + for i, label in enumerate(combined_labels): + if label not in self.label_dict: + self.label_dict[label] = self.label_counter self.label_counter += 1 - new_labels[node] = self.label_dict[combined_label] + new_labels[i] = self.label_dict[label] return new_labels From f140c56c3931861075a93655e9571d52ad977c35 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Tue, 24 Dec 2024 14:25:48 +0100 Subject: [PATCH 31/32] Process wl iterations in batches --- grakel_replace/torch_wl_kernel.py | 159 +++++++++++++++--------------- 1 file changed, 79 insertions(+), 80 deletions(-) diff --git a/grakel_replace/torch_wl_kernel.py b/grakel_replace/torch_wl_kernel.py index 8615e4a3..9d5c62a5 100644 --- a/grakel_replace/torch_wl_kernel.py +++ b/grakel_replace/torch_wl_kernel.py @@ -5,6 +5,8 @@ import networkx as nx import torch from botorch.models.gp_regression_mixed import Kernel +from torch import Tensor +from torch.nn import Module class TorchWLKernel(Kernel): @@ -46,13 +48,13 @@ def set_graph_lookup(self, graph_lookup: list[nx.Graph]) -> None: def forward( self, - x1: torch.Tensor, - x2: torch.Tensor, + x1: Tensor, + x2: Tensor, *, diag: bool = False, last_dim_is_batch: bool = False, **params: Any, - ): + ) -> Tensor: if last_dim_is_batch: raise NotImplementedError("TODO: Figure this out") @@ -96,7 +98,7 @@ def forward( return torch.diag(K_selected) return K_selected - def _get_sparse_adj(self, graph: nx.Graph) -> torch.sparse.Tensor: + def _get_sparse_adj(self, graph: nx.Graph) -> Tensor: """Convert a NetworkX graph to a sparse adjacency tensor.""" edges = list(graph.edges()) num_nodes = graph.number_of_nodes() @@ -106,7 +108,7 @@ def _get_sparse_adj(self, graph: nx.Graph) -> torch.sparse.Tensor: indices=torch.empty((2, 0), dtype=torch.long), values=torch.empty(0), size=(num_nodes, num_nodes), - device=torch.device("cpu"), + device=self.device, ) edge_indices: list[tuple[int, int]] = edges + [(v, u) for u, v in edges] @@ -116,10 +118,10 @@ def _get_sparse_adj(self, graph: nx.Graph) -> torch.sparse.Tensor: values = torch.ones(len(edge_indices), dtype=torch.float) return torch.sparse_coo_tensor( - indices, values, (num_nodes, num_nodes), device=torch.device("cpu") + indices, values, (num_nodes, num_nodes), device=self.device ) - def _init_node_labels(self, graph: nx.Graph) -> torch.Tensor: + def _init_node_labels(self, graph: nx.Graph) -> Tensor: """Initialize node label tensor from graph attributes.""" labels: list[int] = [] label_dict: dict[str, int] = {} @@ -135,10 +137,10 @@ def _init_node_labels(self, graph: nx.Graph) -> torch.Tensor: label_counter += 1 labels.append(label_dict[label]) - return torch.tensor(labels, dtype=torch.long) + return torch.tensor(labels, dtype=torch.long, device=self.device) -class _TorchWLKernel(torch.nn.Module): +class _TorchWLKernel(Module): """A custom implementation of Weisfeiler-Lehman (WL) Kernel in PyTorch. The WL Kernel is a graph kernel that measures similarity between graphs based on @@ -159,69 +161,83 @@ def __init__(self, n_iter: int = 5, *, normalize: bool = True) -> None: super().__init__() self.n_iter = n_iter self.normalize = normalize - self.device: torch.device = torch.device("cpu") + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.label_dict: dict[tuple, int] = {} self.label_counter: int = 0 + self.hash_module = torch.nn.Linear(2, 1, bias=False) + torch.nn.init.normal_(self.hash_module.weight) - def _wl_iteration(self, adj: torch.sparse.Tensor, - labels: torch.Tensor) -> torch.Tensor: - """Perform one WL iteration to update node labels. - Concatenate own label with sorted neighbor labels. - - Args: - adj: Sparse adjacency matrix - labels: Current node label tensor - - Returns: - Updated node label tensor - """ - # Ensure the adjacency matrix is in COO format + def _wl_iteration(self, adj: Tensor, labels: Tensor) -> Tensor: + """Perform one iteration of the WL algorithm to update node labels.""" adj = adj.coalesce() indices = adj.indices() - adj.values() - - # Get the neighbors for each node rows, cols = indices - neighbors = cols[rows] - - # Create a list of combined labels for each node - combined_labels = [] - for node in range(labels.size(0)): - node_neighbors = neighbors[rows == node] - node_neighbor_labels = labels[node_neighbors] - combined_label = (labels[node].item(), tuple(node_neighbor_labels.tolist())) - combined_labels.append(combined_label) - - # Update the label dictionary and counter - new_labels = torch.empty_like(labels) - for i, label in enumerate(combined_labels): - if label not in self.label_dict: - self.label_dict[label] = self.label_counter - self.label_counter += 1 - new_labels[i] = self.label_dict[label] + num_nodes = labels.size(0) - return new_labels + # Create a mask for each node's neighbors + neighbor_mask = torch.zeros((num_nodes, num_nodes), dtype=torch.bool, + device=self.device) + neighbor_mask[rows, cols] = True - def _compute_feature_vector(self, labels: torch.Tensor, size: int) -> torch.Tensor: - """Compute histogram feature vector from node labels. + # Get neighbor labels for each node + # Shape: [num_nodes, num_nodes] + neighbor_labels = labels.unsqueeze(0).expand(num_nodes, -1) + neighbor_labels = neighbor_labels.masked_fill(~neighbor_mask, -1) - Args: - labels: Node label tensor - size: Size of the feature vector + # Sort neighbor labels for each node + sorted_neighbor_labels, _ = torch.sort(neighbor_labels, dim=1, descending=True) - Returns: - Feature vector representing label distribution - """ - feature = torch.zeros(size, device=self.device, dtype=torch.float32) - unique, counts = torch.unique(labels, return_counts=True) - feature[unique] = counts.to(dtype=feature.dtype) - return feature + # Remove padding (-1 values) from sorted labels + valid_neighbors_mask = sorted_neighbor_labels != -1 + max_neighbors = valid_neighbors_mask.sum(1).max().item() + sorted_neighbor_labels = sorted_neighbor_labels[:, :max_neighbors] + + # Combine node labels with neighbor labels + node_labels_expanded = labels.unsqueeze(1).expand(-1, max_neighbors) + + # Create feature vectors + features = torch.cat([ + node_labels_expanded.unsqueeze(-1).float(), + sorted_neighbor_labels.unsqueeze(-1).float() + ], dim=-1) + + # Hash the combined features + hashed_features = self.hash_module(features).squeeze(-1) + hashed_labels = hashed_features.sum(dim=1) + + # Convert to discrete labels + _, new_labels = torch.unique(hashed_labels, sorted=True, return_inverse=True) + return new_labels + + def _compute_feature_vector(self, all_labels: list[list[Tensor]]) -> Tensor: + """Compute feature vectors for all graphs in a batch.""" + max_label = 0 + for iteration_labels in all_labels: + for labels in iteration_labels: + max_label = max(max_label, labels.max().item()) + + batch_size = len(all_labels[0]) + features = torch.zeros((batch_size, max_label + 1), + dtype=torch.float32, device=self.device) + + # Accumulate label counts for each graph + for graph_idx in range(batch_size): + # Sum contributions from all WL iterations + for iteration_labels in all_labels: + graph_labels = iteration_labels[graph_idx] + label_counts = torch.bincount( + graph_labels, + minlength=max_label + 1 + ).float() + features[graph_idx] += label_counts + + return features def forward( self, - adj_matrices: list[torch.sparse.Tensor], - label_tensors: list[torch.Tensor], - ) -> torch.Tensor: + adj_matrices: list[Tensor], + label_tensors: list[Tensor], + ) -> Tensor: """Compute WL kernel matrix for a list of graphs. Args: @@ -231,11 +247,10 @@ def forward( Returns: Kernel matrix containing pairwise graph similarities. """ - # Validate inputs if len(adj_matrices) != len(label_tensors): raise ValueError("Mismatch between adjacency matrices and label tensors.") - # Perform WL iterations to update node labels + # Perform WL iterations to update the node labels all_labels = [label_tensors] for _ in range(self.n_iter): new_labels = [ @@ -244,24 +259,8 @@ def forward( ] all_labels.append(new_labels) - # Compute feature vectors for each graph at each iteration - feature_vectors = [] - label_counter = max( - max(labels.max().item() for labels in label_set) - for label_set in all_labels - ) + 1 - for iteration_labels in all_labels: - feature_vectors.append( - torch.stack( - [ - self._compute_feature_vector(labels, label_counter) - for labels in iteration_labels - ] - ) - ) - - # Combine features from all iterations - final_features = torch.stack(feature_vectors).sum(dim=0) + # Compute feature vectors for each graph in the batch + final_features = self._compute_feature_vector(all_labels) # Compute kernel matrix (similarity matrix) kernel_matrix = torch.mm(final_features, final_features.t()) @@ -269,7 +268,7 @@ def forward( # Apply normalization if requested if self.normalize: diag = torch.sqrt(torch.diag(kernel_matrix)) - kernel_matrix /= diag.unsqueeze(0) * diag.unsqueeze(1) + kernel_matrix /= (diag.unsqueeze(0) * diag.unsqueeze(1)) return kernel_matrix From 371b53009eacdea440e21ac44ae53583748d6373 Mon Sep 17 00:00:00 2001 From: vladislavalerievich Date: Wed, 25 Dec 2024 08:57:50 +0100 Subject: [PATCH 32/32] Use CSR --- grakel_replace/torch_wl_kernel.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/grakel_replace/torch_wl_kernel.py b/grakel_replace/torch_wl_kernel.py index 9d5c62a5..e4f6fd2b 100644 --- a/grakel_replace/torch_wl_kernel.py +++ b/grakel_replace/torch_wl_kernel.py @@ -80,7 +80,7 @@ def forward( indices1 = x1.flatten().to(torch.int64).tolist() indices2 = x2.flatten().to(torch.int64).tolist() all_graphs = indices1 + indices2 - select = lambda K: K[: len(indices1), len(indices1):] + select = lambda K: K[:len(indices1), len(indices1):] # Handle the special case for -1 all_graphs = [ @@ -119,7 +119,7 @@ def _get_sparse_adj(self, graph: nx.Graph) -> Tensor: return torch.sparse_coo_tensor( indices, values, (num_nodes, num_nodes), device=self.device - ) + ).to_sparse_csr() # Convert to CSR for efficient operations def _init_node_labels(self, graph: nx.Graph) -> Tensor: """Initialize node label tensor from graph attributes.""" @@ -169,6 +169,10 @@ def __init__(self, n_iter: int = 5, *, normalize: bool = True) -> None: def _wl_iteration(self, adj: Tensor, labels: Tensor) -> Tensor: """Perform one iteration of the WL algorithm to update node labels.""" + # Ensure the adjacency matrix is in COO format before coalescing + if adj.layout == torch.sparse_csr: + adj = adj.to_sparse_coo() + adj = adj.coalesce() indices = adj.indices() rows, cols = indices @@ -180,7 +184,6 @@ def _wl_iteration(self, adj: Tensor, labels: Tensor) -> Tensor: neighbor_mask[rows, cols] = True # Get neighbor labels for each node - # Shape: [num_nodes, num_nodes] neighbor_labels = labels.unsqueeze(0).expand(num_nodes, -1) neighbor_labels = neighbor_labels.masked_fill(~neighbor_mask, -1)