Skip to content

Commit

Permalink
Begin implementing inverse FFT
Browse files Browse the repository at this point in the history
Signed-off-by: Dashiell Stander <[email protected]>
  • Loading branch information
dashstander committed Sep 23, 2024
1 parent 89d3597 commit e08b2f0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 25 deletions.
32 changes: 11 additions & 21 deletions src/algebraist/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,11 @@ def _inverse_fourier_projection(ft: torch.Tensor, irrep: SnIrrep):
"""

matrices = irrep.matrix_tensor(ft.dtype, ft.device)
if ft.dim() < 2:
ft = ft.unsqueeze(0)

if irrep.dim == 1:
print(ft.shape)
result = torch.einsum('...i,g->...ig', ft, matrices).squeeze()
else:
dim = irrep.dim
Expand Down Expand Up @@ -198,24 +201,17 @@ def slow_sn_ift(ft, n: int):
"""
permutations = Permutation.full_group(n)
group_order = len(permutations)
irreps = {shape: SnIrrep(n, shape).matrix_tensor() for shape in ft.keys()}
trivial_irrep = (n,)
sign_irrep = tuple([1] * n)

irreps = {shape: SnIrrep(n, shape) for shape in ft.keys()}

batch_size = ft[(n - 1, 1)].shape[0] if len(ft[(n - 1, 1)].shape) == 3 else None
if batch_size is None:
ift = torch.zeros((group_order,), device=ft[(n - 1, 1)].device)
else:
ift = torch.zeros((batch_size, group_order), device=ft[(n - 1, 1)].device)
for shape, irrep_ft in ft.items():
# Properly the formula says we should multiply by $rho(g^{-1})$, i.e. the transpose here
inv_rep = irreps[shape].to(irrep_ft.dtype).to(irrep_ft.device)
if shape == trivial_irrep or shape == sign_irrep:
ift += torch.einsum('...i,g->...ig', irrep_ft, inv_rep).squeeze()
else:
dim = inv_rep.shape[-1]
# But this contracts the tensors in the correct order without the transpose
ift += dim * torch.einsum('...ij,gij->...g', irrep_ft, inv_rep)
#inv_rep = irreps[shape].to(irrep_ft.dtype).to(irrep_ft.device)
ift += _inverse_fourier_projection(irrep_ft, irreps[shape])

return (ift / group_order)

Expand All @@ -233,10 +229,7 @@ def slow_sn_fourier_decomposition(ft, n: int):
"""
permutations = Permutation.full_group(n)
group_order = len(permutations)
irreps = {shape: SnIrrep(n, shape).matrix_tensor() for shape in ft.keys()}
trivial_irrep = (n,)
sign_irrep = tuple([1] * n)

irreps = {shape: SnIrrep(n, shape) for shape in ft.keys()}
num_irreps = len(ft.keys())
batch_size = ft[(n - 1, 1)].shape[0] if len(ft[(n - 1, 1)].shape) == 3 else None

Expand All @@ -246,12 +239,9 @@ def slow_sn_fourier_decomposition(ft, n: int):
ift = torch.zeros((num_irreps, batch_size, group_order), device=ft[(n - 1, 1)].device)

for i, (shape, irrep_ft) in enumerate(ft.items()):
inv_rep = irreps[shape].to(irrep_ft.dtype).to(irrep_ft.device)
if shape == trivial_irrep or shape == sign_irrep: # One-dimensional representation
ift[i] = torch.einsum('...i,g->...ig', irrep_ft, inv_rep).squeeze()
else: # Higher-dimensional representation
dim = inv_rep.shape[-1]
ift[i] += dim * torch.einsum('...ij,gij->...g', irrep_ft, inv_rep)
#inv_rep = irreps[shape].to(irrep_ft.dtype).to(irrep_ft.device)
ift[i] = _inverse_fourier_projection(irrep_ft, irreps[shape])

if batch_size is not None:
ift = ift.permute(1, 0, 2)
return (ift / group_order).squeeze()
Expand Down
8 changes: 4 additions & 4 deletions tests/test_fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def generate_random_function(n, batch_size=None):
@pytest.mark.parametrize("batch_size", [None, 1, 5])
def test_fourier_transform_invertibility(n, batch_size):
f = generate_random_function(n, batch_size)
ft = slow_sn_ft(f, n)
ft = sn_fft(f, n)
ift = slow_sn_ift(ft, n)
f = f.squeeze()
assert ift.shape == f.shape
Expand All @@ -71,7 +71,7 @@ def test_fourier_decomposition(n, batch_size):
@pytest.mark.parametrize("batch_size", [None, 1, 5])
def test_fourier_transform_norm_preservation(n, batch_size):
f = generate_random_function(n, batch_size)
ft = slow_sn_ft(f, n)
ft = sn_fft(f, n)
power = calc_power(ft, n)
total_power = sum(p for p in power.values())
if batch_size is None:
Expand All @@ -88,7 +88,7 @@ def test_convolution_theorem(n):
# Compute convolution in group domain
conv_group = convolve(g, f, n)

ft_conv_time = slow_sn_ft(conv_group, n)
ft_conv_time = sn_fft(conv_group, n)

# Compute convolution in Fourier domain
ft_f = sn_fft(f, n)
Expand All @@ -115,7 +115,7 @@ def test_permutation_action(n):
permutation_action = [(perm.inverse * p).permutation_index() for p in permutations ]
# Action in group domain
f_perm = f[permutation_action]
ft_perm = slow_sn_ft(f_perm, n)
ft_perm = sn_fft(f_perm, n)

# Action in Fourier domain
ft_action = {}
Expand Down

0 comments on commit e08b2f0

Please sign in to comment.