Skip to content

Commit

Permalink
Use termination bit to check that full trace is proven (#482)
Browse files Browse the repository at this point in the history
* Use termination bit to check that full trace is proven

* Don't set termination bit if program panicked

* Add test for truncated trace

* Update docs

* Update memory layout diagrams
  • Loading branch information
moodlezoup authored Oct 21, 2024
1 parent 249bb31 commit 133fe3d
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 15 deletions.
22 changes: 11 additions & 11 deletions book/src/how/read_write_memory.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Read-write memory (registers and RAM)

Jolt proves the validity of registers and RAM using offline memory checking.
In contrast to our usage of offline memory checking in other modules, registers and RAM are *writable* memory.
In contrast to our usage of offline memory checking in other modules, registers and RAM are *writable* memory.

## Memory layout

Expand All @@ -10,15 +10,15 @@ This remapped address space is laid out as follows:

![Memory layout](../imgs/memory_layout.png)

The zero-padding depicted above is sized so that RAM starts at a power-of-two offset (this is explained [below](#handling-program-io)).
As noted in the diagram, the size of the witness scales with the highest memory addressed over the course of the program's execution.
The zero-padding depicted above is sized so that RAM starts at a power-of-two offset (this is explained [below](#handling-program-io)).
As noted in the diagram, the size of the witness scales with the highest memory addressed over the course of the program's execution.
In addition to the zero-padding between the "Program I/O" and "RAM" sections, the end of the witness is zero-padded to a power of two.

## Handling program I/O

### Inputs
### Inputs

Program inputs and outputs (and the panic bit, which indicates whether the proram panicked) live in the same memory address space as RAM.
Program inputs and outputs (plus the panic and termination bits, which indicate whether the program has panicked or otherwise terminated, respectively) live in the same memory address space as RAM.
Program inputs populate the designated input space upon initialization:
![init memory](../imgs/initial_memory_state.png)

Expand All @@ -28,28 +28,28 @@ The verifier can efficiently compute the MLE of this initial memory state on its

On the other hand, the verifier cannot compute the MLE of the final memory state on its own –– though the program I/O is known to the verifier, the final memory state contains values written to registers/RAM over the course of program execution, which are *not* known to the verifier.

The verifier is, however, able to compute the MLE of the program I/O values (padded on both sides with zeros) –– this is denoted `v_io` below.
If the prover is honest, then the final memory state (`v_final` below) should agree with `v_io` at the indices corresponding to program I/O.
The verifier is, however, able to compute the MLE of the program I/O values (padded on both sides with zeros) –– this is denoted `v_io` below.
If the prover is honest, then the final memory state (`v_final` below) should agree with `v_io` at the indices corresponding to program I/O.

![final memory](../imgs/final_memory_state.png)

To enforce this, we invoke the sumcheck protocol to perform a "zero-check" on the difference between `v_final` and `v_io`:

![final memory](../imgs/program_output_sumcheck.png)

This also motivates the zero-padding between the "Program I/O" and "RAM" sections of the witness. The zero-padding ensures that both `input_start` and `ram_witness_offset` are powers of two, which makes it easier for the verifier to compute the MLEs of `v_init` and `v_io`.
This also motivates the zero-padding between the "Program I/O" and "RAM" sections of the witness. The zero-padding ensures that both `input_start` and `ram_witness_offset` are powers of two, which makes it easier for the verifier to compute the MLEs of `v_init` and `v_io`.

## Timestamp range check

Registers and RAM are *writable* memory, which introduces additional requirements compared to offline memory checking in a read-only context.

The multiset equality check for read-only memory, typically expressed as $I \cdot W = R \cdot F$, is not adequate for ensuring the accuracy of read values. It is essential to also verify that each read operation retrieves a value that was written in a *previous* step (not a future step). (Here, $I$ denotes the tuples capturing initialization of memory and $W$ the tuples capturing all of the writes to memory following initialization. $R$ denotes the tuples capturing all read operations, and $F$ denotes tuples capturing a final read pass over all memory cells).
The multiset equality check for read-only memory, typically expressed as $I \cdot W = R \cdot F$, is not adequate for ensuring the accuracy of read values. It is essential to also verify that each read operation retrieves a value that was written in a *previous* step (not a future step). (Here, $I$ denotes the tuples capturing initialization of memory and $W$ the tuples capturing all of the writes to memory following initialization. $R$ denotes the tuples capturing all read operations, and $F$ denotes tuples capturing a final read pass over all memory cells).

To formalize this, we assert that the timestamp of each read operation, denoted as $\text{read\_timestamp}$, must not exceed the global timestamp at that particular step.
To formalize this, we assert that the timestamp of each read operation, denoted as $\text{read\_timestamp}$, must not exceed the global timestamp at that particular step.
The global timestamp starts at 0 and is incremented once per step.

The verification of $\text{read\_timestamp} \leq \text{global\_timestamp}$ is equivalent to confirming that $\text{read\_timestamp}$ falls within the range $[0, \text{TRACE\_LENGTH})$ and that the difference $(\text{global\_timestamp} - \text{read\_timestamp})$ is also within the same range.

The process of ensuring that both $\text{read\_timestamp}$ and $(\text{global\_timestamp} - \text{read\_timestamp})$ lie within the specified range is known as range-checking. This is the procedure implemented in [`timestamp_range_check.rs`](https://github.com/a16z/jolt/blob/main/jolt-core/src/jolt/vm/timestamp_range_check.rs), using a modified version of Lasso.

Intuitively, checking that each read timestamp does not exceed the global timestamp prevents an attacker from answering all read operations to a given cell with "the right set of values, but out of order". Such an attack requires the attacker to "jump forward and backward in time". That is, for this attack to succeed, at some timestamp t when the cell is read, the attacker would have to return a value that will be written to that cell in the future (and at some later timestamp t' when the same cell is read the attacker would have to return a value that was written to that cell much earlier). This attack is prevented by confirming that all values returned have a timestamp that does not exceed the current global timestamp.
Intuitively, checking that each read timestamp does not exceed the global timestamp prevents an attacker from answering all read operations to a given cell with "the right set of values, but out of order". Such an attack requires the attacker to "jump forward and backward in time". That is, for this attack to succeed, at some timestamp t when the cell is read, the attacker would have to return a value that will be written to that cell in the future (and at some later timestamp t' when the same cell is read the attacker would have to return a value that was written to that cell much earlier). This attack is prevented by confirming that all values returned have a timestamp that does not exceed the current global timestamp.
Binary file modified book/src/imgs/final_memory_state.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified book/src/imgs/initial_memory_state.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified book/src/imgs/memory_layout.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions common/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ pub const fn virtual_register_index(index: u64) -> u64 {
}

// Layout of the witness (where || denotes concatenation):
// registers || inputs || outputs || panic || padding || RAM
// registers || virtual registers || inputs || outputs || panic || termination || padding || RAM
// Layout of VM memory:
// peripheral devices || inputs || outputs || panic || padding || RAM
// peripheral devices || inputs || outputs || panic || termination || padding || RAM
// Notably, we want to be able to map the VM memory address space to witness indices
// using a constant shift, namely (RAM_WITNESS_OFFSET + RAM_START_ADDRESS)
19 changes: 17 additions & 2 deletions common/src/rv_trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,10 @@ impl JoltDevice {
return;
}

if address == self.memory_layout.termination {
return;
}

let internal_address = self.convert_write_address(address);
if self.outputs.len() <= internal_address {
self.outputs.resize(internal_address + 1, 0);
Expand All @@ -684,13 +688,17 @@ impl JoltDevice {
}

pub fn is_output(&self, address: u64) -> bool {
address >= self.memory_layout.output_start && address < self.memory_layout.panic
address >= self.memory_layout.output_start && address < self.memory_layout.termination
}

pub fn is_panic(&self, address: u64) -> bool {
address == self.memory_layout.panic
}

pub fn is_termination(&self, address: u64) -> bool {
address == self.memory_layout.termination
}

fn convert_read_address(&self, address: u64) -> usize {
(address - self.memory_layout.input_start) as usize
}
Expand All @@ -712,6 +720,7 @@ pub struct MemoryLayout {
pub output_start: u64,
pub output_end: u64,
pub panic: u64,
pub termination: u64,
}

impl MemoryLayout {
Expand All @@ -725,12 +734,14 @@ impl MemoryLayout {
output_start: output_start(max_input_size, max_output_size),
output_end: output_end(max_input_size, max_output_size),
panic: panic_address(max_input_size, max_output_size),
termination: termination_address(max_input_size, max_output_size),
}
}
}

pub fn ram_witness_offset(max_input: u64, max_output: u64) -> u64 {
(REGISTER_COUNT + max_input + max_output + 1).next_power_of_two()
// Adds 2 to account for panic bit and termination bit
(REGISTER_COUNT + max_input + max_output + 2).next_power_of_two()
}

fn input_start(max_input: u64, max_output: u64) -> u64 {
Expand All @@ -752,3 +763,7 @@ fn output_end(max_input: u64, max_output: u64) -> u64 {
fn panic_address(max_input: u64, max_output: u64) -> u64 {
output_end(max_input, max_output) + 1
}

fn termination_address(max_input: u64, max_output: u64) -> u64 {
panic_address(max_input, max_output) + 1
}
16 changes: 16 additions & 0 deletions jolt-core/src/jolt/vm/read_write_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ impl<F: JoltField> ReadWriteMemoryPolynomials<F> {
trace: &Vec<JoltTraceStep<InstructionSet>>,
) -> (Self, [Vec<u64>; MEMORY_OPS_PER_INSTRUCTION]) {
assert!(program_io.inputs.len() <= program_io.memory_layout.max_input_size as usize);
println!("program_io.outputs.len(): {}", program_io.outputs.len());
println!("program_io.memory_layout: {:?}", program_io.memory_layout);
assert!(program_io.outputs.len() <= program_io.memory_layout.max_output_size as usize);

let m = trace.len();
Expand Down Expand Up @@ -1221,6 +1223,13 @@ where
program_io.memory_layout.panic,
program_io.memory_layout.ram_witness_offset,
)] = program_io.panic as u64;
if !program_io.panic {
// Set termination bit
v_io[memory_address_to_witness_index(
program_io.memory_layout.termination,
program_io.memory_layout.ram_witness_offset,
)] = 1;
}

let mut sumcheck_polys = vec![
eq,
Expand Down Expand Up @@ -1329,6 +1338,13 @@ where
memory_layout.panic,
memory_layout.ram_witness_offset,
)] = preprocessing.program_io.as_ref().unwrap().panic as u64;
if !preprocessing.program_io.as_ref().unwrap().panic {
// Set termination bit
v_io[memory_address_to_witness_index(
memory_layout.termination,
memory_layout.ram_witness_offset,
)] = 1;
}
let mut v_io_eval = DensePolynomial::from_u64(&v_io)
.evaluate(&r_sumcheck[(proof.num_rounds - log_nonzero_memory_size)..]);
v_io_eval *= r_prod;
Expand Down
29 changes: 29 additions & 0 deletions jolt-core/src/jolt/vm/rv32i_vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,4 +444,33 @@ mod tests {
verification_result.err()
);
}

#[test]
#[should_panic]
fn truncated_trace() {
let artifact_guard = FIB_FILE_LOCK.lock().unwrap();
let mut program = host::Program::new("fibonacci-guest");
program.set_input(&9u32);
let (bytecode, memory_init) = program.decode();
let (mut io_device, mut trace) = program.trace();
trace.truncate(100);
io_device.outputs[0] = 0; // change the output to 0
drop(artifact_guard);

let preprocessing =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20);
let (proof, commitments, debug_info) =
<RV32IJoltVM as Jolt<Fr, HyperKZG<Bn254>, C, M>>::prove(
io_device,
trace,
preprocessing.clone(),
);
let verification_result =
RV32IJoltVM::verify(preprocessing, proof, commitments, debug_info);
assert!(
verification_result.is_ok(),
"Verification failed with error: {:?}",
verification_result.err()
);
}
}
4 changes: 4 additions & 0 deletions jolt-sdk/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ impl MacroBuilder {
let output_start = memory_layout.output_start;
let max_input_len = attributes.max_input_size as usize;
let max_output_len = attributes.max_output_size as usize;
let termination_bit = memory_layout.termination as usize;

let get_input_slice = quote! {
let input_ptr = #input_start as *const u8;
Expand Down Expand Up @@ -341,6 +342,9 @@ impl MacroBuilder {
#check_input_len
#block
#handle_return
unsafe {
core::ptr::write_volatile(#termination_bit as *mut u8, 1);
}
}

#panic_fn
Expand Down
2 changes: 2 additions & 0 deletions tracer/src/emulator/mmu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ impl Mmu {
if effective_address < DRAM_BASE {
if self.jolt_device.is_output(effective_address)
|| self.jolt_device.is_panic(effective_address)
|| self.jolt_device.is_termination(effective_address)
{
self.tracer.push_memory(MemoryState::Write {
address: effective_address,
Expand Down Expand Up @@ -644,6 +645,7 @@ impl Mmu {
_ => {
if self.jolt_device.is_output(effective_address)
|| self.jolt_device.is_panic(effective_address)
|| self.jolt_device.is_termination(effective_address)
{
self.jolt_device.store(effective_address, value);
} else {
Expand Down

0 comments on commit 133fe3d

Please sign in to comment.