diff --git a/halo2_proofs/src/plonk/circuit.rs b/halo2_proofs/src/plonk/circuit.rs index 8b48f836b2..ee9fb47fc5 100644 --- a/halo2_proofs/src/plonk/circuit.rs +++ b/halo2_proofs/src/plonk/circuit.rs @@ -11,6 +11,7 @@ use ff::Field; use sealed::SealedPhase; use std::collections::HashMap; use std::fmt::Debug; +use std::iter::{Product, Sum}; use std::{ convert::TryFrom, ops::{Neg, Sub}, @@ -478,7 +479,7 @@ impl Selector { } /// Query of fixed column at a certain relative location -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct FixedQuery { /// Query index pub(crate) index: Option, @@ -501,7 +502,7 @@ impl FixedQuery { } /// Query of advice column at a certain relative location -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct AdviceQuery { /// Query index pub(crate) index: Option, @@ -531,7 +532,7 @@ impl AdviceQuery { } /// Query of instance column at a certain relative location -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct InstanceQuery { /// Query index pub(crate) index: Option, @@ -792,7 +793,7 @@ pub trait Circuit { } /// Low-degree expression representing an identity that must hold over the committed columns. -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq)] pub enum Expression { /// This is a constant polynomial Constant(F), @@ -1352,6 +1353,20 @@ impl Mul for Expression { } } +impl Sum for Expression { + fn sum>(iter: I) -> Self { + iter.reduce(|acc, x| acc + x) + .unwrap_or(Expression::Constant(F::ZERO)) + } +} + +impl Product for Expression { + fn product>(iter: I) -> Self { + iter.reduce(|acc, x| acc * x) + .unwrap_or(Expression::Constant(F::ONE)) + } +} + /// Represents an index into a vector where each entry corresponds to a distinct /// point that polynomials are queried at. #[derive(Copy, Clone, Debug)] @@ -2439,3 +2454,47 @@ impl<'a, F: Field> VirtualCells<'a, F> { Expression::Challenge(challenge) } } + +#[cfg(test)] +mod tests { + use super::Expression; + use halo2curves::bn256::Fr; + + #[test] + fn iter_sum() { + let exprs: Vec> = vec![ + Expression::Constant(1.into()), + Expression::Constant(2.into()), + Expression::Constant(3.into()), + ]; + let happened: Expression = exprs.into_iter().sum(); + let expected: Expression = Expression::Sum( + Box::new(Expression::Sum( + Box::new(Expression::Constant(1.into())), + Box::new(Expression::Constant(2.into())), + )), + Box::new(Expression::Constant(3.into())), + ); + + assert_eq!(happened, expected); + } + + #[test] + fn iter_product() { + let exprs: Vec> = vec![ + Expression::Constant(1.into()), + Expression::Constant(2.into()), + Expression::Constant(3.into()), + ]; + let happened: Expression = exprs.into_iter().product(); + let expected: Expression = Expression::Product( + Box::new(Expression::Product( + Box::new(Expression::Constant(1.into())), + Box::new(Expression::Constant(2.into())), + )), + Box::new(Expression::Constant(3.into())), + ); + + assert_eq!(happened, expected); + } +}