Skip to content

Commit

Permalink
Multi-threaded prover part 1 (#1155)
Browse files Browse the repository at this point in the history
* Add clone for Risc0Host

* Adapt Risc0Host for multi-threaded prover

* Fix build

* Fix lint
  • Loading branch information
bkolad authored Nov 15, 2023
1 parent d9b0a17 commit 1286c6b
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 59 deletions.
25 changes: 13 additions & 12 deletions adapters/risc0/src/host.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
//! This module implements the `ZkvmHost` trait for the RISC0 VM.
use std::sync::Mutex;
use risc0_zkvm::serde::to_vec;
use risc0_zkvm::{Executor, ExecutorEnvBuilder, InnerReceipt, Receipt, Session};
use serde::de::DeserializeOwned;
use serde::Serialize;
use sov_rollup_interface::zk::{Zkvm, ZkvmHost};
use sov_rollup_interface::zk::{Proof, Zkvm, ZkvmHost};
#[cfg(feature = "bench")]
use sov_zk_cycle_utils::{cycle_count_callback, get_syscall_name, get_syscall_name_cycles};

Expand All @@ -16,8 +15,9 @@ use crate::Risc0MethodId;

/// A Risc0Host stores a binary to execute in the Risc0 VM, and accumulates hints to be
/// provided to its execution.
#[derive(Clone)]
pub struct Risc0Host<'a> {
env: Mutex<Vec<u32>>,
env: Vec<u32>,
elf: &'a [u8],
}

