Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update deps and doc + refacto models + adds GNNs #13

Merged
merged 17 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM pytorch/pytorch:2.3.1-cuda11.8-cudnn8-runtime
FROM pytorch/pytorch:2.4.1-cuda11.8-cudnn9-runtime

tourniert marked this conversation as resolved.
Show resolved Hide resolved
RUN apt -y update && apt -y install git
WORKDIR /app
Expand Down
200 changes: 118 additions & 82 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@
# Table of contents

- [Neural Network Architectures](#neural-network-architectures)
- deeplabv3/deeplabv3+
- halfunet
- unet/customunet
- segformer
- swinunetr
- unetr++
- Convolutional Neural Networks:
- deeplabv3/deeplabv3+
- halfunet
- unet/customunet
- Vision Transformers:
- segformer
- swinunetr
- unetr++
- Graph Neural Networks:
- hilam
- graphlam
- [SegmentationLightningModule](#segmentationlightningmodule)
- [NamedTensors](#namedtensors)
- [Metrics](#metrics)
Expand All @@ -32,10 +37,45 @@

# Neural Network Architectures

Each model we provide is a subclass of [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) and can be used in a PyTorch training loop. It has three **critical** class attributes:
Currently we support the following neural network architectures:


## Convolutional Neural Networks

| Model | Research Paper | Input Shape | ONNX exportable ? | Notes | Use-Cases at MF |
| :---: | :---: | :---: | :---: | :---: | :---: |
| [DeepLabV3Plus](mfai/torch/models/deeplabv3.py#L1) | [arxiv link](https://arxiv.org/abs/1802.02611) | (Batch, features, Height, Width) | Yes | As a very large receptive field versus U-Net, Half-Unet, ... | Front Detection, Nowcasting |
| [HalfUNet](mfai/torch/models/half_unet.py#L1) | [researchgate link](https://www.researchgate.net/publication/361186968_Half-UNet_A_Simplified_U-Net_Architecture_for_Medical_Image_Segmentation) | (Batch, features, Height, Width) | Yes | In prod/oper on [Espresso](https://www.mdpi.com/2674-0494/2/4/25) V2 with 128 filters and standard conv blocks instead of ghost | Satellite channels to rain estimation |
| [UNet](mfai/torch/models/unet.py#L1) | [arxiv link](https://arxiv.org/pdf/1505.04597.pdf) | (Batch, features, Height, Width) | Yes | Vanilla U-Net | Radar image cleaning |
| [CustomUnet](mfai/torch/models/unet.py#L1) | [arxiv link](https://arxiv.org/pdf/1505.04597.pdf) | (Batch, features, Height, Width) | Yes | U-Net like architecture with a variety of resnet encoder choices | Radar image cleaning


## Vision Transformers

| Model | Research Paper | Input Shape | ONNX exportable ? | Notes | Use-Cases at MF |
| :---: | :---: | :---: | :---: | :---: | :---: |
| [Segformer](mfai/torch/models/segformer.py#L1) | [arxiv link](https://arxiv.org/abs/2105.15203) | (Batch, features, Height, Width) | Yes | On par with u-net like on Deepsyg (MF internal), added an upsampling stage. Adapted from [Lucidrains' github](https://github.com/lucidrains/segformer-pytorch) | Segmentation tasks |
| [SwinUNETR](mfai/torch/models/swinunetr.py#L1) | [arxiv link](https://arxiv.org/abs/2201.01266) | (Batch, features, Height, Width) | No | 2D Swin Unet transformer (Pangu and archweather uses customised 3D versions of Swin Transformers). Plugged in from [MONAI](https://github.com/Project-MONAI/MONAI/). The decoders use Bilinear2D + Conv2d instead of Conv2dTranspose to remove artefacts/checkerboard effects | Segmentation tasks |
| [UNETRPP](mfai/torch/models/unetrpp.py#L1) | [arxiv link](https://arxiv.org/abs/2212.04497) | (Batch, features, Height, Width) or (Batch, features, Height, Width, Depth) | Yes | Vision transformer with a reduced GFLOPS footprint adapted from [author's github](https://github.com/Amshaker/unetr_plus_plus). Modified to work both with 2d and 3d inputs. The decoders use Bilinear2D + Conv2d instead of Conv2dTranspose to remove artefacts/checkerboard effects | Front Detection, LAM Weather Forecasting |

## Graph Neural Networks

| Model | Research Paper | Input Shape | ONNX exportable ? | Notes | Use-Cases at MF |
| :---: | :---: | :---: | :---: | :---: | :---: |
| [hilam, graphlam](mfai/torch/models/nlam/__init__.py) | [arxiv link](https://arxiv.org/abs/2309.17370) | (Batch, graph_node_id, features) | Imported and adapted from [Joel's github](https://github.com/joeloskarsson/neural-lam) |

<details>
<summary>Details about our models</summary>

Each model we provide is a subclass of [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) and can be used in a PyTorch training loop. It has multiple class attributes to facilitate model usage in a project:
- **settings_kls**: a class that defines the settings of the model (number of filters, kernel size, ...). It is used to instanciate the model with a specific configuration.
- **onnx_supported**: a boolean that indicates if the model can be exported to onnx. Our CI validates that the model can be exported to onnx and reloaded for inference.
- **input_spatial_dims**: a tuple that describes the spatial dimensions of the input tensor supported by the model. A model that supports 2D spatial data will have **(2,)** as value. A model that supports 2d or 3d spatial data will have **(2, 3)** as value.
- **supported_num_spatial_dims**: a tuple that describes the spatial dimensions of the input tensor supported by the model. A model that supports 2D spatial data will have **(2,)** as value. A model that supports 2d or 3d spatial data will have **(2, 3)** as value.
- **num_spatial_dims**: an integer that describes the number of spatial dimensions of the input/output tensor expected by the instance of the model, must be a value in **supported_num_spatial_dims**.
- **settings**: a runtime property returns the settings instance used to instanciate the model.
- **model_type**: an Enum describing the type of model: CONVOLUTIONAL, VISION_TRANSFORMER, GRAPH, LLM, MLLM.
- **features_last**: a boolean that indicates if the features dimension is the last dimension of the input/output tensor. If False, the features dimension is the second dimension of the input/output tensor.
- **register**: a boolean that indicates if the model should be registered in the **MODELS** registry. By default, it is set to False which allows the creation of intermediate subclasses not meant for direct use.

The Python interface contract for our model is enforced using [Python ABC](https://docs.python.org/3/library/abc.html) and in our case [ModelABC](mfai/torch/models/base.py#L1) class.

Expand All @@ -52,21 +92,15 @@ class HalfUNetSettings:

class HalfUNet(ModelABC, nn.Module):
settings_kls = HalfUNetSettings
onnx_supported = True
input_spatial_dims = (2,)
onnx_supported: bool = True
supported_num_spatial_dims = (2,)
num_spatial_dims: int = 2
features_last: bool = False
model_type: int = ModelType.CONVOLUTIONAL
register: bool = True
```
</details>

Currently we support the following neural network architectures:

| Model | Research Paper | Input Shape | ONNX exportable ? | Notes | Use-Cases at MF | Maintainer(s) |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| [DeepLabV3Plus](mfai/torch/models/deeplabv3.py#L1) | [arxiv link](https://arxiv.org/abs/1802.02611) | (Batch, features, Height, Width) | Yes | As a very large receptive field versus U-Net, Half-Unet, ... | Front Detection, Nowcasting | Theo Tournier / Frank Guibert |
| [HalfUNet](mfai/torch/models/half_unet.py#L1) | [researchgate link](https://www.researchgate.net/publication/361186968_Half-UNet_A_Simplified_U-Net_Architecture_for_Medical_Image_Segmentation) | (Batch, features, Height, Width) | Yes | In prod/oper on [Espresso](https://www.mdpi.com/2674-0494/2/4/25) V2 with 128 filters and standard conv blocks instead of ghost | Satellite channels to rain estimation | Frank Guibert |
| [UNet](mfai/torch/models/unet.py#L1) | [arxiv link](https://arxiv.org/pdf/1505.04597.pdf) | (Batch, features, Height, Width) | Yes | Vanilla U-Net | Radar image cleaning | Theo Tournier / Frank Guibert |
| [CustomUnet](mfai/torch/models/unet.py#L1) | [arxiv link](https://arxiv.org/pdf/1505.04597.pdf) | (Batch, features, Height, Width) | Yes | U-Net like architecture with a variety of resnet encoder choices | Radar image cleaning | Theo Tournier |
| [Segformer](mfai/torch/models/segformer.py#L1) | [arxiv link](https://arxiv.org/abs/2105.15203) | (Batch, features, Height, Width) | Yes | On par with u-net like on Deepsyg (MF internal), added an upsampling stage. Adapted from [Lucidrains' github](https://github.com/lucidrains/segformer-pytorch) | Segmentation tasks | Frank Guibert |
| [SwinUNETR](mfai/torch/models/swinunetr.py#L1) | [arxiv link](https://arxiv.org/abs/2201.01266) | (Batch, features, Height, Width) | No | 2D Swin Unet transformer (Pangu and archweather uses customised 3D versions of Swin Transformers). Plugged in from [MONAI](https://github.com/Project-MONAI/MONAI/). The decoders have been modified to use Bilinear2D + Conv2d instead of Conv2dTranspose to remove artefacts/checkerboard effects | Segmentation tasks | Frank Guibert |
| [UNETRPP](mfai/torch/models/unetrpp.py#L1) | [arxiv link](https://arxiv.org/abs/2212.04497) | (Batch, features, Height, Width) or (Batch, features, Height, Width, Depth) | Yes | Vision transformer with a reduced GFLOPS footprint adapted from [author's github](https://github.com/Amshaker/unetr_plus_plus). Modified to work both with 2d and 3d inputs | Front Detection | Frank Guibert |

# SegmentationLightningModule

Expand All @@ -93,7 +127,58 @@ PyTorch provides an experimental feature called [**named tensors**](https://pyto

NamedTensors are a way to give names to dimensions of tensors and to keep track of the names of the physical/weather parameters along the features dimension.

The [**NamedTensor**](../py4cast/datasets/base.py#L38) class is a wrapper around a PyTorch tensor, it allows us to pass consistent object linking data and metadata with extra utility methods (concat along features dimension, flatten in place, ...). See the implementation [here](../py4cast/datasets/base.py#L38) and usage for plots [here](../py4cast/observer.py)
The [**NamedTensor**](mfai/torch/namedtensor.py#L28) class is a wrapper around a PyTorch tensor with additionnal attributes and methods, it allows us to pass consistent object linking data and metadata with extra utility methods (concat along features dimension, flatten in place, ...).

An example of NamedTensor usage for gridded data on a 256x256 grid:

```python

tensor = torch.rand(4, 256, 256, 3)

nt = NamedTensor(
tensor,
names=["batch", "lat", "lon", "features"],
feature_names=["u", "v", "t2m"],
)

print(nt.dim_size("lat"))
# 256

nt2 = NamedTensor(
torch.rand(4, 256, 256, 1),
names=["batch", "lat", "lon", "features"],
feature_names=["q"],
)

# concat along the features dimension
nt3 = nt | nt2

# index by feature name
nt3["u"]

# Create a new NamedTensor with the same names but different data (useful for autoregressive models)
nt4 = NamedTensor.new_like(torch.rand(4, 256, 256, 4), nt3)

# Flatten in place the lat and lon dimensions and rename the new dim to 'ngrid'
# this is typically to feed our gridded data to GNNs
nt3.flatten_("ngrid", 1, 2)

# str representation of the NamedTensor yields useful statistics
>>> print(nt)
--- NamedTensor ---
Names: ['batch', 'lat', 'lon', 'features']
Tensor Shape: torch.Size([4, 256, 256, 3]))
Features:
┌────────────────┬─────────────┬──────────┐
│ Feature name │ Min │ Max │
├────────────────┼─────────────┼──────────┤
│ u │ 1.3113e-06 │ 0.999996 │
│ v │ 8.9407e-07 │ 0.999997 │
│ t2m │ 5.06639e-06 │ 0.999995 │

# rearrange in place using einops like syntax
nt3.rearrange_("batch ngrid features -> batch features ngrid")
```

# Metrics

Expand All @@ -113,14 +198,22 @@ cd mfai
pip install -e .
```

## Using pip (experimental)
## Using pip

You can install using pip trageting the main branch:

```bash
pip install git+https://github.com/meteofrance/mfai
```

We have a first release on testpypi, you can install it with:
If you want to target a specific tag/version or branch:

```bash
pip install --index-url https://test.pypi.org/simple/ mfai
pip install git+https://github.com/meteofrance/mfai@v1.0.1
```

This syntax also work in **requirements.txt**. We do not provide wheel on pypi for now.

# Usage

## Instanciate a model
Expand All @@ -130,7 +223,6 @@ Our [unit tests](tests/test_models.py#L39) provides an example of how to use the
The last parameter is an instance of the model's settings class and is a keyword argument with a default value set to the default settings.



Here is an example of how to instanciate the UNet model with a 3 channels input (like an RGB image) and 1 channel output with its default settings:

```python
Expand Down Expand Up @@ -441,62 +533,6 @@ csi_metric = CSINeighborood(task="multiclass", num_classes=2, num_neighbors=0)
csi = csi_metric(preds, target)
```



## NamedTensors example

We use **NamedTensor** instances to keep the link between our torch tensors and our physical/weather feature names (for plotting, for specific losses weights on given features, ...).

Some examples of NamedTensors usage, here for gridded data on a 256x256 grid:

```python

tensor = torch.rand(4, 256, 256, 3)

nt = NamedTensor(
tensor,
names=["batch", "lat", "lon", "features"],
feature_names=["u", "v", "t2m"],
)

print(nt.dim_size("lat"))
# 256

nt2 = NamedTensor(
torch.rand(4, 256, 256, 1),
names=["batch", "lat", "lon", "features"],
feature_names=["q"],
)

# concat along the features dimension
nt3 = nt | nt2

# index by feature name
nt3["u"]

# Create a new NamedTensor with the same names but different data (useful for autoregressive models)
nt4 = NamedTensor.new_like(torch.rand(4, 256, 256, 4), nt3)

# Flatten in place the lat and lon dimensions and rename the new dim to 'ngrid'
# this is typically to feed our gridded data to GNNs
nt3.flatten_("ngrid", 1, 2)

# str representation of the NamedTensor yields useful statistics
>>> print(nt)
--- NamedTensor ---
Names: ['batch', 'lat', 'lon', 'features']
Tensor Shape: torch.Size([4, 256, 256, 3]))
Features:
┌────────────────┬─────────────┬──────────┐
│ Feature name │ Min │ Max │
├────────────────┼─────────────┼──────────┤
│ u │ 1.3113e-06 │ 0.999996 │
│ v │ 8.9407e-07 │ 0.999997 │
│ t2m │ 5.06639e-06 │ 0.999995 │

```


# Running Tests

Our tests are written using [pytest](https://docs.pytest.org). We check that:
Expand Down
2 changes: 1 addition & 1 deletion mfai/config/cli_fit_test.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
seed_everything: true
model:
model:
class_path: mfai.torch.models.Segformer
class_path: mfai.torch.models.segformer.Segformer
init_args:
in_channels: 2
out_channels: 1
Expand Down
48 changes: 26 additions & 22 deletions mfai/torch/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
import importlib
import pkgutil
from pathlib import Path
from typing import Optional, Tuple

from torch import nn

from .deeplabv3 import DeepLabV3, DeepLabV3Plus
from .half_unet import HalfUNet
from .segformer import Segformer
from .swinunetr import SwinUNETR
from .unet import CustomUnet, UNet
from .unetrpp import UNETRPP

all_nn_architectures = (
DeepLabV3,
DeepLabV3Plus,
HalfUNet,
Segformer,
SwinUNETR,
UNet,
CustomUnet,
UNETRPP,
)
from .base import ModelABC


# Load all models from the torch.models package
# which are ModelABC subclasses and have the register attribute set to True
registry = dict()
package = importlib.import_module("mfai.torch.models")
for _, name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."):
module = importlib.import_module(name)
for object_name, kls in module.__dict__.items():
if (
isinstance(kls, type)
and issubclass(kls, ModelABC)
and kls != ModelABC
and kls.register
):
if kls.__name__ in registry:
raise ValueError(
f"Model {kls.__name__} from plugin {object_name} already exists in the registry."
)
registry[kls.__name__] = kls
all_nn_architectures = list(registry.values())


def load_from_settings_file(
Expand All @@ -34,13 +40,11 @@ def load_from_settings_file(
"""

# pick the class matching the supplied name
model_kls = next(
(kls for kls in all_nn_architectures if kls.__name__ == model_name), None
)
model_kls = registry.get(model_name, None)

if model_kls is None:
raise ValueError(
f"Model {model_name} not found in available architectures: {[x.__name__ for x in all_nn_architectures]}"
f"Model {model_name} not found in available architectures: {[x for x in registry]}. Make sure the model's `registry` attribute is set to True (default is False)."
)

# load the settings
Expand Down
Loading
Loading