From 6ce9ae05eeb28ed995c9765fad1cd6de542582a5 Mon Sep 17 00:00:00 2001 From: Michael Zhu Date: Thu, 16 Jan 2025 11:01:07 -0500 Subject: [PATCH] Fix icicle MSM stuff --- jolt-core/src/msm/icicle/adapter.rs | 37 ++++++++++-------- jolt-core/src/msm/mod.rs | 59 ++++++++++++++++------------- 2 files changed, 55 insertions(+), 41 deletions(-) diff --git a/jolt-core/src/msm/icicle/adapter.rs b/jolt-core/src/msm/icicle/adapter.rs index 0d40890ca..6196b3bd6 100644 --- a/jolt-core/src/msm/icicle/adapter.rs +++ b/jolt-core/src/msm/icicle/adapter.rs @@ -1,4 +1,5 @@ -use crate::msm::{GpuBaseType, MsmType, VariableBaseMSM}; +use crate::field::JoltField; +use crate::msm::{GpuBaseType, VariableBaseMSM}; use ark_bn254::G1Projective; use ark_ec::{CurveGroup, ScalarMul}; use ark_ff::{BigInteger, Field, PrimeField}; @@ -60,11 +61,11 @@ pub trait Icicle: ScalarMul { } #[tracing::instrument(skip_all, name = "icicle_msm")] -pub fn icicle_msm( - bases: &[GpuBaseType], - scalars: &[V::ScalarField], - bit_size: usize, -) -> V { +pub fn icicle_msm(bases: &[GpuBaseType], scalars: &[V::ScalarField], max_num_bits: usize) -> V +where + V: VariableBaseMSM, + V::ScalarField: JoltField, +{ assert!(scalars.len() <= bases.len()); let mut bases_slice = DeviceVec::>::device_malloc(bases.len()).unwrap(); @@ -99,7 +100,7 @@ pub fn icicle_msm( cfg.stream_handle = IcicleStreamHandle::from(&stream); cfg.is_async = false; cfg.are_scalars_montgomery_form = true; - cfg.bitsize = bit_size as i32; + cfg.bitsize = max_num_bits as i32; let span = tracing::span!(tracing::Level::INFO, "gpu_msm"); let _guard = span.enter(); @@ -133,11 +134,15 @@ pub fn icicle_msm( /// Batch process msms - assumes batches are equal in size /// Variable Batch sizes is not currently supported by icicle #[tracing::instrument(skip_all)] -pub fn icicle_batch_msm( +pub fn icicle_batch_msm( bases: &[GpuBaseType], scalar_batches: &[&[V::ScalarField]], - batch_type: MsmType, -) -> Vec { + max_num_bits: usize, +) -> Vec +where + V: VariableBaseMSM, + V::ScalarField: JoltField, +{ let bases_len = bases.len(); let batch_size = scalar_batches.len(); assert!(scalar_batches.par_iter().all(|s| s.len() == bases_len)); @@ -192,7 +197,7 @@ pub fn icicle_batch_msm( cfg.is_async = true; cfg.are_scalars_montgomery_form = true; cfg.batch_size = batch_size as i32; - cfg.bitsize = batch_type.num_bits() as i32; + cfg.bitsize = max_num_bits as i32; cfg.ext .set_int(icicle_core::msm::CUDA_MSM_LARGE_BUCKET_FACTOR, 5); @@ -311,7 +316,7 @@ mod tests { let icicle_res = icicle_msm::(&gpu_bases, &scalars, 256); let arkworks_res: G1Projective = ark_VariableBaseMSM::msm(&bases, &scalars).unwrap(); let no_gpu_res: G1Projective = - VariableBaseMSM::inner_msm(&bases, None, &scalars, false, None).unwrap(); + VariableBaseMSM::msm_field_elements(&bases, None, &scalars, None, false).unwrap(); assert_eq!(icicle_res, arkworks_res); assert_eq!(icicle_res, no_gpu_res); @@ -337,15 +342,17 @@ mod tests { .par_iter() .map(|base| ::from_ark_affine(base)) .collect::>(); - let icicle_res = - icicle_batch_msm::(&gpu_bases, &scalar_batches, MsmType::Large(256)); + let icicle_res = icicle_batch_msm::(&gpu_bases, &scalar_batches, 256); let arkworks_res: Vec = (0..20) .into_iter() .map(|_| ark_VariableBaseMSM::msm(&bases, &scalars).unwrap()) .collect(); let no_gpu_res: Vec = (0..20) .into_iter() - .map(|_| VariableBaseMSM::inner_msm(&bases, None, &scalars, false, None).unwrap()) + .map(|_| { + VariableBaseMSM::msm_field_elements(&bases, None, &scalars, None, false) + .unwrap() + }) .collect(); assert_eq!(icicle_res, arkworks_res); diff --git a/jolt-core/src/msm/mod.rs b/jolt-core/src/msm/mod.rs index ee19ad8e2..bb73d57e1 100644 --- a/jolt-core/src/msm/mod.rs +++ b/jolt-core/src/msm/mod.rs @@ -9,7 +9,7 @@ use rayon::prelude::*; pub(crate) mod icicle; use crate::field::JoltField; -use crate::poly::multilinear_polynomial::MultilinearPolynomial; +use crate::poly::{dense_mlpoly::DensePolynomial, multilinear_polynomial::MultilinearPolynomial}; use crate::utils::errors::ProofVerifyError; use crate::utils::math::Math; pub use icicle::*; @@ -244,7 +244,7 @@ where let (cpu_batch, gpu_batch): (Vec<_>, Vec<_>) = polys.par_iter().enumerate().partition_map(|(i, poly)| { let max_num_bits = poly.max_num_bits(); - if use_icicle && max_num_bits > 10 { + if use_icicle && max_num_bits > 64 { Either::Right((i, max_num_bits, *poly)) } else { Either::Left((i, max_num_bits, *poly)) @@ -292,15 +292,22 @@ where // Process GPU batches with memory constraints for work_chunk in gpu_batch.chunks(slices_at_a_time) { - let (scalar_types, chunk_scalars): (Vec<_>, Vec<&[Self::ScalarField]>) = - work_chunk - .par_iter() - .map(|(_, msm_type, scalars)| (*msm_type, *scalars)) - .unzip(); - - let max_scalar_type = scalar_types.par_iter().max().unwrap(); + let (max_num_bits, chunk_polys): (Vec<_>, Vec<_>) = work_chunk + .par_iter() + .map(|(_, max_num_bits, poly)| (*max_num_bits, *poly)) + .unzip(); + + let max_num_bits = max_num_bits.iter().max().unwrap(); + let scalars: Vec<_> = chunk_polys + .into_iter() + .map(|poly| { + let poly: &DensePolynomial = + poly.try_into().unwrap(); + poly.evals_ref() + }) + .collect(); let batch_results = - icicle_batch_msm(gpu_bases, &chunk_scalars, *max_scalar_type); + icicle_batch_msm(gpu_bases, &scalars, *max_num_bits as usize); // Store GPU results using original indices for ((original_idx, _, _), result) in work_chunk.iter().zip(batch_results) { @@ -562,28 +569,28 @@ fn msm_medium( _gpu_bases: Option<&[GpuBaseType]>, scalars: &[T], max_num_bits: usize, - use_icicle: bool, + _use_icicle: bool, ) -> V where F: JoltField, V: VariableBaseMSM, T: Into + Zero + Copy + Sync, { - if use_icicle { - #[cfg(feature = "icicle")] - { - let mut backup = vec![]; - let gpu_bases = _gpu_bases.unwrap_or_else(|| { - backup = Self::get_gpu_bases(bases); - &backup - }); - return icicle_msm::(_gpu_bases, scalars, max_num_bits); - } - #[cfg(not(feature = "icicle"))] - { - unreachable!("icicle_init must not return true without the icicle feature"); - } - } + // if use_icicle { + // #[cfg(feature = "icicle")] + // { + // let mut backup = vec![]; + // let gpu_bases = _gpu_bases.unwrap_or_else(|| { + // backup = VariableBaseMSM::get_gpu_bases(bases); + // &backup + // }); + // return icicle_msm(gpu_bases, scalars, max_num_bits); + // } + // #[cfg(not(feature = "icicle"))] + // { + // unreachable!("icicle_init must not return true without the icicle feature"); + // } + // } let c = if bases.len() < 32 { 3