Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLVM: field addition with saturated fields #456

Merged
merged 6 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions PLANNING.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ Other tracks are stretch goals, contributions towards them are accepted.
- introduce batchAffine_vartime
- Optimized square_repeated in assembly for Montgomery and Crandall/Pseudo-Mersenne primes
- Optimized elliptic curve directly calling assembly without ADX checks and limited input/output movement in registers or using function multi-versioning.
- LLVM IR:
- use internal or private linkage type
- look into calling conventions like "fast" or "Tail fast"
- check if returning a value from function is propely optimized
compared to in-place result
- use readnone (pure) and readmem attribute for functions
- look into passing parameter as arrays instead of pointers?
- use hot function attribute

### User Experience track

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ proc finalSubMayOverflowImpl*(
ctx.mov scratch[i], a[i]
ctx.sbb scratch[i], M[i]

# If it overflows here, it means that it was
# If it underflows here, it means that it was
# smaller than the modulus and we don't need `scratch`
ctx.sbb scratchReg, 0

Expand Down
83 changes: 83 additions & 0 deletions constantine/math_compiler/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Cryptography primitive compiler

This implements a cryptography compiler that can be used to produce
- high-performance JIT code for GPUs
- or assembly files, for CPUs when we want to ensure
there are no side-channel regressions for secret data
- or vectorized assembly file, as LLVM IR is significantly
more convenient to model vector operation

There are also LLVM IR => FPGA translators that might be useful
in the future.

## Platforms limitations

- X86 cannot use dual carry-chain ADCX/ADOX easily.
- no native support for clearing a flag with `xor`
and keeping it clear.
- inline assembly cannot use the raw ASM printer.
so workflow will need to compile -> decompile.
- Nvidia GPUs cannot lower types larger than 64-bit, hence we cannot use i256 for example.
- AMD GPUs have a 1/4 throughput for i32 MUL compared to f32 MUL or i24 MUL
- non-x86 targets may not be as optimized for matching
pattern for addcarry and subborrow, even with @llvm.usub.with.overflow

## ABI

Internal functions are:
- prefixed with `_`
- Linkage: internal
- calling convention: "fast"
- mark `hot` for field arithmetic functions

Internal global constants are:
- prefixed with `_`
- Linkage: linkonce_odr (so they are merged with globals of the same name)

External functions use default convention.

We ensure parameters / return value fit in registers:
- https://llvm.org/docs/Frontend/PerformanceTips.html

TODO:
- function alignment: look into
- https://www.bazhenov.me/posts/2024-02-performance-roulette/
- https://lkml.org/lkml/2015/5/21/443
- function multiversioning
- aggregate alignment (via datalayout)

Naming convention for internal procedures:
- _big_add_u64x4
- _finalsub_mayo_u64x4 -> final substraction may overflow
- _finalsub_noo_u64x4 -> final sub no overflow
- _mod_add_u64x4
- _mod_add2x_u64x8 -> FpDbl backend
- _mty_mulur_u64x4b2 -> unreduced Montgomery multiplication (unreduced result valid iff 2 spare bits)
- _mty_mul_u64x4b1 -> reduced Montgomery multiplication (result valid iff at least 1 spare bit)
- _mty_mul_u64x4 -> reduced Montgomery multiplication
- _mty_nsqrur_u64x4b2 -> unreduced square n times
- _mty_nsqr_u64x4b1 -> reduced square n times
- _mty_sqr_u64x4 -> square
- _mty_red_u64x4 -> reduction u64x4 <- u64x8
- _pmp_red_mayo_u64x4 -> Pseudo-Mersenne Prime partial reduction may overflow (secp256k1)
- _pmp_red_noo_u64x4 -> Pseudo-Mersenne Prime partial reduction no overflow
- _secp256k1_red -> special reduction
- _fp2x_sqr2x_u64x4 -> Fp2 complex, Fp -> FpDbl lazy reduced squaring
- _fp2g_sqr2x_u64x4 -> Fp2 generic/non-complex (do we pass the mul-non-residue as parameter?)
- _fp2_sqr_u64x4 -> Fp2 (pass the mul-by-non-residue function as parameter)
- _fp4o2_mulnr1pi_u64x4 -> Fp4 over Fp2 mul with (1+i) non-residue optimization
- _fp4o2_mulbynr_u64x4
- _fp12_add_u64x4
- _fp12o4o2_mul_u64x4 -> Fp12 over Fp4 over Fp2
- _ecg1swjac_adda0_u64x4 -> Shortweierstrass G1 jacobian addition a=0
- _ecg1swjac_add_u64x4_var -> Shortweierstrass G1 jacobian vartime addition
- _ectwprj_add_u64x4 -> Twisted Edwards Projective addition

Vectorized:
- _big_add_u64x4v4
- _big_add_u32x8v8

Naming for external procedures:
- bls12_381_fp_add
- bls12_381_fr_add
- bls12_381_fp12_add
20 changes: 2 additions & 18 deletions constantine/math_compiler/codegen_nvidia.nim
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import
constantine/platforms/abis/nvidia_abi {.all.},
constantine/platforms/abis/c_abi,
constantine/platforms/llvm/[llvm, nvidia_inlineasm],
constantine/platforms/llvm/llvm,
constantine/platforms/primitives,
./ir

export
nvidia_abi, nvidia_inlineasm,
nvidia_abi,
Flag, flag, wrapOpenArrayLenType

# ############################################################
Expand Down Expand Up @@ -115,22 +115,6 @@ proc cudaDeviceInit*(deviceID = 0'i32): CUdevice =
#
# ############################################################

proc tagCudaKernel(module: ModuleRef, fn: FnDef) =
## Tag a function as a Cuda Kernel, i.e. callable from host

doAssert fn.fnTy.getReturnType().isVoid(), block:
"Kernels must not return values but function returns " & $fn.fnTy.getReturnType().getTypeKind()

let ctx = module.getContext()
module.addNamedMetadataOperand(
"nvvm.annotations",
ctx.asValueRef(ctx.metadataNode([
fn.fnImpl.asMetadataRef(),
ctx.metadataNode("kernel"),
constInt(ctx.int32_t(), 1, LlvmBool(false)).asMetadataRef()
]))
)

proc wrapInCallableCudaKernel*(module: ModuleRef, fn: FnDef) =
## Create a public wrapper of a cuda device function
##
Expand Down
216 changes: 216 additions & 0 deletions constantine/math_compiler/impl_fields_globals.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
constantine/platforms/bithacks,
constantine/platforms/llvm/llvm,
constantine/serialization/[io_limbs, codecs],
constantine/named/deriv/precompute

import ./ir

# ############################################################
#
# Metadata precomputation
#
# ############################################################

# Constantine on CPU is configured at compile-time for several properties that need to be runtime configuration GPUs:
# - word size (32-bit or 64-bit)
# - curve properties access like modulus bitsize or -1/M[0] a.k.a. m0ninv
# - constants are stored in freestanding `const`
#
# This is because it's not possible to store a BigInt[254] and a BigInt[384]
# in a generic way in the same structure, especially without using heap allocation.
# And with Nim's dead code elimination, unused curves are not compiled in.
#
# As there would be no easy way to dynamically retrieve (via an array or a table)
# const BLS12_381_modulus = ...
# const BN254_Snarks_modulus = ...
#
# - We would need a macro to properly access each constant.
# - We would need to create a 32-bit and a 64-bit version.
# - Unused curves would be compiled in the program.
#
# Note: on GPU we don't manipulate secrets hence branches and dynamic memory allocations are allowed.
#
# As GPU is a niche usage, instead we recreate the relevant `precompute` and IO procedures
# with dynamic wordsize support.

type
DynWord = uint32 or uint64
BigNum[T: DynWord] = object
bits: uint32
limbs: seq[T]

# Serialization
# ------------------------------------------------

func byteLen(bits: SomeInteger): SomeInteger {.inline.} =
## Length in bytes to serialize BigNum
(bits + 7) shr 3 # (bits + 8 - 1) div 8

func fromHex[T](a: var BigNum[T], s: string) =
var bytes = newSeq[byte](a.bits.byteLen())
bytes.paddedFromHex(s, bigEndian)

# 2. Convert canonical uint to BigNum
const wordBitwidth = sizeof(T) * 8
a.limbs.unmarshal(bytes, wordBitwidth, bigEndian)

func fromHex[T](BN: type BigNum[T], bits: uint32, s: string): BN =
const wordBitwidth = sizeof(T) * 8
let numWords = wordsRequired(bits, wordBitwidth)

result.bits = bits
result.limbs.setLen(numWords)
result.fromHex(s)

func toHexLlvm*[T](a: BigNum[T]): string =
## Conversion to big-endian hex suitable for LLVM literals
## It MUST NOT have a prefix
## This is variable-time
# 1. Convert BigInt to canonical uint
const wordBitwidth = sizeof(T) * 8
var bytes = newSeq[byte](byteLen(a.bits))
bytes.marshal(a.limbs, wordBitwidth, bigEndian)

# 2. Convert canonical uint to hex
const hexChars = "0123456789abcdef"
result = newString(2 * bytes.len)
for i in 0 ..< bytes.len:
let bi = bytes[i]
result[2*i] = hexChars[bi shr 4 and 0xF]
result[2*i+1] = hexChars[bi and 0xF]

# Checks
# ------------------------------------------------

func checkValidModulus(M: BigNum) =
const wordBitwidth = uint32(BigNum.T.sizeof() * 8)
let expectedMsb = M.bits-1 - wordBitwidth * (M.limbs.len.uint32 - 1)
let msb = log2_vartime(M.limbs[M.limbs.len-1])

doAssert msb == expectedMsb, "Internal Error: the modulus must use all declared bits and only those:\n" &
" Modulus '0x" & M.toHexLlvm() & "' is declared with " & $M.bits &
" bits but uses " & $(msb + wordBitwidth * uint32(M.limbs.len - 1)) & " bits."

# Fields metadata
# ------------------------------------------------

func negInvModWord[T](M: BigNum[T]): T =
## Returns the Montgomery domain magic constant for the input modulus:
##
## µ ≡ -1/M[0] (mod SecretWord)
##
## M[0] is the least significant limb of M
## M must be odd and greater than 2.
##
## Assuming 64-bit words:
##
## µ ≡ -1/M[0] (mod 2^64)
checkValidModulus(M)
return M.limbs[0].negInvModWord()

# ############################################################
#
# Globals in IR
#
# ############################################################

proc getModulusPtr*(asy: Assembler_LLVM, fd: FieldDescriptor): ValueRef =
let modname = fd.name & "_mod"
var M = asy.module.getGlobal(cstring modname)
if M.isNil():
M = asy.defineGlobalConstant(
name = modname,
section = fd.name,
constIntOfStringAndSize(fd.intBufTy, fd.modulus, 16),
fd.intBufTy,
alignment = 64
)
return M

proc getM0ninv*(asy: Assembler_LLVM, fd: FieldDescriptor): ValueRef =
let m0ninvname = fd.name & "_m0ninv"
var m0ninv = asy.module.getGlobal(cstring m0ninvname)
if m0ninv.isNil():
if fd.w == 32:
let M = BigNum[uint32].fromHex(fd.bits, fd.modulus)
m0ninv = asy.defineGlobalConstant(
name = m0ninvname,
section = fd.name,
constInt(fd.wordTy, M.negInvModWord()),
fd.wordTy
)
else:
let M = BigNum[uint64].fromHex(fd.bits, fd.modulus)
m0ninv = asy.defineGlobalConstant(
name = m0ninvname,
section = fd.name,
constInt(fd.wordTy, M.negInvModWord()),
fd.wordTy
)


return m0ninv

when isMainModule:
let asy = Assembler_LLVM.new("test_module", bkX86_64_Linux)
let fd = asy.ctx.configureField(
"bls12_381_fp",
381,
"1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab",
v = 1, w = 64)

discard asy.getModulusPtr(fd)
discard asy.getM0ninv(fd)

echo "========================================="
echo "LLVM IR\n"

echo asy.module
echo "========================================="

asy.module.verify(AbortProcessAction)

# --------------------------------------------
# See the assembly - note it might be different from what the JIT compiler did
initializeFullNativeTarget()

const triple = "x86_64-pc-linux-gnu"

let machine = createTargetMachine(
target = toTarget(triple),
triple = triple,
cpu = "",
features = "adx,bmi2", # TODO check the proper way to pass options
level = CodeGenLevelAggressive,
reloc = RelocDefault,
codeModel = CodeModelDefault
)

let pbo = createPassBuilderOptions()
let err = asy.module.runPasses(
"default<O3>,function-attrs,memcpyopt,sroa,mem2reg,gvn,dse,instcombine,inline,adce",
machine,
pbo
)
if not err.pointer().isNil():
writeStackTrace()
let errMsg = err.getErrorMessage()
stderr.write("\"codegenX86_64\" for module '" & astToStr(module) & "' " & $instantiationInfo() &
" exited with error: " & $cstring(errMsg) & '\n')
errMsg.dispose()
quit 1

echo "========================================="
echo "Assembly\n"

echo machine.emitTo[:string](asy.module, AssemblyFile)
echo "========================================="
Loading
Loading