Skip to content

Commit

Permalink
Merge pull request #1 from dashstander/ifft
Browse files Browse the repository at this point in the history
Fast Inverse Fourier Transform
  • Loading branch information
dashstander authored Sep 25, 2024
2 parents e869a70 + 0360516 commit 7a94751
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 102 deletions.
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Analyzing data on the symmetric group has applications in many fields including
## Features

- Efficient implementation of the $S_n$ FFT algorithm in PyTorch
- Support for both (fast) forward and (slow right now) inverse transforms.
- Support for both forward and inverse transforms.
- Utilities for working with permutations and representations of $S_n$
- Examples and tests demonstrating usage and correctness

Expand All @@ -34,7 +34,7 @@ Here's a basic example of how to use the $S_n$ FFT. For more in-depth examples o

```python
import torch
from algebraist import sn_fft
from algebraist import sn_fft, sn_ifft

# Create a function on S5 (represented as a tensor of size 120)
n = 5
Expand All @@ -44,6 +44,13 @@ fn = torch.randn(120)
ft = sn_fft(fn, n)

# ft is now a dictionary mapping partitions to their Fourier transforms
for partition, ft_matrix in ft.items():
# The frequencies of the Sn Fourier transform are the partitions of n
print(partition) # The partitions of 5 are (5,), (4, 1), (3, 1, 1), (2, 2, 1), (2, 1, 1, 1), and (1, 1, 1, 1, 1)
print(ft_matrix) # because S_n isn't abelian the output of the Fourier transform for each partition is a matrix

# The Fourier transform is completely invertible
assert fn == sn_ifft(ft, n)
```

## Requirements
Expand Down
4 changes: 2 additions & 2 deletions src/algebraist/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from algebraist.fourier import sn_fft
from algebraist.fourier import sn_fft, sn_ifft, sn_fourier_decomposition, calc_power


__all__ = ['sn_fft']
__all__ = ['sn_fft', 'sn_ifft', 'sn_fourier_decomposition', 'calc_power']
131 changes: 103 additions & 28 deletions src/algebraist/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,24 @@ def get_all_irreps(n: int) -> list[SnIrrep]:
return [SnIrrep(n, p) for p in generate_partitions(n)]


def sn_minus_1_coset(tensor: torch.Tensor, sn_perms: torch.Tensor, idx: int) -> torch.Tensor:
def lift_from_coset(lifted_fn, coset_fn: torch.Tensor, sn_perms: torch.Tensor, idx: int) -> torch.Tensor:
"""
Inverse operation of restrict_to_coset. Assigns values from S_{n-1} cosets back to their correct positions in S_n.
Args:
tensor (torch.Tensor): The function on S_{n-1} cosets, shape (n, batch_size, (n-1)!)
sn_perms (torch.Tensor): A tensor-version of S_n with shape (n!, n)
Returns:
None, operates in place on lifted_fn
"""
n = sn_perms.shape[1]
fixed_element = n - 1
coset_idx = torch.argwhere(sn_perms[:, idx] == fixed_element).squeeze()
lifted_fn[:, coset_idx] = coset_fn[idx]


def restrict_to_coset(tensor: torch.Tensor, sn_perms: torch.Tensor, idx: int) -> torch.Tensor:
"""
Returns the values that a function on S_n takes on of one of the cosets of S_{n-1} < S_n
Expand All @@ -61,8 +78,6 @@ def sn_minus_1_coset(tensor: torch.Tensor, sn_perms: torch.Tensor, idx: int) ->
fixed_element = n - 1
coset_idx = torch.argwhere(sn_perms[:, idx] == fixed_element).squeeze()
return tensor[..., coset_idx]



def slow_sn_ft(fn_vals: torch.Tensor, n: int):
"""
Expand All @@ -88,7 +103,7 @@ def slow_sn_ft(fn_vals: torch.Tensor, n: int):
result = torch.einsum('bi,i->b', fn_vals, matrices)
else: # Higher-dimensional representation
result = torch.einsum('bi,ijk->bjk', fn_vals, matrices).squeeze()
results[irrep.shape] = result
results[irrep.partition] = result

return results

