-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds load_from_settings_file + tests
- Loading branch information
Frank Guibert
committed
Jul 4, 2024
1 parent
43df30d
commit d3dbd86
Showing
7 changed files
with
130 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
{ | ||
"num_filters": 128, | ||
"use_ghost": true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,44 @@ | ||
from pathlib import Path | ||
from torch import nn | ||
from .deeplabv3 import DeepLabV3, DeepLabV3Plus | ||
from .half_unet import HalfUnet | ||
from .half_unet import HalfUNet | ||
from .segformer import Segformer | ||
from .swinunetr import SwinUNETR | ||
from .unet import Unet | ||
from .unet import UNet | ||
from .unetrpp import UNETRPP | ||
|
||
|
||
all_nn_architectures = ( | ||
DeepLabV3, | ||
DeepLabV3Plus, | ||
HalfUnet, | ||
HalfUNet, | ||
Segformer, | ||
SwinUNETR, | ||
Unet, | ||
UNet, | ||
UNETRPP, | ||
) | ||
|
||
|
||
def load_from_settings_file( | ||
model_name: str, in_channels: int, out_channels: int, settings_path: Path | ||
) -> nn.Module: | ||
""" | ||
Instanciate a model from a settings file with Schema validation. | ||
""" | ||
|
||
# pick the class matching the supplied name | ||
model_kls = next( | ||
(kls for kls in all_nn_architectures if kls.__name__ == model_name), None | ||
) | ||
|
||
if model_kls is None: | ||
raise ValueError( | ||
f"Model {model_name} not found in available architectures: {[x.__name__ for x in all_nn_architectures]}" | ||
) | ||
|
||
# load the settings | ||
with open(settings_path, "r") as f: | ||
model_settings = model_kls.settings_kls.schema().loads(f.read()) | ||
|
||
# instanciate the model | ||
return model_kls(in_channels, out_channels, settings=model_settings) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters