Skip to content

Commit

Permalink
stash prep for Barret Reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Sep 8, 2023
1 parent 9751ee2 commit 7b29fbd
Show file tree
Hide file tree
Showing 11 changed files with 729 additions and 26 deletions.
108 changes: 108 additions & 0 deletions benchmarks/bench_gmp_modexp.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import
../constantine/math/arithmetic,
../constantine/math/io/[io_bigints, io_fields],
../constantine/math_arbitrary_precision/arithmetic/[bigints_views, limbs_views, limbs_montgomery, limbs_mod2k],
../constantine/math/config/[type_bigint, curves, precompute],
../constantine/platforms/abstractions,
../constantine/serialization/codecs,
../helpers/prng_unsafe,
std/[times, monotimes, strformat]

import gmp
# import stint

const # https://gmplib.org/manual/Integer-Import-and-Export.html
GMP_WordLittleEndian = -1'i32
GMP_WordNativeEndian = 0'i32
GMP_WordBigEndian = 1'i32

GMP_MostSignificantWordFirst = 1'i32
GMP_LeastSignificantWordFirst = -1'i32

# let M = Mod(BN254_Snarks)
const bits = 256
const expBits = bits # Stint only supports same size args

var rng: RngState
rng.seed(1234)

for i in 0 ..< 5:
echo "i: ", i
# -------------------------
let M = rng.random_long01Seq(BigInt[bits])
let a = rng.random_long01Seq(BigInt[bits])

var exponent = newSeq[byte](expBits div 8)
for i in 0 ..< expBits div 8:
exponent[i] = byte rng.next()

# -------------------------

let aHex = a.toHex()
let eHex = exponent.toHex()
let mHex = M.toHex()

echo " base: ", a.toHex()
echo " exponent: ", exponent.toHex()
echo " modulus: ", M.toHex()

# -------------------------

var elapsedCtt, elapsedStint, elapsedGMP: int64

block:
var r: BigInt[bits]
let start = getMonotime()
r.limbs.powMod_vartime(a.limbs, exponent, M.limbs, window = 4)
let stop = getMonotime()

elapsedCtt = inNanoseconds(stop-start)

echo " r Constantine: ", r.toHex()
echo " elapsed Constantine: ", elapsedCtt, " ns"

# -------------------------

# block:
# let aa = Stuint[bits].fromHex(aHex)
# let ee = Stuint[expBits].fromHex(eHex)
# let mm = Stuint[bits].fromHex(mHex)

# var r: Stuint[bits]
# let start = getMonotime()
# r = powmod(aa, ee, mm)
# let stop = getMonotime()

# elapsedStint = inNanoseconds(stop-start)

# echo " r stint: ", r.toHex()
# echo " elapsed Stint: ", elapsedStint, " ns"

block:
var aa, ee, mm, rr: mpz_t
mpz_init(aa)
mpz_init(ee)
mpz_init(mm)
mpz_init(rr)

aa.mpz_import(a.limbs.len, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, a.limbs[0].unsafeAddr)
let e = BigInt[expBits].unmarshal(exponent, bigEndian)
ee.mpz_import(e.limbs.len, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, e.limbs[0].unsafeAddr)
mm.mpz_import(M.limbs.len, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, M.limbs[0].unsafeAddr)

let start = getMonotime()
rr.mpz_powm(aa, ee, mm)
let stop = getMonotime()

elapsedGMP = inNanoSeconds(stop-start)

var r: BigInt[bits]
var rWritten: csize
discard r.limbs[0].addr.mpz_export(rWritten.addr, GMP_LeastSignificantWordFirst, sizeof(SecretWord), GMP_WordNativeEndian, 0, rr)

echo " r GMP: ", r.toHex()
echo " elapsed GMP: ", elapsedGMP, " ns"

# echo &"\n ratio Stint/Constantine: {float64(elapsedStint)/float64(elapsedCtt):.3f}x"
echo &" ratio GMP/Constantine: {float64(elapsedGMP)/float64(elapsedCtt):.3f}x"
echo "---------------------------------------------------------"
182 changes: 182 additions & 0 deletions benchmarks/bench_gmp_modmul.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
# Standard library
std/[macros, times, strutils, monotimes],
# Third-party
gmp,
# Internal
../constantine/math/io/io_bigints,
../constantine/math/arithmetic,
../constantine/math_arbitrary_precision/arithmetic/limbs_divmod_vartime,
../constantine/platforms/abstractions,
../constantine/serialization/codecs,
# Test utilities
../helpers/prng_unsafe

echo "\n------------------------------------------------------\n"
# We test up to 1024-bit, more is really slow

macro testSizes(rBits, aBits, bBits, body: untyped): untyped =
## Configure sizes known at compile-time to test against GMP
result = newStmtList()

for size in [256, 384, 121*8]:
let aBitsVal = size
let bBitsVal = size
let rBitsVal = size * 2

result.add quote do:
block:
const `aBits` = `aBitsVal`
const `bBits` = `bBitsVal`
const `rBits` = `rBitsVal`
block:
`body`

const # https://gmplib.org/manual/Integer-Import-and-Export.html
GMP_WordLittleEndian {.used.} = -1'i32
GMP_WordNativeEndian {.used.} = 0'i32
GMP_WordBigEndian {.used.} = 1'i32

GMP_MostSignificantWordFirst = 1'i32
GMP_LeastSignificantWordFirst {.used.} = -1'i32

proc main() =
var rng: RngState
let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32
rng.seed(seed)
echo "\n------------------------------------------------------\n"
echo "rng xoshiro512** seed: ", seed
echo ""

var r, rMod, a, b: mpz_t
mpz_init(r)
mpz_init(rMod)
mpz_init(a)
mpz_init(b)

testSizes(rBits, aBits, bBits):
# echo "--------------------------------------------------------------------------------"
echo "Testing: mul r (", align($rBits, 4), "-bit) <- a (", align($aBits, 4), "-bit) * b (", align($bBits, 4), "-bit)"

# Build the bigints
let aTest = rng.random_unsafe(BigInt[aBits])
var bTest = rng.random_unsafe(BigInt[bBits])

#########################################################
# Conversion to GMP
const aLen = (aBits + 7) div 8
const bLen = (bBits + 7) div 8

var aBuf: array[aLen, byte]
var bBuf: array[bLen, byte]

aBuf.marshal(aTest, bigEndian)
bBuf.marshal(bTest, bigEndian)

mpz_import(a, aLen, GMP_MostSignificantWordFirst, 1, GMP_WordNativeEndian, 0, aBuf[0].addr)
mpz_import(b, bLen, GMP_MostSignificantWordFirst, 1, GMP_WordNativeEndian, 0, bBuf[0].addr)

#########################################################
# Multiplication
const NumIters = 1000000

let startGMP = getMonoTime()
for _ in 0 ..< NumIters:
mpz_mul(r, a, b)
let stopGMP = getMonoTime()
echo "GMP - ", aBits, " x ", bBits, " -> ", rBits, " mul: ", float(inNanoseconds((stopGMP-startGMP)))/float(NumIters), " ns"

# If a*b overflow the result size we truncate
const numWords = wordsRequired(rBits)
when numWords < wordsRequired(aBits+bBits):
echo " truncating from ", wordsRequired(aBits+bBits), " words to ", numWords, " (2^", WordBitwidth * numWords, ")"
r.mpz_tdiv_r_2exp(r, WordBitwidth * numWords)

let startGMPmod = getMonoTime()
for _ in 0 ..< NumIters:
mpz_mod(rMod, a, b)
let stopGMPmod = getMonoTime()
echo "GMP - ", aBits, " mod ", bBits, " -> ", bBits, " mod: ", float(inNanoseconds((stopGMPmod-startGMPmod)))/float(NumIters), " ns"

