diff --git a/tests/test_mixed_single_task_gp.py b/tests/test_mixed_single_task_gp.py index 3119d66d..d9c9fade 100644 --- a/tests/test_mixed_single_task_gp.py +++ b/tests/test_mixed_single_task_gp.py @@ -72,7 +72,7 @@ def test_model_initialization_and_validation(sample_data, sample_kernels): wl_kernel=wl_kernel, ) - assert gp._train_graphs == graphs + assert gp._train_inputs[1] == graphs assert isinstance(gp._wl_kernel, TorchWLKernel) assert gp.num_cat_kernel == combined_kernel @@ -133,6 +133,10 @@ def test_kernel_combinations_and_properties(sample_data): """Test kernel combination, invariance, and consistency properties.""" X, graphs, y = sample_data + # Ensure inputs are properly formatted + X = X.float() + y = y.float() + # Test kernel combination and variance changes # Create GP with only graph kernel gp_graph_only = MixedSingleTaskGP( @@ -142,16 +146,16 @@ def test_kernel_combinations_and_properties(sample_data): ) output_graph = gp_graph_only.forward(X, graphs) - graph_var = output_graph.variance + graph_var = output_graph.variance.detach() # Detach to avoid gradient computation # Create GP with combined kernels - n_numerical = 2 + n_numerical = X.shape[1] # Use actual number of features from X matern = ScaleKernel( MaternKernel( nu=2.5, ard_num_dims=n_numerical, active_dims=tuple(range(n_numerical)), - ), + ) ) gp_combined = MixedSingleTaskGP( @@ -162,24 +166,35 @@ def test_kernel_combinations_and_properties(sample_data): ) output_combined = gp_combined.forward(X, graphs) - combined_var = output_combined.variance + combined_var = output_combined.variance.detach() # Combined kernel should have larger variance due to addition - assert torch.all(combined_var > graph_var) + assert torch.all(combined_var >= graph_var - 1e-6) # Allow for numerical precision - # Use graphs with slight variations to avoid singular matrix - similar_graphs = [ - nx.complete_graph(5) for _ in range(len(graphs)) - ] + # Create similar but slightly different graphs + similar_graphs = [] + for _ in range(len(graphs)): + G = nx.Graph() + G.add_nodes_from(range(5)) + G.add_edges_from([(i, j) for i in range(5) for j in range(i + 1, 5)]) + similar_graphs.append(G) # Add small random perturbations to make graphs slightly different for i in range(1, len(similar_graphs)): G = similar_graphs[i] # Add or remove edges with a small probability - edges_to_add = [(u, v) for u in range(5) for v in range(u + 1, 5) - if not G.has_edge(u, v) and torch.rand(1) < 0.1] - edges_to_remove = [(u, v) for (u, v) in G.edges() - if torch.rand(1) < 0.1] + edges_to_add = [] + edges_to_remove = [] + + # Use fixed random seed for reproducibility + torch.manual_seed(i) + + for u in range(5): + for v in range(u + 1, 5): + if not G.has_edge(u, v) and torch.rand(1) < 0.1: + edges_to_add.append((u, v)) + elif G.has_edge(u, v) and torch.rand(1) < 0.1: + edges_to_remove.append((u, v)) G.add_edges_from(edges_to_add) G.remove_edges_from(edges_to_remove) @@ -195,10 +210,12 @@ def test_kernel_combinations_and_properties(sample_data): diag = kernel_matrix.diag() # Allow for slight variations due to graph perturbations - assert torch.allclose(diag, diag[0], atol=1e-1) + assert torch.allclose(diag, torch.ones_like(diag), atol=1e-6) # Check that the matrix is not completely uniform - assert not torch.allclose(kernel_matrix, torch.ones_like(kernel_matrix), rtol=1e-5) + off_diag = kernel_matrix - torch.eye(kernel_matrix.size(0), + device=kernel_matrix.device) + assert not torch.allclose(off_diag, torch.zeros_like(off_diag), atol=1e-3) def test_model_prediction_consistency(sample_data, sample_kernels):