Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed-up multiplication by small integers, and improve lookup part of compute_quotient_poly() #1153

Closed
wants to merge 10 commits into from
6 changes: 3 additions & 3 deletions evm/src/cross_table_lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use itertools::Itertools;
use plonky2::field::extension::{Extendable, FieldExtension};
use plonky2::field::packed::PackedField;
use plonky2::field::polynomial::PolynomialValues;
use plonky2::field::types::Field;
use plonky2::field::types::{Field, SmallPowers};
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::iop::target::Target;
Expand Down Expand Up @@ -79,14 +79,14 @@ impl<F: Field> Column<F> {
}

pub fn le_bits<I: IntoIterator<Item = impl Borrow<usize>>>(cs: I) -> Self {
Self::linear_combination(cs.into_iter().map(|c| *c.borrow()).zip(F::TWO.powers()))
Self::linear_combination(cs.into_iter().map(|c| *c.borrow()).zip(SmallPowers::new(2)))
}

pub fn le_bytes<I: IntoIterator<Item = impl Borrow<usize>>>(cs: I) -> Self {
Self::linear_combination(
cs.into_iter()
.map(|c| *c.borrow())
.zip(F::from_canonical_u16(256).powers()),
.zip(SmallPowers::new(256)),
)
}

Expand Down
13 changes: 9 additions & 4 deletions evm/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use plonky2::field::extension::Extendable;
use plonky2::field::packable::Packable;
use plonky2::field::packed::PackedField;
use plonky2::field::polynomial::{PolynomialCoeffs, PolynomialValues};
use plonky2::field::types::Field;
use plonky2::field::types::{Field, SmallPowers};
use plonky2::field::zero_poly_coset::ZeroPolyOnCoset;
use plonky2::fri::oracle::PolynomialBatch;
use plonky2::hash::hash_types::RichField;
Expand Down Expand Up @@ -498,11 +498,16 @@ where
// When opening the `Z`s polys at the "next" point, need to look at the point `next_step` steps away.
let next_step = 1 << quotient_degree_bits;

let powers_vec: Vec<F> = SmallPowers::new(F::coset_shift().to_noncanonical_u64() as u32)
.take(degree << quotient_degree_bits)
.collect_vec();

// Evaluation of the first Lagrange polynomial on the LDE domain.
let lagrange_first = PolynomialValues::selector(degree, 0).lde_onto_coset(quotient_degree_bits);
let lagrange_first = PolynomialValues::selector(degree, 0)
.lde_onto_coset(quotient_degree_bits, powers_vec.iter());
// Evaluation of the last Lagrange polynomial on the LDE domain.
let lagrange_last =
PolynomialValues::selector(degree, degree - 1).lde_onto_coset(quotient_degree_bits);
let lagrange_last = PolynomialValues::selector(degree, degree - 1)
.lde_onto_coset(quotient_degree_bits, powers_vec.iter());

let z_h_on_coset = ZeroPolyOnCoset::<F>::new(degree_bits, quotient_degree_bits);

Expand Down
2 changes: 2 additions & 0 deletions field/src/arch/x86_64/avx2_goldilocks_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ unsafe impl PackedField for Avx2GoldilocksField {
};
(Self::new(res0), Self::new(res1))
}

// TODO: overriding default Self::mul_u32() may yield interesting speed-ups
}

impl Square for Avx2GoldilocksField {
Expand Down
2 changes: 2 additions & 0 deletions field/src/arch/x86_64/avx512_goldilocks_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ unsafe impl PackedField for Avx512GoldilocksField {
};
(Self::new(res0), Self::new(res1))
}

// TODO: overriding default Self::mul_u32() may yield interesting speed-ups
}

