Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precompile Backend #562

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions common/precompiles/bn254_add.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pub fn bn254_add(input_1: [u8; 16], input_2: [u8; 16]) -> [u8; 16] {
// This is a placeholder for the actual implementation.
[0; 16]
}

#[cfg(test)]
mod tests {
#[test]
fn test_bn254_add() {
let input_1 = [1; 16];
let input_2 = [2; 16];
let expected_output = [3; 16];
assert_eq!(bn254_add(input_1, input_2), expected_output);
}
}
22 changes: 22 additions & 0 deletions common/precompiles/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
pub mod bn254_add;

pub enum Precompile {
Bn254_add,
}

impl Precompile {
pub fn from_u64(value: u64) -> Option<Self> {
match value {
0 => None,
1 => Some(Precompile::Bn254_add),
_ => None,
}
}

pub fn execute(&self, inputs: &[u8]) -> Vec<u8> {
match self {
Precompile::Bn254_add => bn254_add::bn254_add(inputs),
}
}
}
// trait to deserialize the raw input bytes and serialize the outputs
15 changes: 14 additions & 1 deletion common/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use std::collections::HashMap;
use syn::{Lit, Meta, MetaNameValue, NestedMeta};

use crate::constants::{
DEFAULT_MAX_INPUT_SIZE, DEFAULT_MAX_OUTPUT_SIZE, DEFAULT_MEMORY_SIZE, DEFAULT_STACK_SIZE,
DEFAULT_MAX_INPUT_SIZE, DEFAULT_MAX_OUTPUT_SIZE, DEFAULT_MEMORY_SIZE, DEFAULT_STACK_SIZE,
DEFAULT_MAX_PRECOMPILE_INPUT_SIZE, DEFAULT_MAX_PRECOMPILE_OUTPUT_SIZE,
};

