From 7f05f05b06ea3a2787c1fee6aa418b470a391bd4 Mon Sep 17 00:00:00 2001 From: Bernie Beckerman Date: Thu, 24 Oct 2024 18:29:12 -0700 Subject: [PATCH] introduce trial_indices argument to SupervisedDataset (#2595) Summary: X-link: https://github.com/facebook/Ax/pull/2960 Adds optional `trial_indices` to SupervisedDataset, whose dimensionality should correspond 1:1 with the first few dimensions of X and Y tensors, as validated in `_validate` ([pointer](https://www.internalfb.com/diff/D64764019?permalink=1739375523489084)). Reviewed By: Balandat Differential Revision: D64764019 --- botorch/utils/datasets.py | 29 ++++++++++++++++++++++++++++- test/utils/test_datasets.py | 22 +++++++++++++++++++--- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/botorch/utils/datasets.py b/botorch/utils/datasets.py index f11f5c80e7..e6de660f03 100644 --- a/botorch/utils/datasets.py +++ b/botorch/utils/datasets.py @@ -14,6 +14,7 @@ import torch from botorch.exceptions.errors import InputDataError, UnsupportedError from botorch.utils.containers import BotorchContainer, SliceContainer +from pyre_extensions import none_throws from torch import long, ones, Tensor @@ -54,6 +55,7 @@ def __init__( outcome_names: list[str], Yvar: BotorchContainer | Tensor | None = None, validate_init: bool = True, + trial_indices: Tensor | None = None, ) -> None: r"""Constructs a `SupervisedDataset`. @@ -65,12 +67,16 @@ def __init__( Yvar: An optional `Tensor` or `BotorchContainer` representing the observation noise. validate_init: If `True`, validates the input shapes. + trial_indices: A `Tensor` representing the trial indices of X and Y. This is + used to support learning-curve-based modeling. If provided, it must + have compatible shape with X and Y. """ self._X = X self._Y = Y self._Yvar = Yvar self.feature_names = feature_names self.outcome_names = outcome_names + self.trial_indices = trial_indices if validate_init: self._validate() @@ -96,6 +102,7 @@ def _validate( self, validate_feature_names: bool = True, validate_outcome_names: bool = True, + validate_trial_indices: bool = True, ) -> None: r"""Checks that the shapes of the inputs are compatible with each other. @@ -108,6 +115,8 @@ def _validate( `outcomes_names` matches the # of columns of `self.Y`. If a particular dataset, e.g., `RankingDataset`, is known to violate this assumption, this can be set to `False`. + validate_trial_indices: By default, we validate that the shape of + `trial_indices` matches the shape of X and Y. """ shape_X = self.X.shape if isinstance(self._X, BotorchContainer): @@ -133,8 +142,19 @@ def _validate( "`Y` must have the same number of columns as the number of " "outcomes in `outcome_names`." ) + if validate_trial_indices and self.trial_indices is not None: + if self.trial_indices.shape != shape_X: + raise ValueError( + f"{shape_X=} must have the same shape as {none_throws(self.trial_indices).shape=}." + ) def __eq__(self, other: Any) -> bool: + if self.trial_indices is None and other.trial_indices is None: + trial_indices_equal = True + elif self.trial_indices is None or other.trial_indices is None: + trial_indices_equal = False + else: + trial_indices_equal = torch.equal(self.trial_indices, other.trial_indices) return ( type(other) is type(self) and torch.equal(self.X, other.X) @@ -146,6 +166,7 @@ def __eq__(self, other: Any) -> bool: ) and self.feature_names == other.feature_names and self.outcome_names == other.outcome_names + and trial_indices_equal ) @@ -241,7 +262,11 @@ def __init__( ) def _validate(self) -> None: - super()._validate(validate_feature_names=False, validate_outcome_names=False) + super()._validate( + validate_feature_names=False, + validate_outcome_names=False, + validate_trial_indices=False, + ) if len(self.feature_names) != self._X.values.shape[-1]: raise ValueError( "The `values` field of `X` must have the same number of columns as " @@ -316,6 +341,7 @@ def __init__( self.has_heterogeneous_features = any( datasets[0].feature_names != ds.feature_names for ds in datasets[1:] ) + self.trial_indices = None @classmethod def from_joint_dataset( @@ -538,6 +564,7 @@ def __init__( c: [self.feature_names.index(i) for i in parameter_decomposition[c]] for c in self.context_buckets } + self.trial_indices = None @property def X(self) -> Tensor: diff --git a/test/utils/test_datasets.py b/test/utils/test_datasets.py index 22d8c24a50..586897fdf0 100644 --- a/test/utils/test_datasets.py +++ b/test/utils/test_datasets.py @@ -43,14 +43,20 @@ def make_dataset( class TestDatasets(BotorchTestCase): def test_supervised(self): # Generate some data - X = rand(3, 2) - Y = rand(3, 1) + n_rows = 3 + X = rand(n_rows, 2) + Y = rand(n_rows, 1) feature_names = ["x1", "x2"] outcome_names = ["y"] + trial_indices = tensor(range(n_rows)) # Test `__init__` dataset = SupervisedDataset( - X=X, Y=Y, feature_names=feature_names, outcome_names=outcome_names + X=X, + Y=Y, + feature_names=feature_names, + outcome_names=outcome_names, + trial_indices=trial_indices, ) self.assertIsInstance(dataset.X, Tensor) self.assertIsInstance(dataset._X, Tensor) @@ -58,12 +64,14 @@ def test_supervised(self): self.assertIsInstance(dataset._Y, Tensor) self.assertEqual(dataset.feature_names, feature_names) self.assertEqual(dataset.outcome_names, outcome_names) + self.assertTrue(torch.equal(dataset.trial_indices, trial_indices)) dataset2 = SupervisedDataset( X=DenseContainer(X, X.shape[-1:]), Y=DenseContainer(Y, Y.shape[-1:]), feature_names=feature_names, outcome_names=outcome_names, + trial_indices=trial_indices, ) self.assertIsInstance(dataset2.X, Tensor) self.assertIsInstance(dataset2._X, DenseContainer) @@ -101,6 +109,14 @@ def test_supervised(self): feature_names=feature_names, outcome_names=[], ) + with self.assertRaisesRegex(ValueError, "trial_indices"): + SupervisedDataset( + X=rand(2, 2), + Y=rand(2, 1), + feature_names=feature_names, + outcome_names=outcome_names, + trial_indices=tensor(range(n_rows + 1)), + ) # Test with Yvar. dataset = SupervisedDataset(