From 9eaccbb8578f3c1741d4961eef7405a4e80a9dfe Mon Sep 17 00:00:00 2001 From: einar-taiko <126954546+einar-taiko@users.noreply.github.com> Date: Mon, 26 Jun 2023 23:24:24 +0900 Subject: [PATCH] Resolve Prover optimization: memory reduction #77 (#6) * Resolve taikoxyz/zkevm-circuits#77 * Please Clippy --- halo2_proofs/Cargo.toml | 1 + halo2_proofs/src/plonk.rs | 11 +- halo2_proofs/src/plonk/circuit.rs | 67 ++ halo2_proofs/src/plonk/evaluation.rs | 980 ++++++++++++++----- halo2_proofs/src/plonk/keygen.rs | 16 +- halo2_proofs/src/plonk/permutation.rs | 7 +- halo2_proofs/src/plonk/permutation/keygen.rs | 14 +- halo2_proofs/src/plonk/permutation/prover.rs | 4 - halo2_proofs/src/poly/domain.rs | 337 +++++++ 9 files changed, 1130 insertions(+), 307 deletions(-) diff --git a/halo2_proofs/Cargo.toml b/halo2_proofs/Cargo.toml index 7f8d8652d8..6c37821353 100644 --- a/halo2_proofs/Cargo.toml +++ b/halo2_proofs/Cargo.toml @@ -48,6 +48,7 @@ name = "fft" harness = false [dependencies] +itertools = "0.10" backtrace = { version = "0.3", optional = true } rayon = "1.5.1" ff = "0.13" diff --git a/halo2_proofs/src/plonk.rs b/halo2_proofs/src/plonk.rs index a32f2342f2..84a762c4a2 100644 --- a/halo2_proofs/src/plonk.rs +++ b/halo2_proofs/src/plonk.rs @@ -371,12 +371,11 @@ pub struct PinnedVerificationKey<'a, C: CurveAffine> { #[derive(Clone, Debug)] pub struct ProvingKey { vk: VerifyingKey, - l0: Polynomial, - l_last: Polynomial, - l_active_row: Polynomial, + l0: Polynomial, + l_last: Polynomial, + l_active_row: Polynomial, fixed_values: Vec>, fixed_polys: Vec>, - fixed_cosets: Vec>, permutation: permutation::ProvingKey, ev: Evaluator, } @@ -398,7 +397,6 @@ where + scalar_len * (self.l0.len() + self.l_last.len() + self.l_active_row.len()) + polynomial_slice_byte_length(&self.fixed_values) + polynomial_slice_byte_length(&self.fixed_polys) - + polynomial_slice_byte_length(&self.fixed_cosets) + self.permutation.bytes_length() } } @@ -424,7 +422,6 @@ where self.l_active_row.write(writer, format)?; write_polynomial_slice(&self.fixed_values, writer, format)?; write_polynomial_slice(&self.fixed_polys, writer, format)?; - write_polynomial_slice(&self.fixed_cosets, writer, format)?; self.permutation.write(writer, format)?; Ok(()) } @@ -456,7 +453,6 @@ where let l_active_row = Polynomial::read(reader, format)?; let fixed_values = read_polynomial_vec(reader, format)?; let fixed_polys = read_polynomial_vec(reader, format)?; - let fixed_cosets = read_polynomial_vec(reader, format)?; let permutation = permutation::ProvingKey::read(reader, format)?; let ev = Evaluator::new(vk.cs()); Ok(Self { @@ -466,7 +462,6 @@ where l_active_row, fixed_values, fixed_polys, - fixed_cosets, permutation, ev, }) diff --git a/halo2_proofs/src/plonk/circuit.rs b/halo2_proofs/src/plonk/circuit.rs index ae3f1eb7ae..4f48b21b0b 100644 --- a/halo2_proofs/src/plonk/circuit.rs +++ b/halo2_proofs/src/plonk/circuit.rs @@ -7,6 +7,7 @@ use crate::{ use core::cmp::max; use core::ops::{Add, Mul}; use ff::Field; +use itertools::Itertools; use sealed::SealedPhase; use std::cmp::Ordering; use std::collections::HashMap; @@ -1248,6 +1249,72 @@ impl Expression { &|a, _| a, ) } + + /// Extracts all used instance columns in this expression + pub fn extract_instances(&self) -> Vec { + self.evaluate( + &|_| vec![], + &|_| vec![], + &|_| vec![], + &|_| vec![], + &|query| vec![query.column_index], + &|_| vec![], + &|a| a, + &|mut a, b| { + a.extend(b); + a.into_iter().unique().collect() + }, + &|mut a, b| { + a.extend(b); + a.into_iter().unique().collect() + }, + &|a, _| a, + ) + } + + /// Extracts all used advice columns in this expression + pub fn extract_advices(&self) -> Vec { + self.evaluate( + &|_| vec![], + &|_| vec![], + &|_| vec![], + &|query| vec![query.column_index], + &|_| vec![], + &|_| vec![], + &|a| a, + &|mut a, b| { + a.extend(b); + a.into_iter().unique().collect() + }, + &|mut a, b| { + a.extend(b); + a.into_iter().unique().collect() + }, + &|a, _| a, + ) + } + + /// Extracts all used fixed columns in this expression + pub fn extract_fixed(&self) -> Vec { + self.evaluate( + &|_| vec![], + &|_| vec![], + &|query| vec![query.column_index], + &|_| vec![], + &|_| vec![], + &|_| vec![], + &|a| a, + &|mut a, b| { + a.extend(b); + a.into_iter().unique().collect() + }, + &|mut a, b| { + a.extend(b); + a.into_iter().unique().collect() + }, + &|a, _| a, + ) + } } impl std::fmt::Debug for Expression { diff --git a/halo2_proofs/src/plonk/evaluation.rs b/halo2_proofs/src/plonk/evaluation.rs index 1c05a261a0..cc99277cf8 100644 --- a/halo2_proofs/src/plonk/evaluation.rs +++ b/halo2_proofs/src/plonk/evaluation.rs @@ -16,10 +16,12 @@ use group::{ ff::{BatchInvert, Field, PrimeField, WithSmallOrderMulGroup}, Curve, }; +use itertools::Itertools; use std::any::TypeId; use std::convert::TryInto; use std::num::ParseIntError; use std::slice; +use std::time::Instant; use std::{ collections::BTreeMap, iter, @@ -55,9 +57,7 @@ pub enum ValueSource { /// theta Theta(), /// y - Y(), - /// Previous value - PreviousValue(), + Y(usize), } impl Default for ValueSource { @@ -73,34 +73,35 @@ impl ValueSource { rotations: &[usize], constants: &[F], intermediates: &[F], - fixed_values: &[Polynomial], - advice_values: &[Polynomial], - instance_values: &[Polynomial], + fixed_values: &[Option>], + advice_values: &[Option>], + instance_values: &[Option>], challenges: &[F], + y_powers: &[F], beta: &F, gamma: &F, theta: &F, - y: &F, - previous_value: &F, ) -> F { match self { ValueSource::Constant(idx) => constants[*idx], ValueSource::Intermediate(idx) => intermediates[*idx], ValueSource::Fixed(column_index, rotation) => { - fixed_values[*column_index][rotations[*rotation]] + assert!(fixed_values[*column_index].is_some()); + fixed_values[*column_index].as_ref().unwrap()[rotations[*rotation]] } ValueSource::Advice(column_index, rotation) => { - advice_values[*column_index][rotations[*rotation]] + assert!(advice_values[*column_index].is_some()); + advice_values[*column_index].as_ref().unwrap()[rotations[*rotation]] } ValueSource::Instance(column_index, rotation) => { - instance_values[*column_index][rotations[*rotation]] + assert!(instance_values[*column_index].is_some()); + instance_values[*column_index].as_ref().unwrap()[rotations[*rotation]] } ValueSource::Challenge(index) => challenges[*index], ValueSource::Beta() => *beta, ValueSource::Gamma() => *gamma, ValueSource::Theta() => *theta, - ValueSource::Y() => *y, - ValueSource::PreviousValue() => *previous_value, + ValueSource::Y(idx) => y_powers[*idx], } } } @@ -133,15 +134,14 @@ impl Calculation { rotations: &[usize], constants: &[F], intermediates: &[F], - fixed_values: &[Polynomial], - advice_values: &[Polynomial], - instance_values: &[Polynomial], + fixed_values: &[Option>], + advice_values: &[Option>], + instance_values: &[Option>], challenges: &[F], + y_powers: &[F], beta: &F, gamma: &F, theta: &F, - y: &F, - previous_value: &F, ) -> F { let get_value = |value: &ValueSource| { value.get( @@ -152,11 +152,10 @@ impl Calculation { advice_values, instance_values, challenges, + y_powers, beta, gamma, theta, - y, - previous_value, ) }; match self { @@ -179,13 +178,41 @@ impl Calculation { } } +#[derive(Clone, Default, Debug)] +struct ConstraintCluster { + /// Used fixed columns in each cluster + used_fixed_columns: Vec, + /// Used instance columns in each cluster + used_instance_columns: Vec, + /// Used advice columns in each cluster + used_advice_columns: Vec, + /// Custom gates evalution + evaluator: GraphEvaluator, + /// The first index of constraints are being evaluated at in each cluster + first_constraint_idx: usize, + /// The last index of constraints are being evaluated at in each cluster + last_constraint_idx: usize, + /// The last value source + last_value_source: Option, +} + /// Evaluator #[derive(Clone, Default, Debug)] pub struct Evaluator { - /// Custom gates evalution - pub custom_gates: GraphEvaluator, - /// Lookups evalution - pub lookups: Vec>, + /// list of constraint clusters + custom_gate_clusters: Vec>, + /// Number of custom gate constraints + num_custom_gate_constraints: usize, + /// Lookups evalution, degree, used instance and advice columns + #[allow(clippy::type_complexity)] + lookups: Vec<( + GraphEvaluator, + usize, + (Vec, Vec, Vec), + )>, + + /// Powers of y + num_y_powers: usize, } /// GraphEvaluator @@ -219,46 +246,122 @@ pub struct CalculationInfo { pub target: usize, } +fn merge_unique(a: Vec, b: Vec) -> Vec { + let mut result = a; + result.extend(b); + result.into_iter().unique().collect() +} + impl Evaluator { /// Creates a new evaluation structure pub fn new(cs: &ConstraintSystem) -> Self { let mut ev = Evaluator::default(); + let mut constraint_idx = 0; + + // Compute the max cluster index + let quotient_poly_degree = (cs.degree() - 1) as u64; + let mut max_cluster_idx = 0; + while (1 << max_cluster_idx) < quotient_poly_degree { + max_cluster_idx += 1; + } + + ev.custom_gate_clusters + .resize(max_cluster_idx + 1, ConstraintCluster::default()); // Custom gates - let mut parts = Vec::new(); for gate in cs.gates.iter() { - parts.extend( - gate.polynomials() - .iter() - .map(|poly| ev.custom_gates.add_expression(poly)), - ); + for poly in gate.polynomials() { + constraint_idx += 1; + let cluster_idx = Self::compute_cluster_idx(poly.degree(), max_cluster_idx); + let custom_gate_cluster = &mut ev.custom_gate_clusters[cluster_idx]; + custom_gate_cluster.used_fixed_columns = merge_unique( + custom_gate_cluster.used_fixed_columns.clone(), + poly.extract_fixed(), + ); + custom_gate_cluster.used_instance_columns = merge_unique( + custom_gate_cluster.used_instance_columns.clone(), + poly.extract_instances(), + ); + custom_gate_cluster.used_advice_columns = merge_unique( + custom_gate_cluster.used_advice_columns.clone(), + poly.extract_advices(), + ); + let curr = custom_gate_cluster.evaluator.add_expression(poly); + if let Some(last) = custom_gate_cluster.last_value_source { + custom_gate_cluster.last_value_source = Some( + custom_gate_cluster + .evaluator + .add_calculation(Calculation::Horner( + last, + vec![curr], + ValueSource::Y( + constraint_idx - custom_gate_cluster.last_constraint_idx, + ), + )), + ); + } else { + assert_eq!(custom_gate_cluster.last_constraint_idx, 0); + custom_gate_cluster.last_value_source = Some(curr); + custom_gate_cluster.first_constraint_idx = constraint_idx; + } + custom_gate_cluster.last_constraint_idx = constraint_idx; + } } - ev.custom_gates.add_calculation(Calculation::Horner( - ValueSource::PreviousValue(), - parts, - ValueSource::Y(), - )); + + ev.num_custom_gate_constraints = constraint_idx; // Lookups for lookup in cs.lookups.iter() { + constraint_idx += 5; let mut graph = GraphEvaluator::default(); let mut evaluate_lc = |expressions: &Vec>| { + let mut max_degree = 0; + let mut used_fixed_columns = vec![]; + let mut used_instance_columns = vec![]; + let mut used_advice_columns = vec![]; let parts = expressions .iter() - .map(|expr| graph.add_expression(expr)) + .map(|expr| { + max_degree = max_degree.max(expr.degree()); + used_fixed_columns = + merge_unique(used_fixed_columns.clone(), expr.extract_fixed()); + used_instance_columns = + merge_unique(used_instance_columns.clone(), expr.extract_instances()); + used_advice_columns = + merge_unique(used_advice_columns.clone(), expr.extract_advices()); + graph.add_expression(expr) + }) .collect(); - graph.add_calculation(Calculation::Horner( - ValueSource::Constant(0), - parts, - ValueSource::Theta(), - )) + ( + graph.add_calculation(Calculation::Horner( + ValueSource::Constant(0), + parts, + ValueSource::Theta(), + )), + max_degree, + used_fixed_columns, + used_instance_columns, + used_advice_columns, + ) }; // Input coset - let compressed_input_coset = evaluate_lc(&lookup.input_expressions); + let ( + compressed_input_coset, + max_input_degree, + input_used_fixed, + input_used_instances, + input_used_advices, + ) = evaluate_lc(&lookup.input_expressions); // table coset - let compressed_table_coset = evaluate_lc(&lookup.table_expressions); + let ( + compressed_table_coset, + max_table_degree, + table_used_fixed, + table_used_instances, + table_used_advices, + ) = evaluate_lc(&lookup.table_expressions); // z(\omega X) (a'(X) + \beta) (s'(X) + \gamma) let right_gamma = graph.add_calculation(Calculation::Add( compressed_table_coset, @@ -269,10 +372,21 @@ impl Evaluator { ValueSource::Beta(), )); graph.add_calculation(Calculation::Mul(lc, right_gamma)); - - ev.lookups.push(graph); + ev.lookups.push(( + graph, + max_input_degree + max_table_degree, + ( + merge_unique(input_used_fixed, table_used_fixed), + merge_unique(input_used_instances, table_used_instances), + merge_unique(input_used_advices, table_used_advices), + ), + )); } + // Count the constraints in permutation + let num_sets = (cs.permutation.get_columns().len() + (cs.degree() - 3)) / (cs.degree() - 2); + constraint_idx += 1 + num_sets * 2; + ev.num_y_powers = constraint_idx + 10; ev } @@ -291,242 +405,582 @@ impl Evaluator { permutations: &[permutation::prover::Committed], ) -> Polynomial { let domain = &pk.vk.domain; - let size = domain.extended_len(); - let rot_scale = 1 << (domain.extended_k() - domain.k()); - let fixed = &pk.fixed_cosets[..]; + let size = 1 << domain.k() as usize; + let rot_scale = 1; let extended_omega = domain.get_extended_omega(); + let omega = domain.get_omega(); let isize = size as i32; let one = C::ScalarExt::ONE; - let l0 = &pk.l0; - let l_last = &pk.l_last; - let l_active_row = &pk.l_active_row; let p = &pk.vk.cs.permutation; + let num_parts = domain.extended_len() >> domain.k(); + let num_clusters = (domain.extended_k() - domain.k() + 1) as usize; - // Calculate the advice and instance cosets - let start = start_measure("cosets", false); - let advice: Vec>> = advice_polys - .iter() - .map(|advice_polys| { - advice_polys - .iter() - .map(|poly| domain.coeff_to_extended(poly.clone())) - .collect() - }) - .collect(); - let instance: Vec>> = instance_polys - .iter() - .map(|instance_polys| { - instance_polys - .iter() - .map(|poly| domain.coeff_to_extended(poly.clone())) - .collect() - }) - .collect(); - stop_measure(start); + assert!(self.custom_gate_clusters.len() <= num_clusters); - let mut values = domain.empty_extended(); + // Initialize the the powers of y and constraint counter + let mut y_powers = vec![C::ScalarExt::ONE; self.num_y_powers * instance_polys.len()]; + for i in 1..self.num_y_powers { + y_powers[i] = y_powers[i - 1] * y; + } - // Core expression evaluations - let num_threads = multicore::current_num_threads(); - for (((advice, instance), lookups), permutation) in advice - .iter() - .zip(instance.iter()) - .zip(lookups.iter()) - .zip(permutations.iter()) + let need_to_compute = |part_idx, cluster_idx| part_idx % (num_parts >> cluster_idx) == 0; + let compute_part_idx_in_cluster = + |part_idx, cluster_idx| part_idx >> (num_clusters - cluster_idx - 1); + + let mut value_part_clusters = Vec::new(); + value_part_clusters.resize(num_clusters, Vec::new()); + for (cluster_idx, cluster) in value_part_clusters + .iter_mut() + .enumerate() + .take(num_clusters) { - // Custom gates - let start = start_measure("custom gates", false); - multicore::scope(|scope| { - let chunk_size = (size + num_threads - 1) / num_threads; - for (thread_idx, values) in values.chunks_mut(chunk_size).enumerate() { - let start = thread_idx * chunk_size; - scope.spawn(move |_| { - let mut eval_data = self.custom_gates.instance(); - for (i, value) in values.iter_mut().enumerate() { - let idx = start + i; - *value = self.custom_gates.evaluate( - &mut eval_data, - fixed, - advice, - instance, - challenges, - &beta, - &gamma, - &theta, - &y, - value, - idx, - rot_scale, - isize, - ); + cluster.resize(1 << cluster_idx, domain.empty_lagrange()); + } + + // Calculate the quotient polynomial for each part + let mut current_extended_omega = one; + for part_idx in 0..num_parts { + let mut fixed: Vec>> = + vec![None; pk.fixed_polys.len()]; + let l0 = domain.coeff_to_extended_part(pk.l0.clone(), current_extended_omega); + let l_last = domain.coeff_to_extended_part(pk.l_last.clone(), current_extended_omega); + let l_active_row = + domain.coeff_to_extended_part(pk.l_active_row.clone(), current_extended_omega); + + let mut constraint_idx = 0; + let mut cluster_last_constraint_idx = vec![0; num_clusters]; + + // Core expression evaluations + let num_threads = multicore::current_num_threads(); + for (((advice_polys, instance_polys), lookups), permutation) in advice_polys + .iter() + .zip(instance_polys.iter()) + .zip(lookups.iter()) + .zip(permutations.iter()) + { + // Calculate the advice and instance cosets + let mut advice: Vec>> = + vec![None; advice_polys.len()]; + let mut instance: Vec>> = + vec![None; instance_polys.len()]; + + // Custom gates + let start = start_measure("custom gates", false); + for (cluster_idx, custom_gates) in self.custom_gate_clusters.iter().enumerate() { + if !need_to_compute(part_idx, cluster_idx) + || custom_gates.last_value_source.is_none() + { + continue; + } + let values = &mut value_part_clusters[cluster_idx] + [compute_part_idx_in_cluster(part_idx, cluster_idx)]; + for fixed_idx in custom_gates.used_fixed_columns.iter() { + if fixed[*fixed_idx].is_none() { + fixed[*fixed_idx] = Some(domain.coeff_to_extended_part( + pk.fixed_polys[*fixed_idx].clone(), + current_extended_omega, + )); + } + } + for instance_idx in custom_gates.used_instance_columns.iter() { + if instance[*instance_idx].is_none() { + instance[*instance_idx] = Some(domain.coeff_to_extended_part( + instance_polys[*instance_idx].clone(), + current_extended_omega, + )); + } + } + for advice_idx in custom_gates.used_advice_columns.iter() { + if advice[*advice_idx].is_none() { + advice[*advice_idx] = Some(domain.coeff_to_extended_part( + advice_polys[*advice_idx].clone(), + current_extended_omega, + )); + } + } + let fixed_slice = &fixed[..]; + let advice_slice = &advice[..]; + let instance_slice = &instance[..]; + let y_power_slice = &y_powers[..]; + let y_power = y_powers[constraint_idx + custom_gates.first_constraint_idx + - cluster_last_constraint_idx[cluster_idx]]; + multicore::scope(|scope| { + let chunk_size = (size + num_threads - 1) / num_threads; + for (thread_idx, values) in values.chunks_mut(chunk_size).enumerate() { + let start = thread_idx * chunk_size; + scope.spawn(move |_| { + let mut eval_data = custom_gates.evaluator.instance(); + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + *value = *value * y_power + + custom_gates.evaluator.evaluate( + &mut eval_data, + fixed_slice, + advice_slice, + instance_slice, + challenges, + y_power_slice, + &beta, + &gamma, + &theta, + idx, + rot_scale, + isize, + ); + } + }); } }); + + // Update the constraint index + cluster_last_constraint_idx[cluster_idx] = + constraint_idx + custom_gates.last_constraint_idx; } - }); - stop_measure(start); - - // Permutations - let start = start_measure("permutations", false); - let sets = &permutation.sets; - if !sets.is_empty() { - let blinding_factors = pk.vk.cs.blinding_factors(); - let last_rotation = Rotation(-((blinding_factors + 1) as i32)); - let chunk_len = pk.vk.cs.degree() - 2; - let delta_start = beta * &C::Scalar::ZETA; - - let first_set = sets.first().unwrap(); - let last_set = sets.last().unwrap(); - - // Permutation constraints - parallelize(&mut values, |values, start| { - let mut beta_term = extended_omega.pow_vartime(&[start as u64, 0, 0, 0]); - for (i, value) in values.iter_mut().enumerate() { - let idx = start + i; - let r_next = get_rotation_idx(idx, 1, rot_scale, isize); - let r_last = get_rotation_idx(idx, last_rotation.0, rot_scale, isize); - - // Enforce only for the first set. - // l_0(X) * (1 - z_0(X)) = 0 - *value = *value * y - + ((one - first_set.permutation_product_coset[idx]) * l0[idx]); - // Enforce only for the last set. - // l_last(X) * (z_l(X)^2 - z_l(X)) = 0 - *value = *value * y - + ((last_set.permutation_product_coset[idx] - * last_set.permutation_product_coset[idx] - - last_set.permutation_product_coset[idx]) - * l_last[idx]); - // Except for the first set, enforce. - // l_0(X) * (z_i(X) - z_{i-1}(\omega^(last) X)) = 0 - for (set_idx, set) in sets.iter().enumerate() { - if set_idx != 0 { - *value = *value * y - + ((set.permutation_product_coset[idx] - - permutation.sets[set_idx - 1].permutation_product_coset - [r_last]) - * l0[idx]); + constraint_idx += self.num_custom_gate_constraints; + stop_measure(start); + + + // Permutations + let start = start_measure("permutations", false); + let sets = &permutation.sets; + if !sets.is_empty() { + let blinding_factors = pk.vk.cs.blinding_factors(); + let last_rotation = Rotation(-((blinding_factors + 1) as i32)); + let chunk_len = pk.vk.cs.degree() - 2; + let delta_start = beta * &C::Scalar::ZETA; + + let permutation_product_cosets: Vec> = + sets.iter() + .map(|set| { + domain.coeff_to_extended_part( + set.permutation_product_poly.clone(), + current_extended_omega, + ) + }) + .collect(); + + let first_set_permutation_product_coset = + permutation_product_cosets.first().unwrap(); + let last_set_permutation_product_coset = + permutation_product_cosets.last().unwrap(); + + // Permutation constraints + constraint_idx += 1; + if need_to_compute(part_idx, 1) { + let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[1]]; + parallelize( + &mut value_part_clusters[1][compute_part_idx_in_cluster(part_idx, 1)], + |values, start| { + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + // Enforce only for the first set. + // l_0(X) * (1 - z_0(X)) = 0, degree = 2 + *value = *value * y_power + + ((one - first_set_permutation_product_coset[idx]) + * l0[idx]); + } + }, + ); + cluster_last_constraint_idx[1] = constraint_idx; + } + + constraint_idx += 1; + if need_to_compute(part_idx, 2) { + let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[2]]; + parallelize( + &mut value_part_clusters[2][compute_part_idx_in_cluster(part_idx, 2)], + |values, start| { + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + // Enforce only for the last set. + // l_last(X) * (z_l(X)^2 - z_l(X)) = 0, degree = 3 + *value = *value * y_power + + ((last_set_permutation_product_coset[idx] + * last_set_permutation_product_coset[idx] + - last_set_permutation_product_coset[idx]) + * l_last[idx]); + } + }, + ); + cluster_last_constraint_idx[2] = constraint_idx; + } + + constraint_idx += sets.len() - 1; + if need_to_compute(part_idx, 1) { + let y_skip = y_powers + [constraint_idx + 1 - sets.len() - cluster_last_constraint_idx[1]]; + parallelize( + &mut value_part_clusters[1][compute_part_idx_in_cluster(part_idx, 1)], + |values, start| { + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + // Except for the first set, enforce. + // l_0(X) * (z_i(X) - z_{i-1}(\omega^(last) X)) = 0, degree = 2 + let r_last = + get_rotation_idx(idx, last_rotation.0, rot_scale, isize); + + *value = *value * y_skip; + + for (set_idx, permutation_product_coset) in + permutation_product_cosets.iter().enumerate() + { + if set_idx != 0 { + *value = *value * y + + ((permutation_product_coset[idx] + - permutation_product_cosets[set_idx - 1] + [r_last]) + * l0[idx]); + } + } + } + }, + ); + cluster_last_constraint_idx[1] = constraint_idx; + } + + constraint_idx += sets.len(); + let running_prod_cluster = + Self::compute_cluster_idx(2 + chunk_len, num_clusters - 1); + if need_to_compute(part_idx, running_prod_cluster) { + for column in p.columns.iter() { + match column.column_type() { + Any::Advice(_) => { + let advice = &mut advice[column.index()]; + if (*advice).is_none() { + *advice = Some(domain.coeff_to_extended_part( + advice_polys[column.index()].clone(), + current_extended_omega, + )); + } + } + Any::Instance => { + let instance = &mut instance[column.index()]; + if instance.is_none() { + *instance = Some(domain.coeff_to_extended_part( + instance_polys[column.index()].clone(), + current_extended_omega, + )); + } + } + Any::Fixed => { + let fixed = &mut fixed[column.index()]; + if fixed.is_none() { + *fixed = Some(domain.coeff_to_extended_part( + pk.fixed_polys[column.index()].clone(), + current_extended_omega, + )); + } + } } } - // And for all the sets we enforce: - // (1 - (l_last(X) + l_blind(X))) * ( - // z_i(\omega X) \prod_j (p(X) + \beta s_j(X) + \gamma) - // - z_i(X) \prod_j (p(X) + \delta^j \beta X + \gamma) - // ) - let mut current_delta = delta_start * beta_term; - for ((set, columns), cosets) in sets + + let permutation_cosets: Vec> = pk + .permutation + .polys .iter() - .zip(p.columns.chunks(chunk_len)) - .zip(pk.permutation.cosets.chunks(chunk_len)) - { - let mut left = set.permutation_product_coset[r_next]; - for (values, permutation) in columns - .iter() - .map(|&column| match column.column_type() { - Any::Advice(_) => &advice[column.index()], - Any::Fixed => &fixed[column.index()], - Any::Instance => &instance[column.index()], - }) - .zip(cosets.iter()) - { - left *= values[idx] + beta * permutation[idx] + gamma; - } + .map(|p| { + domain.coeff_to_extended_part(p.clone(), current_extended_omega) + }) + .collect(); + + let y_skip = y_powers[constraint_idx + - sets.len() + - cluster_last_constraint_idx[running_prod_cluster]]; + + parallelize( + &mut value_part_clusters[running_prod_cluster] + [compute_part_idx_in_cluster(part_idx, running_prod_cluster)], + |values, start| { + let mut beta_term = current_extended_omega + * omega.pow_vartime(&[start as u64, 0, 0, 0]); + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + let r_next = get_rotation_idx(idx, 1, rot_scale, isize); + + *value = *value * y_skip; + + // And for all the sets we enforce: + // (1 - (l_last(X) + l_blind(X))) * ( + // z_i(\omega X) \prod_j (p(X) + \beta s_j(X) + \gamma) + // - z_i(X) \prod_j (p(X) + \delta^j \beta X + \gamma) + // ), degree = 2 + chunk_len + let mut current_delta = delta_start * beta_term; + for ( + (columns, permutation_product_coset), + permutation_coset_chunk, + ) in p + .columns + .chunks(chunk_len) + .zip(permutation_product_cosets.iter()) + .zip(permutation_cosets.chunks(chunk_len)) + { + let mut left = permutation_product_coset[r_next]; + for (values, permutation) in columns + .iter() + .map(|&column| match column.column_type() { + Any::Advice(_) => { + advice[column.index()].as_ref().unwrap() + } + Any::Fixed => { + fixed[column.index()].as_ref().unwrap() + } + Any::Instance => { + instance[column.index()].as_ref().unwrap() + } + }) + .zip(permutation_coset_chunk.iter()) + { + left *= values[idx] + beta * permutation[idx] + gamma; + } + + let mut right = permutation_product_coset[idx]; + for values in columns.iter().map(|&column| { + match column.column_type() { + Any::Advice(_) => { + advice[column.index()].as_ref().unwrap() + } + Any::Fixed => { + fixed[column.index()].as_ref().unwrap() + } + Any::Instance => { + instance[column.index()].as_ref().unwrap() + } + } + }) { + right *= values[idx] + current_delta + gamma; + current_delta *= &C::Scalar::DELTA; + } + + *value = *value * y + ((left - right) * l_active_row[idx]); + } + beta_term *= ω + } + }, + ); + cluster_last_constraint_idx[running_prod_cluster] = constraint_idx; + } + } + stop_measure(start); + + // Lookups + let start = start_measure("lookups", false); + for (n, lookup) in lookups.iter().enumerate() { + let (lookup_evaluator, max_degree, used_columns) = &self.lookups[n]; + let running_prod_cluster = + Self::compute_cluster_idx(max_degree + 2, num_clusters - 1); + if !need_to_compute(part_idx, 1) + && !need_to_compute(part_idx, 2) + && !need_to_compute(part_idx, running_prod_cluster) + { + constraint_idx += 5; + continue; + } + + // Polynomials required for this lookup. + // Calculated here so these only have to be kept in memory for the short time + // they are actually needed. + let product_coset = pk.vk.domain.coeff_to_extended_part( + lookup.product_poly.clone(), + current_extended_omega, + ); + let permuted_input_coset = pk.vk.domain.coeff_to_extended_part( + lookup.permuted_input_poly.clone(), + current_extended_omega, + ); + let permuted_table_coset = pk.vk.domain.coeff_to_extended_part( + lookup.permuted_table_poly.clone(), + current_extended_omega, + ); + + // Lookup constraints + constraint_idx += 1; + if need_to_compute(part_idx, 1) { + let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[1]]; + + parallelize( + &mut value_part_clusters[1][compute_part_idx_in_cluster(part_idx, 1)], + |values, start| { + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + // l_0(X) * (1 - z(X)) = 0, degree = 2 + *value = + *value * y_power + ((one - product_coset[idx]) * l0[idx]); + } + }, + ); + cluster_last_constraint_idx[1] = constraint_idx; + } - let mut right = set.permutation_product_coset[idx]; - for values in columns.iter().map(|&column| match column.column_type() { - Any::Advice(_) => &advice[column.index()], - Any::Fixed => &fixed[column.index()], - Any::Instance => &instance[column.index()], - }) { - right *= values[idx] + current_delta + gamma; - current_delta *= &C::Scalar::DELTA; + constraint_idx += 1; + if need_to_compute(part_idx, 2) { + let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[2]]; + parallelize( + &mut value_part_clusters[2][compute_part_idx_in_cluster(part_idx, 2)], + |values, start| { + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + // l_last(X) * (z(X)^2 - z(X)) = 0, degree = 3 + *value = *value * y_power + + ((product_coset[idx] * product_coset[idx] + - product_coset[idx]) + * l_last[idx]); + } + }, + ); + cluster_last_constraint_idx[2] = constraint_idx; + } + constraint_idx += 1; + if need_to_compute(part_idx, running_prod_cluster) { + for fixed_column in used_columns.0.iter() { + let fixed = &mut fixed[*fixed_column]; + if fixed.is_none() { + *fixed = Some(domain.coeff_to_extended_part( + pk.fixed_polys[*fixed_column].clone(), + current_extended_omega, + )); + } + } + for instance_column in used_columns.1.iter() { + let instance = &mut instance[*instance_column]; + if instance.is_none() { + *instance = Some(domain.coeff_to_extended_part( + instance_polys[*instance_column].clone(), + current_extended_omega, + )); } + } - *value = *value * y + ((left - right) * l_active_row[idx]); + for advice_column in used_columns.2.iter() { + let advice = &mut advice[*advice_column]; + if (*advice).is_none() { + *advice = Some(domain.coeff_to_extended_part( + advice_polys[*advice_column].clone(), + current_extended_omega, + )); + } } - beta_term *= &extended_omega; + + let y_power = y_powers + [constraint_idx - cluster_last_constraint_idx[running_prod_cluster]]; + let fixed_slice = &fixed[..]; + let advice_slice = &advice[..]; + let instance_slice = &instance[..]; + let y_power_slice = &y_powers[..]; + parallelize( + &mut value_part_clusters[running_prod_cluster] + [compute_part_idx_in_cluster(part_idx, running_prod_cluster)], + |values, start| { + let mut eval_data = lookup_evaluator.instance(); + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + let table_value = lookup_evaluator.evaluate( + &mut eval_data, + fixed_slice, + advice_slice, + instance_slice, + challenges, + y_power_slice, + &beta, + &gamma, + &theta, + idx, + rot_scale, + isize, + ); + + let r_next = get_rotation_idx(idx, 1, rot_scale, isize); + + // (1 - (l_last(X) + l_blind(X))) * ( + // z(\omega X) (a'(X) + \beta) (s'(X) + \gamma) + // - z(X) (\theta^{m-1} a_0(X) + ... + a_{m-1}(X) + \beta) + // (\theta^{m-1} s_0(X) + ... + s_{m-1}(X) + \gamma) + // ) = 0, degree = 2 + max(deg(a)) + max(deg(s)) + *value = *value * y_power + + ((product_coset[r_next] + * (permuted_input_coset[idx] + beta) + * (permuted_table_coset[idx] + gamma) + - product_coset[idx] * table_value) + * l_active_row[idx]); + } + }, + ); + cluster_last_constraint_idx[running_prod_cluster] = constraint_idx; } - }); - } - stop_measure(start); - - // Lookups - let start = start_measure("lookups", false); - for (n, lookup) in lookups.iter().enumerate() { - // Polynomials required for this lookup. - // Calculated here so these only have to be kept in memory for the short time - // they are actually needed. - let product_coset = pk.vk.domain.coeff_to_extended(lookup.product_poly.clone()); - let permuted_input_coset = pk - .vk - .domain - .coeff_to_extended(lookup.permuted_input_poly.clone()); - let permuted_table_coset = pk - .vk - .domain - .coeff_to_extended(lookup.permuted_table_poly.clone()); - - // Lookup constraints - parallelize(&mut values, |values, start| { - let lookup_evaluator = &self.lookups[n]; - let mut eval_data = lookup_evaluator.instance(); - for (i, value) in values.iter_mut().enumerate() { - let idx = start + i; - - let table_value = lookup_evaluator.evaluate( - &mut eval_data, - fixed, - advice, - instance, - challenges, - &beta, - &gamma, - &theta, - &y, - &C::ScalarExt::ZERO, - idx, - rot_scale, - isize, + + constraint_idx += 1; + if need_to_compute(part_idx, 1) { + let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[1]]; + parallelize( + &mut value_part_clusters[1][compute_part_idx_in_cluster(part_idx, 1)], + |values, start| { + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + let a_minus_s = + permuted_input_coset[idx] - permuted_table_coset[idx]; + // Check that the first values in the permuted input expression and permuted + // fixed expression are the same. + // l_0(X) * (a'(X) - s'(X)) = 0, degree = 2 + *value = *value * y_power + (a_minus_s * l0[idx]); + } + }, ); + cluster_last_constraint_idx[1] = constraint_idx; + } + + constraint_idx += 1; + if need_to_compute(part_idx, 2) { + let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[2]]; + parallelize( + &mut value_part_clusters[2][compute_part_idx_in_cluster(part_idx, 2)], + |values, start| { + for (i, value) in values.iter_mut().enumerate() { + let idx = start + i; + let r_prev = get_rotation_idx(idx, -1, rot_scale, isize); - let r_next = get_rotation_idx(idx, 1, rot_scale, isize); - let r_prev = get_rotation_idx(idx, -1, rot_scale, isize); - - let a_minus_s = permuted_input_coset[idx] - permuted_table_coset[idx]; - // l_0(X) * (1 - z(X)) = 0 - *value = *value * y + ((one - product_coset[idx]) * l0[idx]); - // l_last(X) * (z(X)^2 - z(X)) = 0 - *value = *value * y - + ((product_coset[idx] * product_coset[idx] - product_coset[idx]) - * l_last[idx]); - // (1 - (l_last(X) + l_blind(X))) * ( - // z(\omega X) (a'(X) + \beta) (s'(X) + \gamma) - // - z(X) (\theta^{m-1} a_0(X) + ... + a_{m-1}(X) + \beta) - // (\theta^{m-1} s_0(X) + ... + s_{m-1}(X) + \gamma) - // ) = 0 - *value = *value * y - + ((product_coset[r_next] - * (permuted_input_coset[idx] + beta) - * (permuted_table_coset[idx] + gamma) - - product_coset[idx] * table_value) - * l_active_row[idx]); - // Check that the first values in the permuted input expression and permuted - // fixed expression are the same. - // l_0(X) * (a'(X) - s'(X)) = 0 - *value = *value * y + (a_minus_s * l0[idx]); - // Check that each value in the permuted lookup input expression is either - // equal to the value above it, or the value at the same index in the - // permuted table expression. - // (1 - (l_last + l_blind)) * (a′(X) − s′(X))⋅(a′(X) − a′(\omega^{-1} X)) = 0 - *value = *value * y - + (a_minus_s - * (permuted_input_coset[idx] - permuted_input_coset[r_prev]) - * l_active_row[idx]); + // Check that each value in the permuted lookup input expression is either + // equal to the value above it, or the value at the same index in the + // permuted table expression. + // (1 - (l_last + l_blind)) * (a′(X) − s′(X))⋅(a′(X) − a′(\omega^{-1} X)) = 0, degree = 3 + let a_minus_s = + permuted_input_coset[idx] - permuted_table_coset[idx]; + *value = *value * y_power + + (a_minus_s + * (permuted_input_coset[idx] + - permuted_input_coset[r_prev]) + * l_active_row[idx]); + } + }, + ); + cluster_last_constraint_idx[2] = constraint_idx; } - }); + } + stop_measure(start); } - stop_measure(start); + // Align the constraints by different powers of y. + for (i, cluster) in value_part_clusters.iter_mut().enumerate() { + if need_to_compute(part_idx, i) && cluster_last_constraint_idx[i] > 0 { + let y_power = y_powers[constraint_idx - cluster_last_constraint_idx[i]]; + parallelize( + &mut cluster[compute_part_idx_in_cluster(part_idx, i)], + |values, _| { + for value in values.iter_mut() { + *value = *value * y_power; + } + }, + ); + } + } + current_extended_omega *= extended_omega; + } + domain.lagrange_vecs_to_extended(value_part_clusters) + } + + fn compute_cluster_idx(degree: usize, max_cluster_idx: usize) -> usize { + let mut idx = (31 - (degree as u32).leading_zeros()) as usize; + if 1 << idx < degree { + idx = idx + 1; } - values + std::cmp::min(max_cluster_idx, idx) } } @@ -708,15 +1162,14 @@ impl GraphEvaluator { pub fn evaluate( &self, data: &mut EvaluationData, - fixed: &[Polynomial], - advice: &[Polynomial], - instance: &[Polynomial], + fixed: &[Option>], + advice: &[Option>], + instance: &[Option>], challenges: &[C::ScalarExt], + y_powers: &[C::ScalarExt], beta: &C::ScalarExt, gamma: &C::ScalarExt, theta: &C::ScalarExt, - y: &C::ScalarExt, - previous_value: &C::ScalarExt, idx: usize, rot_scale: i32, isize: i32, @@ -736,11 +1189,10 @@ impl GraphEvaluator { advice, instance, challenges, + y_powers, beta, gamma, theta, - y, - previous_value, ); } diff --git a/halo2_proofs/src/plonk/keygen.rs b/halo2_proofs/src/plonk/keygen.rs index 16ef4c6fda..9b2a22550e 100644 --- a/halo2_proofs/src/plonk/keygen.rs +++ b/halo2_proofs/src/plonk/keygen.rs @@ -320,11 +320,6 @@ where .map(|poly| vk.domain.lagrange_to_coeff(poly.clone())) .collect(); - let fixed_cosets = fixed_polys - .iter() - .map(|poly| vk.domain.coeff_to_extended(poly.clone())) - .collect(); - let permutation_pk = assembly .permutation .build_pk(params, &vk.domain, &cs.permutation); @@ -334,7 +329,6 @@ where let mut l0 = vk.domain.empty_lagrange(); l0[0] = C::Scalar::ONE; let l0 = vk.domain.lagrange_to_coeff(l0); - let l0 = vk.domain.coeff_to_extended(l0); // Compute l_blind(X) which evaluates to 1 for each blinding factor row // and 0 otherwise over the domain. @@ -342,19 +336,15 @@ where for evaluation in l_blind[..].iter_mut().rev().take(cs.blinding_factors()) { *evaluation = C::Scalar::ONE; } - let l_blind = vk.domain.lagrange_to_coeff(l_blind); - let l_blind = vk.domain.coeff_to_extended(l_blind); // Compute l_last(X) which evaluates to 1 on the first inactive row (just // before the blinding factors) and 0 otherwise over the domain let mut l_last = vk.domain.empty_lagrange(); l_last[params.n() as usize - cs.blinding_factors() - 1] = C::Scalar::ONE; - let l_last = vk.domain.lagrange_to_coeff(l_last); - let l_last = vk.domain.coeff_to_extended(l_last); // Compute l_active_row(X) let one = C::Scalar::ONE; - let mut l_active_row = vk.domain.empty_extended(); + let mut l_active_row = vk.domain.empty_lagrange(); parallelize(&mut l_active_row, |values, start| { for (i, value) in values.iter_mut().enumerate() { let idx = i + start; @@ -362,6 +352,9 @@ where } }); + let l_last = vk.domain.lagrange_to_coeff(l_last); + let l_active_row = vk.domain.lagrange_to_coeff(l_active_row); + // Compute the optimized evaluation data structure let ev = Evaluator::new(&vk.cs); @@ -372,7 +365,6 @@ where l_active_row, fixed_values: fixed, fixed_polys, - fixed_cosets, permutation: permutation_pk, ev, }) diff --git a/halo2_proofs/src/plonk/permutation.rs b/halo2_proofs/src/plonk/permutation.rs index 0e92ccc705..2abcda3708 100644 --- a/halo2_proofs/src/plonk/permutation.rs +++ b/halo2_proofs/src/plonk/permutation.rs @@ -127,8 +127,7 @@ impl VerifyingKey { #[derive(Clone, Debug)] pub(crate) struct ProvingKey { permutations: Vec>, - polys: Vec>, - pub(super) cosets: Vec>, + pub(super) polys: Vec>, } impl ProvingKey @@ -139,11 +138,9 @@ where pub(super) fn read(reader: &mut R, format: SerdeFormat) -> io::Result { let permutations = read_polynomial_vec(reader, format)?; let polys = read_polynomial_vec(reader, format)?; - let cosets = read_polynomial_vec(reader, format)?; Ok(ProvingKey { permutations, polys, - cosets, }) } @@ -155,7 +152,6 @@ where ) -> io::Result<()> { write_polynomial_slice(&self.permutations, writer, format)?; write_polynomial_slice(&self.polys, writer, format)?; - write_polynomial_slice(&self.cosets, writer, format)?; Ok(()) } } @@ -165,6 +161,5 @@ impl ProvingKey { pub(super) fn bytes_length(&self) -> usize { polynomial_slice_byte_length(&self.permutations) + polynomial_slice_byte_length(&self.polys) - + polynomial_slice_byte_length(&self.cosets) } } diff --git a/halo2_proofs/src/plonk/permutation/keygen.rs b/halo2_proofs/src/plonk/permutation/keygen.rs index bcc0d02b54..868149685a 100644 --- a/halo2_proofs/src/plonk/permutation/keygen.rs +++ b/halo2_proofs/src/plonk/permutation/keygen.rs @@ -197,7 +197,7 @@ impl Assembly { }); } - // Compute permutation polynomials, convert to coset form. + // Compute permutation polynomials. let mut permutations = vec![domain.empty_lagrange(); p.columns.len()]; { parallelize(&mut permutations, |o, start| { @@ -222,21 +222,9 @@ impl Assembly { }); } - let mut cosets = vec![domain.empty_extended(); p.columns.len()]; - { - parallelize(&mut cosets, |o, start| { - for (x, coset) in o.iter_mut().enumerate() { - let i = start + x; - let poly = polys[i].clone(); - *coset = domain.coeff_to_extended(poly); - } - }); - } - ProvingKey { permutations, polys, - cosets, } } diff --git a/halo2_proofs/src/plonk/permutation/prover.rs b/halo2_proofs/src/plonk/permutation/prover.rs index f837a3686e..fe9544a6db 100644 --- a/halo2_proofs/src/plonk/permutation/prover.rs +++ b/halo2_proofs/src/plonk/permutation/prover.rs @@ -21,7 +21,6 @@ use crate::{ pub(crate) struct CommittedSet { pub(crate) permutation_product_poly: Polynomial, - pub(crate) permutation_product_coset: Polynomial, permutation_product_blind: Blind, } @@ -173,8 +172,6 @@ impl Argument { let z = domain.lagrange_to_coeff(z); let permutation_product_poly = z.clone(); - let permutation_product_coset = domain.coeff_to_extended(z); - let permutation_product_commitment = permutation_product_commitment_projective.to_affine(); @@ -183,7 +180,6 @@ impl Argument { sets.push(CommittedSet { permutation_product_poly, - permutation_product_coset, permutation_product_blind, }); } diff --git a/halo2_proofs/src/poly/domain.rs b/halo2_proofs/src/poly/domain.rs index fd72f48473..901620a90e 100644 --- a/halo2_proofs/src/poly/domain.rs +++ b/halo2_proofs/src/poly/domain.rs @@ -185,6 +185,32 @@ impl> EvaluationDomain { } } + /// Obtains a polynomial in ExtendedLagrange form when given a vector of + /// Lagrange polynomials with total size `extended_n`; panics if the + /// provided vector is the wrong length. + pub fn lagrange_vec_to_extended( + &self, + values: Vec>, + ) -> Polynomial { + assert_eq!(values.len(), (self.extended_len() >> self.k) as usize); + assert_eq!(values[0].len(), self.n as usize); + + // transpose the values in parallel + let mut transposed = vec![vec![F::ZERO; values.len()]; self.n as usize]; + values.into_iter().enumerate().for_each(|(i, p)| { + parallelize(&mut transposed, |transposed, start| { + for (transposed, p) in transposed.iter_mut().zip(p.values[start..].iter()) { + transposed[i] = *p; + } + }); + }); + + Polynomial { + values: transposed.into_iter().flatten().collect(), + _marker: PhantomData, + } + } + /// Returns an empty (zero) polynomial in the coefficient basis pub fn empty_coeff(&self) -> Polynomial { Polynomial { @@ -277,6 +303,82 @@ impl> EvaluationDomain { } } + /// This takes us from an n-length coefficient vector into parts of the + /// extended evaluation domain. For example, for a polynomial with size n, + /// and an extended domain of size mn, we can compute all parts + /// independently, which are + /// `FFT(f(zeta * X), n)` + /// `FFT(f(zeta * extended_omega * X), n)` + /// ... + /// `FFT(f(zeta * extended_omega^{m-1} * X), n)` + pub fn coeff_to_extended_parts( + &self, + a: &Polynomial, + ) -> Vec> { + assert_eq!(a.values.len(), 1 << self.k); + + let num_parts = self.extended_len() >> self.k; + let mut extended_omega_factor = F::ONE; + (0..num_parts) + .map(|_| { + let part = self.coeff_to_extended_part(a.clone(), extended_omega_factor); + extended_omega_factor *= self.extended_omega; + part + }) + .collect() + } + + /// This takes us from several n-length coefficient vectors each into parts + /// of the extended evaluation domain. For example, for a polynomial with + /// size n, and an extended domain of size mn, we can compute all parts + /// independently, which are + /// `FFT(f(zeta * X), n)` + /// `FFT(f(zeta * extended_omega * X), n)` + /// ... + /// `FFT(f(zeta * extended_omega^{m-1} * X), n)` + pub fn batched_coeff_to_extended_parts( + &self, + a: &[Polynomial], + ) -> Vec>> { + assert_eq!(a[0].values.len(), 1 << self.k); + + let mut extended_omega_factor = F::ONE; + let num_parts = self.extended_len() >> self.k; + (0..num_parts) + .map(|_| { + let a_lagrange = a + .iter() + .map(|poly| self.coeff_to_extended_part(poly.clone(), extended_omega_factor)) + .collect(); + extended_omega_factor *= self.extended_omega; + a_lagrange + }) + .collect() + } + + /// This takes us from an n-length coefficient vector into a part of the + /// extended evaluation domain. For example, for a polynomial with size n, + /// and an extended domain of size mn, we can compute one of the m parts + /// separately, which is + /// `FFT(f(zeta * extended_omega_factor * X), n)` + /// where `extended_omega_factor` is `extended_omega^i` with `i` in `[0, m)`. + pub fn coeff_to_extended_part( + &self, + mut a: Polynomial, + extended_omega_factor: F, + ) -> Polynomial { + assert_eq!(a.values.len(), 1 << self.k); + + self.distribute_powers(&mut a.values, self.g_coset * extended_omega_factor); + let data = self.get_fft_data(a.len()); + best_fft(&mut a.values, self.omega, self.k, data, false); + + Polynomial { + values: a.values, + _marker: PhantomData, + } + } + /// Rotate the extended domain polynomial over the original domain. pub fn rotate_extended( &self, @@ -326,6 +428,66 @@ impl> EvaluationDomain { a.values } + /// This takes us from the a list of lagrange-based polynomials with + /// different degrees and gets their extended lagrange-based summation. + pub fn lagrange_vecs_to_extended( + &self, + mut a: Vec>>, + ) -> Polynomial { + let mut result_poly = if a[a.len() - 1].len() == 1 << (self.extended_k - self.k) { + self.lagrange_vec_to_extended(a.pop().unwrap()) + } else { + self.empty_extended() + }; + + // Transform from each cluster of lagrange representations to coeff representations. + let mut ifft_divisor = self.extended_ifft_divisor; + let mut omega_inv = self.extended_omega_inv; + { + let mut i = a.last().unwrap().len() << self.k; + while i < (1 << self.extended_k) { + ifft_divisor = ifft_divisor + ifft_divisor; + omega_inv = omega_inv * omega_inv; + i = i << 1; + } + } + + let mut result = vec![F::ZERO; 1 << self.extended_k as usize]; + for (i, a_parts) in a.into_iter().enumerate().rev() { + // transpose the values in parallel + assert_eq!(1 << i, a_parts.len()); + let mut a_poly: Vec = { + let mut transposed = vec![vec![F::ZERO; a_parts.len()]; self.n as usize]; + a_parts.into_iter().enumerate().for_each(|(j, p)| { + parallelize(&mut transposed, |transposed, start| { + for (transposed, p) in transposed.iter_mut().zip(p.values[start..].iter()) { + transposed[j] = *p; + } + }); + }); + transposed.into_iter().flatten().collect() + }; + + self.ifft(&mut a_poly, omega_inv, self.k + i as u32, ifft_divisor); + ifft_divisor = ifft_divisor + ifft_divisor; + omega_inv = omega_inv * omega_inv; + + parallelize(&mut result[0..(self.n << i) as usize], |result, start| { + for (other, current) in result.iter_mut().zip(a_poly[start..].iter()) { + * other += current; + } + }); + } + let data = self.get_fft_data(result.len()); + best_fft(&mut result, self.extended_omega, self.extended_k, data, false); + parallelize(&mut result_poly.values, |values, start| { + for (value, other) in values.iter_mut().zip(result[start..].iter()) { + * value += other; + } + }); + result_poly + } + /// This divides the polynomial (in the extended domain) by the vanishing /// polynomial of the $2^k$ size domain. pub fn divide_by_vanishing_poly( @@ -374,6 +536,19 @@ impl> EvaluationDomain { }); } + /// Given a slice of group elements `[a_0, a_1, a_2, ...]`, this returns + /// `[a_0, [c]a_1, [c^2]a_2, [c^3]a_3, [c^4]a_4, ...]`, + /// + fn distribute_powers(&self, a: &mut [F], c: F) { + parallelize(a, |a, index| { + let mut c_power = c.pow_vartime(&[index as u64, 0, 0, 0]); + for a in a { + * a *= c_power; + c_power = c_power * c; + } + }); + } + fn ifft(&self, a: &mut Vec, omega_inv: F, log_n: u32, divisor: F) { self.fft_inner(a, omega_inv, log_n, true); parallelize(a, |a, _| { @@ -609,3 +784,165 @@ fn test_l_i() { assert_eq!(eval_polynomial(&l[(8 - i) % 8][..], x), evaluations[7 - i]); } } + +#[test] +fn test_coeff_to_extended_part() { + use halo2curves::pasta::pallas::Scalar; + use rand_core::OsRng; + + let domain = EvaluationDomain::::new(1, 3); + let rng = OsRng; + let mut poly = domain.empty_coeff(); + assert_eq!(poly.len(), 8); + for value in poly.iter_mut() { + *value = Scalar::random(rng); + } + + let want = domain.coeff_to_extended(poly.clone()); + let got = { + let parts = domain.coeff_to_extended_parts(&poly); + domain.lagrange_vec_to_extended(parts) + }; + assert_eq!(want.values, got.values); +} + +#[test] +fn bench_coeff_to_extended_parts() { + use halo2curves::pasta::pallas::Scalar; + use rand_core::OsRng; + use std::time::Instant; + + let k = 20; + let domain = EvaluationDomain::::new(3, k); + let rng = OsRng; + let mut poly1 = domain.empty_coeff(); + assert_eq!(poly1.len(), 1 << k); + + for value in poly1.iter_mut() { + *value = Scalar::random(rng); + } + + let poly2 = poly1.clone(); + + let coeff_to_extended_timer = Instant::now(); + let _ = domain.coeff_to_extended(poly1); + println!( + "domain.coeff_to_extended time: {}s", + coeff_to_extended_timer.elapsed().as_secs_f64() + ); + + let coeff_to_extended_parts_timer = Instant::now(); + let _ = domain.coeff_to_extended_parts(&poly2); + println!( + "domain.coeff_to_extended_parts time: {}s", + coeff_to_extended_parts_timer.elapsed().as_secs_f64() + ); +} + +#[test] +fn test_lagrange_vecs_to_extended() { + use halo2curves::pasta::pallas::Scalar; + use rand_core::OsRng; + + let rng = OsRng; + let domain = EvaluationDomain::::new(8, 3); + let mut poly_vec = vec![]; + let mut poly_lagrange_vecs = vec![]; + let mut want = domain.empty_extended(); + let mut omega = domain.extended_omega; + for i in (0..(domain.extended_k - domain.k + 1)).rev() { + let mut poly = vec![Scalar::zero(); (1 << i) * domain.n as usize]; + for value in poly.iter_mut() { + *value = Scalar::random(rng); + } + // poly under coeff representation. + poly_vec.push(poly.clone()); + // poly under lagrange vector representation. + let mut poly2 = poly.clone(); + let data = domain.get_fft_data(poly2.len()); + best_fft(&mut poly2, omega, i + domain.k, data, false); + let transposed_poly: Vec> = (0..(1 << i)) + .map(|j| { + let mut p = domain.empty_lagrange(); + for k in 0..domain.n { + p[k as usize] = poly2[j + (k as usize) * (1 << i)]; + } + p + }) + .collect(); + poly_lagrange_vecs.push(transposed_poly); + // poly under extended representation. + poly.resize(domain.extended_len() as usize, Scalar::zero()); + let data = domain.get_fft_data(poly.len()); + best_fft(&mut poly, domain.extended_omega, domain.extended_k, data, false); + let poly = { + let mut p = domain.empty_extended(); + p.values = poly; + p + }; + want = want + &poly; + omega = omega * omega; + } + + poly_lagrange_vecs.reverse(); + let got = domain.lagrange_vecs_to_extended(poly_lagrange_vecs); + assert_eq!(want.values, got.values); +} + +#[test] +fn bench_lagrange_vecs_to_extended() { + use halo2curves::pasta::pallas::Scalar; + use rand_core::OsRng; + use std::time::Instant; + + let rng = OsRng; + let domain = EvaluationDomain::::new(8, 10); + let mut poly_vec = vec![]; + let mut poly_lagrange_vecs = vec![]; + let mut poly_extended_vecs = vec![]; + let mut omega = domain.extended_omega; + + for i in (0..(domain.extended_k - domain.k + 1)).rev() { + let mut poly = vec![Scalar::zero(); (1 << i) * domain.n as usize]; + for value in poly.iter_mut() { + *value = Scalar::random(rng); + } + // poly under coeff representation. + poly_vec.push(poly.clone()); + // poly under lagrange vector representation. + let mut poly2 = poly.clone(); + let data = domain.get_fft_data(poly2.len()); + best_fft(&mut poly2, omega, i + domain.k, data, false); + let transposed_poly: Vec> = (0..(1 << i)) + .map(|j| { + let mut p = domain.empty_lagrange(); + for k in 0..domain.n { + p[k as usize] = poly2[j + (k as usize) * (1 << i)]; + } + p + }) + .collect(); + poly_lagrange_vecs.push(transposed_poly); + // poly under extended representation. + poly.resize(domain.extended_len() as usize, Scalar::zero()); + let data = domain.get_fft_data(poly.len()); + best_fft(&mut poly, domain.extended_omega, domain.extended_k, data, false); + let poly = { + let mut p = domain.empty_extended(); + p.values = poly; + p + }; + poly_extended_vecs.push(poly); + omega = omega * omega; + } + + let want_timer = Instant::now(); + let _ = poly_extended_vecs + .iter() + .fold(domain.empty_extended(), |acc, p| acc + p); + println!("want time: {}s", want_timer.elapsed().as_secs_f64()); + poly_lagrange_vecs.reverse(); + let got_timer = Instant::now(); + let _ = domain.lagrange_vecs_to_extended(poly_lagrange_vecs); + println!("got time: {}s", got_timer.elapsed().as_secs_f64()); +}