let startGMPmod2 = getMonoTime()
for _ in 0 ..< NumIters:
mpz_mod(rMod, r, b)
let stopGMPmod2 = getMonoTime()
echo "GMP - ", rBits, " mod ", bBits, " -> ", bBits, " mod: ", float(inNanoseconds((stopGMPmod2-startGMPmod2)))/float(NumIters), " ns"


# Constantine
var rTest: BigInt[rBits]

let startCTT = getMonoTime()
for _ in 0 ..< NumIters:
rTest.prod(aTest, bTest)
let stopCTT = getMonoTime()
echo "Constantine - ", aBits, " x ", bBits, " -> ", rBits, " mul: ", float(inNanoseconds((stopCTT-startCTT)))/float(NumIters), " ns"

var rTestMod: BigInt[bBits]

let startCTTMod = getMonoTime()
for _ in 0 ..< NumIters:
rTestMod.reduce(aTest, bTest)
let stopCTTMod = getMonoTime()
echo "Constantine - ", aBits, " mod ", bBits, " -> ", bBits, " mod: ", float(inNanoseconds((stopCTTmod-startCTTmod)))/float(NumIters), " ns"

let startCTTvartimeMod = getMonoTime()
var q {.noInit.}: BigInt[bBits]
for _ in 0 ..< NumIters:
discard divRem_vartime(q.limbs, rTestMod.limbs, aTest.limbs, bTest.limbs)
let stopCTTvartimeMod = getMonoTime()
echo "Constantine - ", aBits, " mod ", bBits, " (vartime) -> ", bBits, " mod: ", float(inNanoseconds((stopCTTvartimeMod-startCTTvartimeMod)))/float(NumIters), " ns"

let startCTTMod2 = getMonoTime()
for _ in 0 ..< NumIters:
rTestMod.reduce(rTest, bTest)
let stopCTTMod2 = getMonoTime()
echo "Constantine - ", rBits, " mod ", bBits, " -> ", bBits, " mod: ", float(inNanoseconds((stopCTTmod2-startCTTmod2)))/float(NumIters), " ns"

let startCTTvartimeMod2 = getMonoTime()
var q2 {.noInit.}: BigInt[bBits]
for _ in 0 ..< NumIters:
discard divRem_vartime(q2.limbs, rTestMod.limbs, rTest.limbs, bTest.limbs)
let stopCTTvartimeMod2 = getMonoTime()
echo "Constantine - ", rBits, " mod ", bBits, " (vartime) -> ", bBits, " mod: ", float(inNanoseconds((stopCTTvartimeMod2-startCTTvartimeMod2)))/float(NumIters), " ns"

echo ""

#########################################################
# Check

{.push warnings: off.} # deprecated csize
var aW, bW, rW: csize # Word written by GMP
{.pop.}

const rLen = numWords * WordBitWidth
var rGMP: array[rLen, byte]
discard mpz_export(rGMP[0].addr, rW.addr, GMP_MostSignificantWordFirst, 1, GMP_WordNativeEndian, 0, r)

var rConstantine: array[rLen, byte]
marshal(rConstantine, rTest, bigEndian)

# Note: in bigEndian, GMP aligns left while constantine aligns right
doAssert rGMP.toOpenArray(0, rW-1) == rConstantine.toOpenArray(rLen-rW, rLen-1), block:
# Reexport as bigEndian for debugging
discard mpz_export(aBuf[0].addr, aW.addr, GMP_MostSignificantWordFirst, 1, GMP_WordNativeEndian, 0, a)
discard mpz_export(bBuf[0].addr, bW.addr, GMP_MostSignificantWordFirst, 1, GMP_WordNativeEndian, 0, b)
"\nMultiplication with operands\n" &
" a (" & align($aBits, 4) & "-bit): " & aBuf.toHex & "\n" &
" b (" & align($bBits, 4) & "-bit): " & bBuf.toHex & "\n" &
"into r of size " & align($rBits, 4) & "-bit failed:" & "\n" &
" GMP: " & rGMP.toHex() & "\n" &
" Constantine: " & rConstantine.toHex() & "\n" &
"(Note that GMP aligns bytes left while constantine aligns bytes right)"

