Skip to content

Commit

Permalink
fix: Implement graph acquisition (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
vladislavalerievich authored Dec 24, 2024
2 parents ad55030 + 8093d31 commit 9f978d6
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 206 deletions.
126 changes: 6 additions & 120 deletions grakel_replace/mixed_single_task_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,67 +2,15 @@

from typing import TYPE_CHECKING

import networkx as nx
import torch
from botorch.models import SingleTaskGP
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import AdditiveKernel, Kernel
from grakel_replace.torch_wl_kernel import GraphDataset, TorchWLKernel

if TYPE_CHECKING:
from gpytorch.module import Module
import networkx as nx
from torch import Tensor


class WLKernel(Kernel):
"""Weisfeiler-Lehman Kernel for graph similarity
integrated into the GPyTorch framework.
This kernel encapsulates the precomputed Weisfeiler-Lehman graph kernel matrix
and provides it in a GPyTorch-compatible format.
It computes either the training kernel
or the cross-kernel between training and test graphs as needed.
"""

def __init__(
self,
K_train: Tensor,
wl_kernel: TorchWLKernel,
train_graph_dataset: GraphDataset
) -> None:
super().__init__()
self._K_train = K_train
self._wl_kernel = wl_kernel
self._train_graph_dataset = train_graph_dataset

def forward(
self, x1: Tensor,
x2: Tensor | None = None,
diag: bool = False,
last_dim_is_batch: bool = False
) -> Tensor:
"""Forward method to compute the kernel matrix for the graph inputs.
Args:
x1 (Tensor): First input tensor
(unused, required for interface compatibility).
x2 (Tensor | None): Second input tensor.
If None, computes the training kernel matrix.
diag (bool): Whether to return only the diagonal of the kernel matrix.
last_dim_is_batch (bool): Whether the last dimension is a batch dimension.
Returns:
Tensor: The computed kernel matrix.
"""
if x2 is None:
# Return the precomputed training kernel matrix
return self._K_train

# Compute cross-kernel between training graphs and new test graphs
test_dataset = GraphDataset.from_networkx(x2) # x2 should be test graphs
return self._wl_kernel(self._train_graph_dataset, test_dataset)


class MixedSingleTaskGP(SingleTaskGP):
"""A Gaussian Process model for mixed input spaces containing numerical, categorical,
and graph features.
Expand All @@ -85,9 +33,9 @@ def __init__(
train_X: Tensor,
train_graphs: list[nx.Graph],
train_Y: Tensor,
num_cat_kernel: Kernel,
wl_kernel: TorchWLKernel,
train_Yvar: Tensor | None = None,
num_cat_kernel: Module | None = None,
wl_kernel: TorchWLKernel | None = None,
**kwargs,
) -> None:
"""Initialize the mixed-input Gaussian Process model.
Expand Down Expand Up @@ -115,7 +63,7 @@ def __init__(
**kwargs,
)
# Initialize the Weisfeiler-Lehman kernel or use a default one
self._wl_kernel = wl_kernel or TorchWLKernel(n_iter=5, normalize=True)
self._wl_kernel = wl_kernel

# Preprocess the training graphs into a compatible format and compute the graph
# kernel matrix
Expand All @@ -137,69 +85,7 @@ def __init__(

def __call__(self, X: Tensor, graphs: list[nx.Graph] | None = None, **kwargs):
"""Custom __call__ method that retrieves train graphs if not explicitly passed."""
print("__call__", X.shape, len(graphs) if graphs is not None else None) # noqa: T201
if graphs is None: # Use stored graphs from train_inputs if not provided
graphs = self._train_inputs[1]
return self.forward(X, graphs)

def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal:
"""Forward pass to compute the Gaussian Process distribution for given inputs.
This combines the numerical/categorical kernel with the graph kernel
to compute the joint covariance matrix.
Args:
X (Tensor): Input tensor for numerical and categorical features.
graphs (list[nx.Graph]): List of input graphs.
Returns:
MultivariateNormal: The Gaussian Process distribution for the inputs.
"""
if len(X) != len(graphs):
raise ValueError(
f"Number of feature vectors ({len(X)}) must match "
f"number of graphs ({len(graphs)})"
)
if not all(isinstance(g, nx.Graph) for g in graphs):
raise TypeError("Expected input type is a list of NetworkX graphs.")

# Process the new graph inputs into a compatible dataset
proc_graphs = GraphDataset.from_networkx(graphs)

# Compute the kernel matrix for the new graphs
K_new = self._wl_kernel(proc_graphs)
K_new = K_new.to(dtype=X.dtype)

# Combine the graph kernel with the numerical/categorical kernel (if present)
if self.num_cat_kernel is not None:
K_num_cat = self.num_cat_kernel(X)

# Ensure K_new matches K_num_cat dimensions
if K_num_cat.dim() > 2:
batch_size = K_num_cat.size(0)
target_size = K_num_cat.size(1)

# Resize K_new if needed
if K_new.size(-1) != target_size:
K_new_resized = torch.zeros(
*K_new.shape[:-2], target_size, target_size,
dtype=K_new.dtype,
device=K_new.device
)
K_new_resized[..., :K_new.size(-2), :K_new.size(-1)] = K_new
K_new = K_new_resized

if K_new.dim() < K_num_cat.dim():
K_new = K_new.unsqueeze(0).expand(batch_size, target_size,
target_size)

# Convert to dense tensor if needed
if hasattr(K_num_cat, "to_dense"):
K_num_cat = K_num_cat.to_dense()

K_combined = K_num_cat + K_new
else:
K_combined = K_new

# Compute the mean using the mean module and construct the GP distribution
mean_x = self.mean_module(X)
return MultivariateNormal(mean_x, K_combined)
return self.forward(X)
125 changes: 79 additions & 46 deletions grakel_replace/mixed_single_task_gp_usage_example.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,48 @@
from __future__ import annotations

from collections.abc import Iterator
from contextlib import contextmanager
from itertools import product
from typing import TYPE_CHECKING

import networkx as nx
import torch
from botorch import fit_gpytorch_mll
from botorch.acquisition import LinearMCObjective, qLogNoisyExpectedImprovement
from botorch.fit import fit_gpytorch_mll
from botorch.models.gp_regression_mixed import CategoricalKernel, ScaleKernel
from botorch.models.gp_regression_mixed import CategoricalKernel, Kernel, ScaleKernel
from gpytorch import ExactMarginalLogLikelihood
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.kernels import AdditiveKernel, MaternKernel
from grakel_replace.mixed_single_task_gp import MixedSingleTaskGP
from grakel_replace.optimize import optimize_acqf_graph
from grakel_replace.torch_wl_kernel import TorchWLKernel

TRAIN_CONFIGS = 10
if TYPE_CHECKING:
from gpytorch.distributions.multivariate_normal import MultivariateNormal

TRAIN_CONFIGS = 50
TEST_CONFIGS = 10
TOTAL_CONFIGS = TRAIN_CONFIGS + TEST_CONFIGS

N_NUMERICAL = 2
N_CATEGORICAL = 2
N_CATEGORICAL_VALUES_PER_CATEGORY = 3
N_GRAPH = 2
N_CATEGORICAL = 1
N_CATEGORICAL_VALUES_PER_CATEGORY = 2
N_GRAPH = 1
assert N_GRAPH == 1, "This example only supports a single graph feature"

kernels = []

# Create numerical and categorical features
X = torch.empty(size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL), dtype=torch.float64)
X = torch.empty(
size=(TOTAL_CONFIGS, N_NUMERICAL + N_CATEGORICAL + N_GRAPH),
dtype=torch.float64,
)
if N_NUMERICAL > 0:
X[:, :N_NUMERICAL] = torch.rand(
size=(TOTAL_CONFIGS, N_NUMERICAL),
dtype=torch.float64,
)

if N_CATEGORICAL > 0:
X[:, N_NUMERICAL:] = torch.randint(
X[:, N_NUMERICAL : N_NUMERICAL + N_CATEGORICAL] = torch.randint(
0,
N_CATEGORICAL_VALUES_PER_CATEGORY,
size=(TOTAL_CONFIGS, N_CATEGORICAL),
Expand All @@ -45,8 +55,21 @@
G = nx.erdos_renyi_graph(n=5, p=0.5) # Random graph with 5 nodes
graphs.append(G)

# Assign a new index column to the graphs
X[:, -1] = torch.arange(TOTAL_CONFIGS, dtype=torch.float64)

# Create random target values
y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64)
y = torch.rand(size=(TOTAL_CONFIGS,), dtype=torch.float64) + 0.5

# Split into train and test sets
train_x = X[:TRAIN_CONFIGS]
train_graphs = graphs[:TRAIN_CONFIGS]
train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch

test_x = X[TRAIN_CONFIGS:]
test_graphs = graphs[TRAIN_CONFIGS:]
test_y = y[TRAIN_CONFIGS:].unsqueeze(-1)


# Setup kernels for numerical and categorical features
if N_NUMERICAL > 0:
Expand All @@ -68,47 +91,56 @@
)
kernels.append(hamming)

# Combine numerical and categorical kernels
combined_num_cat_kernel = AdditiveKernel(*kernels) if kernels else None
if N_GRAPH > 0:
wl_kernel = ScaleKernel(
TorchWLKernel(
graph_lookup=train_graphs,
n_iter=5,
normalize=True,
active_dims=(X.shape[1] - 1,), # Last column
)
)
kernels.append(wl_kernel)

# Create WL kernel for graphs
wl_kernel = TorchWLKernel(n_iter=5, normalize=True)

# Split into train and test sets
train_x = X[:TRAIN_CONFIGS]
train_graphs = graphs[:TRAIN_CONFIGS]
train_y = y[:TRAIN_CONFIGS].unsqueeze(-1) # Add dimension for botorch
# Combine numerical and categorical kernels
kernel = AdditiveKernel(*kernels)

test_x = X[TRAIN_CONFIGS:]
test_graphs = graphs[TRAIN_CONFIGS:]
test_y = y[TRAIN_CONFIGS:].unsqueeze(-1)
from botorch.models import SingleTaskGP

# Initialize the mixed GP
gp = MixedSingleTaskGP(
train_X=train_x,
train_graphs=train_graphs,
train_Y=train_y,
num_cat_kernel=combined_num_cat_kernel,
wl_kernel=wl_kernel,
)
gp = SingleTaskGP(train_X=train_x, train_Y=train_y, covar_module=kernel)

# Compute the posterior distribution
multivariate_normal: MultivariateNormal = gp.forward(train_x, train_graphs)
print("Posterior distribution:", multivariate_normal)
# The wl_kernel will use the indices to index into the training graphs it is holding
# on to...
multivariate_normal: MultivariateNormal = gp.forward(train_x)


# Making predictions on test data
with torch.no_grad():
posterior = gp.forward(test_x, test_graphs)
# No the wl_kernel needs to be aware of the test graphs
@contextmanager
def set_graph_lookup(_gp: SingleTaskGP, new_graphs: list[nx.Graph]) -> Iterator[None]:
kernel_prev_graphs: list[tuple[Kernel, list[nx.Graph]]] = []
for kern in _gp.covar_module.sub_kernels():
if isinstance(kern, TorchWLKernel):
kernel_prev_graphs.append((kern, kern.graph_lookup))
kern.set_graph_lookup(new_graphs)

yield

for _kern, _prev_graphs in kernel_prev_graphs:
_kern.set_graph_lookup(_prev_graphs)


with torch.no_grad(), set_graph_lookup(gp, train_graphs + test_graphs):
posterior = gp.forward(test_x)
predictions = posterior.mean
uncertainties = posterior.variance.sqrt()
covar = posterior.covariance_matrix

print("\nMean:", predictions)
print("Variance:", uncertainties)

# =============== Fitting the GP using botorch ===============

print("\nFitting the GP model using botorch...")

mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
fit_gpytorch_mll(mll)
Expand All @@ -124,8 +156,10 @@
# Define bounds
bounds = torch.tensor(
[
[0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL,
[1.0] * N_NUMERICAL + [float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL
[0.0] * N_NUMERICAL + [0.0] * N_CATEGORICAL + [-1.0] * N_GRAPH,
[1.0] * N_NUMERICAL
+ [float(N_CATEGORICAL_VALUES_PER_CATEGORY - 1)] * N_CATEGORICAL
+ [len(X) - 1] * N_GRAPH,
]
)

Expand All @@ -142,21 +176,20 @@
fixed_cats = [{col: i} for i in choice_indices]
else:
fixed_cats = [
dict(zip(cats_per_column.keys(), combo))
dict(zip(cats_per_column.keys(), combo, strict=False))
for combo in product(*cats_per_column.values())
]


print("------------------") # noqa: T201
# Use the graph-optimized acquisition function
best_candidate, best_score = optimize_acqf_graph(
acq_function=acq_function,
bounds=bounds,
fixed_features_list=fixed_cats,
train_graphs=train_graphs,
num_graph_samples=20,
num_restarts=10,
raw_samples=10,
num_graph_samples=2,
num_restarts=2,
raw_samples=16,
q=1,
)

print("Best candidate:", best_candidate)
print("Acquisition score:", best_score)
Loading

0 comments on commit 9f978d6

Please sign in to comment.