Skip to content

Commit

Permalink
Merge pull request #1257 from o1-labs/feature/glv
Browse files Browse the repository at this point in the history
Make ECDSA more efficient with GLV
  • Loading branch information
mitschabaude authored Dec 19, 2023
2 parents b820dea + 43cf2a0 commit 3480c54
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 49 deletions.
2 changes: 1 addition & 1 deletion run
Original file line number Diff line number Diff line change
@@ -1 +1 @@
node --enable-source-maps src/build/run.js $@
node --enable-source-maps --stack-trace-limit=1000 src/build/run.js $@
2 changes: 1 addition & 1 deletion src/bindings
10 changes: 9 additions & 1 deletion src/lib/gadgets/basic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@ import type { Field, VarField } from '../field.js';
import { existsOne, toVar } from './common.js';
import { Gates } from '../gates.js';
import { TupleN } from '../util/types.js';
import { Snarky } from '../../snarky.js';

export { arrayGet, assertOneOf };
export { assertBoolean, arrayGet, assertOneOf };

/**
* Assert that x is either 0 or 1.
*/
function assertBoolean(x: VarField) {
Snarky.field.assertBoolean(x.value);
}

// TODO: create constant versions of these and expose on Gadgets

Expand Down
22 changes: 17 additions & 5 deletions src/lib/gadgets/ecdsa.unit-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ import {
Second,
bool,
equivalentProvable,
fromRandom,
map,
oneOf,
record,
} from '../testing/equivalent.js';
import { Bool } from '../bool.js';
import { Random } from '../testing/random.js';

// quick tests
const Secp256k1 = createCurveAffine(CurveParams.Secp256k1);
Expand Down Expand Up @@ -53,11 +55,21 @@ for (let Curve of curves) {
}
);

// with 30% prob, test the version without GLV even if the curve supports it
let noGlv = fromRandom(Random.map(Random.fraction(), (f) => f < 0.3));

