Skip to content

Commit

Permalink
Adds load_from_settings_file + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Frank Guibert committed Jul 4, 2024
1 parent 43df30d commit d3dbd86
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 28 deletions.
42 changes: 40 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
| [unet](mfai/torch/models/unet.py#L1) | [arxiv link](https://arxiv.org/pdf/1505.04597.pdf) | (Batch, features, Height, Width) | Yes | Vanilla U-Net | Radar image cleaning | Theo Tournier / Frank Guibert |
| [segformer](mfai/torch/models/segformer.py#L1) | [arxiv link](https://arxiv.org/abs/2105.15203) | (Batch, features, Height, Width) | Yes | On par with u-net like on Deepsyg (MF internal), added an upsampling stage. Adapted from [Lucidrains' github](https://github.com/lucidrains/segformer-pytorch) | Segmentation tasks | Frank Guibert |
| [swinunetr](mfai/torch/models/swinunetr.py#L1) | [arxiv link](https://arxiv.org/abs/2201.01266) | (Batch, features, Height, Width) | Yes | 2D Swin Unet transformer (Pangu and archweather uses customised 3D versions of Swin Transformers). Plugged in from [MONAI](https://github.com/Project-MONAI/MONAI/). The decoders have been modified to use Bilinear2D + Conv2d instead of Conv2dTranspose to remove artefacts/checkerboard effects | Segmentation tasks | Frank Guibert |
| [unetr++](mfai/torch/models/unetrpp.py#L1) | [arxiv link](https://arxiv.org/abs/2212.04497) | (Batch, features, Height, Width) | Yes | Adapted from [author's github](https://github.com/Amshaker/unetr_plus_plus). Modified to work both for 2d and 3d inputs | Front Detection | Frank Guibert |
| [unetr++](mfai/torch/models/unetrpp.py#L1) | [arxiv link](https://arxiv.org/abs/2212.04497) | (Batch, features, Height, Width) | Yes | Vision transformer with a reduced GFLOPS footprint adapted from [author's github](https://github.com/Amshaker/unetr_plus_plus). Modified to work both with 2d and 3d inputs | Front Detection | Frank Guibert |

# NamedTensors

Expand Down Expand Up @@ -94,6 +94,43 @@ Features:
pip install mfai
```

# Usage

Our [unit tests](tests/test_models.py#L39) provides an example of how to use the models in a PyTorch training loop. Our models are instanciated with 2 mandatory positional arguments: **in_channels** and **out_channels** respectively the number of input and output channels of the model. The other parameter is an instance of the model's settings class.

Here is an example of how to instanciate the UNet model with a 3 channels input (like an RGB image) and 1 channel output with its default settings:

```python
from mfai.torch.models import UNet
unet = UNet(in_channels=3, out_channels=1)
```

**_FEATURE:_** Once instanciated, the model (subclass of **nn.Module**) can be used like any standard [PyTorch model](https://pytorch.org/tutorials/beginner/introyt/trainingyt.html).

In order to instanciate a HalfUNet model with a 2 channels inputs, 2 channels outputs and a custom settings (128 filters, ghost module):

```python
from mfai.torch.models import HalfUNet
halfunet = HalfUNet(in_channels=2, out_channels=2, settings=HalfUNet.settings_kls(num_filters=128, use_ghost=True))
```

**_FEATURE:_** Each model has its settings class available under the **settings_kls** attribute.

You can use the **load_from_settings_file** function to instanciate a model with its settings from a json file:

```python
from pathlib import Path
from mfai.torch.models import load_from_settings_file
model = load_from_settings_file(
"HalfUNet",
2,
2,
Path(".") / "mfai" / "config" / "models" / "halfunet128.json",
)
```

**_FEATURE:_** Use the **load_from_settings_file** to have the strictest validation of the settings.

# Running Tests

Our tests are written using [pytest](https://docs.pytest.org). We check that:
Expand All @@ -112,5 +149,6 @@ We welcome contributions to this package. Our guidelines are the following:
- Make sure the current tests pass and add new tests if necessary to cover the new features. Our CI will fail with a **test coverage below 80%**.
- Make sure the code is formatted with [ruff](https://docs.astral.sh/ruff/)

# Acknowledgements


This package is maintained by the DSM/LabIA team at Météo-France. We would like to thank the authors of the papers and codes we used to implement the models (see [above links](#neural-network-architectures) to **arxiv** and **github**) and the authors of the libraries we use to build this package (see our [**requirements.txt**](requirements.txt)).
4 changes: 4 additions & 0 deletions mfai/config/models/halfunet128.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"num_filters": 128,
"use_ghost": true
}
35 changes: 31 additions & 4 deletions mfai/torch/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,44 @@
from pathlib import Path
from torch import nn
from .deeplabv3 import DeepLabV3, DeepLabV3Plus
from .half_unet import HalfUnet
from .half_unet import HalfUNet
from .segformer import Segformer
from .swinunetr import SwinUNETR
from .unet import Unet
from .unet import UNet
from .unetrpp import UNETRPP


all_nn_architectures = (
DeepLabV3,
DeepLabV3Plus,
HalfUnet,
HalfUNet,
Segformer,
SwinUNETR,
Unet,
UNet,
UNETRPP,
)


def load_from_settings_file(
model_name: str, in_channels: int, out_channels: int, settings_path: Path
) -> nn.Module:
"""
Instanciate a model from a settings file with Schema validation.
"""

# pick the class matching the supplied name
model_kls = next(
(kls for kls in all_nn_architectures if kls.__name__ == model_name), None
)

if model_kls is None:
raise ValueError(
f"Model {model_name} not found in available architectures: {[x.__name__ for x in all_nn_architectures]}"
)

# load the settings
with open(settings_path, "r") as f:
model_settings = model_kls.settings_kls.schema().loads(f.read())

# instanciate the model
return model_kls(in_channels, out_channels, settings=model_settings)
23 changes: 14 additions & 9 deletions mfai/torch/models/half_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

@dataclass_json
@dataclass(slots=True)
class HalfUnetSettings:
class HalfUNetSettings:
num_filters: int = 64
dilation: int = 1
bias: bool = False
Expand All @@ -32,7 +32,6 @@ def __init__(
dilation=1,
):
super().__init__()
print(type(out_channels), out_channels)

self.conv = nn.Conv2d(
in_channels=in_channels,
Expand Down Expand Up @@ -61,16 +60,16 @@ def forward(self, x):
return self.relu(x)


class HalfUnet(nn.Module):
settings_kls = HalfUnetSettings
class HalfUNet(nn.Module):
settings_kls = HalfUNetSettings
onnx_supported = True

def __init__(
self,
in_channels: int,
out_channels: int,
input_shape: Union[None, Tuple[int, int]] = None,
settings: HalfUnetSettings = HalfUnetSettings(),
settings: HalfUNetSettings = HalfUNetSettings(),
*args,
**kwargs,
):
Expand All @@ -81,6 +80,12 @@ def __init__(

super().__init__(*args, **kwargs)

if settings.absolute_pos_embed:
if input_shape is None:
raise ValueError(
"You must provide an input_shape to use absolute_pos_embed in HalfUnet"
)

self.encoder1 = self._block(
in_channels,
settings.num_filters,
Expand All @@ -102,7 +107,7 @@ def __init__(
use_ghost=settings.use_ghost,
dilation=settings.dilation,
absolute_pos_embed=settings.absolute_pos_embed,
grid_shape=[x // 2 for x in input_shape],
grid_shape=[x // 2 for x in input_shape] if input_shape else None,
)
self.up2 = nn.UpsamplingBilinear2d(scale_factor=2)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
Expand All @@ -115,7 +120,7 @@ def __init__(
use_ghost=settings.use_ghost,
dilation=settings.dilation,
absolute_pos_embed=settings.absolute_pos_embed,
grid_shape=[x // 4 for x in input_shape],
grid_shape=[x // 4 for x in input_shape] if input_shape else None,
)

self.up3 = nn.UpsamplingBilinear2d(scale_factor=4)
Expand All @@ -129,7 +134,7 @@ def __init__(
use_ghost=settings.use_ghost,
dilation=settings.dilation,
absolute_pos_embed=settings.absolute_pos_embed,
grid_shape=[x // 8 for x in input_shape],
grid_shape=[x // 8 for x in input_shape] if input_shape else None,
)
self.up4 = nn.UpsamplingBilinear2d(scale_factor=8)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
Expand All @@ -142,7 +147,7 @@ def __init__(
use_ghost=settings.use_ghost,
dilation=settings.dilation,
absolute_pos_embed=settings.absolute_pos_embed,
grid_shape=[x // 16 for x in input_shape],
grid_shape=[x // 16 for x in input_shape] if input_shape else None,
)
self.up5 = nn.UpsamplingBilinear2d(scale_factor=16)

Expand Down
22 changes: 11 additions & 11 deletions mfai/torch/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class UnetSettings:
init_features: int = 64


class Unet(nn.Module):
class UNet(nn.Module):
"""
Returns a u_net architecture, with uninitialised weights, matching desired numbers of input and output channels.
Expand All @@ -75,34 +75,34 @@ def __init__(
input_shape: Union[None, Tuple[int, int]] = None,
settings: UnetSettings = UnetSettings(),
):
super(Unet, self).__init__()
super(UNet, self).__init__()

features = settings.init_features

self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

self.encoder1 = Unet._block(in_channels, features, name="enc1")
self.encoder2 = Unet._block(features, features * 2, name="enc2")
self.encoder3 = Unet._block(features * 2, features * 4, name="enc3")
self.encoder4 = Unet._block(features * 4, features * 8, name="enc4")
self.bottleneck = Unet._block(features * 8, features * 16, name="bottleneck")
self.encoder1 = UNet._block(in_channels, features, name="enc1")
self.encoder2 = UNet._block(features, features * 2, name="enc2")
self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

self.upconv4 = nn.ConvTranspose2d(
features * 16, features * 8, kernel_size=2, stride=2
)
self.decoder4 = Unet._block((features * 8) * 2, features * 8, name="dec4")
self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
self.upconv3 = nn.ConvTranspose2d(
features * 8, features * 4, kernel_size=2, stride=2
)
self.decoder3 = Unet._block((features * 4) * 2, features * 4, name="dec3")
self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
self.upconv2 = nn.ConvTranspose2d(
features * 4, features * 2, kernel_size=2, stride=2
)
self.decoder2 = Unet._block((features * 2) * 2, features * 2, name="dec2")
self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
self.upconv1 = nn.ConvTranspose2d(
features * 2, features, kernel_size=2, stride=2
)
self.decoder1 = Unet._block(features * 2, features, name="dec1")
self.decoder1 = UNet._block(features * 2, features, name="dec1")

self.conv = nn.Conv2d(features, out_channels, kernel_size=1)

Expand Down
30 changes: 29 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
3. onnx exported
4. onnx loaded and used for inference
"""

from pathlib import Path
import tempfile

from marshmallow.exceptions import ValidationError
import torch
import pytest
from mfai.torch import export_to_onnx, onnx_load_and_infer
from mfai.torch.models import all_nn_architectures
from mfai.torch.models import all_nn_architectures, load_from_settings_file


def to_numpy(tensor):
Expand Down Expand Up @@ -88,3 +91,28 @@ def test_torch_training_loop(model_kls):
export_to_onnx(model, sample, dst.name)
onnx_load_and_infer(dst.name, sample)


def test_load_model_by_name():
with pytest.raises(ValueError):
load_from_settings_file("NotAValidModel", 2, 2, None)

# Should work: valid settings file for this model
load_from_settings_file(
"HalfUNet",
2,
2,
Path(__file__).parents[1] / "mfai" / "config" / "models" / "halfunet128.json",
)

# Should raise: invalid settings file for this model
with pytest.raises(ValidationError):
load_from_settings_file(
"UNETRPP",
2,
2,
Path(__file__).parents[1]
/ "mfai"
/ "config"
/ "models"
/ "halfunet128.json",
)
2 changes: 1 addition & 1 deletion tests/test_namedtensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,4 @@ def test_named_tensor():
assert nt_cat.feature_names == [f"feature_{i}" for i in range(10)] + [
f"v_{i}" for i in range(10)
] + [f"u_{i}" for i in range(10)]
assert nt_cat.names == ["batch", "lat", "lon", "features"]
assert nt_cat.names == ["batch", "lat", "lon", "features"]

0 comments on commit d3dbd86

Please sign in to comment.