Skip to content

Commit

Permalink
WIP: call msm using gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
chokobole committed Aug 10, 2023
1 parent 9922fbb commit 406c992
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 3 deletions.
7 changes: 7 additions & 0 deletions halo2_proofs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,19 @@ log = "0.4.17"
# timer
ark-std = { version = "0.3.0" }

# binding
cxx = "1.0"

[dev-dependencies]
assert_matches = "1.5"
criterion = "0.3"
gumdrop = "0.8"
proptest = "1"
rand_core = { version = "0.6", default-features = false, features = ["getrandom"] }

[build-dependencies]
cxx-build = "1.0"

[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dev-dependencies]
getrandom = { version = "0.2", features = ["js"] }

Expand All @@ -87,6 +93,7 @@ shplonk = []
gwc = []
phase-check = []
profile = ["ark-std/print-trace"]
tachyon_msm_gpu = []

[lib]
bench = false
Expand Down
23 changes: 23 additions & 0 deletions halo2_proofs/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
fn main() {
cxx_build::bridge("src/lib.rs")
.files([
"src/msm.cc",
#[cfg(feature = "tachyon_msm_gpu")]
"src/msm_gpu.cc",
])
.flag_if_supported("-std=c++17")
.compile("halo2_proofs");

let dep_files = vec![
"src/lib.rs",
"src/msm.cc",
#[cfg(feature = "tachyon_msm_gpu")]
"src/msm_gpu.cc",
"include/msm.h",
];
for file in dep_files {
println!("cargo:rerun-if-changed={file}");
}

println!("cargo:rustc-link-lib=dylib=tachyon");
}
26 changes: 26 additions & 0 deletions halo2_proofs/include/msm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef HALO2_PROOFS_INCLUDE_MSM_H_
#define HALO2_PROOFS_INCLUDE_MSM_H_

#include "rust/cxx.h"

namespace tachyon {
namespace halo2 {

struct CppG1Affine;
struct CppG1Jacobian;
struct CppFr;

rust::Box<CppG1Jacobian> msm(rust::Slice<const CppG1Affine> bases,
rust::Slice<const CppFr> scalars);

void init_msm_gpu(uint8_t degree);

void release_msm_gpu();

rust::Box<CppG1Jacobian> msm_gpu(rust::Slice<const CppG1Affine> bases,
rust::Slice<const CppFr> scalars);

} // namespace halo2
} // namespace tachyon

#endif // HALO2_PROOFS_INCLUDE_MSM_H_
49 changes: 48 additions & 1 deletion halo2_proofs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)]
#![deny(broken_intra_doc_links)]
#![deny(missing_debug_implementations)]
#![deny(unsafe_code)]
#![allow(unsafe_code)]
// Remove this once we update pasta_curves
#![allow(unused_imports)]
#![allow(clippy::derive_partial_eq_without_eq)]
Expand All @@ -35,3 +35,50 @@ pub mod transcript;
pub mod dev;
mod helpers;
pub use helpers::SerdeFormat;

#[cxx::bridge(namespace = "tachyon::halo2")]
mod ffi {
// Rust types and signatures exposed to C++.
extern "Rust" {
type CppG1Affine;
type CppG1Jacobian;
type CppFq;
type CppFr;
}

// C++ types and signatures exposed to Rust.
unsafe extern "C++" {
include!("halo2_proofs/include/msm.h");

fn msm(bases: &[CppG1Affine], scalars: &[CppFr]) -> Box<CppG1Jacobian>;
#[cfg(feature = "tachyon_msm_gpu")]
fn init_msm_gpu(degree: u8);
#[cfg(feature = "tachyon_msm_gpu")]
fn release_msm_gpu();
#[cfg(feature = "tachyon_msm_gpu")]
fn msm_gpu(bases: &[CppG1Affine], scalars: &[CppFr]) -> Box<CppG1Jacobian>;
}
}

#[repr(C)]
#[derive(Debug)]
pub struct CppG1Affine {
pub x: CppFq,
pub y: CppFq,
}

#[repr(C)]
#[derive(Debug)]
pub struct CppG1Jacobian {
pub x: CppFq,
pub y: CppFq,
pub z: CppFq,
}

#[repr(transparent)]
#[derive(Debug)]
pub struct CppFq(pub [u64; 4]);

#[repr(transparent)]
#[derive(Debug)]
pub struct CppFr(pub [u64; 4]);
20 changes: 20 additions & 0 deletions halo2_proofs/src/msm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "halo2_proofs/include/msm.h"

#include <tachyon/c/math/msm/msm.h>

#include "halo2_proofs/src/lib.rs.h"

