Skip to content

Commit

Permalink
Improvement of qBayesianActiveLearningByDisagreement (#2457)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2457

Improvement of the implementation of qBayesianActiveLearningByDisagreement
- Utilizes a Monte Carlo approach for approximating the entropy
- Does not use concatenate_pending_points, as it is not evident that fantasizing makes sense in the same way as for standard MC acquisition functions
- Can accept posterior transforms

- get_model and get_fully_bayesian_model are used in tests to be similar to other tests (e.g. JES & the subsequent active learning acqfs to enable move to test_helpers

Reviewed By: saitcakmak

Differential Revision: D60308502

fbshipit-source-id: 6de1dffc4f497ef4823428b2903b19ff8f0d60d7
  • Loading branch information
Carl Hvarfner authored and facebook-github-bot committed Aug 1, 2024
1 parent 9ddd9eb commit e44280e
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 52 deletions.
77 changes: 51 additions & 26 deletions botorch/acquisition/bayesian_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@

from typing import Optional

import torch
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.fully_bayesian import MCMC_DIM, SaasFullyBayesianSingleTaskGP
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
from torch import Tensor

Expand Down Expand Up @@ -54,48 +55,72 @@ class qBayesianActiveLearningByDisagreement(
def __init__(
self,
model: SaasFullyBayesianSingleTaskGP,
sampler: Optional[MCSampler] = None,
posterior_transform: Optional[PosteriorTransform] = None,
X_pending: Optional[Tensor] = None,
) -> None:
"""
Batch implementation [kirsch2019batchbald]_ of BALD [Houlsby2011bald]_,
which maximizes the mutual information between the next observation and the
hyperparameters of the model. Computed by informational lower bound.
hyperparameters of the model. Computed by Monte Carlo integration.
Args:
model: A fully bayesian single-outcome model.
X_pending: A `batch_shape, m x d`-dim Tensor of `m` design points.
model: A fully bayesian model (SaasFullyBayesianSingleTaskGP).
sampler: The sampler used for drawing samples to approximate the entropy
of the Gaussian Mixture posterior.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points
"""
super().__init__(model)
super().__init__(model=model)
MCSamplerMixin.__init__(self, sampler=sampler)
self.set_X_pending(X_pending)
self.posterior_transform = posterior_transform

@concatenate_pending_points
@t_batch_mode_transform()
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate qBayesianActiveLearningByDisagreement on the candidate set `X`.
A monte carlo-estimated information gain is computed over a Gaussian Mixture
marginal posterior, and the Gaussian conditional posterior to obtain the
qBayesianActiveLearningByDisagreement on the candidate set `X`.
Args:
X: `batch_shape x q x D`-dim Tensor of input points.
Returns:
A `batch_shape x num_models`-dim Tensor of BALD values.
"""
return self._compute_lower_bound_information_gain(X)

def _compute_lower_bound_information_gain(self, X: Tensor) -> Tensor:
r"""Evaluates the lower bounded information gain on the candidate set `X`.
Args:
X: `batch_shape x q x D`-dim Tensor of input points.
Returns:
A `batch_shape x num_models`-dim Tensor of information gains.
"""
posterior = self.model.posterior(X, observation_noise=True)
marg_covar = posterior.mixture_covariance_matrix
cond_variances = posterior.variance

prev_entropy = torch.logdet(marg_covar).unsqueeze(-1)
# squeeze excess dim and mean over q-batch
post_ub_entropy = torch.log(cond_variances).squeeze(-1).mean(-1)

return prev_entropy - post_ub_entropy
posterior = self.model.posterior(
X, observation_noise=True, posterior_transform=self.posterior_transform
)
# draw samples from the mixture posterior.
# samples: num_samples x batch_shape x num_models x q x num_outputs
samples = self.get_posterior_samples(posterior=posterior)

# Estimate the entropy of 'num_samples' samples from 'num_models' models by
# evaluating the log_prob on each sample on the mixture posterior
# (which constitutes of M models). thus, order N*M^2 computations

# Make room and move the model dim to the front, squeeze the num_outputs dim.
# prev_samples: num_models x num_samples x batch_shape x 1 x q
prev_samples = samples.unsqueeze(0).transpose(0, MCMC_DIM).squeeze(-1)

# avg the probs over models in the mixture - dim (-2) will be broadcasted
# with the num_models of the posterior --> querying all samples on all models
# posterior.mvn takes q-dimensional input by default, which removes the q-dim
# component_sample_probs: num_models x num_samples x batch_shape x num_models
component_sample_probs = posterior.mvn.log_prob(prev_samples).exp()

# average over mixture components
mixture_sample_probs = component_sample_probs.mean(dim=-1)

# this is the average over the model and sample dim
prev_entropy = -mixture_sample_probs.log().mean(dim=[0, 1])

# the posterior entropy is an average entropy over gaussians, so no mixture
post_entropy = -posterior.mvn.log_prob(samples.squeeze(-1)).mean(0)
bald = prev_entropy.unsqueeze(-1) - post_entropy
return bald
4 changes: 4 additions & 0 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,9 +1678,13 @@ def construct_inputs_qJES(
def construct_inputs_BALD(
model: Model,
X_pending: Optional[Tensor] = None,
sampler: Optional[MCSampler] = None,
posterior_transform: Optional[PosteriorTransform] = None,
):
inputs = {
"model": model,
"X_pending": X_pending,
"sampler": sampler,
"posterior_transform": posterior_transform,
}
return inputs
88 changes: 62 additions & 26 deletions test/acquisition/test_bayesian_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,32 @@
from botorch.models import SingleTaskGP
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.transforms.outcome import Standardize
from botorch.sampling.normal import IIDNormalSampler
from botorch.utils.testing import BotorchTestCase


def get_model(
train_X,
train_Y,
standardize_model,
**tkwargs,
):
num_objectives = train_Y.shape[-1]

if standardize_model:
outcome_transform = Standardize(m=num_objectives)
else:
outcome_transform = None

model = SingleTaskGP(
train_X=train_X,
train_Y=train_Y,
outcome_transform=outcome_transform,
)

return model


def _get_mcmc_samples(num_samples: int, dim: int, infer_noise: bool, **tkwargs):

mcmc_samples = {
Expand All @@ -28,7 +51,7 @@ def _get_mcmc_samples(num_samples: int, dim: int, infer_noise: bool, **tkwargs):
return mcmc_samples


def get_model(
def get_fully_bayesian_model(
train_X,
train_Y,
num_models,
Expand Down Expand Up @@ -72,21 +95,26 @@ def test_q_bayesian_active_learning_by_disagreement(self):
tkwargs = {"device": self.device}
num_objectives = 1
num_models = 3
input_dim = 2

X_pending_list = [None, torch.rand(2, input_dim)]
for (
dtype,
standardize_model,
infer_noise,
X_pending,
) in product(
(torch.float, torch.double),
(False, True), # standardize_model
(True,), # infer_noise - only one option avail in PyroModels
X_pending_list,
):
X_pending = X_pending.to(**tkwargs) if X_pending is not None else None
tkwargs["dtype"] = dtype
input_dim = 2
train_X = torch.rand(4, input_dim, **tkwargs)
train_Y = torch.rand(4, num_objectives, **tkwargs)

model = get_model(
model = get_fully_bayesian_model(
train_X,
train_Y,
num_models,
Expand All @@ -96,32 +124,40 @@ def test_q_bayesian_active_learning_by_disagreement(self):
)

# test acquisition
X_pending_list = [None, torch.rand(2, input_dim, **tkwargs)]
for i in range(len(X_pending_list)):
X_pending = X_pending_list[i]

acq = qBayesianActiveLearningByDisagreement(
model=model,
X_pending=X_pending,
)

test_Xs = [
torch.rand(4, 1, input_dim, **tkwargs),
torch.rand(4, 3, input_dim, **tkwargs),
torch.rand(4, 5, 1, input_dim, **tkwargs),
torch.rand(4, 5, 3, input_dim, **tkwargs),
]

for j in range(len(test_Xs)):
acq_X = acq.forward(test_Xs[j])
acq_X = acq(test_Xs[j])
# assess shape
self.assertTrue(acq_X.shape == test_Xs[j].shape[:-2])
acq = qBayesianActiveLearningByDisagreement(
model=model,
X_pending=X_pending,
)

acq2 = qBayesianActiveLearningByDisagreement(
model=model, sampler=IIDNormalSampler(torch.Size([9]))
)
self.assertIsInstance(acq2.sampler, IIDNormalSampler)

test_Xs = [
torch.rand(4, 1, input_dim, **tkwargs),
torch.rand(4, 3, input_dim, **tkwargs),
torch.rand(4, 5, 1, input_dim, **tkwargs),
torch.rand(4, 5, 3, input_dim, **tkwargs),
torch.rand(5, 13, input_dim, **tkwargs),
]

for j in range(len(test_Xs)):
acq_X = acq.forward(test_Xs[j])
acq_X = acq(test_Xs[j])
# assess shape
self.assertTrue(acq_X.shape == test_Xs[j].shape[:-2])

self.assertTrue(torch.all(acq_X > 0))

# Support with non-fully bayesian models is not possible. Thus, we
# throw an error.
non_fully_bayesian_model = SingleTaskGP(train_X, train_Y)
with self.assertRaises(ValueError):
non_fully_bayesian_model = get_model(train_X, train_Y, False)
with self.assertRaisesRegex(
ValueError,
"Fully Bayesian acquisition functions require a "
"SaasFullyBayesianSingleTaskGP to run.",
):
acq = qBayesianActiveLearningByDisagreement(
model=non_fully_bayesian_model,
)

0 comments on commit e44280e

Please sign in to comment.