Skip to content

Commit

Permalink
Merge pull request #1750 from albinahlback/mulhigh_normalized
Browse files Browse the repository at this point in the history
Add flint_mpn_mulhigh_normalized_n
  • Loading branch information
albinahlback authored Jan 26, 2024
2 parents 30e64ae + 5df1b26 commit 82f6d57
Show file tree
Hide file tree
Showing 19 changed files with 2,647 additions and 2 deletions.
308 changes: 308 additions & 0 deletions dev/gen_mulhigh_basecase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,295 @@ function mulhigh(n::Int; debug::Bool = false)
end
end

###############################################################################
# mulhigh, normalised
###############################################################################

function mulhigh_normalised_1()
r0 = _regs[9] # Important that r0 is rax
r1 = _regs[4]

res = _regs[1]
ap = _regs[2]
bp = _regs[3]

body = ""

body *= "\tmov\t0*8($bp), %rdx\n"
body *= "\tmulx\t0*8($ap), $r0, $r1\n"

# Check if normalised
body *= "\tmov\t\$0, %rdx\n"
body *= "\ttest\t$r1, $r1\n"
body *= "\tsetns\t$(R8("%rdx"))\n"
body *= "\tjs\t.Lcontinue\n"

# If not normalised, shift by one
body *= "\tadd\t$r0, $r0\n"
body *= "\tadc\t$r1, $r1\n"

body *= ".Lcontinue:\n"
body *= "\tmov\t$r1, 0*8($res)\n"

return body * "\n\tret\n"
end

function mulhigh_normalised_2()
r0 = _regs[9]
r1 = _regs[4]
r2 = _regs[5]

sc = _regs[6]
zr = _regs[7]

res = _regs[1]
ap = _regs[2]
bp_old = _regs[3]
b1 = _regs[8]

body = ""

body *= "\tmov\t1*8($bp_old), $b1\n"
body *= "\tmov\t0*8($bp_old), %rdx\n"
body *= "\txor\t$(R32(zr)), $(R32(zr))\n"
body *= "\tmulx\t0*8($ap), $sc, $r0\n"
body *= "\tmulx\t1*8($ap), $sc, $r1\n"
body *= "\tadcx\t$sc, $r0\n"
body *= "\tadcx\t$zr, $r1\n"

body *= "\tmov\t$b1, %rdx\n"
body *= "\tmulx\t0*8($ap), $sc, $r2\n"
body *= "\tadcx\t$sc, $r0\n"
body *= "\tadcx\t$r2, $r1\n"
body *= "\tmulx\t1*8($ap), $sc, $r2\n"
body *= "\tadox\t$sc, $r1\n"
body *= "\tadox\t$zr, $r2\n"
body *= "\tadcx\t$zr, $r2\n"

# Check if normalised
body *= "\tmov\t\$0, %rdx\n"
body *= "\ttest\t$r2, $r2\n"
body *= "\tsetns\t$(R8("%rdx"))\n"
body *= "\tjs\t.Lcontinue\n"

# If not normalised, shift by one
body *= "\tadd\t$r0, $r0\n"
body *= "\tadc\t$r1, $r1\n"
body *= "\tadc\t$r2, $r2\n"

body *= ".Lcontinue:\n"
body *= "\tmov\t$r1, 0*8($res)\n"
body *= "\tmov\t$r2, 1*8($res)\n"

return body * "\n\tret\n"
end

function mulhigh_normalised(n::Int; debug::Bool = false)
if n < 1
error()
elseif n == 1
return mulhigh_normalised_1()
elseif n == 2
return mulhigh_normalised_2()
elseif n <= 12
# Continue
else
error()
end

if debug
res = "res"
ap = "ap"
bp_old = "bp_old"
bp = "bp"
sc = "sc"
zr = "zr"
else
res = _regs[1]
ap = _regs[2]
bp_old = _regs[3] # rdx
bp = _regs[4] # rdx is used by mulx, so we need to switch register for bp

if n != 12
sc = _regs[5] # scrap register
zr = (n < 10) ? _regs[6] : "%rsp" # zero
else
sc = "%rsp"
zr = sc
end

if n < 9
_r = [_regs[9]; _regs[7:8]; __regs[1:n - 2]]
elseif n == 9
_r = [_regs[9]; _regs[7:8]; __regs[1:6]; res]
elseif n == 10
_r = [_regs[9]; _regs[6:8]; __regs[1:6]; res]
elseif n == 11
_r = [_regs[9]; _regs[6:8]; __regs[1:6]; res; bp]
elseif n == 12
_r = [_regs[9]; _regs[5:8]; __regs[1:6]; res; bp]
end
end