Expand Down Expand Up @@ -130,6 +145,8 @@ def _inverse_fourier_projection(ft: torch.Tensor, irrep: SnIrrep):
"""

matrices = irrep.matrix_tensor(ft.dtype, ft.device)
if ft.dim() < 2:
ft = ft.unsqueeze(0)

if irrep.dim == 1:
result = torch.einsum('...i,g->...ig', ft, matrices).squeeze()
Expand All @@ -142,6 +159,67 @@ def _inverse_fourier_projection(ft: torch.Tensor, irrep: SnIrrep):
return result


def inverse_fourier_projection(ft, irrep):
n = irrep.n
if n <= BASE_CASE or irrep.dim == 1:
return _inverse_fourier_projection(ft, irrep) / math.factorial(n)

sn_perms = generate_all_permutations(n)

# Ensure ft is always 3D (batch_dim, irrep.dim, irrep.dim)
has_batch = True
if ft.dim() == 2:
has_batch = False
ft = ft.unsqueeze(0)

batch_dim = ft.shape[0]

# Inverse this time
coset_rep_matrices = torch.stack([mat.T for mat in irrep.coset_rep_matrices(ft.dtype, ft.device)]).unsqueeze(0)
#assert coset_rep_matrices.shape == (1, n, irrep.dim, irrep.dim), \
# f'{coset_rep_matrices.shape} != {(1, n, irrep.dim, irrep.dim)}'

# equivalent to [coset_rep_inverse @ ft for coset_rep_inverse in cosets]
# we have now translated the Fourier transform to be amenable to the S_{n-1} basis
coset_fts = torch.matmul(coset_rep_matrices, ft.unsqueeze(1))
#assert coset_fts.shape == (n, batch_dim, irrep.dim, irrep.dim), \
# f'{coset_fts.shape} != {(n, batch_dim, irrep.dim, irrep.dim)}'

split_irreps = [SnIrrep(n-1, sub_partition) for sub_partition in irrep.split_partition()]


# We have a big (irrep.dim, irrep.dim) matrix as the result of the forward FFT
# With respect to the S_{n-1} irreps it has a block structure, here we pull those blocks out
# We scale by (irrep.dim / (sub_irrep.dim * n)) to keep things working nicely with the recursion.
# Non - recursively we scale by irrep.dim, we divide here to cancel that out. Basically at the very top level
# we only want the top or "main" irrep dim to contribute
sub_ft_blocks = [
(irrep.dim / (sub_irrep.dim * n)) * coset_fts[..., rows, cols]
for (rows, cols), sub_irrep in zip(irrep.get_block_indices(), split_irreps)
]

# recursive call here
sub_ifts = [
torch.vmap(inverse_fourier_projection, in_dims=(0, None))(block, sub_irrep)
for block, sub_irrep in zip(sub_ft_blocks, split_irreps)
]
# there are n elements in sub_ifts, each is an ift on
# assert all([ift.shape == (batch_dim, math.factorial(n-1)) for ift in sub_ifts])

fn_vals = torch.zeros((batch_dim, math.factorial(n)), dtype=ft.dtype, device=ft.device)

# reshapes from
for i, coset_ift in enumerate(sub_ifts):
# operates in place on fn_vals
lift_from_coset(fn_vals, coset_ift, sn_perms, i)

if not has_batch:
fn_vals = fn_vals.squeeze()

return fn_vals / math.factorial(n)



def fourier_projection(fn_vals: torch.Tensor, irrep: SnIrrep) -> torch.Tensor:
"""
Fast projection of a function on S_n (given as a pytorch tensor) onto one of the irreducible representations (irreps) of S_n. If n > 5 then
Expand Down Expand Up @@ -172,7 +250,7 @@ def fourier_projection(fn_vals: torch.Tensor, irrep: SnIrrep) -> torch.Tensor:
has_batch = False
fn_vals = fn_vals.unsqueeze(0)

coset_fns = torch.stack([sn_minus_1_coset(fn_vals, sn_perms, i) for i in range(n)]).permute(1, 0, 2)
coset_fns = torch.stack([restrict_to_coset(fn_vals, sn_perms, i) for i in range(n)]).permute(1, 0, 2)
# Now coset_fns shape is (batch_dim, n, (n-1)!)
# assert coset_fns.shape == (fn_vals.shape[0], n, math.factorial(n-1)), coset_fns.shape

Expand All @@ -196,7 +274,7 @@ def fourier_projection(fn_vals: torch.Tensor, irrep: SnIrrep) -> torch.Tensor:

if not has_batch:
result = result.squeeze(0)

return result


