Skip to content

Commit

Permalink
Touchups and import fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jun 18, 2024
1 parent aa1f6db commit dadd467
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 60 deletions.
2 changes: 2 additions & 0 deletions project/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from .image_classification.fashion_mnist import FashionMNISTDataModule
from .image_classification.imagenet import ImageNetDataModule
from .image_classification.imagenet32 import ImageNet32DataModule, imagenet32_normalization
from .image_classification.inaturalist import INaturalistDataModule
from .image_classification.mnist import MNISTDataModule
from .vision.base import VisionDataModule

__all__ = [
"cifar10_normalization",
"CIFAR10DataModule",
"FashionMNISTDataModule",
"INaturalistDataModule",
"ImageClassificationDataModule",
"imagenet32_normalization",
"ImageNet32DataModule",
Expand Down
9 changes: 6 additions & 3 deletions project/datamodules/image_classification/imagenet32.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torchvision.datasets import VisionDataset
from torchvision.transforms import v2 as transforms

from project.utils.env_vars import SCRATCH
from project.utils.types import C, H, StageStr, W

from ..vision.base import VisionDataModule
Expand Down Expand Up @@ -177,10 +178,10 @@ class ImageNet32DataModule(VisionDataModule):
def __init__(
self,
data_dir: str | Path,
readonly_datasets_dir: str | Path | None = None,
readonly_datasets_dir: str | Path | None = SCRATCH,
val_split: int | float = -1,
num_images_per_val_class: int | None = 50,
num_workers: int | None = 0,
num_workers: int = 0,
normalize: bool = False,
batch_size: int = 32,
seed: int = 42,
Expand Down Expand Up @@ -221,7 +222,9 @@ def __init__(

# ImageNetDataModule uses num_imgs_per_val_class: int = 50, which makes sense! Here
# however we're using probably more than that for validation.
self.EXTRA_ARGS["readonly_datasets_dir"] = readonly_datasets_dir
self.train_kwargs["readonly_datasets_dir"] = readonly_datasets_dir
self.valid_kwargs["readonly_datasets_dir"] = readonly_datasets_dir
self.test_kwargs["readonly_datasets_dir"] = readonly_datasets_dir
self.dataset_train: ImageNet32Dataset | Subset
self.dataset_val: ImageNet32Dataset | Subset
self.dataset_test: ImageNet32Dataset | Subset
Expand Down
64 changes: 23 additions & 41 deletions project/datamodules/vision/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,19 @@ def __init__(
self.test_transforms = test_transforms
self.EXTRA_ARGS = kwargs

self.train_kwargs = self.EXTRA_ARGS | {
"transform": self.train_transforms or self.default_transforms()
}
self.valid_kwargs = self.EXTRA_ARGS | {
"transform": self.val_transforms or self.default_transforms()
}
self.test_kwargs = self.EXTRA_ARGS | {
"transform": self.test_transforms or self.default_transforms()
}
self.train_kwargs: dict = self.EXTRA_ARGS
self.valid_kwargs: dict = self.EXTRA_ARGS
self.test_kwargs: dict = self.EXTRA_ARGS
if _has_constructor_argument(self.dataset_cls, "train"):
self.train_kwargs["train"] = True
self.valid_kwargs["train"] = True
self.test_kwargs["train"] = False

# todo: what about the shuffling at each epoch?
_rng = torch.Generator(device="cpu").manual_seed(self.seed)
self.train_dl_rng_seed = int(
torch.randint(0, int(1e6), (1,), generator=_rng).item()
)
self.val_dl_rng_seed = int(
torch.randint(0, int(1e6), (1,), generator=_rng).item()
)
self.test_dl_rng_seed = int(
torch.randint(0, int(1e6), (1,), generator=_rng).item()
)
self.train_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item())
self.val_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item())
self.test_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item())

self.test_dataset_cls = self.dataset_cls

Expand All @@ -135,6 +123,16 @@ def prepare_data(self) -> None:
)
self.test_dataset_cls(str(self.data_dir), **test_kwargs)

self.train_kwargs = self.EXTRA_ARGS | {
"transform": self.train_transforms or self.default_transforms()
}
self.valid_kwargs = self.EXTRA_ARGS | {
"transform": self.val_transforms or self.default_transforms()
}
self.test_kwargs = self.EXTRA_ARGS | {
"transform": self.test_transforms or self.default_transforms()
}

def setup(self, stage: StageStr | None = None) -> None:
"""Creates train, val, and test dataset."""
if stage in ["fit", "validate"] or stage is None:
Expand All @@ -156,9 +154,7 @@ def setup(self, stage: StageStr | None = None) -> None:

