Skip to content

Commit

Permalink
Improve accuracy of nfloat subtraction and hardcode add/sub up to nli…
Browse files Browse the repository at this point in the history
…mbs=4 (#1977)
  • Loading branch information
fredrik-johansson authored May 15, 2024
1 parent 0d6152b commit 8611346
Show file tree
Hide file tree
Showing 8 changed files with 1,085 additions and 89 deletions.
6 changes: 6 additions & 0 deletions doc/source/nfloat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,5 +317,11 @@ Internal functions
int _nfloat_cmpabs(nfloat_srcptr x, nfloat_srcptr y, gr_ctx_t ctx)
int _nfloat_add_1(nfloat_ptr res, ulong x0, slong xexp, int xsgnbit, ulong y0, slong delta, gr_ctx_t ctx)
int _nfloat_sub_1(nfloat_ptr res, ulong x0, slong xexp, int xsgnbit, ulong y0, slong delta, gr_ctx_t ctx)
int _nfloat_add_2(nfloat_ptr res, nn_srcptr xd, slong xexp, int xsgnbit, nn_srcptr yd, slong delta, gr_ctx_t ctx)
int _nfloat_sub_2(nfloat_ptr res, nn_srcptr xd, slong xexp, int xsgnbit, nn_srcptr yd, slong delta, gr_ctx_t ctx)
int _nfloat_add_3(nfloat_ptr res, nn_srcptr x, slong xexp, int xsgnbit, nn_srcptr y, slong delta, gr_ctx_t ctx)
int _nfloat_sub_3(nfloat_ptr res, nn_srcptr x, slong xexp, int xsgnbit, nn_srcptr y, slong delta, gr_ctx_t ctx)
int _nfloat_add_4(nfloat_ptr res, nn_srcptr x, slong xexp, int xsgnbit, nn_srcptr y, slong delta, gr_ctx_t ctx)
int _nfloat_sub_4(nfloat_ptr res, nn_srcptr x, slong xexp, int xsgnbit, nn_srcptr y, slong delta, gr_ctx_t ctx)
int _nfloat_add_n(nfloat_ptr res, nn_srcptr xd, slong xexp, int xsgnbit, nn_srcptr yd, slong delta, slong nlimbs, gr_ctx_t ctx)
int _nfloat_sub_n(nfloat_ptr res, nn_srcptr xd, slong xexp, int xsgnbit, nn_srcptr yd, slong delta, slong nlimbs, gr_ctx_t ctx)
6 changes: 6 additions & 0 deletions src/nfloat.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ int nfloat_im(nfloat_ptr res, nfloat_srcptr x, gr_ctx_t ctx);

int _nfloat_add_1(nfloat_ptr res, ulong x0, slong xexp, int xsgnbit, ulong y0, slong delta, gr_ctx_t ctx);
int _nfloat_sub_1(nfloat_ptr res, ulong x0, slong xexp, int xsgnbit, ulong y0, slong delta, gr_ctx_t ctx);
int _nfloat_add_2(nfloat_ptr res, nn_srcptr xd, slong xexp, int xsgnbit, nn_srcptr yd, slong delta, gr_ctx_t ctx);
int _nfloat_sub_2(nfloat_ptr res, nn_srcptr xd, slong xexp, int xsgnbit, nn_srcptr yd, slong delta, gr_ctx_t ctx);
int _nfloat_add_3(nfloat_ptr res, nn_srcptr x, slong xexp, int xsgnbit, nn_srcptr y, slong delta, gr_ctx_t ctx);
int _nfloat_sub_3(nfloat_ptr res, nn_srcptr x, slong xexp, int xsgnbit, nn_srcptr y, slong delta, gr_ctx_t ctx);
int _nfloat_add_4(nfloat_ptr res, nn_srcptr x, slong xexp, int xsgnbit, nn_srcptr y, slong delta, gr_ctx_t ctx);
int _nfloat_sub_4(nfloat_ptr res, nn_srcptr x, slong xexp, int xsgnbit, nn_srcptr y, slong delta, gr_ctx_t ctx);
int _nfloat_add_n(nfloat_ptr res, nn_srcptr xd, slong xexp, int xsgnbit, nn_srcptr yd, slong delta, slong nlimbs, gr_ctx_t ctx);
int _nfloat_sub_n(nfloat_ptr res, nn_srcptr xd, slong xexp, int xsgnbit, nn_srcptr yd, slong delta, slong nlimbs, gr_ctx_t ctx);

Expand Down
102 changes: 98 additions & 4 deletions src/nfloat/dot.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,64 @@
(r1) = __r1; (r2) = __r2; (r3) = __r3; \
} while (0)

/* todo: define in longlong.h */
#if FLINT_BITS == 64 && defined(__GNUC__) && defined(__AVX2__)

