From 2cdcc421510dd21d424979be3f972e08c210235e Mon Sep 17 00:00:00 2001 From: Ryan Kim Date: Fri, 12 Jul 2024 16:08:05 +0900 Subject: [PATCH] feat(halo2_proofs): add tachyon halo2 rust binding --- Cargo.lock | 126 +- halo2_proofs/Cargo.toml | 5 + halo2_proofs/build.rs | 48 + halo2_proofs/examples/simple-lookup.rs | 148 ++ halo2_proofs/include/bn254_blake2b_writer.h | 38 + halo2_proofs/include/bn254_evals.h | 42 + halo2_proofs/include/bn254_poly.h | 34 + halo2_proofs/include/bn254_poseidon_writer.h | 35 + halo2_proofs/include/bn254_prover.h | 68 + halo2_proofs/include/bn254_proving_key.h | 50 + halo2_proofs/include/bn254_rational_evals.h | 44 + .../include/bn254_rational_evals_view.h | 34 + halo2_proofs/include/bn254_sha256_writer.h | 38 + .../bn254_snark_verifier_poseidon_writer.h | 38 + halo2_proofs/include/cha_cha20_rng.h | 40 + halo2_proofs/include/xor_shift_rng.h | 40 + halo2_proofs/src/bn254.rs | 1225 +++++++++++++++++ halo2_proofs/src/bn254_blake2b_writer.cc | 45 + halo2_proofs/src/bn254_evals.cc | 27 + halo2_proofs/src/bn254_poly.cc | 9 + halo2_proofs/src/bn254_poseidon_writer.cc | 43 + halo2_proofs/src/bn254_prover.cc | 210 +++ halo2_proofs/src/bn254_proving_key.cc | 112 ++ halo2_proofs/src/bn254_rational_evals.cc | 28 + halo2_proofs/src/bn254_rational_evals_view.cc | 35 + halo2_proofs/src/bn254_sha256_writer.cc | 44 + .../bn254_snark_verifier_poseidon_writer.cc | 44 + halo2_proofs/src/cha_cha20_rng.cc | 44 + halo2_proofs/src/cha_cha20_rng.rs | 121 ++ halo2_proofs/src/consts.rs | 27 + halo2_proofs/src/lib.rs | 6 + halo2_proofs/src/plonk.rs | 2 + halo2_proofs/src/plonk/tachyon.rs | 572 ++++++++ halo2_proofs/src/rng.rs | 6 + halo2_proofs/src/rust_vec.h | 50 + halo2_proofs/src/xor_shift_rng.cc | 44 + halo2_proofs/src/xor_shift_rng.rs | 111 ++ 37 files changed, 3614 insertions(+), 19 deletions(-) create mode 100644 halo2_proofs/build.rs create mode 100644 halo2_proofs/examples/simple-lookup.rs create mode 100644 halo2_proofs/include/bn254_blake2b_writer.h create mode 100644 halo2_proofs/include/bn254_evals.h create mode 100644 halo2_proofs/include/bn254_poly.h create mode 100644 halo2_proofs/include/bn254_poseidon_writer.h create mode 100644 halo2_proofs/include/bn254_prover.h create mode 100644 halo2_proofs/include/bn254_proving_key.h create mode 100644 halo2_proofs/include/bn254_rational_evals.h create mode 100644 halo2_proofs/include/bn254_rational_evals_view.h create mode 100644 halo2_proofs/include/bn254_sha256_writer.h create mode 100644 halo2_proofs/include/bn254_snark_verifier_poseidon_writer.h create mode 100644 halo2_proofs/include/cha_cha20_rng.h create mode 100644 halo2_proofs/include/xor_shift_rng.h create mode 100644 halo2_proofs/src/bn254.rs create mode 100644 halo2_proofs/src/bn254_blake2b_writer.cc create mode 100644 halo2_proofs/src/bn254_evals.cc create mode 100644 halo2_proofs/src/bn254_poly.cc create mode 100644 halo2_proofs/src/bn254_poseidon_writer.cc create mode 100644 halo2_proofs/src/bn254_prover.cc create mode 100644 halo2_proofs/src/bn254_proving_key.cc create mode 100644 halo2_proofs/src/bn254_rational_evals.cc create mode 100644 halo2_proofs/src/bn254_rational_evals_view.cc create mode 100644 halo2_proofs/src/bn254_sha256_writer.cc create mode 100644 halo2_proofs/src/bn254_snark_verifier_poseidon_writer.cc create mode 100644 halo2_proofs/src/cha_cha20_rng.cc create mode 100644 halo2_proofs/src/cha_cha20_rng.rs create mode 100644 halo2_proofs/src/consts.rs create mode 100644 halo2_proofs/src/plonk/tachyon.rs create mode 100644 halo2_proofs/src/rng.rs create mode 100644 halo2_proofs/src/rust_vec.h create mode 100644 halo2_proofs/src/xor_shift_rng.cc create mode 100644 halo2_proofs/src/xor_shift_rng.rs diff --git a/Cargo.lock b/Cargo.lock index ad7523ba..8ca502b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -220,9 +220,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.73" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11" +checksum = "066fce287b1d4eafef758e89e09d724a24808a9196fe9756b8ca90e86d0719a2" [[package]] name = "cfg-if" @@ -256,6 +256,16 @@ dependencies = [ "cc", ] +[[package]] +name = "codespan-reporting" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" +dependencies = [ + "termcolor", + "unicode-width", +] + [[package]] name = "color_quant" version = "1.1.0" @@ -492,6 +502,50 @@ dependencies = [ "memchr", ] +[[package]] +name = "cxx" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "273dcfd3acd4e1e276af13ed2a43eea7001318823e7a726a6b3ed39b4acc0b82" +dependencies = [ + "cc", + "cxxbridge-flags", + "cxxbridge-macro", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8b2766fbd92be34e9ed143898fce6c572dc009de39506ed6903e5a05b68914e" +dependencies = [ + "cc", + "codespan-reporting", + "once_cell", + "proc-macro2", + "quote", + "scratch", + "syn 2.0.70", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "839fcd5e43464614ffaa989eaf1c139ef1f0c51672a1ed08023307fa1b909ccd" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.124" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b2c1c1776b986979be68bb2285da855f8d8a35851a769fca8740df7c3d07877" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.70", +] + [[package]] name = "darling" version = "0.10.2" @@ -513,7 +567,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn", + "syn 1.0.91", ] [[package]] @@ -524,7 +578,7 @@ checksum = "d9b5a2f4ac4969822c62224815d069952656cadc7084fdca9751e6d959189b72" dependencies = [ "darling_core", "quote", - "syn", + "syn 1.0.91", ] [[package]] @@ -556,7 +610,7 @@ dependencies = [ "derive_builder_core", "proc-macro2", "quote", - "syn", + "syn 1.0.91", ] [[package]] @@ -568,7 +622,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn", + "syn 1.0.91", ] [[package]] @@ -810,7 +864,7 @@ checksum = "729f9bd3449d77e7831a18abfb7ba2f99ee813dfd15b8c2167c9a54ba20aa99d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.91", ] [[package]] @@ -863,6 +917,8 @@ dependencies = [ "cfg-if 0.1.10", "criterion", "crossbeam", + "cxx", + "cxx-build", "env_logger", "ff", "getrandom", @@ -1055,6 +1111,15 @@ version = "0.2.137" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89" +[[package]] +name = "link-cplusplus" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d240c6f7e1ba3a28b0249f774e6a9dd0175054b52dfbb61b16eb8505c3785c9" +dependencies = [ + "cc", +] + [[package]] name = "lock_api" version = "0.4.7" @@ -1230,9 +1295,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.15.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "oorandom" @@ -1415,11 +1480,11 @@ checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" [[package]] name = "proc-macro2" -version = "1.0.37" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec757218438d5fda206afc041538b2f6d889286160d649a86a24d37e1235afd1" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" dependencies = [ - "unicode-xid", + "unicode-ident", ] [[package]] @@ -1465,9 +1530,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.18" +version = "1.0.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1feb54ed693b93a84e14094943b84b7c4eae204c512b7ccb95ab0c66d278ad1" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" dependencies = [ "proc-macro2", ] @@ -1650,6 +1715,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "scratch" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152" + [[package]] name = "semver" version = "1.0.7" @@ -1692,7 +1763,7 @@ checksum = "08597e7152fcd306f41838ed3e37be9eaeed2b61c42e2117266a554fab4662f9" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.91", ] [[package]] @@ -1815,6 +1886,17 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "syn" +version = "2.0.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f0209b68b3613b093e0ec905354eccaedcfe83b8cb37cbdeae64026c3064c16" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "tabbycat" version = "0.1.2" @@ -1881,7 +1963,7 @@ checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.91", ] [[package]] @@ -1914,7 +1996,7 @@ checksum = "11c75893af559bc8e10716548bdef5cb2b983f8e637db9d0e15126b61b484ee2" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.91", ] [[package]] @@ -1950,6 +2032,12 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + [[package]] name = "unicode-width" version = "0.1.9" @@ -2021,7 +2109,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn", + "syn 1.0.91", "wasm-bindgen-shared", ] @@ -2043,7 +2131,7 @@ checksum = "99ec0dc7a4756fffc231aab1b9f2f578d23cd391390ab27f952ae0c9b3ece20b" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.91", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/halo2_proofs/Cargo.toml b/halo2_proofs/Cargo.toml index ba76b868..0587129f 100644 --- a/halo2_proofs/Cargo.toml +++ b/halo2_proofs/Cargo.toml @@ -75,6 +75,8 @@ log = "0.4.17" # timer ark-std = { version = "0.3.0" } +# binding +cxx = "1.0" # Legacy circuit compatibility halo2_legacy_pdqsort = { version = "0.1.0", optional = true } @@ -87,6 +89,9 @@ gumdrop = "0.8" proptest = "1" rand_core = { version = "0.6", default-features = false, features = ["getrandom"] } +[build-dependencies] +cxx-build = "1.0" + [target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dev-dependencies] getrandom = { version = "0.2", features = ["js"] } diff --git a/halo2_proofs/build.rs b/halo2_proofs/build.rs new file mode 100644 index 00000000..40ca2c4f --- /dev/null +++ b/halo2_proofs/build.rs @@ -0,0 +1,48 @@ +fn main() { + let src_files = [ + "src/bn254_blake2b_writer.cc", + "src/bn254_evals.cc", + "src/bn254_poly.cc", + "src/bn254_poseidon_writer.cc", + "src/bn254_prover.cc", + "src/bn254_proving_key.cc", + "src/bn254_rational_evals.cc", + "src/bn254_rational_evals_view.cc", + "src/bn254_sha256_writer.cc", + "src/bn254_snark_verifier_poseidon_writer.cc", + "src/cha_cha20_rng.cc", + "src/xor_shift_rng.cc", + ]; + cxx_build::bridges([ + "src/bn254.rs", + "src/cha_cha20_rng.rs", + "src/xor_shift_rng.rs", + ]) + .files(src_files) + .flag_if_supported("-std=c++17") + .compile("halo2_proofs"); + + let mut dep_files = vec![ + "include/bn254_blake2b_writer.h", + "include/bn254_evals.h", + "include/bn254_poly.h", + "include/bn254_poseidon_writer.h", + "include/bn254_prover.h", + "include/bn254_proving_key.h", + "include/bn254_rational_evals.h", + "include/bn254_rational_evals_view.h", + "include/bn254_sha256_writer.h", + "include/bn254_snark_verifier_poseidon_writer.h", + "include/cha_cha20_rng.h", + "include/xor_shift_rng.h", + "src/bn254.rs", + "src/rust_vec.h", + "src/xor_shift_rng.rs", + ]; + dep_files.extend_from_slice(&src_files); + for file in dep_files { + println!("cargo:rerun-if-changed={file}"); + } + + println!("cargo:rustc-link-lib=dylib=tachyon"); +} diff --git a/halo2_proofs/examples/simple-lookup.rs b/halo2_proofs/examples/simple-lookup.rs new file mode 100644 index 00000000..d0b6d695 --- /dev/null +++ b/halo2_proofs/examples/simple-lookup.rs @@ -0,0 +1,148 @@ +use std::marker::PhantomData; + +use ff::{Field, PrimeField}; +use halo2_proofs::{ + bn254::Blake2bWrite, + circuit::{Layouter, SimpleFloorPlanner, Value}, + consts::XOR_SHIFT_SEED, + plonk::{ + create_proof, keygen_pk2, Advice, Circuit, Column, ConstraintSystem, Error, Expression, + Selector, TableColumn, + }, + poly::{ + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::ProverGWC, + }, + Rotation, + }, + transcript::{Challenge255, TranscriptWriterBuffer}, + xor_shift_rng::XORShiftRng, +}; +use halo2curves::bn256::{Bn256, G1Affine}; +use rand_core::SeedableRng; + +#[derive(Clone, Default)] +struct SimpleLookupCircuit { + _marker: PhantomData, +} + +#[derive(Clone)] +struct SimpleLookupConfig { + selector: Selector, + table: TableColumn, + advice: Column, +} + +impl Circuit for SimpleLookupCircuit { + type Config = SimpleLookupConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure(meta: &mut ConstraintSystem) -> SimpleLookupConfig { + let config = SimpleLookupConfig { + selector: meta.complex_selector(), + table: meta.lookup_table_column(), + advice: meta.advice_column(), + }; + + meta.lookup("lookup", |meta| { + let selector = meta.query_selector(config.selector); + let not_selector = Expression::Constant(F::ONE) - selector.clone(); + let advice = meta.query_advice(config.advice, Rotation::cur()); + vec![(selector * advice + not_selector, config.table)] + }); + + config + } + + fn synthesize( + &self, + config: SimpleLookupConfig, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_table( + || "3-bit table", + |mut table| { + for row in 0u64..(1 << 3) { + table.assign_cell( + || format!("row {}", row), + config.table, + row as usize, + || Value::known(F::from(row + 1)), + )?; + } + + Ok(()) + }, + )?; + + layouter.assign_region( + || "assign values", + |mut region| { + for offset in 0u64..(1 << 4) { + config.selector.enable(&mut region, offset as usize)?; + region.assign_advice( + || format!("offset {}", offset), + config.advice, + offset as usize, + || Value::known(F::from((offset % 8) + 1)), + )?; + } + + Ok(()) + }, + ) + } +} + +fn main() { + let vec = vec![ + 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, + ]; + let a = vec.binary_search(&1); + println!("{:?}", a); + + use halo2curves::bn256::Fr; + + env_logger::init(); + + // ANCHOR: test-circuit + // The number of rows in our circuit cannot exceed 2^k. Since our example + // circuit is very small, we can pick a very small value here. + let k = 5; + + // Instantiate the circuit with the private inputs. + let circuit = SimpleLookupCircuit:: { + _marker: PhantomData, + }; + // Arrange the public input. + let public_inputs = vec![]; + let public_inputs2 = vec![&public_inputs[..], &public_inputs[..]]; + + let s = Fr::from(2); + let params = ParamsKZG::::unsafe_setup_with_s(k, s.clone()); + let pk = keygen_pk2(¶ms, &circuit).expect("vk should not fail"); + + let rng = XORShiftRng::from_seed(XOR_SHIFT_SEED); + + let mut transcript = Blake2bWrite::<_, G1Affine, Challenge255<_>>::init(vec![]); + + create_proof::, ProverGWC<_>, _, _, _, _>( + ¶ms, + &pk, + &[circuit.clone(), circuit.clone()], + public_inputs2.as_slice(), + rng.clone(), + &mut transcript, + ) + .expect("proof generation should not fail"); + + let proof = transcript.finalize(); + + println!("done!"); +} diff --git a/halo2_proofs/include/bn254_blake2b_writer.h b/halo2_proofs/include/bn254_blake2b_writer.h new file mode 100644 index 00000000..a2f525aa --- /dev/null +++ b/halo2_proofs/include/bn254_blake2b_writer.h @@ -0,0 +1,38 @@ +#ifndef HALO2_PROOFS_INCLUDE_BN254_BLAKE2B_WRITER_H_ +#define HALO2_PROOFS_INCLUDE_BN254_BLAKE2B_WRITER_H_ + +#include +#include + +#include +#include + +#include + +#include "rust/cxx.h" + +namespace tachyon::halo2_api::bn254 { + +constexpr size_t kBlake2bDigestLength = 64; +constexpr size_t kBlake2bStateLength = 216; + +class Blake2bWriter { + public: + Blake2bWriter(); + Blake2bWriter(const Blake2bWriter& other) = delete; + Blake2bWriter& operator=(const Blake2bWriter& other) = delete; + ~Blake2bWriter(); + + void update(rust::Slice data); + void finalize(std::array& result); + rust::Vec state() const; + + private: + tachyon_halo2_bn254_transcript_writer* writer_; +}; + +std::unique_ptr new_blake2b_writer(); + +} // namespace tachyon::halo2_api::bn254 + +#endif // HALO2_PROOFS_INCLUDE_BN254_BLAKE2B_WRITER_H_ diff --git a/halo2_proofs/include/bn254_evals.h b/halo2_proofs/include/bn254_evals.h new file mode 100644 index 00000000..57b2ab60 --- /dev/null +++ b/halo2_proofs/include/bn254_evals.h @@ -0,0 +1,42 @@ +#ifndef HALO2_PROOFS_INCLUDE_BN254_EVALS_H_ +#define HALO2_PROOFS_INCLUDE_BN254_EVALS_H_ + +#include + +#include +#include + +#include + +namespace tachyon::halo2_api::bn254 { + +struct Fr; + +class Evals { + public: + Evals(); + explicit Evals(tachyon_bn254_univariate_evaluations* evals) : evals_(evals) {} + Evals(const Evals& other) = delete; + Evals& operator=(const Evals& other) = delete; + ~Evals(); + + tachyon_bn254_univariate_evaluations* evals() { return evals_; } + const tachyon_bn254_univariate_evaluations* evals() const { return evals_; } + + tachyon_bn254_univariate_evaluations* release() { + return std::exchange(evals_, nullptr); + } + + size_t len() const; + void set_value(size_t idx, const Fr& value); + std::unique_ptr clone() const; + + private: + tachyon_bn254_univariate_evaluations* evals_; +}; + +std::unique_ptr zero_evals(); + +} // namespace tachyon::halo2_api::bn254 + +#endif // HALO2_PROOFS_INCLUDE_BN254_EVALS_H_ diff --git a/halo2_proofs/include/bn254_poly.h b/halo2_proofs/include/bn254_poly.h new file mode 100644 index 00000000..804e1cbc --- /dev/null +++ b/halo2_proofs/include/bn254_poly.h @@ -0,0 +1,34 @@ +#ifndef HALO2_PROOFS_INCLUDE_BN254_POLY_H_ +#define HALO2_PROOFS_INCLUDE_BN254_POLY_H_ + +#include + +#include + +namespace tachyon::halo2_api::bn254 { + +class Poly { + public: + Poly(); + explicit Poly(tachyon_bn254_univariate_dense_polynomial* poly) + : poly_(poly) {} + Poly(const Poly& other) = delete; + Poly& operator=(const Poly& other) = delete; + ~Poly(); + + tachyon_bn254_univariate_dense_polynomial* poly() { return poly_; } + const tachyon_bn254_univariate_dense_polynomial* poly() const { + return poly_; + } + + tachyon_bn254_univariate_dense_polynomial* release() { + return std::exchange(poly_, nullptr); + } + + private: + tachyon_bn254_univariate_dense_polynomial* poly_; +}; + +} // namespace tachyon::halo2_api::bn254 + +#endif // HALO2_PROOFS_INCLUDE_BN254_POLY_H_ diff --git a/halo2_proofs/include/bn254_poseidon_writer.h b/halo2_proofs/include/bn254_poseidon_writer.h new file mode 100644 index 00000000..c7cf8259 --- /dev/null +++ b/halo2_proofs/include/bn254_poseidon_writer.h @@ -0,0 +1,35 @@ +#ifndef HALO2_PROOFS_INCLUDE_BN254_POSEIDON_WRITER_H_ +#define HALO2_PROOFS_INCLUDE_BN254_POSEIDON_WRITER_H_ + +#include + +#include + +#include + +#include "rust/cxx.h" + +namespace tachyon::halo2_api::bn254 { + +struct Fr; + +class PoseidonWriter { + public: + PoseidonWriter(); + PoseidonWriter(const PoseidonWriter& other) = delete; + PoseidonWriter& operator=(const PoseidonWriter& other) = delete; + ~PoseidonWriter(); + + void update(rust::Slice data); + rust::Box squeeze(); + rust::Vec state() const; + + private: + tachyon_halo2_bn254_transcript_writer* writer_; +}; + +std::unique_ptr new_poseidon_writer(); + +} // namespace tachyon::halo2_api::bn254 + +#endif // HALO2_PROOFS_INCLUDE_BN254_POSEIDON_WRITER_H_ diff --git a/halo2_proofs/include/bn254_prover.h b/halo2_proofs/include/bn254_prover.h new file mode 100644 index 00000000..450741fe --- /dev/null +++ b/halo2_proofs/include/bn254_prover.h @@ -0,0 +1,68 @@ +#ifndef HALO2_PROOFS_INCLUDE_BN254_PROVER_H_ +#define HALO2_PROOFS_INCLUDE_BN254_PROVER_H_ + +#include + +#include + +#include + +#include "rust/cxx.h" + +namespace tachyon::halo2_api::bn254 { + +struct Fr; +struct G1ProjectivePoint; +struct G2AffinePoint; +struct InstanceSingle; +struct AdviceSingle; +class ProvingKey; +class Evals; +class RationalEvals; +class Poly; + +class Prover { + public: + Prover(uint8_t pcs_type, uint8_t transcript_type, uint32_t k, const Fr& s); + Prover(uint8_t pcs_type, uint8_t transcript_type, uint32_t k, + const uint8_t* params, size_t params_len); + Prover(const Prover& other) = delete; + Prover& operator=(const Prover& other) = delete; + ~Prover(); + + const tachyon_halo2_bn254_prover* prover() const { return prover_; } + + uint32_t k() const; + uint64_t n() const; + rust::Box s_g2() const; + rust::Box commit(const Poly& poly) const; + rust::Box commit_lagrange(const Evals& evals) const; + std::unique_ptr empty_evals() const; + std::unique_ptr empty_rational_evals() const; + std::unique_ptr ifft(const Evals& evals) const; + void batch_evaluate( + rust::Slice> rational_evals, + rust::Slice> evals) const; + void set_rng(uint8_t rng_type, rust::Slice state); + void set_transcript(rust::Slice state); + void set_extended_domain(const ProvingKey& pk); + void create_proof(ProvingKey& key, + rust::Slice instance_singles, + rust::Slice advice_singles, + rust::Slice challenges); + rust::Vec get_proof() const; + + private: + tachyon_halo2_bn254_prover* prover_; +}; + +std::unique_ptr new_prover(uint8_t pcs_type, uint8_t transcript_type, + uint32_t k, const Fr& s); + +std::unique_ptr new_prover_from_params( + uint8_t pcs_type, uint8_t transcript_type, uint32_t k, + rust::Slice params); + +} // namespace tachyon::halo2_api::bn254 + +#endif // HALO2_PROOFS_INCLUDE_BN254_PROVER_H_ diff --git a/halo2_proofs/include/bn254_proving_key.h b/halo2_proofs/include/bn254_proving_key.h new file mode 100644 index 00000000..6c4f243d --- /dev/null +++ b/halo2_proofs/include/bn254_proving_key.h @@ -0,0 +1,50 @@ +#ifndef HALO2_PROOFS_INCLUDE_BN254_PROVING_KEY_H_ +#define HALO2_PROOFS_INCLUDE_BN254_PROVING_KEY_H_ + +#include +#include + +#include + +#include + +#include "rust/cxx.h" + +namespace tachyon::halo2_api::bn254 { + +struct Fr; +class Prover; + +class ProvingKey { + public: + explicit ProvingKey(rust::Slice pk_bytes); + ProvingKey(const ProvingKey& other) = delete; + ProvingKey& operator=(const ProvingKey& other) = delete; + ~ProvingKey(); + + const tachyon_bn254_plonk_proving_key* pk() const { return pk_; } + tachyon_bn254_plonk_proving_key* pk() { return pk_; } + + rust::Vec advice_column_phases() const; + uint32_t blinding_factors() const; + rust::Vec challenge_phases() const; + rust::Vec constants() const; + size_t num_advice_columns() const; + size_t num_challenges() const; + size_t num_instance_columns() const; + rust::Vec phases() const; + rust::Box transcript_repr(const Prover& prover); + + private: + const tachyon_bn254_plonk_verifying_key* GetVerifyingKey() const; + const tachyon_bn254_plonk_constraint_system* GetConstraintSystem() const; + + tachyon_bn254_plonk_proving_key* pk_; +}; + +std::unique_ptr new_proving_key( + rust::Slice pk_bytes); + +} // namespace tachyon::halo2_api::bn254 + +#endif // HALO2_PROOFS_INCLUDE_BN254_PROVING_KEY_H_ diff --git a/halo2_proofs/include/bn254_rational_evals.h b/halo2_proofs/include/bn254_rational_evals.h new file mode 100644 index 00000000..889b197b --- /dev/null +++ b/halo2_proofs/include/bn254_rational_evals.h @@ -0,0 +1,44 @@ +#ifndef HALO2_PROOFS_INCLUDE_BN254_RATIONAL_EVALS_H_ +#define HALO2_PROOFS_INCLUDE_BN254_RATIONAL_EVALS_H_ + +#include + +#include +#include + +#include + +namespace tachyon::halo2_api::bn254 { + +struct Fr; +class RationalEvalsView; + +class RationalEvals { + public: + RationalEvals(); + explicit RationalEvals(tachyon_bn254_univariate_rational_evaluations* evals) + : evals_(evals) {} + RationalEvals(const RationalEvals& other) = delete; + RationalEvals& operator=(const RationalEvals& other) = delete; + ~RationalEvals(); + + tachyon_bn254_univariate_rational_evaluations* evals() { return evals_; } + const tachyon_bn254_univariate_rational_evaluations* evals() const { + return evals_; + } + + tachyon_bn254_univariate_rational_evaluations* release() { + return std::exchange(evals_, nullptr); + } + + size_t len() const; + std::unique_ptr create_view(size_t start, size_t len); + std::unique_ptr clone() const; + + private: + tachyon_bn254_univariate_rational_evaluations* evals_; +}; + +} // namespace tachyon::halo2_api::bn254 + +#endif // HALO2_PROOFS_INCLUDE_BN254_RATIONAL_EVALS_H_ diff --git a/halo2_proofs/include/bn254_rational_evals_view.h b/halo2_proofs/include/bn254_rational_evals_view.h new file mode 100644 index 00000000..c04fd5a3 --- /dev/null +++ b/halo2_proofs/include/bn254_rational_evals_view.h @@ -0,0 +1,34 @@ +#ifndef HALO2_PROOFS_INCLUDE_BN254_RATIONAL_EVALS_VIEW_H_ +#define HALO2_PROOFS_INCLUDE_BN254_RATIONAL_EVALS_VIEW_H_ + +#include + +#include + +namespace tachyon::halo2_api::bn254 { + +struct Fr; + +class RationalEvalsView { + public: + RationalEvalsView(tachyon_bn254_univariate_rational_evaluations* evals, + size_t start, size_t len); + RationalEvalsView(const RationalEvalsView& other) = delete; + RationalEvalsView& operator=(const RationalEvalsView& other) = delete; + ~RationalEvalsView() = default; + + void set_zero(size_t idx); + void set_trivial(size_t idx, const Fr& numerator); + void set_rational(size_t idx, const Fr& numerator, const Fr& denominator); + void evaluate(size_t idx, Fr& value) const; + + private: + // not owned + tachyon_bn254_univariate_rational_evaluations* const evals_; + const size_t start_ = 0; + const size_t len_ = 0; +}; + +} // namespace tachyon::halo2_api::bn254 + +#endif // HALO2_PROOFS_INCLUDE_BN254_RATIONAL_EVALS_VIEW_H_ diff --git a/halo2_proofs/include/bn254_sha256_writer.h b/halo2_proofs/include/bn254_sha256_writer.h new file mode 100644 index 00000000..85167c67 --- /dev/null +++ b/halo2_proofs/include/bn254_sha256_writer.h @@ -0,0 +1,38 @@ +#ifndef HALO2_PROOFS_INCLUDE_BN254_SHA256_WRITER_H_ +#define HALO2_PROOFS_INCLUDE_BN254_SHA256_WRITER_H_ + +#include +#include + +#include +#include + +#include + +#include "rust/cxx.h" + +namespace tachyon::halo2_api::bn254 { + +constexpr size_t kSha256DigestLength = 32; +constexpr size_t kSha256StateLength = 112; + +class Sha256Writer { + public: + Sha256Writer(); + Sha256Writer(const Sha256Writer& other) = delete; + Sha256Writer& operator=(const Sha256Writer& other) = delete; + ~Sha256Writer(); + + void update(rust::Slice data); + void finalize(std::array& result); + rust::Vec state() const; + + private: + tachyon_halo2_bn254_transcript_writer* writer_; +}; + +std::unique_ptr new_sha256_writer(); + +} // namespace tachyon::halo2_api::bn254 + +#endif // HALO2_PROOFS_INCLUDE_BN254_SHA256_WRITER_H_ diff --git a/halo2_proofs/include/bn254_snark_verifier_poseidon_writer.h b/halo2_proofs/include/bn254_snark_verifier_poseidon_writer.h new file mode 100644 index 00000000..2b37205b --- /dev/null +++ b/halo2_proofs/include/bn254_snark_verifier_poseidon_writer.h @@ -0,0 +1,38 @@ +#ifndef HALO2_PROOFS_INCLUDE_BN254_SNARK_VERIFIER_POSEIDON_WRITER_H_ +#define HALO2_PROOFS_INCLUDE_BN254_SNARK_VERIFIER_POSEIDON_WRITER_H_ + +#include + +#include + +#include + +#include "rust/cxx.h" + +namespace tachyon::halo2_api::bn254 { + +struct Fr; + +class SnarkVerifierPoseidonWriter { + public: + SnarkVerifierPoseidonWriter(); + SnarkVerifierPoseidonWriter(const SnarkVerifierPoseidonWriter& other) = + delete; + SnarkVerifierPoseidonWriter& operator=( + const SnarkVerifierPoseidonWriter& other) = delete; + ~SnarkVerifierPoseidonWriter(); + + void update(rust::Slice data); + rust::Box squeeze(); + rust::Vec state() const; + + private: + tachyon_halo2_bn254_transcript_writer* writer_; +}; + +std::unique_ptr +new_snark_verifier_poseidon_writer(); + +} // namespace tachyon::halo2_api::bn254 + +#endif // HALO2_PROOFS_INCLUDE_BN254_SNARK_VERIFIER_POSEIDON_WRITER_H_ diff --git a/halo2_proofs/include/cha_cha20_rng.h b/halo2_proofs/include/cha_cha20_rng.h new file mode 100644 index 00000000..4cc71919 --- /dev/null +++ b/halo2_proofs/include/cha_cha20_rng.h @@ -0,0 +1,40 @@ +#ifndef HALO2_PROOFS_INCLUDE_CHA_CHA20_RNG_H_ +#define HALO2_PROOFS_INCLUDE_CHA_CHA20_RNG_H_ + +#include +#include + +#include +#include + +#include + +#include "rust/cxx.h" + +namespace tachyon::halo2_api { + +class ChaCha20Rng { + public: + constexpr static size_t kSeedSize = 32; + constexpr static size_t kStateSize = sizeof(size_t) + 128; + + explicit ChaCha20Rng(tachyon_rng* rng) : rng_(rng) {} + explicit ChaCha20Rng(std::array seed); + ChaCha20Rng(const ChaCha20Rng& other) = delete; + ChaCha20Rng& operator=(const ChaCha20Rng& other) = delete; + ~ChaCha20Rng(); + + uint32_t next_u32(); + std::unique_ptr clone() const; + rust::Vec state() const; + + private: + tachyon_rng* rng_; +}; + +std::unique_ptr new_cha_cha20_rng( + std::array seed); + +} // namespace tachyon::halo2_api + +#endif // HALO2_PROOFS_INCLUDE_CHA_CHA20_RNG_H_ diff --git a/halo2_proofs/include/xor_shift_rng.h b/halo2_proofs/include/xor_shift_rng.h new file mode 100644 index 00000000..776e6cdf --- /dev/null +++ b/halo2_proofs/include/xor_shift_rng.h @@ -0,0 +1,40 @@ +#ifndef HALO2_PROOFS_INCLUDE_XOR_SHIFT_RNG_H_ +#define HALO2_PROOFS_INCLUDE_XOR_SHIFT_RNG_H_ + +#include +#include + +#include +#include + +#include + +#include "rust/cxx.h" + +namespace tachyon::halo2_api { + +class XORShiftRng { + public: + constexpr static size_t kSeedSize = 16; + constexpr static size_t kStateSize = 16; + + explicit XORShiftRng(tachyon_rng* rng) : rng_(rng) {} + explicit XORShiftRng(std::array seed); + XORShiftRng(const XORShiftRng& other) = delete; + XORShiftRng& operator=(const XORShiftRng& other) = delete; + ~XORShiftRng(); + + uint32_t next_u32(); + std::unique_ptr clone() const; + rust::Vec state() const; + + private: + tachyon_rng* rng_; +}; + +std::unique_ptr new_xor_shift_rng( + std::array seed); + +} // namespace tachyon::halo2_api + +#endif // HALO2_PROOFS_INCLUDE_XOR_SHIFT_RNG_H_ diff --git a/halo2_proofs/src/bn254.rs b/halo2_proofs/src/bn254.rs new file mode 100644 index 00000000..23454435 --- /dev/null +++ b/halo2_proofs/src/bn254.rs @@ -0,0 +1,1225 @@ +use log::trace; +use std::{ + fmt, + io::{self, Write}, + marker::PhantomData, +}; + +use crate::{ + consts::{PCSType, RNGType}, + helpers::base_to_scalar, + plonk::{sealed, Column, Fixed}, + poly::commitment::{Blind, CommitmentScheme}, + transcript::{ + Challenge255, EncodedChallenge, Transcript, TranscriptWrite, TranscriptWriterBuffer, + }, +}; +use ff::{Field, FromUniformBytes, PrimeField}; +use halo2curves::{bn256::G2Affine, Coordinates, CurveAffine}; + +#[repr(C)] +#[derive(Debug)] +pub struct Fq2 { + pub c0: Fq, + pub c1: Fq, +} + +#[repr(C)] +#[derive(Debug)] +pub struct G1Point2 { + pub x: Fq, + pub y: Fq, +} + +#[repr(C)] +#[derive(Debug)] +pub struct G1ProjectivePoint { + pub x: Fq, + pub y: Fq, + pub z: Fq, +} + +#[derive(Debug)] +pub struct G2AffinePoint { + pub x: Fq2, + pub y: Fq2, + pub infinity: bool, +} + +#[repr(transparent)] +#[derive(Debug)] +pub struct Fq(pub [u64; 4]); + +#[repr(transparent)] +#[derive(Debug)] +pub struct Fr(pub [u64; 4]); + +#[derive(Debug)] +pub struct InstanceSingle { + pub instance_values: Vec, + pub instance_polys: Vec, +} + +#[derive(Clone, Debug)] +pub struct AdviceSingle { + pub advice_polys: Vec, + pub advice_blinds: Vec>, +} + +#[cxx::bridge(namespace = "tachyon::halo2_api::bn254")] +pub mod ffi { + extern "Rust" { + type G1ProjectivePoint; + type G1Point2; + type G2AffinePoint; + type Fr; + type InstanceSingle; + type AdviceSingle; + } + + unsafe extern "C++" { + include!("halo2_proofs/include/bn254_blake2b_writer.h"); + + type Blake2bWriter; + + fn new_blake2b_writer() -> UniquePtr; + fn update(self: Pin<&mut Blake2bWriter>, data: &[u8]); + fn finalize(self: Pin<&mut Blake2bWriter>, result: &mut [u8; 64]); + fn state(&self) -> Vec; + } + + unsafe extern "C++" { + include!("halo2_proofs/include/bn254_poseidon_writer.h"); + + type PoseidonWriter; + + fn new_poseidon_writer() -> UniquePtr; + fn update(self: Pin<&mut PoseidonWriter>, data: &[u8]); + fn squeeze(self: Pin<&mut PoseidonWriter>) -> Box; + fn state(&self) -> Vec; + } + + unsafe extern "C++" { + include!("halo2_proofs/include/bn254_sha256_writer.h"); + + type Sha256Writer; + + fn new_sha256_writer() -> UniquePtr; + fn update(self: Pin<&mut Sha256Writer>, data: &[u8]); + fn finalize(self: Pin<&mut Sha256Writer>, result: &mut [u8; 32]); + fn state(&self) -> Vec; + } + + unsafe extern "C++" { + include!("halo2_proofs/include/bn254_snark_verifier_poseidon_writer.h"); + + type SnarkVerifierPoseidonWriter; + + fn new_snark_verifier_poseidon_writer() -> UniquePtr; + fn update(self: Pin<&mut SnarkVerifierPoseidonWriter>, data: &[u8]); + fn squeeze(self: Pin<&mut SnarkVerifierPoseidonWriter>) -> Box; + fn state(&self) -> Vec; + } + + unsafe extern "C++" { + include!("halo2_proofs/include/bn254_proving_key.h"); + + type ProvingKey; + + fn new_proving_key(data: &[u8]) -> UniquePtr; + fn advice_column_phases(&self) -> Vec; + fn blinding_factors(&self) -> u32; + fn challenge_phases(&self) -> Vec; + fn constants(&self) -> Vec; + fn num_advice_columns(&self) -> usize; + fn num_challenges(&self) -> usize; + fn num_instance_columns(&self) -> usize; + fn phases(&self) -> Vec; + fn transcript_repr(self: Pin<&mut ProvingKey>, prover: &Prover) -> Box; + } + + unsafe extern "C++" { + include!("halo2_proofs/include/bn254_evals.h"); + + type Evals; + + fn zero_evals() -> UniquePtr; + fn len(&self) -> usize; + fn set_value(self: Pin<&mut Evals>, idx: usize, value: &Fr); + fn clone(&self) -> UniquePtr; + } + + unsafe extern "C++" { + include!("halo2_proofs/include/bn254_rational_evals.h"); + + type RationalEvals; + + fn len(&self) -> usize; + fn create_view( + self: Pin<&mut RationalEvals>, + start: usize, + len: usize, + ) -> UniquePtr; + fn clone(&self) -> UniquePtr; + } + + unsafe extern "C++" { + include!("halo2_proofs/include/bn254_rational_evals_view.h"); + + type RationalEvalsView; + + fn set_zero(self: Pin<&mut RationalEvalsView>, idx: usize); + fn set_trivial(self: Pin<&mut RationalEvalsView>, idx: usize, numerator: &Fr); + fn set_rational( + self: Pin<&mut RationalEvalsView>, + idx: usize, + numerator: &Fr, + denominator: &Fr, + ); + fn evaluate(&self, idx: usize, value: &mut Fr); + } + + unsafe extern "C++" { + include!("halo2_proofs/include/bn254_poly.h"); + + type Poly; + } + + unsafe extern "C++" { + include!("halo2_proofs/include/bn254_prover.h"); + + type Prover; + + fn new_prover(pcs_type: u8, transcript_type: u8, k: u32, s: &Fr) -> UniquePtr; + fn new_prover_from_params( + pcs_type: u8, + transcript_type: u8, + k: u32, + params: &[u8], + ) -> UniquePtr; + fn k(&self) -> u32; + fn n(&self) -> u64; + fn s_g2(&self) -> Box; + fn commit(&self, poly: &Poly) -> Box; + fn commit_lagrange(&self, evals: &Evals) -> Box; + fn empty_evals(&self) -> UniquePtr; + fn empty_rational_evals(&self) -> UniquePtr; + fn ifft(&self, evals: &Evals) -> UniquePtr; + fn batch_evaluate( + &self, + rational_evals: &[UniquePtr], + evals: &mut [UniquePtr], + ); + fn set_rng(self: Pin<&mut Prover>, rng_type: u8, state: &[u8]); + fn set_transcript(self: Pin<&mut Prover>, state: &[u8]); + fn set_extended_domain(self: Pin<&mut Prover>, pk: &ProvingKey); + fn create_proof( + self: Pin<&mut Prover>, + key: Pin<&mut ProvingKey>, + instance_singles: &mut [InstanceSingle], + advice_singles: &mut [AdviceSingle], + challenges: &[Fr], + ); + fn get_proof(self: &Prover) -> Vec; + } +} + +impl fmt::Debug for ffi::Blake2bWriter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Blake2bWriter").finish() + } +} + +impl fmt::Debug for ffi::PoseidonWriter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PoseidonWriter").finish() + } +} + +impl fmt::Debug for ffi::Sha256Writer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Sha256Writer").finish() + } +} + +impl fmt::Debug for ffi::SnarkVerifierPoseidonWriter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SnarkVerifierPoseidonWriter").finish() + } +} + +impl fmt::Debug for ffi::ProvingKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ProvingKey").finish() + } +} + +impl fmt::Debug for ffi::Evals { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Evals").finish() + } +} + +impl fmt::Debug for ffi::RationalEvals { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RationalEvals").finish() + } +} + +impl fmt::Debug for ffi::RationalEvalsView { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RationalEvalsView").finish() + } +} + +impl fmt::Debug for ffi::Poly { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Poly").finish() + } +} + +impl fmt::Debug for ffi::Prover { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Prover").finish() + } +} + +pub trait TranscriptWriteState>: + TranscriptWrite +{ + fn state(&self) -> Vec; +} + +#[derive(Debug)] +pub struct Blake2bWrite> { + state: cxx::UniquePtr, + writer: W, + proof_idx: usize, + _marker: PhantomData<(W, C, E)>, +} + +impl Transcript> + for Blake2bWrite> +where + C::Scalar: FromUniformBytes<64>, +{ + fn squeeze_challenge(&mut self) -> Challenge255 { + // Prefix to a prover's message soliciting a challenge + const BLAKE2B_PREFIX_CHALLENGE: u8 = 0; + self.state.pin_mut().update(&[BLAKE2B_PREFIX_CHALLENGE]); + let mut result: [u8; 64] = [0; 64]; + self.state.pin_mut().finalize(&mut result); + Challenge255::::new(&result) + } + + fn common_point(&mut self, point: C) -> io::Result<()> { + // Prefix to a prover's message containing a curve point + const BLAKE2B_PREFIX_POINT: u8 = 1; + self.state.pin_mut().update(&[BLAKE2B_PREFIX_POINT]); + let coords: Coordinates = Option::from(point.coordinates()).ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + "cannot write points at infinity to the transcript", + ) + })?; + self.state.pin_mut().update(coords.x().to_repr().as_ref()); + self.state.pin_mut().update(coords.y().to_repr().as_ref()); + + Ok(()) + } + + fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + // Prefix to a prover's message containing a scalar + const BLAKE2B_PREFIX_SCALAR: u8 = 2; + self.state.pin_mut().update(&[BLAKE2B_PREFIX_SCALAR]); + self.state.pin_mut().update(scalar.to_repr().as_ref()); + Ok(()) + } +} + +impl TranscriptWrite> + for Blake2bWrite> +where + C::Scalar: FromUniformBytes<64>, +{ + fn write_point(&mut self, point: C) -> io::Result<()> { + trace!( + "[Halo2:WriteToProof] Proof[{}]: {:?}", + self.proof_idx, + point + ); + self.proof_idx += 1; + self.common_point(point)?; + let compressed = point.to_bytes(); + self.writer.write_all(compressed.as_ref()) + } + + fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + trace!( + "[Halo2:WriteToProof] Proof[{}]: {:?}", + self.proof_idx, + scalar + ); + self.proof_idx += 1; + self.common_scalar(scalar)?; + let data = scalar.to_repr(); + self.writer.write_all(data.as_ref()) + } +} + +impl TranscriptWriteState> + for Blake2bWrite> +where + C::Scalar: FromUniformBytes<64>, +{ + fn state(&self) -> Vec { + self.state.state() + } +} + +impl TranscriptWriterBuffer> + for Blake2bWrite> +where + C::Scalar: FromUniformBytes<64>, +{ + /// Initialize a transcript given an output buffer. + fn init(writer: W) -> Self { + Blake2bWrite { + state: ffi::new_blake2b_writer(), + writer: writer, + proof_idx: 0, + _marker: PhantomData, + } + } + + fn finalize(self) -> W { + // TODO: handle outstanding scalars? see issue #138 + self.writer + } +} + +#[derive(Debug)] +pub struct PoseidonWrite> { + state: cxx::UniquePtr, + writer: W, + proof_idx: usize, + _marker: PhantomData<(W, C, E)>, +} + +impl Transcript> + for PoseidonWrite> +where + C::Scalar: FromUniformBytes<64>, +{ + fn squeeze_challenge(&mut self) -> Challenge255 { + let scalar = *unsafe { + std::mem::transmute::<_, Box>(self.state.pin_mut().squeeze()) + }; + let mut scalar_bytes = scalar.to_repr().as_ref().to_vec(); + scalar_bytes.resize(64, 0u8); + Challenge255::::new(&scalar_bytes.try_into().unwrap()) + } + + fn common_point(&mut self, point: C) -> io::Result<()> { + let coords: Coordinates = Option::from(point.coordinates()).ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + "cannot write points at infinity to the transcript", + ) + })?; + let x = coords.x(); + let y = coords.y(); + let slice = &[base_to_scalar::(x), base_to_scalar::(y)]; + let bytes = std::mem::size_of::() * 2; + unsafe { + self.state.pin_mut().update(std::slice::from_raw_parts( + slice.as_ptr() as *const u8, + bytes, + )); + } + + Ok(()) + } + + fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + let slice = &[scalar]; + let bytes = std::mem::size_of::(); + unsafe { + self.state.pin_mut().update(std::slice::from_raw_parts( + slice.as_ptr() as *const u8, + bytes, + )); + } + + Ok(()) + } +} + +impl TranscriptWrite> + for PoseidonWrite> +where + C::Scalar: FromUniformBytes<64>, +{ + fn write_point(&mut self, point: C) -> io::Result<()> { + trace!( + "[Halo2:WriteToProof] Proof[{}]: {:?}", + self.proof_idx, + point + ); + self.common_point(point)?; + let compressed = point.to_bytes(); + self.writer.write_all(compressed.as_ref()) + } + + fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + trace!( + "[Halo2:WriteToProof] Proof[{}]: {:?}", + self.proof_idx, + scalar + ); + self.common_scalar(scalar)?; + let data = scalar.to_repr(); + self.writer.write_all(data.as_ref()) + } +} + +impl> PoseidonWrite { + /// Initialize a transcript given an output buffer. + pub fn init(writer: W) -> Self { + PoseidonWrite { + state: ffi::new_poseidon_writer(), + writer, + proof_idx: 0, + _marker: PhantomData, + } + } + + /// Conclude the interaction and return the output buffer (writer). + pub fn finalize(self) -> W { + // TODO: handle outstanding scalars? see issue #138 + self.writer + } +} + +impl TranscriptWriteState> + for PoseidonWrite> +where + C::Scalar: FromUniformBytes<64>, +{ + fn state(&self) -> Vec { + self.state.state() + } +} + +#[derive(Debug)] +pub struct Sha256Write> { + state: cxx::UniquePtr, + writer: W, + proof_idx: usize, + _marker: PhantomData<(W, C, E)>, +} + +impl Transcript> for Sha256Write> +where + C::Scalar: FromUniformBytes<64>, +{ + fn squeeze_challenge(&mut self) -> Challenge255 { + const SHA256_PREFIX_CHALLENGE: u8 = 0; + self.state.pin_mut().update(&[SHA256_PREFIX_CHALLENGE]); + let mut result: [u8; 32] = [0; 32]; + self.state.pin_mut().finalize(&mut result); + + self.state = ffi::new_sha256_writer(); + self.state.pin_mut().update(result.as_slice()); + + let mut bytes = result.to_vec(); + bytes.resize(64, 0u8); + Challenge255::::new(&bytes.try_into().unwrap()) + } + + fn common_point(&mut self, point: C) -> io::Result<()> { + const SHA256_PREFIX_POINT: u8 = 1; + self.state.pin_mut().update(&[0u8; 31]); + self.state.pin_mut().update(&[SHA256_PREFIX_POINT]); + let coords: Coordinates = Option::from(point.coordinates()).ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + "cannot write points at infinity to the transcript", + ) + })?; + + for base in &[coords.x(), coords.y()] { + let mut buf = base.to_repr().as_ref().to_vec(); + buf.resize(32, 0u8); + buf.reverse(); + self.state.pin_mut().update(buf.as_slice()); + } + + Ok(()) + } + + fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + const SHA256_PREFIX_SCALAR: u8 = 2; + self.state.pin_mut().update(&[0u8; 31]); + self.state.pin_mut().update(&[SHA256_PREFIX_SCALAR]); + + { + let mut buf = scalar.to_repr().as_ref().to_vec(); + buf.resize(32, 0u8); + buf.reverse(); + self.state.pin_mut().update(buf.as_slice()); + } + + Ok(()) + } +} + +impl TranscriptWrite> + for Sha256Write> +where + C::Scalar: FromUniformBytes<64>, +{ + fn write_point(&mut self, point: C) -> io::Result<()> { + trace!( + "[Halo2:WriteToProof] Proof[{}]: {:?}", + self.proof_idx, + point + ); + self.common_point(point)?; + + let coords = point.coordinates(); + let x = coords + .map(|v| *v.x()) + .unwrap_or(::Base::ZERO); + let y = coords + .map(|v| *v.y()) + .unwrap_or(::Base::ZERO); + + for base in &[&x, &y] { + self.writer.write_all(base.to_repr().as_ref())?; + } + + Ok(()) + } + + fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + trace!( + "[Halo2:WriteToProof] Proof[{}]: {:?}", + self.proof_idx, + scalar + ); + self.common_scalar(scalar)?; + let data = scalar.to_repr(); + + self.writer.write_all(data.as_ref()) + } +} + +impl TranscriptWriteState> + for Sha256Write> +where + C::Scalar: FromUniformBytes<64>, +{ + fn state(&self) -> Vec { + self.state.state() + } +} + +impl> Sha256Write { + /// Initialize a transcript given an output buffer. + pub fn init(writer: W) -> Self { + Sha256Write { + state: ffi::new_sha256_writer(), + writer, + proof_idx: 0, + _marker: PhantomData, + } + } + + /// Conclude the interaction and return the output buffer (writer). + pub fn finalize(self) -> W { + // TODO: handle outstanding scalars? see issue #138 + self.writer + } +} + +#[derive(Debug)] +pub struct SnarkVerifierPoseidonWrite> { + state: cxx::UniquePtr, + writer: W, + _marker: PhantomData<(W, C, E)>, +} + +impl Transcript> + for SnarkVerifierPoseidonWrite> +where + C::Scalar: FromUniformBytes<64>, +{ + fn squeeze_challenge(&mut self) -> Challenge255 { + let scalar = *unsafe { + std::mem::transmute::<_, Box>(self.state.pin_mut().squeeze()) + }; + let mut scalar_bytes = scalar.to_repr().as_ref().to_vec(); + scalar_bytes.resize(64, 0u8); + Challenge255::::new(&scalar_bytes.try_into().unwrap()) + } + + fn common_point(&mut self, point: C) -> io::Result<()> { + let coords: Coordinates = Option::from(point.coordinates()).ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + "cannot write points at infinity to the transcript", + ) + })?; + let x = coords.x(); + let y = coords.y(); + let slice = &[base_to_scalar::(x), base_to_scalar::(y)]; + let bytes = std::mem::size_of::() * 2; + unsafe { + self.state.pin_mut().update(std::slice::from_raw_parts( + slice.as_ptr() as *const u8, + bytes, + )); + } + + Ok(()) + } + + fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + let slice = &[scalar]; + let bytes = std::mem::size_of::(); + unsafe { + self.state.pin_mut().update(std::slice::from_raw_parts( + slice.as_ptr() as *const u8, + bytes, + )); + } + + Ok(()) + } +} + +impl TranscriptWrite> + for SnarkVerifierPoseidonWrite> +where + C::Scalar: FromUniformBytes<64>, +{ + fn write_point(&mut self, point: C) -> io::Result<()> { + self.common_point(point)?; + let compressed = point.to_bytes(); + self.writer.write_all(compressed.as_ref()) + } + + fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + self.common_scalar(scalar)?; + let data = scalar.to_repr(); + self.writer.write_all(data.as_ref()) + } +} + +impl> SnarkVerifierPoseidonWrite { + /// Initialize a transcript given an output buffer. + pub fn init(writer: W) -> Self { + SnarkVerifierPoseidonWrite { + state: ffi::new_snark_verifier_poseidon_writer(), + writer, + _marker: PhantomData, + } + } + + /// Conclude the interaction and return the output buffer (writer). + pub fn finalize(self) -> W { + // TODO: handle outstanding scalars? + // See https://github.com/zcash/halo2/issues/138. + self.writer + } +} + +impl TranscriptWriteState> + for SnarkVerifierPoseidonWrite> +where + C::Scalar: FromUniformBytes<64>, +{ + fn state(&self) -> Vec { + self.state.state() + } +} + +#[derive(Debug)] +pub struct ProvingKey { + inner: cxx::UniquePtr, + _marker: PhantomData, +} + +impl ProvingKey { + pub fn from(data: &[u8]) -> ProvingKey { + ProvingKey { + inner: ffi::new_proving_key(data), + _marker: PhantomData, + } + } + + // NOTE(chokobole): We name this as plural since it contains multi phases. + // pk.vk.cs.advice_column_phase + pub fn advice_column_phases(&self) -> Vec { + unsafe { + let phases: Vec = std::mem::transmute(self.inner.advice_column_phases()); + phases + } + } + + // pk.vk.cs.blinding_factors() + pub fn blinding_factors(&self) -> u32 { + self.inner.blinding_factors() + } + + // NOTE(chokobole): We name this as plural since it contains multi phases. + // pk.vk.cs.challenge_phase + pub fn challenge_phases(&self) -> Vec { + unsafe { + let phases: Vec = std::mem::transmute(self.inner.challenge_phases()); + phases + } + } + + // pk.vk.cs.constants + pub fn constants(&self) -> Vec> { + let constants = self + .inner + .constants() + .iter() + .map(|index| Column { + index: *index, + column_type: Fixed, + }) + .collect::>(); + constants + } + + // pk.vk.cs.num_advice_columns + pub fn num_advice_columns(&self) -> usize { + self.inner.num_advice_columns() + } + + // pk.vk.cs.num_challenges + pub fn num_challenges(&self) -> usize { + self.inner.num_challenges() + } + + // pk.vk.cs.num_instance_columns + pub fn num_instance_columns(&self) -> usize { + self.inner.num_instance_columns() + } + + // pk.vk.cs.phases() + pub fn phases(&self) -> Vec { + unsafe { + let phases: Vec = std::mem::transmute(self.inner.phases()); + phases + } + } + + // pk.vk.transcript_repr + pub fn transcript_repr>( + &mut self, + prover: &P, + ) -> C::Scalar { + *unsafe { + std::mem::transmute::<_, Box>( + self.inner.pin_mut().transcript_repr(prover.inner()), + ) + } + } +} + +#[derive(Debug)] +pub struct Evals { + inner: cxx::UniquePtr, +} + +impl Evals { + pub fn zero() -> Evals { + Self::new(ffi::zero_evals()) + } + + pub fn new(inner: cxx::UniquePtr) -> Evals { + Evals { inner } + } + + pub fn len(&self) -> usize { + self.inner.len() + } + + pub fn set_value(&mut self, idx: usize, fr: &halo2curves::bn256::Fr) { + let cpp_fr = unsafe { std::mem::transmute::<_, &Fr>(fr) }; + self.inner.pin_mut().set_value(idx, cpp_fr) + } +} + +impl Clone for Evals { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +#[derive(Debug)] +pub struct RationalEvals { + inner: cxx::UniquePtr, +} + +unsafe impl Send for ffi::RationalEvals {} +unsafe impl Sync for ffi::RationalEvals {} + +impl RationalEvals { + pub fn new(inner: cxx::UniquePtr) -> RationalEvals { + RationalEvals { inner } + } + + pub fn len(&self) -> usize { + self.inner.len() + } + + pub fn create_view(&mut self, start: usize, len: usize) -> RationalEvalsView { + RationalEvalsView::new(self.inner.pin_mut().create_view(start, len)) + } +} + +impl Clone for RationalEvals { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +#[derive(Debug)] +pub struct RationalEvalsView { + inner: cxx::UniquePtr, +} + +unsafe impl Send for ffi::RationalEvalsView {} +unsafe impl Sync for ffi::RationalEvalsView {} + +impl RationalEvalsView { + pub fn new(inner: cxx::UniquePtr) -> RationalEvalsView { + RationalEvalsView { inner } + } + + pub fn set_zero(&mut self, idx: usize) { + self.inner.pin_mut().set_zero(idx) + } + + pub fn set_trivial(&mut self, idx: usize, numerator: &halo2curves::bn256::Fr) { + let cpp_numerator = unsafe { std::mem::transmute::<_, &Fr>(numerator) }; + self.inner.pin_mut().set_trivial(idx, cpp_numerator) + } + + pub fn set_rational( + &mut self, + idx: usize, + numerator: &halo2curves::bn256::Fr, + denominator: &halo2curves::bn256::Fr, + ) { + let cpp_numerator = unsafe { std::mem::transmute::<_, &Fr>(numerator) }; + let cpp_denominator = unsafe { std::mem::transmute::<_, &Fr>(denominator) }; + self.inner + .pin_mut() + .set_rational(idx, cpp_numerator, cpp_denominator) + } + + pub fn evaluate(&self, idx: usize, value: &mut halo2curves::bn256::Fr) { + self.inner + .evaluate(idx, unsafe { std::mem::transmute::<_, &mut Fr>(value) }) + } +} + +#[derive(Debug)] +pub struct Poly { + inner: cxx::UniquePtr, +} + +impl Poly { + pub fn new(inner: cxx::UniquePtr) -> Poly { + Poly { inner } + } +} + +pub trait TachyonProver { + const QUERY_INSTANCE: bool; + + fn inner(&self) -> &ffi::Prover; + + fn k(&self) -> u32; + + fn n(&self) -> u64; + + fn s_g2(&self) -> G2Affine; + + fn commit(&self, poly: &Poly) -> ::CurveExt; + + fn commit_lagrange(&self, evals: &Evals) -> ::CurveExt; + + fn empty_evals(&self) -> Evals; + + fn empty_rational_evals(&self) -> RationalEvals; + + fn batch_evaluate(&self, rational_evals: &[RationalEvals], evals: &mut [Evals]); + + fn ifft(&self, evals: &Evals) -> Poly; + + fn set_rng(&mut self, rng_type: RNGType, state: &[u8]); + + fn set_transcript(&mut self, state: &[u8]); + + fn set_extended_domain(&mut self, pk: &ProvingKey); + + fn create_proof( + &mut self, + key: &mut ProvingKey, + instance_singles: &mut [InstanceSingle], + advice_singles: &mut [AdviceSingle], + challenges: &[Fr], + ); + + fn get_proof(&self) -> Vec; + + fn transcript_repr(&self, pk: &mut ProvingKey) -> Scheme::Scalar; +} + +#[derive(Debug)] +pub struct GWCProver { + inner: cxx::UniquePtr, + _marker: PhantomData, +} + +impl GWCProver { + pub fn new(transcript_type: u8, k: u32, s: &halo2curves::bn256::Fr) -> GWCProver { + let cpp_s = unsafe { std::mem::transmute::<_, &Fr>(s) }; + GWCProver { + inner: ffi::new_prover(PCSType::GWC as u8, transcript_type, k, cpp_s), + _marker: PhantomData, + } + } + + pub fn from_params(transcript_type: u8, k: u32, params: &[u8]) -> GWCProver { + GWCProver { + inner: ffi::new_prover_from_params(PCSType::GWC as u8, transcript_type, k, params), + _marker: PhantomData, + } + } +} + +impl TachyonProver for GWCProver { + const QUERY_INSTANCE: bool = true; + + fn inner(&self) -> &ffi::Prover { + &self.inner + } + + fn k(&self) -> u32 { + self.inner.k() + } + + fn n(&self) -> u64 { + self.inner.n() + } + + fn s_g2(&self) -> G2Affine { + *unsafe { std::mem::transmute::<_, Box>(self.inner.s_g2()) } + } + + fn commit(&self, poly: &Poly) -> ::CurveExt { + *unsafe { + std::mem::transmute::<_, Box<::CurveExt>>( + self.inner.commit(&poly.inner), + ) + } + } + + fn commit_lagrange(&self, evals: &Evals) -> ::CurveExt { + *unsafe { + std::mem::transmute::<_, Box<::CurveExt>>( + self.inner.commit_lagrange(&evals.inner), + ) + } + } + + fn empty_evals(&self) -> Evals { + Evals::new(self.inner.empty_evals()) + } + + fn empty_rational_evals(&self) -> RationalEvals { + RationalEvals::new(self.inner.empty_rational_evals()) + } + + fn batch_evaluate(&self, rational_evals: &[RationalEvals], evals: &mut [Evals]) { + unsafe { + let rational_evals: &[cxx::UniquePtr] = + std::mem::transmute(rational_evals); + let evals: &mut [cxx::UniquePtr] = std::mem::transmute(evals); + self.inner.batch_evaluate(rational_evals, evals) + } + } + + fn ifft(&self, evals: &Evals) -> Poly { + Poly::new(self.inner.ifft(&evals.inner)) + } + + fn set_rng(&mut self, rng_type: RNGType, state: &[u8]) { + self.inner.pin_mut().set_rng(rng_type as u8, state) + } + + fn set_transcript(&mut self, state: &[u8]) { + self.inner.pin_mut().set_transcript(state) + } + + fn set_extended_domain(&mut self, pk: &ProvingKey) { + self.inner.pin_mut().set_extended_domain(&pk.inner) + } + + fn create_proof( + &mut self, + key: &mut ProvingKey, + instance_singles: &mut [InstanceSingle], + advice_singles: &mut [AdviceSingle], + challenges: &[Fr], + ) { + self.inner.pin_mut().create_proof( + key.inner.pin_mut(), + instance_singles, + advice_singles, + challenges, + ) + } + + fn get_proof(&self) -> Vec { + self.inner.get_proof() + } + + fn transcript_repr( + &self, + pk: &mut ProvingKey<::Curve>, + ) -> Scheme::Scalar { + pk.transcript_repr(self) + } +} + +#[derive(Debug)] +pub struct SHPlonkProver { + inner: cxx::UniquePtr, + _marker: PhantomData, +} + +impl SHPlonkProver { + pub fn new(transcript_type: u8, k: u32, s: &halo2curves::bn256::Fr) -> SHPlonkProver { + let cpp_s = unsafe { std::mem::transmute::<_, &Fr>(s) }; + SHPlonkProver { + inner: ffi::new_prover(PCSType::SHPlonk as u8, transcript_type, k, cpp_s), + _marker: PhantomData, + } + } + + pub fn from_params(transcript_type: u8, k: u32, params: &[u8]) -> SHPlonkProver { + SHPlonkProver { + inner: ffi::new_prover_from_params(PCSType::SHPlonk as u8, transcript_type, k, params), + _marker: PhantomData, + } + } +} + +impl TachyonProver for SHPlonkProver { + const QUERY_INSTANCE: bool = false; + + fn inner(&self) -> &ffi::Prover { + &self.inner + } + + fn k(&self) -> u32 { + self.inner.k() + } + + fn n(&self) -> u64 { + self.inner.n() + } + + fn s_g2(&self) -> G2Affine { + *unsafe { std::mem::transmute::<_, Box>(self.inner.s_g2()) } + } + + fn commit(&self, poly: &Poly) -> ::CurveExt { + *unsafe { + std::mem::transmute::<_, Box<::CurveExt>>( + self.inner.commit(&poly.inner), + ) + } + } + + fn commit_lagrange(&self, evals: &Evals) -> ::CurveExt { + *unsafe { + std::mem::transmute::<_, Box<::CurveExt>>( + self.inner.commit_lagrange(&evals.inner), + ) + } + } + + fn empty_evals(&self) -> Evals { + Evals::new(self.inner.empty_evals()) + } + + fn empty_rational_evals(&self) -> RationalEvals { + RationalEvals::new(self.inner.empty_rational_evals()) + } + + fn batch_evaluate(&self, rational_evals: &[RationalEvals], evals: &mut [Evals]) { + unsafe { + let rational_evals: &[cxx::UniquePtr] = + std::mem::transmute(rational_evals); + let evals: &mut [cxx::UniquePtr] = std::mem::transmute(evals); + self.inner.batch_evaluate(rational_evals, evals) + } + } + + fn ifft(&self, evals: &Evals) -> Poly { + Poly::new(self.inner.ifft(&evals.inner)) + } + + fn set_rng(&mut self, rng_type: RNGType, state: &[u8]) { + self.inner.pin_mut().set_rng(rng_type as u8, state) + } + + fn set_transcript(&mut self, state: &[u8]) { + self.inner.pin_mut().set_transcript(state) + } + + fn set_extended_domain(&mut self, pk: &ProvingKey) { + self.inner.pin_mut().set_extended_domain(&pk.inner) + } + + fn create_proof( + &mut self, + key: &mut ProvingKey, + instance_singles: &mut [InstanceSingle], + advice_singles: &mut [AdviceSingle], + challenges: &[Fr], + ) { + self.inner.pin_mut().create_proof( + key.inner.pin_mut(), + instance_singles, + advice_singles, + challenges, + ) + } + + fn get_proof(&self) -> Vec { + self.inner.get_proof() + } + + fn transcript_repr( + &self, + pk: &mut ProvingKey<::Curve>, + ) -> Scheme::Scalar { + pk.transcript_repr(self) + } +} diff --git a/halo2_proofs/src/bn254_blake2b_writer.cc b/halo2_proofs/src/bn254_blake2b_writer.cc new file mode 100644 index 00000000..62b8453a --- /dev/null +++ b/halo2_proofs/src/bn254_blake2b_writer.cc @@ -0,0 +1,45 @@ +#include "halo2_proofs/include/bn254_blake2b_writer.h" + +#include + +namespace tachyon::halo2_api::bn254 { + +Blake2bWriter::Blake2bWriter() + : writer_(tachyon_halo2_bn254_transcript_writer_create( + TACHYON_HALO2_BLAKE2B_TRANSCRIPT)) {} + +Blake2bWriter::~Blake2bWriter() { + tachyon_halo2_bn254_transcript_writer_destroy(writer_); +} + +void Blake2bWriter::update(rust::Slice data) { + tachyon_halo2_bn254_transcript_writer_update(writer_, data.data(), + data.size()); +} + +void Blake2bWriter::finalize( + std::array& result) { + uint8_t data[kBlake2bDigestLength]; + size_t data_size; + tachyon_halo2_bn254_transcript_writer_finalize(writer_, data, &data_size); + memcpy(result.data(), data, data_size); +} + +rust::Vec Blake2bWriter::state() const { + rust::Vec ret; + // NOTE(chokobole): |rust::Vec| doesn't have |resize()|. + ret.reserve(kBlake2bStateLength); + for (size_t i = 0; i < kBlake2bStateLength; ++i) { + ret.push_back(0); + } + size_t state_size; + tachyon_halo2_bn254_transcript_writer_get_state(writer_, ret.data(), + &state_size); + return ret; +} + +std::unique_ptr new_blake2b_writer() { + return std::make_unique(); +} + +} // namespace tachyon::halo2_api::bn254 diff --git a/halo2_proofs/src/bn254_evals.cc b/halo2_proofs/src/bn254_evals.cc new file mode 100644 index 00000000..8faacde2 --- /dev/null +++ b/halo2_proofs/src/bn254_evals.cc @@ -0,0 +1,27 @@ +#include "halo2_proofs/include/bn254_evals.h" + +#include "halo2_proofs/src/bn254.rs.h" + +namespace tachyon::halo2_api::bn254 { + +Evals::Evals() : evals_(tachyon_bn254_univariate_evaluations_create()) {} + +Evals::~Evals() { tachyon_bn254_univariate_evaluations_destroy(evals_); } + +size_t Evals::len() const { + return tachyon_bn254_univariate_evaluations_len(evals_); +} + +void Evals::set_value(size_t idx, const Fr& fr) { + tachyon_bn254_univariate_evaluations_set_value( + evals_, idx, reinterpret_cast(&fr)); +} + +std::unique_ptr Evals::clone() const { + return std::make_unique( + tachyon_bn254_univariate_evaluations_clone(evals_)); +} + +std::unique_ptr zero_evals() { return std::make_unique(); } + +} // namespace tachyon::halo2_api::bn254 diff --git a/halo2_proofs/src/bn254_poly.cc b/halo2_proofs/src/bn254_poly.cc new file mode 100644 index 00000000..80f80583 --- /dev/null +++ b/halo2_proofs/src/bn254_poly.cc @@ -0,0 +1,9 @@ +#include "halo2_proofs/include/bn254_poly.h" + +namespace tachyon::halo2_api::bn254 { + +Poly::Poly() : poly_(tachyon_bn254_univariate_dense_polynomial_create()) {} + +Poly::~Poly() { tachyon_bn254_univariate_dense_polynomial_destroy(poly_); } + +} // namespace tachyon::halo2_api::bn254 diff --git a/halo2_proofs/src/bn254_poseidon_writer.cc b/halo2_proofs/src/bn254_poseidon_writer.cc new file mode 100644 index 00000000..108a0a78 --- /dev/null +++ b/halo2_proofs/src/bn254_poseidon_writer.cc @@ -0,0 +1,43 @@ +#include "halo2_proofs/include/bn254_poseidon_writer.h" + +namespace tachyon::halo2_api::bn254 { + +PoseidonWriter::PoseidonWriter() + : writer_(tachyon_halo2_bn254_transcript_writer_create( + TACHYON_HALO2_POSEIDON_TRANSCRIPT)) {} + +PoseidonWriter::~PoseidonWriter() { + tachyon_halo2_bn254_transcript_writer_destroy(writer_); +} + +void PoseidonWriter::update(rust::Slice data) { + tachyon_halo2_bn254_transcript_writer_update(writer_, data.data(), + data.size()); +} + +rust::Box PoseidonWriter::squeeze() { + tachyon_bn254_fr* ret = new tachyon_bn254_fr; + *ret = tachyon_halo2_bn254_transcript_writer_squeeze(writer_); + return rust::Box::from_raw(reinterpret_cast(ret)); +} + +rust::Vec PoseidonWriter::state() const { + size_t state_size; + tachyon_halo2_bn254_transcript_writer_get_state(writer_, nullptr, + &state_size); + rust::Vec ret; + // NOTE(chokobole): |rust::Vec| doesn't have |resize()|. + ret.reserve(state_size); + for (size_t i = 0; i < state_size; ++i) { + ret.push_back(0); + } + tachyon_halo2_bn254_transcript_writer_get_state(writer_, ret.data(), + &state_size); + return ret; +} + +std::unique_ptr new_poseidon_writer() { + return std::make_unique(); +} + +} // namespace tachyon::halo2_api::bn254 diff --git a/halo2_proofs/src/bn254_prover.cc b/halo2_proofs/src/bn254_prover.cc new file mode 100644 index 00000000..eb8721c2 --- /dev/null +++ b/halo2_proofs/src/bn254_prover.cc @@ -0,0 +1,210 @@ +#include "halo2_proofs/include/bn254_prover.h" + +#include +#include + +#include "halo2_proofs/src/bn254.rs.h" +#include "halo2_proofs/src/rust_vec.h" + +namespace tachyon::halo2_api::bn254 { + +Prover::Prover(uint8_t pcs_type, uint8_t transcript_type, uint32_t k, + const Fr& s) + : prover_(tachyon_halo2_bn254_prover_create_from_unsafe_setup( + pcs_type, TACHYON_HALO2_LOG_DERIVATIVE_HALO2_LS, transcript_type, k, + reinterpret_cast(&s))) {} + +Prover::Prover(uint8_t pcs_type, uint8_t transcript_type, uint32_t k, + const uint8_t* params, size_t params_len) + : prover_(tachyon_halo2_bn254_prover_create_from_params( + pcs_type, TACHYON_HALO2_LOG_DERIVATIVE_HALO2_LS, transcript_type, k, + params, params_len)) {} + +Prover::~Prover() { tachyon_halo2_bn254_prover_destroy(prover_); } + +uint32_t Prover::k() const { return tachyon_halo2_bn254_prover_get_k(prover_); } + +uint64_t Prover::n() const { + return static_cast(tachyon_halo2_bn254_prover_get_n(prover_)); +} + +rust::Box Prover::s_g2() const { + return rust::Box::from_raw( + reinterpret_cast(new tachyon_bn254_g2_affine( + *tachyon_halo2_bn254_prover_get_s_g2(prover_)))); +} + +rust::Box Prover::commit(const Poly& poly) const { + return rust::Box::from_raw( + reinterpret_cast( + tachyon_halo2_bn254_prover_commit(prover_, poly.poly()))); +} + +rust::Box Prover::commit_lagrange(const Evals& evals) const { + return rust::Box::from_raw( + reinterpret_cast( + tachyon_halo2_bn254_prover_commit_lagrange(prover_, evals.evals()))); +} + +std::unique_ptr Prover::empty_evals() const { + return std::make_unique( + tachyon_bn254_univariate_evaluation_domain_empty_evals( + tachyon_halo2_bn254_prover_get_domain(prover_))); +} + +std::unique_ptr Prover::empty_rational_evals() const { + return std::make_unique( + tachyon_bn254_univariate_evaluation_domain_empty_rational_evals( + tachyon_halo2_bn254_prover_get_domain(prover_))); +} + +std::unique_ptr Prover::ifft(const Evals& evals) const { + // NOTE(chokobole): The zero degrees might be removed. This might cause an + // unexpected error if you use this carelessly. Since this is only used to + // compute instance polynomial and this is used only in Tachyon side, so it's + // fine. + return std::make_unique(tachyon_bn254_univariate_evaluation_domain_ifft( + tachyon_halo2_bn254_prover_get_domain(prover_), evals.evals())); +} + +void Prover::batch_evaluate( + rust::Slice> rational_evals, + rust::Slice> evals) const { + for (size_t i = 0; i < rational_evals.size(); ++i) { + evals[i] = std::make_unique( + tachyon_bn254_univariate_rational_evaluations_batch_evaluate( + rational_evals[i]->evals())); + } +} + +void Prover::set_rng(uint8_t rng_type, rust::Slice state) { + tachyon_halo2_bn254_prover_set_rng_state(prover_, rng_type, state.data(), + state.size()); +} + +void Prover::set_transcript(rust::Slice state) { + tachyon_halo2_bn254_prover_set_transcript_state(prover_, state.data(), + state.size()); +} + +void Prover::set_extended_domain(const ProvingKey& pk) { + tachyon_halo2_bn254_prover_set_extended_domain(prover_, pk.pk()); +} + +void Prover::create_proof(ProvingKey& key, + rust::Slice instance_singles, + rust::Slice advice_singles, + rust::Slice challenges) { + tachyon_bn254_blinder* blinder = + tachyon_halo2_bn254_prover_get_blinder(prover_); + const tachyon_bn254_plonk_verifying_key* vk = + tachyon_bn254_plonk_proving_key_get_verifying_key(key.pk()); + const tachyon_bn254_plonk_constraint_system* cs = + tachyon_bn254_plonk_verifying_key_get_constraint_system(vk); + uint32_t blinding_factors = + tachyon_bn254_plonk_constraint_system_compute_blinding_factors(cs); + tachyon_halo2_bn254_blinder_set_blinding_factors(blinder, blinding_factors); + + size_t num_circuits = instance_singles.size(); + + tachyon_halo2_bn254_argument_data* data = + tachyon_halo2_bn254_argument_data_create(num_circuits); + + tachyon_halo2_bn254_argument_data_reserve_challenges(data, challenges.size()); + for (size_t i = 0; i < challenges.size(); ++i) { + tachyon_halo2_bn254_argument_data_add_challenge( + data, reinterpret_cast(&challenges[i])); + } + + size_t num_bytes = sizeof(RustVec); + uint8_t* advice_single_data = + reinterpret_cast(advice_singles.data()); + uint8_t* instance_single_data = + reinterpret_cast(instance_singles.data()); + for (size_t i = 0; i < num_circuits; ++i) { + RustVec vec; + vec.Read(advice_single_data); + size_t num_advice_columns = vec.length; + uintptr_t* advice_columns_ptr = reinterpret_cast(vec.ptr); + tachyon_halo2_bn254_argument_data_reserve_advice_columns( + data, i, num_advice_columns); + for (size_t j = 0; j < num_advice_columns; ++j) { + tachyon_halo2_bn254_argument_data_add_advice_column( + data, i, reinterpret_cast(advice_columns_ptr[j])->release()); + } + advice_single_data += num_bytes; + + vec.Read(&advice_single_data[0]); + size_t num_blinds = vec.length; + const tachyon_bn254_fr* blinds_ptr = + reinterpret_cast(vec.ptr); + tachyon_halo2_bn254_argument_data_reserve_advice_blinds(data, i, + num_blinds); + for (size_t j = 0; j < num_blinds; ++j) { + tachyon_halo2_bn254_argument_data_add_advice_blind(data, i, + &blinds_ptr[j]); + } + advice_single_data += num_bytes; + + vec.Read(&instance_single_data[0]); + size_t num_instance_columns = vec.length; + uintptr_t* instance_columns_ptr = reinterpret_cast(vec.ptr); + tachyon_halo2_bn254_argument_data_reserve_instance_columns( + data, i, num_instance_columns); + for (size_t j = 0; j < num_instance_columns; ++j) { + tachyon_halo2_bn254_argument_data_add_instance_column( + data, i, + reinterpret_cast(instance_columns_ptr[j])->release()); + } + instance_single_data += num_bytes; + + vec.Read(&instance_single_data[0]); + uintptr_t* instance_poly_ptr = reinterpret_cast(vec.ptr); + tachyon_halo2_bn254_argument_data_reserve_instance_polys( + data, i, num_instance_columns); + for (size_t j = 0; j < num_instance_columns; ++j) { + tachyon_halo2_bn254_argument_data_add_instance_poly( + data, i, reinterpret_cast(instance_poly_ptr[j])->release()); + } + instance_single_data += num_bytes; + } + + tachyon_halo2_bn254_prover_create_proof(prover_, key.pk(), data); + tachyon_halo2_bn254_argument_data_destroy(data); +} + +rust::Vec Prover::get_proof() const { + size_t proof_len; + tachyon_halo2_bn254_prover_get_proof(prover_, nullptr, &proof_len); + rust::Vec proof; + // NOTE(chokobole): |rust::Vec| doesn't have |resize()|. + proof.reserve(proof_len); + for (size_t i = 0; i < proof_len; ++i) { + proof.push_back(0); + } + tachyon_halo2_bn254_prover_get_proof(prover_, proof.data(), &proof_len); + return proof; +} + +std::unique_ptr new_prover(uint8_t pcs_type, uint8_t transcript_type, + uint32_t k, const Fr& s) { + return std::make_unique(pcs_type, transcript_type, k, s); +} + +std::unique_ptr new_prover_from_params( + uint8_t pcs_type, uint8_t transcript_type, uint32_t k, + rust::Slice params) { + return std::make_unique(pcs_type, transcript_type, k, params.data(), + params.size()); +} + +rust::Box ProvingKey::transcript_repr(const Prover& prover) { + tachyon_halo2_bn254_prover_set_transcript_repr(prover.prover(), pk_); + tachyon_bn254_fr* ret = new tachyon_bn254_fr; + tachyon_bn254_fr repr = tachyon_bn254_plonk_verifying_key_get_transcript_repr( + tachyon_bn254_plonk_proving_key_get_verifying_key(pk_)); + memcpy(ret->limbs, repr.limbs, sizeof(uint64_t) * 4); + return rust::Box::from_raw(reinterpret_cast(ret)); +} + +} // namespace tachyon::halo2_api::bn254 diff --git a/halo2_proofs/src/bn254_proving_key.cc b/halo2_proofs/src/bn254_proving_key.cc new file mode 100644 index 00000000..4f3a52fd --- /dev/null +++ b/halo2_proofs/src/bn254_proving_key.cc @@ -0,0 +1,112 @@ +#include "halo2_proofs/include/bn254_proving_key.h" + +#include + +#include "halo2_proofs/src/bn254.rs.h" + +namespace tachyon::halo2_api::bn254 { + +namespace { + +using GetPhasesAPI = void (*)(const tachyon_bn254_plonk_constraint_system*, + tachyon_phase*, size_t*); +using GetFixedColumnsAPI = + void (*)(const tachyon_bn254_plonk_constraint_system*, + tachyon_fixed_column_key*, size_t*); + +rust::Vec DoGetPhases(const tachyon_bn254_plonk_constraint_system* cs, + GetPhasesAPI api) { + static_assert(sizeof(uint8_t) == sizeof(tachyon_phase)); + rust::Vec phases; + size_t phases_len; + api(cs, nullptr, &phases_len); + phases.reserve(phases_len); + for (size_t i = 0; i < phases_len; ++i) { + phases.push_back(0); + } + api(cs, reinterpret_cast(phases.data()), &phases_len); + return phases; +} + +rust::Vec GetFixedColumns( + const tachyon_bn254_plonk_constraint_system* cs, GetFixedColumnsAPI api) { + static_assert(sizeof(size_t) == sizeof(tachyon_fixed_column_key)); + rust::Vec fixed_columns; + size_t fixed_columns_len; + api(cs, nullptr, &fixed_columns_len); + fixed_columns.reserve(fixed_columns_len); + for (size_t i = 0; i < fixed_columns_len; ++i) { + fixed_columns.push_back(0); + } + api(cs, reinterpret_cast(fixed_columns.data()), + &fixed_columns_len); + return fixed_columns; +} + +} // namespace + +ProvingKey::ProvingKey(rust::Slice pk_bytes) + : pk_(tachyon_bn254_plonk_proving_key_create_from_state( + TACHYON_HALO2_LOG_DERIVATIVE_HALO2_LS, pk_bytes.data(), + pk_bytes.size())) {} + +ProvingKey::~ProvingKey() { tachyon_bn254_plonk_proving_key_destroy(pk_); } + +rust::Vec ProvingKey::advice_column_phases() const { + return DoGetPhases( + GetConstraintSystem(), + &tachyon_bn254_plonk_constraint_system_get_advice_column_phases); +} + +uint32_t ProvingKey::blinding_factors() const { + return tachyon_bn254_plonk_constraint_system_compute_blinding_factors( + GetConstraintSystem()); +} + +rust::Vec ProvingKey::challenge_phases() const { + return DoGetPhases( + GetConstraintSystem(), + &tachyon_bn254_plonk_constraint_system_get_challenge_phases); +} + +rust::Vec ProvingKey::constants() const { + return GetFixedColumns(GetConstraintSystem(), + &tachyon_bn254_plonk_constraint_system_get_constants); +} + +size_t ProvingKey::num_advice_columns() const { + return tachyon_bn254_plonk_constraint_system_get_num_advice_columns( + GetConstraintSystem()); +} + +size_t ProvingKey::num_challenges() const { + return tachyon_bn254_plonk_constraint_system_get_num_challenges( + GetConstraintSystem()); +} + +size_t ProvingKey::num_instance_columns() const { + return tachyon_bn254_plonk_constraint_system_get_num_instance_columns( + GetConstraintSystem()); +} + +rust::Vec ProvingKey::phases() const { + return DoGetPhases(GetConstraintSystem(), + &tachyon_bn254_plonk_constraint_system_get_phases); +} + +const tachyon_bn254_plonk_verifying_key* ProvingKey::GetVerifyingKey() const { + return tachyon_bn254_plonk_proving_key_get_verifying_key(pk_); +} + +const tachyon_bn254_plonk_constraint_system* ProvingKey::GetConstraintSystem() + const { + return tachyon_bn254_plonk_verifying_key_get_constraint_system( + GetVerifyingKey()); +} + +std::unique_ptr new_proving_key( + rust::Slice pk_bytes) { + return std::make_unique(pk_bytes); +} + +} // namespace tachyon::halo2_api::bn254 diff --git a/halo2_proofs/src/bn254_rational_evals.cc b/halo2_proofs/src/bn254_rational_evals.cc new file mode 100644 index 00000000..21197d2d --- /dev/null +++ b/halo2_proofs/src/bn254_rational_evals.cc @@ -0,0 +1,28 @@ +#include "halo2_proofs/include/bn254_rational_evals.h" + +#include "halo2_proofs/include/bn254_rational_evals_view.h" + +namespace tachyon::halo2_api::bn254 { + +RationalEvals::RationalEvals() + : evals_(tachyon_bn254_univariate_rational_evaluations_create()) {} + +RationalEvals::~RationalEvals() { + tachyon_bn254_univariate_rational_evaluations_destroy(evals_); +} + +size_t RationalEvals::len() const { + return tachyon_bn254_univariate_rational_evaluations_len(evals_); +} + +std::unique_ptr RationalEvals::create_view(size_t start, + size_t len) { + return std::make_unique(evals_, start, len); +} + +std::unique_ptr RationalEvals::clone() const { + return std::make_unique( + tachyon_bn254_univariate_rational_evaluations_clone(evals_)); +} + +} // namespace tachyon::halo2_api::bn254 diff --git a/halo2_proofs/src/bn254_rational_evals_view.cc b/halo2_proofs/src/bn254_rational_evals_view.cc new file mode 100644 index 00000000..34377594 --- /dev/null +++ b/halo2_proofs/src/bn254_rational_evals_view.cc @@ -0,0 +1,35 @@ +#include + +#include "halo2_proofs/src/bn254.rs.h" + +namespace tachyon::halo2_api::bn254 { + +RationalEvalsView::RationalEvalsView( + tachyon_bn254_univariate_rational_evaluations* evals, size_t start, + size_t len) + : evals_(evals), start_(start), len_(len) {} + +void RationalEvalsView::set_zero(size_t idx) { + tachyon_bn254_univariate_rational_evaluations_set_zero(evals_, start_ + idx); +} + +void RationalEvalsView::set_trivial(size_t idx, const Fr& numerator) { + tachyon_bn254_univariate_rational_evaluations_set_trivial( + evals_, start_ + idx, + reinterpret_cast(&numerator)); +} + +void RationalEvalsView::set_rational(size_t idx, const Fr& numerator, + const Fr& denominator) { + tachyon_bn254_univariate_rational_evaluations_set_rational( + evals_, start_ + idx, + reinterpret_cast(&numerator), + reinterpret_cast(&denominator)); +} + +void RationalEvalsView::evaluate(size_t idx, Fr& value) const { + tachyon_bn254_univariate_rational_evaluations_evaluate( + evals_, start_ + idx, reinterpret_cast(&value)); +} + +} // namespace tachyon::halo2_api::bn254 diff --git a/halo2_proofs/src/bn254_sha256_writer.cc b/halo2_proofs/src/bn254_sha256_writer.cc new file mode 100644 index 00000000..5dd98888 --- /dev/null +++ b/halo2_proofs/src/bn254_sha256_writer.cc @@ -0,0 +1,44 @@ +#include "halo2_proofs/include/bn254_sha256_writer.h" + +#include + +namespace tachyon::halo2_api::bn254 { + +Sha256Writer::Sha256Writer() + : writer_(tachyon_halo2_bn254_transcript_writer_create( + TACHYON_HALO2_SHA256_TRANSCRIPT)) {} + +Sha256Writer::~Sha256Writer() { + tachyon_halo2_bn254_transcript_writer_destroy(writer_); +} + +void Sha256Writer::update(rust::Slice data) { + tachyon_halo2_bn254_transcript_writer_update(writer_, data.data(), + data.size()); +} + +void Sha256Writer::finalize(std::array& result) { + uint8_t data[kSha256DigestLength]; + size_t data_size; + tachyon_halo2_bn254_transcript_writer_finalize(writer_, data, &data_size); + memcpy(result.data(), data, data_size); +} + +rust::Vec Sha256Writer::state() const { + rust::Vec ret; + // NOTE(chokobole): |rust::Vec| doesn't have |resize()|. + ret.reserve(kSha256StateLength); + for (size_t i = 0; i < kSha256StateLength; ++i) { + ret.push_back(0); + } + size_t state_size; + tachyon_halo2_bn254_transcript_writer_get_state(writer_, ret.data(), + &state_size); + return ret; +} + +std::unique_ptr new_sha256_writer() { + return std::make_unique(); +} + +} // namespace tachyon::halo2_api::bn254 diff --git a/halo2_proofs/src/bn254_snark_verifier_poseidon_writer.cc b/halo2_proofs/src/bn254_snark_verifier_poseidon_writer.cc new file mode 100644 index 00000000..7613d473 --- /dev/null +++ b/halo2_proofs/src/bn254_snark_verifier_poseidon_writer.cc @@ -0,0 +1,44 @@ +#include "halo2_proofs/include/bn254_snark_verifier_poseidon_writer.h" + +namespace tachyon::halo2_api::bn254 { + +SnarkVerifierPoseidonWriter::SnarkVerifierPoseidonWriter() + : writer_(tachyon_halo2_bn254_transcript_writer_create( + TACHYON_HALO2_SNARK_VERIFIER_POSEIDON_TRANSCRIPT)) {} + +SnarkVerifierPoseidonWriter::~SnarkVerifierPoseidonWriter() { + tachyon_halo2_bn254_transcript_writer_destroy(writer_); +} + +void SnarkVerifierPoseidonWriter::update(rust::Slice data) { + tachyon_halo2_bn254_transcript_writer_update(writer_, data.data(), + data.size()); +} + +rust::Box SnarkVerifierPoseidonWriter::squeeze() { + tachyon_bn254_fr* ret = new tachyon_bn254_fr; + *ret = tachyon_halo2_bn254_transcript_writer_squeeze(writer_); + return rust::Box::from_raw(reinterpret_cast(ret)); +} + +rust::Vec SnarkVerifierPoseidonWriter::state() const { + size_t state_size; + tachyon_halo2_bn254_transcript_writer_get_state(writer_, nullptr, + &state_size); + rust::Vec ret; + // NOTE(chokobole): |rust::Vec| doesn't have |resize()|. + ret.reserve(state_size); + for (size_t i = 0; i < state_size; ++i) { + ret.push_back(0); + } + tachyon_halo2_bn254_transcript_writer_get_state(writer_, ret.data(), + &state_size); + return ret; +} + +std::unique_ptr +new_snark_verifier_poseidon_writer() { + return std::make_unique(); +} + +} // namespace tachyon::halo2_api::bn254 diff --git a/halo2_proofs/src/cha_cha20_rng.cc b/halo2_proofs/src/cha_cha20_rng.cc new file mode 100644 index 00000000..b22d9a0f --- /dev/null +++ b/halo2_proofs/src/cha_cha20_rng.cc @@ -0,0 +1,44 @@ +#include "halo2_proofs/include/cha_cha20_rng.h" + +#include + +namespace tachyon::halo2_api { + +ChaCha20Rng::ChaCha20Rng(std::array seed) { + uint8_t seed_copy[kSeedSize]; + memcpy(seed_copy, seed.data(), kSeedSize); + rng_ = + tachyon_rng_create_from_seed(TACHYON_RNG_CHA_CHA20, seed_copy, kSeedSize); +} + +ChaCha20Rng::~ChaCha20Rng() { tachyon_rng_destroy(rng_); } + +uint32_t ChaCha20Rng::next_u32() { return tachyon_rng_get_next_u32(rng_); } + +std::unique_ptr ChaCha20Rng::clone() const { + uint8_t state[kStateSize]; + size_t state_len; + tachyon_rng_get_state(rng_, state, &state_len); + tachyon_rng* rng = + tachyon_rng_create_from_state(TACHYON_RNG_CHA_CHA20, state, kStateSize); + return std::make_unique(rng); +} + +rust::Vec ChaCha20Rng::state() const { + rust::Vec ret; + // NOTE(chokobole): |rust::Vec| doesn't have |resize()|. + ret.reserve(kStateSize); + for (size_t i = 0; i < kStateSize; ++i) { + ret.push_back(0); + } + size_t state_len; + tachyon_rng_get_state(rng_, ret.data(), &state_len); + return ret; +} + +std::unique_ptr new_cha_cha20_rng( + std::array seed) { + return std::make_unique(seed); +} + +} // namespace tachyon::halo2_api diff --git a/halo2_proofs/src/cha_cha20_rng.rs b/halo2_proofs/src/cha_cha20_rng.rs new file mode 100644 index 00000000..6f189a61 --- /dev/null +++ b/halo2_proofs/src/cha_cha20_rng.rs @@ -0,0 +1,121 @@ +use crate::{consts::RNGType, rng::SerializableRng}; +use std::fmt; + +#[cxx::bridge(namespace = "tachyon::halo2_api")] +pub mod ffi { + unsafe extern "C++" { + include!("halo2_proofs/include/cha_cha20_rng.h"); + + type ChaCha20Rng; + + fn new_cha_cha20_rng(seed: [u8; 32]) -> UniquePtr; + fn next_u32(self: Pin<&mut ChaCha20Rng>) -> u32; + fn clone(&self) -> UniquePtr; + fn state(&self) -> Vec; + } +} + +impl fmt::Debug for ffi::ChaCha20Rng { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ChaCha20Rng").finish() + } +} + +#[derive(Debug)] +pub struct ChaCha20Rng { + inner: cxx::UniquePtr, +} + +impl SerializableRng for ChaCha20Rng { + fn state(&self) -> Vec { + self.inner.state() + } + + fn rng_type() -> RNGType { + RNGType::ChaCha20 + } +} + +impl Clone for ChaCha20Rng { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl rand_core::SeedableRng for ChaCha20Rng { + type Seed = [u8; 32]; + + fn from_seed(seed: Self::Seed) -> Self { + Self { + inner: ffi::new_cha_cha20_rng(seed), + } + } +} + +impl rand_core::RngCore for ChaCha20Rng { + fn next_u32(&mut self) -> u32 { + self.inner.pin_mut().next_u32() + } + + #[inline] + fn next_u64(&mut self) -> u64 { + rand_core::impls::next_u64_via_u32(self) + } + + #[inline] + fn fill_bytes(&mut self, dest: &mut [u8]) { + rand_core::impls::fill_bytes_via_next(self, dest) + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + self.fill_bytes(dest); + Ok(()) + } +} + +#[cfg(test)] +mod test { + use rand_core::{RngCore, SeedableRng}; + + use crate::{consts::CHA_CHA20_SEED, rng::SerializableRng}; + + #[test] + fn test_rng() { + let mut rng = rand_chacha::ChaCha20Rng::from_seed(CHA_CHA20_SEED); + let mut rng_tachyon = crate::cha_cha20_rng::ChaCha20Rng::from_seed(CHA_CHA20_SEED); + + const LEN: i32 = 100; + let random_u64s = (0..LEN).map(|_| rng.next_u64()).collect::>(); + let random_u64s_tachyon = (0..LEN).map(|_| rng_tachyon.next_u64()).collect::>(); + assert_eq!(random_u64s, random_u64s_tachyon); + } + + #[test] + fn test_clone() { + let mut rng = crate::cha_cha20_rng::ChaCha20Rng::from_seed(CHA_CHA20_SEED); + let mut rng_clone = rng.clone(); + + const LEN: i32 = 100; + let random_u64s = (0..LEN).map(|_| rng.next_u64()).collect::>(); + let random_u64s_clone = (0..LEN).map(|_| rng_clone.next_u64()).collect::>(); + assert_eq!(random_u64s, random_u64s_clone); + } + + #[test] + fn test_state() { + let rng = crate::cha_cha20_rng::ChaCha20Rng::from_seed(CHA_CHA20_SEED); + assert_eq!( + rng.state(), + vec![ + 16, 0, 0, 0, 0, 0, 0, 0, 101, 120, 112, 97, 110, 100, 32, 51, 50, 45, 98, 121, 116, + 101, 32, 107, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0 + ] + ); + } +} diff --git a/halo2_proofs/src/consts.rs b/halo2_proofs/src/consts.rs new file mode 100644 index 00000000..0f5af61e --- /dev/null +++ b/halo2_proofs/src/consts.rs @@ -0,0 +1,27 @@ +#[derive(Debug)] +pub enum PCSType { + GWC, + SHPlonk, +} + +#[derive(Debug)] +pub enum TranscriptType { + Blake2b, + Poseidon, + Sha256, + SnarkVerifierPoseidon, +} + +#[derive(Debug)] +pub enum RNGType { + XORShift, + ChaCha20, +} + +pub const XOR_SHIFT_SEED: [u8; 16] = [ + 0x59, 0x62, 0xbe, 0x5d, 0x76, 0x3d, 0x31, 0x8d, 0x17, 0xdb, 0x37, 0x32, 0x54, 0x06, 0xbc, 0xe5, +]; + +pub const CHA_CHA20_SEED: [u8; 32] = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +]; diff --git a/halo2_proofs/src/lib.rs b/halo2_proofs/src/lib.rs index da5706f0..d6934c27 100644 --- a/halo2_proofs/src/lib.rs +++ b/halo2_proofs/src/lib.rs @@ -40,3 +40,9 @@ pub mod transcript; pub mod dev; mod helpers; pub use helpers::SerdeFormat; + +pub mod bn254; +pub mod cha_cha20_rng; +pub mod consts; +pub mod rng; +pub mod xor_shift_rng; diff --git a/halo2_proofs/src/plonk.rs b/halo2_proofs/src/plonk.rs index b16d0ed4..473d8e88 100644 --- a/halo2_proofs/src/plonk.rs +++ b/halo2_proofs/src/plonk.rs @@ -32,6 +32,8 @@ mod vanishing; mod prover; mod verifier; +pub mod tachyon; + pub use assigned::*; pub use circuit::*; pub use error::*; diff --git a/halo2_proofs/src/plonk/tachyon.rs b/halo2_proofs/src/plonk/tachyon.rs new file mode 100644 index 00000000..3c5fd8fe --- /dev/null +++ b/halo2_proofs/src/plonk/tachyon.rs @@ -0,0 +1,572 @@ +use std::{ + collections::{BTreeSet, HashMap}, + ops::{Range, RangeTo}, + sync::Arc, +}; + +use crate::{ + bn254::{ + AdviceSingle, Evals, InstanceSingle, ProvingKey as TachyonProvingKey, RationalEvals, + RationalEvalsView, TachyonProver, TranscriptWriteState, + }, + circuit::Value, + plonk::{ + sealed, Advice, Any, Assigned, Assignment, Challenge, Circuit, Column, ConstraintSystem, + Error, Fixed, FloorPlanner, Instance, Selector, + }, + poly::{ + commitment::{Blind, CommitmentScheme}, + LagrangeCoeff, Polynomial, + }, + rng::SerializableRng, + transcript::EncodedChallenge, +}; +use ff::{Field, FromUniformBytes, WithSmallOrderMulGroup}; +use halo2curves::{ + bn256::Fr, + group::{prime::PrimeCurveAffine, Curve}, + CurveAffine, +}; +use rand_core::RngCore; + +/// This creates a proof for the provided `circuits` when given the public +/// parameters `params` and the proving key [`ProvingKey`] that was +/// generated previously for the same circuit. The provided `instances` +/// are zero-padded internally. +pub fn create_proof< + 'params, + Scheme: CommitmentScheme, + P: TachyonProver, + E: EncodedChallenge, + R: RngCore + SerializableRng, + T: TranscriptWriteState, + ConcreteCircuit: Circuit, +>( + prover: &mut P, + pk: &mut TachyonProvingKey, + circuits: &[ConcreteCircuit], + instances: &[&[&[Scheme::Scalar]]], + fixed_values: Vec>, + mut rng: R, + transcript: &mut T, +) -> Result<(), Error> +where + Scheme::Scalar: WithSmallOrderMulGroup<3> + FromUniformBytes<64> + Ord, +{ + if circuits.len() != instances.len() { + return Err(Error::InvalidInstances); + } + + for instance in instances.iter() { + if instance.len() != pk.num_instance_columns() { + return Err(Error::InvalidInstances); + } + } + + prover.set_extended_domain(pk); + // Hash verification key into transcript + transcript.common_scalar(prover.transcript_repr(pk))?; + + let mut meta = ConstraintSystem::default(); + #[cfg(feature = "circuit-params")] + let config = ConcreteCircuit::configure_with_params(&mut meta, circuits[0].params()); + #[cfg(not(feature = "circuit-params"))] + let config = ConcreteCircuit::configure(&mut meta); + + // Selector optimizations cannot be applied here; use the ConstraintSystem + // from the verification key. + + let mut instance: Vec = instances + .iter() + .map(|instance| -> Result { + let instance_values = instance + .iter() + .map(|values| { + let mut poly = prover.empty_evals(); + assert_eq!(poly.len(), prover.n() as usize); + if values.len() > (poly.len() - ((pk.blinding_factors() as usize) + 1)) { + return Err(Error::InstanceTooLarge); + } + + for i in 0..values.len() { + if !P::QUERY_INSTANCE { + transcript.common_scalar(values[i])?; + } + poly.set_value(i, unsafe { + std::mem::transmute::<_, &halo2curves::bn256::Fr>(&values[i]) + }); + } + Ok(poly) + }) + .collect::, _>>()?; + + if P::QUERY_INSTANCE { + let instance_commitments_projective: Vec<_> = instance_values + .iter() + .map(|poly| prover.commit_lagrange(poly)) + .collect(); + let mut instance_commitments = + vec![Scheme::Curve::identity(); instance_commitments_projective.len()]; + ::CurveExt::batch_normalize( + &instance_commitments_projective, + &mut instance_commitments, + ); + let instance_commitments = instance_commitments; + drop(instance_commitments_projective); + + for commitment in &instance_commitments { + transcript.common_point(*commitment)?; + } + } + + let instance_polys: Vec<_> = instance_values + .iter() + .map(|evals| prover.ifft(evals)) + .collect(); + + Ok(InstanceSingle { + instance_values, + instance_polys, + }) + }) + .collect::, _>>()?; + + struct WitnessCollection<'a, F: Field> { + k: u32, + current_phase: sealed::Phase, + advice_vec: Arc>, + advice: Vec, + challenges: &'a HashMap, + instances: &'a [&'a [F]], + fixed_values: &'a [Polynomial], + rw_rows: Range, + usable_rows: RangeTo, + _marker: std::marker::PhantomData, + } + + impl<'a, F: Field> Assignment for WitnessCollection<'a, F> { + fn enter_region(&mut self, _: N) + where + NR: Into, + N: FnOnce() -> NR, + { + // Do nothing; we don't care about regions in this context. + } + + fn exit_region(&mut self) { + // Do nothing; we don't care about regions in this context. + } + + fn enable_selector(&mut self, _: A, _: &Selector, _: usize) -> Result<(), Error> + where + A: FnOnce() -> AR, + AR: Into, + { + // We only care about advice columns here + + Ok(()) + } + + fn fork(&mut self, ranges: &[Range]) -> Result, Error> { + let mut range_start = self.rw_rows.start; + for (i, sub_range) in ranges.iter().enumerate() { + if sub_range.start < range_start { + log::error!( + "subCS_{} sub_range.start ({}) < range_start ({})", + i, + sub_range.start, + range_start + ); + return Err(Error::Synthesis); + } + if i == ranges.len() - 1 && sub_range.end > self.rw_rows.end { + log::error!( + "subCS_{} sub_range.end ({}) > self.rw_rows.end ({})", + i, + sub_range.end, + self.rw_rows.end + ); + return Err(Error::Synthesis); + } + range_start = sub_range.end; + log::debug!( + "subCS_{} rw_rows: {}..{}", + i, + sub_range.start, + sub_range.end + ); + } + + let mut sub_cs = vec![]; + for sub_range in ranges { + let advice = Arc::try_unwrap(self.advice_vec.clone()) + .expect("there must only one Arc for advice_vec") + .iter_mut() + .map(|advice| { + advice.create_view(sub_range.start, sub_range.end - sub_range.start) + }) + .collect::>(); + + sub_cs.push(Self { + k: 0, + current_phase: self.current_phase, + advice_vec: self.advice_vec.clone(), + advice, + challenges: self.challenges, + instances: self.instances, + fixed_values: self.fixed_values, + rw_rows: sub_range.clone(), + usable_rows: self.usable_rows, + _marker: Default::default(), + }); + } + + Ok(sub_cs) + } + + fn merge(&mut self, _sub_cs: Vec) -> Result<(), Error> { + Ok(()) + } + + fn annotate_column(&mut self, _annotation: A, _column: Column) + where + A: FnOnce() -> AR, + AR: Into, + { + // Do nothing + } + + /// Get the last assigned value of a cell. + fn query_advice(&self, column: Column, row: usize) -> Result { + if !self.usable_rows.contains(&row) { + return Err(Error::not_enough_rows_available(self.k)); + } + if !self.rw_rows.contains(&row) { + log::error!("query_advice: {:?}, row: {}", column, row); + return Err(Error::Synthesis); + } + self.advice + .get(column.index()) + .and_then(|v| { + let mut r = F::ZERO; + v.evaluate(row - self.rw_rows.start, unsafe { + std::mem::transmute::<_, &mut Fr>(&mut r) + }); + Some(r) + }) + .ok_or(Error::BoundsFailure) + } + + fn query_fixed(&self, column: Column, row: usize) -> Result { + self.fixed_values + .get(column.index()) + .and_then(|v| v.get(row)) + .copied() + .ok_or(Error::BoundsFailure) + } + + fn query_instance(&self, column: Column, row: usize) -> Result, Error> { + if !self.usable_rows.contains(&row) { + return Err(Error::not_enough_rows_available(self.k)); + } + + self.instances + .get(column.index()) + .and_then(|column| column.get(row)) + .map(|v| Value::known(*v)) + .ok_or(Error::BoundsFailure) + } + + fn assign_advice( + &mut self, + _: A, + column: Column, + row: usize, + to: V, + ) -> Result<(), Error> + where + V: FnOnce() -> Value, + VR: Into>, + A: FnOnce() -> AR, + AR: Into, + { + // Ignore assignment of advice column in different phase than current one. + if self.current_phase.0 < column.column_type().phase.0 { + return Ok(()); + } + + if !self.usable_rows.contains(&row) { + return Err(Error::not_enough_rows_available(self.k)); + } + + if !self.rw_rows.contains(&row) { + log::error!("assign_advice: {:?}, row: {}", column, row); + return Err(Error::Synthesis); + } + + let rational_evals = self + .advice + .get_mut(column.index()) + .ok_or(Error::BoundsFailure)?; + + let row_idx = row - self.rw_rows.start; + let value = to().into_field().assign()?; + match &value { + Assigned::Zero => rational_evals.set_zero(row_idx), + Assigned::Trivial(numerator) => { + let numerator = unsafe { std::mem::transmute::<_, &Fr>(numerator) }; + rational_evals.set_trivial(row_idx, numerator); + } + Assigned::Rational(numerator, denominator) => { + let numerator = unsafe { std::mem::transmute::<_, &Fr>(numerator) }; + let denominator = unsafe { std::mem::transmute::<_, &Fr>(denominator) }; + rational_evals.set_rational(row_idx, numerator, denominator) + } + } + + Ok(()) + } + + fn assign_fixed( + &mut self, + _: A, + _: Column, + _: usize, + _: V, + ) -> Result<(), Error> + where + V: FnOnce() -> Value, + VR: Into>, + A: FnOnce() -> AR, + AR: Into, + { + // We only care about advice columns here + + Ok(()) + } + + fn copy( + &mut self, + _: Column, + _: usize, + _: Column, + _: usize, + ) -> Result<(), Error> { + // We only care about advice columns here + + Ok(()) + } + + fn fill_from_row( + &mut self, + _: Column, + _: usize, + _: Value>, + ) -> Result<(), Error> { + Ok(()) + } + + fn get_challenge(&self, challenge: Challenge) -> Value { + self.challenges + .get(&challenge.index()) + .cloned() + .map(Value::known) + .unwrap_or_else(Value::unknown) + } + + fn push_namespace(&mut self, _: N) + where + NR: Into, + N: FnOnce() -> NR, + { + // Do nothing; we don't care about namespaces in this context. + } + + fn pop_namespace(&mut self, _: Option) { + // Do nothing; we don't care about namespaces in this context. + } + } + + let (mut advice, challenges) = { + let num_advice_columns = pk.num_advice_columns(); + let num_challenges = pk.num_challenges(); + let mut advice = vec![ + AdviceSingle { + advice_polys: vec![prover.empty_evals(); num_advice_columns], + advice_blinds: vec![Blind::default(); num_advice_columns], + }; + instances.len() + ]; + #[cfg(feature = "phase-check")] + let mut advice_assignments = + vec![vec![prover.empty_rational_evals(); num_advice_columns]; instances.len()]; + let mut challenges = HashMap::::with_capacity(num_challenges); + + let unusable_rows_start = prover.n() as usize - ((pk.blinding_factors() as usize) + 1); + for current_phase in pk.phases() { + let column_indices = meta + .advice_column_phase + .iter() + .enumerate() + .filter_map(|(column_index, phase)| { + if current_phase == *phase { + Some(column_index) + } else { + None + } + }) + .collect::>(); + + for (_circuit_idx, ((circuit, advice), instances)) in circuits + .iter() + .zip(advice.iter_mut()) + .zip(instances) + .enumerate() + { + let mut advice_vec = + Arc::new(vec![prover.empty_rational_evals(); num_advice_columns]); + let advice_slice = Arc::get_mut(&mut advice_vec) + .unwrap() + .iter_mut() + .map(|advice| advice.create_view(0, advice.len())) + .collect::>(); + let mut witness = WitnessCollection { + k: prover.k(), + current_phase, + advice_vec, + advice: advice_slice, + instances, + fixed_values: fixed_values.as_slice(), + challenges: &challenges, + // The prover will not be allowed to assign values to advice + // cells that exist within inactive rows, which include some + // number of blinding factors and an extra row for use in the + // permutation argument. + usable_rows: ..unusable_rows_start, + rw_rows: 0..unusable_rows_start, + _marker: std::marker::PhantomData, + }; + + // Synthesize the circuit to obtain the witness and other information. + + log::info!("create_proof synthesize phase {current_phase:?} begin"); + ConcreteCircuit::FloorPlanner::synthesize( + &mut witness, + circuit, + config.clone(), + pk.constants(), + )?; + log::info!("create_proof synthesize phase {current_phase:?} end"); + + #[cfg(feature = "phase-check")] + { + let advice_column_phases = pk.advice_column_phases(); + for (idx, advice_col) in witness.advice.iter().enumerate() { + if advice_column_phases[idx].0 < current_phase.0 { + if advice_assignments[circuit_idx][idx].values != advice_col.values { + log::error!( + "advice column {}(at {:?}) changed when {:?}", + idx, + advice_column_phases[idx], + current_phase + ); + } + } + } + } + + let advice_assigned_values = Arc::try_unwrap(witness.advice_vec) + .expect("there must only one Arc for advice_vec") + .into_iter() + .enumerate() + .filter_map(|(column_index, advice)| { + if column_indices.contains(&column_index) { + #[cfg(feature = "phase-check")] + { + advice_assignments[circuit_idx][column_index] = advice.clone(); + } + Some(advice) + } else { + None + } + }) + .collect::>(); + let mut advice_values = vec![Evals::zero(); advice_assigned_values.len()]; + prover.batch_evaluate( + advice_assigned_values.as_slice(), + advice_values.as_mut_slice(), + ); + + // Add blinding factors to advice columns + for advice_values in &mut advice_values { + //for cell in &mut advice_values[unusable_rows_start..] { + //*cell = C::Scalar::random(&mut rng); + //*cell = C::Scalar::one(); + //} + let idx = advice_values.len() - 1; + advice_values.set_value(idx, &Fr::one()); + } + + // Compute commitments to advice column polynomials + let blinds: Vec<_> = advice_values + .iter() + .map(|_| Blind(Fr::random(&mut rng))) + .collect(); + let advice_commitments_projective: Vec<_> = advice_values + .iter() + .zip(blinds.iter()) + .map(|(poly, _)| prover.commit_lagrange(poly)) + .collect(); + let mut advice_commitments = + vec![Scheme::Curve::identity(); advice_commitments_projective.len()]; + ::CurveExt::batch_normalize( + &advice_commitments_projective, + &mut advice_commitments, + ); + let advice_commitments = advice_commitments; + drop(advice_commitments_projective); + + for commitment in &advice_commitments { + transcript.write_point(unsafe { + std::mem::transmute::<_, Scheme::Curve>(*commitment) + })?; + } + for ((column_index, advice_values), blind) in + column_indices.iter().zip(advice_values).zip(blinds) + { + advice.advice_polys[*column_index] = advice_values; + advice.advice_blinds[*column_index] = blind; + } + } + + for (index, phase) in pk.challenge_phases().iter().enumerate() { + if current_phase == *phase { + let existing = + challenges.insert(index, *transcript.squeeze_challenge_scalar::<()>()); + assert!(existing.is_none()); + } + } + } + + assert_eq!(challenges.len(), num_challenges); + let challenges = (0..num_challenges) + .map(|index| challenges.remove(&index).unwrap()) + .collect::>(); + + (advice, challenges) + }; + + drop(fixed_values); + + prover.set_rng(R::rng_type(), rng.state().as_slice()); + prover.set_transcript(transcript.state().as_slice()); + + let challenges = unsafe { std::mem::transmute::<_, Vec>(challenges) }; + prover.create_proof( + pk, + instance.as_mut_slice(), + advice.as_mut_slice(), + challenges.as_slice(), + ); + Ok(()) +} diff --git a/halo2_proofs/src/rng.rs b/halo2_proofs/src/rng.rs new file mode 100644 index 00000000..865ad3b2 --- /dev/null +++ b/halo2_proofs/src/rng.rs @@ -0,0 +1,6 @@ +use crate::consts::RNGType; + +pub trait SerializableRng { + fn state(&self) -> Vec; + fn rng_type() -> RNGType; +} diff --git a/halo2_proofs/src/rust_vec.h b/halo2_proofs/src/rust_vec.h new file mode 100644 index 00000000..de17e1a0 --- /dev/null +++ b/halo2_proofs/src/rust_vec.h @@ -0,0 +1,50 @@ +#ifndef HALO2_PROOFS_SRC_RUST_VEC_H_ +#define HALO2_PROOFS_SRC_RUST_VEC_H_ + +#include +#include +#include + +#include +#include +#include + +#include "rust/cxx.h" + +namespace tachyon::halo2_api { + +struct RustVec { + uintptr_t ptr; + size_t capacity; + size_t length; + + void Read(const uint8_t* data) { + memcpy(&capacity, data, sizeof(size_t)); + data += sizeof(size_t); + memcpy(&ptr, data, sizeof(uintptr_t)); + data += sizeof(uintptr_t); + memcpy(&length, data, sizeof(size_t)); + data += sizeof(size_t); + } + + std::string ToString() const { + std::stringstream ss; + ss << std::hex << "ptr: " << std::dec << ptr << " capacity: " << capacity + << " length: " << length; + return ss.str(); + } +}; + +template +rust::Vec ConvertCppContainerToRustVec(const Container& container) { + rust::Vec ret; + ret.reserve(std::size(container)); + for (const T& elem : container) { + ret.push_back(elem); + } + return ret; +} + +} // namespace tachyon::halo2_api + +#endif // HALO2_PROOFS_SRC_RUST_VEC_H_ diff --git a/halo2_proofs/src/xor_shift_rng.cc b/halo2_proofs/src/xor_shift_rng.cc new file mode 100644 index 00000000..fdbc7c7e --- /dev/null +++ b/halo2_proofs/src/xor_shift_rng.cc @@ -0,0 +1,44 @@ +#include "halo2_proofs/include/xor_shift_rng.h" + +#include + +namespace tachyon::halo2_api { + +XORShiftRng::XORShiftRng(std::array seed) { + uint8_t seed_copy[kSeedSize]; + memcpy(seed_copy, seed.data(), kSeedSize); + rng_ = + tachyon_rng_create_from_seed(TACHYON_RNG_XOR_SHIFT, seed_copy, kSeedSize); +} + +XORShiftRng::~XORShiftRng() { tachyon_rng_destroy(rng_); } + +uint32_t XORShiftRng::next_u32() { return tachyon_rng_get_next_u32(rng_); } + +std::unique_ptr XORShiftRng::clone() const { + uint8_t state[kStateSize]; + size_t state_len; + tachyon_rng_get_state(rng_, state, &state_len); + tachyon_rng* rng = + tachyon_rng_create_from_state(TACHYON_RNG_XOR_SHIFT, state, kStateSize); + return std::make_unique(rng); +} + +rust::Vec XORShiftRng::state() const { + rust::Vec ret; + // NOTE(chokobole): |rust::Vec| doesn't have |resize()|. + ret.reserve(kStateSize); + for (size_t i = 0; i < kStateSize; ++i) { + ret.push_back(0); + } + size_t state_len; + tachyon_rng_get_state(rng_, ret.data(), &state_len); + return ret; +} + +std::unique_ptr new_xor_shift_rng( + std::array seed) { + return std::make_unique(seed); +} + +} // namespace tachyon::halo2_api diff --git a/halo2_proofs/src/xor_shift_rng.rs b/halo2_proofs/src/xor_shift_rng.rs new file mode 100644 index 00000000..c9771f56 --- /dev/null +++ b/halo2_proofs/src/xor_shift_rng.rs @@ -0,0 +1,111 @@ +use std::fmt; + +#[cxx::bridge(namespace = "tachyon::halo2_api")] +pub mod ffi { + unsafe extern "C++" { + include!("halo2_proofs/include/xor_shift_rng.h"); + + type XORShiftRng; + + fn new_xor_shift_rng(seed: [u8; 16]) -> UniquePtr; + fn next_u32(self: Pin<&mut XORShiftRng>) -> u32; + fn clone(&self) -> UniquePtr; + fn state(&self) -> Vec; + } +} + +impl fmt::Debug for ffi::XORShiftRng { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("XORShiftRng").finish() + } +} + +#[derive(Debug)] +pub struct XORShiftRng { + inner: cxx::UniquePtr, +} + +unsafe impl Send for ffi::XORShiftRng {} + +impl XORShiftRng { + pub fn state(&self) -> Vec { + self.inner.state() + } +} + +impl Clone for XORShiftRng { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl rand_core::SeedableRng for XORShiftRng { + type Seed = [u8; 16]; + + fn from_seed(seed: Self::Seed) -> Self { + Self { + inner: ffi::new_xor_shift_rng(seed), + } + } +} + +impl rand_core::RngCore for XORShiftRng { + fn next_u32(&mut self) -> u32 { + self.inner.pin_mut().next_u32() + } + + #[inline] + fn next_u64(&mut self) -> u64 { + rand_core::impls::next_u64_via_u32(self) + } + + #[inline] + fn fill_bytes(&mut self, dest: &mut [u8]) { + rand_core::impls::fill_bytes_via_next(self, dest) + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + self.fill_bytes(dest); + Ok(()) + } +} + +#[cfg(test)] +mod test { + use rand_core::{RngCore, SeedableRng}; + + use crate::consts::XOR_SHIFT_SEED; + + #[test] + fn test_rng() { + let mut rng = rand_xorshift::XorShiftRng::from_seed(XOR_SHIFT_SEED); + let mut rng_tachyon = crate::xor_shift_rng::XORShiftRng::from_seed(XOR_SHIFT_SEED); + + const LEN: i32 = 100; + let random_u64s = (0..LEN).map(|_| rng.next_u64()).collect::>(); + let random_u64s_tachyon = (0..LEN).map(|_| rng_tachyon.next_u64()).collect::>(); + assert_eq!(random_u64s, random_u64s_tachyon); + } + + #[test] + fn test_clone() { + let mut rng = crate::xor_shift_rng::XORShiftRng::from_seed(XOR_SHIFT_SEED); + let mut rng_clone = rng.clone(); + + const LEN: i32 = 100; + let random_u64s = (0..LEN).map(|_| rng.next_u64()).collect::>(); + let random_u64s_clone = (0..LEN).map(|_| rng_clone.next_u64()).collect::>(); + assert_eq!(random_u64s, random_u64s_clone); + } + + #[test] + fn test_state() { + let rng = crate::xor_shift_rng::XORShiftRng::from_seed(XOR_SHIFT_SEED); + assert_eq!( + rng.state(), + vec![89, 98, 190, 93, 118, 61, 49, 141, 23, 219, 55, 50, 84, 6, 188, 229] + ); + } +}