Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multiplication chip witness generation and range check #61

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions alu_u32/src/mul/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub struct Mul32Cols<T> {
pub is_real: T,

pub counter: T,
pub counter_mult: T,
}

pub const NUM_MUL_COLS: usize = size_of::<Mul32Cols<u8>>();
Expand Down
91 changes: 86 additions & 5 deletions alu_u32/src/mul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@ extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
use columns::{Mul32Cols, MUL_COL_MAP, NUM_MUL_COLS};
use core::iter::Sum;
use core::ops::Mul;
use itertools::iproduct;
use valida_bus::MachineWithGeneralBus;
use valida_cpu::MachineWithCpuChip;
use valida_machine::{instructions, Chip, Instruction, Interaction, Operands, Word};
use valida_machine::{instructions, BusArgument, Chip, Instruction, Interaction, Operands, Word};
use valida_opcodes::MUL32;
use valida_range::MachineWithRangeChip;

use core::borrow::BorrowMut;
use p3_air::VirtualPairCol;
use p3_field::PrimeField;
use p3_field::{PrimeField, PrimeField64};
use p3_matrix::dense::RowMajorMatrix;

pub mod columns;
Expand All @@ -29,7 +32,7 @@ pub struct Mul32Chip {

impl<F, M> Chip<M> for Mul32Chip
where
F: PrimeField,
F: PrimeField64,
M: MachineWithGeneralBus<F = F>,
{
fn generate_trace(&self, _machine: &M) -> RowMajorMatrix<M::F> {
Expand All @@ -44,6 +47,7 @@ where
let row = &mut values[i * NUM_MUL_COLS..(i + 1) * NUM_MUL_COLS];
let cols: &mut Mul32Cols<F> = row.borrow_mut();
cols.counter = F::from_canonical_usize(i + 1);
cols.is_real = F::ONE;
self.op_to_row(op, cols);
}

Expand All @@ -56,6 +60,16 @@ where
self.op_to_row(&dummy_op, cols);
}

// Set counter multiplicity
let num_rows = values.len() / NUM_MUL_COLS;
let mut mult = vec![F::ZERO; num_rows];
for i in 0..num_rows {
let r = values[MUL_COL_MAP.r + i * NUM_MUL_COLS].as_canonical_u64();
let s = values[MUL_COL_MAP.s + i * NUM_MUL_COLS].as_canonical_u64();
mult[r as usize] += F::ONE;
mult[s as usize] += F::ONE;
}

RowMajorMatrix {
values,
width: NUM_MUL_COLS,
Expand All @@ -82,8 +96,26 @@ where
}

fn local_sends(&self) -> Vec<Interaction<M::F>> {
// TODO
vec![]
let send_r = Interaction {
fields: vec![VirtualPairCol::single_main(MUL_COL_MAP.r)],
count: VirtualPairCol::one(),
argument_index: BusArgument::Local(0),
};
let send_s = Interaction {
fields: vec![VirtualPairCol::single_main(MUL_COL_MAP.s)],
count: VirtualPairCol::one(),
argument_index: BusArgument::Local(0),
};
vec![send_r, send_s]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need to range check output bytes

}

fn local_receives(&self) -> Vec<Interaction<M::F>> {
let receives = Interaction {
fields: vec![VirtualPairCol::single_main(MUL_COL_MAP.counter)],
count: VirtualPairCol::single_main(MUL_COL_MAP.counter_mult),
argument_index: BusArgument::Local(0),
};
vec![receives]
}
}

Expand All @@ -97,11 +129,60 @@ impl Mul32Chip {
cols.input_1 = b.transform(F::from_canonical_u8);
cols.input_2 = c.transform(F::from_canonical_u8);
cols.output = a.transform(F::from_canonical_u8);

// Compute $r$ to satisfy $pi - z = 2^32 r$.
let base_m32: [u64; 4] = [1, 1 << 8, 1 << 16, 1 << 24];
let pi = pi_m::<4, u64, u64>(
&base_m32,
b.transform(|x| x as u64),
c.transform(|x| x as u64),
);
let z: u32 = (*a).into();
let z: u64 = z as u64;
let r = (pi - z) / (1u64 << 32);
let r = r as u32;
cols.r = F::from_canonical_u32(r);

// Compute $s$ to satisfy $pi' - z' = 2^16 s$.
let base_m16: [u32; 2] = [1, 1 << 8];
let pi_prime = pi_m::<2, u32, u32>(
&base_m16,
b.transform(|x| x as u32),
c.transform(|x| x as u32),
);
let z_prime = a[3] as u32 + (1u32 << 8) * a[2] as u32;
let z_prime: u32 = z_prime.into();
let s = (pi_prime - z_prime) / (1u32 << 16);
cols.s = F::from_canonical_u32(s);
}
}
}
}

