diff --git a/README.md b/README.md index 4d3ec5d..7eabfcb 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Analyzing data on the symmetric group has applications in many fields including ## Features - Efficient implementation of the $S_n$ FFT algorithm in PyTorch -- Support for both (fast) forward and (slow right now) inverse transforms. +- Support for both forward and inverse transforms. - Utilities for working with permutations and representations of $S_n$ - Examples and tests demonstrating usage and correctness @@ -34,7 +34,7 @@ Here's a basic example of how to use the $S_n$ FFT. For more in-depth examples o ```python import torch -from algebraist import sn_fft +from algebraist import sn_fft, sn_ifft # Create a function on S5 (represented as a tensor of size 120) n = 5 @@ -44,6 +44,13 @@ fn = torch.randn(120) ft = sn_fft(fn, n) # ft is now a dictionary mapping partitions to their Fourier transforms +for partition, ft_matrix in ft.items(): + # The frequencies of the Sn Fourier transform are the partitions of n + print(partition) # The partitions of 5 are (5,), (4, 1), (3, 1, 1), (2, 2, 1), (2, 1, 1, 1), and (1, 1, 1, 1, 1) + print(ft_matrix) # because S_n isn't abelian the output of the Fourier transform for each partition is a matrix + +# The Fourier transform is completely invertible +assert fn == sn_ifft(ft, n) ``` ## Requirements diff --git a/src/algebraist/__init__.py b/src/algebraist/__init__.py index 051f454..b1e3a92 100644 --- a/src/algebraist/__init__.py +++ b/src/algebraist/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. -from algebraist.fourier import sn_fft +from algebraist.fourier import sn_fft, sn_ifft, sn_fourier_decomposition, calc_power -__all__ = ['sn_fft'] +__all__ = ['sn_fft', 'sn_ifft', 'sn_fourier_decomposition', 'calc_power'] diff --git a/src/algebraist/fourier.py b/src/algebraist/fourier.py index 1cb093e..88e0ac0 100644 --- a/src/algebraist/fourier.py +++ b/src/algebraist/fourier.py @@ -43,7 +43,24 @@ def get_all_irreps(n: int) -> list[SnIrrep]: return [SnIrrep(n, p) for p in generate_partitions(n)] -def sn_minus_1_coset(tensor: torch.Tensor, sn_perms: torch.Tensor, idx: int) -> torch.Tensor: +def lift_from_coset(lifted_fn, coset_fn: torch.Tensor, sn_perms: torch.Tensor, idx: int) -> torch.Tensor: + """ + Inverse operation of restrict_to_coset. Assigns values from S_{n-1} cosets back to their correct positions in S_n. + + Args: + tensor (torch.Tensor): The function on S_{n-1} cosets, shape (n, batch_size, (n-1)!) + sn_perms (torch.Tensor): A tensor-version of S_n with shape (n!, n) + + Returns: + None, operates in place on lifted_fn + """ + n = sn_perms.shape[1] + fixed_element = n - 1 + coset_idx = torch.argwhere(sn_perms[:, idx] == fixed_element).squeeze() + lifted_fn[:, coset_idx] = coset_fn[idx] + + +def restrict_to_coset(tensor: torch.Tensor, sn_perms: torch.Tensor, idx: int) -> torch.Tensor: """ Returns the values that a function on S_n takes on of one of the cosets of S_{n-1} < S_n @@ -61,8 +78,6 @@ def sn_minus_1_coset(tensor: torch.Tensor, sn_perms: torch.Tensor, idx: int) -> fixed_element = n - 1 coset_idx = torch.argwhere(sn_perms[:, idx] == fixed_element).squeeze() return tensor[..., coset_idx] - - def slow_sn_ft(fn_vals: torch.Tensor, n: int): """ @@ -88,7 +103,7 @@ def slow_sn_ft(fn_vals: torch.Tensor, n: int): result = torch.einsum('bi,i->b', fn_vals, matrices) else: # Higher-dimensional representation result = torch.einsum('bi,ijk->bjk', fn_vals, matrices).squeeze() - results[irrep.shape] = result + results[irrep.partition] = result return results @@ -130,6 +145,8 @@ def _inverse_fourier_projection(ft: torch.Tensor, irrep: SnIrrep): """ matrices = irrep.matrix_tensor(ft.dtype, ft.device) + if ft.dim() < 2: + ft = ft.unsqueeze(0) if irrep.dim == 1: result = torch.einsum('...i,g->...ig', ft, matrices).squeeze() @@ -142,6 +159,67 @@ def _inverse_fourier_projection(ft: torch.Tensor, irrep: SnIrrep): return result +def inverse_fourier_projection(ft, irrep): + n = irrep.n + if n <= BASE_CASE or irrep.dim == 1: + return _inverse_fourier_projection(ft, irrep) / math.factorial(n) + + sn_perms = generate_all_permutations(n) + + # Ensure ft is always 3D (batch_dim, irrep.dim, irrep.dim) + has_batch = True + if ft.dim() == 2: + has_batch = False + ft = ft.unsqueeze(0) + + batch_dim = ft.shape[0] + + # Inverse this time + coset_rep_matrices = torch.stack([mat.T for mat in irrep.coset_rep_matrices(ft.dtype, ft.device)]).unsqueeze(0) + #assert coset_rep_matrices.shape == (1, n, irrep.dim, irrep.dim), \ + # f'{coset_rep_matrices.shape} != {(1, n, irrep.dim, irrep.dim)}' + + # equivalent to [coset_rep_inverse @ ft for coset_rep_inverse in cosets] + # we have now translated the Fourier transform to be amenable to the S_{n-1} basis + coset_fts = torch.matmul(coset_rep_matrices, ft.unsqueeze(1)) + #assert coset_fts.shape == (n, batch_dim, irrep.dim, irrep.dim), \ + # f'{coset_fts.shape} != {(n, batch_dim, irrep.dim, irrep.dim)}' + + split_irreps = [SnIrrep(n-1, sub_partition) for sub_partition in irrep.split_partition()] + + + # We have a big (irrep.dim, irrep.dim) matrix as the result of the forward FFT + # With respect to the S_{n-1} irreps it has a block structure, here we pull those blocks out + # We scale by (irrep.dim / (sub_irrep.dim * n)) to keep things working nicely with the recursion. + # Non - recursively we scale by irrep.dim, we divide here to cancel that out. Basically at the very top level + # we only want the top or "main" irrep dim to contribute + sub_ft_blocks = [ + (irrep.dim / (sub_irrep.dim * n)) * coset_fts[..., rows, cols] + for (rows, cols), sub_irrep in zip(irrep.get_block_indices(), split_irreps) + ] + + # recursive call here + sub_ifts = [ + torch.vmap(inverse_fourier_projection, in_dims=(0, None))(block, sub_irrep) + for block, sub_irrep in zip(sub_ft_blocks, split_irreps) + ] + # there are n elements in sub_ifts, each is an ift on + # assert all([ift.shape == (batch_dim, math.factorial(n-1)) for ift in sub_ifts]) + + fn_vals = torch.zeros((batch_dim, math.factorial(n)), dtype=ft.dtype, device=ft.device) + + # reshapes from + for i, coset_ift in enumerate(sub_ifts): + # operates in place on fn_vals + lift_from_coset(fn_vals, coset_ift, sn_perms, i) + + if not has_batch: + fn_vals = fn_vals.squeeze() + + return fn_vals / math.factorial(n) + + + def fourier_projection(fn_vals: torch.Tensor, irrep: SnIrrep) -> torch.Tensor: """ Fast projection of a function on S_n (given as a pytorch tensor) onto one of the irreducible representations (irreps) of S_n. If n > 5 then @@ -172,7 +250,7 @@ def fourier_projection(fn_vals: torch.Tensor, irrep: SnIrrep) -> torch.Tensor: has_batch = False fn_vals = fn_vals.unsqueeze(0) - coset_fns = torch.stack([sn_minus_1_coset(fn_vals, sn_perms, i) for i in range(n)]).permute(1, 0, 2) + coset_fns = torch.stack([restrict_to_coset(fn_vals, sn_perms, i) for i in range(n)]).permute(1, 0, 2) # Now coset_fns shape is (batch_dim, n, (n-1)!) # assert coset_fns.shape == (fn_vals.shape[0], n, math.factorial(n-1)), coset_fns.shape @@ -196,7 +274,7 @@ def fourier_projection(fn_vals: torch.Tensor, irrep: SnIrrep) -> torch.Tensor: if not has_batch: result = result.squeeze(0) - + return result @@ -206,10 +284,20 @@ def sn_fft(fn_vals: torch.Tensor, n: int, verbose=False) -> dict[tuple[int, ...] if verbose: all_irreps = tqdm(all_irreps) for irrep in all_irreps: - result[irrep.shape] = fourier_projection(fn_vals, irrep) + result[irrep.partition] = fourier_projection(fn_vals, irrep) return result +def sn_fourier_decomposition(ft, n): + return { + irrep.partition: inverse_fourier_projection(ft[irrep.partition], irrep) + for irrep in SnIrrep.generate_all_irreps(n) + } + + +def sn_ifft(ft, n): + return sum(sn_fourier_decomposition(ft, n).values()) + def slow_sn_ift(ft, n: int): """ @@ -224,10 +312,8 @@ def slow_sn_ift(ft, n: int): """ permutations = Permutation.full_group(n) group_order = len(permutations) - irreps = {shape: SnIrrep(n, shape).matrix_tensor() for shape in ft.keys()} - trivial_irrep = (n,) - sign_irrep = tuple([1] * n) - + irreps = {shape: SnIrrep(n, shape) for shape in ft.keys()} + batch_size = ft[(n - 1, 1)].shape[0] if len(ft[(n - 1, 1)].shape) == 3 else None if batch_size is None: ift = torch.zeros((group_order,), device=ft[(n - 1, 1)].device) @@ -235,13 +321,8 @@ def slow_sn_ift(ft, n: int): ift = torch.zeros((batch_size, group_order), device=ft[(n - 1, 1)].device) for shape, irrep_ft in ft.items(): # Properly the formula says we should multiply by $rho(g^{-1})$, i.e. the transpose here - inv_rep = irreps[shape].to(irrep_ft.dtype).to(irrep_ft.device) - if shape == trivial_irrep or shape == sign_irrep: - ift += torch.einsum('...i,g->...ig', irrep_ft, inv_rep).squeeze() - else: - dim = inv_rep.shape[-1] - # But this contracts the tensors in the correct order without the transpose - ift += dim * torch.einsum('...ij,gij->...g', irrep_ft, inv_rep) + #inv_rep = irreps[shape].to(irrep_ft.dtype).to(irrep_ft.device) + ift += _inverse_fourier_projection(irrep_ft, irreps[shape]) return (ift / group_order) @@ -259,10 +340,7 @@ def slow_sn_fourier_decomposition(ft, n: int): """ permutations = Permutation.full_group(n) group_order = len(permutations) - irreps = {shape: SnIrrep(n, shape).matrix_tensor() for shape in ft.keys()} - trivial_irrep = (n,) - sign_irrep = tuple([1] * n) - + irreps = {shape: SnIrrep(n, shape) for shape in ft.keys()} num_irreps = len(ft.keys()) batch_size = ft[(n - 1, 1)].shape[0] if len(ft[(n - 1, 1)].shape) == 3 else None @@ -272,12 +350,9 @@ def slow_sn_fourier_decomposition(ft, n: int): ift = torch.zeros((num_irreps, batch_size, group_order), device=ft[(n - 1, 1)].device) for i, (shape, irrep_ft) in enumerate(ft.items()): - inv_rep = irreps[shape].to(irrep_ft.dtype).to(irrep_ft.device) - if shape == trivial_irrep or shape == sign_irrep: # One-dimensional representation - ift[i] = torch.einsum('...i,g->...ig', irrep_ft, inv_rep).squeeze() - else: # Higher-dimensional representation - dim = inv_rep.shape[-1] - ift[i] += dim * torch.einsum('...ij,gij->...g', irrep_ft, inv_rep) + #inv_rep = irreps[shape].to(irrep_ft.dtype).to(irrep_ft.device) + ift[i] = _inverse_fourier_projection(irrep_ft, irreps[shape]) + if batch_size is not None: ift = ift.permute(1, 0, 2) return (ift / group_order).squeeze() diff --git a/src/algebraist/irreps.py b/src/algebraist/irreps.py index d91949e..a70db69 100644 --- a/src/algebraist/irreps.py +++ b/src/algebraist/irreps.py @@ -13,7 +13,6 @@ # limitations under the License. -from copy import deepcopy from functools import cached_property, reduce from itertools import combinations, pairwise import numpy as np @@ -22,11 +21,11 @@ from typing import Iterator, Self from algebraist.permutations import Permutation -from algebraist.tableau import enumerate_standard_tableau, generate_partitions, hook_length, YoungTableau +from algebraist.tableau import enumerate_standard_tableau, generate_partitions, hook_length, youngs_lattice_covering_relation, YoungTableau from algebraist.utils import adj_trans_decomp, cycle_to_one_line, trans_to_one_line -def contiguous_cycle(n: int, i: int): +def contiguous_cycle(n: int, i: int) -> tuple[int, ...]: """ Generates a permutation (in cycle notation) of the form (i, i+1, ..., n) """ if i == n - 1: @@ -40,8 +39,8 @@ class SnIrrep: def __init__(self, n: int, partition: tuple[int, ...]): self.n = n - self.shape = partition - self.dim = hook_length(self.shape) + self.partition = partition + self.dim = hook_length(self.partition) self.permutations = Permutation.full_group(n) @staticmethod @@ -50,37 +49,38 @@ def generate_all_irreps(n: int) -> Iterator[Self]: yield SnIrrep(n, partition) def __eq__(self, other) -> bool: - return self.shape == other.shape + return self.partition == other.shape def __hash__(self) -> int: - return hash(str(self.shape)) + return hash(str(self.partition)) def __repr__(self) -> str: - return f'S{self.n} Irrep: {self.shape}' + return f'S{self.n} Irrep: {self.partition}' @cached_property def basis(self) -> list[YoungTableau]: - return sorted(enumerate_standard_tableau(self.shape)) + return sorted(enumerate_standard_tableau(self.partition)) def split_partition(self) -> list[tuple[int, ...]]: - new_partitions = [] - k = len(self.shape) - for i in range(k - 1): - # check if valid subrepresentation - if self.shape[i] > self.shape[i+1]: - # if so, copy, modify, and append to list - partition = list(deepcopy(self.shape)) - partition[i] -= 1 - new_partitions.append(tuple(partition)) - # the last subrep - partition = list(deepcopy(self.shape)) - if partition[-1] > 1: - partition[-1] -= 1 - else: - # removing last element of partition if it’s a 1 - del partition[-1] - new_partitions.append(tuple(partition)) - return sorted(new_partitions) + """A list of the partitions directly underneath the partition that defines this irrep, in terms of Young's lattice. + These partitions directly beneath self.partition define the irreducible representations of S_{n-1} that this irrep "splits" into when we restrict to S_{n-1}. This relationship form the core of the "fast" part of the FFT. + """ + return sorted(youngs_lattice_covering_relation(self.partition)) + + def get_block_indices(self): + """When restricted to S_{n-1} this irrep has a block-diagonal form--one block for each of the split irreps of S_{n-1}. This helper method gets the indices of those blocks. + """ + curr_row, curr_col = 0, 0 + block_idx = [] + for split_irrep in self.split_partition(): + dim = SnIrrep(self.n - 1, split_irrep).dim + next_row = curr_row + dim + next_col = curr_col + dim + block_idx.append((slice(curr_row, next_row ), slice(curr_col, next_col))) + curr_row = next_row + curr_col = next_col + + return block_idx def adjacent_transpositions(self) -> list[tuple[int, int]]: return pairwise(range(self.n)) @@ -116,6 +116,8 @@ def generate_transposition_matrices(self) -> dict[tuple[int], ArrayLike]: decomp = [matrices[pair] for pair in adj_trans_decomp(i, j)] matrices[(i, j)] = reduce(lambda x, y: x @ y, decomp) return matrices + + @cached_property def matrix_representations(self) -> dict[tuple[int], ArrayLike]: @@ -145,9 +147,12 @@ def matrix_tensor(self, dtype=torch.float64, device=torch.device('cpu')) -> torc ] return torch.concatenate(tensors, dim=0).squeeze().to(device) - def coset_rep_matrices(self, dtype=torch.float64) -> list[torch.Tensor]: + def coset_rep_matrices(self, dtype=torch.float64, device=torch.device('cpu')) -> list[torch.Tensor]: coset_reps = [Permutation(contiguous_cycle(self.n, i)).sigma for i in range(self.n)] - return [torch.from_numpy(self.matrix_representations[rep]).to(dtype) for rep in coset_reps] + return [ + torch.from_numpy(self.matrix_representations[rep]).to(dtype).to(device) + for rep in coset_reps + ] def alternating_matrix_tensor(self, dtype=torch.float64, device=torch.device('cpu')): tensors = [ diff --git a/tests/test_fourier.py b/tests/test_fourier.py index 021ed94..f693311 100644 --- a/tests/test_fourier.py +++ b/tests/test_fourier.py @@ -5,7 +5,7 @@ import torch import algebraist from algebraist.fourier import ( - slow_sn_ft, slow_sn_ift, slow_sn_fourier_decomposition, sn_fft, calc_power + slow_sn_ft, slow_sn_ift, sn_fft, sn_ifft, sn_fourier_decomposition, calc_power ) from algebraist.permutations import Permutation from algebraist.irreps import SnIrrep @@ -40,12 +40,24 @@ def generate_random_function(n, batch_size=None): return torch.randn(math.factorial(n)) return torch.randn(batch_size, math.factorial(n)) + +def generate_random_fourier_transform(n, batch_size=None): + has_batch = batch_size is not None + batch_size = batch_size if batch_size else 1 + ft = {} + for irrep in SnIrrep.generate_all_irreps(n): + ft[irrep.partition] = torch.randn(batch_size, irrep.dim, irrep.dim) + if not has_batch: + ft[irrep.partition] = ft[irrep.partition].squeeze() + return ft + + @pytest.mark.parametrize("n", [3, 4, 5]) @pytest.mark.parametrize("batch_size", [None, 1, 5]) def test_fourier_transform_invertibility(n, batch_size): f = generate_random_function(n, batch_size) - ft = slow_sn_ft(f, n) - ift = slow_sn_ift(ft, n) + ft = sn_fft(f, n) + ift = sn_ifft(ft, n) f = f.squeeze() assert ift.shape == f.shape assert torch.allclose(f, ift, atol=1e-5), f"Fourier transform not invertible for n={n}, batch_size={batch_size}" @@ -54,24 +66,17 @@ def test_fourier_transform_invertibility(n, batch_size): @pytest.mark.parametrize("batch_size", [None, 1, 5]) def test_fourier_decomposition(n, batch_size): f = generate_random_function(n, batch_size) - ft = slow_sn_ft(f, n) - print(ft[(n -1, 1)].shape) - decomp = slow_sn_fourier_decomposition(ft, n) - if batch_size is not None and batch_size > 1: - assert decomp.shape == (batch_size, len(ft), math.factorial(n)) - else: - assert decomp.shape == (len(ft), math.factorial(n)) - reconstructed = decomp.sum(dim=-2) - if batch_size is None: - f = f.unsqueeze(0) - assert torch.allclose(f, reconstructed, atol=1e-5), f"Fourier decomposition failed for n={n}, batch_size={batch_size}" + ft = sn_fft(f, n) + decomp = sn_fourier_decomposition(ft, n) + + assert torch.allclose(f, sum(decomp.values()), atol=1e-5), f"Fourier decomposition failed for n={n}, batch_size={batch_size}" @pytest.mark.parametrize("n", [3, 4, 5]) @pytest.mark.parametrize("batch_size", [None, 1, 5]) def test_fourier_transform_norm_preservation(n, batch_size): f = generate_random_function(n, batch_size) - ft = slow_sn_ft(f, n) + ft = sn_fft(f, n) power = calc_power(ft, n) total_power = sum(p for p in power.values()) if batch_size is None: @@ -88,7 +93,7 @@ def test_convolution_theorem(n): # Compute convolution in group domain conv_group = convolve(g, f, n) - ft_conv_time = slow_sn_ft(conv_group, n) + ft_conv_time = sn_fft(conv_group, n) # Compute convolution in Fourier domain ft_f = sn_fft(f, n) @@ -99,13 +104,11 @@ def test_convolution_theorem(n): ft_conv_freq[shape] = ft_f[shape] * ft_g[shape] else: ft_conv_freq[shape] = ft_f[shape] @ ft_g[shape] - #ft_conv = {shape: torch.matmul(ft_f[shape], ft_g[shape]) for shape in ft_f.keys()} for shape in ft_f.keys(): assert torch.allclose(ft_conv_time[shape], ft_conv_freq[shape], atol=1.e-4),\ f"Convolution theorem failed for n={n}, partition={shape}, max diff = {(ft_conv_time[shape] - ft_conv_freq[shape]).abs().max()}" - @pytest.mark.parametrize("n", [3, 4, 5]) def test_permutation_action(n): f = generate_random_function(n, None) @@ -115,7 +118,7 @@ def test_permutation_action(n): permutation_action = [(perm.inverse * p).permutation_index() for p in permutations ] # Action in group domain f_perm = f[permutation_action] - ft_perm = slow_sn_ft(f_perm, n) + ft_perm = sn_fft(f_perm, n) # Action in Fourier domain ft_action = {} @@ -132,8 +135,7 @@ def test_permutation_action(n): for shape in ft_action.keys(): assert torch.allclose(ft_perm[shape], ft_action[shape], atol=1e-4), \ - f"Permutation action failed for n={n}, shape={shape}" - + f"Permutation action failed for n={n}, shape={shape}" @pytest.mark.parametrize("n", [3, 4, 5]) @@ -148,6 +150,14 @@ def test_sn_fft(n): assert all(equalities.values()), equalities +@pytest.mark.parametrize("n", [3, 4, 5]) +def test_sn_ifft(n): + ft = generate_random_fourier_transform(n) + slow_ift = slow_sn_ift(ft, n) + fast_ift = sn_ifft(ft, n) + + assert torch.allclose(slow_ift, fast_ift) + if __name__ == '__main__': pytest.main(['-v', '-s']) diff --git a/tests/test_irreps.py b/tests/test_irreps.py index 5652bae..16ebbf8 100644 --- a/tests/test_irreps.py +++ b/tests/test_irreps.py @@ -25,7 +25,7 @@ def sn_with_permutations(draw): def test_snirrep_initialization(): irrep = SnIrrep(3, (2, 1)) assert irrep.n == 3 - assert irrep.shape == (2, 1) + assert irrep.partition == (2, 1) assert len(irrep.basis) == 2 # There are two standard Young tableaux for (2,1) @@ -78,12 +78,12 @@ def test_orthogonality_relations(n): # First orthogonality relation for irrep1 in irreps: for irrep2 in irreps: - if irrep1.shape != irrep2.shape: + if irrep1.partition != irrep2.partition: continue reps1 = irrep1.matrix_representations reps2 = irrep2.matrix_representations sum_matrix = sum(reps1[perm.sigma] @ np.conj(reps2[perm.sigma].T) for perm in Permutation.full_group(n)) - expected = np.eye(sum_matrix.shape[0]) if irrep1.shape == irrep2.shape else np.zeros_like(sum_matrix) + expected = np.eye(sum_matrix.shape[0]) if irrep1.partition == irrep2.partition else np.zeros_like(sum_matrix) assert np.allclose(sum_matrix / len(reps1), expected) # Second orthogonality relation (sum of squares of dimensions equals n!) @@ -97,7 +97,7 @@ def test_all_partitions(n): for partition in partitions: irrep = SnIrrep(n, partition) assert irrep.n == n - assert irrep.shape == partition + assert irrep.partition == partition assert irrep.dim == len(irrep.basis) reps = irrep.matrix_representations assert len(reps) == math.factorial(n) @@ -105,7 +105,7 @@ def test_all_partitions(n): assert matrix.shape == (irrep.dim, irrep.dim) -@given( sn_with_permutations()) +@given(sn_with_permutations()) def test_representation_homomorphism(n_and_permutations): n, permutations = n_and_permutations partitions = generate_partitions(n) diff --git a/tests/test_permutations.py b/tests/test_permutations.py index cc4ebc1..97be713 100644 --- a/tests/test_permutations.py +++ b/tests/test_permutations.py @@ -4,10 +4,11 @@ from algebraist.permutations import Permutation + # Helper strategy to generate valid permutations @st.composite def permutation_strategy(draw, max_n=10): - n = draw(st.integers(min_value=1, max_value=max_n)) + n = draw(st.integers(min_value=3, max_value=max_n)) return Permutation(draw(st.permutations(range(n)))) @@ -17,17 +18,17 @@ def test_init(): assert p.n == 3 -def test_full_group(): - group = Permutation.full_group(3) - assert len(group) == 6 - assert Permutation([0, 1, 2]) in group - assert Permutation([2, 1, 0]) in group +@given(n=st.integers(min_value=3, max_value=6)) +def test_full_group(n): + group = Permutation.full_group(n) + assert len(group) == math.factorial(n) -def test_identity(): - id3 = Permutation.identity(3) - assert id3.sigma == (0, 1, 2) - assert id3.is_identity() +@given(n=st.integers(min_value=3, max_value=6)) +def test_identity(n): + ident = Permutation.identity(n) + assert ident == tuple(range(n)) + assert ident.is_identity() def test_transposition(): @@ -41,6 +42,7 @@ def test_multiplication(): p3 = p1 * p2 assert p3.sigma == (0, 1, 2) + @pytest.mark.parametrize("perm, power, expected", [ (Permutation([1, 2, 0]), 2, (2, 0, 1)), (Permutation([1, 2, 0]), 3, (0, 1, 2)), @@ -79,7 +81,6 @@ def test_transposition_decomposition(): assert p.transposition_decomposition() == [(0, 2), (1, 2)] - @pytest.mark.parametrize("n", [3, 4, 5]) def test_permutation_index(n): indices = [p.permutation_index() for p in Permutation.full_group(n)] @@ -114,17 +115,19 @@ def test_parity_product_property(perm1, perm2): def test_double_inverse_property(perm): assert perm == perm.inverse.inverse + # Property: order of permutation divides group order (n!) @given(perm=permutation_strategy()) def test_order_divides_group_order(perm): - from math import factorial - assert factorial(perm.n) % perm.order == 0 + assert math.factorial(perm.n) % perm.order == 0 + # Property: conjugacy class partition sums to n @given(perm=permutation_strategy()) def test_conjugacy_class_sum(perm): assert sum(perm.conjugacy_class) == perm.n + # Property: multiplication is associative @given(perm1=permutation_strategy(), perm2=permutation_strategy(), perm3=permutation_strategy()) def test_multiplication_associativity(perm1, perm2, perm3): @@ -132,6 +135,7 @@ def test_multiplication_associativity(perm1, perm2, perm3): return # Skip if permutations are of different sizes assert (perm1 * perm2) * perm3 == perm1 * (perm2 * perm3) + # Property: identity permutation is neutral element @given(perm=permutation_strategy()) def test_identity_neutral(perm): @@ -139,7 +143,14 @@ def test_identity_neutral(perm): assert perm * identity == perm assert identity * perm == perm + # Property: permutation to the power of its order is identity @given(perm=permutation_strategy()) def test_power_order_is_identity(perm): assert (perm ** perm.order).is_identity() + + +@given(perm=permutation_strategy()) +def test_inverse_gives_identity(perm): + ident = Permutation.identity(perm.n) + assert perm * perm.inverse == ident diff --git a/tests/test_tableau.py b/tests/test_tableau.py index 7ebf8ee..5998e6e 100644 --- a/tests/test_tableau.py +++ b/tests/test_tableau.py @@ -85,8 +85,19 @@ def test_hook_length_known_values(): def test_generate_partitions(): - assert set(generate_partitions(4)) == {(4,), (3, 1), (2, 2), (2, 1, 1), (1, 1, 1, 1)} - assert set(generate_partitions(5)) == {(5,), (4, 1), (3, 2), (3, 1, 1), (2, 2, 1), (2, 1, 1, 1), (1, 1, 1, 1, 1)} + partsof9 = { + (9,), (8, 1), (7, 2), (7, 1, 1), (6, 3), (6, 2, 1), (6, 1, 1, 1), (5, 4), (5, 3, 1), + (5, 2, 2), (5, 2, 1, 1), (5, 1, 1, 1, 1), (4, 4, 1), (4, 3, 2), (4, 3, 1, 1), (4, 2, 2, 1), (4, 2, 1, 1, 1), + (4, 1, 1, 1, 1, 1), (3, 3, 3), (3, 3, 2, 1), (3, 3, 1, 1, 1), (3, 2, 2, 2), (3, 2, 2, 1, 1), (3, 2, 1, 1, 1, 1), (3, 1, 1, 1, 1, 1, 1), + (2, 2, 2, 2, 1), (2, 2, 2, 1, 1, 1), (2, 2, 1, 1, 1, 1, 1), (2, 1, 1, 1, 1, 1, 1, 1), (1, 1, 1, 1, 1, 1, 1, 1, 1) + } + partsof8 = { + (8,), (7, 1), (6, 2), (6, 1, 1), (5, 3), (5, 2, 1), (5, 1, 1, 1), (4, 4), (4, 3, 1), (4, 2, 2), (4, 2, 1, 1), (4, 1, 1, 1, 1), + (3, 3, 2), (3, 3, 1, 1), (3, 2, 2, 1), (3, 2, 1, 1, 1), (3, 1, 1, 1, 1, 1), (2, 2, 2, 2), (2, 2, 2, 1, 1), (2, 2, 1, 1, 1, 1), + (2, 1, 1, 1, 1, 1, 1), (1, 1, 1, 1, 1, 1, 1, 1) + } + assert set(generate_partitions(8)) == partsof8 + assert set(generate_partitions(9)) == partsof9 def test_enumerate_standard_tableau():