Expand All @@ -206,10 +284,20 @@ def sn_fft(fn_vals: torch.Tensor, n: int, verbose=False) -> dict[tuple[int, ...]
if verbose:
all_irreps = tqdm(all_irreps)
for irrep in all_irreps:
result[irrep.shape] = fourier_projection(fn_vals, irrep)
result[irrep.partition] = fourier_projection(fn_vals, irrep)
return result


def sn_fourier_decomposition(ft, n):
return {
irrep.partition: inverse_fourier_projection(ft[irrep.partition], irrep)
for irrep in SnIrrep.generate_all_irreps(n)
}


def sn_ifft(ft, n):
return sum(sn_fourier_decomposition(ft, n).values())


def slow_sn_ift(ft, n: int):
"""
Expand All @@ -224,24 +312,17 @@ def slow_sn_ift(ft, n: int):
"""
permutations = Permutation.full_group(n)
group_order = len(permutations)
irreps = {shape: SnIrrep(n, shape).matrix_tensor() for shape in ft.keys()}
trivial_irrep = (n,)
sign_irrep = tuple([1] * n)

irreps = {shape: SnIrrep(n, shape) for shape in ft.keys()}

batch_size = ft[(n - 1, 1)].shape[0] if len(ft[(n - 1, 1)].shape) == 3 else None
if batch_size is None:
ift = torch.zeros((group_order,), device=ft[(n - 1, 1)].device)
else:
ift = torch.zeros((batch_size, group_order), device=ft[(n - 1, 1)].device)
for shape, irrep_ft in ft.items():
# Properly the formula says we should multiply by $rho(g^{-1})$, i.e. the transpose here
inv_rep = irreps[shape].to(irrep_ft.dtype).to(irrep_ft.device)
if shape == trivial_irrep or shape == sign_irrep:
ift += torch.einsum('...i,g->...ig', irrep_ft, inv_rep).squeeze()
else:
dim = inv_rep.shape[-1]
# But this contracts the tensors in the correct order without the transpose
ift += dim * torch.einsum('...ij,gij->...g', irrep_ft, inv_rep)
#inv_rep = irreps[shape].to(irrep_ft.dtype).to(irrep_ft.device)
ift += _inverse_fourier_projection(irrep_ft, irreps[shape])

return (ift / group_order)

Expand All @@ -259,10 +340,7 @@ def slow_sn_fourier_decomposition(ft, n: int):
"""
permutations = Permutation.full_group(n)
group_order = len(permutations)
irreps = {shape: SnIrrep(n, shape).matrix_tensor() for shape in ft.keys()}
trivial_irrep = (n,)
sign_irrep = tuple([1] * n)

irreps = {shape: SnIrrep(n, shape) for shape in ft.keys()}
num_irreps = len(ft.keys())
batch_size = ft[(n - 1, 1)].shape[0] if len(ft[(n - 1, 1)].shape) == 3 else None

Expand All @@ -272,12 +350,9 @@ def slow_sn_fourier_decomposition(ft, n: int):
ift = torch.zeros((num_irreps, batch_size, group_order), device=ft[(n - 1, 1)].device)

for i, (shape, irrep_ft) in enumerate(ft.items()):
inv_rep = irreps[shape].to(irrep_ft.dtype).to(irrep_ft.device)
if shape == trivial_irrep or shape == sign_irrep: # One-dimensional representation
ift[i] = torch.einsum('...i,g->...ig', irrep_ft, inv_rep).squeeze()
else: # Higher-dimensional representation
dim = inv_rep.shape[-1]
ift[i] += dim * torch.einsum('...ij,gij->...g', irrep_ft, inv_rep)
#inv_rep = irreps[shape].to(irrep_ft.dtype).to(irrep_ft.device)
ift[i] = _inverse_fourier_projection(irrep_ft, irreps[shape])

if batch_size is not None:
ift = ift.permute(1, 0, 2)
return (ift / group_order).squeeze()
Expand Down
63 changes: 34 additions & 29 deletions src/algebraist/irreps.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.


from copy import deepcopy
from functools import cached_property, reduce
from itertools import combinations, pairwise
import numpy as np
Expand All @@ -22,11 +21,11 @@
from typing import Iterator, Self

from algebraist.permutations import Permutation
from algebraist.tableau import enumerate_standard_tableau, generate_partitions, hook_length, YoungTableau
from algebraist.tableau import enumerate_standard_tableau, generate_partitions, hook_length, youngs_lattice_covering_relation, YoungTableau
from algebraist.utils import adj_trans_decomp, cycle_to_one_line, trans_to_one_line


def contiguous_cycle(n: int, i: int):
def contiguous_cycle(n: int, i: int) -> tuple[int, ...]:
""" Generates a permutation (in cycle notation) of the form (i, i+1, ..., n)
"""
if i == n - 1:
Expand All @@ -40,8 +39,8 @@ class SnIrrep:

def __init__(self, n: int, partition: tuple[int, ...]):
self.n = n
self.shape = partition
self.dim = hook_length(self.shape)
self.partition = partition
self.dim = hook_length(self.partition)
self.permutations = Permutation.full_group(n)

@staticmethod
Expand All @@ -50,37 +49,38 @@ def generate_all_irreps(n: int) -> Iterator[Self]:
yield SnIrrep(n, partition)

def __eq__(self, other) -> bool:
return self.shape == other.shape
return self.partition == other.shape

def __hash__(self) -> int:
return hash(str(self.shape))
return hash(str(self.partition))

def __repr__(self) -> str:
return f'S{self.n} Irrep: {self.shape}'
return f'S{self.n} Irrep: {self.partition}'

@cached_property
def basis(self) -> list[YoungTableau]:
return sorted(enumerate_standard_tableau(self.shape))
return sorted(enumerate_standard_tableau(self.partition))

def split_partition(self) -> list[tuple[int, ...]]:
new_partitions = []
k = len(self.shape)
for i in range(k - 1):
# check if valid subrepresentation
if self.shape[i] > self.shape[i+1]:
# if so, copy, modify, and append to list
partition = list(deepcopy(self.shape))
partition[i] -= 1
new_partitions.append(tuple(partition))
# the last subrep
partition = list(deepcopy(self.shape))
if partition[-1] > 1:
partition[-1] -= 1
else:
# removing last element of partition if it’s a 1
del partition[-1]
new_partitions.append(tuple(partition))
return sorted(new_partitions)
"""A list of the partitions directly underneath the partition that defines this irrep, in terms of Young's lattice.
These partitions directly beneath self.partition define the irreducible representations of S_{n-1} that this irrep "splits" into when we restrict to S_{n-1}. This relationship form the core of the "fast" part of the FFT.
"""
return sorted(youngs_lattice_covering_relation(self.partition))

def get_block_indices(self):
"""When restricted to S_{n-1} this irrep has a block-diagonal form--one block for each of the split irreps of S_{n-1}. This helper method gets the indices of those blocks.
"""
curr_row, curr_col = 0, 0
block_idx = []
for split_irrep in self.split_partition():
dim = SnIrrep(self.n - 1, split_irrep).dim
next_row = curr_row + dim
next_col = curr_col + dim
block_idx.append((slice(curr_row, next_row ), slice(curr_col, next_col)))
curr_row = next_row
curr_col = next_col

return block_idx

def adjacent_transpositions(self) -> list[tuple[int, int]]:
return pairwise(range(self.n))
Expand Down Expand Up @@ -116,6 +116,8 @@ def generate_transposition_matrices(self) -> dict[tuple[int], ArrayLike]:
decomp = [matrices[pair] for pair in adj_trans_decomp(i, j)]
matrices[(i, j)] = reduce(lambda x, y: x @ y, decomp)
return matrices



@cached_property
def matrix_representations(self) -> dict[tuple[int], ArrayLike]:
Expand Down Expand Up @@ -145,9 +147,12 @@ def matrix_tensor(self, dtype=torch.float64, device=torch.device('cpu')) -> torc
]
return torch.concatenate(tensors, dim=0).squeeze().to(device)

def coset_rep_matrices(self, dtype=torch.float64) -> list[torch.Tensor]:
def coset_rep_matrices(self, dtype=torch.float64, device=torch.device('cpu')) -> list[torch.Tensor]:
coset_reps = [Permutation(contiguous_cycle(self.n, i)).sigma for i in range(self.n)]
return [torch.from_numpy(self.matrix_representations[rep]).to(dtype) for rep in coset_reps]
return [
torch.from_numpy(self.matrix_representations[rep]).to(dtype).to(device)
for rep in coset_reps
]

def alternating_matrix_tensor(self, dtype=torch.float64, device=torch.device('cpu')):
tensors = [
Expand Down
Loading

0 comments on commit 7a94751

Please sign in to comment.