namespace tachyon {
namespace halo2 {

rust::Box<CppG1Jacobian> msm(rust::Slice<const CppG1Affine> bases,
rust::Slice<const CppFr> scalars) {
auto ret = tachyon_msm_g1_point2(
reinterpret_cast<const tachyon_bn254_point2*>(bases.data()),
bases.length(), reinterpret_cast<const tachyon_bn254_fr*>(scalars.data()),
scalars.length());
return rust::Box<CppG1Jacobian>::from_raw(reinterpret_cast<CppG1Jacobian*>(ret));
}

} // namespace halo2
} // namespace tachyon
26 changes: 26 additions & 0 deletions halo2_proofs/src/msm_gpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// clang-format off
#include <tachyon/c/math/msm/msm_gpu.h>
// clang-format on

#include "halo2_proofs/src/lib.rs.h"
#include "halo2_proofs/include/msm.h"

namespace tachyon {

namespace halo2 {

void init_msm_gpu(uint8_t degree) { tachyon_init_msm_gpu(degree); }

void release_msm_gpu() { tachyon_release_msm_gpu(); }

rust::Box<CppG1Jacobian> msm_gpu(rust::Slice<const CppG1Affine> bases,
rust::Slice<const CppFr> scalars) {
auto ret = tachyon_msm_g1_point2_gpu(
reinterpret_cast<const tachyon_bn254_point2*>(bases.data()),
bases.length(), reinterpret_cast<const tachyon_bn254_fr*>(scalars.data()),
scalars.length());
return rust::Box<CppG1Jacobian>::from_raw(reinterpret_cast<CppG1Jacobian*>(ret));
}

} // namespace halo2
} // namespace tachyon
30 changes: 28 additions & 2 deletions halo2_proofs/src/poly/kzg/commitment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,24 @@ where
poly: &Polynomial<E::Scalar, LagrangeCoeff>,
_: Blind<E::Scalar>,
) -> E::G1 {
let mut scalars = Vec::with_capacity(poly.len());
let mut scalars: Vec<E::Scalar> = Vec::with_capacity(poly.len());
scalars.extend(poly.iter());
let bases = &self.g_lagrange;
let size = scalars.len();
assert!(bases.len() >= size);
#[cfg(feature = "tachyon_msm_gpu")]
unsafe {
use crate::{ffi, CppFr, CppG1Affine};
use std::mem;

let bases: &[CppG1Affine] = mem::transmute(bases.as_slice());
let scalars: &[CppFr] = mem::transmute(scalars.as_slice());

let ret = ffi::msm_gpu(bases, scalars);
let ret: Box<E::G1> = mem::transmute(ret);
*ret
}
#[cfg(not(feature = "tachyon_msm_gpu"))]
best_multiexp(&scalars, &bases[0..size])
}

Expand Down Expand Up @@ -337,11 +350,24 @@ where
}

fn commit(&self, poly: &Polynomial<E::Scalar, Coeff>, _: Blind<E::Scalar>) -> E::G1 {
let mut scalars = Vec::with_capacity(poly.len());
let mut scalars: Vec<E::Scalar> = Vec::with_capacity(poly.len());
scalars.extend(poly.iter());
let bases = &self.g;
let size = scalars.len();
assert!(bases.len() >= size);
#[cfg(feature = "tachyon_msm_gpu")]
unsafe {
use crate::{ffi, CppFr, CppG1Affine};
use std::mem;

let bases: &[CppG1Affine] = mem::transmute(bases.as_slice());
let scalars: &[CppFr] = mem::transmute(scalars.as_slice());

let ret = ffi::msm_gpu(bases, scalars);
let ret: Box<E::G1> = mem::transmute(ret);
*ret
}
#[cfg(not(feature = "tachyon_msm_gpu"))]
best_multiexp(&scalars, &bases[0..size])
}

Expand Down
13 changes: 13 additions & 0 deletions halo2_proofs/src/poly/kzg/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ impl<E: Engine + Debug> MSM<E::G1Affine> for MSMKZG<E> {
use group::prime::PrimeCurveAffine;
let mut bases = vec![E::G1Affine::identity(); self.scalars.len()];
E::G1::batch_normalize(&self.bases, &mut bases);
#[cfg(feature = "tachyon_msm_gpu")]
unsafe {
use crate::{ffi, CppFr, CppG1Affine};
use std::mem;

let bases: &[CppG1Affine] = mem::transmute(bases.as_slice());
let scalars: &[CppFr] = mem::transmute(self.scalars.as_slice());

let ret = ffi::msm_gpu(bases, scalars);
let ret: Box<E::G1> = mem::transmute(ret);
*ret
}
#[cfg(not(feature = "tachyon_msm_gpu"))]
best_multiexp(&self.scalars, &bases)
}

Expand Down

0 comments on commit 406c992

Please sign in to comment.