#define add_sssssaaaaaaaaaa(s4,s3,s2,s1,s0, a4,a3,a2,a1,a0, b4,b3,b2,b1,b0) \
__asm__ ("addq %14,%q4\n\tadcq %12,%q3\n\tadcq %10,%q2\n\tadcq %8,%q1\n\tadcq %6,%q0" \
: "=r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \
: "0" ((ulong)(a4)), "rme" ((ulong)(b4)), \
"1" ((ulong)(a3)), "rme" ((ulong)(b3)), \
"2" ((ulong)(a2)), "rme" ((ulong)(b2)), \
"3" ((ulong)(a1)), "rme" ((ulong)(b1)), \
"4" ((ulong)(a0)), "rme" ((ulong)(b0)))


#define sub_ddddmmmmssss(s3, s2, s1, s0, a3, a2, a1, a0, b3, b2, b1, b0) \
__asm__ ("subq %11,%q3\n\tsbbq %9,%q2\n\tsbbq %7,%q1\n\tsbbq %5,%q0" \
: "=r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \
: "0" ((ulong)(a3)), "rme" ((ulong)(b3)), \
"1" ((ulong)(a2)), "rme" ((ulong)(b2)), \
"2" ((ulong)(a1)), "rme" ((ulong)(b1)), \
"3" ((ulong)(a0)), "rme" ((ulong)(b0)))

#define sub_dddddmmmmmsssss(s4,s3,s2,s1,s0, a4,a3,a2,a1,a0, b4,b3,b2,b1,b0) \
__asm__ ("subq %14,%q4\n\tsbbq %12,%q3\n\tsbbq %10,%q2\n\tsbbq %8,%q1\n\tsbbq %6,%q0" \
: "=r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \
: "0" ((ulong)(a4)), "rme" ((ulong)(b4)), \
"1" ((ulong)(a3)), "rme" ((ulong)(b3)), \
"2" ((ulong)(a2)), "rme" ((ulong)(b2)), \
"3" ((ulong)(a1)), "rme" ((ulong)(b1)), \
"4" ((ulong)(a0)), "rme" ((ulong)(b0)))
#else

#define add_sssssaaaaaaaaaa(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \
do { \
ulong __t0 = 0; \
add_ssssaaaaaaaa(__t0, s2, s1, s0, (ulong) 0, a2, a1, a0, (ulong) 0, b2, b1, b0); \
add_ssaaaa(s4, s3, a4, a3, b4, b3); \
add_ssaaaa(s4, s3, s4, s3, (ulong) 0, __t0); \
} while (0)


#define sub_ddddmmmmssss(s3, s2, s1, s0, a3, a2, a1, a0, b3, b2, b1, b0) \
do { \
ulong __t1, __u1; \
sub_dddmmmsss(__t1, s1, s0, (ulong) 0, a1, a0, (ulong) 0, b1, b0); \
sub_ddmmss(__u1, s2, (ulong) 0, a2, (ulong) 0, b2); \
sub_ddmmss(s3, s2, (a3) - (b3), s2, -__u1, -__t1); \
} while (0)

#define sub_dddddmmmmmsssss(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \
do { \
ulong __t2, __u2; \
sub_ddddmmmmssss(__t2, s2, s1, s0, (ulong) 0, a2, a1, a0, (ulong) 0, b2, b1, b0); \
sub_ddmmss(__u2, s3, (ulong) 0, a3, (ulong) 0, b3); \
sub_ddmmss(s4, s3, (a4) - (b4), s3, -__u2, -__t2); \
} while (0)

#endif

int
__nfloat_vec_dot(nfloat_ptr res, nfloat_srcptr initial, int subtract, nfloat_srcptr x, slong sizeof_xstep, nfloat_srcptr y, slong sizeof_ystep, slong len, gr_ctx_t ctx)
{
Expand Down Expand Up @@ -467,12 +525,48 @@ __nfloat_vec_dot(nfloat_ptr res, nfloat_srcptr initial, int subtract, nfloat_src
if (delta < FLINT_BITS)
{
t[0] = flint_mpn_mulhigh_n(t + 1, NFLOAT_D(xi), NFLOAT_D(yi), nlimbs);
mpn_rshift(t, t, nlimbs + 1, delta);

if (xsgnbit)
mpn_sub_n(s, s, t, nlimbs + 1);
if (nlimbs == 3)
{
if (xsgnbit)
sub_ddddmmmmssss(s[3], s[2], s[1], s[0], s[3], s[2], s[1], s[0],
t[3] >> delta,
(t[2] >> delta) | (t[3] << (FLINT_BITS - delta)),
(t[1] >> delta) | (t[2] << (FLINT_BITS - delta)),
(t[0] >> delta) | (t[1] << (FLINT_BITS - delta)));
else
add_ssssaaaaaaaa(s[3], s[2], s[1], s[0], s[3], s[2], s[1], s[0],
t[3] >> delta,
(t[2] >> delta) | (t[3] << (FLINT_BITS - delta)),
(t[1] >> delta) | (t[2] << (FLINT_BITS - delta)),
(t[0] >> delta) | (t[1] << (FLINT_BITS - delta)));
}
else if (nlimbs == 4)
{
if (xsgnbit)
sub_dddddmmmmmsssss(s[4], s[3], s[2], s[1], s[0], s[4], s[3], s[2], s[1], s[0],
t[4] >> delta,
(t[3] >> delta) | (t[4] << (FLINT_BITS - delta)),
(t[2] >> delta) | (t[3] << (FLINT_BITS - delta)),
(t[1] >> delta) | (t[2] << (FLINT_BITS - delta)),
(t[0] >> delta) | (t[1] << (FLINT_BITS - delta)));
else
add_sssssaaaaaaaaaa(s[4], s[3], s[2], s[1], s[0], s[4], s[3], s[2], s[1], s[0],
t[4] >> delta,
(t[3] >> delta) | (t[4] << (FLINT_BITS - delta)),
(t[2] >> delta) | (t[3] << (FLINT_BITS - delta)),
(t[1] >> delta) | (t[2] << (FLINT_BITS - delta)),
(t[0] >> delta) | (t[1] << (FLINT_BITS - delta)));
}
else
mpn_add_n(s, s, t, nlimbs + 1);
{
mpn_rshift(t, t, nlimbs + 1, delta);

if (xsgnbit)
mpn_sub_n(s, s, t, nlimbs + 1);
else
mpn_add_n(s, s, t, nlimbs + 1);
}
}
else
{
Expand Down
Loading

0 comments on commit 8611346

Please sign in to comment.