// provable method we want to test
const verify = (s: Second<typeof signature>) => {
const verify = (s: Second<typeof signature>, noGlv: boolean) => {
// invalid public key can lead to either a failing constraint, or verify() returning false
EllipticCurve.assertOnCurve(s.publicKey, Curve);
return Ecdsa.verify(Curve, s.signature, s.msg, s.publicKey);

let hasGlv = Curve.hasEndomorphism;
if (noGlv) Curve.hasEndomorphism = false; // hack to force non-GLV version
try {
return Ecdsa.verify(Curve, s.signature, s.msg, s.publicKey);
} finally {
Curve.hasEndomorphism = hasGlv;
}
};

// input validation equivalent to the one implicit in verify()
Expand All @@ -72,22 +84,22 @@ for (let Curve of curves) {
};

// positive test
equivalentProvable({ from: [signature], to: bool, verbose: true })(
equivalentProvable({ from: [signature, noGlv], to: bool, verbose: true })(
() => true,
verify,
`${Curve.name}: verifies`
);

// negative test
equivalentProvable({ from: [badSignature], to: bool, verbose: true })(
equivalentProvable({ from: [badSignature, noGlv], to: bool, verbose: true })(
(s) => checkInputs(s) && false,
verify,
`${Curve.name}: fails`
);

// test against constant implementation, with both invalid and valid signatures
equivalentProvable({
from: [oneOf(signature, badSignature)],
from: [oneOf(signature, badSignature), noGlv],
to: bool,
verbose: true,
})(
Expand Down
191 changes: 160 additions & 31 deletions src/lib/gadgets/elliptic-curve.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ import { inverse, mod } from '../../bindings/crypto/finite_field.js';
import { Field } from '../field.js';
import { Provable } from '../provable.js';
import { assert, exists } from './common.js';
import { Field3, ForeignField, split } from './foreign-field.js';
import { l } from './range-check.js';
import { Field3, ForeignField, split, weakBound } from './foreign-field.js';
import { l, l2, multiRangeCheck } from './range-check.js';
import { sha256 } from 'js-sha256';
import {
bigIntToBits,
Expand All @@ -18,7 +18,7 @@ import {
import { Bool } from '../bool.js';
import { provable } from '../circuit_value.js';
import { assertPositiveInteger } from '../../bindings/crypto/non-negative.js';
import { arrayGet } from './basic.js';
import { arrayGet, assertBoolean } from './basic.js';

// external API
export { EllipticCurve, Point, Ecdsa };
Expand Down Expand Up @@ -275,6 +275,32 @@ function verifyEcdsa(
return Provable.equal(Field3.provable, Rx, r);
}

/**
* Bigint implementation of ECDSA verify
*/
function verifyEcdsaConstant(
Curve: CurveAffine,
{ r, s }: Ecdsa.signature,
msgHash: bigint,
publicKey: point
) {
let pk = Curve.from(publicKey);
if (Curve.equal(pk, Curve.zero)) return false;
if (Curve.hasCofactor && !Curve.isInSubgroup(pk)) return false;
if (r < 1n || r >= Curve.order) return false;
if (s < 1n || s >= Curve.order) return false;

let sInv = Curve.Scalar.inverse(s);
assert(sInv !== undefined);
let u1 = Curve.Scalar.mul(msgHash, sInv);
let u2 = Curve.Scalar.mul(r, sInv);

let R = Curve.add(Curve.scale(Curve.one, u1), Curve.scale(pk, u2));
if (Curve.equal(R, Curve.zero)) return false;

return Curve.Scalar.equal(R.x, r);
}

/**
* Multi-scalar multiplication:
*
Expand All @@ -293,7 +319,6 @@ function verifyEcdsa(
*
* TODO: could use lookups for picking precomputed multiples, instead of O(2^c) provable switch
* TODO: custom bit representation for the scalar that avoids 0, to get rid of the degenerate addition case
* TODO: glv trick which cuts down ec doubles by half by splitting s*P = s0*P + s1*endo(P) with s0, s1 in [0, 2^128)
*/
function multiScalarMul(
scalars: Field3[],
Expand All @@ -309,6 +334,7 @@ function multiScalarMul(
let n = points.length;
assert(scalars.length === n, 'Points and scalars lengths must match');
assertPositiveInteger(n, 'Expected at least 1 point and scalar');
let useGlv = Curve.hasEndomorphism;

// constant case
if (scalars.every(Field3.isConstant) && points.every(Point.isConstant)) {
Expand All @@ -317,7 +343,11 @@ function multiScalarMul(
let P = points.map(Point.toBigint);
let sum = Curve.zero;
for (let i = 0; i < n; i++) {
sum = Curve.add(sum, Curve.scale(P[i], s[i]));
if (useGlv) {
sum = Curve.add(sum, Curve.Endo.scale(P[i], s[i]));
} else {
sum = Curve.add(sum, Curve.scale(P[i], s[i]));
}
}
if (mode === 'assert-zero') {
assert(sum.infinity, 'scalar multiplication: expected zero result');
Expand All @@ -333,16 +363,60 @@ function multiScalarMul(
getPointTable(Curve, P, windowSizes[i], tableConfigs[i]?.multiples)
);

let maxBits = Curve.Scalar.sizeInBits;

if (useGlv) {
maxBits = Curve.Endo.decomposeMaxBits;

// decompose scalars and handle signs
let n2 = 2 * n;
let scalars2: Field3[] = Array(n2);
let points2: Point[] = Array(n2);
let windowSizes2: number[] = Array(n2);
let tables2: Point[][] = Array(n2);
let mrcStack: Field[] = [];

for (let i = 0; i < n; i++) {
let [s0, s1] = decomposeNoRangeCheck(Curve, scalars[i]);
scalars2[2 * i] = s0.abs;
scalars2[2 * i + 1] = s1.abs;

let table = tables[i];
let endoTable = table.map((P, i) => {
if (i === 0) return P;
let [phiP, betaXBound] = endomorphism(Curve, P);
mrcStack.push(betaXBound);
return phiP;
});
tables2[2 * i] = table.map((P) =>
negateIf(s0.isNegative, P, Curve.modulus)
);
tables2[2 * i + 1] = endoTable.map((P) =>
negateIf(s1.isNegative, P, Curve.modulus)
);
points2[2 * i] = tables2[2 * i][1];
points2[2 * i + 1] = tables2[2 * i + 1][1];

windowSizes2[2 * i] = windowSizes2[2 * i + 1] = windowSizes[i];
}
reduceMrcStack(mrcStack);
// from now on, everything is the same as if these were the original points and scalars
points = points2;
tables = tables2;
scalars = scalars2;
windowSizes = windowSizes2;
n = n2;
}

// slice scalars
let b = Curve.order.toString(2).length;
let scalarChunks = scalars.map((s, i) =>
slice(s, { maxBits: b, chunkSize: windowSizes[i] })
slice(s, { maxBits, chunkSize: windowSizes[i] })
);

ia ??= initialAggregator(Curve);
let sum = Point.from(ia);

for (let i = b - 1; i >= 0; i--) {
for (let i = maxBits - 1; i >= 0; i--) {
// add in multiple of each point
for (let j = 0; j < n; j++) {
let windowSize = windowSizes[j];
Expand Down Expand Up @@ -371,7 +445,7 @@ function multiScalarMul(

// the sum is now 2^(b-1)*IA + sum_i s_i*P_i
// we assert that sum != 2^(b-1)*IA, and add -2^(b-1)*IA to get our result
let iaFinal = Curve.scale(Curve.fromNonzero(ia), 1n << BigInt(b - 1));
let iaFinal = Curve.scale(Curve.fromNonzero(ia), 1n << BigInt(maxBits - 1));
let isZero = equals(sum, iaFinal, Curve);

if (mode === 'assert-nonzero') {
Expand All @@ -386,31 +460,70 @@ function multiScalarMul(
return sum;
}

/**
* Bigint implementation of ECDSA verify
*/
function verifyEcdsaConstant(
Curve: CurveAffine,
{ r, s }: Ecdsa.signature,
msgHash: bigint,
publicKey: point
) {
let pk = Curve.from(publicKey);
if (Curve.equal(pk, Curve.zero)) return false;
if (!Curve.isOnCurve(pk)) return false;
if (Curve.hasCofactor && !Curve.isInSubgroup(pk)) return false;
if (r < 1n || r >= Curve.order) return false;
if (s < 1n || s >= Curve.order) return false;
function negateIf(condition: Field, P: Point, f: bigint) {
let y = Provable.if(
Bool.Unsafe.ofField(condition),
Field3.provable,
ForeignField.negate(P.y, f),
P.y
);
return { x: P.x, y };
}

let sInv = Curve.Scalar.inverse(s);
assert(sInv !== undefined);
let u1 = Curve.Scalar.mul(msgHash, sInv);
let u2 = Curve.Scalar.mul(r, sInv);
function endomorphism(Curve: CurveAffine, P: Point) {
let beta = Field3.from(Curve.Endo.base);
let betaX = ForeignField.mul(beta, P.x, Curve.modulus);
return [{ x: betaX, y: P.y }, weakBound(betaX[2], Curve.modulus)] as const;
}

let R = Curve.add(Curve.scale(Curve.one, u1), Curve.scale(pk, u2));
if (Curve.equal(R, Curve.zero)) return false;
/**
* Decompose s = s0 + s1*lambda where s0, s1 are guaranteed to be small
*
* Note: This assumes that s0 and s1 are range-checked externally; in scalar multiplication this happens because they are split into chunks.
*/
function decomposeNoRangeCheck(Curve: CurveAffine, s: Field3) {
assert(
Curve.Endo.decomposeMaxBits < l2,
'decomposed scalars assumed to be < 2*88 bits'
);
// witness s0, s1
let witnesses = exists(6, () => {
let [s0, s1] = Curve.Endo.decompose(Field3.toBigint(s));
let [s00, s01] = split(s0.abs);
let [s10, s11] = split(s1.abs);
// prettier-ignore
return [
s0.isNegative ? 1n : 0n, s00, s01,
s1.isNegative ? 1n : 0n, s10, s11,
];
});
let [s0Negative, s00, s01, s1Negative, s10, s11] = witnesses;
// we can hard-code highest limb to zero
// (in theory this would allow us to hard-code the high quotient limb to zero in the ffmul below, and save 2 RCs.. but not worth it)
let s0: Field3 = [s00, s01, Field.from(0n)];
let s1: Field3 = [s10, s11, Field.from(0n)];
assertBoolean(s0Negative);
assertBoolean(s1Negative);

// prove that s1*lambda = s - s0
let lambda = Provable.if(
Bool.Unsafe.ofField(s1Negative),
Field3.provable,
Field3.from(Curve.Scalar.negate(Curve.Endo.scalar)),
Field3.from(Curve.Endo.scalar)
);
let rhs = Provable.if(
Bool.Unsafe.ofField(s0Negative),
Field3.provable,
ForeignField.Sum(s).add(s0).finish(Curve.order),
ForeignField.Sum(s).sub(s0).finish(Curve.order)
);
ForeignField.assertMul(s1, lambda, rhs, Curve.order);

return Curve.Scalar.equal(R.x, r);
return [
{ isNegative: s0Negative, abs: s0 },
{ isNegative: s1Negative, abs: s1 },
] as const;
}

/**
Expand Down Expand Up @@ -684,3 +797,19 @@ const Ecdsa = {
verify: verifyEcdsa,
Signature: EcdsaSignature,
};

// MRC stack

function reduceMrcStack(xs: Field[]) {
let n = xs.length;
let nRemaining = n % 3;
let nFull = (n - nRemaining) / 3;
for (let i = 0; i < nFull; i++) {
multiRangeCheck([xs[3 * i], xs[3 * i + 1], xs[3 * i + 2]]);
}
let remaining: Field3 = [Field.from(0n), Field.from(0n), Field.from(0n)];
for (let i = 0; i < nRemaining; i++) {
remaining[i] = xs[3 * nFull + i];
}
multiRangeCheck(remaining);
}
Loading

0 comments on commit 3480c54

Please sign in to comment.