Skip to content

Commit

Permalink
fmpz_mat_mul_waksman (code from PML)
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrik-johansson committed Jan 26, 2024
1 parent 8329475 commit ea32ed8
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 74 deletions.
10 changes: 10 additions & 0 deletions doc/source/fmpz_mat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,16 @@ Matrix multiplication
The matrices must have compatible dimensions for matrix multiplication.
No aliasing is allowed.

.. function:: void fmpz_mat_mul_waksman(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B)

Sets ``C`` to the matrix product `C = A B` computed using
Waksman multiplication, which does only `n^3/2 + O(n^2)`
products, but many additions. This is good for small matrices
with large entries.

The matrices must have compatible dimensions for matrix multiplication.
No aliasing is allowed.

.. function:: void fmpz_mat_mul_strassen(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B)

Sets `C = AB`. Dimensions must be compatible for matrix multiplication.
Expand Down
6 changes: 2 additions & 4 deletions src/fmpz_mat.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,8 @@ void fmpz_mat_scalar_mod_fmpz(fmpz_mat_t B, const fmpz_mat_t A, const fmpz_t m);
/* Multiplication */

void fmpz_mat_mul(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B);

void fmpz_mat_mul_classical(fmpz_mat_t C, const fmpz_mat_t A,
const fmpz_mat_t B);

void fmpz_mat_mul_classical(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B);
void fmpz_mat_mul_waksman(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B);
void fmpz_mat_mul_strassen(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B);

#define fmpz_mat_mul_classical_inline _Pragma("GCC error \"'fmpz_mat_mul_classical_inline' is deprecated. Use 'fmpz_mat_mul_classical' instead.\"")
Expand Down
20 changes: 15 additions & 5 deletions src/fmpz_mat/mul.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (C) 2010,2011,2018 Fredrik Johansson
Copyright (C) 2010, 2011, 2018, 2024 Fredrik Johansson
Copyright (C) 2016 Aaditya Thakkar
This file is part of FLINT.
Expand Down Expand Up @@ -138,13 +138,12 @@ fmpz_mat_mul(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B)

if (br == 1)
{
for (i = 0; i < ar; i++)
for (j = 0; j < bc; j++)
fmpz_mul(fmpz_mat_entry(C, i, j),
fmpz_mat_entry(A, i, 0), fmpz_mat_entry(B, 0, j));
fmpz_mat_mul_classical(C, A, B);
return;
}

/* todo: use Strassen, Waksman or FFT when entries are huge
and the matrix is not structured */
if (br == 2)
{
for (i = 0; i < ar; i++)
Expand All @@ -157,6 +156,7 @@ fmpz_mat_mul(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B)

dim = FLINT_MIN(ar, bc);
dim = FLINT_MIN(dim, br);

/* TODO: for space reasons maybe just call strassen here if dim > 10000 */

abits = fmpz_mat_max_bits(A);
Expand Down Expand Up @@ -274,8 +274,18 @@ fmpz_mat_mul(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B)
}
else
{
/* want balanced entries for mul_waksman */
/* todo: should check for structured matrices (favor mul_classical) */
slong min_bits = FLINT_MIN(abits, bbits);
slong max_bits = FLINT_MAX(abits, bbits);

if (dim >= 3 * FLINT_BIT_COUNT(cbits)) /* tuning param */
_fmpz_mat_mul_multi_mod(C, A, B, sign, cbits);
else if (dim < 20 && ((dim == 2 && min_bits >= 5000 && max_bits <= 1.1 * min_bits)
|| (max_bits <= 1.6 * min_bits && ((dim == 3 && min_bits >= 3000)
|| (dim >= 4 && min_bits >= 1000)
|| (dim >= 12 && min_bits >= 500)))))
fmpz_mat_mul_waksman(C, A, B);
else if (abits >= 500 && bbits >= 500 && dim >= 8) /* tuning param */
fmpz_mat_mul_strassen(C, A, B);
else
Expand Down
4 changes: 2 additions & 2 deletions src/fmpz_mat/mul_strassen.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ void fmpz_mat_mul_strassen(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B)
b = A->c;
c = B->c;

if (a <= 4 || b <= 4 || c <= 4)
if (a <= 1 || b <= 1 || c <= 1)
{
fmpz_mat_mul(C, A, B);
fmpz_mat_mul_classical(C, A, B);
return;
}

Expand Down
131 changes: 131 additions & 0 deletions src/fmpz_mat/mul_waksman.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
Copyright (C) 2024 Éric Schost
Copyright (C) 2024 Vincent Neiger
This file is part of FLINT.
FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 2.1 of the License, or
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/

#include "fmpz.h"
#include "fmpz_vec.h"
#include "fmpz_mat.h"

/** ------------------------------------------------------------ */
/** Waksman's algorithm for matrix multiplication */
/** does n^3/2+O(n^2) products, but many additions */
/** good for small matrices with large entries */
/** ------------------------------------------------------------ */
void fmpz_mat_mul_waksman(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B)
{
slong m = A->r;
slong n = B->r;
slong p = B->c;

if (m == 0 || n == 0 || p == 0)
{
fmpz_mat_zero(C);
return;
}

slong i, l, j, k;

fmpz * Crow = _fmpz_vec_init(p + m);
fmpz * Ccol = Crow + p;

slong np = n >> 1;

fmpz_t val0, val1, val2, crow;

fmpz_init(val0);
fmpz_init(val1);
fmpz_init(val2);
fmpz_init(crow);

for (i = 0; i < p; i++)
fmpz_zero(Crow + i);

for (i = 0; i < m; i++)
fmpz_zero(Ccol + i);

for (i = 0; i < m; i++)
for (j = 0; j < p; j++)
fmpz_zero(fmpz_mat_entry(C, i, j));

for (j = 1; j <= np; j++)
{
slong j2 = (j << 1) - 1;

for (k = 0; k < p; k++)
{
fmpz_add(val1, fmpz_mat_entry(A, 0, j2-1), fmpz_mat_entry(B, j2, k));
fmpz_add(val2, fmpz_mat_entry(A, 0, j2), fmpz_mat_entry(B, j2-1, k));
fmpz_addmul(fmpz_mat_entry(C, 0, k), val1, val2);

fmpz_sub(val1, fmpz_mat_entry(A, 0, j2-1), fmpz_mat_entry(B, j2, k));
fmpz_sub(val2, fmpz_mat_entry(A, 0, j2), fmpz_mat_entry(B, j2-1, k));
fmpz_addmul(Crow + k, val1, val2);
}

for (l = 1; l < m; l++)
{
fmpz_add(val1, fmpz_mat_entry(A, l, j2-1), fmpz_mat_entry(B, j2, 0));
fmpz_add(val2, fmpz_mat_entry(A, l, j2), fmpz_mat_entry(B, j2-1, 0));
fmpz_addmul(fmpz_mat_entry(C, l, 0), val1, val2);

fmpz_sub(val1, fmpz_mat_entry(A, l, j2-1), fmpz_mat_entry(B, j2, 0));
fmpz_sub(val2, fmpz_mat_entry(A, l, j2), fmpz_mat_entry(B, j2-1, 0));
fmpz_addmul(Ccol + l, val1, val2);
}

for (k = 1; k < p; k++)
{
for (l = 1; l < m; l++)
{
fmpz_add(val1, fmpz_mat_entry(A, l, j2-1), fmpz_mat_entry(B, j2, k));
fmpz_add(val2, fmpz_mat_entry(A, l, j2), fmpz_mat_entry(B, j2-1, k));
fmpz_addmul(fmpz_mat_entry(C, l, k), val1, val2);
}
}
}

for (l = 1; l < m; l++)
{
fmpz_add(val1, Ccol + l, fmpz_mat_entry(C, l, 0));
fmpz_tdiv_q_2exp(Ccol+ l, val1, 1);
fmpz_sub(fmpz_mat_entry(C, l, 0), fmpz_mat_entry(C, l, 0), Ccol + l);
}

fmpz_add(val1, Crow, fmpz_mat_entry(C, 0, 0));
fmpz_tdiv_q_2exp(val0, val1, 1);
fmpz_sub(fmpz_mat_entry(C, 0, 0), fmpz_mat_entry(C, 0, 0), val0);

for (k = 1; k < p; k++)
{
fmpz_add(crow, Crow + k, fmpz_mat_entry(C, 0, k));
fmpz_tdiv_q_2exp(val1, crow, 1);
fmpz_sub(fmpz_mat_entry(C, 0, k), fmpz_mat_entry(C, 0, k), val1);
fmpz_sub(crow, val1, val0);

for (l = 1; l < m; l++)
{
fmpz_sub(val2, fmpz_mat_entry(C, l, k), crow);
fmpz_sub(fmpz_mat_entry(C, l, k), val2, Ccol + l);
}
}

if ((n & 1) == 1)
for (l = 0; l < m; l++)
for (k = 0; k < p; k++)
fmpz_addmul(fmpz_mat_entry(C, l, k), fmpz_mat_entry(A, l, n-1), fmpz_mat_entry(B, n-1, k));

_fmpz_vec_clear(Crow, p + m);

fmpz_clear(val0);
fmpz_clear(val1);
fmpz_clear(val2);
fmpz_clear(crow);
}
Loading

0 comments on commit ea32ed8

Please sign in to comment.