if stage == "test" or stage is None:
logger.debug(f"creating test dataset with kwargs {self.train_kwargs}")
self.dataset_test = self.test_dataset_cls(
str(self.data_dir), **self.test_kwargs
)
self.dataset_test = self.test_dataset_cls(str(self.data_dir), **self.test_kwargs)

def _split_dataset(self, dataset: VisionDataset, train: bool = True) -> Dataset:
"""Splits the dataset into train and validation set."""
Expand Down Expand Up @@ -190,9 +186,7 @@ def _get_splits(self, len_dataset: int) -> list[int]:
def default_transforms(self) -> Callable:
"""Default transform for the dataset."""

def train_dataloader[
**P
](
def train_dataloader[**P](
self,
_dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader,
*args: P.args,
Expand All @@ -212,9 +206,7 @@ def train_dataloader[
),
)

def val_dataloader[
**P
](
def val_dataloader[**P](
self,
_dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader,
*args: P.args,
Expand All @@ -226,15 +218,10 @@ def val_dataloader[
self.dataset_val,
_dataloader_fn=_dataloader_fn,
*args,
**(
dict(generator=torch.Generator().manual_seed(self.val_dl_rng_seed))
| kwargs
),
**(dict(generator=torch.Generator().manual_seed(self.val_dl_rng_seed)) | kwargs),
)

def test_dataloader[
**P
](
def test_dataloader[**P](
self,
_dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader,
*args: P.args,
Expand All @@ -248,15 +235,10 @@ def test_dataloader[
self.dataset_test,
_dataloader_fn=_dataloader_fn,
*args,
**(
dict(generator=torch.Generator().manual_seed(self.test_dl_rng_seed))
| kwargs
),
**(dict(generator=torch.Generator().manual_seed(self.test_dl_rng_seed)) | kwargs),
)

def _data_loader[
**P
](
def _data_loader[**P](
self,
dataset: Dataset,
_dataloader_fn: Callable[Concatenate[Dataset, P], DataLoader] = DataLoader,
Expand Down
28 changes: 12 additions & 16 deletions project/utils/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
from torch import Tensor, nn
from torch.optim import Optimizer

from project.configs import Config, cs
from project.configs.datamodule import DATA_DIR
from project.configs import Config
from project.datamodules.image_classification import (
ImageClassificationDataModule,
)
Expand All @@ -47,7 +46,7 @@
"inaturalist": [
pytest.mark.slow,
pytest.mark.xfail(
not Path("/network/datasets/inat").exists(),
not (NETWORK_DIR and (NETWORK_DIR / "datasets/inat").exists()),
strict=True,
raises=hydra.errors.InstantiationException,
reason="Expects to be run on the Mila cluster for now",
Expand All @@ -62,17 +61,6 @@
reason="Expects to be run on a cluster with the ImageNet dataset.",
),
],
"rl": [
pytest.mark.xfail(
strict=False,
raises=AssertionError,
# match="Shapes are not the same."
reason="Isn't entirely deterministic yet.",
),
],
"moving_mnist": [
(pytest.mark.slow if not (DATA_DIR / "MovingMNIST").exists() else pytest.mark.timeout(5))
],
}
"""Dict with some default marks for some configs name."""

Expand Down Expand Up @@ -169,11 +157,15 @@ def get_all_algorithm_names() -> list[str]:
return get_all_configs_in_group("algorithm")


def get_type_for_config_name(config_group: str, config_name: str, _cs: ConfigStore = cs) -> type:
def get_type_for_config_name(
config_group: str, config_name: str, _cs: ConfigStore | None = None
) -> type:
"""Returns the class that is to be instantiated by the given config name.
In the case of inner dataclasses (e.g. Model.HParams), this returns the outer class (Model).
"""
if _cs is None:
from project.configs import cs as _cs

config_loader = get_config_loader()
_, caching_repo = config_loader._parse_overrides_and_create_caching_repo(
Expand Down Expand Up @@ -288,7 +280,11 @@ def test_network_output_is_reproducible(network: nn.Module, x: Tensor):

def get_all_datamodule_names() -> list[str]:
"""Retrieves the names of all the datamodules that are saved in the ConfigStore of Hydra."""
return get_all_configs_in_group("datamodule")
datamodules = get_all_configs_in_group("datamodule")
# todo: automatically detect which ones are configs for ABCs and remove them?
if "vision" in datamodules:
datamodules.remove("vision")
return datamodules


def get_all_datamodule_names_params():
Expand Down

0 comments on commit dadd467

Please sign in to comment.