Skip to content

Commit

Permalink
Return the best graph from optimize_acqf_graph function
Browse files Browse the repository at this point in the history
  • Loading branch information
vladislavalerievich committed Jan 24, 2025
1 parent 6794a05 commit e94bbe8
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 150 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
product(*cats_per_column.values())]

# Optimize the acquisition function with graph sampling
best_candidate, best_score = optimize_acqf_graph(
best_candidate, best_graph, best_score = optimize_acqf_graph(
acq_function=acq_function,
bounds=bounds,
fixed_features_list=fixed_cats,
Expand All @@ -120,6 +120,10 @@
)

# Print the results
print(f"Best candidate: {best_candidate}")
print(f"Best graph: {best_graph}")
print(f"Best score: {best_score}")
print(f"Execution time: {time.time() - start_time:.2f} seconds")

# Clear caches after optimization to avoid memory leaks or unexpected behavior
BoTorchWLKernel._compute_kernel.cache_clear()
Expand Down
127 changes: 0 additions & 127 deletions neps/optimizers/models/graphs/graph_aware_gp_optimization_example.py

This file was deleted.

32 changes: 18 additions & 14 deletions neps/optimizers/models/graphs/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ def optimize_acqf_graph(
num_restarts: int = 10,
raw_samples: int = 1024,
q: int = 1,
) -> tuple[torch.Tensor, float]:
) -> tuple[torch.Tensor, nx.Graph, float]:
"""Optimize an acquisition function with graph sampling.
This function optimizes the acquisition function by sampling graphs from the training
set, temporarily updating the kernel's graph lookup, and evaluating the acquisition
function for each sampled graph. The best candidate and its corresponding acquisition
score are returned.
function for each sampled graph. The best candidate, the best graph, and its
corresponding acquisition score are returned.
Args:
acq_function (AcquisitionFunction): The acquisition function to optimize.
Expand All @@ -47,32 +47,31 @@ def optimize_acqf_graph(
q (int): The number of candidates to generate. Defaults to 1.
Returns:
tuple[torch.Tensor, float]: A tuple containing the best candidate (as a tensor)
and its corresponding acquisition score.
tuple[torch.Tensor, nx.Graph, float]: A tuple containing the best candidate
(as a tensor), the best graph, and its corresponding acquisition score.
Raises:
ValueError: If `train_graphs` is None.
"""
if train_graphs is None:
raise ValueError("train_graphs cannot be None.")

# Sample graphs from the training set
sampled_graphs = sample_graphs(train_graphs, num_samples=num_graph_samples)

# Initialize lists to store the best candidates and their scores
best_candidates, best_scores = [], []
best_candidates, best_graphs, best_scores = [], [], []

# Get the index of the graph feature in the bounds
graph_idx = bounds.shape[1] - 1

# Iterate through each sampled graph
# Todo: Instead of iterating over the graphs, optimize by putting all
# sampled graphs into the kernel and compute the scores in a single batch.
# Update the caching logic accordingly.
for graph in sampled_graphs:
# Temporarily set the graph lookup for the kernel
with set_graph_lookup(acq_function.model.covar_module, [graph], append=True):
# Iterate through each fixed feature configuration (if provided)
for fixed_features in fixed_features_list or [{}]:
# Add the graph index to the fixed features, indicating that the last
# graphin the lookup should be used
# graph in the lookup should be used
updated_fixed_features = {**fixed_features, graph_idx: -1.0}

# Optimize the acquisition function with the updated fixed features
Expand All @@ -85,12 +84,17 @@ def optimize_acqf_graph(
q=q,
)

# Store the candidates and their scores
# Store the candidates, graphs, and their scores
best_candidates.append(candidates)
best_graphs.append(graph)
best_scores.append(scores)

# Find the index of the best score
best_idx = torch.argmax(torch.tensor(best_scores))

# Return the best candidate and its score
return best_candidates[best_idx], best_scores[best_idx].item()
# Return the best candidate (without the graph index), the best graph, and its score
return (
best_candidates[best_idx][:, :-1],
best_graphs[best_idx],
best_scores[best_idx].item()
)
48 changes: 40 additions & 8 deletions tests/test_graphs/test_optimization_over_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_acquisition_function_optimization(self, setup_data: dict) -> None:
product(*cats_per_column.values())]

# Optimize the acquisition function
best_candidate, best_score = optimize_acqf_graph(
best_candidate, best_graph, best_score = optimize_acqf_graph(
acq_function=acq_function,
bounds=bounds,
fixed_features_list=fixed_cats,
Expand All @@ -179,9 +179,17 @@ def test_acquisition_function_optimization(self, setup_data: dict) -> None:
q=1,
)

# Basic checks
assert best_candidate.shape == (1, train_x.shape[1])
assert isinstance(best_score, float)
# Assertions for the acquisition function optimization
assert isinstance(best_candidate,
torch.Tensor), "Best candidate should be a tensor"
assert best_candidate.shape == (1, train_x.shape[1] - 1), \
"Best candidate should have the correct shape (excluding the graph index)"
assert isinstance(best_graph, nx.Graph), "Best graph should be a NetworkX graph"
assert isinstance(best_score, float), "Best score should be a float"

# Ensure the best candidate does not contain the graph index column
assert best_candidate.shape[1] == train_x.shape[1] - 1, \
"Best candidate should not include the graph index column"

def test_graph_sampling(self, setup_data: dict) -> None:
"""Test the graph sampling functionality."""
Expand All @@ -192,10 +200,34 @@ def test_graph_sampling(self, setup_data: dict) -> None:
sampled_graphs = sample_graphs(train_graphs, num_samples=num_samples)

# Basic checks
assert len(sampled_graphs) == num_samples
for graph in sampled_graphs:
assert isinstance(graph, nx.Graph)
assert nx.is_connected(graph)
assert len(sampled_graphs) == num_samples, \
f"Expected {num_samples} sampled graphs, got {len(sampled_graphs)}"
assert all(isinstance(graph, nx.Graph) for graph in sampled_graphs), \
"All sampled graphs should be NetworkX graphs"
assert all(nx.is_connected(graph) for graph in sampled_graphs), \
"All sampled graphs should be connected"

def test_min_max_scaling(self, setup_data: dict) -> None:
"""Test the min-max scaling utility."""
train_x = setup_data["train_x"]

# Apply min-max scaling
scaled_train_x = min_max_scale(train_x)

# Assertions for min-max scaling
assert torch.all(scaled_train_x >= 0), "Scaled values should be >= 0"
assert torch.all(scaled_train_x <= 1), "Scaled values should be <= 1"
assert scaled_train_x.shape == train_x.shape, \
"Scaled data should have the same shape as the input data"

# Check that the scaling is correct
for i in range(train_x.shape[1]):
col_min = torch.min(train_x[:, i])
col_max = torch.max(train_x[:, i])
if col_min != col_max: # Avoid division by zero
expected_scaled_col = (train_x[:, i] - col_min) / (col_max - col_min)
assert torch.allclose(scaled_train_x[:, i], expected_scaled_col), \
f"Scaling is incorrect for column {i}"

def test_set_graph_lookup(self, setup_data: dict) -> None:
"""Test the set_graph_lookup context manager."""
Expand Down

0 comments on commit e94bbe8

Please sign in to comment.