pub struct Attributes {
Expand All @@ -11,6 +12,8 @@ pub struct Attributes {
pub stack_size: u64,
pub max_input_size: u64,
pub max_output_size: u64,
pub max_precompile_input_size: u64,
pub max_precompile_output_size: u64,
}

pub fn parse_attributes(attr: &Vec<NestedMeta>) -> Attributes {
Expand All @@ -30,6 +33,8 @@ pub fn parse_attributes(attr: &Vec<NestedMeta>) -> Attributes {
"stack_size" => attributes.insert("stack_size", value),
"max_input_size" => attributes.insert("max_input_size", value),
"max_output_size" => attributes.insert("max_output_size", value),
"max_precompile_input_size" => attributes.insert("max_precompile_input_size", value),
"max_precompile_output_size" => attributes.insert("max_precompile_output_size", value),
_ => panic!("invalid attribute"),
};
}
Expand All @@ -50,12 +55,20 @@ pub fn parse_attributes(attr: &Vec<NestedMeta>) -> Attributes {
let max_output_size = *attributes
.get("max_output_size")
.unwrap_or(&DEFAULT_MAX_OUTPUT_SIZE);
let max_precompile_input_size = *attributes
.get("max_precompile_input_size")
.unwrap_or(&DEFAULT_MAX_PRECOMPILE_INPUT_SIZE);
let max_precompile_output_size = *attributes
.get("max_precompile_output_size")
.unwrap_or(&DEFAULT_MAX_PRECOMPILE_OUTPUT_SIZE);

Attributes {
wasm,
memory_size,
stack_size,
max_input_size,
max_output_size,
max_precompile_input_size,
max_precompile_output_size,
}
}
2 changes: 2 additions & 0 deletions common/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pub const DEFAULT_MEMORY_SIZE: u64 = 10 * 1024 * 1024;
pub const DEFAULT_STACK_SIZE: u64 = 4096;
pub const DEFAULT_MAX_INPUT_SIZE: u64 = 4096;
pub const DEFAULT_MAX_OUTPUT_SIZE: u64 = 4096;
pub const DEFAULT_MAX_PRECOMPILE_INPUT_SIZE: u64 = 512;
pub const DEFAULT_MAX_PRECOMPILE_OUTPUT_SIZE: u64 = 512;

pub const fn virtual_register_index(index: u64) -> u64 {
index + VIRTUAL_REGISTER_COUNT
Expand Down
55 changes: 52 additions & 3 deletions common/src/rv_trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ pub enum CircuitFlags {
Assert,
/// Used in virtual sequences; the program counter should be the same for the full sequence.
DoNotUpdatePC,
/// 1 if the instruction is an ecall/precompile
Precompile,
}
pub const NUM_CIRCUIT_FLAGS: usize = CircuitFlags::COUNT;

Expand Down Expand Up @@ -429,6 +431,7 @@ pub enum RV32IM {
VIRTUAL_ASSERT_EQ,
VIRTUAL_ASSERT_VALID_DIV0,
VIRTUAL_ASSERT_HALFWORD_ALIGNMENT,
VIRTUAL_PRECOMPILE,
}

impl FromStr for RV32IM {
Expand Down Expand Up @@ -588,6 +591,8 @@ pub struct JoltDevice {
pub outputs: Vec<u8>,
pub panic: bool,
pub memory_layout: MemoryLayout,
pub precompile_input: Vec<u8>,
pub precompile_output: Vec<u8>,
}

impl JoltDevice {
Expand All @@ -597,6 +602,8 @@ impl JoltDevice {
outputs: Vec::new(),
panic: false,
memory_layout: MemoryLayout::new(max_input_size, max_output_size),
precompile_input: Vec::new(),
precompile_output: Vec::new(),
}
}

Expand All @@ -619,6 +626,20 @@ impl JoltDevice {
} else {
self.outputs[internal_address]
}
} else if self.is_precompile_input(address) {
let internal_address = self.convert_read_address(address);
if self.precompile_inputs.len() <= internal_address {
0
} else {
self.precompile_inputs[internal_address]
}
} else if self.is_precompile_output(address) {
let internal_address = self.convert_read_address(address);
if self.precompile_outputs.len() <= internal_address {
0
} else {
self.precompile_outputs[internal_address]
}
} else {
0 // zero-padding
}
Expand All @@ -636,10 +657,19 @@ impl JoltDevice {
}

let internal_address = self.convert_write_address(address);

if self.outputs.len() <= internal_address {
self.outputs.resize(internal_address + 1, 0);
}

if self.precompile_outputs.len() <= internal_address {
self.precompile_outputs.resize(internal_address + 1, 0);
}

if self.precompile_outputs.len() <= internal_address {
self.precompile_outputs.resize(internal_address + 1, 0);
}

self.outputs[internal_address] = value;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The store method only takes one value as an input. Should there be additional Option inputs for precompile input and output?

}

Expand All @@ -655,6 +685,14 @@ impl JoltDevice {
address >= self.memory_layout.output_start && address < self.memory_layout.termination
}

pub fn is_precompile_input(&self, address: u64) -> bool {
address >= self.memory_layout.precompile_input_start && address < self.memory_layout.precompile_input_end
}

pub fn is_precompile_output(&self, address: u64) -> bool {
address >= self.memory_layout.precompile_output_start && address < self.memory_layout.precompile_output_end
}

pub fn is_panic(&self, address: u64) -> bool {
address == self.memory_layout.panic
}
Expand Down Expand Up @@ -682,10 +720,13 @@ pub struct MemoryLayout {
pub input_end: u64,
pub output_start: u64,
pub output_end: u64,
pub precompile_input_start: u64,
pub precompile_input_end: u64,
pub precompile_output_start: u64,
pub precompile_output_end: u64,
pub panic: u64,
pub termination: u64,
}

impl MemoryLayout {
pub fn new(mut max_input_size: u64, mut max_output_size: u64) -> Self {
// Must be word-aligned
Expand All @@ -694,7 +735,7 @@ impl MemoryLayout {

// Adds 8 to account for panic bit and termination bit
// (they each occupy one full 4-byte word)
let io_region_num_bytes = max_input_size + max_output_size + 8;
let io_region_num_bytes = max_input_size + max_output_size + 8 + 32;

// Padded so that the witness index corresponding to `RAM_START_ADDRESS`
// is a power of 2
Expand All @@ -704,7 +745,11 @@ impl MemoryLayout {
let input_end = input_start + max_input_size;
let output_start = input_end;
let output_end = output_start + max_output_size;
let panic = output_end;
let precompile_input_start = output_end;
let precompile_input_end = precompile_input_start + 16; // 512 bits
let precompile_output_start = precompile_input_end;
let precompile_output_end = precompile_output_start + 16; // 512 bits
let panic = precompile_output_end;
let termination = panic + 4;

Self {
Expand All @@ -714,6 +759,10 @@ impl MemoryLayout {
input_end,
output_start,
output_end,
precompile_input_start,
precompile_input_end,
precompile_output_start,
precompile_output_end,
panic,
termination,
}
Expand Down
Loading