main()
2 changes: 1 addition & 1 deletion constantine/math/arithmetic/bigints.nim
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import
./limbs,
./limbs_extmul,
./limbs_exgcd,
../../math_arbitrary_precision/arithmetic/limbs_division
../../math_arbitrary_precision/arithmetic/limbs_divmod

export BigInt

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import
./limbs_mod2k,
./limbs_multiprec,
./limbs_extmul,
./limbs_division
./limbs_divmod_vartime

# No exceptions allowed
{.push raises: [], checks: off.}
Expand Down Expand Up @@ -61,6 +61,7 @@ func powOddMod_vartime*(

if eBits == 1:
r.view().reduce(a.view(), aBits, M.view(), mBits)
# discard r.reduce_vartime(a, M)
return

let L = wordsRequired(mBits)
Expand All @@ -77,8 +78,9 @@ func powOddMod_vartime*(
# For now, we call explicit reduction as it can handle all sizes.
# TODO: explicit reduction uses constant-time division which is **very** expensive
if a.len != M.len:
let t = allocStackArray(SecretWord, L)
var t = allocStackArray(SecretWord, L)
t.LimbsViewMut.reduce(a.view(), aBits, M.view(), mBits)
# discard t.toOpenArray(0, L-1).reduce_vartime(a, M)
rMont.LimbsViewMut.getMont(LimbsViewConst t, M.view(), LimbsViewConst r2.view(), m0ninv, mBits)
else:
rMont.LimbsViewMut.getMont(a.view(), M.view(), LimbsViewConst r2.view(), m0ninv, mBits)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func shlAddMod_estimate(a: LimbsViewMut, aLen: int,
a1 = (a[^1] shl (WordBitWidth-R)) or (a[^2] shr R)
m0 = (M[^1] shl (WordBitWidth-R)) or (M[^2] shr R)

# m0 has its high bit set. (a0, a1)/p0 fits in a limb.
# m0 has its high bit set. (a0, a1)/m0 fits in a limb.
# Get a quotient q, at most we will be 2 iterations off
# from the true quotient
var q, r: SecretWord
Expand All @@ -78,29 +78,29 @@ func shlAddMod_estimate(a: LimbsViewMut, aLen: int,

# Now substract a*2^64 - q*p
var carry = Zero
var over_p = CtTrue # Track if quotient greater than the modulus
var overM = CtTrue # Track if quotient greater than the modulus

for i in 0 ..< MLen:
var qp_lo: SecretWord
var qm_lo: SecretWord

block: # q*p
# q * p + carry (doubleword) carry from previous limb
muladd1(carry, qp_lo, q, M[i], carry)
block: # q*m
# q * m + carry (doubleword) carry from previous limb
muladd1(carry, qm_lo, q, M[i], carry)

block: # a*2^64 - q*p
var borrow: Borrow
subB(borrow, a[i], a[i], qp_lo, Borrow(0))
subB(borrow, a[i], a[i], qm_lo, Borrow(0))
carry += SecretWord(borrow) # Adjust if borrow

over_p = mux(a[i] == M[i], over_p, a[i] > M[i])
overM = mux(a[i] == M[i], overM, a[i] > M[i])

# Fix quotient, the true quotient is either q-1, q or q+1
#
# if carry < q or carry == q and over_p we must do "a -= p"
# if carry > hi (negative result) we must do "a += p"
# if carry < q or carry == q and over_p we must do "a -= m"
# if carry > hi (negative result) we must do "a += m"

result.neg = carry > hi
result.tooBig = not(result.neg) and (over_p or (carry < hi))
result.tooBig = not(result.neg) and (overM or (carry < hi))

func shlAddMod(a: LimbsViewMut, aLen: int,
c: SecretWord, M: LimbsViewConst, mBits: int) =
Expand Down
Loading

0 comments on commit 7b29fbd

Please sign in to comment.