Expand Down Expand Up @@ -51,7 +51,7 @@ impl<'a> Risc0Host<'a> {
/// This creates the "Session" trace without invoking the heavy cryptographic machinery.
pub fn run_without_proving(&mut self) -> anyhow::Result<Session> {
let env = add_benchmarking_callbacks(ExecutorEnvBuilder::default())
.add_input(&self.env.lock().unwrap())
.add_input(&self.env)
.build()
.unwrap();
let mut executor = Executor::from_elf(env, self.elf)?;
Expand All @@ -65,24 +65,25 @@ impl<'a> Risc0Host<'a> {
}

impl<'a> ZkvmHost for Risc0Host<'a> {
fn add_hint<T: serde::Serialize>(&self, item: T) {
type Guest = Risc0Guest;

fn add_hint<T: serde::Serialize>(&mut self, item: T) {
let serialized = to_vec(&item).expect("Serialization to vec is infallible");
self.env.lock().unwrap().extend_from_slice(&serialized[..]);
self.env.extend_from_slice(&serialized[..]);
}

type Guest = Risc0Guest;

fn simulate_with_hints(&mut self) -> Self::Guest {
Risc0Guest::with_hints(std::mem::take(&mut self.env.lock().unwrap()))
Risc0Guest::with_hints(std::mem::take(&mut self.env))
}

fn run(&mut self, with_proof: bool) -> Result<(), anyhow::Error> {
fn run(&mut self, with_proof: bool) -> Result<Proof, anyhow::Error> {
if with_proof {
self.run()?;
let jurnal = self.run()?.journal;
Ok(Proof::Data(jurnal))
} else {
self.run_without_proving()?;
Ok(Proof::Empty)
}
Ok(())
}
}

Expand Down
16 changes: 8 additions & 8 deletions examples/demo-rollup/provers/risc0/guest-celestia/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub fn main() {
let stf: StfBlueprint<ZkDefaultContext, _, _, Runtime<_, _>, BasicKernel<_>> =
StfBlueprint::new();

let mut stf_verifier = StfVerifier::new(
let stf_verifier = StfVerifier::new(
stf,
CelestiaVerifier {
rollup_namespace: ROLLUP_NAMESPACE,
Expand Down
16 changes: 8 additions & 8 deletions examples/demo-rollup/provers/risc0/guest-mock/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub fn main() {
let stf: StfBlueprint<ZkDefaultContext, _, _, Runtime<_, _>, BasicKernel<_>> =
StfBlueprint::new();

let mut stf_verifier = StfVerifier::new(stf, MockDaVerifier {});
let stf_verifier = StfVerifier::new(stf, MockDaVerifier {});

stf_verifier
.run_block(guest, storage)
Expand Down
37 changes: 24 additions & 13 deletions full-node/sov-stf-runner/src/prover_service/blocking_prover.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;
use std::sync::Mutex;
use std::ops::Deref;
use std::sync::{Arc, Mutex};

use async_trait::async_trait;
use serde::de::DeserializeOwned;
Expand All @@ -11,7 +12,7 @@ use sov_rollup_interface::zk::ZkvmHost;

use super::{Hash, ProverService, ProverServiceError};
use crate::verifier::StateTransitionVerifier;
use crate::{ProofGenConfig, Prover, RollupProverConfig, StateTransitionData};
use crate::{ProofGenConfig, RollupProverConfig, StateTransitionData};

/// Prover that blocks the current thread and creates a ZKP proof.
pub struct BlockingProver<StateRoot, Witness, Da, Vm, V>
Expand All @@ -22,10 +23,12 @@ where
Vm: ZkvmHost,
V: StateTransitionFunction<Vm::Guest, Da::Spec> + Send + Sync,
{
prover: Mutex<Option<Prover<V, Da, Vm>>>,
vm: Vm,
prover_config: Arc<Option<ProofGenConfig<V, Da, Vm>>>,
zk_storage: V::PreState,

#[allow(clippy::type_complexity)]
witness: Mutex<HashMap<Hash, StateTransitionData<StateRoot, Witness, Da::Spec>>>,
zk_storage: V::PreState,
}

impl<StateRoot, Witness, Da, Vm, V> BlockingProver<StateRoot, Witness, Da, Vm, V>
Expand All @@ -45,7 +48,7 @@ where
config: Option<RollupProverConfig>,
zk_storage: V::PreState,
) -> Self {
let prover = config.map(|config| {
let prover_config = config.map(|config| {
let stf_verifier =
StateTransitionVerifier::<V, Da::Verifier, Vm::Guest>::new(zk_stf, da_verifier);

Expand All @@ -55,11 +58,12 @@ where
RollupProverConfig::Prove => ProofGenConfig::Prover,
};

Prover { vm, config }
config
});

Self {
prover: Mutex::new(prover),
vm,
prover_config: Arc::new(prover_config),
witness: Mutex::new(HashMap::new()),
zk_storage,
}
Expand All @@ -72,8 +76,8 @@ where
StateRoot: Serialize + DeserializeOwned + Clone + AsRef<[u8]> + Send + Sync,
Witness: Serialize + DeserializeOwned + Send + Sync,
Da: DaService,
Vm: ZkvmHost,
V: StateTransitionFunction<Vm::Guest, Da::Spec> + Send + Sync,
Vm: ZkvmHost + 'static,
V: StateTransitionFunction<Vm::Guest, Da::Spec> + Send + Sync + 'static,
V::PreState: Clone + Send + Sync,
{
type StateRoot = StateRoot;
Expand All @@ -98,8 +102,7 @@ where
}

async fn prove(&self, block_header_hash: Hash) -> Result<(), ProverServiceError> {
if let Some(Prover { vm, config }) = self.prover.lock().expect("Lock was poisoned").as_mut()
{
if let Some(config) = self.prover_config.clone().deref() {
let transition_data = {
self.witness
.lock()
Expand All @@ -108,16 +111,24 @@ where
.unwrap()
};

let mut vm = self.vm.clone();
vm.add_hint(transition_data);

tracing::info_span!("guest_execution").in_scope(|| match config {
ProofGenConfig::Simulate(verifier) => verifier
.run_block(vm.simulate_with_hints(), self.zk_storage.clone())
.map_err(|e| {
anyhow::anyhow!("Guest execution must succeed but failed with {:?}", e)
})
.map(|_| ()),
ProofGenConfig::Execute => vm.run(false),
ProofGenConfig::Prover => vm.run(true),
ProofGenConfig::Execute => {
let _ = vm.run(false)?;
Ok(())
}
ProofGenConfig::Prover => {
let _ = vm.run(true)?;
Ok(())
}
})?;
}

Expand Down
11 changes: 0 additions & 11 deletions full-node/sov-stf-runner/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,6 @@ where
Prover,
}

/// A prover for the demo rollup. Consists of a VM and a config
pub struct Prover<Stf, Da: DaService, Vm: ZkvmHost>
where
Stf: StateTransitionFunction<Vm::Guest, Da::Spec>,
{
/// The Zkvm Host to use
pub vm: Vm,
/// The prover configuration
pub config: ProofGenConfig<Stf, Da, Vm>,
}

impl<Stf, Sm, Da, Vm, Ps> StateTransitionRunner<Stf, Sm, Da, Vm, Ps>
where
Da: DaService<Error = anyhow::Error> + Clone + Send + Sync + 'static,
Expand Down
3 changes: 2 additions & 1 deletion full-node/sov-stf-runner/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ where
da_verifier: Da,
phantom: PhantomData<Zk>,
}

impl<Stf, Da, Zk> StateTransitionVerifier<Stf, Da, Zk>
where
Da: DaVerifier,
Expand All @@ -34,7 +35,7 @@ where

/// Verify the next block
pub fn run_block(
&mut self,
&self,
zkvm: Zk,
pre_state: Stf::PreState,
) -> Result<Stf::StateRoot, Da::Error> {
Expand Down
17 changes: 13 additions & 4 deletions rollup-interface/src/state_machine/zk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//! For a detailed example showing how to implement these traits, see the
//! [risc0 adapter](https://github.com/Sovereign-Labs/sovereign-sdk/tree/main/adapters/risc0)
//! maintained by the Sovereign Labs team.
use alloc::vec::Vec;
use core::fmt::Debug;

use borsh::{BorshDeserialize, BorshSerialize};
Expand All @@ -16,12 +17,20 @@ use serde::{Deserialize, Serialize};
use crate::da::DaSpec;
use crate::RollupAddress;

/// The ZK proof generated by the [`ZkvmHost::run`] method.
pub enum Proof {
/// Proof generation was skipped.
Empty,
/// The serialized ZK proof.
Data(Vec<u8>),
}

/// A trait implemented by the prover ("host") of a zkVM program.
pub trait ZkvmHost: Zkvm {
pub trait ZkvmHost: Zkvm + Clone {
/// The associated guest type
type Guest: ZkvmGuest;
/// Give the guest a piece of advice non-deterministically
fn add_hint<T: Serialize>(&self, item: T);
fn add_hint<T: Serialize>(&mut self, item: T);

/// Simulate running the guest using the provided hints.
///
Expand All @@ -34,7 +43,7 @@ pub trait ZkvmHost: Zkvm {
/// This runs the guest binary compiled for the zkVM target, optionally
/// creating a SNARK of correct execution. Running the true guest binary comes
/// with some mild performance overhead and is not as easy to debug as [`simulate_with_hints`](ZkvmHost::simulate_with_hints).
fn run(&mut self, with_proof: bool) -> Result<(), anyhow::Error>;
fn run(&mut self, with_proof: bool) -> Result<Proof, anyhow::Error>;
}

/// A Zk proof system capable of proving and verifying arbitrary Rust code
Expand Down Expand Up @@ -71,7 +80,7 @@ pub trait Zkvm: Send + Sync {
}

/// A trait which is accessible from within a zkVM program.
pub trait ZkvmGuest: Zkvm + Send {
pub trait ZkvmGuest: Zkvm + Send + Sync {
/// Obtain "advice" non-deterministically from the host
fn read_from_host<T: DeserializeOwned>(&self) -> T;
/// Add a public output to the zkVM proof
Expand Down

0 comments on commit 1286c6b

Please sign in to comment.