Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vladislavalerievich committed Dec 7, 2024
1 parent 2999582 commit ad55030
Showing 1 changed file with 33 additions and 16 deletions.
49 changes: 33 additions & 16 deletions tests/test_mixed_single_task_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_model_initialization_and_validation(sample_data, sample_kernels):
wl_kernel=wl_kernel,
)

assert gp._train_graphs == graphs
assert gp._train_inputs[1] == graphs
assert isinstance(gp._wl_kernel, TorchWLKernel)
assert gp.num_cat_kernel == combined_kernel

Expand Down Expand Up @@ -133,6 +133,10 @@ def test_kernel_combinations_and_properties(sample_data):
"""Test kernel combination, invariance, and consistency properties."""
X, graphs, y = sample_data

# Ensure inputs are properly formatted
X = X.float()
y = y.float()

# Test kernel combination and variance changes
# Create GP with only graph kernel
gp_graph_only = MixedSingleTaskGP(
Expand All @@ -142,16 +146,16 @@ def test_kernel_combinations_and_properties(sample_data):
)

output_graph = gp_graph_only.forward(X, graphs)
graph_var = output_graph.variance
graph_var = output_graph.variance.detach() # Detach to avoid gradient computation

# Create GP with combined kernels
n_numerical = 2
n_numerical = X.shape[1] # Use actual number of features from X
matern = ScaleKernel(
MaternKernel(
nu=2.5,
ard_num_dims=n_numerical,
active_dims=tuple(range(n_numerical)),
),
)
)

gp_combined = MixedSingleTaskGP(
Expand All @@ -162,24 +166,35 @@ def test_kernel_combinations_and_properties(sample_data):
)

output_combined = gp_combined.forward(X, graphs)
combined_var = output_combined.variance
combined_var = output_combined.variance.detach()

# Combined kernel should have larger variance due to addition
assert torch.all(combined_var > graph_var)
assert torch.all(combined_var >= graph_var - 1e-6) # Allow for numerical precision

# Use graphs with slight variations to avoid singular matrix
similar_graphs = [
nx.complete_graph(5) for _ in range(len(graphs))
]
# Create similar but slightly different graphs
similar_graphs = []
for _ in range(len(graphs)):
G = nx.Graph()
G.add_nodes_from(range(5))
G.add_edges_from([(i, j) for i in range(5) for j in range(i + 1, 5)])
similar_graphs.append(G)

# Add small random perturbations to make graphs slightly different
for i in range(1, len(similar_graphs)):
G = similar_graphs[i]
# Add or remove edges with a small probability
edges_to_add = [(u, v) for u in range(5) for v in range(u + 1, 5)
if not G.has_edge(u, v) and torch.rand(1) < 0.1]
edges_to_remove = [(u, v) for (u, v) in G.edges()
if torch.rand(1) < 0.1]
edges_to_add = []
edges_to_remove = []

# Use fixed random seed for reproducibility
torch.manual_seed(i)

for u in range(5):
for v in range(u + 1, 5):
if not G.has_edge(u, v) and torch.rand(1) < 0.1:
edges_to_add.append((u, v))
elif G.has_edge(u, v) and torch.rand(1) < 0.1:
edges_to_remove.append((u, v))

G.add_edges_from(edges_to_add)
G.remove_edges_from(edges_to_remove)
Expand All @@ -195,10 +210,12 @@ def test_kernel_combinations_and_properties(sample_data):
diag = kernel_matrix.diag()

# Allow for slight variations due to graph perturbations
assert torch.allclose(diag, diag[0], atol=1e-1)
assert torch.allclose(diag, torch.ones_like(diag), atol=1e-6)

# Check that the matrix is not completely uniform
assert not torch.allclose(kernel_matrix, torch.ones_like(kernel_matrix), rtol=1e-5)
off_diag = kernel_matrix - torch.eye(kernel_matrix.size(0),
device=kernel_matrix.device)
assert not torch.allclose(off_diag, torch.zeros_like(off_diag), atol=1e-3)


def test_model_prediction_consistency(sample_data, sample_kernels):
Expand Down

0 comments on commit ad55030

Please sign in to comment.