Skip to content

Commit

Permalink
wip: static data chip: making the constraints pass
Browse files Browse the repository at this point in the history
  • Loading branch information
morganthomas committed Apr 2, 2024
1 parent ae4b90f commit 68526ec
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 61 deletions.
6 changes: 3 additions & 3 deletions basic/tests/test_static_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ use valida_machine::__internal::p3_commit::ExtensionMmcs;
#[test]
fn prove_static_data() {
// _start:
// imm32 0(fp), 0, 0, 0, 0
// imm32 0(fp), 0, 0, 0, 0x13
// load32 -4(fp), 0(fp), 0, 0, 0
// bnei _start, 0(fp), 0x25, 0, 1 // infinite loop unless static value is loaded
// stop
let program = vec![
InstructionWord {
opcode: <Imm32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
operands: Operands([0, 0, 0, 0, 0]),
operands: Operands([0, 0, 0, 0, 0x13]),
},
InstructionWord {
opcode: <Load32Instruction as Instruction<BasicMachine<Val>, Val>>::OPCODE,
Expand All @@ -61,7 +61,7 @@ fn prove_static_data() {

let mut machine = BasicMachine::<Val>::default();
let rom = ProgramROM::new(program);
machine.static_data_mut().write(0, Word([0, 0, 0, 0x25]));
machine.static_data_mut().write(0x13, Word([0, 0, 0, 0x25]));
machine.program_mut().set_program_rom(&rom);
machine.cpu_mut().fp = 0x1000;
machine.cpu_mut().save_register_state(); // TODO: Initial register state should be saved
Expand Down
46 changes: 24 additions & 22 deletions cpu/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#![no_std]

extern crate alloc;

Expand Down Expand Up @@ -91,23 +90,23 @@ where

fn global_sends(&self, machine: &M) -> Vec<Interaction<SC::Val>> {
// Memory bus channels
let mem_sends = (0..3).map(|i| {
let channel = &CPU_COL_MAP.mem_channels[i];
let is_read = VirtualPairCol::single_main(channel.is_read);
let clk = VirtualPairCol::single_main(CPU_COL_MAP.clk);
let addr = VirtualPairCol::single_main(channel.addr);
let is_static_initial = VirtualPairCol::constant(SC::Val::zero());
let value = channel.value.0.map(VirtualPairCol::single_main);

let mut fields = vec![is_read, clk, addr, is_static_initial];
fields.extend(value);

Interaction {
fields,
count: VirtualPairCol::single_main(channel.used),
argument_index: machine.mem_bus(),
}
});
// let mem_sends = (0..3).map(|i| {
// let channel = &CPU_COL_MAP.mem_channels[i];
// let is_read = VirtualPairCol::single_main(channel.is_read);
// let clk = VirtualPairCol::single_main(CPU_COL_MAP.clk);
// let addr = VirtualPairCol::single_main(channel.addr);
// let is_static_initial = VirtualPairCol::constant(SC::Val::zero());
// let value = channel.value.0.map(VirtualPairCol::single_main);

// let mut fields = vec![is_read, clk, addr, is_static_initial];
// fields.extend(value);

// Interaction {
// fields,
// count: VirtualPairCol::single_main(channel.used),
// argument_index: machine.mem_bus(),
// }
// });

// General bus channel
let mut fields = vec![VirtualPairCol::single_main(CPU_COL_MAP.instruction.opcode)];
Expand Down Expand Up @@ -145,10 +144,11 @@ where
// argument_index: machine.program_bus(),
// };

mem_sends
.chain(iter::once(send_general))
// .chain(iter::once(send_program))
.collect()
vec![send_general]
//mem_sends
// .chain(iter::once(send_general))
// // .chain(iter::once(send_program))
// .collect()
}
}

Expand Down Expand Up @@ -212,6 +212,8 @@ impl CpuChip {
}
}

std::println!("cpu row: {:?}", row.clone());

row
}

Expand Down
2 changes: 1 addition & 1 deletion memory/src/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use valida_derive::AlignedBorrow;
use valida_machine::Word;
use valida_util::indices_arr;

#[derive(AlignedBorrow, Default)]
#[derive(AlignedBorrow, Default, Debug)]
pub struct MemoryCols<T> {
/// Memory address
pub addr: T,
Expand Down
68 changes: 38 additions & 30 deletions memory/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#![no_std]

extern crate alloc;

use crate::columns::{MemoryCols, MEM_COL_MAP, NUM_MEM_COLS};
Expand Down Expand Up @@ -107,50 +105,56 @@ where
SC: StarkConfig,
{
fn generate_trace(&self, _machine: &M) -> RowMajorMatrix<SC::Val> {
let mut ops = self
.operations
.par_iter()
.map(|(clk, ops)| {
ops.iter()
.map(|op| (*clk, *op))
.collect::<Vec<(u32, Operation)>>()
})
.collect::<Vec<_>>()
.into_iter()
.flatten()
.collect::<Vec<_>>();

// Sort first by addr, then by clk
ops.sort_by_key(|(clk, op)| (op.get_address(), *clk));

// Consecutive sorted clock cycles for an address should differ no more
// than the length of the table (capped at 2^29)
Self::insert_dummy_reads(&mut ops);
// let mut ops = self
// .operations
// .par_iter()
// .map(|(clk, ops)| {
// ops.iter()
// .map(|op| (*clk, *op))
// .collect::<Vec<(u32, Operation)>>()
// })
// .collect::<Vec<_>>()
// .into_iter()
// .flatten()
// .collect::<Vec<_>>();

// // Sort first by addr, then by clk
// ops.sort_by_key(|(clk, op)| (op.get_address(), *clk));

// // Consecutive sorted clock cycles for an address should differ no more
// // than the length of the table (capped at 2^29)
// Self::insert_dummy_reads(&mut ops);

let mut rows = self.static_data
.iter()
.enumerate()
.map(|(n, (addr, value))| self.static_data_to_row(n, *addr, *value))
.collect::<Vec<_>>();

let padding_row = [SC::Val::zero(); NUM_MEM_COLS];

let n0 = rows.len();

let ops_rows = ops
.par_iter()
.enumerate()
.map(|(n, (clk, op))| self.op_to_row(n0+n, *clk as usize, *op))
.collect::<Vec<_>>();
rows.extend(ops_rows);
// let ops_rows = ops
// .par_iter()
// .enumerate()
// .map(|(n, (clk, op))| self.op_to_row(n0+n, *clk as usize, *op))
// .collect::<Vec<_>>();
// rows.extend(ops_rows.clone());

// Compute address difference values
self.compute_address_diffs(ops, &mut rows);
// self.compute_address_diffs(ops, &mut rows);

// Make sure the table length is a power of two
rows.resize(rows.len().next_power_of_two(), [SC::Val::zero(); NUM_MEM_COLS]);
rows.resize(rows.len().next_power_of_two(), padding_row);

let trace =
RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_MEM_COLS);
RowMajorMatrix::new(rows.clone().into_iter().flatten().collect::<Vec<_>>(), NUM_MEM_COLS);

std::println!("static data = {:?}\nmemory trace rows = {:?}",
self.static_data,
rows);
// std::println!("static data = {:?}\nops = {:?}\nops rows = {:?}\nmemory trace = {:?}", self.static_data, ops, ops_rows, trace);
trace
}

Expand Down Expand Up @@ -233,6 +237,10 @@ impl MemoryChip {
cols.addr = F::from_canonical_u32(addr);
cols.value = value.transform(F::from_canonical_u8);
cols.is_write = F::one();
cols.is_read = F::zero();
cols.diff = F::zero();
cols.diff_inv = F::zero();
cols.addr_not_equal = F::zero();
row
}

Expand Down
13 changes: 8 additions & 5 deletions static_data/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#![no_std]

extern crate alloc;

use crate::columns::{NUM_STATIC_DATA_COLS, STATIC_DATA_COL_MAP};
use crate::columns::{StaticDataCols, NUM_STATIC_DATA_COLS, STATIC_DATA_COL_MAP};
use alloc::collections::BTreeMap;
use alloc::vec;
use alloc::vec::Vec;
use core::mem::transmute;
use p3_air::VirtualPairCol;
use p3_field::{AbstractField, Field};
use p3_matrix::dense::RowMajorMatrix;
Expand Down Expand Up @@ -55,8 +55,12 @@ where
fn generate_trace(&self, machine: &M) -> RowMajorMatrix<SC::Val> {
let mut rows = self.cells.iter()
.map(|(addr, value)| {
let mut row: Vec<SC::Val> = vec![SC::Val::from_canonical_u32(*addr)];
row.extend(value.0.into_iter().map(SC::Val::from_canonical_u8).collect::<Vec<_>>());
let mut row = [SC::Val::zero(); NUM_STATIC_DATA_COLS];
let cols: &mut StaticDataCols<SC::Val> = unsafe { transmute(&mut row) };
cols.addr = SC::Val::from_canonical_u32(*addr);
cols.value = value.transform(SC::Val::from_canonical_u8);
cols.is_real = SC::Val::one();
std::println!("static data row: {:?}\n", row.clone());
row
})
.flatten()
Expand All @@ -66,7 +70,6 @@ where
}

fn global_sends(&self, machine: &M) -> Vec<Interaction<SC::Val>> {
// return vec![]; // TODO
let addr = VirtualPairCol::single_main(STATIC_DATA_COL_MAP.addr);
let value = STATIC_DATA_COL_MAP.value.0.map(VirtualPairCol::single_main);
let is_read = VirtualPairCol::constant(SC::Val::zero());
Expand Down

0 comments on commit 68526ec

Please sign in to comment.