From 8093d3118157ff65ea0128507fddba6c6f2ace51 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 16 Dec 2024 15:13:38 +0100 Subject: [PATCH] fix: Implement graph acquisition --- grakel_replace/mixed_single_task_gp.py | 126 +-------------- .../mixed_single_task_gp_usage_example.py | 125 +++++++++------ grakel_replace/optimize.py | 72 +++++++-- grakel_replace/torch_wl_kernel.py | 149 +++++++++++++++--- grakel_replace/torch_wl_usage_example.py | 20 ++- 5 files changed, 286 insertions(+), 206 deletions(-) diff --git a/grakel_replace/mixed_single_task_gp.py b/grakel_replace/mixed_single_task_gp.py index 5b502b19..9a381c7f 100644 --- a/grakel_replace/mixed_single_task_gp.py +++ b/grakel_replace/mixed_single_task_gp.py @@ -2,67 +2,15 @@ from typing import TYPE_CHECKING -import networkx as nx -import torch from botorch.models import SingleTaskGP -from gpytorch.distributions.multivariate_normal import MultivariateNormal from gpytorch.kernels import AdditiveKernel, Kernel from grakel_replace.torch_wl_kernel import GraphDataset, TorchWLKernel if TYPE_CHECKING: - from gpytorch.module import Module + import networkx as nx from torch import Tensor -class WLKernel(Kernel): - """Weisfeiler-Lehman Kernel for graph similarity - integrated into the GPyTorch framework. - - This kernel encapsulates the precomputed Weisfeiler-Lehman graph kernel matrix - and provides it in a GPyTorch-compatible format. - It computes either the training kernel - or the cross-kernel between training and test graphs as needed. - """ - - def __init__( - self, - K_train: Tensor, - wl_kernel: TorchWLKernel, - train_graph_dataset: GraphDataset - ) -> None: - super().__init__() - self._K_train = K_train - self._wl_kernel = wl_kernel - self._train_graph_dataset = train_graph_dataset - - def forward( - self, x1: Tensor, - x2: Tensor | None = None, - diag: bool = False, - last_dim_is_batch: bool = False - ) -> Tensor: - """Forward method to compute the kernel matrix for the graph inputs. - - Args: - x1 (Tensor): First input tensor - (unused, required for interface compatibility). - x2 (Tensor | None): Second input tensor. - If None, computes the training kernel matrix. - diag (bool): Whether to return only the diagonal of the kernel matrix. - last_dim_is_batch (bool): Whether the last dimension is a batch dimension. - - Returns: - Tensor: The computed kernel matrix. - """ - if x2 is None: - # Return the precomputed training kernel matrix - return self._K_train - - # Compute cross-kernel between training graphs and new test graphs - test_dataset = GraphDataset.from_networkx(x2) # x2 should be test graphs - return self._wl_kernel(self._train_graph_dataset, test_dataset) - - class MixedSingleTaskGP(SingleTaskGP): """A Gaussian Process model for mixed input spaces containing numerical, categorical, and graph features. @@ -85,9 +33,9 @@ def __init__( train_X: Tensor, train_graphs: list[nx.Graph], train_Y: Tensor, + num_cat_kernel: Kernel, + wl_kernel: TorchWLKernel, train_Yvar: Tensor | None = None, - num_cat_kernel: Module | None = None, - wl_kernel: TorchWLKernel | None = None, **kwargs, ) -> None: """Initialize the mixed-input Gaussian Process model. @@ -115,7 +63,7 @@ def __init__( **kwargs, ) # Initialize the Weisfeiler-Lehman kernel or use a default one - self._wl_kernel = wl_kernel or TorchWLKernel(n_iter=5, normalize=True) + self._wl_kernel = wl_kernel # Preprocess the training graphs into a compatible format and compute the graph # kernel matrix @@ -137,69 +85,7 @@ def __init__( def __call__(self, X: Tensor, graphs: list[nx.Graph] | None = None, **kwargs): """Custom __call__ method that retrieves train graphs if not explicitly passed.""" + print("__call__", X.shape, len(graphs) if graphs is not None else None) # noqa: T201 if graphs is None: # Use stored graphs from train_inputs if not provided graphs = self._train_inputs[1] - return self.forward(X, graphs) - - def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal: - """Forward pass to compute the Gaussian Process distribution for given inputs. - - This combines the numerical/categorical kernel with the graph kernel - to compute the joint covariance matrix. - - Args: - X (Tensor): Input tensor for numerical and categorical features. - graphs (list[nx.Graph]): List of input graphs. - - Returns: - MultivariateNormal: The Gaussian Process distribution for the inputs. - """ - if len(X) != len(graphs): - raise ValueError( - f"Number of feature vectors ({len(X)}) must match " - f"number of graphs ({len(graphs)})" - ) - if not all(isinstance(g, nx.Graph) for g in graphs): - raise TypeError("Expected input type is a list of NetworkX graphs.") - - # Process the new graph inputs into a compatible dataset - proc_graphs = GraphDataset.from_networkx(graphs) - - # Compute the kernel matrix for the new graphs - K_new = self._wl_kernel(proc_graphs) - K_new = K_new.to(dtype=X.dtype) - - # Combine the graph kernel with the numerical/categorical kernel (if present) - if self.num_cat_kernel is not None: - K_num_cat = self.num_cat_kernel(X) - - # Ensure K_new matches K_num_cat dimensions - if K_num_cat.dim() > 2: - batch_size = K_num_cat.size(0) - target_size = K_num_cat.size(1) - - # Resize K_new if needed - if K_new.size(-1) != target_size: - K_new_resized = torch.zeros( - *K_new.shape[:-2], target_size, target_size, - dtype=K_new.dtype, - device=K_new.device - ) - K_new_resized[..., :K_new.size(-2), :K_new.size(-1)] = K_new - K_new = K_new_resized - - if K_new.dim() < K_num_cat.dim(): - K_new = K_new.unsqueeze(0).expand(batch_size, target_size, - target_size) - - # Convert to dense tensor if needed - if hasattr(K_num_cat, "to_dense"): - K_num_cat = K_num_cat.to_dense() - - K_combined = K_num_cat + K_new - else: - K_combined = K_new - - # Compute the mean using the mean module and construct the GP distribution - mean_x = self.mean_module(X) - return MultivariateNormal(mean_x, K_combined) + return self.forward(X) diff --git a/grakel_replace/mixed_single_task_gp_usage_example.py b/grakel_replace/mixed_single_task_gp_usage_example.py index cd8528e4..67306456 100644 --- a/grakel_replace/mixed_single_task_gp_usage_example.py +++ b/grakel_replace/mixed_single_task_gp_usage_example.py @@ -1,30 +1,40 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager from itertools import product +from typing import TYPE_CHECKING import networkx as nx import torch +from botorch import fit_gpytorch_mll from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement -from botorch.fit import fit_gpytorch_mll -from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel +from botorch.models.gp_regression_mixed import CategoricalKernel, Kernel, ScaleKernel from gpytorch import ExactMarginalLogLikelihood -from gpytorch.distributions.multivariate_normal import MultivariateNormal from gpytorch.kernels import AdditiveKernel, MaternKernel -from grakel_replace.mixed_single_task_gp import MixedSingleTaskGP from grakel_replace.optimize import optimize_acqf_graph from grakel_replace.torch_wl_kernel import TorchWLKernel -TRAIN_CONFIGS = 10 +if TYPE_CHECKING: + from gpytorch.distributions.multivariate_normal import MultivariateNormal + +TRAIN_CONFIGS = 50 TEST_CONFIGS = 10 TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS N_NUMERICAL = 2 -N_CATEGORICAL = 2 -N_CATEGORICAL_VALUES_PER_CATEGORY = 3 -N_GRAPH = 2 +N_CATEGORICAL = 1 +N_CATEGORICAL_VALUES_PER_CATEGORY = 2 +N_GRAPH = 1 +assert N_GRAPH == 1, "This example only supports a single graph feature" kernels = [] # Create numerical and categorical features -X = torch.empty(size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL), dtype=torch.float64) +X = torch.empty( + size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL + N_GRAPH), + dtype=torch.float64, +) if N_NUMERICAL > 0: X[:, :N_NUMERICAL] = torch.rand( size=(TOTAL_CONFIGS, N_NUMERICAL), @@ -32,7 +42,7 @@ ) if N_CATEGORICAL > 0: - X[:, N_NUMERICAL:] = torch.randint( + X[:, N_NUMERICAL : N_NUMERICAL + N_CATEGORICAL] = torch.randint( 0, N_CATEGORICAL_VALUES_PER_CATEGORY, size=(TOTAL_CONFIGS, N_CATEGORICAL), @@ -45,8 +55,21 @@ G = nx.erdos_renyi_graph(n=5, p=0.5) # Random graph with 5 nodes graphs.append(G) +# Assign a new index column to the graphs +X[:, -1] = torch.arange(TOTAL_CONFIGS, dtype=torch.float64) + # Create random target values -y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64) +y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64) + 0.5 + +# Split into train and test sets +train_x = X[:TRAIN_CONFIGS] +train_graphs = graphs[:TRAIN_CONFIGS] +train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch + +test_x = X[TRAIN_CONFIGS:] +test_graphs = graphs[TRAIN_CONFIGS:] +test_y = y[TRAIN_CONFIGS:].unsqueeze(-1) + # Setup kernels for numerical and categorical features if N_NUMERICAL > 0: @@ -68,47 +91,56 @@ ) kernels.append(hamming) -# Combine numerical and categorical kernels -combined_num_cat_kernel = AdditiveKernel(*kernels) if kernels else None +if N_GRAPH > 0: + wl_kernel = ScaleKernel( + TorchWLKernel( + graph_lookup=train_graphs, + n_iter=5, + normalize=True, + active_dims=(X.shape[1] - 1,), # Last column + ) + ) + kernels.append(wl_kernel) -# Create WL kernel for graphs -wl_kernel = TorchWLKernel(n_iter=5, normalize=True) -# Split into train and test sets -train_x = X[:TRAIN_CONFIGS] -train_graphs = graphs[:TRAIN_CONFIGS] -train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch +# Combine numerical and categorical kernels +kernel = AdditiveKernel(*kernels) -test_x = X[TRAIN_CONFIGS:] -test_graphs = graphs[TRAIN_CONFIGS:] -test_y = y[TRAIN_CONFIGS:].unsqueeze(-1) +from botorch.models import SingleTaskGP # Initialize the mixed GP -gp = MixedSingleTaskGP( - train_X=train_x, - train_graphs=train_graphs, - train_Y=train_y, - num_cat_kernel=combined_num_cat_kernel, - wl_kernel=wl_kernel, -) +gp = SingleTaskGP(train_X=train_x, train_Y=train_y, covar_module=kernel) # Compute the posterior distribution -multivariate_normal: MultivariateNormal = gp.forward(train_x, train_graphs) -print("Posterior distribution:", multivariate_normal) +# The wl_kernel will use the indices to index into the training graphs it is holding +# on to... +multivariate_normal: MultivariateNormal = gp.forward(train_x) + # Making predictions on test data -with torch.no_grad(): - posterior = gp.forward(test_x, test_graphs) +# No the wl_kernel needs to be aware of the test graphs +@contextmanager +def set_graph_lookup(_gp: SingleTaskGP, new_graphs: list[nx.Graph]) -> Iterator[None]: + kernel_prev_graphs: list[tuple[Kernel, list[nx.Graph]]] = [] + for kern in _gp.covar_module.sub_kernels(): + if isinstance(kern, TorchWLKernel): + kernel_prev_graphs.append((kern, kern.graph_lookup)) + kern.set_graph_lookup(new_graphs) + + yield + + for _kern, _prev_graphs in kernel_prev_graphs: + _kern.set_graph_lookup(_prev_graphs) + + +with torch.no_grad(), set_graph_lookup(gp, train_graphs + test_graphs): + posterior = gp.forward(test_x) predictions = posterior.mean uncertainties = posterior.variance.sqrt() covar = posterior.covariance_matrix -print("\nMean:", predictions) -print("Variance:", uncertainties) - # =============== Fitting the GP using botorch =============== -print("\nFitting the GP model using botorch...") mll = ExactMarginalLogLikelihood(gp.likelihood, gp) fit_gpytorch_mll(mll) @@ -124,8 +156,10 @@ # Define bounds bounds = torch.tensor( [ - [0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL, - [1.0] * N_NUMERICAL + [float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL + [0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL + [-1.0] * N_GRAPH, + [1.0] * N_NUMERICAL + + [float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL + + [len(X) - 1] * N_GRAPH, ] ) @@ -142,21 +176,20 @@ fixed_cats = [{col: i} for i in choice_indices] else: fixed_cats = [ - dict(zip(cats_per_column.keys(), combo)) + dict(zip(cats_per_column.keys(), combo, strict=False)) for combo in product(*cats_per_column.values()) ] + +print("------------------") # noqa: T201 # Use the graph-optimized acquisition function best_candidate, best_score = optimize_acqf_graph( acq_function=acq_function, bounds=bounds, fixed_features_list=fixed_cats, train_graphs=train_graphs, - num_graph_samples=20, - num_restarts=10, - raw_samples=10, + num_graph_samples=2, + num_restarts=2, + raw_samples=16, q=1, ) - -print("Best candidate:", best_candidate) -print("Acquisition score:", best_score) diff --git a/grakel_replace/optimize.py b/grakel_replace/optimize.py index eff7adac..4f21886e 100644 --- a/grakel_replace/optimize.py +++ b/grakel_replace/optimize.py @@ -1,14 +1,49 @@ from __future__ import annotations import random +from collections.abc import Iterator +from contextlib import contextmanager from typing import TYPE_CHECKING import networkx as nx import torch from botorch.optim import optimize_acqf_mixed +from grakel_replace.torch_wl_kernel import TorchWLKernel if TYPE_CHECKING: from botorch.acquisition import AcquisitionFunction + from botorch.models.gp_regression_mixed import Kernel + + +# Making predictions on test data +# No the wl_kernel needs to be aware of the test graphs +@contextmanager +def set_graph_lookup( + kernel: Kernel, + new_graphs: list[nx.Graph], + *, + append: bool = True, +) -> Iterator[None]: + kernel_prev_graphs: list[tuple[Kernel, list[nx.Graph]]] = [] + if isinstance(kernel, TorchWLKernel): + modules = [kernel] + else: + assert hasattr( + kernel, "sub_kernels" + ), "Kernel module must have sub_kernels method." + modules = [k for k in kernel.sub_kernels() if isinstance(k, TorchWLKernel)] + + for kern in modules: + kernel_prev_graphs.append((kern, kern.graph_lookup)) + if append: + kern.set_graph_lookup([*kern.graph_lookup, *new_graphs]) + else: + kern.set_graph_lookup(new_graphs) + + yield + + for _kern, _prev_graphs in kernel_prev_graphs: + _kern.set_graph_lookup(_prev_graphs) def sample_graphs(graphs: list[nx.Graph], num_samples: int) -> list[nx.Graph]: @@ -35,7 +70,7 @@ def sample_graphs(graphs: list[nx.Graph], num_samples: int) -> list[nx.Graph]: u, v = random.sample(nodes, 2) if not sampled_graph.has_edge(u, v): sampled_graph.add_edge(u, v) - elif sampled_graph.edges: # 30% chance to remove edge + elif sampled_graph.edges: # 30% chance to remove edge u, v = random.choice(list(sampled_graph.edges)) sampled_graph.remove_edge(u, v) @@ -81,21 +116,34 @@ def optimize_acqf_graph( raise ValueError("train_graphs cannot be None.") sampled_graphs = sample_graphs(train_graphs, num_samples=num_graph_samples) + gp = acq_function.model + covar_module = gp.covar_module best_candidates, best_scores = [], [] + TODO_GRAPH_COLUMN_INDEX = bounds.shape[1] - 1 + for _graph in sampled_graphs: - for fixed_features in fixed_features_list or [{}]: - candidates, scores = optimize_acqf_mixed( - acq_function=acq_function, - bounds=bounds, - fixed_features_list=[fixed_features], - num_restarts=num_restarts, - raw_samples=raw_samples, - q=q, - ) - best_candidates.append(candidates) - best_scores.append(scores) + # This is new, we essentially iterate through all the kernels and + # include the sampled graph. + with set_graph_lookup(covar_module, [_graph], append=True): + for fixed_features in fixed_features_list or [{}]: + # We then consider this graph as a fixed feature, i.e. in the X's + # generated during acquisition, the graph column will just be full + # of `-1` indicating to select the very last graph in the lookup + # they used. + _fixed_features = {**fixed_features, TODO_GRAPH_COLUMN_INDEX: -1.0} + + candidates, scores = optimize_acqf_mixed( + acq_function=acq_function, + bounds=bounds, + fixed_features_list=[_fixed_features], + num_restarts=num_restarts, + raw_samples=raw_samples, + q=q, + ) + best_candidates.append(candidates) + best_scores.append(scores) best_scores_tensor = torch.tensor(best_scores) best_idx = torch.argmax(best_scores_tensor) diff --git a/grakel_replace/torch_wl_kernel.py b/grakel_replace/torch_wl_kernel.py index ba2ae071..5c946555 100644 --- a/grakel_replace/torch_wl_kernel.py +++ b/grakel_replace/torch_wl_kernel.py @@ -1,13 +1,122 @@ from __future__ import annotations from collections import Counter +from typing import Any import networkx as nx import torch +from botorch.models.gp_regression_mixed import Kernel from torch import nn -class TorchWLKernel(nn.Module): +class TorchWLKernel(Kernel): + has_lengthscale = False + + def __init__( + self, + graph_lookup: list[nx.Graph], + n_iter: int = 5, + *, + normalize: bool = True, + active_dims: tuple[int, ...], + **kwargs: Any, + ) -> None: + super().__init__(active_dims=active_dims, **kwargs) + self.graph_lookup = graph_lookup + self.n_iter = n_iter + self.normalize = normalize + + # NOTE: set in the `super().__init__()` + self.active_dims: torch.Tensor + + def set_graph_lookup(self, graph_lookup: list[nx.Graph]) -> None: + self.graph_lookup = graph_lookup + + def forward( + self, + x1: torch.Tensor, + x2: torch.Tensor, + *, + diag: bool = False, + last_dim_is_batch: bool = False, + **params: Any, + ): + if last_dim_is_batch: + raise NotImplementedError("TODO: Figure this out") + + assert x1.shape[-1] == 1, "Last dimension must be the graph index" + assert x2.shape[-1] == 1, "Last dimension must be the graph index" + + # TODO: Optimizations + # + # 1. We're computing the whole K Matrix, but we only need the K_x1_x2 + # + # K + # -------------------- + # | K_x1_x1 K_x1_x2 | + # | K_x2_x1 K_x2_x2 | + # -------------------- + # + # However in the case where x1 == x2, we can shortcut this slightly as in + # the above, K_x1_x2 == K_x2_x1 == K_x1_x1 == K_x2_x2 + # This shortcut is implemented below based on this flag. + # + # 2. The _TorchWLKernel used below has the following properties, which + # get set on forward. In the case where x1.ndim == 3 then the first dim is + # the `q` dim. Doesn't matter what it is other than we end up repeating the + # processing the graphs `q` times. Given that it's likely that the indices + # in last dimension are likely to be constant (i.e. all `4`, indicating the + # `4th` graph, we are effectively doing a lot of extra calculation. We could + # shortcut this by pre-computing these for each index. Could be nice to somehow + # have the inned `_TorchWLKernel` be aware of this extra dimension but it's + # fine if not as long as we can reduce the extraneuous computations. We could + # change the interface of `_TorchWLKernel` to take in the raw processed + # tensors instead of `nx.Graph` objects, which we would instead preprocess here. + # If that's the case, we could move the `_TorchWLKernel` to essentially just + # be functions we call instead with the correct pre-processed data. + # + # .self.label_dict + # .self.label_counter + # + x1_is_x2 = torch.equal(x1, x2) + + # NOTE: The active dim is already selected out for us and is the last dimension + # (not including whatever happens when last_dim_is_batch) is True. + if x1.ndim == 3: + # - x1: torch.Size([32, 5, 1]) + # - x2: torch.Size([32, 55, 1]) + # - output: torch.Size([32, 5, 55]) + q_dim_size = x1.shape[0] + assert x2.shape[0] == q_dim_size + + out = torch.empty((q_dim_size, x1.shape[1], x2.shape[1]), device=x1.device) + for q in range(q_dim_size): + out[q] = self.forward(x1[q], x2[q], diag=diag) + return out + + if x1_is_x2: + _ixs = x1.flatten().to(torch.int64).tolist() + all_graphs = [self.graph_lookup[i] for i in _ixs] + + # No selection requires + select = None + else: + _ixs1 = x1.flatten().to(torch.int64).tolist() + _ixs2 = x2.flatten().to(torch.int64).tolist() + all_graphs = [self.graph_lookup[i] for i in _ixs1 + _ixs2] + + # Select out K_x1_x2 + select = lambda _K: _K[: len(_ixs1), len(_ixs1) :] + + _kernel = _TorchWLKernel(n_iter=self.n_iter, normalize=self.normalize) + K = _kernel(all_graphs) + K_selected = K if select is None else select(K) + if diag: + return torch.diag(K_selected) + return K_selected + + +class _TorchWLKernel(nn.Module): """A custom implementation of Weisfeiler-Lehman (WL) Kernel in PyTorch. The WL Kernel is a graph kernel that measures similarity between graphs based on @@ -24,7 +133,7 @@ class TorchWLKernel(nn.Module): label_counter: Counter for generating new label indices """ - def __init__(self, n_iter: int = 5, normalize: bool = True) -> None: + def __init__(self, n_iter: int = 5, *, normalize: bool = True) -> None: super().__init__() self.n_iter = n_iter self.normalize = normalize @@ -49,7 +158,7 @@ def _get_sparse_adj(self, graph: nx.Graph) -> torch.sparse.Tensor: indices=torch.empty((2, 0), dtype=torch.long), values=torch.empty(0), size=(num_nodes, num_nodes), - device=self.device + device=self.device, ) # Create bidirectional edge indices for undirected graph @@ -60,8 +169,7 @@ def _get_sparse_adj(self, graph: nx.Graph) -> torch.sparse.Tensor: values = torch.ones(len(edge_indices), dtype=torch.float, device=self.device) return torch.sparse_coo_tensor( - indices, values, (num_nodes, num_nodes), - device=self.device + indices, values, (num_nodes, num_nodes), device=self.device ) def _init_node_labels(self, graph: nx.Graph) -> torch.Tensor: @@ -88,9 +196,7 @@ def _init_node_labels(self, graph: nx.Graph) -> torch.Tensor: return torch.tensor(labels, dtype=torch.long, device=self.device) def _wl_iteration( - self, - adj: torch.sparse.Tensor, - labels: torch.Tensor + self, adj: torch.sparse.Tensor, labels: torch.Tensor ) -> torch.Tensor: """Perform one WL iteration to update node labels. Concatenate own label with sorted neighbor labels. @@ -126,11 +232,7 @@ def _wl_iteration( return torch.tensor(new_labels, dtype=torch.long, device=self.device) - def _compute_feature_vector( - self, - labels: torch.Tensor, - size: int - ) -> torch.Tensor: + def _compute_feature_vector(self, labels: torch.Tensor, size: int) -> torch.Tensor: """Compute histogram feature vector from node labels. Args: @@ -172,8 +274,9 @@ def forward(self, graphs: list[nx.Graph]) -> torch.Tensor: TypeError: If input is not a list of NetworkX graphs """ # Validate input - if (not isinstance(graphs, list) or - not all(isinstance(g, nx.Graph) for g in graphs)): + if not isinstance(graphs, list) or not all( + isinstance(g, nx.Graph) for g in graphs + ): raise TypeError("Expected input type is a list of NetworkX graphs.") # Setup computation @@ -201,10 +304,12 @@ def forward(self, graphs: list[nx.Graph]) -> torch.Tensor: # Compute feature matrices using final label count feature_matrices = [ - torch.stack([ - self._compute_feature_vector(labels, self.label_counter) - for labels in iteration_labels - ]) + torch.stack( + [ + self._compute_feature_vector(labels, self.label_counter) + for labels in iteration_labels + ] + ) for iteration_labels in all_label_tensors ] @@ -226,9 +331,9 @@ class GraphDataset: """Utility class to convert NetworkX graphs for WL kernel.""" @staticmethod - def from_networkx(graphs: list[nx.Graph], node_labels_tag: str = "label") -> list[ - nx.Graph]: - + def from_networkx( + graphs: list[nx.Graph], node_labels_tag: str = "label" + ) -> list[nx.Graph]: if not all(isinstance(g, nx.Graph) for g in graphs): raise TypeError("Expected input type is a list of NetworkX graphs.") diff --git a/grakel_replace/torch_wl_usage_example.py b/grakel_replace/torch_wl_usage_example.py index a24d2536..f9958045 100644 --- a/grakel_replace/torch_wl_usage_example.py +++ b/grakel_replace/torch_wl_usage_example.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import networkx as nx +import torch from torch_wl_kernel import GraphDataset, TorchWLKernel # Create the same graphs as for the Grakel example @@ -10,12 +13,17 @@ G3.add_edges_from([(0, 1), (1, 3), (3, 2)]) # Process graphs -graphs = GraphDataset.from_networkx([G1, G2, G3]) +graphs: list[nx.Graph] = GraphDataset.from_networkx([G1, G2, G3]) # Initialize and run WL kernel -wl_kernel = TorchWLKernel(n_iter=2, normalize=False) - -K = wl_kernel(graphs) +wl_kernel = TorchWLKernel( + training_graph_list=graphs, + n_iter=2, + normalize=True, + active_dims=(1,), +) +X1 = torch.tensor([[42.4, 43.4, 44.5], [0, 1, 2]]).T +X2 = torch.tensor([[42.4, 43.4, 44.5], [0, 1, 2]]).T -print("Kernel matrix (pairwise similarities):") -print(K) +K = wl_kernel(X1, X2) +print(K.to_dense()) # noqa: T201