impl Add<Self> for Avx512GoldilocksField {
Expand Down
5 changes: 5 additions & 0 deletions field/src/extension/quadratic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ impl<F: Extendable<2>> Field for QuadraticExtension<F> {
fn from_noncanonical_u64(n: u64) -> Self {
F::from_noncanonical_u64(n).into()
}

#[inline]
fn mul_u32(&self, x: u32) -> Self {
Self([self.0[0].mul_u32(x), self.0[1].mul_u32(x)])
}
}

impl<F: Extendable<2>> Display for QuadraticExtension<F> {
Expand Down
10 changes: 10 additions & 0 deletions field/src/extension/quartic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ impl<F: Extendable<4>> Field for QuarticExtension<F> {
fn from_noncanonical_u64(n: u64) -> Self {
F::from_noncanonical_u64(n).into()
}

#[inline]
fn mul_u32(&self, x: u32) -> Self {
Self([
self.0[0].mul_u32(x),
self.0[1].mul_u32(x),
self.0[2].mul_u32(x),
self.0[3].mul_u32(x),
])
}
}

impl<F: Extendable<4>> Display for QuarticExtension<F> {
Expand Down
11 changes: 11 additions & 0 deletions field/src/extension/quintic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ impl<F: Extendable<5>> Field for QuinticExtension<F> {
fn from_noncanonical_u64(n: u64) -> Self {
F::from_noncanonical_u64(n).into()
}

#[inline]
fn mul_u32(&self, x: u32) -> Self {
Self([
self.0[0].mul_u32(x),
self.0[1].mul_u32(x),
self.0[2].mul_u32(x),
self.0[3].mul_u32(x),
self.0[4].mul_u32(x),
])
}
}

impl<F: Extendable<5>> Display for QuinticExtension<F> {
Expand Down
6 changes: 6 additions & 0 deletions field/src/goldilocks_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ impl Field for GoldilocksField {
// u64 + u64 * u64 cannot overflow.
reduce128((self.0 as u128) + (x.0 as u128) * (y.0 as u128))
}

#[inline]
fn mul_u32(&self, x: u32) -> Self {
let t = self.0 as u128 * x as u128;
Self::from_noncanonical_u96((t as u64, (t >> 64) as u32))
}
}

impl PrimeField for GoldilocksField {
Expand Down
5 changes: 5 additions & 0 deletions field/src/packed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ where
fn doubles(&self) -> Self {
*self * Self::Scalar::TWO
}

#[inline]
fn mul_u32(&self, x: u32) -> Self {
*self * Self::Scalar::from_canonical_u32(x)
}
}

unsafe impl<F: Field> PackedField for F {
Expand Down
12 changes: 10 additions & 2 deletions field/src/polynomial/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,17 @@ impl<F: Field> PolynomialValues<F> {
}

/// Low-degree extend `Self` (seen as evaluations over the subgroup) onto a coset.
pub fn lde_onto_coset(self, rate_bits: usize) -> Self {
pub fn lde_onto_coset<'a, I>(self, rate_bits: usize, shift_powers: I) -> Self
where
I: Iterator<Item = &'a F>,
{
let coeffs = ifft(self).lde(rate_bits);
coeffs.coset_fft_with_options(F::coset_shift(), Some(rate_bits), None)
let modified_poly: PolynomialCoeffs<F> = shift_powers
.zip(&coeffs.coeffs)
.map(|(&r, &c)| r * c)
.collect::<Vec<_>>()
.into();
modified_poly.fft_with_options(Some(rate_bits), None)
}

pub fn degree(&self) -> usize {
Expand Down
35 changes: 35 additions & 0 deletions field/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ pub trait Field:
const CHARACTERISTIC_TWO_ADICITY: usize;

/// Generator of the entire multiplicative group, i.e. all non-zero elements.
///
/// **For prime fields, this element is expected to fit in a u32
/// when canonically reduced.**
const MULTIPLICATIVE_GROUP_GENERATOR: Self;
/// Generator of a multiplicative subgroup of order `2^TWO_ADICITY`.
const POWER_OF_TWO_GENERATOR: Self;
Expand Down Expand Up @@ -444,6 +447,12 @@ pub trait Field:
// Default implementation.
*self + x * y
}

/// Equivalent to regular multiplication, but may be cheaper.
#[inline]
fn mul_u32(&self, x: u32) -> Self {
*self * Self::from_canonical_u32(x)
}
}

pub trait PrimeField: Field {
Expand Down Expand Up @@ -592,3 +601,29 @@ impl<F: Field> Powers<F> {
}
}
}

/// An iterator similar to `Powers`, but which base fits in a `u32`.
#[derive(Clone)]
pub struct SmallPowers<F: Field> {
base: u32,
current: F,
}

impl<F: Field> SmallPowers<F> {
pub fn new(base: u32) -> Self {
Self {
base,
current: F::ONE,
}
}
}

impl<F: Field> Iterator for SmallPowers<F> {
type Item = F;

fn next(&mut self) -> Option<F> {
let result = self.current;
self.current = self.current.mul_u32(self.base);
Some(result)
}
}
25 changes: 21 additions & 4 deletions plonky2/src/fri/oracle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use alloc::format;
use alloc::vec::Vec;

use itertools::Itertools;
use plonky2_field::types::Field;
use plonky2_field::types::{Field, SmallPowers};
use plonky2_maybe_rayon::*;

use crate::field::extension::Extendable;
Expand Down Expand Up @@ -112,8 +112,17 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
.par_iter()
.map(|p| {
assert_eq!(p.len(), degree, "Polynomial degrees inconsistent");
p.lde(rate_bits)
.coset_fft_with_options(F::coset_shift(), Some(rate_bits), fft_root_table)
let poly = p.lde(rate_bits);

// Custom coset_fft with small shift
let modified_poly: PolynomialCoeffs<F> =
SmallPowers::<F>::new(F::coset_shift().to_noncanonical_u64() as u32)
.zip(&poly.coeffs)
.map(|(r, &c)| r * c)
.collect::<Vec<_>>()
.into();
modified_poly
.fft_with_options(Some(rate_bits), fft_root_table)
.values
})
.chain(
Expand Down Expand Up @@ -197,10 +206,18 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
}

let lde_final_poly = final_poly.lde(fri_params.config.rate_bits);

// Custom coset_fft with small shift
let modified_poly: PolynomialCoeffs<F::Extension> =
SmallPowers::<F::Extension>::new(F::coset_shift().to_noncanonical_u64() as u32)
.zip(&lde_final_poly.coeffs)
.map(|(r, &c)| r * c)
.collect::<Vec<_>>()
.into();
let lde_final_values = timed!(
timing,
&format!("perform final FFT {}", lde_final_poly.len()),
lde_final_poly.coset_fft(F::coset_shift().into())
modified_poly.fft_with_options(None, None)
);

let fri_proof = fri_proof::<F, C, D>(
Expand Down
3 changes: 2 additions & 1 deletion plonky2/src/gadgets/split_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ impl<F: RichField + Extendable<D>, const B: usize, const D: usize> SimpleGenerat
.map(|&t| witness.get_bool_target(t))
.rev()
.fold(F::ZERO, |acc, limb| {
acc * F::from_canonical_usize(B) + F::from_bool(limb)
let t = acc.to_noncanonical_u64() as u128 * B as u128 + limb as u128;
F::from_noncanonical_u96((t as u64, (t >> 64) as u32))
});

out_buffer.set_target(Target::wire(self.row, BaseSumGate::<B>::WIRE_SUM), sum);
Expand Down
6 changes: 3 additions & 3 deletions plonky2/src/gates/base_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::iop::target::Target;
use crate::iop::witness::{PartitionWitness, Witness, WitnessWrite};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData};
use crate::plonk::plonk_common::{reduce_with_powers, reduce_with_powers_ext_circuit};
use crate::plonk::plonk_common::{reduce_with_powers_ext_circuit, reduce_with_powers_u32};
use crate::plonk::vars::{
EvaluationTargets, EvaluationVars, EvaluationVarsBase, EvaluationVarsBaseBatch,
EvaluationVarsBasePacked,
Expand Down Expand Up @@ -67,7 +67,7 @@ impl<F: RichField + Extendable<D>, const D: usize, const B: usize> Gate<F, D> fo
fn eval_unfiltered(&self, vars: EvaluationVars<F, D>) -> Vec<F::Extension> {
let sum = vars.local_wires[Self::WIRE_SUM];
let limbs = vars.local_wires[self.limbs()].to_vec();
let computed_sum = reduce_with_powers(&limbs, F::Extension::from_canonical_usize(B));
let computed_sum = reduce_with_powers_u32(&limbs, B as u32);
let mut constraints = vec![computed_sum - sum];
for limb in limbs {
constraints.push(
Expand Down Expand Up @@ -156,7 +156,7 @@ impl<F: RichField + Extendable<D>, const D: usize, const B: usize> PackedEvaluab
) {
let sum = vars.local_wires[Self::WIRE_SUM];
let limbs = vars.local_wires.view(self.limbs());
let computed_sum = reduce_with_powers(limbs, F::from_canonical_usize(B));
let computed_sum = reduce_with_powers_u32(limbs, B as u32);

yield_constr.one(computed_sum - sum);

Expand Down
14 changes: 14 additions & 0 deletions plonky2/src/plonk/plonk_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,20 @@ pub(crate) fn reduce_with_powers_multi<
cumul
}

pub fn reduce_with_powers_u32<'a, P: PackedField, T: IntoIterator<Item = &'a P>>(
terms: T,
alpha: u32,
) -> P
where
T::IntoIter: DoubleEndedIterator,
{
let mut sum = P::ZEROS;
for &term in terms.into_iter().rev() {
sum = sum.mul_u32(alpha) + term;
}
sum
}

pub fn reduce_with_powers<'a, P: PackedField, T: IntoIterator<Item = &'a P>>(
terms: T,
alpha: P::Scalar,
Expand Down
22 changes: 18 additions & 4 deletions plonky2/src/plonk/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,11 @@ fn compute_lookup_polys<
let looked_inp = witness.get_wire(row, LookupTableGate::wire_ith_looked_inp(s));
let looked_out = witness.get_wire(row, LookupTableGate::wire_ith_looked_out(s));

looked_inp + deltas[LookupChallenges::ChallengeA as usize] * looked_out
let t = looked_inp.to_noncanonical_u64() as u128
+ deltas[LookupChallenges::ChallengeA as usize].to_noncanonical_u64()
as u128
* looked_out.to_noncanonical_u64() as u128;
F::from_noncanonical_u96((t as u64, (t >> 64) as u32))
})
.collect();
// Get (alpha - combo).
Expand All @@ -483,7 +487,11 @@ fn compute_lookup_polys<
let looked_inp = witness.get_wire(row, LookupTableGate::wire_ith_looked_inp(s));
let looked_out = witness.get_wire(row, LookupTableGate::wire_ith_looked_out(s));

looked_inp + deltas[LookupChallenges::ChallengeB as usize] * looked_out
let t = looked_inp.to_noncanonical_u64() as u128
+ deltas[LookupChallenges::ChallengeB as usize].to_noncanonical_u64()
as u128
* looked_out.to_noncanonical_u64() as u128;
F::from_noncanonical_u96((t as u64, (t >> 64) as u32))
})
.collect();

Expand Down Expand Up @@ -520,7 +528,11 @@ fn compute_lookup_polys<
let looking_in = witness.get_wire(row, LookupGate::wire_ith_looking_inp(s));
let looking_out = witness.get_wire(row, LookupGate::wire_ith_looking_out(s));

looking_in + deltas[LookupChallenges::ChallengeA as usize] * looking_out
let t = looking_in.to_noncanonical_u64() as u128
+ deltas[LookupChallenges::ChallengeA as usize].to_noncanonical_u64()
as u128
* looking_out.to_noncanonical_u64() as u128;
F::from_noncanonical_u96((t as u64, (t >> 64) as u32))
})
.collect();
// Get (alpha - combo).
Expand Down Expand Up @@ -685,7 +697,9 @@ fn compute_quotient_polys<
let mut local_wires_batch_refs = Vec::with_capacity(xs_batch.len());

for (&i, &x) in indices_batch.iter().zip(xs_batch) {
let shifted_x = F::coset_shift() * x;
// F::coset_shift() returns the multiplicative generator,
// which fits in a u32 for `RichField`.
let shifted_x = x.mul_u32(F::coset_shift().to_noncanonical_u64() as u32);
let i_next = (i + next_step) % lde_size;
let local_constants_sigmas = prover_data
.constants_sigmas_commitment
Expand Down
6 changes: 4 additions & 2 deletions plonky2/src/plonk/vanishing_poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ pub(crate) fn get_lut_poly<F: RichField + Extendable<D>, const D: usize>(
let b = deltas[LookupChallenges::ChallengeB as usize];
let mut coeffs = Vec::new();
let n = common_data.luts[lut_index].len();
for (input, output) in common_data.luts[lut_index].iter() {
coeffs.push(F::from_canonical_u16(*input) + b * F::from_canonical_u16(*output));
for &(input, output) in common_data.luts[lut_index].iter() {
let t = input as u128 + b.to_noncanonical_u64() as u128 * output as u128;

coeffs.push(F::from_noncanonical_u96((t as u64, (t >> 64) as u32)));
}
coeffs.append(&mut vec![F::ZERO; degree - n]);
coeffs.reverse();
Expand Down
Loading