Skip to content

Commit

Permalink
feat: Add ValueDecoder trait
Browse files Browse the repository at this point in the history
ValueDecoder allows all Boolean iterators to decode Simplicity values.
This is necessary for alternative implementations of the Bit Machine,
such as on the web IDE.

This commit also fixes the code to correctly decode sum values with
padding.

We can extend the trait to cover natural numbers, CMRs and so on in the
future. This commit is limited to decoding values, which is what I need.
  • Loading branch information
uncomputable committed Sep 27, 2024
1 parent 24367c8 commit 0719071
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 51 deletions.
4 changes: 2 additions & 2 deletions jets-bench/src/data_structures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use simplicity::{
hashes::Hash,
hex::FromHex,
types::Final,
BitIter, Error, Value,
BitIter, Error, Value, ValueDecoder,
};

/// Engine to compute SHA256 hash function.
Expand Down Expand Up @@ -63,7 +63,7 @@ pub fn var_len_buf_from_slice(v: &[u8], mut n: usize) -> Result<Value, Error> {
while n > 0 {
let ty = Final::two_two_n(n);
let v = if v.len() >= (1 << (n + 1)) {
let val = iter.read_value(&ty)?;
let val = iter.decode_padded_value(&ty)?;
Value::some(val)
} else {
Value::none(ty)
Expand Down
97 changes: 56 additions & 41 deletions src/bit_encoding/bititer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,52 @@ impl<I: Iterator<Item = u8>> BitIter<I> {
Ok(FailEntropy::from_byte_array(ret))
}

/// Decode a value from bits, based on the given type.
pub fn read_value(&mut self, ty: &Final) -> Result<Value, EarlyEndOfStreamError> {
/// Decode a natural number from bits.
///
/// If a bound is specified, then the decoding terminates before trying to
/// decode a larger number.
pub fn read_natural(&mut self, bound: Option<usize>) -> Result<usize, decode::Error> {
decode::decode_natural(self, bound)
}

/// Accessor for the number of bits which have been read,
/// in total, from this iterator
pub fn n_total_read(&self) -> usize {
self.total_read
}

/// Consumes the bit iterator, checking that there are no remaining
/// bytes and that any unread bits are zero.
pub fn close(mut self) -> Result<(), CloseError> {
if let Some(first_byte) = self.iter.next() {
return Err(CloseError::TrailingBytes { first_byte });
}

debug_assert!(self.read_bits >= 1);
debug_assert!(self.read_bits <= 8);
let n_bits = 8 - self.read_bits;
let masked_padding = self.cached_byte & ((1u8 << n_bits) - 1);
if masked_padding != 0 {
Err(CloseError::IllegalPadding {
masked_padding,
n_bits,
})
} else {
Ok(())
}
}
}

/// Functionality for Boolean iterators to decode Simplicity values.
pub trait ValueDecoder {
/// Decode a value of the given type from its padded bit encoding.
///
/// Return `None` if there are not enough bits.
fn decode_padded_value(&mut self, ty: &Final) -> Result<Value, EarlyEndOfStreamError>;
}

impl<I: Iterator<Item = bool>> ValueDecoder for I {
fn decode_padded_value(&mut self, ty: &Final) -> Result<Value, EarlyEndOfStreamError> {
enum State<'a> {
ProcessType(&'a Final),
DoSumL(Arc<Final>),
Expand All @@ -237,12 +281,18 @@ impl<I: Iterator<Item = u8>> BitIter<I> {
State::ProcessType(ty) => match ty.bound() {
types::CompleteBound::Unit => result_stack.push(Value::unit()),
types::CompleteBound::Sum(ref l, ref r) => {
if self.read_bit()? {
stack.push(State::DoSumR(Arc::clone(l)));
stack.push(State::ProcessType(r));
} else {
if !self.next().ok_or(EarlyEndOfStreamError)? {
for _ in 0..l.pad_left(r) {
let _padding = self.next().ok_or(EarlyEndOfStreamError)?;
}
stack.push(State::DoSumL(Arc::clone(r)));
stack.push(State::ProcessType(l));
} else {
for _ in 0..l.pad_right(r) {
let _padding = self.next().ok_or(EarlyEndOfStreamError)?;
}
stack.push(State::DoSumR(Arc::clone(l)));
stack.push(State::ProcessType(r));
}
}
types::CompleteBound::Product(ref l, ref r) => {
Expand All @@ -269,41 +319,6 @@ impl<I: Iterator<Item = u8>> BitIter<I> {
debug_assert_eq!(result_stack.len(), 1);
Ok(result_stack.pop().unwrap())
}

/// Decode a natural number from bits.
///
/// If a bound is specified, then the decoding terminates before trying to
/// decode a larger number.
pub fn read_natural(&mut self, bound: Option<usize>) -> Result<usize, decode::Error> {
decode::decode_natural(self, bound)
}

/// Accessor for the number of bits which have been read,
/// in total, from this iterator
pub fn n_total_read(&self) -> usize {
self.total_read
}

/// Consumes the bit iterator, checking that there are no remaining
/// bytes and that any unread bits are zero.
pub fn close(mut self) -> Result<(), CloseError> {
if let Some(first_byte) = self.iter.next() {
return Err(CloseError::TrailingBytes { first_byte });
}

debug_assert!(self.read_bits >= 1);
debug_assert!(self.read_bits <= 8);
let n_bits = 8 - self.read_bits;
let masked_padding = self.cached_byte & ((1u8 << n_bits) - 1);
if masked_padding != 0 {
Err(CloseError::IllegalPadding {
masked_padding,
n_bits,
})
} else {
Ok(())
}
}
}

/// Functionality for Boolean iterators to collect their bits or bytes.
Expand Down
2 changes: 1 addition & 1 deletion src/bit_encoding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ mod bitwriter;
pub mod decode;
pub mod encode;

pub use bititer::{u2, BitCollector, BitIter, CloseError, EarlyEndOfStreamError};
pub use bititer::{u2, BitCollector, BitIter, CloseError, EarlyEndOfStreamError, ValueDecoder};
pub use bitwriter::{write_to_vec, BitWriter};
4 changes: 2 additions & 2 deletions src/bit_machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ use std::error;
use std::fmt;
use std::sync::Arc;

use crate::analysis;
use crate::jet::{Jet, JetFailed};
use crate::node::{self, RedeemNode};
use crate::types::Final;
use crate::{analysis, ValueDecoder};
use crate::{Cmr, FailEntropy, Value};
use frame::Frame;

Expand Down Expand Up @@ -360,7 +360,7 @@ impl BitMachine {
out_frame.reset_cursor();
let value = out_frame
.as_bit_iter(&self.data)
.read_value(&program.arrow().target)
.decode_padded_value(&program.arrow().target)
.expect("Decode value of output frame");

Ok(value)
Expand Down
4 changes: 2 additions & 2 deletions src/human_encoding/parse/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::sync::Arc;

use crate::human_encoding::{Error, ErrorSet, Position, WitnessOrHole};
use crate::jet::Jet;
use crate::{node, types};
use crate::{node, types, ValueDecoder};
use crate::{BitIter, Cmr, FailEntropy};
use santiago::grammar::{Associativity, Grammar};
use santiago::lexer::{Lexeme, LexerRules};
Expand Down Expand Up @@ -647,7 +647,7 @@ fn grammar<J: Jet + 'static>() -> Grammar<Ast<J>> {
let ty = types::Final::two_two_n(bit_length.trailing_zeros() as usize);
// unwrap ok here since literally every sequence of bits is a valid
// value for the given type
let value = iter.read_value(&ty).unwrap();
let value = iter.decode_padded_value(&ty).unwrap();
Ast::Expression(Expression {
inner: ExprInner::Inline(node::Inner::Word(value)),
position,
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ mod value;
pub use bit_encoding::decode;
pub use bit_encoding::encode;
pub use bit_encoding::{
u2, BitCollector, BitIter, CloseError as BitIterCloseError, EarlyEndOfStreamError,
u2, BitCollector, BitIter, CloseError as BitIterCloseError, EarlyEndOfStreamError, ValueDecoder,
};
pub use bit_encoding::{write_to_vec, BitWriter};

Expand Down
6 changes: 4 additions & 2 deletions src/node/redeem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::analysis::NodeBounds;
use crate::dag::{DagLike, InternalSharing, MaxSharing, PostOrderIterItem};
use crate::jet::Jet;
use crate::types::{self, arrow::FinalArrow};
use crate::{encode, WitnessNode};
use crate::{encode, ValueDecoder, WitnessNode};
use crate::{Amr, BitIter, BitWriter, Cmr, Error, FirstPassImr, Imr, Value};

use super::{
Expand Down Expand Up @@ -299,7 +299,9 @@ impl<J: Jet> RedeemNode<J> {
) -> Result<Value, Self::Error> {
let arrow = data.node.data.arrow();
let target_ty = arrow.target.finalize()?;
self.bits.read_value(&target_ty).map_err(Error::from)
self.bits
.decode_padded_value(&target_ty)
.map_err(Error::from)
}

fn convert_disconnect(
Expand Down

0 comments on commit 0719071

Please sign in to comment.