From 69354ae29c48dde3dc666fdcab3e6e433a0e6c2c Mon Sep 17 00:00:00 2001 From: Oleksandr Brezhniev Date: Tue, 28 Dec 2021 20:47:12 +0200 Subject: [PATCH] Faster ff arithmetics (regenerated code with the newest goff) (#43) --- babyjub/babyjub.go | 14 +- ff/arith.go | 66 +- ff/asm.go | 24 + ff/asm_noadx.go | 25 + ff/doc.go | 43 + ff/element.go | 1309 +++++++++++++++---------- ff/element_fuzz.go | 136 +++ ff/element_mul_adx_amd64.s | 466 +++++++++ ff/element_mul_amd64.s | 488 +++++++++ ff/element_ops_amd64.go | 50 + ff/element_ops_amd64.s | 340 +++++++ ff/element_ops_noasm.go | 78 ++ ff/element_test.go | 1898 +++++++++++++++++++++++++++++++++--- ff/util.go | 6 - go.mod | 5 + go.sum | 4 + poseidon/poseidon.go | 2 +- 17 files changed, 4215 insertions(+), 739 deletions(-) create mode 100644 ff/asm.go create mode 100644 ff/asm_noadx.go create mode 100644 ff/doc.go create mode 100644 ff/element_fuzz.go create mode 100644 ff/element_mul_adx_amd64.s create mode 100644 ff/element_mul_amd64.s create mode 100644 ff/element_ops_amd64.go create mode 100644 ff/element_ops_amd64.s create mode 100644 ff/element_ops_noasm.go delete mode 100644 ff/util.go diff --git a/babyjub/babyjub.go b/babyjub/babyjub.go index 4e42f2a..317dc86 100644 --- a/babyjub/babyjub.go +++ b/babyjub/babyjub.go @@ -95,20 +95,20 @@ func (p *PointProjective) Add(q *PointProjective, o *PointProjective) *PointProj c := ff.NewElement().Mul(q.X, o.X) d := ff.NewElement().Mul(q.Y, o.Y) e := ff.NewElement().Mul(Dff, c) - e.MulAssign(d) + e.Mul(e, d) f := ff.NewElement().Sub(b, e) g := ff.NewElement().Add(b, e) x1y1 := ff.NewElement().Add(q.X, q.Y) x2y2 := ff.NewElement().Add(o.X, o.Y) x3 := ff.NewElement().Mul(x1y1, x2y2) - x3.SubAssign(c) - x3.SubAssign(d) - x3.MulAssign(a) - x3.MulAssign(f) + x3.Sub(x3, c) + x3.Sub(x3, d) + x3.Mul(x3, a) + x3.Mul(x3, f) ac := ff.NewElement().Mul(Aff, c) y3 := ff.NewElement().Sub(d, ac) - y3.MulAssign(a) - y3.MulAssign(g) + y3.Mul(y3, a) + y3.Mul(y3, g) z3 := ff.NewElement().Mul(f, g) p.X = x3 diff --git a/ff/arith.go b/ff/arith.go index 938c87a..790067f 100644 --- a/ff/arith.go +++ b/ff/arith.go @@ -1,4 +1,4 @@ -// Copyright 2020 ConsenSys AG +// Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Code generated by goff DO NOT EDIT +// Code generated by consensys/gnark-crypto DO NOT EDIT package ff @@ -20,15 +20,6 @@ import ( "math/bits" ) -func madd(a, b, t, u, v uint64) (uint64, uint64, uint64) { - var carry uint64 - hi, lo := bits.Mul64(a, b) - v, carry = bits.Add64(lo, v, 0) - u, carry = bits.Add64(hi, u, carry) - t, _ = bits.Add64(t, 0, carry) - return t, u, v -} - // madd0 hi = a*b + c (discards lo bits) func madd0(a, b, c uint64) (hi uint64) { var carry, lo uint64 @@ -58,59 +49,6 @@ func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { return } -// madd2s superhi, hi, lo = 2*a*b + c + d + e -func madd2s(a, b, c, d, e uint64) (superhi, hi, lo uint64) { - var carry, sum uint64 - - hi, lo = bits.Mul64(a, b) - lo, carry = bits.Add64(lo, lo, 0) - hi, superhi = bits.Add64(hi, hi, carry) - - sum, carry = bits.Add64(c, e, 0) - hi, _ = bits.Add64(hi, 0, carry) - lo, carry = bits.Add64(lo, sum, 0) - hi, _ = bits.Add64(hi, 0, carry) - hi, _ = bits.Add64(hi, 0, d) - return -} - -func madd1s(a, b, d, e uint64) (superhi, hi, lo uint64) { - var carry uint64 - - hi, lo = bits.Mul64(a, b) - lo, carry = bits.Add64(lo, lo, 0) - hi, superhi = bits.Add64(hi, hi, carry) - lo, carry = bits.Add64(lo, e, 0) - hi, _ = bits.Add64(hi, 0, carry) - hi, _ = bits.Add64(hi, 0, d) - return -} - -func madd2sb(a, b, c, e uint64) (superhi, hi, lo uint64) { - var carry, sum uint64 - - hi, lo = bits.Mul64(a, b) - lo, carry = bits.Add64(lo, lo, 0) - hi, superhi = bits.Add64(hi, hi, carry) - - sum, carry = bits.Add64(c, e, 0) - hi, _ = bits.Add64(hi, 0, carry) - lo, carry = bits.Add64(lo, sum, 0) - hi, _ = bits.Add64(hi, 0, carry) - return -} - -func madd1sb(a, b, e uint64) (superhi, hi, lo uint64) { - var carry uint64 - - hi, lo = bits.Mul64(a, b) - lo, carry = bits.Add64(lo, lo, 0) - hi, superhi = bits.Add64(hi, hi, carry) - lo, carry = bits.Add64(lo, e, 0) - hi, _ = bits.Add64(hi, 0, carry) - return -} - func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { var carry uint64 hi, lo = bits.Mul64(a, b) diff --git a/ff/asm.go b/ff/asm.go new file mode 100644 index 0000000..2718ff3 --- /dev/null +++ b/ff/asm.go @@ -0,0 +1,24 @@ +//go:build !noadx +// +build !noadx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package ff + +import "golang.org/x/sys/cpu" + +var supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 diff --git a/ff/asm_noadx.go b/ff/asm_noadx.go new file mode 100644 index 0000000..23c3a0b --- /dev/null +++ b/ff/asm_noadx.go @@ -0,0 +1,25 @@ +//go:build noadx +// +build noadx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package ff + +// note: this is needed for test purposes, as dynamically changing supportAdx doesn't flag +// certain errors (like fatal error: missing stackmap) +// this ensures we test all asm path. +var supportAdx = false diff --git a/ff/doc.go b/ff/doc.go new file mode 100644 index 0000000..114a4eb --- /dev/null +++ b/ff/doc.go @@ -0,0 +1,43 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// Package ff contains field arithmetic operations for modulus = 0x30644e...000001. +// +// The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x for the modular multiplication on amd64, see also https://hackmd.io/@zkteam/modular_multiplication) +// +// The modulus is hardcoded in all the operations. +// +// Field elements are represented as an array, and assumed to be in Montgomery form in all methods: +// type Element [4]uint64 +// +// Example API signature +// // Mul z = x * y mod q +// func (z *Element) Mul(x, y *Element) *Element +// +// and can be used like so: +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) +// +// Modulus +// 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 // base 16 +// 21888242871839275222246405745257275088548364400416034343698204186575808495617 // base 10 +package ff diff --git a/ff/element.go b/ff/element.go index 60b4e6b..c2ff2bc 100644 --- a/ff/element.go +++ b/ff/element.go @@ -1,4 +1,4 @@ -// Copyright 2020 ConsenSys AG +// Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,44 +12,95 @@ // See the License for the specific language governing permissions and // limitations under the License. -// field modulus q = -// -// 21888242871839275222246405745257275088548364400416034343698204186575808495617 -// Code generated by goff DO NOT EDIT -// goff version: - build: -// Element are assumed to be in Montgomery form in all methods +// Code generated by consensys/gnark-crypto DO NOT EDIT -// Package ff (generated by goff) contains field arithmetics operations package ff import ( "crypto/rand" "encoding/binary" + "errors" "io" "math/big" "math/bits" + "reflect" + "strconv" "sync" - - "unsafe" ) // Element represents a field element stored on 4 words (uint64) // Element are assumed to be in Montgomery form in all methods +// field modulus q = +// +// 21888242871839275222246405745257275088548364400416034343698204186575808495617 type Element [4]uint64 -// ElementLimbs number of 64 bits words needed to represent Element -const ElementLimbs = 4 +// Limbs number of 64 bits words needed to represent Element +const Limbs = 4 + +// Bits number bits needed to represent Element +const Bits = 254 + +// Bytes number bytes needed to represent Element +const Bytes = Limbs * 8 + +// field modulus stored as big.Int +var _modulus big.Int + +// Modulus returns q as a big.Int +// q = +// +// 21888242871839275222246405745257275088548364400416034343698204186575808495617 +func Modulus() *big.Int { + return new(big.Int).Set(&_modulus) +} + +// q (modulus) +var qElement = Element{ + 4891460686036598785, + 2896914383306846353, + 13281191951274694749, + 3486998266802970665, +} + +// rSquare +var rSquare = Element{ + 1997599621687373223, + 6052339484930628067, + 10108755138030829701, + 150537098327114917, +} + +var bigIntPool = sync.Pool{ + New: func() interface{} { + return new(big.Int) + }, +} + +func init() { + _modulus.SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) +} + +// NewElement returns a new Element +func NewElement() *Element { + return &Element{} +} -// ElementBits number bits needed to represent Element -const ElementBits = 254 +// NewElementFromUint64 returns a new Element from a uint64 value +// +// it is equivalent to +// var v NewElement +// v.SetUint64(...) +func NewElementFromUint64(v uint64) Element { + z := Element{v} + z.Mul(&z, &rSquare) + return z +} // SetUint64 z = v, sets z LSB to v (non-Montgomery form) and convert z to Montgomery form func (z *Element) SetUint64(v uint64) *Element { - z[0] = v - z[1] = 0 - z[2] = 0 - z[3] = 0 - return z.ToMont() + *z = Element{v} + return z.Mul(z, &rSquare) // z.ToMont() } // Set z = x @@ -61,6 +112,33 @@ func (z *Element) Set(x *Element) *Element { return z } +// SetInterface converts provided interface into Element +// returns an error if provided type is not supported +// supported types: Element, *Element, uint64, int, string (interpreted as base10 integer), +// *big.Int, big.Int, []byte +func (z *Element) SetInterface(i1 interface{}) (*Element, error) { + switch c1 := i1.(type) { + case Element: + return z.Set(&c1), nil + case *Element: + return z.Set(c1), nil + case uint64: + return z.SetUint64(c1), nil + case int: + return z.SetString(strconv.Itoa(c1)), nil + case string: + return z.SetString(c1), nil + case *big.Int: + return z.SetBigInt(c1), nil + case big.Int: + return z.SetBigInt(&c1), nil + case []byte: + return z.SetBytes(c1), nil + default: + return nil, errors.New("can't set ff.Element from type " + reflect.TypeOf(i1).String()) + } +} + // SetZero z = 0 func (z *Element) SetZero() *Element { z[0] = 0 @@ -79,19 +157,6 @@ func (z *Element) SetOne() *Element { return z } -// Neg z = q - x -func (z *Element) Neg(x *Element) *Element { - if x.IsZero() { - return z.SetZero() - } - var borrow uint64 - z[0], borrow = bits.Sub64(4891460686036598785, x[0], 0) - z[1], borrow = bits.Sub64(2896914383306846353, x[1], borrow) - z[2], borrow = bits.Sub64(13281191951274694749, x[2], borrow) - z[3], _ = bits.Sub64(3486998266802970665, x[3], borrow) - return z -} - // Div z = x*y^-1 mod q func (z *Element) Div(x, y *Element) *Element { var yInv Element @@ -100,6 +165,16 @@ func (z *Element) Div(x, y *Element) *Element { return z } +// Bit returns the i'th bit, with lsb == bit 0. +// It is the responsability of the caller to convert from Montgomery to Regular form if needed +func (z *Element) Bit(i uint64) uint64 { + j := i / 64 + if j >= 4 { + return 0 + } + return uint64(z[j] >> (i % 64) & 1) +} + // Equal returns z == x func (z *Element) Equal(x *Element) bool { return (z[3] == x[3]) && (z[2] == x[2]) && (z[1] == x[1]) && (z[0] == x[0]) @@ -110,200 +185,70 @@ func (z *Element) IsZero() bool { return (z[3] | z[2] | z[1] | z[0]) == 0 } -// field modulus stored as big.Int -var _elementModulusBigInt big.Int -var onceelementModulus sync.Once - -func elementModulusBigInt() *big.Int { - onceelementModulus.Do(func() { - _elementModulusBigInt.SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) - }) - return &_elementModulusBigInt +// IsUint64 returns true if z[0] >= 0 and all other words are 0 +func (z *Element) IsUint64() bool { + return (z[3] | z[2] | z[1]) == 0 } -// Inverse z = x^-1 mod q -// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" -// if x == 0, sets and returns z = x -func (z *Element) Inverse(x *Element) *Element { - if x.IsZero() { - return z.Set(x) +// Cmp compares (lexicographic order) z and x and returns: +// +// -1 if z < x +// 0 if z == x +// +1 if z > x +// +func (z *Element) Cmp(x *Element) int { + _z := *z + _x := *x + _z.FromMont() + _x.FromMont() + if _z[3] > _x[3] { + return 1 + } else if _z[3] < _x[3] { + return -1 } - - // initialize u = q - var u = Element{ - 4891460686036598785, - 2896914383306846353, - 13281191951274694749, - 3486998266802970665, + if _z[2] > _x[2] { + return 1 + } else if _z[2] < _x[2] { + return -1 } - - // initialize s = r^2 - var s = Element{ - 1997599621687373223, - 6052339484930628067, - 10108755138030829701, - 150537098327114917, + if _z[1] > _x[1] { + return 1 + } else if _z[1] < _x[1] { + return -1 } + if _z[0] > _x[0] { + return 1 + } else if _z[0] < _x[0] { + return -1 + } + return 0 +} - // r = 0 - r := Element{} - - v := *x - - var carry, borrow, t, t2 uint64 - var bigger, uIsOne, vIsOne bool - - for !uIsOne && !vIsOne { - for v[0]&1 == 0 { - - // v = v >> 1 - t2 = v[3] << 63 - v[3] >>= 1 - t = t2 - t2 = v[2] << 63 - v[2] = (v[2] >> 1) | t - t = t2 - t2 = v[1] << 63 - v[1] = (v[1] >> 1) | t - t = t2 - v[0] = (v[0] >> 1) | t - - if s[0]&1 == 1 { - - // s = s + q - s[0], carry = bits.Add64(s[0], 4891460686036598785, 0) - s[1], carry = bits.Add64(s[1], 2896914383306846353, carry) - s[2], carry = bits.Add64(s[2], 13281191951274694749, carry) - s[3], _ = bits.Add64(s[3], 3486998266802970665, carry) - - } - - // s = s >> 1 - t2 = s[3] << 63 - s[3] >>= 1 - t = t2 - t2 = s[2] << 63 - s[2] = (s[2] >> 1) | t - t = t2 - t2 = s[1] << 63 - s[1] = (s[1] >> 1) | t - t = t2 - s[0] = (s[0] >> 1) | t - - } - for u[0]&1 == 0 { - - // u = u >> 1 - t2 = u[3] << 63 - u[3] >>= 1 - t = t2 - t2 = u[2] << 63 - u[2] = (u[2] >> 1) | t - t = t2 - t2 = u[1] << 63 - u[1] = (u[1] >> 1) | t - t = t2 - u[0] = (u[0] >> 1) | t - - if r[0]&1 == 1 { - - // r = r + q - r[0], carry = bits.Add64(r[0], 4891460686036598785, 0) - r[1], carry = bits.Add64(r[1], 2896914383306846353, carry) - r[2], carry = bits.Add64(r[2], 13281191951274694749, carry) - r[3], _ = bits.Add64(r[3], 3486998266802970665, carry) - - } - - // r = r >> 1 - t2 = r[3] << 63 - r[3] >>= 1 - t = t2 - t2 = r[2] << 63 - r[2] = (r[2] >> 1) | t - t = t2 - t2 = r[1] << 63 - r[1] = (r[1] >> 1) | t - t = t2 - r[0] = (r[0] >> 1) | t - - } - - // v >= u - bigger = !(v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))) - - if bigger { - - // v = v - u - v[0], borrow = bits.Sub64(v[0], u[0], 0) - v[1], borrow = bits.Sub64(v[1], u[1], borrow) - v[2], borrow = bits.Sub64(v[2], u[2], borrow) - v[3], _ = bits.Sub64(v[3], u[3], borrow) - - // r >= s - bigger = !(r[3] < s[3] || (r[3] == s[3] && (r[2] < s[2] || (r[2] == s[2] && (r[1] < s[1] || (r[1] == s[1] && (r[0] < s[0]))))))) - - if bigger { - - // s = s + q - s[0], carry = bits.Add64(s[0], 4891460686036598785, 0) - s[1], carry = bits.Add64(s[1], 2896914383306846353, carry) - s[2], carry = bits.Add64(s[2], 13281191951274694749, carry) - s[3], _ = bits.Add64(s[3], 3486998266802970665, carry) - - } - - // s = s - r - s[0], borrow = bits.Sub64(s[0], r[0], 0) - s[1], borrow = bits.Sub64(s[1], r[1], borrow) - s[2], borrow = bits.Sub64(s[2], r[2], borrow) - s[3], _ = bits.Sub64(s[3], r[3], borrow) - - } else { - - // u = u - v - u[0], borrow = bits.Sub64(u[0], v[0], 0) - u[1], borrow = bits.Sub64(u[1], v[1], borrow) - u[2], borrow = bits.Sub64(u[2], v[2], borrow) - u[3], _ = bits.Sub64(u[3], v[3], borrow) - - // s >= r - bigger = !(s[3] < r[3] || (s[3] == r[3] && (s[2] < r[2] || (s[2] == r[2] && (s[1] < r[1] || (s[1] == r[1] && (s[0] < r[0]))))))) - - if bigger { - - // r = r + q - r[0], carry = bits.Add64(r[0], 4891460686036598785, 0) - r[1], carry = bits.Add64(r[1], 2896914383306846353, carry) - r[2], carry = bits.Add64(r[2], 13281191951274694749, carry) - r[3], _ = bits.Add64(r[3], 3486998266802970665, carry) - - } - - // r = r - s - r[0], borrow = bits.Sub64(r[0], s[0], 0) - r[1], borrow = bits.Sub64(r[1], s[1], borrow) - r[2], borrow = bits.Sub64(r[2], s[2], borrow) - r[3], _ = bits.Sub64(r[3], s[3], borrow) +// LexicographicallyLargest returns true if this element is strictly lexicographically +// larger than its negation, false otherwise +func (z *Element) LexicographicallyLargest() bool { + // adapted from github.com/zkcrypto/bls12_381 + // we check if the element is larger than (q-1) / 2 + // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - } - uIsOne = (u[0] == 1) && (u[3]|u[2]|u[1]) == 0 - vIsOne = (v[0] == 1) && (v[3]|v[2]|v[1]) == 0 - } + _z := *z + _z.FromMont() - if uIsOne { - z.Set(&r) - } else { - z.Set(&s) - } + var b uint64 + _, b = bits.Sub64(_z[0], 11669102379873075201, 0) + _, b = bits.Sub64(_z[1], 10671829228508198984, b) + _, b = bits.Sub64(_z[2], 15863968012492123182, b) + _, b = bits.Sub64(_z[3], 1743499133401485332, b) - return z + return b == 0 } // SetRandom sets z to a random element < q -func (z *Element) SetRandom() *Element { - bytes := make([]byte, 32) - io.ReadFull(rand.Reader, bytes) +func (z *Element) SetRandom() (*Element, error) { + var bytes [32]byte + if _, err := io.ReadFull(rand.Reader, bytes[:]); err != nil { + return nil, err + } z[0] = binary.BigEndian.Uint64(bytes[0:8]) z[1] = binary.BigEndian.Uint64(bytes[8:16]) z[2] = binary.BigEndian.Uint64(bytes[16:24]) @@ -311,6 +256,7 @@ func (z *Element) SetRandom() *Element { z[3] %= 3486998266802970665 // if z > q --> z -= q + // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { var b uint64 z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) @@ -319,125 +265,156 @@ func (z *Element) SetRandom() *Element { z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) } - return z + return z, nil } -// Add z = x + y mod q -func (z *Element) Add(x, y *Element) *Element { - var carry uint64 +// One returns 1 (in montgommery form) +func One() Element { + var one Element + one.SetOne() + return one +} - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], _ = bits.Add64(x[3], y[3], carry) +// Halve sets z to z / 2 (mod p) +func (z *Element) Halve() { + if z[0]&1 == 1 { + var carry uint64 + + // z = z + q + z[0], carry = bits.Add64(z[0], 4891460686036598785, 0) + z[1], carry = bits.Add64(z[1], 2896914383306846353, carry) + z[2], carry = bits.Add64(z[2], 13281191951274694749, carry) + z[3], _ = bits.Add64(z[3], 3486998266802970665, carry) - // if z > q --> z -= q - if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) - z[1], b = bits.Sub64(z[1], 2896914383306846353, b) - z[2], b = bits.Sub64(z[2], 13281191951274694749, b) - z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) } + + // z = z >> 1 + + z[0] = z[0]>>1 | z[1]<<63 + z[1] = z[1]>>1 | z[2]<<63 + z[2] = z[2]>>1 | z[3]<<63 + z[3] >>= 1 + +} + +// API with assembly impl + +// Mul z = x * y mod q +// see https://hackmd.io/@zkteam/modular_multiplication +func (z *Element) Mul(x, y *Element) *Element { + mul(z, x, y) return z } -// AddAssign z = z + x mod q -func (z *Element) AddAssign(x *Element) *Element { - var carry uint64 +// Square z = x * x mod q +// see https://hackmd.io/@zkteam/modular_multiplication +func (z *Element) Square(x *Element) *Element { + mul(z, x, x) + return z +} - z[0], carry = bits.Add64(z[0], x[0], 0) - z[1], carry = bits.Add64(z[1], x[1], carry) - z[2], carry = bits.Add64(z[2], x[2], carry) - z[3], _ = bits.Add64(z[3], x[3], carry) +// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// sets and returns z = z * 1 +func (z *Element) FromMont() *Element { + fromMont(z) + return z +} - // if z > q --> z -= q - if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) - z[1], b = bits.Sub64(z[1], 2896914383306846353, b) - z[2], b = bits.Sub64(z[2], 13281191951274694749, b) - z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) - } +// Add z = x + y mod q +func (z *Element) Add(x, y *Element) *Element { + add(z, x, y) return z } // Double z = x + x mod q, aka Lsh 1 func (z *Element) Double(x *Element) *Element { - var carry uint64 - - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], _ = bits.Add64(x[3], x[3], carry) - - // if z > q --> z -= q - if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) - z[1], b = bits.Sub64(z[1], 2896914383306846353, b) - z[2], b = bits.Sub64(z[2], 13281191951274694749, b) - z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) - } + double(z, x) return z } // Sub z = x - y mod q func (z *Element) Sub(x, y *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], 4891460686036598785, 0) - z[1], c = bits.Add64(z[1], 2896914383306846353, c) - z[2], c = bits.Add64(z[2], 13281191951274694749, c) - z[3], _ = bits.Add64(z[3], 3486998266802970665, c) - } + sub(z, x, y) return z } -// SubAssign z = z - x mod q -func (z *Element) SubAssign(x *Element) *Element { - var b uint64 - z[0], b = bits.Sub64(z[0], x[0], 0) - z[1], b = bits.Sub64(z[1], x[1], b) - z[2], b = bits.Sub64(z[2], x[2], b) - z[3], b = bits.Sub64(z[3], x[3], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], 4891460686036598785, 0) - z[1], c = bits.Add64(z[1], 2896914383306846353, c) - z[2], c = bits.Add64(z[2], 13281191951274694749, c) - z[3], _ = bits.Add64(z[3], 3486998266802970665, c) - } +// Neg z = q - x +func (z *Element) Neg(x *Element) *Element { + neg(z, x) return z } -// Exp z = x^e mod q -func (z *Element) Exp(x Element, e uint64) *Element { - if e == 0 { - return z.SetOne() - } +// Generic (no ADX instructions, no AMD64) versions of multiplication and squaring algorithms - z.Set(&x) +func _mulGeneric(z, x, y *Element) { - l := bits.Len64(e) - 2 - for i := l; i >= 0; i-- { - z.Square(z) - if e&(1< q --> z -= q + // note: this is NOT constant time + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } +} +func _fromMontGeneric(z *Element) { // the following lines implement z = z * 1 // with a modified CIOS montgomery multiplication { @@ -478,6 +455,85 @@ func (z *Element) FromMont() *Element { } // if z > q --> z -= q + // note: this is NOT constant time + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } +} + +func _addGeneric(z, x, y *Element) { + var carry uint64 + + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], _ = bits.Add64(x[3], y[3], carry) + + // if z > q --> z -= q + // note: this is NOT constant time + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } +} + +func _doubleGeneric(z, x *Element) { + var carry uint64 + + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], _ = bits.Add64(x[3], x[3], carry) + + // if z > q --> z -= q + // note: this is NOT constant time + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } +} + +func _subGeneric(z, x, y *Element) { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], 4891460686036598785, 0) + z[1], c = bits.Add64(z[1], 2896914383306846353, c) + z[2], c = bits.Add64(z[2], 13281191951274694749, c) + z[3], _ = bits.Add64(z[3], 3486998266802970665, c) + } +} + +func _negGeneric(z, x *Element) { + if x.IsZero() { + z.SetZero() + return + } + var borrow uint64 + z[0], borrow = bits.Sub64(4891460686036598785, x[0], 0) + z[1], borrow = bits.Sub64(2896914383306846353, x[1], borrow) + z[2], borrow = bits.Sub64(13281191951274694749, x[2], borrow) + z[3], _ = bits.Sub64(3486998266802970665, x[3], borrow) +} + +func _reduceGeneric(z *Element) { + + // if z > q --> z -= q + // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { var b uint64 z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) @@ -485,19 +541,108 @@ func (z *Element) FromMont() *Element { z[2], b = bits.Sub64(z[2], 13281191951274694749, b) z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) } +} + +func mulByConstant(z *Element, c uint8) { + switch c { + case 0: + z.SetZero() + return + case 1: + return + case 2: + z.Double(z) + return + case 3: + _z := *z + z.Double(z).Add(z, &_z) + case 5: + _z := *z + z.Double(z).Double(z).Add(z, &_z) + default: + var y Element + y.SetUint64(uint64(c)) + z.Mul(z, &y) + } +} + +// BatchInvert returns a new slice with every element inverted. +// Uses Montgomery batch inversion trick +func BatchInvert(a []Element) []Element { + res := make([]Element, len(a)) + if len(a) == 0 { + return res + } + + zeroes := make([]bool, len(a)) + accumulator := One() + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + zeroes[i] = true + continue + } + res[i] = accumulator + accumulator.Mul(&accumulator, &a[i]) + } + + accumulator.Inverse(&accumulator) + + for i := len(a) - 1; i >= 0; i-- { + if zeroes[i] { + continue + } + res[i].Mul(&res[i], &accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + return res +} + +func _butterflyGeneric(a, b *Element) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// BitLen returns the minimum number of bits needed to represent z +// returns 0 if z == 0 +func (z *Element) BitLen() int { + if z[3] != 0 { + return 192 + bits.Len64(z[3]) + } + if z[2] != 0 { + return 128 + bits.Len64(z[2]) + } + if z[1] != 0 { + return 64 + bits.Len64(z[1]) + } + return bits.Len64(z[0]) +} + +// Exp z = x^exponent mod q +func (z *Element) Exp(x Element, exponent *big.Int) *Element { + var bZero big.Int + if exponent.Cmp(&bZero) == 0 { + return z.SetOne() + } + + z.Set(&x) + + for i := exponent.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if exponent.Bit(i) == 1 { + z.Mul(z, &x) + } + } + return z } // ToMont converts z to Montgomery form // sets and returns z = z * r^2 func (z *Element) ToMont() *Element { - var rSquare = Element{ - 1997599621687373223, - 6052339484930628067, - 10108755138030829701, - 150537098327114917, - } - return z.MulAssign(&rSquare) + return z.Mul(z, &rSquare) } // ToRegular returns z in regular form (doesn't mutate z) @@ -507,65 +652,110 @@ func (z Element) ToRegular() Element { // String returns the string form of an Element in Montgomery form func (z *Element) String() string { - var _z big.Int - return z.ToBigIntRegular(&_z).String() + zz := *z + zz.FromMont() + if zz.IsUint64() { + return strconv.FormatUint(zz[0], 10) + } else { + var zzNeg Element + zzNeg.Neg(z) + zzNeg.FromMont() + if zzNeg.IsUint64() { + return "-" + strconv.FormatUint(zzNeg[0], 10) + } + } + vv := bigIntPool.Get().(*big.Int) + defer bigIntPool.Put(vv) + return zz.ToBigInt(vv).String() } // ToBigInt returns z as a big.Int in Montgomery form func (z *Element) ToBigInt(res *big.Int) *big.Int { - if bits.UintSize == 64 { - bits := (*[4]big.Word)(unsafe.Pointer(z)) - return res.SetBits(bits[:]) - } else { - var bits [8]big.Word - for i := 0; i < len(z); i++ { - bits[i*2] = big.Word(z[i]) - bits[i*2+1] = big.Word(z[i] >> 32) - } - return res.SetBits(bits[:]) - } + var b [Limbs * 8]byte + binary.BigEndian.PutUint64(b[24:32], z[0]) + binary.BigEndian.PutUint64(b[16:24], z[1]) + binary.BigEndian.PutUint64(b[8:16], z[2]) + binary.BigEndian.PutUint64(b[0:8], z[3]) + + return res.SetBytes(b[:]) } // ToBigIntRegular returns z as a big.Int in regular form func (z Element) ToBigIntRegular(res *big.Int) *big.Int { z.FromMont() - if bits.UintSize == 64 { - bits := (*[4]big.Word)(unsafe.Pointer(&z)) - return res.SetBits(bits[:]) - } else { - var bits [8]big.Word - for i := 0; i < len(z); i++ { - bits[i*2] = big.Word(z[i]) - bits[i*2+1] = big.Word(z[i] >> 32) - } - return res.SetBits(bits[:]) - } + return z.ToBigInt(res) +} + +// Bytes returns the regular (non montgomery) value +// of z as a big-endian byte array. +func (z *Element) Bytes() (res [Limbs * 8]byte) { + _z := z.ToRegular() + binary.BigEndian.PutUint64(res[24:32], _z[0]) + binary.BigEndian.PutUint64(res[16:24], _z[1]) + binary.BigEndian.PutUint64(res[8:16], _z[2]) + binary.BigEndian.PutUint64(res[0:8], _z[3]) + + return +} + +// Marshal returns the regular (non montgomery) value +// of z as a big-endian byte slice. +func (z *Element) Marshal() []byte { + b := z.Bytes() + return b[:] +} + +// SetBytes interprets e as the bytes of a big-endian unsigned integer, +// sets z to that value (in Montgomery form), and returns z. +func (z *Element) SetBytes(e []byte) *Element { + // get a big int from our pool + vv := bigIntPool.Get().(*big.Int) + vv.SetBytes(e) + + // set big int + z.SetBigInt(vv) + + // put temporary object back in pool + bigIntPool.Put(vv) + + return z } // SetBigInt sets z to v (regular form) and returns z in Montgomery form func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() - zero := big.NewInt(0) - q := elementModulusBigInt() + var zero big.Int - // copy input - vv := new(big.Int).Set(v) - - // while v < 0, v+=q - for vv.Cmp(zero) == -1 { - vv.Add(vv, q) - } - // while v > q, v-=q - for vv.Cmp(q) == 1 { - vv.Sub(vv, q) - } - // if v == q, return 0 - if vv.Cmp(q) == 0 { + // fast path + c := v.Cmp(&_modulus) + if c == 0 { + // v == 0 return z + } else if c != 1 && v.Cmp(&zero) != -1 { + // 0 < v < q + return z.setBigInt(v) } - // v should - vBits := vv.Bits() + + // get temporary big int from the pool + vv := bigIntPool.Get().(*big.Int) + + // copy input + modular reduction + vv.Set(v) + vv.Mod(v, &_modulus) + + // set big int byte value + z.setBigInt(vv) + + // release object into pool + bigIntPool.Put(vv) + return z +} + +// setBigInt assumes 0 <= v < q +func (z *Element) setBigInt(v *big.Int) *Element { + vBits := v.Bits() + if bits.UintSize == 64 { for i := 0; i < len(vBits); i++ { z[i] = uint64(vBits[i]) @@ -579,214 +769,269 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } } } + return z.ToMont() } // SetString creates a big.Int with s (in base 10) and calls SetBigInt on z func (z *Element) SetString(s string) *Element { - x, ok := new(big.Int).SetString(s, 10) - if !ok { + // get temporary big int from the pool + vv := bigIntPool.Get().(*big.Int) + + if _, ok := vv.SetString(s, 10); !ok { panic("Element.SetString failed -> can't parse number in base10 into a big.Int") } - return z.SetBigInt(x) + z.SetBigInt(vv) + + // release object into pool + bigIntPool.Put(vv) + + return z } -// Mul z = x * y mod q -func (z *Element) Mul(x, y *Element) *Element { +var ( + _bLegendreExponentElement *big.Int + _bSqrtExponentElement *big.Int +) - var t [4]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) - } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - z[3], z[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) +func init() { + _bLegendreExponentElement, _ = new(big.Int).SetString("183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f8000000", 16) + const sqrtExponentElement = "183227397098d014dc2822db40c0ac2e9419f4243cdcb848a1f0fac9f" + _bSqrtExponentElement, _ = new(big.Int).SetString(sqrtExponentElement, 16) +} + +// Legendre returns the Legendre symbol of z (either +1, -1, or 0.) +func (z *Element) Legendre() int { + var l Element + // z^((q-1)/2) + l.Exp(*z, _bLegendreExponentElement) + + if l.IsZero() { + return 0 } - // if z > q --> z -= q - if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) - z[1], b = bits.Sub64(z[1], 2896914383306846353, b) - z[2], b = bits.Sub64(z[2], 13281191951274694749, b) - z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + // if l == 1 + if (l[3] == 1011752739694698287) && (l[2] == 7381016538464732718) && (l[1] == 3962172157175319849) && (l[0] == 12436184717236109307) { + return 1 } - return z + return -1 } -// MulAssign z = z * x mod q -func (z *Element) MulAssign(x *Element) *Element { - - var t [4]uint64 - var c [3]uint64 - { - // round 0 - v := z[0] - c[1], c[0] = bits.Mul64(v, x[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd1(v, x[1], c[1]) - c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd1(v, x[2], c[1]) - c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd1(v, x[3], c[1]) - t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) +// Sqrt z = √x mod q +// if the square root doesn't exist (x is not a square mod q) +// Sqrt leaves z unchanged and returns nil +func (z *Element) Sqrt(x *Element) *Element { + // q ≡ 1 (mod 4) + // see modSqrtTonelliShanks in math/big/int.go + // using https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf + + var y, b, t, w Element + // w = x^((s-1)/2)) + w.Exp(*x, _bSqrtExponentElement) + + // y = x^((s+1)/2)) = w * x + y.Mul(x, &w) + + // b = x^s = w * w * x = y * x + b.Mul(&w, &y) + + // g = nonResidue ^ s + var g = Element{ + 7164790868263648668, + 11685701338293206998, + 6216421865291908056, + 1756667274303109607, } - { - // round 1 - v := z[1] - c[1], c[0] = madd1(v, x[0], t[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd2(v, x[1], c[1], t[1]) - c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd2(v, x[2], c[1], t[2]) - c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd2(v, x[3], c[1], t[3]) - t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + r := uint64(28) + + // compute legendre symbol + // t = x^((q-1)/2) = r-1 squaring of x^s + t = b + for i := uint64(0); i < r-1; i++ { + t.Square(&t) } - { - // round 2 - v := z[2] - c[1], c[0] = madd1(v, x[0], t[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd2(v, x[1], c[1], t[1]) - c[2], t[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd2(v, x[2], c[1], t[2]) - c[2], t[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd2(v, x[3], c[1], t[3]) - t[3], t[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + if t.IsZero() { + return z.SetZero() } - { - // round 3 - v := z[3] - c[1], c[0] = madd1(v, x[0], t[0]) - m := c[0] * 14042775128853446655 - c[2] = madd0(m, 4891460686036598785, c[0]) - c[1], c[0] = madd2(v, x[1], c[1], t[1]) - c[2], z[0] = madd2(m, 2896914383306846353, c[2], c[0]) - c[1], c[0] = madd2(v, x[2], c[1], t[2]) - c[2], z[1] = madd2(m, 13281191951274694749, c[2], c[0]) - c[1], c[0] = madd2(v, x[3], c[1], t[3]) - z[3], z[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) + if !((t[3] == 1011752739694698287) && (t[2] == 7381016538464732718) && (t[1] == 3962172157175319849) && (t[0] == 12436184717236109307)) { + // t != 1, we don't have a square root + return nil } + for { + var m uint64 + t = b + + // for t != 1 + for !((t[3] == 1011752739694698287) && (t[2] == 7381016538464732718) && (t[1] == 3962172157175319849) && (t[0] == 12436184717236109307)) { + t.Square(&t) + m++ + } - // if z > q --> z -= q - if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) - z[1], b = bits.Sub64(z[1], 2896914383306846353, b) - z[2], b = bits.Sub64(z[2], 13281191951274694749, b) - z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + if m == 0 { + return z.Set(&y) + } + // t = g^(2^(r-m-1)) mod q + ge := int(r - m - 1) + t = g + for ge > 0 { + t.Square(&t) + ge-- + } + + g.Square(&t) + y.Mul(&y, &t) + b.Mul(&b, &g) + r = m } - return z } -// Square z = x * x mod q -func (z *Element) Square(x *Element) *Element { - - var p [4]uint64 - - var u, v uint64 - { - // round 0 - u, p[0] = bits.Mul64(x[0], x[0]) - m := p[0] * 14042775128853446655 - C := madd0(m, 4891460686036598785, p[0]) - var t uint64 - t, u, v = madd1sb(x[0], x[1], u) - C, p[0] = madd2(m, 2896914383306846353, v, C) - t, u, v = madd1s(x[0], x[2], t, u) - C, p[1] = madd2(m, 13281191951274694749, v, C) - _, u, v = madd1s(x[0], x[3], t, u) - p[3], p[2] = madd3(m, 3486998266802970665, v, C, u) - } - { - // round 1 - m := p[0] * 14042775128853446655 - C := madd0(m, 4891460686036598785, p[0]) - u, v = madd1(x[1], x[1], p[1]) - C, p[0] = madd2(m, 2896914383306846353, v, C) - var t uint64 - t, u, v = madd2sb(x[1], x[2], p[2], u) - C, p[1] = madd2(m, 13281191951274694749, v, C) - _, u, v = madd2s(x[1], x[3], p[3], t, u) - p[3], p[2] = madd3(m, 3486998266802970665, v, C, u) +// Inverse z = x^-1 mod q +// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" +// if x == 0, sets and returns z = x +func (z *Element) Inverse(x *Element) *Element { + if x.IsZero() { + z.SetZero() + return z } - { - // round 2 - m := p[0] * 14042775128853446655 - C := madd0(m, 4891460686036598785, p[0]) - C, p[0] = madd2(m, 2896914383306846353, p[1], C) - u, v = madd1(x[2], x[2], p[2]) - C, p[1] = madd2(m, 13281191951274694749, v, C) - _, u, v = madd2sb(x[2], x[3], p[3], u) - p[3], p[2] = madd3(m, 3486998266802970665, v, C, u) + + // initialize u = q + var u = Element{ + 4891460686036598785, + 2896914383306846353, + 13281191951274694749, + 3486998266802970665, } - { - // round 3 - m := p[0] * 14042775128853446655 - C := madd0(m, 4891460686036598785, p[0]) - C, z[0] = madd2(m, 2896914383306846353, p[1], C) - C, z[1] = madd2(m, 13281191951274694749, p[2], C) - u, v = madd1(x[3], x[3], p[3]) - z[3], z[2] = madd3(m, 3486998266802970665, v, C, u) + + // initialize s = r^2 + var s = Element{ + 1997599621687373223, + 6052339484930628067, + 10108755138030829701, + 150537098327114917, } - // if z > q --> z -= q - if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) - z[1], b = bits.Sub64(z[1], 2896914383306846353, b) - z[2], b = bits.Sub64(z[2], 13281191951274694749, b) - z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + // r = 0 + r := Element{} + + v := *x + + var carry, borrow uint64 + var bigger bool + + for { + for v[0]&1 == 0 { + + // v = v >> 1 + + v[0] = v[0]>>1 | v[1]<<63 + v[1] = v[1]>>1 | v[2]<<63 + v[2] = v[2]>>1 | v[3]<<63 + v[3] >>= 1 + + if s[0]&1 == 1 { + + // s = s + q + s[0], carry = bits.Add64(s[0], 4891460686036598785, 0) + s[1], carry = bits.Add64(s[1], 2896914383306846353, carry) + s[2], carry = bits.Add64(s[2], 13281191951274694749, carry) + s[3], _ = bits.Add64(s[3], 3486998266802970665, carry) + + } + + // s = s >> 1 + + s[0] = s[0]>>1 | s[1]<<63 + s[1] = s[1]>>1 | s[2]<<63 + s[2] = s[2]>>1 | s[3]<<63 + s[3] >>= 1 + + } + for u[0]&1 == 0 { + + // u = u >> 1 + + u[0] = u[0]>>1 | u[1]<<63 + u[1] = u[1]>>1 | u[2]<<63 + u[2] = u[2]>>1 | u[3]<<63 + u[3] >>= 1 + + if r[0]&1 == 1 { + + // r = r + q + r[0], carry = bits.Add64(r[0], 4891460686036598785, 0) + r[1], carry = bits.Add64(r[1], 2896914383306846353, carry) + r[2], carry = bits.Add64(r[2], 13281191951274694749, carry) + r[3], _ = bits.Add64(r[3], 3486998266802970665, carry) + + } + + // r = r >> 1 + + r[0] = r[0]>>1 | r[1]<<63 + r[1] = r[1]>>1 | r[2]<<63 + r[2] = r[2]>>1 | r[3]<<63 + r[3] >>= 1 + + } + + // v >= u + bigger = !(v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))) + + if bigger { + + // v = v - u + v[0], borrow = bits.Sub64(v[0], u[0], 0) + v[1], borrow = bits.Sub64(v[1], u[1], borrow) + v[2], borrow = bits.Sub64(v[2], u[2], borrow) + v[3], _ = bits.Sub64(v[3], u[3], borrow) + + // s = s - r + s[0], borrow = bits.Sub64(s[0], r[0], 0) + s[1], borrow = bits.Sub64(s[1], r[1], borrow) + s[2], borrow = bits.Sub64(s[2], r[2], borrow) + s[3], borrow = bits.Sub64(s[3], r[3], borrow) + + if borrow == 1 { + + // s = s + q + s[0], carry = bits.Add64(s[0], 4891460686036598785, 0) + s[1], carry = bits.Add64(s[1], 2896914383306846353, carry) + s[2], carry = bits.Add64(s[2], 13281191951274694749, carry) + s[3], _ = bits.Add64(s[3], 3486998266802970665, carry) + + } + } else { + + // u = u - v + u[0], borrow = bits.Sub64(u[0], v[0], 0) + u[1], borrow = bits.Sub64(u[1], v[1], borrow) + u[2], borrow = bits.Sub64(u[2], v[2], borrow) + u[3], _ = bits.Sub64(u[3], v[3], borrow) + + // r = r - s + r[0], borrow = bits.Sub64(r[0], s[0], 0) + r[1], borrow = bits.Sub64(r[1], s[1], borrow) + r[2], borrow = bits.Sub64(r[2], s[2], borrow) + r[3], borrow = bits.Sub64(r[3], s[3], borrow) + + if borrow == 1 { + + // r = r + q + r[0], carry = bits.Add64(r[0], 4891460686036598785, 0) + r[1], carry = bits.Add64(r[1], 2896914383306846353, carry) + r[2], carry = bits.Add64(r[2], 13281191951274694749, carry) + r[3], _ = bits.Add64(r[3], 3486998266802970665, carry) + + } + } + if (u[0] == 1) && (u[3]|u[2]|u[1]) == 0 { + z.Set(&r) + return z + } + if (v[0] == 1) && (v[3]|v[2]|v[1]) == 0 { + z.Set(&s) + return z + } } - return z + } diff --git a/ff/element_fuzz.go b/ff/element_fuzz.go new file mode 100644 index 0000000..cfb088a --- /dev/null +++ b/ff/element_fuzz.go @@ -0,0 +1,136 @@ +//go:build gofuzz +// +build gofuzz + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package ff + +import ( + "bytes" + "encoding/binary" + "io" + "math/big" + "math/bits" +) + +const ( + fuzzInteresting = 1 + fuzzNormal = 0 + fuzzDiscard = -1 +) + +// Fuzz arithmetic operations fuzzer +func Fuzz(data []byte) int { + r := bytes.NewReader(data) + + var e1, e2 Element + e1.SetRawBytes(r) + e2.SetRawBytes(r) + + { + // mul assembly + + var c, _c Element + a, _a, b, _b := e1, e1, e2, e2 + c.Mul(&a, &b) + _mulGeneric(&_c, &_a, &_b) + + if !c.Equal(&_c) { + panic("mul asm != mul generic on Element") + } + } + + { + // inverse + inv := e1 + inv.Inverse(&inv) + + var bInv, b1, b2 big.Int + e1.ToBigIntRegular(&b1) + bInv.ModInverse(&b1, Modulus()) + inv.ToBigIntRegular(&b2) + + if b2.Cmp(&bInv) != 0 { + panic("inverse operation doesn't match big int result") + } + } + + { + // a + -a == 0 + a, b := e1, e1 + b.Neg(&b) + a.Add(&a, &b) + if !a.IsZero() { + panic("a + -a != 0") + } + } + + return fuzzNormal + +} + +// SetRawBytes reads up to Bytes (bytes needed to represent Element) from reader +// and interpret it as big endian uint64 +// used for fuzzing purposes only +func (z *Element) SetRawBytes(r io.Reader) { + + buf := make([]byte, 8) + + for i := 0; i < len(z); i++ { + if _, err := io.ReadFull(r, buf); err != nil { + goto eof + } + z[i] = binary.BigEndian.Uint64(buf[:]) + } +eof: + z[3] %= qElement[3] + + if z.BiggerModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], qElement[0], 0) + z[1], b = bits.Sub64(z[1], qElement[1], b) + z[2], b = bits.Sub64(z[2], qElement[2], b) + z[3], b = bits.Sub64(z[3], qElement[3], b) + } + + return +} + +func (z *Element) BiggerModulus() bool { + if z[3] > qElement[3] { + return true + } + if z[3] < qElement[3] { + return false + } + + if z[2] > qElement[2] { + return true + } + if z[2] < qElement[2] { + return false + } + + if z[1] > qElement[1] { + return true + } + if z[1] < qElement[1] { + return false + } + + return z[0] >= qElement[0] +} diff --git a/ff/element_mul_adx_amd64.s b/ff/element_mul_adx_amd64.s new file mode 100644 index 0000000..494e7bf --- /dev/null +++ b/ff/element_mul_adx_amd64.s @@ -0,0 +1,466 @@ +// +build amd64_adx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" +#include "funcdata.h" + +// modulus q +DATA q<>+0(SB)/8, $0x43e1f593f0000001 +DATA q<>+8(SB)/8, $0x2833e84879b97091 +DATA q<>+16(SB)/8, $0xb85045b68181585d +DATA q<>+24(SB)/8, $0x30644e72e131a029 +GLOBL q<>(SB), (RODATA+NOPTR), $32 + +// qInv0 q'[0] +DATA qInv0<>(SB)/8, $0xc2e1f593efffffff +GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 + +#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ + MOVQ ra0, rb0; \ + SUBQ q<>(SB), ra0; \ + MOVQ ra1, rb1; \ + SBBQ q<>+8(SB), ra1; \ + MOVQ ra2, rb2; \ + SBBQ q<>+16(SB), ra2; \ + MOVQ ra3, rb3; \ + SBBQ q<>+24(SB), ra3; \ + CMOVQCS rb0, ra0; \ + CMOVQCS rb1, ra1; \ + CMOVQCS rb2, ra2; \ + CMOVQCS rb3, ra3; \ + +// mul(res, x, y *Element) +TEXT ·mul(SB), NOSPLIT, $0-24 + + // the algorithm is described here + // https://hackmd.io/@zkteam/modular_multiplication + // however, to benefit from the ADCX and ADOX carry chains + // we split the inner loops in 2: + // for i=0 to N-1 + // for j=0 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // t[N-1] = C + A + + MOVQ x+8(FP), SI + + // x[0] -> DI + // x[1] -> R8 + // x[2] -> R9 + // x[3] -> R10 + MOVQ 0(SI), DI + MOVQ 8(SI), R8 + MOVQ 16(SI), R9 + MOVQ 24(SI), R10 + MOVQ y+16(FP), R11 + + // A -> BP + // t[0] -> R14 + // t[1] -> R15 + // t[2] -> CX + // t[3] -> BX + // clear the flags + XORQ AX, AX + MOVQ 0(R11), DX + + // (A,t[0]) := x[0]*y[0] + A + MULXQ DI, R14, R15 + + // (A,t[1]) := x[1]*y[0] + A + MULXQ R8, AX, CX + ADOXQ AX, R15 + + // (A,t[2]) := x[2]*y[0] + A + MULXQ R9, AX, BX + ADOXQ AX, CX + + // (A,t[3]) := x[3]*y[0] + A + MULXQ R10, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R12 + ADCXQ R14, AX + MOVQ R12, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 8(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[1] + A + MULXQ DI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[1] + A + ADCXQ BP, R15 + MULXQ R8, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[1] + A + ADCXQ BP, CX + MULXQ R9, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[1] + A + ADCXQ BP, BX + MULXQ R10, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R12 + ADCXQ R14, AX + MOVQ R12, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 16(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[2] + A + MULXQ DI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[2] + A + ADCXQ BP, R15 + MULXQ R8, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[2] + A + ADCXQ BP, CX + MULXQ R9, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[2] + A + ADCXQ BP, BX + MULXQ R10, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R12 + ADCXQ R14, AX + MOVQ R12, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 24(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[3] + A + MULXQ DI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[3] + A + ADCXQ BP, R15 + MULXQ R8, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[3] + A + ADCXQ BP, CX + MULXQ R9, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[3] + A + ADCXQ BP, BX + MULXQ R10, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R12 + ADCXQ R14, AX + MOVQ R12, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // reduce element(R14,R15,CX,BX) using temp registers (R13,SI,R12,R11) + REDUCE(R14,R15,CX,BX,R13,SI,R12,R11) + + MOVQ res+0(FP), AX + MOVQ R14, 0(AX) + MOVQ R15, 8(AX) + MOVQ CX, 16(AX) + MOVQ BX, 24(AX) + RET + +TEXT ·fromMont(SB), NOSPLIT, $0-8 + + // the algorithm is described here + // https://hackmd.io/@zkteam/modular_multiplication + // when y = 1 we have: + // for i=0 to N-1 + // t[i] = x[i] + // for i=0 to N-1 + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // t[N-1] = C + MOVQ res+0(FP), DX + MOVQ 0(DX), R14 + MOVQ 8(DX), R15 + MOVQ 16(DX), CX + MOVQ 24(DX), BX + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ AX, BX + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ AX, BX + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ AX, BX + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ AX, BX + + // reduce element(R14,R15,CX,BX) using temp registers (SI,DI,R8,R9) + REDUCE(R14,R15,CX,BX,SI,DI,R8,R9) + + MOVQ res+0(FP), AX + MOVQ R14, 0(AX) + MOVQ R15, 8(AX) + MOVQ CX, 16(AX) + MOVQ BX, 24(AX) + RET diff --git a/ff/element_mul_amd64.s b/ff/element_mul_amd64.s new file mode 100644 index 0000000..38b3b6c --- /dev/null +++ b/ff/element_mul_amd64.s @@ -0,0 +1,488 @@ +// +build !amd64_adx + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" +#include "funcdata.h" + +// modulus q +DATA q<>+0(SB)/8, $0x43e1f593f0000001 +DATA q<>+8(SB)/8, $0x2833e84879b97091 +DATA q<>+16(SB)/8, $0xb85045b68181585d +DATA q<>+24(SB)/8, $0x30644e72e131a029 +GLOBL q<>(SB), (RODATA+NOPTR), $32 + +// qInv0 q'[0] +DATA qInv0<>(SB)/8, $0xc2e1f593efffffff +GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 + +#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ + MOVQ ra0, rb0; \ + SUBQ q<>(SB), ra0; \ + MOVQ ra1, rb1; \ + SBBQ q<>+8(SB), ra1; \ + MOVQ ra2, rb2; \ + SBBQ q<>+16(SB), ra2; \ + MOVQ ra3, rb3; \ + SBBQ q<>+24(SB), ra3; \ + CMOVQCS rb0, ra0; \ + CMOVQCS rb1, ra1; \ + CMOVQCS rb2, ra2; \ + CMOVQCS rb3, ra3; \ + +// mul(res, x, y *Element) +TEXT ·mul(SB), $24-24 + + // the algorithm is described here + // https://hackmd.io/@zkteam/modular_multiplication + // however, to benefit from the ADCX and ADOX carry chains + // we split the inner loops in 2: + // for i=0 to N-1 + // for j=0 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // t[N-1] = C + A + + NO_LOCAL_POINTERS + CMPB ·supportAdx(SB), $1 + JNE l1 + MOVQ x+8(FP), SI + + // x[0] -> DI + // x[1] -> R8 + // x[2] -> R9 + // x[3] -> R10 + MOVQ 0(SI), DI + MOVQ 8(SI), R8 + MOVQ 16(SI), R9 + MOVQ 24(SI), R10 + MOVQ y+16(FP), R11 + + // A -> BP + // t[0] -> R14 + // t[1] -> R15 + // t[2] -> CX + // t[3] -> BX + // clear the flags + XORQ AX, AX + MOVQ 0(R11), DX + + // (A,t[0]) := x[0]*y[0] + A + MULXQ DI, R14, R15 + + // (A,t[1]) := x[1]*y[0] + A + MULXQ R8, AX, CX + ADOXQ AX, R15 + + // (A,t[2]) := x[2]*y[0] + A + MULXQ R9, AX, BX + ADOXQ AX, CX + + // (A,t[3]) := x[3]*y[0] + A + MULXQ R10, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R12 + ADCXQ R14, AX + MOVQ R12, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 8(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[1] + A + MULXQ DI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[1] + A + ADCXQ BP, R15 + MULXQ R8, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[1] + A + ADCXQ BP, CX + MULXQ R9, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[1] + A + ADCXQ BP, BX + MULXQ R10, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R12 + ADCXQ R14, AX + MOVQ R12, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 16(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[2] + A + MULXQ DI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[2] + A + ADCXQ BP, R15 + MULXQ R8, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[2] + A + ADCXQ BP, CX + MULXQ R9, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[2] + A + ADCXQ BP, BX + MULXQ R10, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R12 + ADCXQ R14, AX + MOVQ R12, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // clear the flags + XORQ AX, AX + MOVQ 24(R11), DX + + // (A,t[0]) := t[0] + x[0]*y[3] + A + MULXQ DI, AX, BP + ADOXQ AX, R14 + + // (A,t[1]) := t[1] + x[1]*y[3] + A + ADCXQ BP, R15 + MULXQ R8, AX, BP + ADOXQ AX, R15 + + // (A,t[2]) := t[2] + x[2]*y[3] + A + ADCXQ BP, CX + MULXQ R9, AX, BP + ADOXQ AX, CX + + // (A,t[3]) := t[3] + x[3]*y[3] + A + ADCXQ BP, BX + MULXQ R10, AX, BP + ADOXQ AX, BX + + // A += carries from ADCXQ and ADOXQ + MOVQ $0, AX + ADCXQ AX, BP + ADOXQ AX, BP + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + + // clear the flags + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, R12 + ADCXQ R14, AX + MOVQ R12, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + + // t[3] = C + A + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ BP, BX + + // reduce element(R14,R15,CX,BX) using temp registers (R13,SI,R12,R11) + REDUCE(R14,R15,CX,BX,R13,SI,R12,R11) + + MOVQ res+0(FP), AX + MOVQ R14, 0(AX) + MOVQ R15, 8(AX) + MOVQ CX, 16(AX) + MOVQ BX, 24(AX) + RET + +l1: + MOVQ res+0(FP), AX + MOVQ AX, (SP) + MOVQ x+8(FP), AX + MOVQ AX, 8(SP) + MOVQ y+16(FP), AX + MOVQ AX, 16(SP) + CALL ·_mulGeneric(SB) + RET + +TEXT ·fromMont(SB), $8-8 + NO_LOCAL_POINTERS + + // the algorithm is described here + // https://hackmd.io/@zkteam/modular_multiplication + // when y = 1 we have: + // for i=0 to N-1 + // t[i] = x[i] + // for i=0 to N-1 + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // t[N-1] = C + CMPB ·supportAdx(SB), $1 + JNE l2 + MOVQ res+0(FP), DX + MOVQ 0(DX), R14 + MOVQ 8(DX), R15 + MOVQ 16(DX), CX + MOVQ 24(DX), BX + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ AX, BX + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ AX, BX + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ AX, BX + XORQ DX, DX + + // m := t[0]*q'[0] mod W + MOVQ qInv0<>(SB), DX + IMULQ R14, DX + XORQ AX, AX + + // C,_ := t[0] + m*q[0] + MULXQ q<>+0(SB), AX, BP + ADCXQ R14, AX + MOVQ BP, R14 + + // (C,t[0]) := t[1] + m*q[1] + C + ADCXQ R15, R14 + MULXQ q<>+8(SB), AX, R15 + ADOXQ AX, R14 + + // (C,t[1]) := t[2] + m*q[2] + C + ADCXQ CX, R15 + MULXQ q<>+16(SB), AX, CX + ADOXQ AX, R15 + + // (C,t[2]) := t[3] + m*q[3] + C + ADCXQ BX, CX + MULXQ q<>+24(SB), AX, BX + ADOXQ AX, CX + MOVQ $0, AX + ADCXQ AX, BX + ADOXQ AX, BX + + // reduce element(R14,R15,CX,BX) using temp registers (SI,DI,R8,R9) + REDUCE(R14,R15,CX,BX,SI,DI,R8,R9) + + MOVQ res+0(FP), AX + MOVQ R14, 0(AX) + MOVQ R15, 8(AX) + MOVQ CX, 16(AX) + MOVQ BX, 24(AX) + RET + +l2: + MOVQ res+0(FP), AX + MOVQ AX, (SP) + CALL ·_fromMontGeneric(SB) + RET diff --git a/ff/element_ops_amd64.go b/ff/element_ops_amd64.go new file mode 100644 index 0000000..777ba01 --- /dev/null +++ b/ff/element_ops_amd64.go @@ -0,0 +1,50 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package ff + +//go:noescape +func MulBy3(x *Element) + +//go:noescape +func MulBy5(x *Element) + +//go:noescape +func MulBy13(x *Element) + +//go:noescape +func add(res, x, y *Element) + +//go:noescape +func sub(res, x, y *Element) + +//go:noescape +func neg(res, x *Element) + +//go:noescape +func double(res, x *Element) + +//go:noescape +func mul(res, x, y *Element) + +//go:noescape +func fromMont(res *Element) + +//go:noescape +func reduce(res *Element) + +//go:noescape +func Butterfly(a, b *Element) diff --git a/ff/element_ops_amd64.s b/ff/element_ops_amd64.s new file mode 100644 index 0000000..d5dca83 --- /dev/null +++ b/ff/element_ops_amd64.s @@ -0,0 +1,340 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" +#include "funcdata.h" + +// modulus q +DATA q<>+0(SB)/8, $0x43e1f593f0000001 +DATA q<>+8(SB)/8, $0x2833e84879b97091 +DATA q<>+16(SB)/8, $0xb85045b68181585d +DATA q<>+24(SB)/8, $0x30644e72e131a029 +GLOBL q<>(SB), (RODATA+NOPTR), $32 + +// qInv0 q'[0] +DATA qInv0<>(SB)/8, $0xc2e1f593efffffff +GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 + +#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ + MOVQ ra0, rb0; \ + SUBQ q<>(SB), ra0; \ + MOVQ ra1, rb1; \ + SBBQ q<>+8(SB), ra1; \ + MOVQ ra2, rb2; \ + SBBQ q<>+16(SB), ra2; \ + MOVQ ra3, rb3; \ + SBBQ q<>+24(SB), ra3; \ + CMOVQCS rb0, ra0; \ + CMOVQCS rb1, ra1; \ + CMOVQCS rb2, ra2; \ + CMOVQCS rb3, ra3; \ + +// add(res, x, y *Element) +TEXT ·add(SB), NOSPLIT, $0-24 + MOVQ x+8(FP), AX + MOVQ 0(AX), CX + MOVQ 8(AX), BX + MOVQ 16(AX), SI + MOVQ 24(AX), DI + MOVQ y+16(FP), DX + ADDQ 0(DX), CX + ADCQ 8(DX), BX + ADCQ 16(DX), SI + ADCQ 24(DX), DI + + // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) + REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) + + MOVQ res+0(FP), R12 + MOVQ CX, 0(R12) + MOVQ BX, 8(R12) + MOVQ SI, 16(R12) + MOVQ DI, 24(R12) + RET + +// sub(res, x, y *Element) +TEXT ·sub(SB), NOSPLIT, $0-24 + XORQ DI, DI + MOVQ x+8(FP), SI + MOVQ 0(SI), AX + MOVQ 8(SI), DX + MOVQ 16(SI), CX + MOVQ 24(SI), BX + MOVQ y+16(FP), SI + SUBQ 0(SI), AX + SBBQ 8(SI), DX + SBBQ 16(SI), CX + SBBQ 24(SI), BX + MOVQ $0x43e1f593f0000001, R8 + MOVQ $0x2833e84879b97091, R9 + MOVQ $0xb85045b68181585d, R10 + MOVQ $0x30644e72e131a029, R11 + CMOVQCC DI, R8 + CMOVQCC DI, R9 + CMOVQCC DI, R10 + CMOVQCC DI, R11 + ADDQ R8, AX + ADCQ R9, DX + ADCQ R10, CX + ADCQ R11, BX + MOVQ res+0(FP), R12 + MOVQ AX, 0(R12) + MOVQ DX, 8(R12) + MOVQ CX, 16(R12) + MOVQ BX, 24(R12) + RET + +// double(res, x *Element) +TEXT ·double(SB), NOSPLIT, $0-16 + MOVQ x+8(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + MOVQ res+0(FP), R11 + MOVQ DX, 0(R11) + MOVQ CX, 8(R11) + MOVQ BX, 16(R11) + MOVQ SI, 24(R11) + RET + +// neg(res, x *Element) +TEXT ·neg(SB), NOSPLIT, $0-16 + MOVQ res+0(FP), DI + MOVQ x+8(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + MOVQ DX, AX + ORQ CX, AX + ORQ BX, AX + ORQ SI, AX + TESTQ AX, AX + JEQ l1 + MOVQ $0x43e1f593f0000001, R8 + SUBQ DX, R8 + MOVQ R8, 0(DI) + MOVQ $0x2833e84879b97091, R8 + SBBQ CX, R8 + MOVQ R8, 8(DI) + MOVQ $0xb85045b68181585d, R8 + SBBQ BX, R8 + MOVQ R8, 16(DI) + MOVQ $0x30644e72e131a029, R8 + SBBQ SI, R8 + MOVQ R8, 24(DI) + RET + +l1: + MOVQ AX, 0(DI) + MOVQ AX, 8(DI) + MOVQ AX, 16(DI) + MOVQ AX, 24(DI) + RET + +TEXT ·reduce(SB), NOSPLIT, $0-8 + MOVQ res+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + RET + +// MulBy3(x *Element) +TEXT ·MulBy3(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + + // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + RET + +// MulBy5(x *Element) +TEXT ·MulBy5(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + + // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) + REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + RET + +// MulBy13(x *Element) +TEXT ·MulBy13(SB), NOSPLIT, $0-8 + MOVQ x+0(FP), AX + MOVQ 0(AX), DX + MOVQ 8(AX), CX + MOVQ 16(AX), BX + MOVQ 24(AX), SI + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) + REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) + + MOVQ DX, R11 + MOVQ CX, R12 + MOVQ BX, R13 + MOVQ SI, R14 + ADDQ DX, DX + ADCQ CX, CX + ADCQ BX, BX + ADCQ SI, SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + ADDQ R11, DX + ADCQ R12, CX + ADCQ R13, BX + ADCQ R14, SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + ADDQ 0(AX), DX + ADCQ 8(AX), CX + ADCQ 16(AX), BX + ADCQ 24(AX), SI + + // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) + REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) + + MOVQ DX, 0(AX) + MOVQ CX, 8(AX) + MOVQ BX, 16(AX) + MOVQ SI, 24(AX) + RET + +// Butterfly(a, b *Element) sets a = a + b; b = a - b +TEXT ·Butterfly(SB), NOSPLIT, $0-16 + MOVQ a+0(FP), AX + MOVQ 0(AX), CX + MOVQ 8(AX), BX + MOVQ 16(AX), SI + MOVQ 24(AX), DI + MOVQ CX, R8 + MOVQ BX, R9 + MOVQ SI, R10 + MOVQ DI, R11 + XORQ AX, AX + MOVQ b+8(FP), DX + ADDQ 0(DX), CX + ADCQ 8(DX), BX + ADCQ 16(DX), SI + ADCQ 24(DX), DI + SUBQ 0(DX), R8 + SBBQ 8(DX), R9 + SBBQ 16(DX), R10 + SBBQ 24(DX), R11 + MOVQ $0x43e1f593f0000001, R12 + MOVQ $0x2833e84879b97091, R13 + MOVQ $0xb85045b68181585d, R14 + MOVQ $0x30644e72e131a029, R15 + CMOVQCC AX, R12 + CMOVQCC AX, R13 + CMOVQCC AX, R14 + CMOVQCC AX, R15 + ADDQ R12, R8 + ADCQ R13, R9 + ADCQ R14, R10 + ADCQ R15, R11 + MOVQ R8, 0(DX) + MOVQ R9, 8(DX) + MOVQ R10, 16(DX) + MOVQ R11, 24(DX) + + // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) + REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) + + MOVQ a+0(FP), AX + MOVQ CX, 0(AX) + MOVQ BX, 8(AX) + MOVQ SI, 16(AX) + MOVQ DI, 24(AX) + RET diff --git a/ff/element_ops_noasm.go b/ff/element_ops_noasm.go new file mode 100644 index 0000000..ca357bc --- /dev/null +++ b/ff/element_ops_noasm.go @@ -0,0 +1,78 @@ +//go:build !amd64 +// +build !amd64 + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package ff + +// /!\ WARNING /!\ +// this code has not been audited and is provided as-is. In particular, +// there is no security guarantees such as constant time implementation +// or side-channel attack resistance +// /!\ WARNING /!\ + +// MulBy3 x *= 3 +func MulBy3(x *Element) { + mulByConstant(x, 3) +} + +// MulBy5 x *= 5 +func MulBy5(x *Element) { + mulByConstant(x, 5) +} + +// MulBy13 x *= 13 +func MulBy13(x *Element) { + mulByConstant(x, 13) +} + +// Butterfly sets +// a = a + b +// b = a - b +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func mul(z, x, y *Element) { + _mulGeneric(z, x, y) +} + +// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// sets and returns z = z * 1 +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func add(z, x, y *Element) { + _addGeneric(z, x, y) +} + +func double(z, x *Element) { + _doubleGeneric(z, x) +} + +func sub(z, x, y *Element) { + _subGeneric(z, x, y) +} + +func neg(z, x *Element) { + _negGeneric(z, x) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} diff --git a/ff/element_test.go b/ff/element_test.go index 090313f..6c43f79 100644 --- a/ff/element_test.go +++ b/ff/element_test.go @@ -1,135 +1,75 @@ -// Code generated by goff DO NOT EDIT +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + package ff import ( "crypto/rand" "math/big" - mrand "math/rand" + "math/bits" "testing" -) -func TestELEMENTCorrectnessAgainstBigInt(t *testing.T) { - modulus, _ := new(big.Int).SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) - cmpEandB := func(e *Element, b *big.Int, name string) { - var _e big.Int - if e.FromMont().ToBigInt(&_e).Cmp(b) != 0 { - t.Fatal(name, "failed") - } - } - var modulusMinusOne, one big.Int - one.SetUint64(1) - - modulusMinusOne.Sub(modulus, &one) - - for i := 0; i < 1000; i++ { - - // sample 2 random big int - b1, _ := rand.Int(rand.Reader, modulus) - b2, _ := rand.Int(rand.Reader, modulus) - rExp := mrand.Uint64() - - // adding edge cases - // TODO need more edge cases - switch i { - case 0: - rExp = 0 - b1.SetUint64(0) - case 1: - b2.SetUint64(0) - case 2: - b1.SetUint64(0) - b2.SetUint64(0) - case 3: - rExp = 0 - case 4: - rExp = 1 - case 5: - rExp = ^uint64(0) // max uint - case 6: - rExp = 2 - b1.Set(&modulusMinusOne) - case 7: - b2.Set(&modulusMinusOne) - case 8: - b1.Set(&modulusMinusOne) - b2.Set(&modulusMinusOne) - } - - rbExp := new(big.Int).SetUint64(rExp) - - var bMul, bAdd, bSub, bDiv, bNeg, bLsh, bInv, bExp, bSquare big.Int - - // e1 = mont(b1), e2 = mont(b2) - var e1, e2, eMul, eAdd, eSub, eDiv, eNeg, eLsh, eInv, eExp, eSquare, eMulAssign, eSubAssign, eAddAssign Element - e1.SetBigInt(b1) - e2.SetBigInt(b2) - - // (e1*e2).FromMont() === b1*b2 mod q ... etc - eSquare.Square(&e1) - eMul.Mul(&e1, &e2) - eMulAssign.Set(&e1) - eMulAssign.MulAssign(&e2) - eAdd.Add(&e1, &e2) - eAddAssign.Set(&e1) - eAddAssign.AddAssign(&e2) - eSub.Sub(&e1, &e2) - eSubAssign.Set(&e1) - eSubAssign.SubAssign(&e2) - eDiv.Div(&e1, &e2) - eNeg.Neg(&e1) - eInv.Inverse(&e1) - eExp.Exp(e1, rExp) - eLsh.Double(&e1) - - // same operations with big int - bAdd.Add(b1, b2).Mod(&bAdd, modulus) - bMul.Mul(b1, b2).Mod(&bMul, modulus) - bSquare.Mul(b1, b1).Mod(&bSquare, modulus) - bSub.Sub(b1, b2).Mod(&bSub, modulus) - bDiv.ModInverse(b2, modulus) - bDiv.Mul(&bDiv, b1). - Mod(&bDiv, modulus) - bNeg.Neg(b1).Mod(&bNeg, modulus) - - bInv.ModInverse(b1, modulus) - bExp.Exp(b1, rbExp, modulus) - bLsh.Lsh(b1, 1).Mod(&bLsh, modulus) - - cmpEandB(&eSquare, &bSquare, "Square") - cmpEandB(&eMul, &bMul, "Mul") - cmpEandB(&eMulAssign, &bMul, "MulAssign") - cmpEandB(&eAdd, &bAdd, "Add") - cmpEandB(&eAddAssign, &bAdd, "AddAssign") - cmpEandB(&eSub, &bSub, "Sub") - cmpEandB(&eSubAssign, &bSub, "SubAssign") - cmpEandB(&eDiv, &bDiv, "Div") - cmpEandB(&eNeg, &bNeg, "Neg") - cmpEandB(&eInv, &bInv, "Inv") - cmpEandB(&eExp, &bExp, "Exp") - cmpEandB(&eLsh, &bLsh, "Lsh") - } -} - -func TestELEMENTIsRandom(t *testing.T) { - for i := 0; i < 1000; i++ { - var x, y Element - x.SetRandom() - y.SetRandom() - if x.Equal(&y) { - t.Fatal("2 random numbers are unlikely to be equal") - } - } -} + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" +) // ------------------------------------------------------------------------------------------------- // benchmarks // most benchmarks are rudimentary and should sample a large number of random inputs // or be run multiple times to ensure it didn't measure the fastest path of the function -// TODO: clean up and push benchmarking branch var benchResElement Element -func BenchmarkInverseELEMENT(b *testing.B) { +func BenchmarkElementSetBytes(b *testing.B) { + var x Element + x.SetRandom() + bb := x.Bytes() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchResElement.SetBytes(bb[:]) + } + +} + +func BenchmarkElementMulByConstants(b *testing.B) { + b.Run("mulBy3", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy3(&benchResElement) + } + }) + b.Run("mulBy5", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy5(&benchResElement) + } + }) + b.Run("mulBy13", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy13(&benchResElement) + } + }) +} + +func BenchmarkElementInverse(b *testing.B) { var x Element x.SetRandom() benchResElement.SetRandom() @@ -140,17 +80,29 @@ func BenchmarkInverseELEMENT(b *testing.B) { } } -func BenchmarkExpELEMENT(b *testing.B) { + +func BenchmarkElementButterfly(b *testing.B) { var x Element x.SetRandom() benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.Exp(x, mrand.Uint64()) + Butterfly(&x, &benchResElement) } } -func BenchmarkDoubleELEMENT(b *testing.B) { +func BenchmarkElementExp(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b1, _ := rand.Int(rand.Reader, Modulus()) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Exp(x, b1) + } +} + +func BenchmarkElementDouble(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -158,7 +110,7 @@ func BenchmarkDoubleELEMENT(b *testing.B) { } } -func BenchmarkAddELEMENT(b *testing.B) { +func BenchmarkElementAdd(b *testing.B) { var x Element x.SetRandom() benchResElement.SetRandom() @@ -168,7 +120,7 @@ func BenchmarkAddELEMENT(b *testing.B) { } } -func BenchmarkSubELEMENT(b *testing.B) { +func BenchmarkElementSub(b *testing.B) { var x Element x.SetRandom() benchResElement.SetRandom() @@ -178,7 +130,7 @@ func BenchmarkSubELEMENT(b *testing.B) { } } -func BenchmarkNegELEMENT(b *testing.B) { +func BenchmarkElementNeg(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -186,7 +138,7 @@ func BenchmarkNegELEMENT(b *testing.B) { } } -func BenchmarkDivELEMENT(b *testing.B) { +func BenchmarkElementDiv(b *testing.B) { var x Element x.SetRandom() benchResElement.SetRandom() @@ -196,7 +148,7 @@ func BenchmarkDivELEMENT(b *testing.B) { } } -func BenchmarkFromMontELEMENT(b *testing.B) { +func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -204,14 +156,14 @@ func BenchmarkFromMontELEMENT(b *testing.B) { } } -func BenchmarkToMontELEMENT(b *testing.B) { +func BenchmarkElementToMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { benchResElement.ToMont() } } -func BenchmarkSquareELEMENT(b *testing.B) { +func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -219,7 +171,17 @@ func BenchmarkSquareELEMENT(b *testing.B) { } } -func BenchmarkMulAssignELEMENT(b *testing.B) { +func BenchmarkElementSqrt(b *testing.B) { + var a Element + a.SetUint64(4) + a.Neg(&a) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sqrt(&a) + } +} + +func BenchmarkElementMul(b *testing.B) { x := Element{ 1997599621687373223, 6052339484930628067, @@ -229,6 +191,1684 @@ func BenchmarkMulAssignELEMENT(b *testing.B) { benchResElement.SetOne() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.MulAssign(&x) + benchResElement.Mul(&benchResElement, &x) + } +} + +func BenchmarkElementCmp(b *testing.B) { + x := Element{ + 1997599621687373223, + 6052339484930628067, + 10108755138030829701, + 150537098327114917, + } + benchResElement = x + benchResElement[0] = 0 + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Cmp(&x) + } +} + +func TestElementCmp(t *testing.T) { + var x, y Element + + if x.Cmp(&y) != 0 { + t.Fatal("x == y") + } + + one := One() + y.Sub(&y, &one) + + if x.Cmp(&y) != -1 { + t.Fatal("x < y") + } + if y.Cmp(&x) != 1 { + t.Fatal("x < y") + } + + x = y + if x.Cmp(&y) != 0 { + t.Fatal("x == y") + } + + x.Sub(&x, &one) + if x.Cmp(&y) != -1 { + t.Fatal("x < y") + } + if y.Cmp(&x) != 1 { + t.Fatal("x < y") + } +} + +func TestElementIsRandom(t *testing.T) { + for i := 0; i < 50; i++ { + var x, y Element + x.SetRandom() + y.SetRandom() + if x.Equal(&y) { + t.Fatal("2 random numbers are unlikely to be equal") + } + } +} + +// ------------------------------------------------------------------------------------------------- +// Gopter tests +// most of them are generated with a template + +const ( + nbFuzzShort = 200 + nbFuzz = 1000 +) + +// special values to be used in tests +var staticTestValues []Element + +func init() { + staticTestValues = append(staticTestValues, Element{}) // zero + staticTestValues = append(staticTestValues, One()) // one + staticTestValues = append(staticTestValues, rSquare) // r^2 + var e, one Element + one.SetOne() + e.Sub(&qElement, &one) + staticTestValues = append(staticTestValues, e) // q - 1 + e.Double(&one) + staticTestValues = append(staticTestValues, e) // 2 + + { + a := qElement + a[3]-- + staticTestValues = append(staticTestValues, a) + } + { + a := qElement + a[0]-- + staticTestValues = append(staticTestValues, a) + } + + for i := 0; i <= 3; i++ { + staticTestValues = append(staticTestValues, Element{uint64(i)}) + staticTestValues = append(staticTestValues, Element{0, uint64(i)}) + } + + { + a := qElement + a[3]-- + a[0]++ + staticTestValues = append(staticTestValues, a) + } + +} + +func TestElementNegZero(t *testing.T) { + var a, b Element + b.SetZero() + for a.IsZero() { + a.SetRandom() + } + a.Neg(&b) + if !a.IsZero() { + t.Fatal("neg(0) != 0") + } +} + +func TestElementReduce(t *testing.T) { + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, s := range testValues { + expected := s + reduce(&s) + _reduceGeneric(&expected) + if !s.Equal(&expected) { + t.Fatal("reduce failed: asm and generic impl don't match") + } + } + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := genFull() + + properties.Property("reduce should output a result smaller than modulus", prop.ForAll( + func(a Element) bool { + b := a + reduce(&a) + _reduceGeneric(&b) + return !a.biggerOrEqualModulus() && a.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + // if we have ADX instruction enabled, test both path in assembly + if supportAdx { + t.Log("disabling ADX") + supportAdx = false + properties.TestingRun(t, gopter.ConsoleReporter(false)) + supportAdx = true + } + +} + +func TestElementBytes(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("SetBytes(Bytes()) should stayt constant", prop.ForAll( + func(a testPairElement) bool { + var b Element + bytes := a.element.Bytes() + b.SetBytes(bytes[:]) + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementInverseExp(t *testing.T) { + // inverse must be equal to exp^-2 + exp := Modulus() + exp.Sub(exp, new(big.Int).SetUint64(2)) + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("inv == exp^-2", prop.ForAll( + func(a testPairElement) bool { + var b Element + b.Set(&a.element) + a.element.Inverse(&a.element) + b.Exp(b, exp) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + // if we have ADX instruction enabled, test both path in assembly + if supportAdx { + t.Log("disabling ADX") + supportAdx = false + properties.TestingRun(t, gopter.ConsoleReporter(false)) + supportAdx = true + } +} + +func TestElementMulByConstants(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + implemented := []uint8{0, 1, 2, 3, 5, 13} + properties.Property("mulByConstant", prop.ForAll( + func(a testPairElement) bool { + for _, c := range implemented { + var constant Element + constant.SetUint64(uint64(c)) + + b := a.element + b.Mul(&b, &constant) + + aa := a.element + mulByConstant(&aa, c) + + if !aa.Equal(&b) { + return false + } + } + + return true + }, + genA, + )) + + properties.Property("MulBy3(x) == Mul(x, 3)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(3) + + b := a.element + b.Mul(&b, &constant) + + MulBy3(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("MulBy5(x) == Mul(x, 5)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(5) + + b := a.element + b.Mul(&b, &constant) + + MulBy5(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("MulBy13(x) == Mul(x, 13)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(13) + + b := a.element + b.Mul(&b, &constant) + + MulBy13(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + // if we have ADX instruction enabled, test both path in assembly + if supportAdx { + t.Log("disabling ADX") + supportAdx = false + properties.TestingRun(t, gopter.ConsoleReporter(false)) + supportAdx = true + } + +} + +func TestElementLegendre(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("legendre should output same result than big.Int.Jacobi", prop.ForAll( + func(a testPairElement) bool { + return a.element.Legendre() == big.Jacobi(&a.bigint, Modulus()) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + // if we have ADX instruction enabled, test both path in assembly + if supportAdx { + t.Log("disabling ADX") + supportAdx = false + properties.TestingRun(t, gopter.ConsoleReporter(false)) + supportAdx = true + } + +} + +func TestElementButterflies(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("butterfly0 == a -b; a +b", prop.ForAll( + func(a, b testPairElement) bool { + a0, b0 := a.element, b.element + + _butterflyGeneric(&a.element, &b.element) + Butterfly(&a0, &b0) + + return a.element.Equal(&a0) && b.element.Equal(&b0) + }, + genA, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + // if we have ADX instruction enabled, test both path in assembly + if supportAdx { + t.Log("disabling ADX") + supportAdx = false + properties.TestingRun(t, gopter.ConsoleReporter(false)) + supportAdx = true + } + +} + +func TestElementLexicographicallyLargest(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("element.Cmp should match LexicographicallyLargest output", prop.ForAll( + func(a testPairElement) bool { + var negA Element + negA.Neg(&a.element) + + cmpResult := a.element.Cmp(&negA) + lResult := a.element.LexicographicallyLargest() + + if lResult && cmpResult == 1 { + return true + } + if !lResult && cmpResult != 1 { + return true + } + return false + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + // if we have ADX instruction enabled, test both path in assembly + if supportAdx { + t.Log("disabling ADX") + supportAdx = false + properties.TestingRun(t, gopter.ConsoleReporter(false)) + supportAdx = true + } + +} + +func TestElementAdd(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Add: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Add(&a.element, &b.element) + a.element.Add(&a.element, &b.element) + b.element.Add(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Add: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Add(&a.element, &b.element) + + var d, e big.Int + d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, r := range testValues { + var d, e, rb big.Int + r.ToBigIntRegular(&rb) + + var c Element + c.Add(&a.element, &r) + d.Add(&a.bigint, &rb).Mod(&d, Modulus()) + + // checking generic impl against asm path + var cGeneric Element + _addGeneric(&cGeneric, &a.element, &r) + if !cGeneric.Equal(&c) { + // need to give context to failing error. + return false + } + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Add: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Add(&a.element, &b.element) + + return !c.biggerOrEqualModulus() + }, + genA, + genB, + )) + + properties.Property("Add: assembly implementation must be consistent with generic one", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + c.Add(&a.element, &b.element) + _addGeneric(&d, &a.element, &b.element) + return c.Equal(&d) + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.ToBigIntRegular(&aBig) + for _, b := range testValues { + + var bBig, d, e big.Int + b.ToBigIntRegular(&bBig) + + var c Element + c.Add(&a, &b) + d.Add(&aBig, &bBig).Mod(&d, Modulus()) + + // checking asm against generic impl + var cGeneric Element + _addGeneric(&cGeneric, &a, &b) + if !cGeneric.Equal(&c) { + t.Fatal("Add failed special test values: asm and generic impl don't match") + } + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + t.Fatal("Add failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + // if we have ADX instruction enabled, test both path in assembly + if supportAdx { + t.Log("disabling ADX") + supportAdx = false + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + supportAdx = true + } +} + +func TestElementSub(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Sub: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Sub(&a.element, &b.element) + a.element.Sub(&a.element, &b.element) + b.element.Sub(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Sub(&a.element, &b.element) + + var d, e big.Int + d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, r := range testValues { + var d, e, rb big.Int + r.ToBigIntRegular(&rb) + + var c Element + c.Sub(&a.element, &r) + d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) + + // checking generic impl against asm path + var cGeneric Element + _subGeneric(&cGeneric, &a.element, &r) + if !cGeneric.Equal(&c) { + // need to give context to failing error. + return false + } + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Sub(&a.element, &b.element) + + return !c.biggerOrEqualModulus() + }, + genA, + genB, + )) + + properties.Property("Sub: assembly implementation must be consistent with generic one", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + c.Sub(&a.element, &b.element) + _subGeneric(&d, &a.element, &b.element) + return c.Equal(&d) + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.ToBigIntRegular(&aBig) + for _, b := range testValues { + + var bBig, d, e big.Int + b.ToBigIntRegular(&bBig) + + var c Element + c.Sub(&a, &b) + d.Sub(&aBig, &bBig).Mod(&d, Modulus()) + + // checking asm against generic impl + var cGeneric Element + _subGeneric(&cGeneric, &a, &b) + if !cGeneric.Equal(&c) { + t.Fatal("Sub failed special test values: asm and generic impl don't match") + } + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + t.Fatal("Sub failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + // if we have ADX instruction enabled, test both path in assembly + if supportAdx { + t.Log("disabling ADX") + supportAdx = false + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + supportAdx = true + } +} + +func TestElementMul(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Mul: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Mul(&a.element, &b.element) + a.element.Mul(&a.element, &b.element) + b.element.Mul(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Mul: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Mul(&a.element, &b.element) + + var d, e big.Int + d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, r := range testValues { + var d, e, rb big.Int + r.ToBigIntRegular(&rb) + + var c Element + c.Mul(&a.element, &r) + d.Mul(&a.bigint, &rb).Mod(&d, Modulus()) + + // checking generic impl against asm path + var cGeneric Element + _mulGeneric(&cGeneric, &a.element, &r) + if !cGeneric.Equal(&c) { + // need to give context to failing error. + return false + } + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Mul: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Mul(&a.element, &b.element) + + return !c.biggerOrEqualModulus() + }, + genA, + genB, + )) + + properties.Property("Mul: assembly implementation must be consistent with generic one", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + c.Mul(&a.element, &b.element) + _mulGeneric(&d, &a.element, &b.element) + return c.Equal(&d) + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.ToBigIntRegular(&aBig) + for _, b := range testValues { + + var bBig, d, e big.Int + b.ToBigIntRegular(&bBig) + + var c Element + c.Mul(&a, &b) + d.Mul(&aBig, &bBig).Mod(&d, Modulus()) + + // checking asm against generic impl + var cGeneric Element + _mulGeneric(&cGeneric, &a, &b) + if !cGeneric.Equal(&c) { + t.Fatal("Mul failed special test values: asm and generic impl don't match") + } + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + t.Fatal("Mul failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + // if we have ADX instruction enabled, test both path in assembly + if supportAdx { + t.Log("disabling ADX") + supportAdx = false + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + supportAdx = true + } +} + +func TestElementDiv(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Div: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Div(&a.element, &b.element) + a.element.Div(&a.element, &b.element) + b.element.Div(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Div: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Div(&a.element, &b.element) + + var d, e big.Int + d.ModInverse(&b.bigint, Modulus()) + d.Mul(&d, &a.bigint).Mod(&d, Modulus()) + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, r := range testValues { + var d, e, rb big.Int + r.ToBigIntRegular(&rb) + + var c Element + c.Div(&a.element, &r) + d.ModInverse(&rb, Modulus()) + d.Mul(&d, &a.bigint).Mod(&d, Modulus()) + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Div: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Div(&a.element, &b.element) + + return !c.biggerOrEqualModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.ToBigIntRegular(&aBig) + for _, b := range testValues { + + var bBig, d, e big.Int + b.ToBigIntRegular(&bBig) + + var c Element + c.Div(&a, &b) + d.ModInverse(&bBig, Modulus()) + d.Mul(&d, &aBig).Mod(&d, Modulus()) + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + t.Fatal("Div failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + // if we have ADX instruction enabled, test both path in assembly + if supportAdx { + t.Log("disabling ADX") + supportAdx = false + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + supportAdx = true + } +} + +func TestElementExp(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Exp: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Exp(a.element, &b.bigint) + a.element.Exp(a.element, &b.bigint) + b.element.Exp(d, &b.bigint) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Exp: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Exp(a.element, &b.bigint) + + var d, e big.Int + d.Exp(&a.bigint, &b.bigint, Modulus()) + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, r := range testValues { + var d, e, rb big.Int + r.ToBigIntRegular(&rb) + + var c Element + c.Exp(a.element, &rb) + d.Exp(&a.bigint, &rb, Modulus()) + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Exp: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Exp(a.element, &b.bigint) + + return !c.biggerOrEqualModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.ToBigIntRegular(&aBig) + for _, b := range testValues { + + var bBig, d, e big.Int + b.ToBigIntRegular(&bBig) + + var c Element + c.Exp(a, &bBig) + d.Exp(&aBig, &bBig, Modulus()) + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + t.Fatal("Exp failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + // if we have ADX instruction enabled, test both path in assembly + if supportAdx { + t.Log("disabling ADX") + supportAdx = false + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + supportAdx = true + } +} + +func TestElementSquare(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Square: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Square(&a.element) + a.element.Square(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Square: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Square(&a.element) + + var d, e big.Int + d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) + + return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Square: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Square(&a.element) + return !c.biggerOrEqualModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.ToBigIntRegular(&aBig) + var c Element + c.Square(&a) + + var d, e big.Int + d.Mul(&aBig, &aBig).Mod(&d, Modulus()) + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + t.Fatal("Square failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + // if we have ADX instruction enabled, test both path in assembly + if supportAdx { + supportAdx = false + t.Log("disabling ADX") + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + supportAdx = true + } +} + +func TestElementInverse(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Inverse: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Inverse(&a.element) + a.element.Inverse(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Inverse: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Inverse(&a.element) + + var d, e big.Int + d.ModInverse(&a.bigint, Modulus()) + + return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Inverse: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Inverse(&a.element) + return !c.biggerOrEqualModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.ToBigIntRegular(&aBig) + var c Element + c.Inverse(&a) + + var d, e big.Int + d.ModInverse(&aBig, Modulus()) + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + t.Fatal("Inverse failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + // if we have ADX instruction enabled, test both path in assembly + if supportAdx { + supportAdx = false + t.Log("disabling ADX") + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + supportAdx = true + } +} + +func TestElementSqrt(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Sqrt: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + b := a.element + + b.Sqrt(&a.element) + a.element.Sqrt(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Sqrt: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Sqrt(&a.element) + + var d, e big.Int + d.ModSqrt(&a.bigint, Modulus()) + + return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Sqrt: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Sqrt(&a.element) + return !c.biggerOrEqualModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.ToBigIntRegular(&aBig) + var c Element + c.Sqrt(&a) + + var d, e big.Int + d.ModSqrt(&aBig, Modulus()) + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + t.Fatal("Sqrt failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + // if we have ADX instruction enabled, test both path in assembly + if supportAdx { + supportAdx = false + t.Log("disabling ADX") + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + supportAdx = true + } +} + +func TestElementDouble(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Double: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Double(&a.element) + a.element.Double(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Double: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + + var d, e big.Int + d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) + + return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Double: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + return !c.biggerOrEqualModulus() + }, + genA, + )) + + properties.Property("Double: assembly implementation must be consistent with generic one", prop.ForAll( + func(a testPairElement) bool { + var c, d Element + c.Double(&a.element) + _doubleGeneric(&d, &a.element) + return c.Equal(&d) + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.ToBigIntRegular(&aBig) + var c Element + c.Double(&a) + + var d, e big.Int + d.Lsh(&aBig, 1).Mod(&d, Modulus()) + + // checking asm against generic impl + var cGeneric Element + _doubleGeneric(&cGeneric, &a) + if !cGeneric.Equal(&c) { + t.Fatal("Double failed special test values: asm and generic impl don't match") + } + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + t.Fatal("Double failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + // if we have ADX instruction enabled, test both path in assembly + if supportAdx { + supportAdx = false + t.Log("disabling ADX") + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + supportAdx = true + } +} + +func TestElementNeg(t *testing.T) { + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Neg: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Neg(&a.element) + a.element.Neg(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Neg: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + + var d, e big.Int + d.Neg(&a.bigint).Mod(&d, Modulus()) + + return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Neg: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + return !c.biggerOrEqualModulus() + }, + genA, + )) + + properties.Property("Neg: assembly implementation must be consistent with generic one", prop.ForAll( + func(a testPairElement) bool { + var c, d Element + c.Neg(&a.element) + _negGeneric(&d, &a.element) + return c.Equal(&d) + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.ToBigIntRegular(&aBig) + var c Element + c.Neg(&a) + + var d, e big.Int + d.Neg(&aBig).Mod(&d, Modulus()) + + // checking asm against generic impl + var cGeneric Element + _negGeneric(&cGeneric, &a) + if !cGeneric.Equal(&c) { + t.Fatal("Neg failed special test values: asm and generic impl don't match") + } + + if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + t.Fatal("Neg failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + // if we have ADX instruction enabled, test both path in assembly + if supportAdx { + supportAdx = false + t.Log("disabling ADX") + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + supportAdx = true + } +} + +func TestElementHalve(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + var twoInv Element + twoInv.SetUint64(2) + twoInv.Inverse(&twoInv) + + properties.Property("z.Halve must match z / 2", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.Halve() + d.Mul(&d, &twoInv) + return c.Equal(&d) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementFromMont(t *testing.T) { + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Assembly implementation must be consistent with generic one", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.FromMont() + _fromMontGeneric(&d) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + func(a testPairElement) bool { + c := a.element + c.FromMont().ToMont() + return c.Equal(&a.element) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +type testPairElement struct { + element Element + bigint big.Int +} + +func (z *Element) biggerOrEqualModulus() bool { + if z[3] > qElement[3] { + return true + } + if z[3] < qElement[3] { + return false + } + + if z[2] > qElement[2] { + return true + } + if z[2] < qElement[2] { + return false + } + + if z[1] > qElement[1] { + return true + } + if z[1] < qElement[1] { + return false + } + + return z[0] >= qElement[0] +} + +func gen() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var g testPairElement + + g.element = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g.element[3] %= (qElement[3] + 1) + } + + for g.element.biggerOrEqualModulus() { + g.element = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g.element[3] %= (qElement[3] + 1) + } + } + + g.element.ToBigIntRegular(&g.bigint) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + + genRandomFq := func() Element { + var g Element + + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + + for g.biggerOrEqualModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + } + + return g + } + a := genRandomFq() + + var carry uint64 + a[0], carry = bits.Add64(a[0], qElement[0], carry) + a[1], carry = bits.Add64(a[1], qElement[1], carry) + a[2], carry = bits.Add64(a[2], qElement[2], carry) + a[3], _ = bits.Add64(a[3], qElement[3], carry) + + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult } } diff --git a/ff/util.go b/ff/util.go deleted file mode 100644 index fc3c5a8..0000000 --- a/ff/util.go +++ /dev/null @@ -1,6 +0,0 @@ -package ff - -// NewElement returns a new empty *Element -func NewElement() *Element { - return &Element{} -} diff --git a/go.mod b/go.mod index 23ca3a3..bd872a3 100644 --- a/go.mod +++ b/go.mod @@ -6,4 +6,9 @@ require ( github.com/dchest/blake512 v1.0.0 github.com/stretchr/testify v1.7.0 golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871 + golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e // indirect + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/leanovate/gopter v0.2.9 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect ) diff --git a/go.sum b/go.sum index 880feaa..8d5806f 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dchest/blake512 v1.0.0 h1:oDFEQFIqFSeuA34xLtXZ/rWxCXdSjirjzPhey5EUvmA= github.com/dchest/blake512 v1.0.0/go.mod h1:FV1x7xPPLWukZlpDpWQ88rF/SFwZ5qbskrzhLMB92JI= +github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c= +github.com/leanovate/gopter v0.2.9/go.mod h1:U2L/78B+KVFIx2VmW6onHJQzXtFb+p5y3y2Sh+Jxxv8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -14,6 +16,8 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/poseidon/poseidon.go b/poseidon/poseidon.go index 8ee54b7..2cd81e0 100644 --- a/poseidon/poseidon.go +++ b/poseidon/poseidon.go @@ -20,7 +20,7 @@ func zero() *ff.Element { // exp5 performs x^5 mod p // https://eprint.iacr.org/2019/458.pdf page 8 func exp5(a *ff.Element) { - a.Exp(*a, 5) //nolint:gomnd + a.Exp(*a, big.NewInt(5)) //nolint:gomnd } // exp5state perform exp5 for whole state