From cacb22d2941e1b24c5d391e3cbdd00689a63f6a3 Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 16:45:58 +0100 Subject: [PATCH 01/22] wrap `execCudaImpl` macro logic in a block Otherwise we run into problems if we have two execs in the same scope. --- constantine/math_compiler/codegen_nvidia.nim | 3 +++ 1 file changed, 3 insertions(+) diff --git a/constantine/math_compiler/codegen_nvidia.nim b/constantine/math_compiler/codegen_nvidia.nim index 08af496f..a963fb0f 100644 --- a/constantine/math_compiler/codegen_nvidia.nim +++ b/constantine/math_compiler/codegen_nvidia.nim @@ -448,6 +448,9 @@ proc execCudaImpl(jitFn, res, inputs: NimNode): NimNode = x[0] ) ) + result = quote do: + block: + `result` macro execCuda*(jitFn: CUfunction, res: typed, From d4e640ce1def3ebefaedc8372eb1cfb601baf57b Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:00:42 +0100 Subject: [PATCH 02/22] add more EC Jac operations to helper templates --- .../math_compiler/impl_curves_ops_jacobian.nim | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/constantine/math_compiler/impl_curves_ops_jacobian.nim b/constantine/math_compiler/impl_curves_ops_jacobian.nim index b040df6f..ab40502c 100644 --- a/constantine/math_compiler/impl_curves_ops_jacobian.nim +++ b/constantine/math_compiler/impl_curves_ops_jacobian.nim @@ -60,7 +60,9 @@ proc store*(dst: EcPointJac, src: EcPointJac) = template declEllipticJacOps*(asy: Assembler_LLVM, cd: CurveDescriptor): untyped = ## This template can be used to make operations on `Field` elements ## more convenient. - ## XXX: extend to include all ops + # Setters + template setNeutral(x: EcPointJac): untyped = asy.setNeutral(cd, x.buf) + # Boolean checks template isNeutral(res, x: EcPointJac): untyped = asy.isNeutral(cd, res, x.buf) template isNeutral(x: EcPointJac): untyped = @@ -68,6 +70,16 @@ template declEllipticJacOps*(asy: Assembler_LLVM, cd: CurveDescriptor): untyped asy.isNeutral(cd, res, x.buf) res + # Mutating assignment ops + template sum(res, x, y: EcPointJac): untyped = asy.sum(cd, res.buf, x.buf, y.buf) + template `+=`(x, y: EcPointJac): untyped = x.sum(x, y) + template mixedSum(res, x: EcPointJac, y: EcPointAff): untyped = asy.mixedSum(cd, res.buf, x.buf, y.buf) + template `+=`(x: EcPointJac, y: EcPointAff): untyped = x.mixedSum(x, y) + + # Arithmetic mutations + template double(res, x: EcPointJac): untyped = asy.double(cd, res.buf, x.buf) + template double(x: EcPointJac): untyped = x.double(x) + # Conditional ops template ccopy(x, y: EcPointJac, c): untyped = asy.ccopy(cd, x.buf, y.buf, derefBool c) From 22c05658f4c68b921b8f4eca346986296055b083 Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:03:54 +0100 Subject: [PATCH 03/22] do not quit on failure in NvidiaAssembler destructor A failure in the check from the destructor almost certainly means that we destroyed early, due to an exception. We don't want to hide the exception, hence we don't quit. --- constantine/math_compiler/codegen_nvidia.nim | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/constantine/math_compiler/codegen_nvidia.nim b/constantine/math_compiler/codegen_nvidia.nim index a963fb0f..808a6004 100644 --- a/constantine/math_compiler/codegen_nvidia.nim +++ b/constantine/math_compiler/codegen_nvidia.nim @@ -71,15 +71,17 @@ export # Cuda Driver API # ------------------------------------------------------------ -template check*(status: CUresult) = +template check*(status: CUresult, quitOnFailure = true) = ## Check the status code of a CUDA operation ## Exit program with error if failure let code = status # ensure that the input expression is evaluated once only + if code != CUDA_SUCCESS: writeStackTrace() stderr.write(astToStr(status) & " " & $instantiationInfo() & " exited with error: " & $code & '\n') - quit 1 + if quitOnFailure: + quit 1 # NOTE: this hides exceptions if they are thrown! func cuModuleLoadData*(module: var CUmodule, sourceCode: openArray[char]): CUresult {.inline.}= cuModuleLoadData(module, sourceCode[0].unsafeAddr) @@ -516,8 +518,18 @@ type proc `=destroy`*(nv: NvidiaAssemblerObj) = ## XXX: Need to also call the finalizer for `asy` in the future! - check nv.cuMod.cuModuleUnload() - check nv.cuCtx.cuCtxDestroy() + # NOTE: In the destructor we don't want to quit on a `check` failure. + # The reason is that if we throw an exception with an `NvidiaAssembler` + # in scope, it will trigger the destructor here (with a likely invalid + # state in the CUDA module / context). However, in this case + # we will crash anyway and would just end up hiding the actual cause of + # the error. + # In the unlikely case that all CUDA operations worked correctly up + # to this point, but then fail to unload, we currently ignore this + # as a failure mode. + # Hopefully we find a better solution in the future. + check nv.cuMod.cuModuleUnload(), quitOnFailure = false + check nv.cuCtx.cuCtxDestroy(), quitOnFailure = false `=destroy`(nv.asy) proc initNvAsm*[Name: static Algebra](field: type FF[Name], wordSize: int = 32, backend = bkNvidiaPTX): NvidiaAssembler = From c1257ac2696fc1a0346eef8ff2c4d503ea1200ed Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:09:29 +0100 Subject: [PATCH 04/22] add CurveDescriptor fields for LLVM type for Fr, scalars for MSM --- constantine/math_compiler/ir.nim | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/constantine/math_compiler/ir.nim b/constantine/math_compiler/ir.nim index 01fa2518..ddb752eb 100644 --- a/constantine/math_compiler/ir.nim +++ b/constantine/math_compiler/ir.nim @@ -226,9 +226,13 @@ type fd*: FieldDescriptor # of the underlying field family*: CurveFamily modulus*: string # Modulus as Big-Endian uppercase hex, NOT prefixed with 0x - modulusBitWidth*: uint32 + modulusBitWidth*: uint32 # bits required for elements of `Fp` order*: string - orderBitWidth*: uint32 + orderBitWidth*: uint32 # bits required for scalar elements `Fr` + + # type of the field Fr + fieldScalarTy*: TypeRef + numWordsScalar*: uint32 # num words required for it cofactor*: string eqForm*: CurveEquationForm @@ -265,6 +269,10 @@ proc configureCurve*(ctx: ContextRef, # Array of 2 arrays for affine coords result.curveTyAff = array_t(result.fd.fieldTy, 2) + # and the type for elements of Fr + result.numWordsScalar = uint32 wordsRequired(curveOrderBitWidth, w) + result.fieldScalarTy = array_t(result.fd.wordTy, result.numWordsScalar) + # Curve parameters result.coef_a = coef_a result.coef_b = coef_b # unused From 43e4d19e9450100c9caa1a3395111ff1262f05e8 Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:10:32 +0100 Subject: [PATCH 05/22] [LLVM] add `isPointerTy` helper to determine if type is a pointer --- constantine/platforms/abis/llvm_abi.nim | 2 ++ 1 file changed, 2 insertions(+) diff --git a/constantine/platforms/abis/llvm_abi.nim b/constantine/platforms/abis/llvm_abi.nim index 334e2eca..bac0e76b 100644 --- a/constantine/platforms/abis/llvm_abi.nim +++ b/constantine/platforms/abis/llvm_abi.nim @@ -277,6 +277,8 @@ type proc getContext*(ty: TypeRef): ContextRef {.importc: "LLVMGetTypeContext".} proc getTypeKind*(ty: TypeRef): TypeKind {.importc: "LLVMGetTypeKind".} +proc isPointerType*(ty: TypeRef): bool = ty.getTypeKind == tkPointer + proc dumpType*(ty: TypeRef) {.sideeffect, importc: "LLVMDumpType".} proc toLLVMstring(ty: TypeRef): LLVMstring {.used, importc: "LLVMPrintTypeToString".} From 6cd3ca8ecac6140e2573faf73f2fdaa285d7cee4 Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:11:33 +0100 Subject: [PATCH 06/22] [tests] add sanity test for adding neutral EC element to EC sum --- tests/gpu/t_ec_sum.nim | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/gpu/t_ec_sum.nim b/tests/gpu/t_ec_sum.nim index 02724c3f..da928bde 100644 --- a/tests/gpu/t_ec_sum.nim +++ b/tests/gpu/t_ec_sum.nim @@ -19,7 +19,8 @@ import helpers/prng_unsafe proc testSum[Name: static Algebra](field: type FF[Name], wordSize: int, - a, b: EC_ShortW_Jac[field, G1]) = + a, b: EC_ShortW_Jac[field, G1], + iters = 1000) = # Codegen # ------------------------- let nv = initNvAsm(EC_ShortW_Jac[field, G1], wordSize) @@ -40,7 +41,7 @@ proc testSum[Name: static Algebra](field: type FF[Name], wordSize: int, # return point rGPU var res = a - for i in 0 ..< 1000: # `res = res + b`, starting with `a + b` + for i in 0 ..< iters: # `res = res + b`, starting with `a + b` res = checkSum(res, b) proc testMixedSum[Name: static Algebra](field: type FF[Name], wordSize: int, @@ -68,15 +69,16 @@ proc testMixedSum[Name: static Algebra](field: type FF[Name], wordSize: int, for i in 0 ..< 1000: # `res = res + b`, starting with `a + b` res = checkSum(res, b) +type EC = EC_ShortW_Jac[Fp[BN254_Snarks], G1] let x = "0x2ef34a5db00ff691849861d49415d8081d9d0e10cba33b57b2dd1f37f13eeee0" let y = "0x2beb0d0d6115007676f30bcc462fe814bf81198848f139621a3e9fa454fe8e6a" -let pt = EC_ShortW_Jac[Fp[BN254_Snarks], G1].fromHex(x, y) +let pt = EC.fromHex(x, y) echo pt.toHex() let x2 = "0x226c85cf65f4596a77da7d247310a81ac9aa9220e819e3ef23b6cbe0218ce272" let y2 = "0xf53265870f65aa18bded3ccb9c62a4d8b060a32a05a75d455710bce95a991df" -let pt2 = EC_ShortW_Jac[Fp[BN254_Snarks], G1].fromHex(x2, y2) +let pt2 = EC.fromHex(x2, y2) ## If `skipFinalSub` is set to `true` in the EC sum implementation ## `S1.prod(Q.z, Z2Z2, skipFinalSub = true)` @@ -88,6 +90,11 @@ pt2Aff.affine(pt2) testMixedSum(Fp[BN254_Snarks], 32, pt, pt2Aff) +block CheckAddZero: + var pt3: EC + pt3.setNeutral() + testSum(Fp[BN254_Snarks], 32, pt, pt3, iters = 2) + ## NOTE: While these inputs a, b are the ones that end up causing the ## CPU / GPU mismatch: From d8b21c5d8343dbf67af3af8ebbfd59cdefd49517 Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:12:17 +0100 Subject: [PATCH 07/22] store EC order bit width in CurveDescriptor --- constantine/math_compiler/codegen_nvidia.nim | 3 ++- constantine/math_compiler/ir.nim | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/constantine/math_compiler/codegen_nvidia.nim b/constantine/math_compiler/codegen_nvidia.nim index 808a6004..b6407bd1 100644 --- a/constantine/math_compiler/codegen_nvidia.nim +++ b/constantine/math_compiler/codegen_nvidia.nim @@ -586,7 +586,8 @@ proc initNvAsm*[Name: static Algebra](field: type EC_ShortW_Jac[Fp[Name], G1], w Fp[Name].getModulus().toHex(), v = 1, w = wordSize, coef_a = Fp[Name].Name.getCoefA(), - coef_B = Fp[Name].Name.getCoefB() + coef_B = Fp[Name].Name.getCoefB(), + curveOrderBitWidth = Fr[Name].bits() ) result.fd = result.cd.fd result.asy.definePrimitives(result.cd) diff --git a/constantine/math_compiler/ir.nim b/constantine/math_compiler/ir.nim index ddb752eb..e4569b7e 100644 --- a/constantine/math_compiler/ir.nim +++ b/constantine/math_compiler/ir.nim @@ -254,7 +254,8 @@ proc configureCurve*(ctx: ContextRef, name: string, modBits: int, modulus: string, v, w: int, - coefA, coefB: int): CurveDescriptor = + coefA, coefB: int, + curveOrderBitWidth: int): CurveDescriptor = ## Configure a curve descriptor with: ## - v: vector length ## - w: base word size in bits @@ -276,6 +277,7 @@ proc configureCurve*(ctx: ContextRef, # Curve parameters result.coef_a = coef_a result.coef_b = coef_b # unused + result.orderBitWidth = curveOrderBitWidth.uint32 proc definePrimitives*(asy: Assembler_LLVM, cd: CurveDescriptor) = asy.definePrimitives(cd.fd) From 7a786ef40e6c0ec6dda31de4679bec3879fc3d4c Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:17:45 +0100 Subject: [PATCH 08/22] make `store` for `ValueRef` safer by checking for pointer-ness Also adds `storePtr` if user really wants to store a pointer --- constantine/math_compiler/ir.nim | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/constantine/math_compiler/ir.nim b/constantine/math_compiler/ir.nim index e4569b7e..3e6d51d2 100644 --- a/constantine/math_compiler/ir.nim +++ b/constantine/math_compiler/ir.nim @@ -646,4 +646,19 @@ template load2*(asy: Assembler_LLVM, ty: TypeRef, `ptr`: ValueRef, name: cstring asy.br.load2(ty, `ptr`, name) template store*(asy: Assembler_LLVM, dst, src: ValueRef, name: cstring = "") = + if not dst.getTypeOf.isPointerType(): + raise newException(ValueError, "The destination argument to `store` is not a pointer type.") + if src.getTypeOf.isPointerType(): + raise newException(ValueError, "The source argument to `store` is a pointer type. " & + "You must `load2()` it before the store. Or use the `MutableValue` type, in which case " & + "we can load it automatically for you. If you really wish to store the pointer " & + "to the destination, use `storePtr` instead.") + asy.br.store(src, dst) + +template storePtr*(asy: Assembler_LLVM, dst, src: ValueRef, name: cstring = "") = + if not dst.getTypeOf.isPointerType(): + raise newException(ValueError, "The destination argument to `storePtr` is not a pointer type.") + if not src.getTypeOf.isPointerType(): + raise newException(ValueError, "The source argument to `storePtr` is not a pointer type. " & + "You likely want to call `store` instead.") asy.br.store(src, dst) From 9ee8fe5be3dde259f1a67ed9aa41c9a588572320 Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:18:55 +0100 Subject: [PATCH 09/22] forbid `=copy` on Array, likely *not* what user wants Easy to introduce bugs by thinking one stores, when in fact one just copies the reference. --- constantine/math_compiler/ir.nim | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/constantine/math_compiler/ir.nim b/constantine/math_compiler/ir.nim index 3e6d51d2..3b522660 100644 --- a/constantine/math_compiler/ir.nim +++ b/constantine/math_compiler/ir.nim @@ -304,8 +304,11 @@ type elemTy*: TypeRef int32_t: TypeRef -proc `[]`*(a: Array, index: SomeInteger): ValueRef {.inline.} -proc `[]=`*(a: Array, index: SomeInteger, val: ValueRef) {.inline.} +proc `=copy`*(m: var Array, x: Array) {.error: "Copying an Array is not allowed. " & + "You likely want to copy the LLVM value. Use `dst.store(src)` instead.".} + +proc `[]`*(a: Array, index: SomeInteger | ValueRef): ValueRef {.inline.} +proc `[]=`*(a: Array, index: SomeInteger | ValueRef, val: ValueRef) {.inline.} proc asArray*(br: BuilderRef, arrayPtr: ValueRef, arrayTy: TypeRef): Array = Array( From 494e4ca355cec9df77b50a7103ebfe84993653ec Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:20:07 +0100 Subject: [PATCH 10/22] allow access read/write of `Array` using `ValueRef` --- constantine/math_compiler/ir.nim | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/constantine/math_compiler/ir.nim b/constantine/math_compiler/ir.nim index 3b522660..39e9f910 100644 --- a/constantine/math_compiler/ir.nim +++ b/constantine/math_compiler/ir.nim @@ -348,13 +348,34 @@ proc getElementPtr*(a: Array, indices: varargs[int]): ValueRef = idxs[i] = constInt(a.int32_t, idx) result = a.builder.getElementPtr2_InBounds(a.arrayTy, a.buf, idxs) -proc `[]`*(a: Array, index: SomeInteger): ValueRef {.inline.}= +proc getElementPtr*(a: Array, indices: varargs[ValueRef]): ValueRef = + ## Helper to get an element pointer from a (nested) array using + ## indices that are already `ValueRef` + let idxs = @indices + result = a.builder.getElementPtr2_InBounds(a.arrayTy, a.buf, idxs) + +template asInt(x: SomeInteger | ValueRef): untyped = + when typeof(x) is ValueRef: x + else: x.int + +proc getPtr*(a: Array, index: SomeInteger | ValueRef): ValueRef {.inline.}= + ## First dereference the array pointer with 0, then access the `index` + ## but do not load the element! + when typeof(index) is SomeInteger: + result = a.getElementPtr(0, index.int) + else: + result = a.getElementPtr(constInt(a.int32_t, 0), index) + +proc `[]`*(a: Array, index: SomeInteger | ValueRef): ValueRef {.inline.}= # First dereference the array pointer with 0, then access the `index` - let pelem = a.getElementPtr(0, index.int) + let pelem = getPtr(a, index) a.builder.load2(a.elemTy, pelem) -proc `[]=`*(a: Array, index: SomeInteger, val: ValueRef) {.inline.}= - let pelem = a.getElementPtr(0, index.int) +proc `[]=`*(a: Array, index: SomeInteger | ValueRef, val: ValueRef) {.inline.}= + when typeof(index) is SomeInteger: + let pelem = a.getElementPtr(0, index.int) + else: + let pelem = a.getElementPtr(constInt(a.int32_t, 0), index) a.builder.store(val, pelem) proc store*(asy: Assembler_LLVM, dst: Array, src: Array) {.inline.}= From ec28afc2471e83e3ce12b2a5e077b1dffd185a94 Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:21:07 +0100 Subject: [PATCH 11/22] add `FieldScalar`, `FieldScalarArray`, `EcAffArray`, `EcAffArray` - for safer handling of multiple EC points in different coordinates - separate logic of elements of Fp (`Field`) from those of Fr (`FieldScalar`) --- .../math_compiler/impl_curves_ops_affine.nim | 29 +++++++ .../impl_curves_ops_jacobian.nim | 29 +++++++ constantine/math_compiler/ir.nim | 79 ++++++++++++++----- 3 files changed, 116 insertions(+), 21 deletions(-) diff --git a/constantine/math_compiler/impl_curves_ops_affine.nim b/constantine/math_compiler/impl_curves_ops_affine.nim index c207be7b..35e4b5a1 100644 --- a/constantine/math_compiler/impl_curves_ops_affine.nim +++ b/constantine/math_compiler/impl_curves_ops_affine.nim @@ -22,6 +22,16 @@ const SectionName = "ctt.curves_affine" type EcPointAff* {.borrow: `.`.} = distinct Array +proc `=copy`*(m: var EcPointAff, x: EcPointAff) {.error: "Copying an EcPointAff is not allowed. " & + "You likely want to copy the LLVM value. Use `dst.store(src)` instead.".} + +proc asEcPointAff*(br: BuilderRef, arrayPtr: ValueRef, arrayTy: TypeRef): EcPointAff = + ## Constructs an elliptic curve point in Affine coordinates from an array pointer. + ## + ## `arrayTy` is an `array[FieldTy, 2]` where `FieldTy` itsel is an array of + ## `array[WordTy, NumWords]`. + result = EcPointAff(br.asArray(arrayPtr, arrayTy)) + proc asEcPointAff*(asy: Assembler_LLVM, arrayPtr: ValueRef, arrayTy: TypeRef): EcPointAff = ## Constructs an elliptic curve point in Affine coordinates from an array pointer. ## @@ -54,6 +64,25 @@ proc store*(dst: EcPointAff, src: EcPointAff) = store(dst.getX(), src.getX()) store(dst.getY(), src.getY()) +# Array of EC points in affine coordinates +type EcAffArray* {.borrow: `.`.} = distinct Array + +proc `=copy`(m: var EcAffArray, x: EcAffArray) {.error: "Copying an EcAffArray is not allowed. " & + "You likely want to copy the LLVM value. Use `dst.store(src)` instead.".} + +proc `[]`*(a: EcAffArray, index: SomeInteger | ValueRef): EcPointAff = a.builder.asEcPointAff((distinctBase(a).getPtr(index)), a.elemTy) +proc `[]=`*(a: EcAffArray, index: SomeInteger | ValueRef, val: EcPointAff) = distinctBase(a)[index] = val.buf + +proc asEcAffArray*(asy: Assembler_LLVM, cd: CurveDescriptor, a: ValueRef, num: int): EcAffArray = + ## Interpret the given value `a` as an array of EC elements in Affine coordinates. + let ty = array_t(cd.curveTyAff, num) + result = EcAffArray(asy.br.asArray(a, ty)) + +proc initEcAffArray*(asy: Assembler_LLVM, cd: CurveDescriptor, num: int): EcAffArray = + ## Initialize a new EcAffArray for `num` elements + let ty = array_t(cd.curveTyAff, num) + result = EcAffArray(asy.makeArray(ty)) + template declEllipticAffOps*(asy: Assembler_LLVM, cd: CurveDescriptor): untyped = ## This template can be used to make operations on `Field` elements ## more convenient. diff --git a/constantine/math_compiler/impl_curves_ops_jacobian.nim b/constantine/math_compiler/impl_curves_ops_jacobian.nim index ab40502c..8ecae2f4 100644 --- a/constantine/math_compiler/impl_curves_ops_jacobian.nim +++ b/constantine/math_compiler/impl_curves_ops_jacobian.nim @@ -23,6 +23,16 @@ const SectionName = "ctt.curves_jacobian" type EcPointJac* {.borrow: `.`.} = distinct Array +proc `=copy`(m: var EcPointJac, x: EcPointJac) {.error: "Copying an EcPointJac is not allowed. " & + "You likely want to copy the LLVM value. Use `dst.store(src)` instead.".} + +proc asEcPointJac*(br: BuilderRef, arrayPtr: ValueRef, arrayTy: TypeRef): EcPointJac = + ## Constructs an elliptic curve point in Jacobian coordinates from an array pointer. + ## + ## `arrayTy` is an `array[FieldTy, 3]` where `FieldTy` itsel is an array of + ## `array[WordTy, NumWords]`. + result = EcPointJac(br.asArray(arrayPtr, arrayTy)) + proc asEcPointJac*(asy: Assembler_LLVM, arrayPtr: ValueRef, arrayTy: TypeRef): EcPointJac = ## Constructs an elliptic curve point in Jacobian coordinates from an array pointer. ## @@ -57,6 +67,25 @@ proc store*(dst: EcPointJac, src: EcPointJac) = store(dst.getY(), src.getY()) store(dst.getZ(), src.getZ()) +# Representation of a finite field point with some utilities +type EcJacArray* {.borrow: `.`.} = distinct Array + +proc `=copy`(m: var EcJacArray, x: EcJacArray) {.error: "Copying an EcJacArray is not allowed. " & + "You likely want to copy the LLVM value. Use `dst.store(src)` instead.".} + +proc `[]`*(a: EcJacArray, index: SomeInteger | ValueRef): EcPointJac = a.builder.asEcPointJac((distinctBase(a).getPtr(index)), a.elemTy) +proc `[]=`*(a: EcJacArray, index: SomeInteger | ValueRef, val: EcPointJac) = distinctBase(a)[index] = val.buf + +proc asEcJacArray*(asy: Assembler_LLVM, cd: CurveDescriptor, a: ValueRef, num: int): EcJacArray = + ## Interpret the given value `a` as an array of EC elements in Jacobian coordinates. + let ty = array_t(cd.curveTy, num) + result = EcJacArray(asy.br.asArray(a, ty)) + +proc initEcJacArray*(asy: Assembler_LLVM, cd: CurveDescriptor, num: int): EcJacArray = + ## Initialize a new EcJacArray for `num` elements + let ty = array_t(cd.curveTy, num) + result = EcJacArray(asy.makeArray(ty)) + template declEllipticJacOps*(asy: Assembler_LLVM, cd: CurveDescriptor): untyped = ## This template can be used to make operations on `Field` elements ## more convenient. diff --git a/constantine/math_compiler/ir.nim b/constantine/math_compiler/ir.nim index 39e9f910..7d518a68 100644 --- a/constantine/math_compiler/ir.nim +++ b/constantine/math_compiler/ir.nim @@ -388,27 +388,64 @@ proc store*(asy: Assembler_LLVM, dst: Array, src: ValueRef) {.inline.}= asy.br.store(src, dst.buf) # Representation of a finite field point with some utilities -type Field* {.borrow: `.`.} = distinct Array - -proc `[]`*(a: Field, index: SomeInteger): ValueRef = distinctBase(a)[index] -proc `[]=`*(a: Field, index: SomeInteger, val: ValueRef) = distinctBase(a)[index] = val - -proc asField*(br: BuilderRef, a: ValueRef, fieldTy: TypeRef): Field = - result = Field(br.asArray(a, fieldTy)) -proc asField*(asy: Assembler_LLVM, a: ValueRef, fieldTy: TypeRef): Field = - asy.br.asField(a, fieldTy) -proc asField*(asy: Assembler_LLVM, fd: FieldDescriptor, a: ValueRef): Field = - asy.br.asField(a, fd.fieldTy) - -proc newField*(asy: Assembler_LLVM, fd: FieldDescriptor): Field = - ## Use field descriptor for size etc? - result = Field(asy.makeArray(fd.fieldTy)) - -proc store*(dst: Field, src: Field) = - ## Stores the `dst` in `src`. Both must correspond to the same field of course. - assert dst.arrayTy.getArrayLength() == src.arrayTy.getArrayLength() - for i in 0 ..< dst.arrayTy.getArrayLength: - dst[i] = src[i] + +template genField(name, desc, field: untyped): untyped = + type name* {.borrow: `.`.} = distinct Array + + proc `=copy`(m: var name, x: name) {.error: "Copying a " & $name & " is not allowed. " & + "You likely want to copy the LLVM value. Use `dst.store(src)` instead.".} + + proc `[]`*(a: name, index: SomeInteger | ValueRef): ValueRef = distinctBase(a)[index] + proc `[]=`*(a: name, index: SomeInteger | ValueRef, val: ValueRef) = distinctBase(a)[index] = val + + proc `as name`*(br: BuilderRef, a: ValueRef, fieldTy: TypeRef): name = + result = name(br.asArray(a, fieldTy)) + proc `as name`*(asy: Assembler_LLVM, a: ValueRef, fieldTy: TypeRef): name = + asy.br.`as name`(a, fieldTy) + proc `as name`*(asy: Assembler_LLVM, d: desc, a: ValueRef): name = + asy.br.`as name`(a, d.field) + + proc `new name`*(asy: Assembler_LLVM, d: desc): name = + ## Use field descriptor for size etc? + result = name(asy.makeArray(d.field)) + + proc store*(dst: name, src: name) = + ## Stores the `dst` in `src`. Both must correspond to the same field of course. + assert dst.arrayTy.getArrayLength() == src.arrayTy.getArrayLength() + for i in 0 ..< dst.arrayTy.getArrayLength: + dst[i] = src[i] + + +genField(Field, FieldDescriptor, fieldTy) # intended for elements of `Fp[Curve]` +genField(FieldScalar, CurveDescriptor, fieldScalarTy) # intended for elements of `Fr[Curve]` + +# Representation of a finite field point with some utilities +type FieldArray* {.borrow: `.`.} = distinct Array + +proc `=copy`(m: var FieldArray, x: FieldArray) {.error: "Copying an FieldArray is not allowed. " & + "You likely want to copy the LLVM value. Use `dst.store(src)` instead.".} + +proc `[]`*(a: FieldArray, index: SomeInteger | ValueRef): Field = asField(a.builder, distinctBase(a).getPtr(index), a.elemTy) +proc `[]=`*(a: FieldArray, index: SomeInteger | ValueRef, val: ValueRef) = distinctBase(a)[index] = val + +proc asFieldArray*(asy: Assembler_LLVM, fd: FieldDescriptor, a: ValueRef, num: int): FieldArray = + ## Interpret the given value `a` as an array of Field elements. + let ty = array_t(fd.fieldTy, num) + result = FieldArray(asy.br.asArray(a, ty)) + +type FieldScalarArray* {.borrow: `.`.} = distinct Array + +proc `=copy`(m: var FieldScalarArray, x: FieldScalarArray) {.error: "Copying an FieldScalarArray is not allowed. " & + "You likely want to copy the LLVM value. Use `dst.store(src)` instead.".} + +proc `[]`*(a: FieldScalarArray, index: SomeInteger | ValueRef): FieldScalar = asFieldScalar(a.builder, distinctBase(a).getPtr(index), a.elemTy) +proc `[]=`*(a: FieldScalarArray, index: SomeInteger | ValueRef, val: ValueRef) = distinctBase(a)[index] = val + +proc asFieldScalarArray*(asy: Assembler_LLVM, cd: CurveDescriptor, a: ValueRef, num: int): FieldScalarArray = + ## Interpret the given value `a` as an array of Field elements. + let ty = array_t(cd.fieldScalarTy, num) + result = FieldScalarArray(asy.br.asArray(a, ty)) + # Conversion to native LLVM int # ------------------------------- From 21fb88c647c1df10d6c3a5bf27a03afc6052b8da Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:22:51 +0100 Subject: [PATCH 12/22] extend doc string of `compile` taking a string --- constantine/math_compiler/codegen_nvidia.nim | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/constantine/math_compiler/codegen_nvidia.nim b/constantine/math_compiler/codegen_nvidia.nim index b6407bd1..dbce048a 100644 --- a/constantine/math_compiler/codegen_nvidia.nim +++ b/constantine/math_compiler/codegen_nvidia.nim @@ -596,6 +596,19 @@ proc compile*(nv: NvidiaAssembler, kernName: string): CUfunction = ## Overload of `compile` below. ## Call this version if you have manually used the Assembler_LLVM object ## to build instructions and have a kernel name you wish to compile. + ## + ## Use this overload if your generator function does not match the `FieldFnGenerator` or + ## `CurveFnGenerator` signatures. This is useful if your function requires additional + ## arguments that are compile time values in the context of LLVM. + ## + ## Example: + ## + ## ```nim + ## let nv = initNvAsm(EC, wordSize) + ## let kernel = nv.compile(asy.genEcMSM(cd, 3, 1000) # window size, num. points + ## ``` + ## where `genEcMSM` returns the name of the kernel. + let ptx = nv.asy.codegenNvidiaPTX(nv.sm) # convert to PTX # GPU exec From cf095cf66e7585c23146bb7c6f8c5c53f3bff1cf Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:38:34 +0100 Subject: [PATCH 13/22] add ConstantValue, MutableValue wrappers around ValueRef Dealing with ValueRef and the fact that pointers are now opaque in LLVM is extremely annoying. So here are 2 types that wrap the LLVM values with their respective underlying types which also provide easier load / write access. --- constantine/math_compiler/ir.nim | 128 +++++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) diff --git a/constantine/math_compiler/ir.nim b/constantine/math_compiler/ir.nim index 7d518a68..1c0afebb 100644 --- a/constantine/math_compiler/ir.nim +++ b/constantine/math_compiler/ir.nim @@ -723,3 +723,131 @@ template storePtr*(asy: Assembler_LLVM, dst, src: ValueRef, name: cstring = "") raise newException(ValueError, "The source argument to `storePtr` is not a pointer type. " & "You likely want to call `store` instead.") asy.br.store(src, dst) + +## No-op to support calling it on `MutableValue | ConstantValue | ValueRef` +proc getValueRef*(x: ValueRef): ValueRef = x + +proc nimToLlvmType[T](asy: Assembler_LLVM, _: typedesc[T]): TypeRef = + when T is SomeInteger: + result = asy.ctx.int_t(sizeof(T) * 8) + else: + {.error: "Unsupported so far: " & $T.} + +type + ## A value constructed using `constX` + ConstantValue* = object + br: BuilderRef + val: ValueRef + typ: TypeRef + +proc `=copy`(m: var ConstantValue, x: ConstantValue) {.error: "Copying a constant value is not allowed. " & + "You likely want to copy the LLVM value. Use `dst.store(src)` instead.".} + + +proc initConstVal*(br: BuilderRef, val: ValueRef): ConstantValue = + ## Construct a constant value from a given LLVM value. + result = ConstantValue(br: br, val: val, typ: getTypeOf(val)) + +proc initConstVal*(asy: Assembler_LLVM, val: ValueRef): ConstantValue = + asy.br.initConstVal(val) + +proc initConstVal*[T: SomeInteger](asy: Assembler_LLVM, x: T): ConstantValue = + let t = asy.nimToLlvmType(T) + result = initConstVal(asy.br, constInt(t, x)) + +proc initConstVal*(br: BuilderRef, val: int{lit}, typ: TypeRef): ConstantValue = + ## Construct an LLVM value from an integer literal of the targe type `typ`. + result = br.initConstVal(constInt(typ, val)) + +proc initConstVal*(asy: Assembler_LLVM, val: int{lit}, typ: TypeRef): ConstantValue = + ## Construct an LLVM value from an integer literal of the targe type `typ`. + result = asy.br.initConstVal(constInt(typ, val)) + +template store*(asy: Assembler_LLVM, dst: ValueRef, src: ConstantValue, name: cstring = "") = + if not dst.getTypeOf.isPointerType(): + raise newException(ValueError, "The destination argument to `store` is not a pointer type.") + asy.br.store(src.val, dst) + +proc getValueRef*(v: ConstantValue): ValueRef = v.val + +proc asLlvmConstInt[T: SomeInteger | ValueRef](x: T, dtype: TypeRef): ValueRef = + ## Given either a value that is already an LLVM value ref or + ## a Nim value, return a `constInt` + when T is ValueRef: + ## XXX: check type is int + result = x + else: + result = constInt(dtype, x) + +type + ## A value constructed using `constX` + MutableValue* = object + br: BuilderRef + buf: ValueRef + typ: TypeRef ## type of the *underlying* type, not the pointer + +proc `=copy`(m: var MutableValue, x: MutableValue) {.error: "Copying a mutable value is not allowed. " & + "You likely want to copy the LLVM value. Use `dst.store(src)` instead.".} + +proc initMutVal*(br: BuilderRef, x: ValueRef): MutableValue = + ## Initializes a mutable value from a given LLVM value. Raises if the given + ## value is of pointer type. + if x.getTypeOf().isPointerType(): + raise newException(ValueError, "Initializing a mutable value from a pointer type is not supported.") + let typ = x.getTypeOf() + result = MutableValue( + br: br, + buf: br.alloca(typ), + typ: typ + ) + br.store(x, result.buf) # LLVM store is (source, dest) + +proc initMutVal*(br: BuilderRef, x: ConstantValue): MutableValue = + br.initMutVal(x.val) + +proc initMutVal*(asy: Assembler_LLVM, x: ConstantValue): MutableValue = + asy.br.initMutVal(x) + +proc initMutVal*(br: BuilderRef, typ: TypeRef): MutableValue = + if typ.getTypeKind != tkInteger: + raise newException(ValueError, "Initializing a mutable value from a non integer type without value is not supported. " & + "Type is: " & $typ) + br.initMutVal(constInt(typ, 0)) + +proc initMutVal*(asy: Assembler_LLVM, typ: TypeRef): MutableValue = + asy.br.initMutVal(typ) + +proc initMutVal*[T](br: BuilderRef): MutableValue = + br.initMutVal(default(T)) # initialize with default value for correct type info + +proc initMutVal*[T](asy: Assembler_LLVM): MutableValue = + asy.br.initMutVal[:T]() + +proc load*(m: MutableValue): ConstantValue = + result = m.br.initConstVal(m.br.load2(m.typ, m.buf)) + +proc store*(m: MutableValue, val: ValueRef) = + if val.getTypeOf.isPointerType(): + raise newException(ValueError, "The source argument to `store` is a pointer type. " & + "You must `load2()` it before the store. Or use the `MutableValue` type, in which case " & + "we can load it automatically for you. If you really wish to store the pointer " & + "to the destination, use `storePtr` instead.") + m.br.store(val, m.buf) # LLVM store uses (target, source) + +proc store*(asy: Assembler_LLVM, dst: ValueRef, m: MutableValue) = + asy.store(dst, m.load().val) # delegate to regular template defined further above + +proc store*(asy: Assembler_LLVM, dst: MutableValue, x: ValueRef) = + asy.store(dst.buf, x) # delegate to regular template defined further above + +proc storePtr*(m: MutableValue, val: ValueRef) = + if not val.getTypeOf.isPointerType(): + raise newException(ValueError, "The source argument to `store` is not a pointer type. " & + "You likely want to call `store` instead.") + m.br.store(val, m.buf) # LLVM store uses (target, source) + +proc store*(m: MutableValue, val: ConstantValue) = + m.store(val.val) + +proc getValueRef*(m: MutableValue): ValueRef = m.load().val + From 5d7f03dd94a6c6479413245b1ca48aca2b70be2a Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:39:50 +0100 Subject: [PATCH 14/22] add `llvmFor` macro that produces code for a for loop in LLVM --- constantine/math_compiler/ir.nim | 86 ++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/constantine/math_compiler/ir.nim b/constantine/math_compiler/ir.nim index 1c0afebb..94f361de 100644 --- a/constantine/math_compiler/ir.nim +++ b/constantine/math_compiler/ir.nim @@ -851,3 +851,89 @@ proc store*(m: MutableValue, val: ConstantValue) = proc getValueRef*(m: MutableValue): ValueRef = m.load().val +## Convenience templates that make writing code more succinct + +import std / macros +template llvmForImpl(asy, iter, suffix: untyped, start, stop, isCountup: typed, body: untyped): untyped = + ## `asy: Assembler_LLVM`, `fn` need to be in scope! + ## Start and stop need to be Nim values (CT or RT) + block: + let loopEntry = asy.ctx.appendBasicBlock(fn, "loop.entry" & suffix) + let loopBody = asy.ctx.appendBasicBlock(fn, "loop.body" & suffix) + let loopExit = asy.ctx.appendBasicBlock(fn, "loop.exit" & suffix) + + # Branch to loop entry + asy.br.br(loopEntry) + + # Position at loop entry + asy.br.positionAtEnd(loopEntry) + + # stopping value & increment / decrement per iteration + let cStart = asLlvmConstInt(start, asy.ctx.int32_t()) + let cStop = asLlvmConstInt(stop, asy.ctx.int32_t()) + let change = if isCountup: 1 else: -1 + let cChange = constInt(asy.ctx.int32_t(), change) + + # Loop entry condition + let cmp = if isCountup: kSLE else: kSGE + let condition = asy.br.icmp(cmp, cStart, cStop) + asy.br.condBr(condition, loopBody, loopExit) + + # Loop body + asy.br.positionAtEnd(loopBody) + let phi = asy.br.phi(getTypeOf cStart) + phi.addIncoming(cStart, loopEntry) + + # Inject the phi node as the iterator + let iter {.inject.} = phi + # The loop body + body + + # Increment / decrement for next iteration + let nextIter = asy.br.add(phi, cChange) # will subtract for countdown + + ## After the loop body the builder may not be in the `loopBody` anymore. + ## Consider: + ## + ## llvmFor i, 0, 10, true: # Outer loop + ## # Block: outer.body + ## llvmFor j, 0, 5, true: # Inner loop + ## # Block: inner.body + ## # ... instructions ... + ## # After inner loop - which block are we in? + ## + ## # Need to add PHI incoming edge for outer loop + ## phi.addIncoming(nextIter, ????) # <-- `getInsertBlock` yields the after block of the inner loop + ## # `loopBody` would be incorrect as a result of the inner loop. + phi.addIncoming(nextIter, asy.br.getInsertBlock()) + + # Check if we should continue looping + let continueLoop = asy.br.icmp(cmp, nextIter, cStop) + asy.br.condBr(continueLoop, loopBody, loopExit) + + # Loop exit + asy.br.positionAtEnd(loopExit) + +macro llvmFor*(asy: untyped, iter: untyped, start, stop, isCountup: typed, body: untyped): untyped = + let label = $genSym(nskLabel, "loop") + result = quote do: + llvmForImpl(`asy`, `iter`, `label`, `start`, `stop`, `isCountup`, `body`) + +template llvmFor*(asy: untyped, iter: untyped, start, stop: typed, body: untyped): untyped {.dirty.} = + ## Start and stop must be Nim values + block: + let isCountup = start < stop + asy.llvmFor iter, start, stop, isCountup: + body + +template llvmForCountup*(asy: untyped, iter: untyped, start, stop: typed, body: untyped): untyped {.dirty.} = + ## Start and stop can either be Nim or LLVM values + block: + asy.llvmFor iter, start, stop, true: + body + +template llvmForCountdown*(asy: untyped, iter: untyped, start, stop: typed, body: untyped): untyped {.dirty.} = + ## Start and stop can either be Nim or LLVM values + block: + asy.llvmFor iter, start, stop, false: + body From 9a0f8eb54456af5e3406ad9e63525b900ada5839 Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:47:54 +0100 Subject: [PATCH 15/22] add helpers for arithmetic, boolean logic for ValueRef, M/CValue --- constantine/math_compiler/ir.nim | 82 ++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/constantine/math_compiler/ir.nim b/constantine/math_compiler/ir.nim index 94f361de..05342a3f 100644 --- a/constantine/math_compiler/ir.nim +++ b/constantine/math_compiler/ir.nim @@ -937,3 +937,85 @@ template llvmForCountdown*(asy: untyped, iter: untyped, start, stop: typed, body block: asy.llvmFor iter, start, stop, false: body + +## Convenience utilities for `ValueRef` (representing numbers) for LLVM +template declNumberOps*(asy: Assembler_LLVM, fd: FieldDescriptor): untyped = + ## Declares templates similar to the field and EC ops templates + ## for `ValueRef`, `MutableValue` and `ConstantValue` so that one + ## can effectively write regular arithmetic / boolean logic code + ## with LLVM values to produce the correct code. + template genLhsRhsVariants(name, fn: untyped): untyped = + ## Generates variants for mix of Nim integer + ValueRef and + ## pure ValueRef + type T = int | uint32 | uint64 + type U = ValueRef | MutableValue | ConstantValue + type X = MutableValue | ConstantValue + + let I = fd.wordTy + + template name(lhs, rhs: ValueRef): untyped = + if $getTypeOf(lhs) != $getTypeOf(rhs): + raise newException(ValueError, "Inputs do not have matching types. LHS = " & $getTypeOf(lhs) & ", RHS = " & $getTypeOf(rhs)) + elif getTypeOf(lhs).isPointerType(): + raise newException(ValueError, "Inputs must not be pointer types.") + asy.br.fn(lhs, rhs) + template name(lhs: SomeInteger, rhs: U): untyped = + block: + let lhsV = constInt(I, lhs) + asy.br.fn(lhsV, getValueRef rhs) + template name(lhs: U, rhs: SomeInteger): untyped = + block: + let rhsV = constInt(I, rhs) + asy.br.fn(getValueRef lhs, rhsV) + template name[T: X; U: X](lhs: T; rhs: U): untyped = + asy.br.fn(getValueRef lhs, getValueRef rhs) + + template genLhsRhsBooleanVariants(name, pred: untyped): untyped = + ## Generates variants for mix of Nim integer + ValueRef and + ## pure ValueRef for boolean operations + type T = int | uint32 | uint64 + type U = ValueRef | MutableValue | ConstantValue + type X = MutableValue | ConstantValue + + let I = fd.wordTy + + template name(lhs, rhs: ValueRef): untyped = + if $getTypeOf(lhs) != $getTypeOf(rhs): + raise newException(ValueError, "Inputs do not have matching types. LHS = " & $getTypeOf(lhs) & ", RHS = " & $getTypeOf(rhs)) + elif getTypeOf(lhs).isPointerType(): + raise newException(ValueError, "Inputs must not be pointer types.") + asy.br.icmp(pred, lhs, rhs) + template name(lhs: T; rhs: U): untyped = + block: + let lhsV = constInt(I, lhs) + name(lhsV, getValueRef rhs) + template name(lhs: U; rhs: T): untyped = + block: + let rhsV = constInt(I, rhs) + name(getValueRef lhs, rhsV) + template name[T: X; U: X](lhs: T; rhs: U): untyped = + name(getValueRef lhs, getValueRef rhs) + + # standard binary operations + genLhsRhsVariants(`shl`, lshl) + genLhsRhsVariants(`shr`, lshr) + genLhsRhsVariants(`and`, `and`) + genLhsRhsVariants(`or`, `or`) + genLhsRhsVariants(`+`, add) + genLhsRhsVariants(`-`, sub) + genLhsRhsVariants(`*`, mul) + + # boolean based on `icmp` + genLhsRhsBooleanVariants(`<`, kSLT) + genLhsRhsBooleanVariants(`<=`, kSLE) + genLhsRhsBooleanVariants(`==`, kEQ) + ## XXX: The following cause overload resolution errors for + ## bog standard types, i.e. `>` of `uint32` or `!=` for `string`. + ## I think this is because `!=`, `>` and `>=` are implemented as + ## untyped templates in system.nim. + ## Slightly problematic, we need to add `not` for LLVM to achieve + ## the correct behavior. + #genLhsRhsBooleanVariants(`>`, kSGT) + #genLhsRhsBooleanVariants(`>=`, kSGE) + #genLhsRhsBooleanVariants(`!=`, kNE) + From c67548cd2bd1c134ec9024ff21580a3c3ccb4c65 Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:48:37 +0100 Subject: [PATCH 16/22] add `llvmIf` to generate code for if statements It _wraps around_ a full if statement. --- constantine/math_compiler/ir.nim | 148 +++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) diff --git a/constantine/math_compiler/ir.nim b/constantine/math_compiler/ir.nim index 05342a3f..47fc506c 100644 --- a/constantine/math_compiler/ir.nim +++ b/constantine/math_compiler/ir.nim @@ -1019,3 +1019,151 @@ template declNumberOps*(asy: Assembler_LLVM, fd: FieldDescriptor): untyped = #genLhsRhsBooleanVariants(`>=`, kSGE) #genLhsRhsBooleanVariants(`!=`, kNE) +proc collectElifBranches(n: NimNode): tuple[elifs: seq[NimNode], els: NimNode] = + ## The `else` branch is an optional second argument, `els` + doAssert n.kind == nnkIfStmt + result.els = newEmptyNode() # set to empty as default + for el in n: + case el.kind + of nnkElifBranch: result.elifs.add el + of nnkElse: result.els = el + else: raiseAssert "Invalid branch: " & $el.kind + +macro llvmIf*(asy, body: untyped): untyped = + ## Rewrites the given body, which *must* contain an if statement + ## with (possibly) multiple branches) into conditional branches + ## on LLVM. We jump from the current block of `asy` into the + ## conditional branches and provide a block at the end, which we + ## will reach from every if branch. + ## + ## NOTE: This can only be used inside of an `llvmInternalFnDef` template, + ## because it needs access to the current function, `fn` identifier. + ## + ## BE CAREFUL: For the moment this macro does not handle using values + ## assigned in its body after the if statements. This would require + ## creating a φ-node for the value, which we currently do not do. + ## Mainly, because this requires a more complicated traversal of the + ## macro body to detect such a requirement. Instead we might add + ## an alternative `llvmIfUse` or similar in the future, where exactly + ## one assignment is allowed. + ## + ## IfStmt + ## ElifBranch + ## Ident "true" + ## StmtList + ## Command + ## Ident "echo" + ## StrLit "x" + ## Else + ## StmtList + ## Command + ## Ident "echo" + ## StrLit "y" + ## + doAssert body.kind in {nnkIfStmt, nnkStmtList}, "Input *must* be an if statement, but is: " & $body.kind + var body = body + if body.kind == nnkStmtList: + doAssert body.len == 1 and body[0].kind == nnkIfStmt, "If a nnkStmtList, must only contain an nnkIfStmt, but: " & $body.treerepr + body = body[0] + + # 1. collect all elif branches (and possible else) + let (elifs, els) = collectElifBranches(body) + let hasElse = els.kind != nnkEmpty + + # For each `elif` (including the first `if`) we need 2 blocks: + # - elif condition + # - if true body + # If `els` is set, need an additional: + # - else body + # Finally, need an + # - after if/else body + result = newStmtList() + + # 2. generate all required blocks + var elifBranches = newSeq[tuple[cond, body: NimNode]]() # contains the *identifiers* for the blocks + for i, el in elifs: + # 2.1 create the identifiers + let condId = genSym(nskLet, "elifCond") + let bodyId = genSym(nskLet, "elifBody") + # 2.2 generate let stmts to append the blocks to LLVM context + let idx = $i + result.add quote do: + let `condId` = `asy`.ctx.appendBasicBlock(fn, "elif.condition." & `idx`) + let `bodyId` = `asy`.ctx.appendBasicBlock(fn, "elif.body." & `idx`) + # 2.3 store + elifBranches.add (cond: condId, body: bodyId) + # 2.4 create `else` block if needed + var elseId: NimNode + if hasElse: + elseId = genSym(nskLet, "elseBody") + result.add quote do: + let `elseId` = `asy`.ctx.appendBasicBlock(fn, "else.body") + # 2.5 create 'after if/else' block + let afterId = genSym(nskLet, "afterBody") + result.add quote do: + let `afterId` = `asy`.ctx.appendBasicBlock(fn, "after.body") + + # 3. jump to the first if condition + let firstCond = elifBranches[0].cond + result.add quote do: + `asy`.br.br(`firstCond`) + + # 4. fill all the blocks + for i, el in elifs: + # 4.1 take condition of `elif` + # ElifBranch + # Ident "true" <- `el[0]` + # StmtList <- `el[1]` + # Command + # Ident "echo" + # StrLit "x" + let cond = el[0] + # 4.2 set our builder to this block + let condId = elifBranches[i].cond + let condVal = genSym(nskLet, "condVal") + result.add quote do: + `asy`.br.positionAtEnd(`condId`) + # 4.3 fill the block with the condition + result.add quote do: + let `condVal` = `cond` + + #result.add nnkBlockStmt.newTree(ident("Block_" & $condId), cond) + # 4.4 determine the `else` block to jump to (else, after or next elif) + let ifFalseNext = + if i < elifBranches.high: # has another elif + elifBranches[i+1].cond + elif hasElse: # last, jump to existing else + elseId + else: # neither, jump after if/else + afterId + + # 4.5 conditionally branch based on condition the block with a conditional branch + let bodyId = elifBranches[i].body + result.add quote do: + `asy`.br.condBr(`condVal`, `bodyId`, `ifFalseNext`) + + # 4.6 set builder to if body + let blkBody = el[1] + result.add quote do: + `asy`.br.positionAtEnd(`bodyId`) + # 4.7 fill the block with the body + result.add nnkBlockStmt.newTree(ident("Block_" & $bodyId), blkBody) + # 4.8 branch to after if + result.add quote do: + `asy`.br.br(`afterId`) + + # 5. handle the `else` branch if any + if hasElse: + # 5.1 position at else block + result.add quote do: + `asy`.br.positionAtEnd(`elseId`) + # 5.2 add the else body to this block + result.add els[0] + # 5.3 jump to after block + result.add quote do: + `asy`.br.br(`afterId`) + + # 6. position builder at after block + result.add quote do: + `asy`.br.positionAtEnd(`afterId`) + From 0b282325a069cb0bf77359a75c60636c76220251 Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:49:30 +0100 Subject: [PATCH 17/22] add `to` type conversion helper which extends/truncates int types --- constantine/math_compiler/ir.nim | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/constantine/math_compiler/ir.nim b/constantine/math_compiler/ir.nim index 47fc506c..5f4db183 100644 --- a/constantine/math_compiler/ir.nim +++ b/constantine/math_compiler/ir.nim @@ -1167,3 +1167,26 @@ macro llvmIf*(asy, body: untyped): untyped = result.add quote do: `asy`.br.positionAtEnd(`afterId`) +proc to*(asy: Assembler_LLVM, x: ValueRef, dtype: TypeRef, signed = false): ValueRef = + ## Converts the given integer type of `x` to the target type `T`. + ## The numbers are treated as signed integers if `signed` is true, else + ## as unsigned. + let outsize = getIntTypeWidth(dtype) + let tk = x.getTypeOf().getTypeKind() + if tk != tkInteger: + raise newException(ValueError, "The argument is not an integer type, but: " & $getTypeOf(x)) + let inSize = getTypeOf(x).getIntTypeWidth() + if inSize == outsize: + result = x + elif inSize < outsize: + # extend, + if signed: + result = asy.br.sext(x, dtype, "to.i" & $outSize) + else: + result = asy.br.zext(x, dtype, "to.u" & $outSize) + else: # trunacte + result = asy.br.trunc(x, dtype, "trunc.to.i" & $outSize) + +proc to*[T](asy: Assembler_LLVM, x: ValueRef, dtype: typedesc[T], signed = false): ValueRef = + let outTyp = asy.nimToLlvmType(T) + result = asy.to(x, outTyp, signed) From 085b233ebb5a1923ed7c8b90b0ba732d5afa578d Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:49:50 +0100 Subject: [PATCH 18/22] use `llvmForCountdown` in `genFpNsqrRt` instead of fixed countdown logic --- constantine/math_compiler/pub_fields.nim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/constantine/math_compiler/pub_fields.nim b/constantine/math_compiler/pub_fields.nim index ce24a4d3..16019901 100644 --- a/constantine/math_compiler/pub_fields.nim +++ b/constantine/math_compiler/pub_fields.nim @@ -330,7 +330,7 @@ proc genFpNsqrRT*(asy: Assembler_LLVM, fd: FieldDescriptor): string = for i in 0 ..< fd.numWords: rA[i] = aA[i] - asy.genCountdownLoop(fn, count): + asy.llvmForCountdown i, count, 0: # countdown from count to 0 # use `mtymul` to multiply `r·r` and store it again in `r` asy.mtymul(fd, r, r, r, M) From bdf667dd19e6b6f034a3c1e5a5da8f18b00255fa Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:50:18 +0100 Subject: [PATCH 19/22] add `getWindowAt` helper required for baseline MSM implementation --- constantine/math_compiler/impl_fields_ops.nim | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/constantine/math_compiler/impl_fields_ops.nim b/constantine/math_compiler/impl_fields_ops.nim index 2564ea0f..a4eeb55c 100644 --- a/constantine/math_compiler/impl_fields_ops.nim +++ b/constantine/math_compiler/impl_fields_ops.nim @@ -7,6 +7,7 @@ # at your option. This file may not be copied, modified, or distributed except according to those terms. import + constantine/platforms/bithacks, # for log2_vartime constantine/platforms/llvm/[llvm, asm_nvidia], ./ir, ./impl_fields_globals, @@ -725,3 +726,46 @@ proc scalarMul*(asy: Assembler_LLVM, fd: FieldDescriptor, a: ValueRef, b: int) = asy.br.retVoid() asy.callFn(name, [a]) + +proc getWindowAt*(asy: Assembler_LLVM, cd: CurveDescriptor, r, c, bI, wI: ValueRef) {.used.} = + ## Generate an internal field `getWindowAt` function + ## with signature + ## void name(BaseType r, FieldType c, int bitIndex, int windowSize) + let name = cd.fd.name & "_getWindowAt" + asy.llvmInternalFnDef( + name, SectionName, + asy.void_t, toTypes([r, c, bI, wI]), + {kHot}): + tagParameter(1, "sret") + + # Operations for numbers as `ValueRef` + declNumberOps(asy, cd.fd) + + let (ri, ci, bitIndex, windowSize) = llvmParams + let rA = asy.asFieldScalar(cd, ri) + let cA = asy.asFieldScalar(cd, ci) + let fd = cd.fd + + # Nim values + let SlotShift = log2_vartime(fd.w.uint32) + let WordMask = fd.w - 1 + let WindowMask = (1 shl windowSize) - 1 # LLVM + + # LLVM values + let slot = bitIndex shr SlotShift + let word = cA[slot] # word in limbs + let pos = bitIndex and WordMask # position in the word + + # This is constant-time, the branch does not depend on secret data. + llvmIf(asy): # transforms an `if` statement body into llvm conditional branches + if pos + windowSize > fd.w and slot+1 < fd.numWords: + # Read next word as well + let x = ((word shr pos) or (cA[slot+1] shl (fd.w - pos))) and WindowMask + asy.store(ri, x) + else: + let x = (word shr pos) and WindowMask + asy.store(ri, x) + + asy.br.retVoid() + + asy.callFn(name, [r, c, bI, wI]) From 44ce9df944efb0f425420f8bd3f253f2269bdca1 Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:57:20 +0100 Subject: [PATCH 20/22] add serial MSM implementation for Nvidia using bucket method This implementation is a bit of a proof of concept and playground to investigate how easily we can generate code on the LLVM target with the help of Nim macros. --- constantine/math_compiler/impl_msm_nvidia.nim | 96 +++++++++++++++++++ .../math_compiler/pub_curves_jacobian.nim | 24 ++++- 2 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 constantine/math_compiler/impl_msm_nvidia.nim diff --git a/constantine/math_compiler/impl_msm_nvidia.nim b/constantine/math_compiler/impl_msm_nvidia.nim new file mode 100644 index 00000000..19f28a2e --- /dev/null +++ b/constantine/math_compiler/impl_msm_nvidia.nim @@ -0,0 +1,96 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + constantine/platforms/llvm/[llvm, asm_nvidia], + constantine/platforms/[primitives], + ./ir, + ./impl_fields_globals, + ./impl_fields_dispatch, + ./impl_fields_ops, + ./impl_curves_ops_affine, + ./impl_curves_ops_jacobian, + std / typetraits # for distinctBase + +## Section name used for `llvmInternalFnDef` +const SectionName = "ctt.msm_nvidia" + +proc msm*(asy: Assembler_LLVM, cd: CurveDescriptor, r, coefs, points: ValueRef, + c, N: int) {.used.} = + ## Inner implementation of MSM, for static dispatch over c, the bucket bit length + ## This is a straightforward simple translation of BDLO12, section 4 + ## + ## Entirely serial implementation! + ## + ## Important note: The coefficients given to this procedure must be in canonical + ## representation instead of Montgomery representation! Thus, you cannot pass + ## values of type `Fr[Curve]` directly, as they are internally stored in Montgomery + ## rep. Convert to a `BigInt` using `fromField`. + let name = cd.name & "_msm_impl" + asy.llvmInternalFnDef( + name, SectionName, + asy.void_t, toTypes([r, coefs, points]), + {kHot}): + tagParameter(1, "sret") + + # Inject templates for convenient access + declFieldOps(asy, cd.fd) + declEllipticJacOps(asy, cd) + declEllipticAffOps(asy, cd) + declNumberOps(asy, cd.fd) + + let (ri, coefsIn, pointsIn) = llvmParams + let rA = asy.asEcPointJac(cd, ri) + let cs = asy.asFieldScalarArray(cd, coefsIn, N) # coefficients + let Ps = asy.asEcAffArray(cd, pointsIn, N) # EC points + # Prologue + # -------- + let numBuckets = 1 shl c - 1 # bucket 0 is unused + let numWindows = cd.orderBitWidth.int.ceilDiv_vartime(c) + + let miniMSMs = asy.initEcJacArray(cd, numWindows) + let buckets = asy.initEcJacArray(cd, numBuckets) + + # Algorithm + # --------- + var cNonZero = asy.initMutVal(cd.fd.wordTy) + asy.llvmFor w, 0, numWindows - 1, true: + # Place our points in a bucket corresponding to + # how many times their bit pattern in the current window of size c + asy.llvmFor i, 0, numBuckets - 1, true: + buckets[i].setNeutral() + + # 1. Bucket accumulation. Cost: n - (2ᶜ-1) => n points in 2ᶜ-1 buckets, first point per bucket is just copied + asy.llvmFor j, 0, N-1, true: + var b = asy.initMutVal(cd.fd.wordTy) + let w0 = asy.initConstVal(0, cd.fd.wordTy) + asy.getWindowAt(cd, b.buf, cs[j].buf, asy.to(w, cd.fd.wordTy) * c, constInt(cd.fd.wordTy, c)) + llvmIf(asy): + if b != w0: + buckets[b-1] += Ps[j] + + var accumBuckets = asy.newEcPointJac(cd) + var miniMSM = asy.newEcPointJac(cd) + accumBuckets = buckets[numBuckets-1] + miniMSM.store(buckets[numBuckets-1]) + + asy.llvmFor k, numBuckets-2, 0, false: + accumBuckets += buckets[k] # Stores S₈ then S₈+S₇ then S₈+S₇+S₆ then ... + miniMSM += accumBuckets # Stores S₈ then [2]S₈+S₇ then [3]S₈+[2]S₇+S₆ then ... + + miniMSMs[w].store(miniMSM) + + rA.store(miniMSMs[numWindows-1]) + asy.llvmFor w, numWindows-2, 0, false: + asy.llvmFor j, 0, c-1: + rA.double() + rA += miniMSMs[w] + + asy.br.retVoid() + + asy.callFn(name, [r, coefs, points]) diff --git a/constantine/math_compiler/pub_curves_jacobian.nim b/constantine/math_compiler/pub_curves_jacobian.nim index e9040ecf..0f8d8908 100644 --- a/constantine/math_compiler/pub_curves_jacobian.nim +++ b/constantine/math_compiler/pub_curves_jacobian.nim @@ -15,7 +15,9 @@ import ./impl_fields_dispatch, ./impl_fields_ops, ./impl_curves_ops_affine, - ./impl_curves_ops_jacobian + ./impl_curves_ops_jacobian, + ./impl_msm_nvidia + ## Section name used for `llvmInternalFnDef` const SectionName = "ctt.pub_curves_jacobian" @@ -158,3 +160,23 @@ proc genEcMixedSum*(asy: Assembler_LLVM, cd: CurveDescriptor): string = asy.mixedSum(cd, ri, pi, qi) asy.br.retVoid() result = name + +proc genEcMSM*(asy: Assembler_LLVM, cd: CurveDescriptor, c, N: int): string = + ## Generate a publc elliptic curve MSM proc for EC points in affine + ## coordinates and coefficients in canonical representation. + ## Uses the bucket method and is currently fully serial. So don't + ## expect any speedup from the CPU implementation. + ## + ## `c` is the window size and `N` the number of points. The code + ## requires these to be defined at compile time. + ## + ## Returns the name of the produced kernel to call it. + let fT = array_t(cd.fieldScalarTy, N) + let cT = array_t(cd.curveTyAff, N) + let name = cd.name & "_msm_public_c_" & $c & "_N_" & $N + + asy.llvmPublicFnDef(name, "ctt." & cd.name, asy.void_t, [cd.curveTy, fT, cT]): + let (ri, cs, ps) = llvmParams + asy.msm(cd, ri, cs, ps, c, N) + asy.br.retVoid() + result = name From 6eb0c607aebb9ffc83c5a85d02c0b2d6a19c2d18 Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 17:58:13 +0100 Subject: [PATCH 21/22] [tests] add mini test case for MSM on Nvidia --- tests/gpu/t_msm.nim | 68 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 tests/gpu/t_msm.nim diff --git a/tests/gpu/t_msm.nim b/tests/gpu/t_msm.nim new file mode 100644 index 00000000..877c9227 --- /dev/null +++ b/tests/gpu/t_msm.nim @@ -0,0 +1,68 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + # Internal + constantine/named/algebras, + constantine/math/io/[io_bigints, io_fields, io_ec], + constantine/math/arithmetic, + constantine/math/elliptic/[ec_shortweierstrass_affine, ec_shortweierstrass_jacobian, ec_multi_scalar_mul], + constantine/platforms/abstractions, + constantine/platforms/llvm/llvm, + constantine/math_compiler/[ir, pub_fields, pub_curves_jacobian, codegen_nvidia, impl_fields_globals], + # Test utilities + helpers/prng_unsafe + +type + EC = EC_ShortW_Jac[Fp[BN254_Snarks], G1] + ECAff = EC_ShortW_Aff[Fp[BN254_Snarks], G1] +const wordSize = 32 + +# 2 EC points +let x = "0x2ef34a5db00ff691849861d49415d8081d9d0e10cba33b57b2dd1f37f13eeee0" +let y = "0x2beb0d0d6115007676f30bcc462fe814bf81198848f139621a3e9fa454fe8e6a" +let pt = ECAff.fromHex(x, y) +echo pt.toHex() + +let x2 = "0x226c85cf65f4596a77da7d247310a81ac9aa9220e819e3ef23b6cbe0218ce272" +let y2 = "0xf53265870f65aa18bded3ccb9c62a4d8b060a32a05a75d455710bce95a991df" +let pt2 = ECAff.fromHex(x2, y2) + +# 2 coefficients +let a = Fr[BN254_Snarks].fromUInt(1'u32) +let b = Fr[BN254_Snarks].fromHex("0x2beb0d0d6115007676f30bcc462fe814bf81198848f139621a3e9fa454fe8e6a") + +proc fromField[BigInt](x: FF): BigInt = + result.fromField(x) + +type CB = Fr[BN254_Snarks].getBigInt() + +template toPOA(x): untyped = cast[ptr UncheckedArray[x[0].typeof]](x[0].addr) + +let bN = fromField[CB](b) # convert to BigInt to go from Montgomery rep to canonical rep +let coefs = [bN, bN] + +let points = [pt, pt2] + +block MSM: + # Codegen + # ------------------------- + let nv = initNvAsm(EC, wordSize) + let kernel = nv.compile(nv.asy.genEcMSM(nv.cd, 3, coefs.len)) + + # For CPU: + var rCPU: EC + rCPU.multiScalarMul_reference_vartime(@coefs, @points) + + # For GPU: + var rGPU: EC + kernel.execCuda(res = rGPU, inputs = (coefs, points)) + echo "CPU: ", rCPU.toHex() + echo "GPU: ", rGPU.toHex() + # Verify CPU and GPU agree + doAssert bool(rCPU == rGPU) From 83e603a83810ac2715b69e87b2038beff0f927f8 Mon Sep 17 00:00:00 2001 From: Vindaar Date: Tue, 5 Nov 2024 18:08:46 +0100 Subject: [PATCH 22/22] whoops, revert local change to test CT error on `=copy` --- constantine/math_compiler/impl_msm_nvidia.nim | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/constantine/math_compiler/impl_msm_nvidia.nim b/constantine/math_compiler/impl_msm_nvidia.nim index 19f28a2e..c2c3ab13 100644 --- a/constantine/math_compiler/impl_msm_nvidia.nim +++ b/constantine/math_compiler/impl_msm_nvidia.nim @@ -76,7 +76,7 @@ proc msm*(asy: Assembler_LLVM, cd: CurveDescriptor, r, coefs, points: ValueRef, var accumBuckets = asy.newEcPointJac(cd) var miniMSM = asy.newEcPointJac(cd) - accumBuckets = buckets[numBuckets-1] + accumBuckets.store(buckets[numBuckets-1]) miniMSM.store(buckets[numBuckets-1]) asy.llvmFor k, numBuckets-2, 0, false: