diff --git a/viscy/data/gpu_aug.py b/viscy/data/gpu_aug.py index a445fff5..09302aad 100644 --- a/viscy/data/gpu_aug.py +++ b/viscy/data/gpu_aug.py @@ -197,6 +197,8 @@ class CachedOmeZarrDataModule(GPUTransformDataModule): Use page-locked memory in data-loaders, by default True skip_cache : bool, optional Skip caching for this dataset, by default False + include_wells : list[str], optional + List of well names to include in the dataset, by default None (all) """ def __init__( @@ -212,6 +214,7 @@ def __init__( val_gpu_transforms: list[DictTransform], pin_memory: bool = True, skip_cache: bool = False, + include_wells: list[str] | None = None, ): super().__init__() self.data_path = data_path @@ -225,6 +228,7 @@ def __init__( self._val_gpu_transforms = Compose(val_gpu_transforms) self.pin_memory = pin_memory self.skip_cache = skip_cache + self._include_wells = include_wells @property def train_cpu_transforms(self) -> Compose: @@ -248,12 +252,30 @@ def _set_fit_global_state(self, num_positions: int) -> list[int]: # shuffle positions, randomness is handled globally return torch.randperm(num_positions).tolist() + def _include_well_name(self, name: str) -> bool: + if self._include_wells is None: + return True + else: + return name in self._include_wells + + def _filter_fit_fovs(self, plate: Plate) -> list[Position]: + positions = [] + for well_name, well in plate.wells(): + if self._include_well_name(well_name): + for _, p in well.positions(): + positions.append(p) + if len(positions) < 2: + raise ValueError( + "At least 2 FOVs are required for training and validation." + ) + return positions + def setup(self, stage: Literal["fit", "validate"]) -> None: if stage not in ("fit", "validate"): raise NotImplementedError("Only fit and validate stages are supported.") cache_map = Manager().dict() plate: Plate = open_ome_zarr(self.data_path, mode="r", layout="hcs") - positions = [p for _, p in plate.positions()] + positions = self._filter_fit_fovs(plate) shuffled_indices = self._set_fit_global_state(len(positions)) num_train_fovs = int(len(positions) * self.split_ratio) train_fovs = [positions[i] for i in shuffled_indices[:num_train_fovs]]