diff --git a/src/bit_encoding/bititer.rs b/src/bit_encoding/bititer.rs index c7e8d811..3aaeb2d5 100644 --- a/src/bit_encoding/bititer.rs +++ b/src/bit_encoding/bititer.rs @@ -221,8 +221,52 @@ impl> BitIter { Ok(FailEntropy::from_byte_array(ret)) } + /// Decode a natural number from bits. + /// + /// If a bound is specified, then the decoding terminates before trying to + /// decode a larger number. + pub fn read_natural(&mut self, bound: Option) -> Result { + decode::decode_natural(self, bound) + } + + /// Accessor for the number of bits which have been read, + /// in total, from this iterator + pub fn n_total_read(&self) -> usize { + self.total_read + } + + /// Consumes the bit iterator, checking that there are no remaining + /// bytes and that any unread bits are zero. + pub fn close(mut self) -> Result<(), CloseError> { + if let Some(first_byte) = self.iter.next() { + return Err(CloseError::TrailingBytes { first_byte }); + } + + debug_assert!(self.read_bits >= 1); + debug_assert!(self.read_bits <= 8); + let n_bits = 8 - self.read_bits; + let masked_padding = self.cached_byte & ((1u8 << n_bits) - 1); + if masked_padding != 0 { + Err(CloseError::IllegalPadding { + masked_padding, + n_bits, + }) + } else { + Ok(()) + } + } +} + +/// Functionality for Boolean iterators to decode Simplicity values. +pub trait ValueDecoder { /// Decode a value from bits, based on the given type. - pub fn read_value(&mut self, ty: &Final) -> Result { + /// + /// Return `None` if there are not enough bits. + fn decode_value(&mut self, ty: &Final) -> Result; +} + +impl> ValueDecoder for I { + fn decode_value(&mut self, ty: &Final) -> Result { enum State<'a> { ProcessType(&'a Final), DoSumL(Arc), @@ -237,7 +281,7 @@ impl> BitIter { State::ProcessType(ty) => match ty.bound() { types::CompleteBound::Unit => result_stack.push(Value::unit()), types::CompleteBound::Sum(ref l, ref r) => { - if self.read_bit()? { + if self.next().ok_or(EarlyEndOfStreamError)? { stack.push(State::DoSumR(Arc::clone(l))); stack.push(State::ProcessType(r)); } else { @@ -269,41 +313,6 @@ impl> BitIter { debug_assert_eq!(result_stack.len(), 1); Ok(result_stack.pop().unwrap()) } - - /// Decode a natural number from bits. - /// - /// If a bound is specified, then the decoding terminates before trying to - /// decode a larger number. - pub fn read_natural(&mut self, bound: Option) -> Result { - decode::decode_natural(self, bound) - } - - /// Accessor for the number of bits which have been read, - /// in total, from this iterator - pub fn n_total_read(&self) -> usize { - self.total_read - } - - /// Consumes the bit iterator, checking that there are no remaining - /// bytes and that any unread bits are zero. - pub fn close(mut self) -> Result<(), CloseError> { - if let Some(first_byte) = self.iter.next() { - return Err(CloseError::TrailingBytes { first_byte }); - } - - debug_assert!(self.read_bits >= 1); - debug_assert!(self.read_bits <= 8); - let n_bits = 8 - self.read_bits; - let masked_padding = self.cached_byte & ((1u8 << n_bits) - 1); - if masked_padding != 0 { - Err(CloseError::IllegalPadding { - masked_padding, - n_bits, - }) - } else { - Ok(()) - } - } } /// Functionality for Boolean iterators to collect their bits or bytes. diff --git a/src/bit_encoding/mod.rs b/src/bit_encoding/mod.rs index 6463933d..35626027 100644 --- a/src/bit_encoding/mod.rs +++ b/src/bit_encoding/mod.rs @@ -13,5 +13,5 @@ mod bitwriter; pub mod decode; pub mod encode; -pub use bititer::{u2, BitCollector, BitIter, CloseError, EarlyEndOfStreamError}; +pub use bititer::{u2, BitCollector, BitIter, CloseError, EarlyEndOfStreamError, ValueDecoder}; pub use bitwriter::{write_to_vec, BitWriter}; diff --git a/src/bit_machine/mod.rs b/src/bit_machine/mod.rs index 8b56873c..5fd7404a 100644 --- a/src/bit_machine/mod.rs +++ b/src/bit_machine/mod.rs @@ -12,10 +12,10 @@ use std::error; use std::fmt; use std::sync::Arc; -use crate::analysis; use crate::jet::{Jet, JetFailed}; use crate::node::{self, RedeemNode}; use crate::types::Final; +use crate::{analysis, ValueDecoder}; use crate::{Cmr, FailEntropy, Value}; use frame::Frame; @@ -360,7 +360,7 @@ impl BitMachine { out_frame.reset_cursor(); let value = out_frame .as_bit_iter(&self.data) - .read_value(&program.arrow().target) + .decode_value(&program.arrow().target) .expect("Decode value of output frame"); Ok(value) diff --git a/src/human_encoding/parse/ast.rs b/src/human_encoding/parse/ast.rs index d34153e2..677ae6bd 100644 --- a/src/human_encoding/parse/ast.rs +++ b/src/human_encoding/parse/ast.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use crate::human_encoding::{Error, ErrorSet, Position, WitnessOrHole}; use crate::jet::Jet; -use crate::{node, types}; +use crate::{node, types, ValueDecoder}; use crate::{BitIter, Cmr, FailEntropy}; use santiago::grammar::{Associativity, Grammar}; use santiago::lexer::{Lexeme, LexerRules}; @@ -647,7 +647,7 @@ fn grammar() -> Grammar> { let ty = types::Final::two_two_n(bit_length.trailing_zeros() as usize); // unwrap ok here since literally every sequence of bits is a valid // value for the given type - let value = iter.read_value(&ty).unwrap(); + let value = iter.decode_value(&ty).unwrap(); Ast::Expression(Expression { inner: ExprInner::Inline(node::Inner::Word(value)), position, diff --git a/src/lib.rs b/src/lib.rs index 6c0a6746..41d4bcfd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,7 +42,7 @@ mod value; pub use bit_encoding::decode; pub use bit_encoding::encode; pub use bit_encoding::{ - u2, BitCollector, BitIter, CloseError as BitIterCloseError, EarlyEndOfStreamError, + u2, BitCollector, BitIter, CloseError as BitIterCloseError, EarlyEndOfStreamError, ValueDecoder, }; pub use bit_encoding::{write_to_vec, BitWriter}; diff --git a/src/node/redeem.rs b/src/node/redeem.rs index 30880419..d4d26af5 100644 --- a/src/node/redeem.rs +++ b/src/node/redeem.rs @@ -4,7 +4,7 @@ use crate::analysis::NodeBounds; use crate::dag::{DagLike, InternalSharing, MaxSharing, PostOrderIterItem}; use crate::jet::Jet; use crate::types::{self, arrow::FinalArrow}; -use crate::{encode, WitnessNode}; +use crate::{encode, ValueDecoder, WitnessNode}; use crate::{Amr, BitIter, BitWriter, Cmr, Error, FirstPassImr, Imr, Value}; use super::{ @@ -299,7 +299,7 @@ impl RedeemNode { ) -> Result { let arrow = data.node.data.arrow(); let target_ty = arrow.target.finalize()?; - self.bits.read_value(&target_ty).map_err(Error::from) + self.bits.decode_value(&target_ty).map_err(Error::from) } fn convert_disconnect(