diff --git a/README.md b/README.md index 851bbba..822b4c4 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ | [unet](mfai/torch/models/unet.py#L1) | [arxiv link](https://arxiv.org/pdf/1505.04597.pdf) | (Batch, features, Height, Width) | Yes | Vanilla U-Net | Radar image cleaning | Theo Tournier / Frank Guibert | | [segformer](mfai/torch/models/segformer.py#L1) | [arxiv link](https://arxiv.org/abs/2105.15203) | (Batch, features, Height, Width) | Yes | On par with u-net like on Deepsyg (MF internal), added an upsampling stage. Adapted from [Lucidrains' github](https://github.com/lucidrains/segformer-pytorch) | Segmentation tasks | Frank Guibert | | [swinunetr](mfai/torch/models/swinunetr.py#L1) | [arxiv link](https://arxiv.org/abs/2201.01266) | (Batch, features, Height, Width) | Yes | 2D Swin Unet transformer (Pangu and archweather uses customised 3D versions of Swin Transformers). Plugged in from [MONAI](https://github.com/Project-MONAI/MONAI/). The decoders have been modified to use Bilinear2D + Conv2d instead of Conv2dTranspose to remove artefacts/checkerboard effects | Segmentation tasks | Frank Guibert | -| [unetr++](mfai/torch/models/unetrpp.py#L1) | [arxiv link](https://arxiv.org/abs/2212.04497) | (Batch, features, Height, Width) | Yes | Adapted from [author's github](https://github.com/Amshaker/unetr_plus_plus). Modified to work both for 2d and 3d inputs | Front Detection | Frank Guibert | +| [unetr++](mfai/torch/models/unetrpp.py#L1) | [arxiv link](https://arxiv.org/abs/2212.04497) | (Batch, features, Height, Width) | Yes | Vision transformer with a reduced GFLOPS footprint adapted from [author's github](https://github.com/Amshaker/unetr_plus_plus). Modified to work both with 2d and 3d inputs | Front Detection | Frank Guibert | # NamedTensors @@ -94,6 +94,43 @@ Features: pip install mfai ``` +# Usage + +Our [unit tests](tests/test_models.py#L39) provides an example of how to use the models in a PyTorch training loop. Our models are instanciated with 2 mandatory positional arguments: **in_channels** and **out_channels** respectively the number of input and output channels of the model. The other parameter is an instance of the model's settings class. + +Here is an example of how to instanciate the UNet model with a 3 channels input (like an RGB image) and 1 channel output with its default settings: + +```python +from mfai.torch.models import UNet +unet = UNet(in_channels=3, out_channels=1) +``` + +**_FEATURE:_** Once instanciated, the model (subclass of **nn.Module**) can be used like any standard [PyTorch model](https://pytorch.org/tutorials/beginner/introyt/trainingyt.html). + +In order to instanciate a HalfUNet model with a 2 channels inputs, 2 channels outputs and a custom settings (128 filters, ghost module): + +```python +from mfai.torch.models import HalfUNet +halfunet = HalfUNet(in_channels=2, out_channels=2, settings=HalfUNet.settings_kls(num_filters=128, use_ghost=True)) +``` + +**_FEATURE:_** Each model has its settings class available under the **settings_kls** attribute. + +You can use the **load_from_settings_file** function to instanciate a model with its settings from a json file: + +```python +from pathlib import Path +from mfai.torch.models import load_from_settings_file +model = load_from_settings_file( + "HalfUNet", + 2, + 2, + Path(".") / "mfai" / "config" / "models" / "halfunet128.json", +) +``` + +**_FEATURE:_** Use the **load_from_settings_file** to have the strictest validation of the settings. + # Running Tests Our tests are written using [pytest](https://docs.pytest.org). We check that: @@ -112,5 +149,6 @@ We welcome contributions to this package. Our guidelines are the following: - Make sure the current tests pass and add new tests if necessary to cover the new features. Our CI will fail with a **test coverage below 80%**. - Make sure the code is formatted with [ruff](https://docs.astral.sh/ruff/) +# Acknowledgements - +This package is maintained by the DSM/LabIA team at Météo-France. We would like to thank the authors of the papers and codes we used to implement the models (see [above links](#neural-network-architectures) to **arxiv** and **github**) and the authors of the libraries we use to build this package (see our [**requirements.txt**](requirements.txt)). diff --git a/mfai/config/models/halfunet128.json b/mfai/config/models/halfunet128.json new file mode 100644 index 0000000..21ebfa3 --- /dev/null +++ b/mfai/config/models/halfunet128.json @@ -0,0 +1,4 @@ +{ + "num_filters": 128, + "use_ghost": true +} \ No newline at end of file diff --git a/mfai/torch/models/__init__.py b/mfai/torch/models/__init__.py index ea85e75..b3bc3d6 100644 --- a/mfai/torch/models/__init__.py +++ b/mfai/torch/models/__init__.py @@ -1,17 +1,44 @@ +from pathlib import Path +from torch import nn from .deeplabv3 import DeepLabV3, DeepLabV3Plus -from .half_unet import HalfUnet +from .half_unet import HalfUNet from .segformer import Segformer from .swinunetr import SwinUNETR -from .unet import Unet +from .unet import UNet from .unetrpp import UNETRPP all_nn_architectures = ( DeepLabV3, DeepLabV3Plus, - HalfUnet, + HalfUNet, Segformer, SwinUNETR, - Unet, + UNet, UNETRPP, ) + + +def load_from_settings_file( + model_name: str, in_channels: int, out_channels: int, settings_path: Path +) -> nn.Module: + """ + Instanciate a model from a settings file with Schema validation. + """ + + # pick the class matching the supplied name + model_kls = next( + (kls for kls in all_nn_architectures if kls.__name__ == model_name), None + ) + + if model_kls is None: + raise ValueError( + f"Model {model_name} not found in available architectures: {[x.__name__ for x in all_nn_architectures]}" + ) + + # load the settings + with open(settings_path, "r") as f: + model_settings = model_kls.settings_kls.schema().loads(f.read()) + + # instanciate the model + return model_kls(in_channels, out_channels, settings=model_settings) diff --git a/mfai/torch/models/half_unet.py b/mfai/torch/models/half_unet.py index eb88a96..ed9f4ec 100644 --- a/mfai/torch/models/half_unet.py +++ b/mfai/torch/models/half_unet.py @@ -12,7 +12,7 @@ @dataclass_json @dataclass(slots=True) -class HalfUnetSettings: +class HalfUNetSettings: num_filters: int = 64 dilation: int = 1 bias: bool = False @@ -32,7 +32,6 @@ def __init__( dilation=1, ): super().__init__() - print(type(out_channels), out_channels) self.conv = nn.Conv2d( in_channels=in_channels, @@ -61,8 +60,8 @@ def forward(self, x): return self.relu(x) -class HalfUnet(nn.Module): - settings_kls = HalfUnetSettings +class HalfUNet(nn.Module): + settings_kls = HalfUNetSettings onnx_supported = True def __init__( @@ -70,7 +69,7 @@ def __init__( in_channels: int, out_channels: int, input_shape: Union[None, Tuple[int, int]] = None, - settings: HalfUnetSettings = HalfUnetSettings(), + settings: HalfUNetSettings = HalfUNetSettings(), *args, **kwargs, ): @@ -81,6 +80,12 @@ def __init__( super().__init__(*args, **kwargs) + if settings.absolute_pos_embed: + if input_shape is None: + raise ValueError( + "You must provide an input_shape to use absolute_pos_embed in HalfUnet" + ) + self.encoder1 = self._block( in_channels, settings.num_filters, @@ -102,7 +107,7 @@ def __init__( use_ghost=settings.use_ghost, dilation=settings.dilation, absolute_pos_embed=settings.absolute_pos_embed, - grid_shape=[x // 2 for x in input_shape], + grid_shape=[x // 2 for x in input_shape] if input_shape else None, ) self.up2 = nn.UpsamplingBilinear2d(scale_factor=2) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) @@ -115,7 +120,7 @@ def __init__( use_ghost=settings.use_ghost, dilation=settings.dilation, absolute_pos_embed=settings.absolute_pos_embed, - grid_shape=[x // 4 for x in input_shape], + grid_shape=[x // 4 for x in input_shape] if input_shape else None, ) self.up3 = nn.UpsamplingBilinear2d(scale_factor=4) @@ -129,7 +134,7 @@ def __init__( use_ghost=settings.use_ghost, dilation=settings.dilation, absolute_pos_embed=settings.absolute_pos_embed, - grid_shape=[x // 8 for x in input_shape], + grid_shape=[x // 8 for x in input_shape] if input_shape else None, ) self.up4 = nn.UpsamplingBilinear2d(scale_factor=8) self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) @@ -142,7 +147,7 @@ def __init__( use_ghost=settings.use_ghost, dilation=settings.dilation, absolute_pos_embed=settings.absolute_pos_embed, - grid_shape=[x // 16 for x in input_shape], + grid_shape=[x // 16 for x in input_shape] if input_shape else None, ) self.up5 = nn.UpsamplingBilinear2d(scale_factor=16) diff --git a/mfai/torch/models/unet.py b/mfai/torch/models/unet.py index c796276..55b8b46 100644 --- a/mfai/torch/models/unet.py +++ b/mfai/torch/models/unet.py @@ -58,7 +58,7 @@ class UnetSettings: init_features: int = 64 -class Unet(nn.Module): +class UNet(nn.Module): """ Returns a u_net architecture, with uninitialised weights, matching desired numbers of input and output channels. @@ -75,34 +75,34 @@ def __init__( input_shape: Union[None, Tuple[int, int]] = None, settings: UnetSettings = UnetSettings(), ): - super(Unet, self).__init__() + super(UNet, self).__init__() features = settings.init_features self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) - self.encoder1 = Unet._block(in_channels, features, name="enc1") - self.encoder2 = Unet._block(features, features * 2, name="enc2") - self.encoder3 = Unet._block(features * 2, features * 4, name="enc3") - self.encoder4 = Unet._block(features * 4, features * 8, name="enc4") - self.bottleneck = Unet._block(features * 8, features * 16, name="bottleneck") + self.encoder1 = UNet._block(in_channels, features, name="enc1") + self.encoder2 = UNet._block(features, features * 2, name="enc2") + self.encoder3 = UNet._block(features * 2, features * 4, name="enc3") + self.encoder4 = UNet._block(features * 4, features * 8, name="enc4") + self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck") self.upconv4 = nn.ConvTranspose2d( features * 16, features * 8, kernel_size=2, stride=2 ) - self.decoder4 = Unet._block((features * 8) * 2, features * 8, name="dec4") + self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4") self.upconv3 = nn.ConvTranspose2d( features * 8, features * 4, kernel_size=2, stride=2 ) - self.decoder3 = Unet._block((features * 4) * 2, features * 4, name="dec3") + self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3") self.upconv2 = nn.ConvTranspose2d( features * 4, features * 2, kernel_size=2, stride=2 ) - self.decoder2 = Unet._block((features * 2) * 2, features * 2, name="dec2") + self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2") self.upconv1 = nn.ConvTranspose2d( features * 2, features, kernel_size=2, stride=2 ) - self.decoder1 = Unet._block(features * 2, features, name="dec1") + self.decoder1 = UNet._block(features * 2, features, name="dec1") self.conv = nn.Conv2d(features, out_channels, kernel_size=1) diff --git a/tests/test_models.py b/tests/test_models.py index e92fd21..62a11a4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -5,12 +5,15 @@ 3. onnx exported 4. onnx loaded and used for inference """ + +from pathlib import Path import tempfile +from marshmallow.exceptions import ValidationError import torch import pytest from mfai.torch import export_to_onnx, onnx_load_and_infer -from mfai.torch.models import all_nn_architectures +from mfai.torch.models import all_nn_architectures, load_from_settings_file def to_numpy(tensor): @@ -88,3 +91,28 @@ def test_torch_training_loop(model_kls): export_to_onnx(model, sample, dst.name) onnx_load_and_infer(dst.name, sample) + +def test_load_model_by_name(): + with pytest.raises(ValueError): + load_from_settings_file("NotAValidModel", 2, 2, None) + + # Should work: valid settings file for this model + load_from_settings_file( + "HalfUNet", + 2, + 2, + Path(__file__).parents[1] / "mfai" / "config" / "models" / "halfunet128.json", + ) + + # Should raise: invalid settings file for this model + with pytest.raises(ValidationError): + load_from_settings_file( + "UNETRPP", + 2, + 2, + Path(__file__).parents[1] + / "mfai" + / "config" + / "models" + / "halfunet128.json", + ) diff --git a/tests/test_namedtensors.py b/tests/test_namedtensors.py index f436cad..0fd6080 100644 --- a/tests/test_namedtensors.py +++ b/tests/test_namedtensors.py @@ -127,4 +127,4 @@ def test_named_tensor(): assert nt_cat.feature_names == [f"feature_{i}" for i in range(10)] + [ f"v_{i}" for i in range(10) ] + [f"u_{i}" for i in range(10)] - assert nt_cat.names == ["batch", "lat", "lon", "features"] \ No newline at end of file + assert nt_cat.names == ["batch", "lat", "lon", "features"]