r(ix::Int) = debug ? "r$ix" : _r[ix + 1]

# Push
body = ""
for ix in 1:min(n - 2, 6)
body *= "\tpush\t$(__regs[ix])\n"
end
if n == 9
body *= "\tpush\t$res\n"
elseif n >= 10
body *= "\tvmovq\t%rsp, %xmm0\n"
body *= "\tvmovq\t$res, %xmm1\n"
end
body *= "\n"

# Prepare
body *= "\tmov\t$bp_old, $bp\n"
body *= "\tmov\t0*8($bp_old), %rdx\n"
if n != 12
body *= "\txor\t$(R32(zr)), $(R32(zr))\n"
end
body *= "\n"

# First multiplication chain
body *= "\tmulx\t$(n - 2)*8($ap), $sc, $(r(0))\n"
body *= "\tmulx\t$(n - 1)*8($ap), $sc, $(r(1))\n"
if n != 12
body *= "\tadcx\t$sc, $(r(0))\n"
body *= "\tadcx\t$zr, $(r(1))\n"
else
body *= "\tadd\t$sc, $(r(0))\n"
body *= "\tadc\t\$0, $(r(1))\n"
body *= "\ttest\t%al, %al\n"
end
body *= "\n"

# Intermediate multiplication chains
for ix in 1:min(n - 2, (n != 12) ? 8 : 9)
body *= "\tmov\t$ix*8($bp), %rdx\n"

body *= "\tmulx\t$(n - 2 - ix)*8($ap), $sc, $(r(ix + 2))\n"
body *= "\tmulx\t$(n - 1 - ix)*8($ap), $sc, $(r(ix + 1))\n"
body *= "\tadcx\t$(r(ix + 2)), $(r(0))\n"
body *= "\tadox\t$sc, $(r(0))\n"
body *= "\tadcx\t$(r(ix + 1)), $(r(1))\n"

for jx in 1:ix - 1
body *= "\tmulx\t$(n - 1 - ix + jx)*8($ap), $sc, $(r(ix + 1))\n"
body *= "\tadox\t$sc, $(r(jx + 0))\n"
body *= "\tadcx\t$(r(ix + 1)), $(r(jx + 1))\n"
end

body *= "\tmulx\t$(n - 1)*8($ap), $sc, $(r(ix + 1))\n"
body *= "\tadox\t$sc, $(r(ix + 0))\n"
if n == 12
body *= "\tmov\t\$0, $(R32(zr))\n"
end
body *= "\tadcx\t$zr, $(r(ix + 1))\n"
body *= "\tadox\t$zr, $(r(ix + 1))\n"

body *= "\n"
end

if n >= 11
N = n - 1
body *= "\tmov\t$(n - 2)*8($bp), %rdx\n"
for ix in 0:n - 2
body *= "\tmulx\t$ix*8($ap), $sc, $(r(N))\n"
if ix == 0
body *= "\tadcx\t$(r(N)), $(r(ix + 0))\n"
else
body *= "\tadox\t$sc, $(r(ix - 1))\n"
body *= "\tadcx\t$(r(N)), $(r(ix + 0))\n"
end
end
body *= "\tmulx\t$(n - 1)*8($ap), $sc, $(r(N))\n"
body *= "\tadox\t$sc, $(r(N - 1))\n"
if n == 12
body *= "\tmov\t\$0, $(R32(zr))\n"
end
body *= "\tadcx\t$zr, $(r(N))\n"
body *= "\tadox\t$zr, $(r(N))\n"
body *= "\n"
end

# Last multiplication chain
body *= "\tmov\t$(n - 1)*8($bp), %rdx\n"
for ix in 0:n - 2
body *= "\tmulx\t$ix*8($ap), $sc, $(r(n))\n"
if ix % 2 == 0
body *= "\tadcx\t$sc, $(r(ix + 0))\n"
body *= "\tadcx\t$(r(n)), $(r(ix + 1))\n"
else
body *= "\tadox\t$sc, $(r(ix + 0))\n"
body *= "\tadox\t$(r(n)), $(r(ix + 1))\n"
end
end
body *= "\tmulx\t$(n - 1)*8($ap), $sc, $(r(n))\n"
if (n - 1) % 2 == 0
body *= "\tadcx\t$sc, $(r(n - 1))\n"
else
body *= "\tadox\t$sc, $(r(n - 1))\n"
end
if n == 12
body *= "\tmov\t\$0, $(R32(zr))\n"
end
body *= "\tadcx\t$zr, $(r(n))\n"
if n == 9
# Use scrap register for storing pointer to res
res = sc
body *= "\tpop\t$res\n"
end
body *= "\tadox\t$zr, $(r(n))\n"
body *= "\n"