fn pi_m<const N: usize, I: Copy, O: Mul<I, Output = O> + Clone + Sum>(
base: &[O; N],
input_1: Word<I>,
input_2: Word<I>,
) -> O {
iproduct!(0..N, 0..N)
.filter(|(i, j)| i + j < N)
.map(|(i, j)| base[i + j].clone() * input_1[3 - i] * input_2[3 - j])
.sum()
}

fn sigma_m<const N: usize, I, O: Mul<I, Output = O> + Clone + Sum>(
base: &[O],
input: Word<I>,
) -> O {
input
.into_iter()
.rev()
.take(N)
.enumerate()
.map(|(i, x)| base[i].clone() * x)
.sum()
}

pub trait MachineWithMul32Chip: MachineWithCpuChip {
fn mul_u32(&self) -> &Mul32Chip;
fn mul_u32_mut(&mut self) -> &mut Mul32Chip;
Expand Down
42 changes: 11 additions & 31 deletions alu_u32/src/mul/stark.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use super::columns::Mul32Cols;
use super::Mul32Chip;
use super::{pi_m, sigma_m, Mul32Chip};
use core::borrow::Borrow;
use itertools::iproduct;
use valida_machine::Word;

use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::{AbstractField, PrimeField};
Expand All @@ -21,23 +19,26 @@ where
let next: &Mul32Cols<AB::Var> = main.row_slice(1).borrow();

// Limb weights modulo 2^32
let base_m = [1, 1 << 8, 1 << 16, 1 << 24].map(AB::Expr::from_canonical_u32);
let base_m32 = [1, 1 << 8, 1 << 16, 1 << 24].map(AB::Expr::from_canonical_u32);

// Limb weights modulo 2^16
let base_m16 = [1, 1 << 8].map(AB::Expr::from_canonical_u32);

// Partially reduced summation of input product limbs (mod 2^32)
let pi = pi_m::<4, AB>(&base_m, local.input_1, local.input_2);
let pi = pi_m::<4, AB::Var, AB::Expr>(&base_m32, local.input_1, local.input_2);

// Partially reduced summation of output limbs (mod 2^32)
let sigma = sigma_m::<4, AB>(&base_m, local.output);
let sigma = sigma_m::<4, AB::Var, AB::Expr>(&base_m32, local.output);

// Partially reduced summation of input product limbs (mod 2^16)
let pi_prime = pi_m::<2, AB>(&base_m[..2], local.input_1, local.input_2);
let pi_prime = pi_m::<2, AB::Var, AB::Expr>(&base_m16, local.input_1, local.input_2);

// Partially reduced summation of output limbs (mod 2^16)
let sigma_prime = sigma_m::<2, AB>(&base_m[..2], local.output);
let sigma_prime = sigma_m::<2, AB::Var, AB::Expr>(&base_m16, local.output);

// Congruence checks
builder.assert_eq(pi - sigma, local.r * AB::Expr::TWO);
builder.assert_eq(pi_prime - sigma_prime, local.s * base_m[2].clone());
builder.assert_eq(pi - sigma, local.r * AB::Expr::from_wrapped_u64(1 << 32));
builder.assert_eq(pi_prime - sigma_prime, local.s * base_m32[2].clone());

// Range check counter
builder
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look like this enforces a range of 1 ..= 2^10, should it be 0 .. 2^10?

Can also use assert_bool on the counter diff

Expand All @@ -52,24 +53,3 @@ where
.assert_eq(local.counter, AB::Expr::from_canonical_u32(1 << 10));
}
}

fn pi_m<const N: usize, AB: AirBuilder>(
base: &[AB::Expr],
input_1: Word<AB::Var>,
input_2: Word<AB::Var>,
) -> AB::Expr {
iproduct!(0..N, 0..N)
.filter(|(i, j)| i + j < N)
.map(|(i, j)| base[i + j].clone() * input_1[3 - i] * input_2[3 - j])
.sum()
}

fn sigma_m<const N: usize, AB: AirBuilder>(base: &[AB::Expr], input: Word<AB::Var>) -> AB::Expr {
input
.into_iter()
.rev()
.take(N)
.enumerate()
.map(|(i, x)| base[i].clone() * x)
.sum()
}
1 change: 1 addition & 0 deletions basic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ valida-memory = { path = "../memory" }
valida-output = { path = "../output" }
valida-program = { path = "../program" }
valida-range = { path = "../range" }
valida-opcodes = { path = "../opcodes" }
p3-maybe-rayon = { path = "../../Plonky3/maybe-rayon" }
p3-baby-bear = { path = "../../Plonky3/baby-bear" }
byteorder = "1.4.3"
Expand Down
Loading