Skip to content

Commit

Permalink
Fix icicle MSM stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
moodlezoup committed Jan 16, 2025
1 parent b72983c commit 6ce9ae0
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 41 deletions.
37 changes: 22 additions & 15 deletions jolt-core/src/msm/icicle/adapter.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::msm::{GpuBaseType, MsmType, VariableBaseMSM};
use crate::field::JoltField;
use crate::msm::{GpuBaseType, VariableBaseMSM};
use ark_bn254::G1Projective;
use ark_ec::{CurveGroup, ScalarMul};
use ark_ff::{BigInteger, Field, PrimeField};
Expand Down Expand Up @@ -60,11 +61,11 @@ pub trait Icicle: ScalarMul {
}

#[tracing::instrument(skip_all, name = "icicle_msm")]
pub fn icicle_msm<V: VariableBaseMSM>(
bases: &[GpuBaseType<V>],
scalars: &[V::ScalarField],
bit_size: usize,
) -> V {
pub fn icicle_msm<V>(bases: &[GpuBaseType<V>], scalars: &[V::ScalarField], max_num_bits: usize) -> V
where
V: VariableBaseMSM,
V::ScalarField: JoltField,
{
assert!(scalars.len() <= bases.len());

let mut bases_slice = DeviceVec::<GpuBaseType<V>>::device_malloc(bases.len()).unwrap();
Expand Down Expand Up @@ -99,7 +100,7 @@ pub fn icicle_msm<V: VariableBaseMSM>(
cfg.stream_handle = IcicleStreamHandle::from(&stream);
cfg.is_async = false;
cfg.are_scalars_montgomery_form = true;
cfg.bitsize = bit_size as i32;
cfg.bitsize = max_num_bits as i32;

let span = tracing::span!(tracing::Level::INFO, "gpu_msm");
let _guard = span.enter();
Expand Down Expand Up @@ -133,11 +134,15 @@ pub fn icicle_msm<V: VariableBaseMSM>(
/// Batch process msms - assumes batches are equal in size
/// Variable Batch sizes is not currently supported by icicle
#[tracing::instrument(skip_all)]
pub fn icicle_batch_msm<V: VariableBaseMSM>(
pub fn icicle_batch_msm<V>(
bases: &[GpuBaseType<V>],
scalar_batches: &[&[V::ScalarField]],
batch_type: MsmType,
) -> Vec<V> {
max_num_bits: usize,
) -> Vec<V>
where
V: VariableBaseMSM,
V::ScalarField: JoltField,
{
let bases_len = bases.len();
let batch_size = scalar_batches.len();
assert!(scalar_batches.par_iter().all(|s| s.len() == bases_len));
Expand Down Expand Up @@ -192,7 +197,7 @@ pub fn icicle_batch_msm<V: VariableBaseMSM>(
cfg.is_async = true;
cfg.are_scalars_montgomery_form = true;
cfg.batch_size = batch_size as i32;
cfg.bitsize = batch_type.num_bits() as i32;
cfg.bitsize = max_num_bits as i32;
cfg.ext
.set_int(icicle_core::msm::CUDA_MSM_LARGE_BUCKET_FACTOR, 5);

Expand Down Expand Up @@ -311,7 +316,7 @@ mod tests {
let icicle_res = icicle_msm::<G1Projective>(&gpu_bases, &scalars, 256);
let arkworks_res: G1Projective = ark_VariableBaseMSM::msm(&bases, &scalars).unwrap();
let no_gpu_res: G1Projective =
VariableBaseMSM::inner_msm(&bases, None, &scalars, false, None).unwrap();
VariableBaseMSM::msm_field_elements(&bases, None, &scalars, None, false).unwrap();

assert_eq!(icicle_res, arkworks_res);
assert_eq!(icicle_res, no_gpu_res);
Expand All @@ -337,15 +342,17 @@ mod tests {
.par_iter()
.map(|base| <G1Projective as Icicle>::from_ark_affine(base))
.collect::<Vec<_>>();
let icicle_res =
icicle_batch_msm::<G1Projective>(&gpu_bases, &scalar_batches, MsmType::Large(256));
let icicle_res = icicle_batch_msm::<G1Projective>(&gpu_bases, &scalar_batches, 256);
let arkworks_res: Vec<G1Projective> = (0..20)
.into_iter()
.map(|_| ark_VariableBaseMSM::msm(&bases, &scalars).unwrap())
.collect();
let no_gpu_res: Vec<G1Projective> = (0..20)
.into_iter()
.map(|_| VariableBaseMSM::inner_msm(&bases, None, &scalars, false, None).unwrap())
.map(|_| {
VariableBaseMSM::msm_field_elements(&bases, None, &scalars, None, false)
.unwrap()
})
.collect();

assert_eq!(icicle_res, arkworks_res);
Expand Down
59 changes: 33 additions & 26 deletions jolt-core/src/msm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rayon::prelude::*;

pub(crate) mod icicle;
use crate::field::JoltField;
use crate::poly::multilinear_polynomial::MultilinearPolynomial;
use crate::poly::{dense_mlpoly::DensePolynomial, multilinear_polynomial::MultilinearPolynomial};

Check failure on line 12 in jolt-core/src/msm/mod.rs

View workflow job for this annotation

GitHub Actions / Build Wasm

unused import: `dense_mlpoly::DensePolynomial`

Check failure on line 12 in jolt-core/src/msm/mod.rs

View workflow job for this annotation

GitHub Actions / clippy

unused import: `dense_mlpoly::DensePolynomial`

Check failure on line 12 in jolt-core/src/msm/mod.rs

View workflow job for this annotation

GitHub Actions / Onchain Verifier Tests

unused import: `dense_mlpoly::DensePolynomial`

Check failure on line 12 in jolt-core/src/msm/mod.rs

View workflow job for this annotation

GitHub Actions / test

unused import: `dense_mlpoly::DensePolynomial`
use crate::utils::errors::ProofVerifyError;
use crate::utils::math::Math;
pub use icicle::*;
Expand Down Expand Up @@ -244,7 +244,7 @@ where
let (cpu_batch, gpu_batch): (Vec<_>, Vec<_>) =
polys.par_iter().enumerate().partition_map(|(i, poly)| {
let max_num_bits = poly.max_num_bits();
if use_icicle && max_num_bits > 10 {
if use_icicle && max_num_bits > 64 {
Either::Right((i, max_num_bits, *poly))
} else {
Either::Left((i, max_num_bits, *poly))
Expand Down Expand Up @@ -292,15 +292,22 @@ where

// Process GPU batches with memory constraints
for work_chunk in gpu_batch.chunks(slices_at_a_time) {
let (scalar_types, chunk_scalars): (Vec<_>, Vec<&[Self::ScalarField]>) =
work_chunk
.par_iter()
.map(|(_, msm_type, scalars)| (*msm_type, *scalars))
.unzip();

let max_scalar_type = scalar_types.par_iter().max().unwrap();
let (max_num_bits, chunk_polys): (Vec<_>, Vec<_>) = work_chunk
.par_iter()
.map(|(_, max_num_bits, poly)| (*max_num_bits, *poly))
.unzip();

let max_num_bits = max_num_bits.iter().max().unwrap();
let scalars: Vec<_> = chunk_polys
.into_iter()
.map(|poly| {
let poly: &DensePolynomial<Self::ScalarField> =
poly.try_into().unwrap();
poly.evals_ref()
})
.collect();
let batch_results =
icicle_batch_msm(gpu_bases, &chunk_scalars, *max_scalar_type);
icicle_batch_msm(gpu_bases, &scalars, *max_num_bits as usize);

// Store GPU results using original indices
for ((original_idx, _, _), result) in work_chunk.iter().zip(batch_results) {
Expand Down Expand Up @@ -562,28 +569,28 @@ fn msm_medium<F, V, T>(
_gpu_bases: Option<&[GpuBaseType<V>]>,
scalars: &[T],
max_num_bits: usize,
use_icicle: bool,
_use_icicle: bool,
) -> V
where
F: JoltField,
V: VariableBaseMSM<ScalarField = F>,
T: Into<u64> + Zero + Copy + Sync,
{
if use_icicle {
#[cfg(feature = "icicle")]
{
let mut backup = vec![];
let gpu_bases = _gpu_bases.unwrap_or_else(|| {
backup = Self::get_gpu_bases(bases);
&backup
});
return icicle_msm::<Self>(_gpu_bases, scalars, max_num_bits);
}
#[cfg(not(feature = "icicle"))]
{
unreachable!("icicle_init must not return true without the icicle feature");
}
}
// if use_icicle {
// #[cfg(feature = "icicle")]
// {
// let mut backup = vec![];
// let gpu_bases = _gpu_bases.unwrap_or_else(|| {
// backup = VariableBaseMSM::get_gpu_bases(bases);
// &backup
// });
// return icicle_msm(gpu_bases, scalars, max_num_bits);
// }
// #[cfg(not(feature = "icicle"))]
// {
// unreachable!("icicle_init must not return true without the icicle feature");
// }
// }

let c = if bases.len() < 32 {
3
Expand Down

0 comments on commit 6ce9ae0

Please sign in to comment.