diff --git a/Cargo.lock b/Cargo.lock index a105054669..597d6a4fd5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6748,6 +6748,7 @@ dependencies = [ "elf", "enum-map", "eyre", + "generic-array 1.1.0", "hashbrown 0.14.5", "hex", "itertools 0.13.0", @@ -6938,7 +6939,7 @@ dependencies = [ "p3-poseidon2", "p3-symmetric", "serde", - "sha2 0.10.8", + "sha3", ] [[package]] @@ -7211,6 +7212,7 @@ dependencies = [ "p3-field", "rand 0.8.5", "sha2 0.10.8", + "sha3", "sp1-lib", "sp1-primitives", ] diff --git a/crates/core/executor/Cargo.toml b/crates/core/executor/Cargo.toml index 273b12aa4c..54f1b15154 100644 --- a/crates/core/executor/Cargo.toml +++ b/crates/core/executor/Cargo.toml @@ -28,6 +28,7 @@ bincode = "1.3.3" hashbrown = { version = "0.14.5", features = ["serde", "inline-more"] } itertools = "0.13.0" rand = "0.8.5" +generic-array = { version = "1.1.0", features = ["alloc", "serde"] } num = { version = "0.4.3" } typenum = "1.17.0" nohash-hasher = "0.2.0" diff --git a/crates/core/executor/src/events/memcpy.rs b/crates/core/executor/src/events/memcpy.rs new file mode 100644 index 0000000000..4224c1647f --- /dev/null +++ b/crates/core/executor/src/events/memcpy.rs @@ -0,0 +1,15 @@ +use super::{LookupId, MemoryLocalEvent, MemoryReadRecord, MemoryWriteRecord}; +use serde::{Deserialize, Serialize}; + +#[derive(Default, Clone, Debug, Serialize, Deserialize)] +pub struct MemCopyEvent { + pub lookup_id: LookupId, + pub shard: u32, + pub clk: u32, + pub src_ptr: u32, + pub dst_ptr: u32, + pub read_records: Vec, + pub write_records: Vec, + /// The local memory access records. + pub local_mem_access: Vec, +} diff --git a/crates/core/executor/src/events/mod.rs b/crates/core/executor/src/events/mod.rs index da38bb83c2..ffb53ef8ae 100644 --- a/crates/core/executor/src/events/mod.rs +++ b/crates/core/executor/src/events/mod.rs @@ -3,6 +3,7 @@ mod alu; mod byte; mod cpu; +mod memcpy; mod memory; mod precompiles; mod syscall; @@ -11,6 +12,7 @@ mod utils; pub use alu::*; pub use byte::*; pub use cpu::*; +pub use memcpy::*; pub use memory::*; pub use precompiles::*; pub use syscall::*; diff --git a/crates/core/executor/src/events/precompiles/bn254_scalar.rs b/crates/core/executor/src/events/precompiles/bn254_scalar.rs new file mode 100644 index 0000000000..b46afe5886 --- /dev/null +++ b/crates/core/executor/src/events/precompiles/bn254_scalar.rs @@ -0,0 +1,166 @@ +use num::BigUint; +use sp1_curves::{ + params::{FieldParameters, NumWords}, + weierstrass::bn254::Bn254ScalarField, +}; +use typenum::Unsigned; + +use serde::{Deserialize, Serialize}; + +use crate::{ + events::{LookupId, MemoryLocalEvent, MemoryReadRecord, MemoryWriteRecord}, + syscalls::SyscallContext, +}; + +use super::FieldOperation; + +pub const NUM_WORDS_PER_FE: usize = 8; + +#[derive(Default, PartialEq, Copy, Clone, Debug, Serialize, Deserialize)] +pub enum Bn254FieldOperation { + #[default] + Invalid = 0, + Mul = 2, + Mac = 4, +} + +impl Bn254FieldOperation { + pub const fn to_field_operation(&self) -> FieldOperation { + match self { + Bn254FieldOperation::Mul => FieldOperation::Mul, + Bn254FieldOperation::Mac => panic!("not supported"), + Bn254FieldOperation::Invalid => panic!("what??"), + } + } +} + +#[derive(Default, Clone, Debug, Serialize, Deserialize)] +pub struct Bn254FieldArithEvent { + pub lookup_id: LookupId, + pub shard: u32, + pub clk: u32, + pub op: Bn254FieldOperation, + pub arg1: FieldArithMemoryAccess, + pub arg2: FieldArithMemoryAccess, + pub a: Option>, + pub b: Option>, + /// The local memory access records. + pub local_mem_access: Vec, +} + +pub fn create_bn254_scalar_arith_event( + rt: &mut SyscallContext, + arg1: u32, + arg2: u32, + op: Bn254FieldOperation, +) -> Bn254FieldArithEvent { + let start_clk = rt.clk; + let p_ptr = arg1; + let q_ptr = arg2; + + assert_eq!(p_ptr % 4, 0, "p_ptr({p_ptr:x}) is not aligned"); + assert_eq!(q_ptr % 4, 0, "q_ptr({q_ptr:x}) is not aligned"); + + let nw_per_fe = ::WordsFieldElement::USIZE; + debug_assert_eq!(nw_per_fe, NUM_WORDS_PER_FE); + + let arg1: Vec = rt.slice_unsafe(p_ptr, nw_per_fe); + let arg2 = match op { + // 2 ptrs of real U256 values + Bn254FieldOperation::Mac => FieldArithMemoryAccess::read(rt, arg2, 2), + _ => FieldArithMemoryAccess::read(rt, arg2, nw_per_fe), + }; + + let bn_arg1 = BigUint::from_bytes_le( + &arg1.iter().copied().flat_map(u32::to_le_bytes).collect::>(), + ); + let modulus = Bn254ScalarField::modulus(); + + let (a, b, bn_arg1_out) = if matches!(op, Bn254FieldOperation::Mac) { + let a = FieldArithMemoryAccess::read(rt, arg2.memory_records[0].value, nw_per_fe); + let b = FieldArithMemoryAccess::read(rt, arg2.memory_records[1].value, nw_per_fe); + + let bn_a = a.value_as_biguint(); + let bn_b = b.value_as_biguint(); + let bn_arg1_out = (&bn_a * &bn_b + &bn_arg1) % modulus; + + (Some(a), Some(b), bn_arg1_out) + } else { + let bn_arg2 = arg2.value_as_biguint(); + + let bn_arg1_out = match op { + Bn254FieldOperation::Mul => (&bn_arg1 * &bn_arg2) % modulus, + _ => unimplemented!("not supported"), + }; + (None, None, bn_arg1_out) + }; + + log::trace!( + "shard: {}, clk: {}, op: {:?}, arg1: {:?}, arg2: {:?}, a: {:?}, b: {:?}", + rt.current_shard(), + rt.clk, + op, + arg1, + arg2, + a, + b + ); + rt.clk += 1; + + let mut result_words = bn_arg1_out.to_u32_digits(); + result_words.resize(nw_per_fe, 0); + + let arg1 = FieldArithMemoryAccess::write(rt, p_ptr, &result_words); + + let shard = rt.current_shard(); + Bn254FieldArithEvent { + lookup_id: rt.syscall_lookup_id, + shard, + clk: start_clk, + op, + arg1, + arg2, + a, + b, + local_mem_access: rt.postprocess(), + } +} + +#[derive(Default, Clone, Debug, Serialize, Deserialize)] +pub struct FieldArithMemoryAccess { + pub ptr: u32, + pub memory_records: Vec, +} + +impl FieldArithMemoryAccess { + pub fn read(rt: &mut SyscallContext, ptr: u32, len: usize) -> Self { + let (memory_records, _) = rt.mr_slice(ptr, len); + Self { ptr, memory_records } + } + + pub fn value_as_biguint(&self) -> BigUint { + BigUint::from_bytes_le( + &self + .memory_records + .iter() + .flat_map(|word| word.value.to_le_bytes()) + .collect::>(), + ) + } +} + +impl FieldArithMemoryAccess { + pub fn write(rt: &mut SyscallContext, ptr: u32, values: &[u32]) -> Self { + Self { ptr, memory_records: rt.mw_slice(ptr, values) } + } + + pub fn prev_value_as_biguint(&self) -> BigUint { + BigUint::from_bytes_le( + &self + .memory_records + .iter() + .flat_map(|word| word.prev_value.to_le_bytes()) + .collect::>(), + ) + } +} diff --git a/crates/core/executor/src/events/precompiles/mod.rs b/crates/core/executor/src/events/precompiles/mod.rs index cc4f54ed7d..71c7ac734d 100644 --- a/crates/core/executor/src/events/precompiles/mod.rs +++ b/crates/core/executor/src/events/precompiles/mod.rs @@ -1,3 +1,4 @@ +mod bn254_scalar; mod ec; mod edwards; mod fptower; @@ -6,6 +7,9 @@ mod sha256_compress; mod sha256_extend; mod uint256; +pub use bn254_scalar::{ + create_bn254_scalar_arith_event, Bn254FieldArithEvent, Bn254FieldOperation, NUM_WORDS_PER_FE, +}; pub use ec::*; pub use edwards::*; pub use fptower::*; @@ -19,6 +23,7 @@ pub use uint256::*; use crate::syscalls::SyscallCode; +use super::{MemCopyEvent}; use super::{MemoryLocalEvent, SyscallEvent}; #[derive(Clone, Debug, Serialize, Deserialize, EnumIter)] @@ -52,6 +57,12 @@ pub enum PrecompileEvent { Bn254Fp2AddSub(Fp2AddSubEvent), /// Bn254 quadratic field mul precompile event. Bn254Fp2Mul(Fp2MulEvent), + + Bn254ScalarMac(Bn254FieldArithEvent), + Bn254ScalarMul(Bn254FieldArithEvent), + MemCopy32(MemCopyEvent), + MemCopy64(MemCopyEvent), + /// Bls12-381 curve add precompile event. Bls12381Add(EllipticCurveAddEvent), /// Bls12-381 curve double precompile event. @@ -120,6 +131,12 @@ impl PrecompileLocalMemory for Vec<(SyscallEvent, PrecompileEvent)> { PrecompileEvent::Bls12381Fp2Mul(e) | PrecompileEvent::Bn254Fp2Mul(e) => { iterators.push(e.local_mem_access.iter()); } + PrecompileEvent::Bn254ScalarMac(e) | PrecompileEvent::Bn254ScalarMul(e) => { + iterators.push(e.local_mem_access.iter()); + } + PrecompileEvent::MemCopy32(e) | PrecompileEvent::MemCopy64(e) => { + iterators.push(e.local_mem_access.iter()); + } } } diff --git a/crates/core/executor/src/executor.rs b/crates/core/executor/src/executor.rs index cfbc8cacd5..6e3934af40 100644 --- a/crates/core/executor/src/executor.rs +++ b/crates/core/executor/src/executor.rs @@ -899,6 +899,10 @@ impl<'a> Executor<'a> { let value = (memory_read_value).to_le_bytes()[(addr % 4) as usize]; a = ((value as i8) as i32) as u32; memory_store_value = Some(memory_read_value); + //println!( + // "[clk: {}, pc: 0x{:x}] LB: {:?} <- {:x}", + // self.state.global_clk, self.state.pc, rd, a + //); self.rw(rd, a); } Opcode::LH => { @@ -913,6 +917,10 @@ impl<'a> Executor<'a> { }; a = ((value as i16) as i32) as u32; memory_store_value = Some(memory_read_value); + //println!( + // "[clk: {}, pc: 0x{:x}] LH: {:?} <- {:x}", + // self.state.global_clk, self.state.pc, rd, a + //); self.rw(rd, a); } Opcode::LW => { @@ -922,6 +930,10 @@ impl<'a> Executor<'a> { } a = memory_read_value; memory_store_value = Some(memory_read_value); + //println!( + // "[clk: {}, pc: 0x{:x}] LW: {:?} <- {}", + // self.state.global_clk, self.state.pc, rd, a + //); self.rw(rd, a); } Opcode::LBU => { @@ -957,6 +969,10 @@ impl<'a> Executor<'a> { _ => unreachable!(), }; memory_store_value = Some(value); + //println!( + // "[clk: {}, pc: 0x{:x}] SB 0x{:x} <- 0x{:x}", + // self.state.global_clk, pc, addr, value + //); self.mw_cpu(align(addr), value, MemoryAccessPosition::Memory); } Opcode::SH => { @@ -970,6 +986,10 @@ impl<'a> Executor<'a> { _ => unreachable!(), }; memory_store_value = Some(value); + //println!( + // "[clk: {}, pc: 0x{:x}] SH 0x{:x} <- 0x{:x}", + // self.state.global_clk, pc, addr, value + //); self.mw_cpu(align(addr), value, MemoryAccessPosition::Memory); } Opcode::SW => { @@ -979,6 +999,10 @@ impl<'a> Executor<'a> { } let value = a; memory_store_value = Some(value); + //println!( + // "[clk: {}, pc: 0x{:x}] SW 0x{:x} <- 0x{:x}", + // self.state.global_clk, pc, addr, value + //); self.mw_cpu(align(addr), value, MemoryAccessPosition::Memory); } @@ -1070,6 +1094,7 @@ impl<'a> Executor<'a> { return Err(ExecutionError::InvalidSyscallUsage(syscall_id as u64)); } + let global_clk = self.state.global_clk; // Update the syscall counts. let syscall_for_count = syscall.count_map(); let syscall_count = self.state.syscall_counts.entry(syscall_for_count).or_insert(0); @@ -1089,6 +1114,14 @@ impl<'a> Executor<'a> { } let mut precompile_rt = SyscallContext::new(self); precompile_rt.syscall_lookup_id = syscall_lookup_id; + log::trace!( + "[clk: {}, pc: 0x{:x}] ecall syscall_id=0x{:x}, b: 0x{:x}, c: 0x{:x}", + global_clk, + pc, + syscall_id, + b, + c, + ); let (precompile_next_pc, precompile_cycles, returned_exit_code) = if let Some(syscall_impl) = syscall_impl { // Executing a syscall optionally returns a value to write to the t0 @@ -1125,6 +1158,10 @@ impl<'a> Executor<'a> { next_pc = precompile_next_pc; self.state.clk += precompile_cycles; exit_code = returned_exit_code; + + //log::info!( + // "execute_instruction {syscall:?} {syscall_count} {nonce} {syscall_lookup_id}" + //); } Opcode::EBREAK => { return Err(ExecutionError::Breakpoint()); diff --git a/crates/core/executor/src/syscalls/code.rs b/crates/core/executor/src/syscalls/code.rs index ece5a06e99..e6c520b67f 100644 --- a/crates/core/executor/src/syscalls/code.rs +++ b/crates/core/executor/src/syscalls/code.rs @@ -128,6 +128,15 @@ pub enum SyscallCode { /// Executes the `BN254_FP2_MUL` precompile. BN254_FP2_MUL = 0x00_01_01_2B, + + /// Execute the `MEMCPY_32` precompile. + MEMCPY_32 = 0x00_01_01_90, + /// Execute the `MEMCPY_64` precompile. + MEMCPY_64 = 0x00_01_01_91, + /// Execute the `BN254_SCALAR_MUL` precompile. + BN254_SCALAR_MUL = 0x00_01_01_80, + /// Execute the `BN254_SCALAR_MAC` precompile. + BN254_SCALAR_MAC = 0x00_01_01_81, } impl SyscallCode { @@ -170,6 +179,10 @@ impl SyscallCode { 0x00_01_01_2A => SyscallCode::BN254_FP2_SUB, 0x00_01_01_2B => SyscallCode::BN254_FP2_MUL, 0x00_00_01_1C => SyscallCode::BLS12381_DECOMPRESS, + 0x00_01_01_90 => SyscallCode::MEMCPY_32, + 0x00_01_01_91 => SyscallCode::MEMCPY_64, + 0x00_01_01_80 => SyscallCode::BN254_SCALAR_MUL, + 0x00_01_01_81 => SyscallCode::BN254_SCALAR_MAC, _ => panic!("invalid syscall number: {value}"), } } diff --git a/crates/core/executor/src/syscalls/mod.rs b/crates/core/executor/src/syscalls/mod.rs index 6633633102..2bf28d9cae 100644 --- a/crates/core/executor/src/syscalls/mod.rs +++ b/crates/core/executor/src/syscalls/mod.rs @@ -22,9 +22,11 @@ pub use code::*; pub use context::*; use hint::{HintLenSyscall, HintReadSyscall}; use precompiles::{ + bn254_scalar::{Bn254ScalarMacSyscall, Bn254ScalarMulSyscall}, edwards::{add::EdwardsAddAssignSyscall, decompress::EdwardsDecompressSyscall}, fptower::{Fp2AddSubSyscall, Fp2MulSyscall, FpOpSyscall}, keccak256::permute::Keccak256PermuteSyscall, + memcopy::MemCopySyscall, sha256::{compress::Sha256CompressSyscall, extend::Sha256ExtendSyscall}, uint256::Uint256MulSyscall, weierstrass::{ @@ -206,5 +208,16 @@ pub fn default_syscall_map() -> HashMap> { Arc::new(WeierstrassDecompressSyscall::::new()), ); + syscall_map.insert(SyscallCode::BN254_SCALAR_MUL, Arc::new(Bn254ScalarMulSyscall)); + syscall_map.insert(SyscallCode::BN254_SCALAR_MAC, Arc::new(Bn254ScalarMacSyscall)); + syscall_map.insert( + SyscallCode::MEMCPY_32, + Arc::new(MemCopySyscall::::new()), + ); + syscall_map.insert( + SyscallCode::MEMCPY_64, + Arc::new(MemCopySyscall::::new()), + ); + syscall_map } diff --git a/crates/core/executor/src/syscalls/precompiles/bn254_scalar/mod.rs b/crates/core/executor/src/syscalls/precompiles/bn254_scalar/mod.rs new file mode 100644 index 0000000000..2fc024dad1 --- /dev/null +++ b/crates/core/executor/src/syscalls/precompiles/bn254_scalar/mod.rs @@ -0,0 +1,60 @@ +use crate::{ + events::{create_bn254_scalar_arith_event, Bn254FieldOperation, PrecompileEvent}, + syscalls::{Syscall, SyscallCode, SyscallContext}, +}; + +pub(crate) struct Bn254ScalarMacSyscall; + +impl Syscall for Bn254ScalarMacSyscall { + fn execute( + &self, + rt: &mut SyscallContext, + syscall_code: SyscallCode, + arg1: u32, + arg2: u32, + ) -> Option { + let start_clk = rt.clk; + let event = create_bn254_scalar_arith_event(rt, arg1, arg2, Bn254FieldOperation::Mac); + let syscall_event = + rt.rt.syscall_event(start_clk, syscall_code.syscall_id(), arg1, arg2, event.lookup_id); + + rt.record_mut().add_precompile_event( + syscall_code, + syscall_event, + PrecompileEvent::Bn254ScalarMac(event), + ); + + None + } + + fn num_extra_cycles(&self) -> u32 { + 1 + } +} + +pub(crate) struct Bn254ScalarMulSyscall; +impl Syscall for Bn254ScalarMulSyscall { + fn execute( + &self, + rt: &mut SyscallContext, + syscall_code: SyscallCode, + arg1: u32, + arg2: u32, + ) -> Option { + let start_clk = rt.clk; + let event = create_bn254_scalar_arith_event(rt, arg1, arg2, Bn254FieldOperation::Mul); + let syscall_event = + rt.rt.syscall_event(start_clk, syscall_code.syscall_id(), arg1, arg2, event.lookup_id); + rt.record_mut().add_precompile_event( + syscall_code, + syscall_event, + PrecompileEvent::Bn254ScalarMul(event), + ); + + None + } + + fn num_extra_cycles(&self) -> u32 { + 1 + } +} diff --git a/crates/core/executor/src/syscalls/precompiles/memcopy.rs b/crates/core/executor/src/syscalls/precompiles/memcopy.rs new file mode 100644 index 0000000000..e07a5761f0 --- /dev/null +++ b/crates/core/executor/src/syscalls/precompiles/memcopy.rs @@ -0,0 +1,69 @@ +use std::marker::PhantomData; + +use generic_array::ArrayLength; + +use crate::{ + events::{MemCopyEvent, PrecompileEvent}, + syscalls::{Syscall, SyscallCode, SyscallContext}, +}; + +pub struct MemCopySyscall { + _marker: PhantomData<(NumWords, NumBytes)>, +} + +impl MemCopySyscall { + pub const fn new() -> Self { + Self { _marker: PhantomData } + } +} + +impl Syscall + for MemCopySyscall +{ + fn execute( + &self, + rt: &mut SyscallContext, + syscall_code: SyscallCode, + src: u32, + dst: u32, + ) -> Option { + let start_clk = rt.clk; + let (read, read_bytes) = rt.mr_slice(src, NumWords::USIZE); + + // dst == src is supported, even it is actually a no-op. + rt.clk += 1; + + let write = rt.mw_slice(dst, &read_bytes); + + let event = MemCopyEvent { + lookup_id: rt.syscall_lookup_id, + shard: rt.current_shard(), + clk: start_clk, + src_ptr: src, + dst_ptr: dst, + read_records: read, + write_records: write, + local_mem_access: rt.postprocess(), + }; + let precompile_event = match NumWords::USIZE { + 8 => PrecompileEvent::MemCopy32(event), + 16 => PrecompileEvent::MemCopy64(event), + _ => panic!("invalid uszie {}", NumWords::USIZE), + }; + let syscall_event = rt.rt.syscall_event( + start_clk, + syscall_code.syscall_id(), + src, + dst, + rt.syscall_lookup_id, + ); + + rt.record_mut().add_precompile_event(syscall_code, syscall_event, precompile_event); + + None + } + + fn num_extra_cycles(&self) -> u32 { + 1 + } +} diff --git a/crates/core/executor/src/syscalls/precompiles/mod.rs b/crates/core/executor/src/syscalls/precompiles/mod.rs index f07da94609..b6d2d96b2d 100644 --- a/crates/core/executor/src/syscalls/precompiles/mod.rs +++ b/crates/core/executor/src/syscalls/precompiles/mod.rs @@ -1,6 +1,8 @@ +pub mod bn254_scalar; pub mod edwards; pub mod fptower; pub mod keccak256; +pub mod memcopy; pub mod sha256; pub mod uint256; pub mod weierstrass; diff --git a/crates/core/machine/src/cpu/trace.rs b/crates/core/machine/src/cpu/trace.rs index 01b1489127..a9c42b0bdf 100644 --- a/crates/core/machine/src/cpu/trace.rs +++ b/crates/core/machine/src/cpu/trace.rs @@ -516,9 +516,17 @@ impl CpuChip { } // Write the syscall nonce. - ecall_cols.syscall_nonce = F::from_canonical_u32( - nonce_lookup.get(&event.syscall_lookup_id).copied().unwrap_or_default(), + let syscall_nonce = + nonce_lookup.get(&event.syscall_lookup_id).copied().unwrap_or_default(); + ecall_cols.syscall_nonce = F::from_canonical_u32(syscall_nonce); + + /* + log::info!( + "populate_ecall syscall_lookup_id {} syscall_nonce {} syscall_id {syscall_id:?}", + event.syscall_lookup_id, + syscall_nonce ); + */ is_halt = syscall_id == F::from_canonical_u32(SyscallCode::HALT.syscall_id()); @@ -544,7 +552,8 @@ impl CpuChip { fn pad_to_power_of_two(&self, shape: &Option, values: &mut Vec) { let n_real_rows = values.len() / NUM_CPU_COLS; let padded_nb_rows = if let Some(shape) = shape { - 1 << shape.inner[&MachineAir::::name(self)] + let name = MachineAir::::name(self); + 1 << shape.inner.get(&name).expect(&format!("fail to get shape of {}", name)) } else if n_real_rows < 16 { 16 } else { diff --git a/crates/core/machine/src/lib.rs b/crates/core/machine/src/lib.rs index 16fd571c95..dbd60ed936 100644 --- a/crates/core/machine/src/lib.rs +++ b/crates/core/machine/src/lib.rs @@ -30,7 +30,7 @@ pub mod utils; /// This string should be updated whenever any step in verifying an SP1 proof changes, including /// core, recursion, and plonk-bn254. This string is used to download SP1 artifacts and the gnark /// docker image. -pub const SP1_CIRCUIT_VERSION: &str = "v3.0.0"; +pub const SP1_CIRCUIT_VERSION: &str = "v3.0.0-scroll"; // Re-export the `SP1ReduceProof` struct from sp1_core_machine. // diff --git a/crates/core/machine/src/riscv/cost.rs b/crates/core/machine/src/riscv/cost.rs index 27469ee00a..71ad5f6f04 100644 --- a/crates/core/machine/src/riscv/cost.rs +++ b/crates/core/machine/src/riscv/cost.rs @@ -1,7 +1,7 @@ use p3_baby_bear::BabyBear; use sp1_core_executor::{syscalls::SyscallCode, ExecutionReport, Opcode}; -use crate::riscv::RiscvAirDiscriminants; +use crate::{riscv::RiscvAirDiscriminants, syscall::precompiles::bn254_scalar}; use super::RiscvAir; @@ -122,6 +122,24 @@ impl CostEstimator for ExecutionReport { total_area += (bn254_fp2_mul_events as u64) * costs[&RiscvAirDiscriminants::Bn254Fp2Mul]; total_chips += 1; + let bn254_scalar_mul_events = self.syscall_counts[SyscallCode::BN254_SCALAR_MUL]; + total_area += + (bn254_scalar_mul_events as u64) * costs[&RiscvAirDiscriminants::Bn254ScalarMul]; + total_chips += 1; + + let bn254_scalar_mac_events = self.syscall_counts[SyscallCode::BN254_SCALAR_MAC]; + total_area += + (bn254_scalar_mac_events as u64) * costs[&RiscvAirDiscriminants::Bn254ScalarMac]; + total_chips += 1; + + let mem_copy_32_events = self.syscall_counts[SyscallCode::MEMCPY_32]; + total_area += (mem_copy_32_events as u64) * costs[&RiscvAirDiscriminants::MemCopy32]; + total_chips += 1; + + let mem_copy_64_events = self.syscall_counts[SyscallCode::MEMCPY_64]; + total_area += (mem_copy_64_events as u64) * costs[&RiscvAirDiscriminants::MemCopy64]; + total_chips += 1; + let bls12381_decompress_events = self.syscall_counts[SyscallCode::BLS12381_DECOMPRESS]; total_area += (bls12381_decompress_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Decompress]; diff --git a/crates/core/machine/src/riscv/mod.rs b/crates/core/machine/src/riscv/mod.rs index b2395acf5e..6720f26718 100644 --- a/crates/core/machine/src/riscv/mod.rs +++ b/crates/core/machine/src/riscv/mod.rs @@ -13,7 +13,13 @@ use crate::{ MemoryChipType, MemoryLocalChip, MemoryProgramChip, NUM_LOCAL_MEMORY_ENTRIES_PER_ROW, }, riscv::MemoryChipType::{Finalize, Initialize}, - syscall::precompiles::fptower::{Fp2AddSubAssignChip, Fp2MulAssignChip, FpOpChip}, + syscall::{ + memcpy::{self, MemCopy32Chip, MemCopy64Chip, MemCopyChip}, + precompiles::{ + bn254_scalar::{self, Bn254ScalarMacChip, Bn254ScalarMulChip}, + fptower::{Fp2AddSubAssignChip, Fp2MulAssignChip, FpOpChip}, + }, + }, }; use hashbrown::{HashMap, HashSet}; use p3_field::PrimeField32; @@ -140,6 +146,11 @@ pub enum RiscvAir { Bn254Fp2Mul(Fp2MulAssignChip), /// A precompile for BN-254 fp2 addition/subtraction. Bn254Fp2AddSub(Fp2AddSubAssignChip), + + Bn254ScalarMac(bn254_scalar::Bn254ScalarMacChip), + Bn254ScalarMul(bn254_scalar::Bn254ScalarMulChip), + MemCopy32(memcpy::MemCopy32Chip), + MemCopy64(memcpy::MemCopy64Chip), } impl RiscvAir { @@ -278,6 +289,19 @@ impl RiscvAir { costs.insert(RiscvAirDiscriminants::Bn254Fp2Mul, bn254_fp2_mul.cost()); chips.push(bn254_fp2_mul); + let bn254_scalar_mac = Chip::new(RiscvAir::Bn254ScalarMac(Bn254ScalarMacChip::new())); + costs.insert(RiscvAirDiscriminants::Bn254ScalarMac, bn254_scalar_mac.cost()); + chips.push(bn254_scalar_mac); + let bn254_scalar_mul = Chip::new(RiscvAir::Bn254ScalarMul(Bn254ScalarMulChip::new())); + costs.insert(RiscvAirDiscriminants::Bn254ScalarMul, bn254_scalar_mul.cost()); + chips.push(bn254_scalar_mul); + let mem_copy_32 = Chip::new(RiscvAir::MemCopy32(MemCopy32Chip::new())); + costs.insert(RiscvAirDiscriminants::MemCopy32, mem_copy_32.cost()); + chips.push(mem_copy_32); + let mem_copy_64 = Chip::new(RiscvAir::MemCopy64(MemCopy64Chip::new())); + costs.insert(RiscvAirDiscriminants::MemCopy64, mem_copy_64.cost()); + chips.push(mem_copy_64); + let bls12381_decompress = Chip::new(RiscvAir::Bls12381Decompress(WeierstrassDecompressChip::< SwCurve, @@ -461,6 +485,10 @@ impl RiscvAir { pub(crate) fn syscall_code(&self) -> SyscallCode { match self { + Self::Bn254ScalarMul(_) => SyscallCode::BN254_SCALAR_MUL, + Self::Bn254ScalarMac(_) => SyscallCode::BN254_SCALAR_MAC, + Self::MemCopy32(_) => SyscallCode::MEMCPY_32, + Self::MemCopy64(_) => SyscallCode::MEMCPY_64, Self::Bls12381Add(_) => SyscallCode::BLS12381_ADD, Self::Bn254Add(_) => SyscallCode::BN254_ADD, Self::Bn254Double(_) => SyscallCode::BN254_DOUBLE, diff --git a/crates/core/machine/src/runtime/syscall.rs b/crates/core/machine/src/runtime/syscall.rs index 789b073b46..ee909fa93f 100644 --- a/crates/core/machine/src/runtime/syscall.rs +++ b/crates/core/machine/src/runtime/syscall.rs @@ -4,9 +4,11 @@ use std::sync::Arc; use serde::{Deserialize, Serialize}; use strum_macros::EnumIter; +use typenum::{U16, U32, U64, U8}; use crate::operations::field::field_op::FieldOperation; use crate::runtime::{Register, Runtime}; +use crate::syscall::precompiles::bn254_scalar::{Bn254ScalarMacChip, Bn254ScalarMulChip}; use crate::syscall::precompiles::edwards::EdAddAssignChip; use crate::syscall::precompiles::edwards::EdDecompressChip; use crate::syscall::precompiles::fptower::{Fp2AddSubSyscall, Fp2MulAssignChip, FpOpSyscall}; @@ -17,8 +19,9 @@ use crate::syscall::precompiles::weierstrass::WeierstrassAddAssignChip; use crate::syscall::precompiles::weierstrass::WeierstrassDecompressChip; use crate::syscall::precompiles::weierstrass::WeierstrassDoubleAssignChip; use crate::syscall::{ - SyscallCommit, SyscallCommitDeferred, SyscallEnterUnconstrained, SyscallExitUnconstrained, - SyscallHalt, SyscallHintLen, SyscallHintRead, SyscallVerifySP1Proof, SyscallWrite, + MemCopyChip, SyscallCommit, SyscallCommitDeferred, SyscallEnterUnconstrained, + SyscallExitUnconstrained, SyscallHalt, SyscallHintLen, SyscallHintRead, SyscallVerifySP1Proof, + SyscallWrite, }; use crate::utils::ec::edwards::ed25519::{Ed25519, Ed25519Parameters}; use crate::utils::ec::weierstrass::bls12_381::{Bls12381, Bls12381BaseField}; @@ -141,6 +144,15 @@ pub enum SyscallCode { /// Executes the `BN254_FP2_MUL` precompile. BN254_FP2_MUL = 0x00_01_01_2B, + + /// Execute the `MEMCPY_32` precompile. + MEMCPY_32 = 0x00_01_01_90, + /// Execute the `MEMCPY_64` precompile. + MEMCPY_64 = 0x00_01_01_91, + /// Execute the `BN254_SCALAR_MUL` precompile. + BN254_SCALAR_MUL = 0x00_01_01_80, + /// Execute the `BN254_SCALAR_MAC` precompile. + BN254_SCALAR_MAC = 0x00_01_01_81, } impl SyscallCode { @@ -182,6 +194,10 @@ impl SyscallCode { 0x00_01_01_2A => SyscallCode::BN254_FP2_SUB, 0x00_01_01_2B => SyscallCode::BN254_FP2_MUL, 0x00_00_01_1C => SyscallCode::BLS12381_DECOMPRESS, + 0x00_01_01_90 => SyscallCode::MEMCPY_32, + 0x00_01_01_91 => SyscallCode::MEMCPY_64, + 0x00_01_01_80 => SyscallCode::BN254_SCALAR_MUL, + 0x00_01_01_81 => SyscallCode::BN254_SCALAR_MAC, _ => panic!("invalid syscall number: {}", value), } } @@ -336,22 +352,12 @@ pub fn default_syscall_map() -> HashMap> { syscall_map.insert(SyscallCode::HALT, Arc::new(SyscallHalt {})); syscall_map.insert(SyscallCode::SHA_EXTEND, Arc::new(ShaExtendChip::new())); syscall_map.insert(SyscallCode::SHA_COMPRESS, Arc::new(ShaCompressChip::new())); - syscall_map.insert( - SyscallCode::ED_ADD, - Arc::new(EdAddAssignChip::::new()), - ); - syscall_map.insert( - SyscallCode::ED_DECOMPRESS, - Arc::new(EdDecompressChip::::new()), - ); - syscall_map.insert( - SyscallCode::KECCAK_PERMUTE, - Arc::new(KeccakPermuteChip::new()), - ); - syscall_map.insert( - SyscallCode::SECP256K1_ADD, - Arc::new(WeierstrassAddAssignChip::::new()), - ); + syscall_map.insert(SyscallCode::ED_ADD, Arc::new(EdAddAssignChip::::new())); + syscall_map + .insert(SyscallCode::ED_DECOMPRESS, Arc::new(EdDecompressChip::::new())); + syscall_map.insert(SyscallCode::KECCAK_PERMUTE, Arc::new(KeccakPermuteChip::new())); + syscall_map + .insert(SyscallCode::SECP256K1_ADD, Arc::new(WeierstrassAddAssignChip::::new())); syscall_map.insert( SyscallCode::SECP256K1_DOUBLE, Arc::new(WeierstrassDoubleAssignChip::::new()), @@ -360,18 +366,11 @@ pub fn default_syscall_map() -> HashMap> { SyscallCode::SECP256K1_DECOMPRESS, Arc::new(WeierstrassDecompressChip::::with_lsb_rule()), ); - syscall_map.insert( - SyscallCode::BN254_ADD, - Arc::new(WeierstrassAddAssignChip::::new()), - ); - syscall_map.insert( - SyscallCode::BN254_DOUBLE, - Arc::new(WeierstrassDoubleAssignChip::::new()), - ); - syscall_map.insert( - SyscallCode::BLS12381_ADD, - Arc::new(WeierstrassAddAssignChip::::new()), - ); + syscall_map.insert(SyscallCode::BN254_ADD, Arc::new(WeierstrassAddAssignChip::::new())); + syscall_map + .insert(SyscallCode::BN254_DOUBLE, Arc::new(WeierstrassDoubleAssignChip::::new())); + syscall_map + .insert(SyscallCode::BLS12381_ADD, Arc::new(WeierstrassAddAssignChip::::new())); syscall_map.insert( SyscallCode::BLS12381_DOUBLE, Arc::new(WeierstrassDoubleAssignChip::::new()), @@ -391,15 +390,11 @@ pub fn default_syscall_map() -> HashMap> { ); syscall_map.insert( SyscallCode::BLS12381_FP2_ADD, - Arc::new(Fp2AddSubSyscall::::new( - FieldOperation::Add, - )), + Arc::new(Fp2AddSubSyscall::::new(FieldOperation::Add)), ); syscall_map.insert( SyscallCode::BLS12381_FP2_SUB, - Arc::new(Fp2AddSubSyscall::::new( - FieldOperation::Sub, - )), + Arc::new(Fp2AddSubSyscall::::new(FieldOperation::Sub)), ); syscall_map.insert( SyscallCode::BLS12381_FP2_MUL, @@ -425,28 +420,15 @@ pub fn default_syscall_map() -> HashMap> { SyscallCode::BN254_FP2_SUB, Arc::new(Fp2AddSubSyscall::::new(FieldOperation::Sub)), ); - syscall_map.insert( - SyscallCode::BN254_FP2_MUL, - Arc::new(Fp2MulAssignChip::::new()), - ); - syscall_map.insert( - SyscallCode::ENTER_UNCONSTRAINED, - Arc::new(SyscallEnterUnconstrained::new()), - ); - syscall_map.insert( - SyscallCode::EXIT_UNCONSTRAINED, - Arc::new(SyscallExitUnconstrained::new()), - ); + syscall_map + .insert(SyscallCode::BN254_FP2_MUL, Arc::new(Fp2MulAssignChip::::new())); + syscall_map + .insert(SyscallCode::ENTER_UNCONSTRAINED, Arc::new(SyscallEnterUnconstrained::new())); + syscall_map.insert(SyscallCode::EXIT_UNCONSTRAINED, Arc::new(SyscallExitUnconstrained::new())); syscall_map.insert(SyscallCode::WRITE, Arc::new(SyscallWrite::new())); syscall_map.insert(SyscallCode::COMMIT, Arc::new(SyscallCommit::new())); - syscall_map.insert( - SyscallCode::COMMIT_DEFERRED_PROOFS, - Arc::new(SyscallCommitDeferred::new()), - ); - syscall_map.insert( - SyscallCode::VERIFY_SP1_PROOF, - Arc::new(SyscallVerifySP1Proof::new()), - ); + syscall_map.insert(SyscallCode::COMMIT_DEFERRED_PROOFS, Arc::new(SyscallCommitDeferred::new())); + syscall_map.insert(SyscallCode::VERIFY_SP1_PROOF, Arc::new(SyscallVerifySP1Proof::new())); syscall_map.insert(SyscallCode::HINT_LEN, Arc::new(SyscallHintLen::new())); syscall_map.insert(SyscallCode::HINT_READ, Arc::new(SyscallHintRead::new())); syscall_map.insert( @@ -455,6 +437,11 @@ pub fn default_syscall_map() -> HashMap> { ); syscall_map.insert(SyscallCode::UINT256_MUL, Arc::new(Uint256MulChip::new())); + syscall_map.insert(SyscallCode::BN254_SCALAR_MUL, Arc::new(Bn254ScalarMulChip::new())); + syscall_map.insert(SyscallCode::BN254_SCALAR_MAC, Arc::new(Bn254ScalarMacChip::new())); + syscall_map.insert(SyscallCode::MEMCPY_32, Arc::new(MemCopyChip::::new())); + syscall_map.insert(SyscallCode::MEMCPY_64, Arc::new(MemCopyChip::::new())); + syscall_map } @@ -580,6 +567,14 @@ mod tests { SyscallCode::BN254_FP2_MUL => { assert_eq!(code as u32, sp1_zkvm::syscalls::BN254_FP2_MUL) } + SyscallCode::BN254_SCALAR_MUL => { + assert_eq!(code as u32, sp1_zkvm::syscalls::BN254_SCALAR_MUL) + } + SyscallCode::BN254_SCALAR_MAC => { + assert_eq!(code as u32, sp1_zkvm::syscalls::BN254_SCALAR_MAC) + } + SyscallCode::MEMCPY_32 => todo!(), + SyscallCode::MEMCPY_64 => todo!(), } } } diff --git a/crates/core/machine/src/runtime/utils.rs b/crates/core/machine/src/runtime/utils.rs index 7c0ad541e7..0d45d9c10c 100644 --- a/crates/core/machine/src/runtime/utils.rs +++ b/crates/core/machine/src/runtime/utils.rs @@ -68,6 +68,13 @@ impl<'a> Runtime<'a> { registers[18], ); + log::trace!( + "[clk: {}, pc: 0x{:x}] {:?}", + self.state.global_clk, + self.state.pc, + instruction, + ); + if !self.unconstrained && self.state.global_clk % 10_000_000 == 0 { log::info!( "clk = {} pc = 0x{:x?}", diff --git a/crates/core/machine/src/syscall/memcpy.rs b/crates/core/machine/src/syscall/memcpy.rs new file mode 100644 index 0000000000..8fa16efa66 --- /dev/null +++ b/crates/core/machine/src/syscall/memcpy.rs @@ -0,0 +1,207 @@ +use generic_array::{ArrayLength, GenericArray}; +use sp1_core_executor::events::ByteRecord; +use sp1_core_executor::events::MemCopyEvent; +use sp1_core_executor::events::PrecompileEvent; +use sp1_core_executor::syscalls::{Syscall, SyscallCode, SyscallContext}; +use sp1_core_executor::{ExecutionRecord, Program}; +use sp1_curves::params::Limbs; +use sp1_stark::air::InteractionScope; +use sp1_stark::air::{MachineAir, SP1AirBuilder}; +use std::borrow::{Borrow, BorrowMut}; +use std::marker::PhantomData; + +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::AbstractField; +use p3_field::{Field, PrimeField32}; +use p3_matrix::{dense::RowMajorMatrix, Matrix}; +use sp1_derive::AlignedBorrow; + +use crate::air::MemoryAirBuilder; +use crate::memory::{MemoryReadCols, MemoryWriteCols}; +use crate::utils::pad_rows_fixed; +use crate::utils::{limbs_from_access, limbs_from_prev_access}; + +#[derive(Debug, Clone, AlignedBorrow)] +#[repr(C)] +pub struct MemCopyCols { + is_real: T, + shard: T, + channel: T, + clk: T, + nonce: T, + src_ptr: T, + dst_ptr: T, + src_access: GenericArray, NumWords>, + dst_access: GenericArray, NumWords>, +} + +pub struct MemCopyChip { + _marker: PhantomData<(NumWords, NumBytes)>, +} + +use typenum::{U16, U32, U64, U8}; +pub type MemCopy32Chip = MemCopyChip; +pub type MemCopy64Chip = MemCopyChip; + +impl MemCopyChip { + const NUM_COLS: usize = core::mem::size_of::>(); + + pub fn new() -> Self { + println!("MemCopyChip<{}> NUM_COLS = {}", NumWords::USIZE, Self::NUM_COLS); + assert_eq!(NumWords::USIZE * 4, NumBytes::USIZE); + Self { _marker: PhantomData } + } + + pub fn syscall_id() -> u32 { + match NumBytes::USIZE { + 32 => SyscallCode::MEMCPY_32.syscall_id(), + 64 => SyscallCode::MEMCPY_64.syscall_id(), + _ => unreachable!(), + } + } +} + +impl + MachineAir for MemCopyChip +{ + type Record = ExecutionRecord; + + type Program = Program; + + fn name(&self) -> String { + format!("MemCopy{}Chip", NumWords::USIZE) + } + + fn generate_trace(&self, input: &Self::Record, output: &mut Self::Record) -> RowMajorMatrix { + let mut rows = vec![]; + let mut new_byte_lookup_events = vec![]; + let events = match NumWords::USIZE { + 8 => input.get_precompile_events(SyscallCode::MEMCPY_32), + 16 => input.get_precompile_events(SyscallCode::MEMCPY_64), + _ => unreachable!(), + }; + + for event in events { + let event: &MemCopyEvent = match NumWords::USIZE { + 8 => { + if let (_, PrecompileEvent::MemCopy32(event)) = event { + event + } else { + unreachable!(); + } + } + 16 => { + if let (_, PrecompileEvent::MemCopy64(event)) = event { + event + } else { + unreachable!(); + } + } + _ => unreachable!(), + }; + let mut row = Vec::with_capacity(Self::NUM_COLS); + row.resize(Self::NUM_COLS, F::zero()); + let cols: &mut MemCopyCols = row.as_mut_slice().borrow_mut(); + + cols.is_real = F::one(); + cols.shard = F::from_canonical_u32(event.shard); + cols.clk = F::from_canonical_u32(event.clk); + cols.src_ptr = F::from_canonical_u32(event.src_ptr); + cols.dst_ptr = F::from_canonical_u32(event.dst_ptr); + + //cols.nonce = F::from_canonical_u32( + // output.nonce_lookup.get(&event.lookup_id).copied().expect("should not be none"), + //); + + for i in 0..NumWords::USIZE { + cols.src_access[i].populate(event.read_records[i], &mut new_byte_lookup_events); + } + for i in 0..NumWords::USIZE { + cols.dst_access[i].populate(event.write_records[i], &mut new_byte_lookup_events); + } + + rows.push(row); + } + output.add_byte_lookup_events(new_byte_lookup_events); + + pad_rows_fixed( + &mut rows, + || vec![F::zero(); Self::NUM_COLS], + input.fixed_log2_rows::(self), + ); + + let mut trace = + RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), Self::NUM_COLS); + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut MemCopyCols = + trace.values[i * Self::NUM_COLS..(i + 1) * Self::NUM_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + trace + } + + fn included(&self, shard: &Self::Record) -> bool { + match NumWords::USIZE { + 8 => !shard.get_precompile_events(SyscallCode::MEMCPY_32).is_empty(), + 16 => !shard.get_precompile_events(SyscallCode::MEMCPY_64).is_empty(), + _ => unreachable!(), + } + } +} + +impl BaseAir + for MemCopyChip +{ + fn width(&self) -> usize { + Self::NUM_COLS + } +} + +impl Air + for MemCopyChip +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local: &MemCopyCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &MemCopyCols = (*next).borrow(); + + // Check that nonce is incremented. + builder.when_first_row().assert_zero(local.nonce); + builder.when_transition().assert_eq(local.nonce + AB::Expr::one(), next.nonce); + + let src: Limbs<::Var, NumBytes> = + limbs_from_prev_access(&local.src_access); + let dst: Limbs<::Var, NumBytes> = limbs_from_access(&local.dst_access); + + // TODO assert eq + + builder.eval_memory_access_slice( + local.shard, + local.clk.into(), + local.src_ptr, + &local.src_access, + local.is_real, + ); + builder.eval_memory_access_slice( + local.shard, + local.clk.into() + AB::Expr::one(), + local.dst_ptr, + &local.dst_access, + local.is_real, + ); + + builder.receive_syscall( + local.shard, + local.clk, + local.nonce, + AB::F::from_canonical_u32(Self::syscall_id()), + local.src_ptr, + local.dst_ptr, + local.is_real, + InteractionScope::Local, + ); + } +} diff --git a/crates/core/machine/src/syscall/mod.rs b/crates/core/machine/src/syscall/mod.rs index ab4b7db7fc..1aa64ada0b 100644 --- a/crates/core/machine/src/syscall/mod.rs +++ b/crates/core/machine/src/syscall/mod.rs @@ -1,2 +1,3 @@ pub mod chip; +pub mod memcpy; pub mod precompiles; diff --git a/crates/core/machine/src/syscall/precompiles/bn254_scalar/general_field_op.rs b/crates/core/machine/src/syscall/precompiles/bn254_scalar/general_field_op.rs new file mode 100644 index 0000000000..d657390d75 --- /dev/null +++ b/crates/core/machine/src/syscall/precompiles/bn254_scalar/general_field_op.rs @@ -0,0 +1,81 @@ +use num::BigUint; +use sp1_derive::AlignedBorrow; + +use crate::operations::field::field_op::FieldOperation; +use crate::operations::field::util_air::eval_field_operation; +use crate::{ + air::Polynomial, operations::field::field_op::FieldOpCols, stark::SP1AirBuilder, + utils::ec::field::FieldParameters, +}; +use p3_field::AbstractField; +use p3_field::PrimeField32; + +#[derive(Debug, Clone, AlignedBorrow)] +pub struct GeneralFieldOpCols { + pub is_sub_div: T, + pub is_mul_div: T, + pub cols: FieldOpCols, +} + +impl GeneralFieldOpCols { + pub fn populate(&mut self, a: &BigUint, b: &BigUint, op: FieldOperation) -> BigUint { + let (is_mul_div, is_sub_div) = match op { + FieldOperation::Add => (0, 0), + FieldOperation::Sub => (0, 1), + FieldOperation::Mul => (1, 0), + FieldOperation::Div => (1, 1), + }; + self.is_mul_div = F::from_canonical_u32(is_mul_div); + self.is_sub_div = F::from_canonical_u32(is_sub_div); + self.cols.populate(a, b, op) + } +} +impl GeneralFieldOpCols { + pub fn eval< + AB: SP1AirBuilder, + A: Into> + Clone, + B: Into> + Clone, + OP: Into, + >( + &self, + builder: &mut AB, + a: &A, + b: &B, + op: OP, + ) where + V: Into, + { + let one = AB::Expr::from(AB::F::one()); + let is_sub_div: AB::Expr = self.is_sub_div.into(); + let is_mul_div: AB::Expr = self.is_mul_div.into(); + let not_sub_div = one.clone() - is_sub_div.clone(); + let not_mul_div = one - is_mul_div.clone(); + builder.assert_bool(is_sub_div.clone()); + builder.assert_bool(is_mul_div.clone()); + + // mul: 1 0 + // div: 1 1 + // add: 0 0 + // sub: 0 1 + let assigned_op: AB::Expr = AB::Expr::from(AB::F::from_canonical_u8(0b01)) + * is_sub_div.clone() + + AB::Expr::from(AB::F::from_canonical_u8(0b10)) * is_mul_div.clone(); + builder.assert_eq(assigned_op, op.into()); + + let p_a_param: Polynomial = (*a).clone().into(); + let p_b: Polynomial = (*b).clone().into(); + + let result: Polynomial = self.cols.result.clone().into(); + let p_a = &result * is_sub_div.clone() + &p_a_param * not_sub_div.clone(); + let p_result = &p_a_param * is_sub_div.clone() + &result * not_sub_div.clone(); + let p_carry: Polynomial = self.cols.carry.clone().into(); + let p_op = &p_a * &p_b * is_mul_div.clone() + (&p_a + &p_b) * not_mul_div; + + let p_op_minus_result: Polynomial = p_op - p_result; + let p_limbs = Polynomial::from_iter(P::modulus_field_iter::().map(AB::Expr::from)); + let p_vanishing = p_op_minus_result - &(&p_carry * &p_limbs); + let p_witness_low = self.cols.witness_low.0.iter().into(); + let p_witness_high = self.cols.witness_high.0.iter().into(); + eval_field_operation::(builder, &p_vanishing, &p_witness_low, &p_witness_high); + } +} diff --git a/crates/core/machine/src/syscall/precompiles/bn254_scalar/mac.rs b/crates/core/machine/src/syscall/precompiles/bn254_scalar/mac.rs new file mode 100644 index 0000000000..68b918d6cd --- /dev/null +++ b/crates/core/machine/src/syscall/precompiles/bn254_scalar/mac.rs @@ -0,0 +1,278 @@ +use std::borrow::{Borrow, BorrowMut}; + +use num::BigUint; +use num::Zero; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::AbstractField; +use p3_field::{Field, PrimeField32}; +use p3_matrix::{dense::RowMajorMatrix, Matrix}; +use sp1_core_executor::events::Bn254FieldOperation; +use sp1_core_executor::events::ByteRecord; +use sp1_core_executor::events::FieldOperation; +use sp1_core_executor::events::PrecompileEvent; +use sp1_core_executor::events::NUM_WORDS_PER_FE; +use sp1_core_executor::syscalls::SyscallCode; +use sp1_core_executor::ExecutionRecord; +use sp1_core_executor::Program; +use sp1_curves::params::FieldParameters; +use sp1_curves::params::Limbs; +use sp1_curves::params::NumLimbs; +use sp1_curves::weierstrass::bn254::Bn254ScalarField; +use sp1_derive::AlignedBorrow; +use sp1_stark::air::InteractionScope; +use sp1_stark::air::MachineAir; +use sp1_stark::air::SP1AirBuilder; +use typenum::U8; + +use crate::air::MemoryAirBuilder; +use crate::utils::limbs_from_prev_access; +use crate::utils::pad_rows_fixed; +use crate::{ + memory::{MemoryCols, MemoryReadCols, MemoryWriteCols}, + operations::field::field_op::FieldOpCols, +}; + +const NUM_COLS: usize = core::mem::size_of::>(); +const OP: Bn254FieldOperation = Bn254FieldOperation::Mac; + +#[derive(Debug, Clone, AlignedBorrow)] +#[repr(C)] +pub struct Bn254ScalarMacCols { + is_real: T, + shard: T, + channel: T, + nonce: T, + clk: T, + arg1_ptr: T, + arg2_ptr: T, + arg1_access: [MemoryWriteCols; NUM_WORDS_PER_FE], + arg2_access: [MemoryReadCols; 2], + a_access: [MemoryReadCols; NUM_WORDS_PER_FE], + b_access: [MemoryReadCols; NUM_WORDS_PER_FE], + mul_eval: FieldOpCols, + add_eval: FieldOpCols, +} + +pub struct Bn254ScalarMacChip; + +impl Bn254ScalarMacChip { + pub fn new() -> Self { + Self + } +} + +impl MachineAir for Bn254ScalarMacChip { + type Record = ExecutionRecord; + + type Program = Program; + + fn name(&self) -> String { + "Bn254ScalarMac".to_string() + } + + fn generate_trace(&self, input: &Self::Record, output: &mut Self::Record) -> RowMajorMatrix { + let events = input.get_precompile_events(SyscallCode::BN254_SCALAR_MAC); + + let mut rows = vec![]; + let mut new_byte_lookup_events = vec![]; + + for event in events { + let event = if let (_, PrecompileEvent::Bn254ScalarMac(event)) = event { + event + } else { + unreachable!(); + }; + let mut row = [F::zero(); NUM_COLS]; + let cols: &mut Bn254ScalarMacCols = row.as_mut_slice().borrow_mut(); + + let arg1 = event.arg1.prev_value_as_biguint(); + let a = event.a.as_ref().unwrap().value_as_biguint(); + let b = event.b.as_ref().unwrap().value_as_biguint(); + + cols.is_real = F::one(); + cols.shard = F::from_canonical_u32(event.shard); + cols.clk = F::from_canonical_u32(event.clk); + cols.arg1_ptr = F::from_canonical_u32(event.arg1.ptr); + cols.arg2_ptr = F::from_canonical_u32(event.arg2.ptr); + + //cols.nonce = F::from_canonical_u32( + // output.nonce_lookup.get(&event.lookup_id).copied().expect("should not be none"), + //); + + let mul = cols.mul_eval.populate( + &mut new_byte_lookup_events, + event.shard, + &a, + &b, + FieldOperation::Mul, + ); + cols.add_eval.populate( + &mut new_byte_lookup_events, + event.shard, + &arg1, + &mul, + FieldOperation::Add, + ); + + for i in 0..cols.arg1_access.len() { + cols.arg1_access[i] + .populate(event.arg1.memory_records[i], &mut new_byte_lookup_events); + } + for i in 0..cols.arg2_access.len() { + cols.arg2_access[i] + .populate(event.arg2.memory_records[i], &mut new_byte_lookup_events); + } + for i in 0..cols.a_access.len() { + cols.a_access[i].populate( + event.a.as_ref().unwrap().memory_records[i], + &mut new_byte_lookup_events, + ); + } + for i in 0..cols.b_access.len() { + cols.b_access[i].populate( + event.b.as_ref().unwrap().memory_records[i], + &mut new_byte_lookup_events, + ); + } + + rows.push(row); + } + output.add_byte_lookup_events(new_byte_lookup_events); + + pad_rows_fixed( + &mut rows, + || { + let mut row = [F::zero(); NUM_COLS]; + let cols: &mut Bn254ScalarMacCols = row.as_mut_slice().borrow_mut(); + + let zero = BigUint::zero(); + cols.mul_eval.populate(&mut vec![], 0, &zero, &zero, FieldOperation::Mul); + cols.add_eval.populate(&mut vec![], 0, &zero, &zero, FieldOperation::Add); + + row + }, + input.fixed_log2_rows::(self), + ); + + let mut trace = + RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_COLS); + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut Bn254ScalarMacCols = + trace.values[i * NUM_COLS..(i + 1) * NUM_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace + } + + fn included(&self, shard: &Self::Record) -> bool { + !shard.get_precompile_events(SyscallCode::BN254_SCALAR_MAC).is_empty() + } +} + +impl BaseAir for Bn254ScalarMacChip { + fn width(&self) -> usize { + NUM_COLS + } +} + +impl Air for Bn254ScalarMacChip +where + AB: SP1AirBuilder, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local: &Bn254ScalarMacCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &Bn254ScalarMacCols = (*next).borrow(); + + // Check that nonce is incremented. + builder.when_first_row().assert_zero(local.nonce); + builder.when_transition().assert_eq(local.nonce + AB::Expr::one(), next.nonce); + + builder.assert_bool(local.is_real); + + let arg1: Limbs<::Var, ::Limbs> = + limbs_from_prev_access(&local.arg1_access); + let arg2: Limbs<::Var, U8> = limbs_from_prev_access(&local.arg2_access); + let a: Limbs<::Var, ::Limbs> = + limbs_from_prev_access(&local.a_access); + let b: Limbs<::Var, ::Limbs> = + limbs_from_prev_access(&local.b_access); + + local.mul_eval.eval(builder, &a, &b, FieldOperation::Mul, local.is_real); + local.add_eval.eval( + builder, + &arg1, + &local.mul_eval.result, + FieldOperation::Add, + local.is_real, + ); + + for i in 0..Bn254ScalarField::NB_LIMBS { + builder + .when(local.is_real) + .assert_eq(local.add_eval.result[i], local.arg1_access[i / 4].value()[i % 4]); + } + + builder.eval_memory_access_slice( + local.shard, + local.clk.into() + AB::Expr::one(), + local.arg1_ptr, + &local.arg1_access, + local.is_real, + ); + + builder.eval_memory_access_slice( + local.shard, + local.clk.into(), + local.arg2_ptr, + &local.arg2_access, + local.is_real, + ); + + let a_ptr = arg2.0[0..4] + .iter() + .rev() + .cloned() + .map(|v| v.into()) + .fold(AB::Expr::zero(), |acc, b| acc * AB::Expr::from_canonical_u16(0x100) + b); + + let b_ptr = arg2.0[4..8] + .iter() + .rev() + .cloned() + .map(|v| v.into()) + .fold(AB::Expr::zero(), |acc, b| acc * AB::Expr::from_canonical_u16(0x100) + b); + + builder.eval_memory_access_slice( + local.shard, + local.clk.into(), + a_ptr, + &local.a_access, + local.is_real, + ); + + builder.eval_memory_access_slice( + local.shard, + local.clk.into(), + b_ptr, + &local.b_access, + local.is_real, + ); + + let syscall_id = AB::F::from_canonical_u32(SyscallCode::BN254_SCALAR_MAC.syscall_id()); + builder.receive_syscall( + local.shard, + local.clk, + local.nonce, + syscall_id, + local.arg1_ptr, + local.arg2_ptr, + local.is_real, + InteractionScope::Local, + ); + } +} diff --git a/crates/core/machine/src/syscall/precompiles/bn254_scalar/mod.rs b/crates/core/machine/src/syscall/precompiles/bn254_scalar/mod.rs new file mode 100644 index 0000000000..8315a2ec37 --- /dev/null +++ b/crates/core/machine/src/syscall/precompiles/bn254_scalar/mod.rs @@ -0,0 +1,6 @@ +mod mac; +mod mul; +// mod general_field_op; + +pub use mac::Bn254ScalarMacChip; +pub use mul::Bn254ScalarMulChip; diff --git a/crates/core/machine/src/syscall/precompiles/bn254_scalar/mul.rs b/crates/core/machine/src/syscall/precompiles/bn254_scalar/mul.rs new file mode 100644 index 0000000000..498a7b0d4e --- /dev/null +++ b/crates/core/machine/src/syscall/precompiles/bn254_scalar/mul.rs @@ -0,0 +1,213 @@ +use std::borrow::{Borrow, BorrowMut}; + +use num::BigUint; +use num::Zero; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::AbstractField; +use p3_field::{Field, PrimeField32}; +use p3_matrix::{dense::RowMajorMatrix, Matrix}; +use sp1_core_executor::events::Bn254FieldOperation; +use sp1_core_executor::events::ByteRecord; +use sp1_core_executor::events::PrecompileEvent; +use sp1_core_executor::events::NUM_WORDS_PER_FE; +use sp1_core_executor::syscalls::SyscallCode; +use sp1_core_executor::ExecutionRecord; +use sp1_core_executor::Program; +use sp1_curves::params::FieldParameters; +use sp1_curves::params::Limbs; +use sp1_curves::params::NumLimbs; +use sp1_curves::weierstrass::bn254::Bn254ScalarField; +use sp1_derive::AlignedBorrow; +use sp1_stark::air::InteractionScope; +use sp1_stark::air::MachineAir; +use sp1_stark::air::SP1AirBuilder; + +use crate::air::MemoryAirBuilder; +use crate::utils::limbs_from_prev_access; +use crate::utils::pad_rows_fixed; +use crate::{ + memory::{MemoryCols, MemoryReadCols, MemoryWriteCols}, + operations::field::field_op::FieldOpCols, +}; + +const NUM_COLS: usize = core::mem::size_of::>(); +const OP: Bn254FieldOperation = Bn254FieldOperation::Mul; + +#[derive(Debug, Clone, AlignedBorrow)] +#[repr(C)] +pub struct Bn254ScalarMulCols { + is_real: T, + shard: T, + channel: T, + nonce: T, + clk: T, + p_ptr: T, + q_ptr: T, + p_access: [MemoryWriteCols; NUM_WORDS_PER_FE], + q_access: [MemoryReadCols; NUM_WORDS_PER_FE], + eval: FieldOpCols, +} + +pub struct Bn254ScalarMulChip; + +impl Bn254ScalarMulChip { + pub fn new() -> Self { + Self + } +} + +impl MachineAir for Bn254ScalarMulChip { + type Record = ExecutionRecord; + + type Program = Program; + + fn name(&self) -> String { + "Bn254ScalarMul".to_string() + } + + fn generate_trace(&self, input: &Self::Record, output: &mut Self::Record) -> RowMajorMatrix { + let events = input.get_precompile_events(SyscallCode::BN254_SCALAR_MUL); + + let mut rows = vec![]; + let mut new_byte_lookup_events = vec![]; + + for event in events { + let event = if let (_, PrecompileEvent::Bn254ScalarMul(event)) = event { + event + } else { + unreachable!(); + }; + let mut row = [F::zero(); NUM_COLS]; + let cols: &mut Bn254ScalarMulCols = row.as_mut_slice().borrow_mut(); + + let p = event.arg1.prev_value_as_biguint(); + let q = event.arg2.value_as_biguint(); + + cols.is_real = F::one(); + cols.shard = F::from_canonical_u32(event.shard); + cols.clk = F::from_canonical_u32(event.clk); + cols.p_ptr = F::from_canonical_u32(event.arg1.ptr); + cols.q_ptr = F::from_canonical_u32(event.arg2.ptr); + + //cols.nonce = F::from_canonical_u32( + // output.nonce_lookup.get(&event.lookup_id).copied().expect("should not be none"), + //); + + cols.eval.populate( + &mut new_byte_lookup_events, + event.shard, + &p, + &q, + OP.to_field_operation(), + ); + + for i in 0..cols.p_access.len() { + cols.p_access[i] + .populate(event.arg1.memory_records[i], &mut new_byte_lookup_events); + } + for i in 0..cols.q_access.len() { + cols.q_access[i] + .populate(event.arg2.memory_records[i], &mut new_byte_lookup_events); + } + + rows.push(row); + } + output.add_byte_lookup_events(new_byte_lookup_events); + + pad_rows_fixed( + &mut rows, + || { + let mut row = [F::zero(); NUM_COLS]; + let cols: &mut Bn254ScalarMulCols = row.as_mut_slice().borrow_mut(); + + let zero = BigUint::zero(); + cols.eval.populate(&mut vec![], 0, &zero, &zero, OP.to_field_operation()); + + row + }, + input.fixed_log2_rows::(self), + ); + + let mut trace = + RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_COLS); + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut Bn254ScalarMulCols = + trace.values[i * NUM_COLS..(i + 1) * NUM_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace + } + + fn included(&self, shard: &Self::Record) -> bool { + !shard.get_precompile_events(SyscallCode::BN254_SCALAR_MUL).is_empty() + } +} + +impl BaseAir for Bn254ScalarMulChip { + fn width(&self) -> usize { + NUM_COLS + } +} + +impl Air for Bn254ScalarMulChip +where + AB: SP1AirBuilder, + // AB::Expr: Copy, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local: &Bn254ScalarMulCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &Bn254ScalarMulCols = (*next).borrow(); + + // Check that nonce is incremented. + builder.when_first_row().assert_zero(local.nonce); + builder.when_transition().assert_eq(local.nonce + AB::Expr::one(), next.nonce); + + builder.assert_bool(local.is_real); + + let p: Limbs<::Var, ::Limbs> = + limbs_from_prev_access(&local.p_access); + let q: Limbs<::Var, ::Limbs> = + limbs_from_prev_access(&local.q_access); + + local.eval.eval(builder, &p, &q, OP.to_field_operation(), local.is_real); + + for i in 0..Bn254ScalarField::NB_LIMBS { + builder + .when(local.is_real) + .assert_eq(local.eval.result[i], local.p_access[i / 4].value()[i % 4]); + } + + builder.eval_memory_access_slice( + local.shard, + local.clk.into(), + local.q_ptr, + &local.q_access, + local.is_real, + ); + + builder.eval_memory_access_slice( + local.shard, + local.clk.into() + AB::Expr::one(), + local.p_ptr, + &local.p_access, + local.is_real, + ); + + let syscall_id = AB::F::from_canonical_u32(SyscallCode::BN254_SCALAR_MUL.syscall_id()); + builder.receive_syscall( + local.shard, + local.clk, + local.nonce, + syscall_id, + local.p_ptr, + local.q_ptr, + local.is_real, + InteractionScope::Local, + ); + } +} diff --git a/crates/core/machine/src/syscall/precompiles/mod.rs b/crates/core/machine/src/syscall/precompiles/mod.rs index f07da94609..a93d372dd7 100644 --- a/crates/core/machine/src/syscall/precompiles/mod.rs +++ b/crates/core/machine/src/syscall/precompiles/mod.rs @@ -1,3 +1,4 @@ +pub mod bn254_scalar; pub mod edwards; pub mod fptower; pub mod keccak256; diff --git a/crates/core/machine/src/utils/prove.rs b/crates/core/machine/src/utils/prove.rs index 4b18c2e54c..d93a9bba6d 100644 --- a/crates/core/machine/src/utils/prove.rs +++ b/crates/core/machine/src/utils/prove.rs @@ -313,7 +313,6 @@ where shape_config.fix_shape(record).unwrap(); } } - // Generate the traces. let mut traces = vec![]; tracing::debug_span!("generate traces", index).in_scope(|| { diff --git a/crates/curves/src/weierstrass/bn254.rs b/crates/curves/src/weierstrass/bn254.rs index a2dd2aa32c..346fcfad69 100644 --- a/crates/curves/src/weierstrass/bn254.rs +++ b/crates/curves/src/weierstrass/bn254.rs @@ -48,6 +48,32 @@ impl NumLimbs for Bn254BaseField { type Witness = U62; } +#[derive(Debug, Default, Clone, Copy, PartialEq, Serialize, Deserialize)] +pub struct Bn254ScalarField; + +impl FieldParameters for Bn254ScalarField { + const MODULUS: &'static [u8] = &[ + 1, 0, 0, 240, 147, 245, 225, 67, 145, 112, 185, 121, 72, 232, 51, 40, 93, 88, 129, 129, + 182, 69, 80, 184, 41, 160, 49, 225, 114, 78, 100, 48, + ]; + + // TODO: check this constant + const WITNESS_OFFSET: usize = 1usize << 13; + + fn modulus() -> BigUint { + BigUint::from_str_radix( + "21888242871839275222246405745257275088548364400416034343698204186575808495617", + 10, + ) + .unwrap() + } +} + +impl NumLimbs for Bn254ScalarField { + type Limbs = U32; + type Witness = U62; +} + impl EllipticCurveParameters for Bn254Parameters { type BaseField = Bn254BaseField; diff --git a/crates/primitives/Cargo.toml b/crates/primitives/Cargo.toml index 56ea6e4178..558d8f1ef9 100644 --- a/crates/primitives/Cargo.toml +++ b/crates/primitives/Cargo.toml @@ -19,4 +19,5 @@ p3-baby-bear = { workspace = true } p3-poseidon2 = { workspace = true } p3-symmetric = { workspace = true } serde = { version = "1.0.207", features = ["derive"] } -sha2 = "0.10.8" +#sha2 = "0.10.8" +sha3 = "0.10.8" diff --git a/crates/primitives/src/io.rs b/crates/primitives/src/io.rs index 0d4d89e957..b4c1c5d721 100644 --- a/crates/primitives/src/io.rs +++ b/crates/primitives/src/io.rs @@ -1,7 +1,7 @@ use crate::types::Buffer; use num_bigint::BigUint; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use sha2::{Digest, Sha256}; +use sha3::{Digest, Keccak256}; /// Public values for the prover. #[derive(Debug, Clone, Serialize, Deserialize, Default)] @@ -54,7 +54,7 @@ impl SP1PublicValues { /// Hash the public values. pub fn hash(&self) -> Vec { - let mut hasher = Sha256::new(); + let mut hasher = Keccak256::new(); hasher.update(self.buffer.data.as_slice()); hasher.finalize().to_vec() } @@ -67,7 +67,7 @@ impl SP1PublicValues { /// ``` pub fn hash_bn254(&self) -> BigUint { // Hash the public values. - let mut hasher = Sha256::new(); + let mut hasher = Keccak256::new(); hasher.update(self.buffer.data.as_slice()); let hash_result = hasher.finalize(); let mut hash = hash_result.to_vec(); diff --git a/crates/recursion/gnark-ffi/Cargo.toml b/crates/recursion/gnark-ffi/Cargo.toml index 3276efb80d..afd2212936 100644 --- a/crates/recursion/gnark-ffi/Cargo.toml +++ b/crates/recursion/gnark-ffi/Cargo.toml @@ -33,4 +33,5 @@ cc = "1.1" cfg-if = "1.0" [features] +default = ["native"] native = [] diff --git a/crates/recursion/gnark-ffi/go/sp1/build.go b/crates/recursion/gnark-ffi/go/sp1/build.go index f1893833c2..205548a93f 100644 --- a/crates/recursion/gnark-ffi/go/sp1/build.go +++ b/crates/recursion/gnark-ffi/go/sp1/build.go @@ -96,7 +96,11 @@ func BuildPlonk(dataDir string) { _, err = srsLagrange.ReadFrom(srsLagrangeFile) if err != nil { - panic(err) + srsLagrange = trusted_setup.ToLagrange(scs, srs) + _, err = srsLagrange.WriteTo(srsLagrangeFile) + if err != nil { + panic(err) + } } } diff --git a/crates/sdk/src/provers/cuda.rs b/crates/sdk/src/provers/cuda.rs index 5f8ab983aa..a3ff7812cf 100644 --- a/crates/sdk/src/provers/cuda.rs +++ b/crates/sdk/src/provers/cuda.rs @@ -49,6 +49,11 @@ impl Prover for CudaProver { // Generate the core proof. let proof = self.cuda_prover.prove_core(pk, &stdin)?; + + { + tracing::info!("verify core"); + self.prover.verify(&proof.proof, &pk.vk).expect("prove core failed"); + } if kind == SP1ProofKind::Core { return Ok(SP1ProofWithPublicValues { proof: SP1Proof::Core(proof.proof.0), @@ -64,6 +69,10 @@ impl Prover for CudaProver { // Generate the compressed proof. let reduce_proof = self.cuda_prover.compress(&pk.vk, proof, deferred_proofs)?; + { + tracing::info!("verify compressed"); + self.prover.verify_compressed(&reduce_proof, &pk.vk).expect("prove compressed failed"); + } if kind == SP1ProofKind::Compressed { return Ok(SP1ProofWithPublicValues { proof: SP1Proof::Compressed(Box::new(reduce_proof)), @@ -75,9 +84,17 @@ impl Prover for CudaProver { // Generate the shrink proof. let compress_proof = self.cuda_prover.shrink(reduce_proof)?; + { + tracing::info!("verify shrink"); + self.prover.verify_shrink(&compress_proof, &pk.vk).expect("prove shrink failed"); + } // Genenerate the wrap proof. - let outer_proof = self.cuda_prover.wrap_bn254(compress_proof)?; + let outer_proof = self.prover.wrap_bn254(compress_proof, _opts.sp1_prover_opts)?; + { + tracing::info!("verify wrap bn254"); + self.prover.verify_wrap_bn254(&outer_proof, &pk.vk).expect("prove wrap bn254 failed"); + } if kind == SP1ProofKind::Plonk { let plonk_bn254_artifacts = if sp1_prover::build::sp1_dev_mode() { diff --git a/crates/zkvm/entrypoint/Cargo.toml b/crates/zkvm/entrypoint/Cargo.toml index 320340a7ea..ea06250a45 100644 --- a/crates/zkvm/entrypoint/Cargo.toml +++ b/crates/zkvm/entrypoint/Cargo.toml @@ -15,6 +15,7 @@ getrandom = { version = "0.2.15", features = ["custom"] } rand = "0.8.5" libm = { version = "0.2.8", optional = true } sha2 = { version = "0.10.8" } +sha3 = { version = "0.10.8" } lazy_static = "1.5.0" # optional diff --git a/crates/zkvm/entrypoint/src/lib.rs b/crates/zkvm/entrypoint/src/lib.rs index 46dbed730d..db192e3706 100644 --- a/crates/zkvm/entrypoint/src/lib.rs +++ b/crates/zkvm/entrypoint/src/lib.rs @@ -25,7 +25,7 @@ mod zkvm { use crate::syscalls::syscall_halt; use cfg_if::cfg_if; - use sha2::{Digest, Sha256}; + use sha3::{Digest, Keccak256}; cfg_if! { if #[cfg(feature = "verify")] { @@ -36,12 +36,12 @@ mod zkvm { } } - pub static mut PUBLIC_VALUES_HASHER: Option = None; + pub static mut PUBLIC_VALUES_HASHER: Option = None; #[no_mangle] unsafe extern "C" fn __start() { { - PUBLIC_VALUES_HASHER = Some(Sha256::new()); + PUBLIC_VALUES_HASHER = Some(Keccak256::new()); #[cfg(feature = "verify")] { DEFERRED_PROOFS_DIGEST = Some([BabyBear::zero(); 8]); diff --git a/crates/zkvm/entrypoint/src/memcpy.c b/crates/zkvm/entrypoint/src/memcpy.c new file mode 100644 index 0000000000..1c6df49d31 --- /dev/null +++ b/crates/zkvm/entrypoint/src/memcpy.c @@ -0,0 +1,145 @@ +// clang -target riscv32 -march=rv32im -O3 -S memcpy.c -nostdlib -fno-builtin -funroll-loops +// replace contents start from `memcpy:` to the end by new memcpy.s +// manually add `memcpy` suffix to all labels in memcpy.s +#include +#include + +#define MEMCPY_32 0x00010190 +#define MEMCPY_64 0x00010191 + +void *memcpy(void *restrict dest, const void *restrict src, size_t n) +{ + unsigned char *d = dest; + const unsigned char *s = src; + +#ifdef __GNUC__ +#define LS >> +#define RS << + + typedef uint32_t __attribute__((__may_alias__)) u32; + uint32_t w, x; + + for (; (uintptr_t)s % 4 && n; n--) *d++ = *s++; + + if ((uintptr_t)d % 4 == 0) { + for (; n>=64; s+=64, d+=64, n-=64) { + asm volatile( + "mv t0, %0\n" + "mv a0, %1\n" + "mv a1, %2\n" + "ecall" + : // No output operands + : "r"(MEMCPY_64), "r"(s), "r"(d) + : "t0", "a0", "a1" // Clobbered registers + ); + } + for (; n>=32; s+=32, d+=32, n-=32) { + asm volatile( + "mv t0, %0\n" + "mv a0, %1\n" + "mv a1, %2\n" + "ecall" + : // No output operands + : "r"(MEMCPY_32), "r"(s), "r"(d) + : "t0", "a0", "a1" // Clobbered registers + ); + } + for (; n>=16; s+=16, d+=16, n-=16) { + *(u32 *)(d+0) = *(u32 *)(s+0); + *(u32 *)(d+4) = *(u32 *)(s+4); + *(u32 *)(d+8) = *(u32 *)(s+8); + *(u32 *)(d+12) = *(u32 *)(s+12); + } + if (n&8) { + *(u32 *)(d+0) = *(u32 *)(s+0); + *(u32 *)(d+4) = *(u32 *)(s+4); + d += 8; s += 8; + } + if (n&4) { + *(u32 *)(d+0) = *(u32 *)(s+0); + d += 4; s += 4; + } + if (n&2) { + *d++ = *s++; *d++ = *s++; + } + if (n&1) { + *d = *s; + } + return dest; + } + + if (n >= 32) switch ((uintptr_t)d % 4) { + case 1: + w = *(u32 *)s; + *d++ = *s++; + *d++ = *s++; + *d++ = *s++; + n -= 3; + for (; n>=17; s+=16, d+=16, n-=16) { + x = *(u32 *)(s+1); + *(u32 *)(d+0) = (w LS 24) | (x RS 8); + w = *(u32 *)(s+5); + *(u32 *)(d+4) = (x LS 24) | (w RS 8); + x = *(u32 *)(s+9); + *(u32 *)(d+8) = (w LS 24) | (x RS 8); + w = *(u32 *)(s+13); + *(u32 *)(d+12) = (x LS 24) | (w RS 8); + } + break; + case 2: + w = *(u32 *)s; + *d++ = *s++; + *d++ = *s++; + n -= 2; + for (; n>=18; s+=16, d+=16, n-=16) { + x = *(u32 *)(s+2); + *(u32 *)(d+0) = (w LS 16) | (x RS 16); + w = *(u32 *)(s+6); + *(u32 *)(d+4) = (x LS 16) | (w RS 16); + x = *(u32 *)(s+10); + *(u32 *)(d+8) = (w LS 16) | (x RS 16); + w = *(u32 *)(s+14); + *(u32 *)(d+12) = (x LS 16) | (w RS 16); + } + break; + case 3: + w = *(u32 *)s; + *d++ = *s++; + n -= 1; + for (; n>=19; s+=16, d+=16, n-=16) { + x = *(u32 *)(s+3); + *(u32 *)(d+0) = (w LS 8) | (x RS 24); + w = *(u32 *)(s+7); + *(u32 *)(d+4) = (x LS 8) | (w RS 24); + x = *(u32 *)(s+11); + *(u32 *)(d+8) = (w LS 8) | (x RS 24); + w = *(u32 *)(s+15); + *(u32 *)(d+12) = (x LS 8) | (w RS 24); + } + break; + } + if (n&16) { + *d++ = *s++; *d++ = *s++; *d++ = *s++; *d++ = *s++; + *d++ = *s++; *d++ = *s++; *d++ = *s++; *d++ = *s++; + *d++ = *s++; *d++ = *s++; *d++ = *s++; *d++ = *s++; + *d++ = *s++; *d++ = *s++; *d++ = *s++; *d++ = *s++; + } + if (n&8) { + *d++ = *s++; *d++ = *s++; *d++ = *s++; *d++ = *s++; + *d++ = *s++; *d++ = *s++; *d++ = *s++; *d++ = *s++; + } + if (n&4) { + *d++ = *s++; *d++ = *s++; *d++ = *s++; *d++ = *s++; + } + if (n&2) { + *d++ = *s++; *d++ = *s++; + } + if (n&1) { + *d = *s; + } + return dest; +#endif + + for (; n; n--) *d++ = *s++; + return dest; +} \ No newline at end of file diff --git a/crates/zkvm/entrypoint/src/syscalls/bn254.rs b/crates/zkvm/entrypoint/src/syscalls/bn254.rs index 6ac4e98c1d..8ee6a8fa5d 100644 --- a/crates/zkvm/entrypoint/src/syscalls/bn254.rs +++ b/crates/zkvm/entrypoint/src/syscalls/bn254.rs @@ -50,3 +50,40 @@ pub extern "C" fn syscall_bn254_double(p: *mut [u32; 16]) { #[cfg(not(target_os = "zkvm"))] unreachable!() } + +#[allow(unused_variables)] +#[no_mangle] +pub extern "C" fn syscall_bn254_scalar_mul(p: *mut u32, q: *const u32) { + #[cfg(target_os = "zkvm")] + unsafe { + asm!( + "ecall", + in("t0") crate::syscalls::BN254_SCALAR_MUL, + in("a0") p, + in("a1") q, + ); + } + + #[cfg(not(target_os = "zkvm"))] + unreachable!() +} + +#[allow(unused_variables)] +#[no_mangle] +pub extern "C" fn syscall_bn254_scalar_mac(ret: *mut u32, a: *const u32, b: *const u32) { + let q = &[a, b]; + let q_ptr = q.as_ptr() as *const u32; + #[cfg(target_os = "zkvm")] + unsafe { + asm!( + "ecall", + in("t0") crate::syscalls::BN254_SCALAR_MAC, + in("a0") ret, + in("a1") q_ptr, + + ); + } + + #[cfg(not(target_os = "zkvm"))] + unreachable!() +} diff --git a/crates/zkvm/entrypoint/src/syscalls/mod.rs b/crates/zkvm/entrypoint/src/syscalls/mod.rs index e40728d7ed..b920701609 100644 --- a/crates/zkvm/entrypoint/src/syscalls/mod.rs +++ b/crates/zkvm/entrypoint/src/syscalls/mod.rs @@ -141,3 +141,12 @@ pub const BN254_FP2_SUB: u32 = 0x00_01_01_2A; /// Executes the `BN254_FP2_MUL` precompile. pub const BN254_FP2_MUL: u32 = 0x00_01_01_2B; + +/// Executes the `MEMCPY_32` precompile +pub const MEMCPY_32: u32 = 0x00_01_01_90; +/// Executes the `MEMCPY_64` precompile +pub const MEMCPY_64: u32 = 0x00_01_01_91; +/// Executes the `BN254_SCALAR_MUL` precompile +pub const BN254_SCALAR_MUL: u32 = 0x00_01_01_80; +/// Executes the `BN254_SCALAR_MAC` precompile +pub const BN254_SCALAR_MAC: u32 = 0x00_01_01_81; diff --git a/crates/zkvm/lib/src/lib.rs b/crates/zkvm/lib/src/lib.rs index f849859a8b..c08763e355 100644 --- a/crates/zkvm/lib/src/lib.rs +++ b/crates/zkvm/lib/src/lib.rs @@ -83,6 +83,9 @@ extern "C" { /// Decompresses a BLS12-381 point. pub fn syscall_bls12381_decompress(point: &mut [u8; 96], is_odd: bool); + pub fn syscall_bn254_scalar_mul(p: *mut u32, q: *const u32); + pub fn syscall_bn254_scalar_mac(ret: *mut u32, a: *const u32, b: *const u32); + /// Computes a big integer operation with a modulus. pub fn sys_bigint( result: *mut [u32; 8],