Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better naming of customized models #7

Merged
merged 2 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cliffordlayers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import sys

if sys.version_info >= (3, 8):
from importlib import metadata
else:
import importlib_metadata as metadata

__version__ = metadata.version("cliffordlayers")
12 changes: 6 additions & 6 deletions cliffordlayers/models/custom_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from cliffordlayers.nn.functional.utils import _w_assert


def get_2d_clifford_encoding_kernel(
def get_2d_scalar_vector_encoding_kernel(
w: Union[tuple, list, torch.Tensor, nn.Parameter, nn.ParameterList], g: torch.Tensor
) -> Tuple[int, torch.Tensor]:
"""Clifford kernel for 2d Clifford algebra encoding layers.
Expand All @@ -40,7 +40,7 @@ def get_2d_clifford_encoding_kernel(
return 4, k


def get_2d_clifford_decoding_kernel(
def get_2d_scalar_vector_decoding_kernel(
w: Union[tuple, list, torch.Tensor, nn.Parameter, nn.ParameterList], g: torch.Tensor
) -> Tuple[int, torch.Tensor]:
"""Clifford kernel for 2d Clifford algebra decoding layers.
Expand All @@ -66,7 +66,7 @@ def get_2d_clifford_decoding_kernel(
return 3, k


def get_2d_clifford_rotation_encoding_kernel(
def get_2d_rotation_scalar_vector_encoding_kernel(
w: Union[tuple, list, torch.Tensor, nn.Parameter, nn.ParameterList], g: torch.Tensor
) -> Tuple[int, torch.Tensor]:
"""Rotational Clifford kernel for 2d Clifford algebra encoding layers.
Expand Down Expand Up @@ -146,7 +146,7 @@ def get_2d_clifford_rotation_encoding_kernel(
return 4, k


def get_2d_clifford_rotation_decoding_kernel(
def get_2d_rotation_scalar_vector_decoding_kernel(
w: Union[tuple, list, torch.Tensor, nn.Parameter, nn.ParameterList], g: torch.Tensor
) -> Tuple[int, torch.Tensor]:
"""Rotational Clifford kernel for 2d Clifford algebra decoding layers.
Expand Down Expand Up @@ -220,7 +220,7 @@ def get_2d_clifford_rotation_decoding_kernel(
return 3, k


def get_3d_clifford_encoding_kernel(
def get_3d_maxwell_encoding_kernel(
w: Union[tuple, list, torch.Tensor, nn.Parameter, nn.ParameterList], g: torch.Tensor
) -> Tuple[int, torch.Tensor]:
"""Clifford kernel for 3d Clifford algebra encoding layers.
Expand Down Expand Up @@ -269,7 +269,7 @@ def get_3d_clifford_encoding_kernel(
return 8, k


def get_3d_clifford_decoding_kernel(
def get_3d_maxwell_decoding_kernel(
w: Union[tuple, list, torch.Tensor, nn.Parameter, nn.ParameterList], g: torch.Tensor
) -> Tuple[int, torch.Tensor]:
"""Clifford kernel for 3d Clifford algebra decoding layers.
Expand Down
40 changes: 20 additions & 20 deletions cliffordlayers/models/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
CliffordConv3d,
)
from cliffordlayers.models.custom_kernels import (
get_2d_clifford_encoding_kernel,
get_2d_clifford_decoding_kernel,
get_2d_clifford_rotation_encoding_kernel,
get_2d_clifford_rotation_decoding_kernel,
get_3d_clifford_encoding_kernel,
get_3d_clifford_decoding_kernel,
get_2d_scalar_vector_encoding_kernel,
get_2d_scalar_vector_decoding_kernel,
get_2d_rotation_scalar_vector_encoding_kernel,
get_2d_rotation_scalar_vector_decoding_kernel,
get_3d_maxwell_encoding_kernel,
get_3d_maxwell_decoding_kernel,
)


class CliffordConv2dEncoder(CliffordConv2d):
"""2d Clifford convolution encoder which inherits from CliffordConv2d."""
class CliffordConv2dScalarVectorEncoder(CliffordConv2d):
"""2d Clifford convolution encoder for scalar+vector input fields which inherits from CliffordConv2d."""

def __init__(
self,
Expand Down Expand Up @@ -54,16 +54,16 @@ def __init__(
)

if rotation:
self._get_kernel = get_2d_clifford_rotation_encoding_kernel
self._get_kernel = get_2d_rotation_scalar_vector_encoding_kernel
else:
self._get_kernel = get_2d_clifford_encoding_kernel
self._get_kernel = get_2d_scalar_vector_encoding_kernel

def forward(self, x: torch.Tensor) -> torch.Tensor:
return super(CliffordConv2d, self).forward(x, F.conv2d)


class CliffordConv2dDecoder(CliffordConv2d):
"""2d Clifford convolution decoder which inherits from CliffordConv2d."""
class CliffordConv2dScalarVectorDecoder(CliffordConv2d):
"""2d Clifford convolution decoder for scalar+vector output fields which inherits from CliffordConv2d."""

def __init__(
self,
Expand Down Expand Up @@ -94,18 +94,18 @@ def __init__(
)

if rotation:
self._get_kernel = get_2d_clifford_rotation_decoding_kernel
self._get_kernel = get_2d_rotation_scalar_vector_decoding_kernel
else:
self._get_kernel = get_2d_clifford_decoding_kernel
self._get_kernel = get_2d_scalar_vector_decoding_kernel

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.bias is True:
raise ValueError("Bias needs to be set to False for 2d Clifford decoding layers.")
return super(CliffordConv2d, self).forward(x, F.conv2d)


class CliffordConv3dEncoder(CliffordConv3d):
"""3d Clifford convolution encoder which inherits from CliffordConv3d."""
class CliffordConv3dMaxwellEncoder(CliffordConv3d):
"""3d Clifford convolution encoder for vector+bivector inputs which inherits from CliffordConv3d."""

def __init__(
self,
Expand Down Expand Up @@ -133,14 +133,14 @@ def __init__(
padding_mode,
)

self._get_kernel = get_3d_clifford_encoding_kernel
self._get_kernel = get_3d_maxwell_encoding_kernel

def forward(self, x: torch.Tensor) -> torch.Tensor:
return super(CliffordConv3d, self).forward(x, F.conv3d)


class CliffordConv3dDecoder(CliffordConv3d):
"""3d Clifford convolution decoder which inherits from CliffordConv3d."""
class CliffordConv3dMaxwellDecoder(CliffordConv3d):
"""3d Clifford convolution decoder for vector+bivector inputs which inherits from CliffordConv3d."""

def __init__(
self,
Expand Down Expand Up @@ -168,7 +168,7 @@ def __init__(
padding_mode,
)

self._get_kernel = get_3d_clifford_decoding_kernel
self._get_kernel = get_3d_maxwell_decoding_kernel

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.bias is True:
Expand Down
16 changes: 8 additions & 8 deletions cliffordlayers/models/models_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from cliffordlayers.nn.modules.cliffordconv import CliffordConv2d
from cliffordlayers.nn.modules.cliffordfourier import CliffordSpectralConv2d
from cliffordlayers.nn.modules.groupnorm import CliffordGroupNorm2d
from cliffordlayers.models.custom_layers import CliffordConv2dDecoder, CliffordConv2dEncoder
from cliffordlayers.models.custom_layers import CliffordConv2dScalarVectorEncoder, CliffordConv2dScalarVectorDecoder


class CliffordBasicBlock2d(nn.Module):
Expand Down Expand Up @@ -147,12 +147,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.activation(self.norm(x1 + x2))


class CliffordNet2d(nn.Module):
"""2D building block for Clifford architectures with ResNet backbone network.
The backbone networks follows these three steps:
1. Clifford encoding.
class CliffordFluidNet2d(nn.Module):
"""2D building block for Clifford architectures for fluid mechanics (vector field+scalar field)
with ResNet backbone network. The backbone networks follows these three steps:
1. Clifford scalar+vector field encoding.
2. Basic blocks as provided.
3. Decoding.
3. Clifford scalar+vector field decoding.

Args:
g (Union[tuple, list, torch.Tensor]): Signature of Clifford algebra.
Expand Down Expand Up @@ -186,15 +186,15 @@ def __init__(

self.activation = activation
# Encoding and decoding layers
self.encoder = CliffordConv2dEncoder(
self.encoder = CliffordConv2dScalarVectorEncoder(
g,
in_channels=in_channels,
out_channels=hidden_channels,
kernel_size=1,
padding=0,
rotation=rotation,
)
self.decoder = CliffordConv2dDecoder(
self.decoder = CliffordConv2dScalarVectorDecoder(
g,
in_channels=hidden_channels,
out_channels=out_channels,
Expand Down
14 changes: 7 additions & 7 deletions cliffordlayers/models/models_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from cliffordlayers.nn.modules.cliffordconv import CliffordConv3d
from cliffordlayers.nn.modules.cliffordfourier import CliffordSpectralConv3d
from cliffordlayers.nn.modules.groupnorm import CliffordGroupNorm3d
from cliffordlayers.models.custom_layers import CliffordConv3dDecoder, CliffordConv3dEncoder
from cliffordlayers.models.custom_layers import CliffordConv3dMaxwellEncoder, CliffordConv3dMaxwellDecoder


class CliffordFourierBasicBlock3d(nn.Module):
"""2D building block for Clifford FNO architectures.
"""3D building block for Clifford FNO architectures.

Args:
g (Union[tuple, list, torch.Tensor]): Signature of Clifford algebra.
Expand Down Expand Up @@ -79,12 +79,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.activation(self.norm(x1 + x2))


class CliffordNet3d(nn.Module):
class CliffordMaxwellNet3d(nn.Module):
"""3D building block for Clifford architectures with ResNet backbone network.
The backbone networks follows these three steps:
1. Clifford encoding.
1. Clifford vector+bivector encoding.
2. Basic blocks as provided.
3. Decoding.
3. Clifford vector+bivector decoding.

Args:
g (Union[tuple, list, torch.Tensor]): Signature of Clifford algebra.
Expand Down Expand Up @@ -116,14 +116,14 @@ def __init__(

self.activation = activation
# Encoding and decoding layers.
self.encoder = CliffordConv3dEncoder(
self.encoder = CliffordConv3dMaxwellEncoder(
g,
in_channels=in_channels,
out_channels=hidden_channels,
kernel_size=1,
padding=0,
)
self.decoder = CliffordConv3dDecoder(
self.decoder = CliffordConv3dMaxwellDecoder(
g,
in_channels=hidden_channels,
out_channels=out_channels,
Expand Down
4 changes: 0 additions & 4 deletions cliffordlayers/version.py

This file was deleted.

26 changes: 13 additions & 13 deletions tests/test_CliffordNet2d.py → tests/test_CliffordFluidNet2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
import torch.nn.functional as F
from cliffordlayers.models.utils import partialclass
from cliffordlayers.models.models_2d import (
CliffordNet2d,
CliffordFluidNet2d,
CliffordBasicBlock2d,
CliffordFourierBasicBlock2d,
)


def test_clifford_resnet():
"""Test shape compatibility of Clifford2d ResNet model."""
"""Test shape compatibility of CliffordFluidNet2d ResNet model."""
x = torch.randn(8, 4, 128, 128, 3)
in_channels = 4
out_channels = 1
model = CliffordNet2d(
model = CliffordFluidNet2d(
g=[1, 1],
block=CliffordBasicBlock2d,
num_blocks=[2, 2, 2, 2],
Expand All @@ -35,11 +35,11 @@ def test_clifford_resnet():


def test_clifford_resnet_norm():
"""Test shape compatibility of Clifford2d ResNet model using normalization."""
"""Test shape compatibility of CliffordFluidNet2d ResNet model using normalization."""
in_channels = 4
out_channels = 1
x = torch.randn(8, in_channels, 128, 128, 3)
model = CliffordNet2d(
model = CliffordFluidNet2d(
g=[1, 1],
block=CliffordBasicBlock2d,
num_blocks=[2, 2, 2, 2],
Expand All @@ -58,11 +58,11 @@ def test_clifford_resnet_norm():


def test_clifford_rotational_resnet_norm():
"""Test shape compatibility of Clifford2d rotational ResNet model using normalization."""
"""Test shape compatibility of CliffordFluidNet2d rotational ResNet model using normalization."""
in_channels = 4
out_channels = 1
x = torch.randn(8, in_channels, 128, 128, 3)
model = CliffordNet2d(
model = CliffordFluidNet2d(
g=[-1, -1],
block=CliffordBasicBlock2d,
num_blocks=[2, 2, 2, 2],
Expand All @@ -81,11 +81,11 @@ def test_clifford_rotational_resnet_norm():


def test_clifford_fourier_net():
"""Test shape compatibility of Clifford2d Fourier model."""
"""Test shape compatibility of CliffordFluidNet2d Fourier model."""
in_channels = 4
out_channels = 1
x = torch.randn(8, in_channels, 128, 128, 3)
model = CliffordNet2d(
model = CliffordFluidNet2d(
g=[1, 1],
block=partialclass("CliffordFourierBasicBlock2d", CliffordFourierBasicBlock2d, modes1=32, modes2=32),
num_blocks=[1, 1, 1, 1],
Expand All @@ -104,11 +104,11 @@ def test_clifford_fourier_net():


def test_clifford_fourier_net_norm():
"""Test shape compatibility of Clifford2d Fourier model using normalization."""
"""Test shape compatibility of CliffordFluidNet2d Fourier model using normalization."""
in_channels = 4
out_channels = 1
x = torch.randn(8, in_channels, 128, 128, 3)
model = CliffordNet2d(
model = CliffordFluidNet2d(
g=[1, 1],
block=partialclass("CliffordFourierBasicBlock2d", CliffordFourierBasicBlock2d, modes1=32, modes2=32),
num_blocks=[1, 1, 1, 1],
Expand All @@ -127,11 +127,11 @@ def test_clifford_fourier_net_norm():


def test_clifford_fourier_rotational_net_norm():
"""Test shapes compatibility of Clifford2d Fourier model using normalization (and rotation)."""
"""Test shapes compatibility of CliffordFluidNet2d Fourier model using normalization (and rotation)."""
in_channels = 4
out_channels = 1
x = torch.randn(8, in_channels, 128, 128, 3)
model = CliffordNet2d(
model = CliffordFluidNet2d(
g=[-1, -1],
block=partialclass("CliffordFourierBasicBlock2d", CliffordFourierBasicBlock2d, modes1=32, modes2=32),
num_blocks=[1, 1, 1, 1],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import torch
import torch.nn.functional as F
from cliffordlayers.models.models_3d import (
CliffordNet3d,
CliffordMaxwellNet3d,
CliffordFourierBasicBlock3d,
)


def test_clifford_fourier_resnet():
"""Test shape compatibility of Clifford3d Fourier model."""
x = torch.randn(8, 4, 64, 64, 64, 6)
model = CliffordNet3d(
"""Test shape compatibility of CliffordMaxwellNet3d Fourier model."""
x = torch.randn(8, 4, 32, 32, 32, 6)
model = CliffordMaxwellNet3d(
g=[1, 1, 1],
block=CliffordFourierBasicBlock3d,
num_blocks=[1, 1, 1, 1],
Expand All @@ -26,13 +26,13 @@ def test_clifford_fourier_resnet():
x = x.to("cuda:0")
model = model.to("cuda:0")
out = model(x)
assert out.shape == (8, 1, 64, 64, 64, 6)
assert out.shape == (8, 1, 32, 32, 32, 6)


def test_clifford_fourier_net_norm():
"""Test shape compatibility of Clifford2d Fourier model using normalization."""
x = torch.randn(8, 4, 64, 64, 64, 6)
model = CliffordNet3d(
"""Test shape compatibility of CliffordMaxwellNet2d Fourier model using normalization."""
x = torch.randn(8, 4, 32, 32, 32, 6)
model = CliffordMaxwellNet3d(
g=[1, 1, 1],
block=CliffordFourierBasicBlock3d,
num_blocks=[1, 1, 1, 1],
Expand All @@ -47,4 +47,4 @@ def test_clifford_fourier_net_norm():
x = x.to("cuda:0")
model = model.to("cuda:0")
out = model(x)
assert out.shape == (8, 6, 64, 64, 64, 6)
assert out.shape == (8, 6, 32, 32, 32, 6)
Loading
Loading