Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic config #976

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion GANDLF/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from GANDLF.metrics import surface_distance_ids
from importlib.metadata import version
from GANDLF.utils.pydantic_config import Parameters

## dictionary to define defaults for appropriate options, which are evaluated
parameter_defaults = {
Expand Down Expand Up @@ -653,6 +654,7 @@ def _parseConfig(
if "opt" in params:
print("DeprecationWarning: 'opt' has been superseded by 'optimizer'")
params["optimizer"] = params["opt"]
params.pop("opt")

# initialize defaults for patch sampler
temp_patch_sampler_dict = {
Expand Down Expand Up @@ -747,7 +749,10 @@ def ConfigManager(
dict: The parameter dictionary.
"""
try:
return _parseConfig(config_file_path, version_check_flag)
parameters = Parameters(
**_parseConfig(config_file_path, version_check_flag)
).model_dump()
return parameters
except Exception as e:
## todo: ensure logging captures assertion errors
assert (
Expand Down
2 changes: 1 addition & 1 deletion GANDLF/models/imagenet_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def __init__(self, parameters) -> None:
)

# all BatchNorm should be replaced with InstanceNorm for DP experiments
if "differential_privacy" in parameters:
if parameters["differential_privacy"] is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this check can be simpler, like this?

Suggested change
if parameters["differential_privacy"] is not None:
if parameters.get("differential_privacy"):

Would that work?

self.replace_batchnorm(self.model)

if self.n_dimensions == 3:
Expand Down
4 changes: 2 additions & 2 deletions GANDLF/utils/data_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ def split_data(
"nested_training" in parameters
), "`nested_training` key missing in parameters"
# populate the headers
if "headers" not in parameters:
if parameters["headers"] is None:
_, parameters["headers"] = parseTrainingCSV(full_dataset)

parameters = (
populate_header_in_parameters(parameters, parameters["headers"])
if "problem_type" not in parameters
if parameters["problem_type"] is None
else parameters
)

Expand Down
75 changes: 75 additions & 0 deletions GANDLF/utils/pydantic_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from pydantic import BaseModel, ConfigDict
from typing import Dict, List, Optional, Union


class Version(BaseModel):
minimum: str
maximum: str


class Model(BaseModel):
dimension: int
base_filters: int
architecture: str
norm_type: str
final_layer: str
class_list: list[Union[int, str]]
ignore_label_validation: Union[int, None]
amp: bool
print_summary: bool
type: str
data_type: str
save_at_every_epoch: bool
num_channels: Optional[int] = None


class Parameters(BaseModel):
model_config = ConfigDict(extra="forbid")
version: Version
model: Model
modality: str
scheduler: dict
learning_rate: float
weighted_loss: bool
verbose: bool
q_verbose: bool
medcam_enabled: bool
save_training: bool
save_output: bool
in_memory: bool
pin_memory_dataloader: bool
scaling_factor: Union[float, int]
q_max_length: int
q_samples_per_volume: int
q_num_workers: int
num_epochs: int
patience: int
batch_size: int
learning_rate: float
clip_grad: Union[None, float]
track_memory_usage: bool
memory_save_mode: bool
print_rgb_label_warning: bool
data_postprocessing: Dict # TODO: maybe is better to create a class
data_preprocessing: Dict # TODO: maybe is better to create a class
grid_aggregator_overlap: str
determinism: bool
previous_parameters: None
metrics: Union[List, dict]
patience: int
parallel_compute_command: Union[str, bool, None]
loss_function: Union[str, Dict]
data_augmentation: dict # TODO: maybe is better to create a class
nested_training: dict # TODO: maybe is better to create a class
optimizer: Union[dict, str]
patch_sampler: Union[dict, str]
patch_size: Union[List[int], int]
clip_mode: Union[str, None]
inference_mechanism: dict
data_postprocessing_after_reverse_one_hot_encoding: dict
enable_padding: Optional[Union[dict, bool]] = None
headers: Optional[dict] = None
output_dir: Optional[str] = ""
problem_type: Optional[str] = None
differential_privacy: Optional[dict] = None
# opt: Optional[Union[dict, str]] = {} # TODO find a better way
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
"openslide-bin",
"openslide-python==1.4.1",
"lion-pytorch==0.2.2",
"pydantic",
]

if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
get_patch_size_in_microns,
convert_to_tiff,
)
from GANDLF.utils.pydantic_config import Parameters
from GANDLF.config_manager import ConfigManager
from GANDLF.parseConfig import parseConfig
from GANDLF.training_manager import TrainingManager
Expand Down
Loading