# Check if normalised
body *= "\tmov\t\$0, %rdx\n"
body *= "\ttest\t$(r(n)), $(r(n))\n"
body *= "\tsetns\t$(R8("%rdx"))\n"
body *= "\tjs\t.Lcontinue\n"

# If not normalised, shift by one
body *= "\tadd\t$(r(0)), $(r(0))\n"
for ix in 1:n
body *= "\tadc\t$(r(ix)), $(r(ix))\n"
end

body *= ".Lcontinue:\n"
if n == 10 || n == 11
res, zr = sc, "error zr"
body *= "\tvmovq\t%xmm1, $res\n"
body *= "\tvmovq\t%xmm0, %rsp\n"
elseif n == 12
res = sc
body *= "\tvmovq\t%xmm1, $res\n"
end

# Store result
for ix in 1:n
body *= "\tmov\t$(r(ix)), $(ix - 1)*8($res)\n"
end
body *= "\n"

# Pop
if n == 12
body *= "\tvmovq\t%xmm0, %rsp\n"
end
for ix in min(n - 2, 6):-1:1
body *= "\tpop\t$(__regs[ix])\n"
end
body *= "\n"

if debug
print(body * "\tret\n")
else
return body * "\tret\n"
end
end

###############################################################################
# Generate file
###############################################################################
Expand All @@ -396,8 +685,27 @@ function gen_mulhigh(m::Int, nofile::Bool = false)
end
end

function gen_mulhigh_normalised(m::Int, nofile::Bool = false)
(pre, post) = function_pre_post("flint_mpn_mulhigh_normalised_$m")
functionbody = mulhigh_normalised(m)

str = "$copyright\n$preamble\n$pre$functionbody$post"

if nofile
print(str)
else
path = String(@__DIR__) * "/../src/mpn_extras/broadwell/mulhigh_normalised_$m.asm"
file = open(path, "w")
write(file, str)
close(file)
end
end

function gen_all()
for m in 1:12
gen_mulhigh(m)
end
for m in 1:12
gen_mulhigh_normalised(m)
end
end
6 changes: 4 additions & 2 deletions doc/source/mpn_extras.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ Multiplication
is typically non-exact for sizes larger than one. The highest error
is *n + 2* ULP in the returned limb.

.. note:: This function may not exist on processors not supporting the ADX
instruction set.
.. note::

This function may not exist on processors not supporting the ADX instruction
set.


Divisibility
Expand Down
13 changes: 13 additions & 0 deletions src/mpn_extras.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,11 @@ flint_mpn_sqr(mp_ptr r, mp_srcptr x, mp_size_t n)

#define FLINT_HAVE_MULHIGH_N_FUNC(n) ((n) <= FLINT_MPN_MULHIGH_N_FUNC_TAB_WIDTH)

struct mp_limb_pair_t { mp_limb_t m1; mp_limb_t m2; };
typedef struct mp_limb_pair_t (* flint_mpn_mulhigh_normalised_func_t)(mp_ptr, mp_srcptr, mp_srcptr);

FLINT_DLL extern const flint_mpn_mul_func_t flint_mpn_mulhigh_n_func_tab[];
FLINT_DLL extern const flint_mpn_mulhigh_normalised_func_t flint_mpn_mulhigh_normalised_n_func_tab[];

/* NOTE: Aliasing is allowed! */
/* FIXME: How do we proceed for bigger n? */
Expand All @@ -250,6 +254,15 @@ mp_limb_t flint_mpn_mulhigh_n(mp_ptr rp, mp_srcptr xp, mp_srcptr yp, mp_size_t n
return flint_mpn_mulhigh_n_func_tab[n - 1](rp, xp, yp);
}

FLINT_FORCE_INLINE
struct mp_limb_pair_t flint_mpn_mulhigh_normalised_n(mp_ptr rp, mp_srcptr xp, mp_srcptr yp, mp_size_t n)
{
FLINT_ASSERT(n >= 1);
FLINT_ASSERT(FLINT_HAVE_MULHIGH_N_FUNC(n));

return flint_mpn_mulhigh_normalised_n_func_tab[n - 1](rp, xp, yp);
}

/*
return the high limb of a two limb left shift by n < GMP_LIMB_BITS bits.
Note: if GMP_NAIL_BITS != 0, the rest of flint is already broken anyways.
Expand Down
Loading

0 comments on commit 82f6d57

Please sign in to comment.