diff --git a/halo2_proofs/Cargo.toml b/halo2_proofs/Cargo.toml index a21b41d9..6a520de9 100644 --- a/halo2_proofs/Cargo.toml +++ b/halo2_proofs/Cargo.toml @@ -67,6 +67,9 @@ log = "0.4.17" # timer ark-std = { version = "0.3.0" } +# binding +cxx = "1.0" + [dev-dependencies] assert_matches = "1.5" criterion = "0.3" @@ -74,6 +77,9 @@ gumdrop = "0.8" proptest = "1" rand_core = { version = "0.6", default-features = false, features = ["getrandom"] } +[build-dependencies] +cxx-build = "1.0" + [target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dev-dependencies] getrandom = { version = "0.2", features = ["js"] } @@ -87,6 +93,7 @@ shplonk = [] gwc = [] phase-check = [] profile = ["ark-std/print-trace"] +tachyon_msm_gpu = [] [lib] bench = false diff --git a/halo2_proofs/build.rs b/halo2_proofs/build.rs new file mode 100644 index 00000000..7da8005d --- /dev/null +++ b/halo2_proofs/build.rs @@ -0,0 +1,23 @@ +fn main() { + cxx_build::bridge("src/lib.rs") + .files([ + "src/msm.cc", + #[cfg(feature = "tachyon_msm_gpu")] + "src/msm_gpu.cc", + ]) + .flag_if_supported("-std=c++17") + .compile("halo2_proofs"); + + let dep_files = vec![ + "src/lib.rs", + "src/msm.cc", + #[cfg(feature = "tachyon_msm_gpu")] + "src/msm_gpu.cc", + "include/msm.h", + ]; + for file in dep_files { + println!("cargo:rerun-if-changed={file}"); + } + + println!("cargo:rustc-link-lib=dylib=tachyon"); +} diff --git a/halo2_proofs/include/msm.h b/halo2_proofs/include/msm.h new file mode 100644 index 00000000..100c8df5 --- /dev/null +++ b/halo2_proofs/include/msm.h @@ -0,0 +1,26 @@ +#ifndef HALO2_PROOFS_INCLUDE_MSM_H_ +#define HALO2_PROOFS_INCLUDE_MSM_H_ + +#include "rust/cxx.h" + +namespace tachyon { +namespace halo2 { + +struct CppG1Affine; +struct CppG1Jacobian; +struct CppFr; + +rust::Box msm(rust::Slice bases, + rust::Slice scalars); + +void init_msm_gpu(uint8_t degree); + +void release_msm_gpu(); + +rust::Box msm_gpu(rust::Slice bases, + rust::Slice scalars); + +} // namespace halo2 +} // namespace tachyon + +#endif // HALO2_PROOFS_INCLUDE_MSM_H_ diff --git a/halo2_proofs/src/lib.rs b/halo2_proofs/src/lib.rs index e577a8c0..9a08c65d 100644 --- a/halo2_proofs/src/lib.rs +++ b/halo2_proofs/src/lib.rs @@ -19,7 +19,7 @@ )] #![deny(broken_intra_doc_links)] #![deny(missing_debug_implementations)] -#![deny(unsafe_code)] +#![allow(unsafe_code)] // Remove this once we update pasta_curves #![allow(unused_imports)] #![allow(clippy::derive_partial_eq_without_eq)] @@ -35,3 +35,50 @@ pub mod transcript; pub mod dev; mod helpers; pub use helpers::SerdeFormat; + +#[cxx::bridge(namespace = "tachyon::halo2")] +mod ffi { + // Rust types and signatures exposed to C++. + extern "Rust" { + type CppG1Affine; + type CppG1Jacobian; + type CppFq; + type CppFr; + } + + // C++ types and signatures exposed to Rust. + unsafe extern "C++" { + include!("halo2_proofs/include/msm.h"); + + fn msm(bases: &[CppG1Affine], scalars: &[CppFr]) -> Box; + #[cfg(feature = "tachyon_msm_gpu")] + fn init_msm_gpu(degree: u8); + #[cfg(feature = "tachyon_msm_gpu")] + fn release_msm_gpu(); + #[cfg(feature = "tachyon_msm_gpu")] + fn msm_gpu(bases: &[CppG1Affine], scalars: &[CppFr]) -> Box; + } +} + +#[repr(C)] +#[derive(Debug)] +pub struct CppG1Affine { + pub x: CppFq, + pub y: CppFq, +} + +#[repr(C)] +#[derive(Debug)] +pub struct CppG1Jacobian { + pub x: CppFq, + pub y: CppFq, + pub z: CppFq, +} + +#[repr(transparent)] +#[derive(Debug)] +pub struct CppFq(pub [u64; 4]); + +#[repr(transparent)] +#[derive(Debug)] +pub struct CppFr(pub [u64; 4]); diff --git a/halo2_proofs/src/msm.cc b/halo2_proofs/src/msm.cc new file mode 100644 index 00000000..b7b26cc5 --- /dev/null +++ b/halo2_proofs/src/msm.cc @@ -0,0 +1,20 @@ +#include "halo2_proofs/include/msm.h" + +#include + +#include "halo2_proofs/src/lib.rs.h" + +namespace tachyon { +namespace halo2 { + +rust::Box msm(rust::Slice bases, + rust::Slice scalars) { + auto ret = tachyon_msm_g1_point2( + reinterpret_cast(bases.data()), + bases.length(), reinterpret_cast(scalars.data()), + scalars.length()); + return rust::Box::from_raw(reinterpret_cast(ret)); +} + +} // namespace halo2 +} // namespace tachyon diff --git a/halo2_proofs/src/msm_gpu.cc b/halo2_proofs/src/msm_gpu.cc new file mode 100644 index 00000000..3841b44a --- /dev/null +++ b/halo2_proofs/src/msm_gpu.cc @@ -0,0 +1,26 @@ +// clang-format off +#include +// clang-format on + +#include "halo2_proofs/src/lib.rs.h" +#include "halo2_proofs/include/msm.h" + +namespace tachyon { + +namespace halo2 { + +void init_msm_gpu(uint8_t degree) { tachyon_init_msm_gpu(degree); } + +void release_msm_gpu() { tachyon_release_msm_gpu(); } + +rust::Box msm_gpu(rust::Slice bases, + rust::Slice scalars) { + auto ret = tachyon_msm_g1_point2_gpu( + reinterpret_cast(bases.data()), + bases.length(), reinterpret_cast(scalars.data()), + scalars.length()); + return rust::Box::from_raw(reinterpret_cast(ret)); +} + +} // namespace halo2 +} // namespace tachyon diff --git a/halo2_proofs/src/poly/kzg/commitment.rs b/halo2_proofs/src/poly/kzg/commitment.rs index aa86fdc1..f48e889d 100644 --- a/halo2_proofs/src/poly/kzg/commitment.rs +++ b/halo2_proofs/src/poly/kzg/commitment.rs @@ -295,11 +295,24 @@ where poly: &Polynomial, _: Blind, ) -> E::G1 { - let mut scalars = Vec::with_capacity(poly.len()); + let mut scalars: Vec = Vec::with_capacity(poly.len()); scalars.extend(poly.iter()); let bases = &self.g_lagrange; let size = scalars.len(); assert!(bases.len() >= size); + #[cfg(feature = "tachyon_msm_gpu")] + unsafe { + use crate::{ffi, CppFr, CppG1Affine}; + use std::mem; + + let bases: &[CppG1Affine] = mem::transmute(bases.as_slice()); + let scalars: &[CppFr] = mem::transmute(scalars.as_slice()); + + let ret = ffi::msm_gpu(bases, scalars); + let ret: Box = mem::transmute(ret); + *ret + } + #[cfg(not(feature = "tachyon_msm_gpu"))] best_multiexp(&scalars, &bases[0..size]) } @@ -337,11 +350,24 @@ where } fn commit(&self, poly: &Polynomial, _: Blind) -> E::G1 { - let mut scalars = Vec::with_capacity(poly.len()); + let mut scalars: Vec = Vec::with_capacity(poly.len()); scalars.extend(poly.iter()); let bases = &self.g; let size = scalars.len(); assert!(bases.len() >= size); + #[cfg(feature = "tachyon_msm_gpu")] + unsafe { + use crate::{ffi, CppFr, CppG1Affine}; + use std::mem; + + let bases: &[CppG1Affine] = mem::transmute(bases.as_slice()); + let scalars: &[CppFr] = mem::transmute(scalars.as_slice()); + + let ret = ffi::msm_gpu(bases, scalars); + let ret: Box = mem::transmute(ret); + *ret + } + #[cfg(not(feature = "tachyon_msm_gpu"))] best_multiexp(&scalars, &bases[0..size]) } diff --git a/halo2_proofs/src/poly/kzg/msm.rs b/halo2_proofs/src/poly/kzg/msm.rs index 6cc90a51..3d13b081 100644 --- a/halo2_proofs/src/poly/kzg/msm.rs +++ b/halo2_proofs/src/poly/kzg/msm.rs @@ -66,6 +66,19 @@ impl MSM for MSMKZG { use group::prime::PrimeCurveAffine; let mut bases = vec![E::G1Affine::identity(); self.scalars.len()]; E::G1::batch_normalize(&self.bases, &mut bases); + #[cfg(feature = "tachyon_msm_gpu")] + unsafe { + use crate::{ffi, CppFr, CppG1Affine}; + use std::mem; + + let bases: &[CppG1Affine] = mem::transmute(bases.as_slice()); + let scalars: &[CppFr] = mem::transmute(self.scalars.as_slice()); + + let ret = ffi::msm_gpu(bases, scalars); + let ret: Box = mem::transmute(ret); + *ret + } + #[cfg(not(feature = "tachyon_msm_gpu"))] best_multiexp(&self.scalars, &bases) }