diff --git a/src/msm.rs b/src/msm.rs index e30777bc..2240eece 100644 --- a/src/msm.rs +++ b/src/msm.rs @@ -522,6 +522,112 @@ pub fn multiexp_serial(coeffs: &[C::Scalar], bases: &[C], acc: & } } +pub fn multiexp_serial_skip_zeros( + coeffs: &[C::Scalar], + bases: &[C], + acc: &mut C::Curve, +) { + // Do conversion to bytes once + let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); + + let c = if bases.len() < 4 { + 1 + } else if bases.len() < 32 { + 3 + } else { + (f64::from(bases.len() as u32)).ln().ceil() as usize + }; + + let number_of_windows = C::Scalar::NUM_BITS as usize / c + 1; + // println!("c = {}, num_win = {}", c, number_of_windows); + + // In each window, get the booth index of each coefficient + let mut coeffs_in_windows = Vec::with_capacity(number_of_windows); + // Track what is the last window where we actually have nonzero booth index, so we completely skip buckets where the scalar bits for all coeffs are 0 + let mut max_nonzero_window = None; + for current_window in 0..number_of_windows { + let coeffs_in_window: Vec = coeffs + .iter() + .map(|coeff| { + let coeff = get_booth_index(current_window, c, coeff.as_ref()); + if coeff != 0 { + max_nonzero_window = Some(current_window); + } + coeff + }) + .collect(); + coeffs_in_windows.push(coeffs_in_window); + } + // Save memory and drop coeffs as bytes since it's not needed anymore + drop(coeffs); + + if max_nonzero_window.is_none() { + return; + } + // println!("max_nonzero_win = {:?}", max_nonzero_window); + for coeffs_in_window in coeffs_in_windows + .into_iter() + .take(max_nonzero_window.unwrap() + 1) + .rev() + { + for _ in 0..c { + *acc = acc.double(); + } + + #[derive(Clone, Copy)] + enum Bucket { + None, + Affine(C), + Projective(C::Curve), + } + + impl Bucket { + fn add_assign(&mut self, other: &C) { + *self = match *self { + Bucket::None => Bucket::Affine(*other), + Bucket::Affine(a) => Bucket::Projective(a + *other), + Bucket::Projective(mut a) => { + a += *other; + Bucket::Projective(a) + } + } + } + + fn add(self, mut other: C::Curve) -> C::Curve { + match self { + Bucket::None => other, + Bucket::Affine(a) => { + other += a; + other + } + Bucket::Projective(a) => other + a, + } + } + } + + let mut buckets: Vec> = vec![Bucket::None; 1 << (c - 1)]; + + for (coeff, base) in coeffs_in_window.into_iter().zip(bases.iter()) { + if coeff.is_positive() { + buckets[coeff as usize - 1].add_assign(base); + } + if coeff.is_negative() { + buckets[coeff.unsigned_abs() as usize - 1].add_assign(&base.neg()); + } + } + + // Summation by parts + // e.g. 3a + 2b + 1c = a + + // (a) + b + + // ((a) + b) + c + let mut running_sum = C::Curve::identity(); + for exp in buckets.into_iter().rev() { + running_sum = exp.add(running_sum); + *acc += &running_sum; + } + } +} + pub fn best_multiexp_bit(coeffs: &BitSlice, bases: &[C]) -> C::Curve { assert_eq!(coeffs.len(), bases.len()); @@ -583,6 +689,40 @@ pub fn best_multiexp_small( } } +/// Performs a multi-exponentiation operation. +/// +/// This function will panic if coeffs and bases have a different length. +/// +/// This will use multithreading if beneficial. +pub fn best_multiexp_skip_zeros(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { + assert_eq!(coeffs.len(), bases.len()); + + let num_threads = rayon::current_num_threads(); + if coeffs.len() > num_threads { + let chunk = coeffs.len() / num_threads; + let num_chunks = coeffs.chunks(chunk).len(); + let mut results = vec![C::Curve::identity(); num_chunks]; + rayon::scope(|scope| { + let chunk = coeffs.len() / num_threads; + + for ((coeffs, bases), acc) in coeffs + .chunks(chunk) + .zip(bases.chunks(chunk)) + .zip(results.iter_mut()) + { + scope.spawn(move |_| { + multiexp_serial_skip_zeros(coeffs, bases, acc); + }); + } + }); + results.iter().fold(C::Curve::identity(), |a, b| a + b) + } else { + let mut acc = C::Curve::identity(); + multiexp_serial_skip_zeros(coeffs, bases, &mut acc); + acc + } +} + /// Performs a multi-exponentiation operation. /// /// This function will panic if coeffs and bases have a different length. @@ -916,6 +1056,7 @@ mod test { let points = affine_points; const BYTES: usize = 1; + println!("bits = {}", BYTES * 8); assert!(BYTES <= 16); let max_val = 2u128.pow((BYTES * 8) as u32); let mut scalars = vec![C::Scalar::ZERO; 1 << max_k]; @@ -923,7 +1064,7 @@ mod test { for i in 0..1 << max_k { let v_lo = OsRng.next_u64() as u128; let v_hi = OsRng.next_u64() as u128; - let mut v = v_lo + v_hi << 64; + let mut v = v_lo + (v_hi << 64); if BYTES < 16 { v %= max_val; } @@ -950,8 +1091,8 @@ mod test { end_timer!(t1); // assert_eq!(e0, e1); - let t11 = start_timer!(|| format!("older_prime k={}", k)); - let e11 = super::best_multiexp_prime(scalars, points); + let t11 = start_timer!(|| format!("older_skip_zeros k={}", k)); + let e11 = super::best_multiexp_skip_zeros(scalars, points); end_timer!(t11); assert_eq!(e11, e1); @@ -970,7 +1111,7 @@ mod test { #[test] fn test_msm_cross_small() { - run_msm_cross_small::(14, 22); + run_msm_cross_small::(18, 22); // run_msm_cross::(18, 20); } }