Skip to content

Commit

Permalink
select which wells to include in fit
Browse files Browse the repository at this point in the history
  • Loading branch information
ziw-liu committed Dec 11, 2024
1 parent bea9413 commit a23efe9
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion viscy/data/gpu_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ class CachedOmeZarrDataModule(GPUTransformDataModule):
Use page-locked memory in data-loaders, by default True
skip_cache : bool, optional
Skip caching for this dataset, by default False
include_wells : list[str], optional
List of well names to include in the dataset, by default None (all)
"""

def __init__(
Expand All @@ -212,6 +214,7 @@ def __init__(
val_gpu_transforms: list[DictTransform],
pin_memory: bool = True,
skip_cache: bool = False,
include_wells: list[str] | None = None,
):
super().__init__()
self.data_path = data_path
Expand All @@ -225,6 +228,7 @@ def __init__(
self._val_gpu_transforms = Compose(val_gpu_transforms)
self.pin_memory = pin_memory
self.skip_cache = skip_cache
self._include_wells = include_wells

@property
def train_cpu_transforms(self) -> Compose:
Expand All @@ -248,12 +252,30 @@ def _set_fit_global_state(self, num_positions: int) -> list[int]:
# shuffle positions, randomness is handled globally
return torch.randperm(num_positions).tolist()

def _include_well_name(self, name: str) -> bool:
if self._include_wells is None:
return True
else:
return name in self._include_wells

def _filter_fit_fovs(self, plate: Plate) -> list[Position]:
positions = []
for well_name, well in plate.wells():
if self._include_well_name(well_name):
for _, p in well.positions():
positions.append(p)
if len(positions) < 2:
raise ValueError(
"At least 2 FOVs are required for training and validation."
)
return positions

def setup(self, stage: Literal["fit", "validate"]) -> None:
if stage not in ("fit", "validate"):
raise NotImplementedError("Only fit and validate stages are supported.")
cache_map = Manager().dict()
plate: Plate = open_ome_zarr(self.data_path, mode="r", layout="hcs")
positions = [p for _, p in plate.positions()]
positions = self._filter_fit_fovs(plate)
shuffled_indices = self._set_fit_global_state(len(positions))
num_train_fovs = int(len(positions) * self.split_ratio)
train_fovs = [positions[i] for i in shuffled_indices[:num_train_fovs]]
Expand Down

0 comments on commit a23efe9

Please sign in to comment.