diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f136cec21..3904cf0a1b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,8 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm - `Gadgets.rangeCheck64()`, new provable method to do efficient 64-bit range checks using lookup tables https://github.com/o1-labs/o1js/pull/1181 +- Added bitwise `XOR` operation support for native field elements. https://github.com/o1-labs/o1js/pull/1177 + - `Proof.dummy()` to create dummy proofs https://github.com/o1-labs/o1js/pull/1188 - You can use this to write ZkPrograms that handle the base case and the inductive case in the same method. diff --git a/src/bindings b/src/bindings index 6db2442764..dbe878db43 160000 --- a/src/bindings +++ b/src/bindings @@ -1 +1 @@ -Subproject commit 6db2442764e48977dc25145b5b048f272e3995f5 +Subproject commit dbe878db43d256ac3085f248551b05b75ffecfda diff --git a/src/examples/gadgets.ts b/src/examples/gadgets.ts new file mode 100644 index 0000000000..2048b294fa --- /dev/null +++ b/src/examples/gadgets.ts @@ -0,0 +1,31 @@ +import { Field, Provable, Gadgets, Experimental } from 'o1js'; + +const XOR = Experimental.ZkProgram({ + methods: { + baseCase: { + privateInputs: [], + method: () => { + let a = Provable.witness(Field, () => Field(5)); + let b = Provable.witness(Field, () => Field(2)); + let actual = Gadgets.xor(a, b, 4); + let expected = Field(7); + actual.assertEquals(expected); + }, + }, + }, +}); + +console.log('compiling..'); + +console.time('compile'); +await XOR.compile(); +console.timeEnd('compile'); + +console.log('proving..'); + +console.time('prove'); +let proof = await XOR.baseCase(); +console.timeEnd('prove'); + +if (!(await XOR.verify(proof))) throw Error('Invalid proof'); +else console.log('proof valid'); diff --git a/src/examples/primitive_constraint_system.ts b/src/examples/primitive_constraint_system.ts index 17735faa04..1ef5a5f87c 100644 --- a/src/examples/primitive_constraint_system.ts +++ b/src/examples/primitive_constraint_system.ts @@ -1,4 +1,4 @@ -import { Field, Group, Poseidon, Provable, Scalar } from 'o1js'; +import { Field, Group, Gadgets, Provable, Scalar } from 'o1js'; function mock(obj: { [K: string]: (...args: any) => void }, name: string) { let methodKeys = Object.keys(obj); @@ -63,4 +63,16 @@ const GroupMock = { }, }; +const BitwiseMock = { + xor() { + let a = Provable.witness(Field, () => new Field(5n)); + let b = Provable.witness(Field, () => new Field(5n)); + Gadgets.xor(a, b, 16); + Gadgets.xor(a, b, 32); + Gadgets.xor(a, b, 48); + Gadgets.xor(a, b, 64); + }, +}; + export const GroupCS = mock(GroupMock, 'Group Primitive'); +export const BitwiseCS = mock(BitwiseMock, 'Bitwise Primitive'); diff --git a/src/examples/regression_test.json b/src/examples/regression_test.json index e492f2d163..817796fa3a 100644 --- a/src/examples/regression_test.json +++ b/src/examples/regression_test.json @@ -164,5 +164,18 @@ "data": "", "hash": "" } + }, + "Bitwise Primitive": { + "digest": "Bitwise Primitive", + "methods": { + "xor": { + "rows": 15, + "digest": "b3595a9cc9562d4f4a3a397b6de44971" + } + }, + "verificationKey": { + "data": "", + "hash": "" + } } } \ No newline at end of file diff --git a/src/examples/vk_regression.ts b/src/examples/vk_regression.ts index 6dbcd31373..a76b41aef4 100644 --- a/src/examples/vk_regression.ts +++ b/src/examples/vk_regression.ts @@ -3,7 +3,7 @@ import { Voting_ } from './zkapps/voting/voting.js'; import { Membership_ } from './zkapps/voting/membership.js'; import { HelloWorld } from './zkapps/hello_world/hello_world.js'; import { TokenContract, createDex } from './zkapps/dex/dex.js'; -import { GroupCS } from './primitive_constraint_system.js'; +import { GroupCS, BitwiseCS } from './primitive_constraint_system.js'; // toggle this for quick iteration when debugging vk regressions const skipVerificationKeys = false; @@ -37,6 +37,7 @@ const ConstraintSystems: MinimumConstraintSystem[] = [ TokenContract, createDex().Dex, GroupCS, + BitwiseCS, ]; let filePath = jsonPath ? jsonPath : './src/examples/regression_test.json'; diff --git a/src/lib/field.ts b/src/lib/field.ts index 29193d64db..5d810ef20f 100644 --- a/src/lib/field.ts +++ b/src/lib/field.ts @@ -19,6 +19,7 @@ export { withMessage, readVarMessage, toConstantField, + toFp, }; type FieldConst = [0, bigint]; @@ -1245,13 +1246,24 @@ class Field { /** * **Warning**: This function is mainly for internal use. Normally it is not intended to be used by a zkApp developer. * - * As all {@link Field} elements have 31 bits, this function returns 31. + * As all {@link Field} elements have 32 bytes, this function returns 32. * - * @return The size of a {@link Field} element - 31. + * @return The size of a {@link Field} element - 32. */ static sizeInBytes() { return Fp.sizeInBytes(); } + + /** + * **Warning**: This function is mainly for internal use. Normally it is not intended to be used by a zkApp developer. + * + * As all {@link Field} elements have 255 bits, this function returns 255. + * + * @return The size of a {@link Field} element in bits - 255. + */ + static sizeInBits() { + return Fp.sizeInBits; + } } const FieldBinable = defineBinable({ diff --git a/src/lib/gadgets/bitwise.ts b/src/lib/gadgets/bitwise.ts new file mode 100644 index 0000000000..53e6c432a4 --- /dev/null +++ b/src/lib/gadgets/bitwise.ts @@ -0,0 +1,131 @@ +import { Provable } from '../provable.js'; +import { Field as Fp } from '../../provable/field-bigint.js'; +import { Field } from '../field.js'; +import * as Gates from '../gates.js'; + +export { xor }; + +function xor(a: Field, b: Field, length: number) { + // check that both input lengths are positive + assert(length > 0, `Input lengths need to be positive values.`); + + // check that length does not exceed maximum field size in bits + assert( + length <= Field.sizeInBits(), + `Length ${length} exceeds maximum of ${Field.sizeInBits()} bits.` + ); + + // obtain pad length until the length is a multiple of 16 for n-bit length lookup table + let padLength = Math.ceil(length / 16) * 16; + + // handle constant case + if (a.isConstant() && b.isConstant()) { + let max = 1n << BigInt(padLength); + + assert( + a.toBigInt() < max, + `${a.toBigInt()} does not fit into ${padLength} bits` + ); + + assert( + b.toBigInt() < max, + `${b.toBigInt()} does not fit into ${padLength} bits` + ); + + return new Field(Fp.xor(a.toBigInt(), b.toBigInt())); + } + + // calculate expected xor output + let outputXor = Provable.witness( + Field, + () => new Field(Fp.xor(a.toBigInt(), b.toBigInt())) + ); + + // builds the xor gadget chain + buildXor(a, b, outputXor, padLength); + + // return the result of the xor operation + return outputXor; +} + +// builds a xor chain +function buildXor( + a: Field, + b: Field, + expectedOutput: Field, + padLength: number +) { + // construct the chain of XORs until padLength is 0 + while (padLength !== 0) { + // slices the inputs into 4x 4bit-sized chunks + // slices of a + let in1_0 = witnessSlices(a, 0, 4); + let in1_1 = witnessSlices(a, 4, 4); + let in1_2 = witnessSlices(a, 8, 4); + let in1_3 = witnessSlices(a, 12, 4); + + // slices of b + let in2_0 = witnessSlices(b, 0, 4); + let in2_1 = witnessSlices(b, 4, 4); + let in2_2 = witnessSlices(b, 8, 4); + let in2_3 = witnessSlices(b, 12, 4); + + // slices of expected output + let out0 = witnessSlices(expectedOutput, 0, 4); + let out1 = witnessSlices(expectedOutput, 4, 4); + let out2 = witnessSlices(expectedOutput, 8, 4); + let out3 = witnessSlices(expectedOutput, 12, 4); + + // assert that the xor of the slices is correct, 16 bit at a time + Gates.xor( + a, + b, + expectedOutput, + in1_0, + in1_1, + in1_2, + in1_3, + in2_0, + in2_1, + in2_2, + in2_3, + out0, + out1, + out2, + out3 + ); + + // update the values for the next loop iteration + a = witnessNextValue(a); + b = witnessNextValue(b); + expectedOutput = witnessNextValue(expectedOutput); + padLength = padLength - 16; + } + + // inputs are zero and length is zero, add the zero check - we reached the end of our chain + Gates.zero(a, b, expectedOutput); + + let zero = new Field(0); + zero.assertEquals(a); + zero.assertEquals(b); + zero.assertEquals(expectedOutput); +} + +function assert(stmt: boolean, message?: string) { + if (!stmt) { + throw Error(message ?? 'Assertion failed'); + } +} + +function witnessSlices(f: Field, start: number, length: number) { + if (length <= 0) throw Error('Length must be a positive number'); + + return Provable.witness(Field, () => { + let n = f.toBigInt(); + return new Field((n >> BigInt(start)) & ((1n << BigInt(length)) - 1n)); + }); +} + +function witnessNextValue(current: Field) { + return Provable.witness(Field, () => new Field(current.toBigInt() >> 16n)); +} diff --git a/src/lib/gadgets/bitwise.unit-test.ts b/src/lib/gadgets/bitwise.unit-test.ts new file mode 100644 index 0000000000..4b92ce7aff --- /dev/null +++ b/src/lib/gadgets/bitwise.unit-test.ts @@ -0,0 +1,58 @@ +import { ZkProgram } from '../proof_system.js'; +import { + Spec, + equivalent, + equivalentAsync, + field, + fieldWithRng, +} from '../testing/equivalent.js'; +import { Fp, mod } from '../../bindings/crypto/finite_field.js'; +import { Field } from '../field.js'; +import { Gadgets } from './gadgets.js'; +import { Random } from '../testing/property.js'; + +let Bitwise = ZkProgram({ + publicOutput: Field, + methods: { + xor: { + privateInputs: [Field, Field], + method(a: Field, b: Field) { + return Gadgets.xor(a, b, 64); + }, + }, + }, +}); + +await Bitwise.compile(); + +let uint = (length: number) => fieldWithRng(Random.biguint(length)); + +[2, 4, 8, 16, 32, 64, 128].forEach((length) => { + equivalent({ from: [uint(length), uint(length)], to: field })( + Fp.xor, + (x, y) => Gadgets.xor(x, y, length) + ); +}); + +let maybeUint64: Spec = { + ...field, + rng: Random.map(Random.oneOf(Random.uint64, Random.uint64.invalid), (x) => + mod(x, Field.ORDER) + ), +}; + +// do a couple of proofs +await equivalentAsync( + { from: [maybeUint64, maybeUint64], to: field }, + { runs: 3 } +)( + (x, y) => { + if (x >= 2n ** 64n || y >= 2n ** 64n) + throw Error('Does not fit into 64 bits'); + return Fp.xor(x, y); + }, + async (x, y) => { + let proof = await Bitwise.xor(x, y); + return proof.publicOutput; + } +); diff --git a/src/lib/gadgets/gadgets.ts b/src/lib/gadgets/gadgets.ts index b4c799b2cb..5f993a9727 100644 --- a/src/lib/gadgets/gadgets.ts +++ b/src/lib/gadgets/gadgets.ts @@ -2,6 +2,7 @@ * Wrapper file for various gadgets, with a namespace and doccomments. */ import { rangeCheck64 } from './range-check.js'; +import { xor } from './bitwise.js'; import { Field } from '../core.js'; export { Gadgets }; @@ -33,4 +34,31 @@ const Gadgets = { rangeCheck64(x: Field) { return rangeCheck64(x); }, + + /** + * Bitwise XOR gadget on {@link Field} elements. Equivalent to the [bitwise XOR `^` operator in JavaScript](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/Bitwise_XOR). + * A XOR gate works by comparing two bits and returning `1` if two bits differ, and `0` if two bits are equal. + * + * This gadget builds a chain of XOR gates recursively. Each XOR gate can verify 16 bit at most. If your input elements exceed 16 bit, another XOR gate will be added to the chain. + * + * The `length` parameter lets you define how many bits should be compared. `length` is rounded to the nearest multiple of 16, `paddedLength = ceil(length / 16) * 16`, and both input values are constrained to fit into `paddedLength` bits. The output is guaranteed to have at most `paddedLength` bits as well. + * + * **Note:** Specifying a larger `length` parameter adds additional constraints. + * It is also important to mention that specifying a smaller `length` allows the verifier to infer the length of the original input data (e.g. smaller than 16 bit if only one XOR gate has been used). + * A zkApp developer should consider these implications when choosing the `length` parameter and carefully weigh the trade-off between increased amount of constraints and security. + * + * **Note:** Both {@link Field} elements need to fit into `2^paddedLength - 1`. Otherwise, an error is thrown and no proof can be generated.. + * For example, with `length = 2` (`paddedLength = 16`), `xor()` will fail for any input that is larger than `2**16`. + * + * ```typescript + * let a = Field(5); // ... 000101 + * let b = Field(3); // ... 000011 + * + * let c = xor(a, b, 2); // ... 000110 + * c.assertEquals(6); + * ``` + */ + xor(a: Field, b: Field, length: number) { + return xor(a, b, length); + }, }; diff --git a/src/lib/gadgets/gadgets.unit-test.ts b/src/lib/gadgets/range-check.unit-test.ts similarity index 93% rename from src/lib/gadgets/gadgets.unit-test.ts rename to src/lib/gadgets/range-check.unit-test.ts index 13a44a059d..669f811174 100644 --- a/src/lib/gadgets/gadgets.unit-test.ts +++ b/src/lib/gadgets/range-check.unit-test.ts @@ -35,7 +35,7 @@ let maybeUint64: Spec = { // do a couple of proofs // TODO: we use this as a test because there's no way to check custom gates quickly :( -equivalentAsync({ from: [maybeUint64], to: boolean }, { runs: 3 })( +await equivalentAsync({ from: [maybeUint64], to: boolean }, { runs: 3 })( (x) => { if (x >= 1n << 64n) throw Error('expected 64 bits'); return true; diff --git a/src/lib/gates.ts b/src/lib/gates.ts index 503c4d2c6a..1a5a9a3c00 100644 --- a/src/lib/gates.ts +++ b/src/lib/gates.ts @@ -1,7 +1,7 @@ import { Snarky } from '../snarky.js'; import { FieldVar, FieldConst, type Field } from './field.js'; -export { rangeCheck64 }; +export { rangeCheck64, xor, zero }; /** * Asserts that x is at most 64 bits @@ -42,6 +42,49 @@ function rangeCheck64(x: Field) { ); } +/** + * Asserts that 16 bit limbs of input two elements are the correct XOR output + */ +function xor( + input1: Field, + input2: Field, + outputXor: Field, + in1_0: Field, + in1_1: Field, + in1_2: Field, + in1_3: Field, + in2_0: Field, + in2_1: Field, + in2_2: Field, + in2_3: Field, + out0: Field, + out1: Field, + out2: Field, + out3: Field +) { + Snarky.gates.xor( + input1.value, + input2.value, + outputXor.value, + in1_0.value, + in1_1.value, + in1_2.value, + in1_3.value, + in2_0.value, + in2_1.value, + in2_2.value, + in2_3.value, + out0.value, + out1.value, + out2.value, + out3.value + ); +} + +function zero(a: Field, b: Field, c: Field) { + Snarky.gates.zero(a.value, b.value, c.value); +} + function getBits(x: bigint, start: number, length: number) { return FieldConst.fromBigint( (x >> BigInt(start)) & ((1n << BigInt(length)) - 1n) diff --git a/src/lib/int.ts b/src/lib/int.ts index f4143a9a5b..a48118fa2c 100644 --- a/src/lib/int.ts +++ b/src/lib/int.ts @@ -458,6 +458,7 @@ class UInt32 extends CircuitValue { static MAXINT() { return new UInt32(Field((1n << 32n) - 1n)); } + /** * Integer division with remainder. * diff --git a/src/lib/testing/equivalent.ts b/src/lib/testing/equivalent.ts index c19748624e..22c1954a91 100644 --- a/src/lib/testing/equivalent.ts +++ b/src/lib/testing/equivalent.ts @@ -15,6 +15,7 @@ export { handleErrors, deepEqual as defaultAssertEqual, id, + fieldWithRng, }; export { field, bigintField, bool, boolean, unit }; export { Spec, ToSpec, FromSpec, SpecFromFunctions, ProvableSpec }; @@ -240,6 +241,10 @@ let boolean: Spec = { back: id, }; +function fieldWithRng(rng: Random): Spec { + return { ...field, rng }; +} + // helper to ensure two functions throw equivalent errors function handleErrors( diff --git a/src/snarky.d.ts b/src/snarky.d.ts index 7fb7f27fb8..512ac61ed1 100644 --- a/src/snarky.d.ts +++ b/src/snarky.d.ts @@ -306,6 +306,26 @@ declare const Snarky: { ], compact: FieldConst ): void; + + xor( + in1: FieldVar, + in2: FieldVar, + out: FieldVar, + in1_0: FieldVar, + in1_1: FieldVar, + in1_2: FieldVar, + in1_3: FieldVar, + in2_0: FieldVar, + in2_1: FieldVar, + in2_2: FieldVar, + in2_3: FieldVar, + out_0: FieldVar, + out_1: FieldVar, + out_2: FieldVar, + out_3: FieldVar + ): void; + + zero(in1: FieldVar, in2: FieldVar, out: FieldVar): void; }; bool: {