diff --git a/GANDLF/config_manager.py b/GANDLF/config_manager.py index ed26ec8e1..c0db814b2 100644 --- a/GANDLF/config_manager.py +++ b/GANDLF/config_manager.py @@ -11,6 +11,7 @@ from GANDLF.metrics import surface_distance_ids from importlib.metadata import version +from GANDLF.utils.pydantic_config import Parameters ## dictionary to define defaults for appropriate options, which are evaluated parameter_defaults = { @@ -653,6 +654,7 @@ def _parseConfig( if "opt" in params: print("DeprecationWarning: 'opt' has been superseded by 'optimizer'") params["optimizer"] = params["opt"] + params.pop("opt") # initialize defaults for patch sampler temp_patch_sampler_dict = { @@ -747,7 +749,10 @@ def ConfigManager( dict: The parameter dictionary. """ try: - return _parseConfig(config_file_path, version_check_flag) + parameters = Parameters( + **_parseConfig(config_file_path, version_check_flag) + ).model_dump() + return parameters except Exception as e: ## todo: ensure logging captures assertion errors assert ( diff --git a/GANDLF/models/imagenet_unet.py b/GANDLF/models/imagenet_unet.py index f1a203d4a..31945db67 100644 --- a/GANDLF/models/imagenet_unet.py +++ b/GANDLF/models/imagenet_unet.py @@ -253,7 +253,7 @@ def __init__(self, parameters) -> None: ) # all BatchNorm should be replaced with InstanceNorm for DP experiments - if "differential_privacy" in parameters: + if parameters["differential_privacy"] is not None: self.replace_batchnorm(self.model) if self.n_dimensions == 3: diff --git a/GANDLF/utils/data_splitter.py b/GANDLF/utils/data_splitter.py index 939c8cfca..5e76d0a06 100644 --- a/GANDLF/utils/data_splitter.py +++ b/GANDLF/utils/data_splitter.py @@ -22,12 +22,12 @@ def split_data( "nested_training" in parameters ), "`nested_training` key missing in parameters" # populate the headers - if "headers" not in parameters: + if parameters["headers"] is None: _, parameters["headers"] = parseTrainingCSV(full_dataset) parameters = ( populate_header_in_parameters(parameters, parameters["headers"]) - if "problem_type" not in parameters + if parameters["problem_type"] is None else parameters ) diff --git a/GANDLF/utils/pydantic_config.py b/GANDLF/utils/pydantic_config.py new file mode 100644 index 000000000..ca533d98a --- /dev/null +++ b/GANDLF/utils/pydantic_config.py @@ -0,0 +1,75 @@ +from pydantic import BaseModel, ConfigDict +from typing import Dict, List, Optional, Union + + +class Version(BaseModel): + minimum: str + maximum: str + + +class Model(BaseModel): + dimension: int + base_filters: int + architecture: str + norm_type: str + final_layer: str + class_list: list[Union[int, str]] + ignore_label_validation: Union[int, None] + amp: bool + print_summary: bool + type: str + data_type: str + save_at_every_epoch: bool + num_channels: Optional[int] = None + + +class Parameters(BaseModel): + model_config = ConfigDict(extra="forbid") + version: Version + model: Model + modality: str + scheduler: dict + learning_rate: float + weighted_loss: bool + verbose: bool + q_verbose: bool + medcam_enabled: bool + save_training: bool + save_output: bool + in_memory: bool + pin_memory_dataloader: bool + scaling_factor: Union[float, int] + q_max_length: int + q_samples_per_volume: int + q_num_workers: int + num_epochs: int + patience: int + batch_size: int + learning_rate: float + clip_grad: Union[None, float] + track_memory_usage: bool + memory_save_mode: bool + print_rgb_label_warning: bool + data_postprocessing: Dict # TODO: maybe is better to create a class + data_preprocessing: Dict # TODO: maybe is better to create a class + grid_aggregator_overlap: str + determinism: bool + previous_parameters: None + metrics: Union[List, dict] + patience: int + parallel_compute_command: Union[str, bool, None] + loss_function: Union[str, Dict] + data_augmentation: dict # TODO: maybe is better to create a class + nested_training: dict # TODO: maybe is better to create a class + optimizer: Union[dict, str] + patch_sampler: Union[dict, str] + patch_size: Union[List[int], int] + clip_mode: Union[str, None] + inference_mechanism: dict + data_postprocessing_after_reverse_one_hot_encoding: dict + enable_padding: Optional[Union[dict, bool]] = None + headers: Optional[dict] = None + output_dir: Optional[str] = "" + problem_type: Optional[str] = None + differential_privacy: Optional[dict] = None + # opt: Optional[Union[dict, str]] = {} # TODO find a better way diff --git a/setup.py b/setup.py index bd75e5ae9..e15aa6d69 100644 --- a/setup.py +++ b/setup.py @@ -85,6 +85,7 @@ "openslide-bin", "openslide-python==1.4.1", "lion-pytorch==0.2.2", + "pydantic", ] if __name__ == "__main__": diff --git a/testing/test_full.py b/testing/test_full.py index eccf0b3c8..83e6b064d 100644 --- a/testing/test_full.py +++ b/testing/test_full.py @@ -20,6 +20,7 @@ get_patch_size_in_microns, convert_to_tiff, ) +from GANDLF.utils.pydantic_config import Parameters from GANDLF.config_manager import ConfigManager from GANDLF.parseConfig import parseConfig from GANDLF.training_manager import TrainingManager