Skip to content

Commit

Permalink
feat: improve transaction building (#39)
Browse files Browse the repository at this point in the history
* feat: support custom ledger impl
* fix: set estimated_fee to 2 ada
* fix: calculate change correctly
* feat: support search_utxos on ExtLedgerFacade
* feat: give UtxoSet an iterator
* feat: give UtxoPattern a default
  • Loading branch information
SupernaviX authored Nov 20, 2024
1 parent cae513c commit 864ab73
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 52 deletions.
15 changes: 15 additions & 0 deletions balius-runtime/src/ledgers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
use std::sync::Arc;

use tokio::sync::Mutex;

use crate::wit::balius::app::ledger as wit;

pub mod mock;
pub mod u5c;

pub use wit::{Host as CustomLedger, LedgerError, TxoRef, Utxo, UtxoPage, UtxoPattern};

#[derive(Clone)]
pub enum Ledger {
Mock(mock::Ledger),
U5C(u5c::Ledger),
Custom(Arc<Mutex<dyn wit::Host + Send + Sync>>),
}

impl From<mock::Ledger> for Ledger {
Expand All @@ -30,6 +37,10 @@ impl wit::Host for Ledger {
match self {
Ledger::Mock(ledger) => ledger.read_utxos(refs).await,
Ledger::U5C(ledger) => ledger.read_utxos(refs).await,
Ledger::Custom(ledger) => {
let mut lock = ledger.lock().await;
lock.read_utxos(refs).await
}
}
}

Expand All @@ -42,6 +53,10 @@ impl wit::Host for Ledger {
match self {
Ledger::Mock(ledger) => ledger.search_utxos(pattern, start, max_items).await,
Ledger::U5C(ledger) => ledger.search_utxos(pattern, start, max_items).await,
Ledger::Custom(ledger) => {
let mut lock = ledger.lock().await;
lock.search_utxos(pattern, start, max_items).await
}
}
}
}
1 change: 1 addition & 0 deletions balius-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub mod ledgers;
pub mod submit;

pub use store::Store;
pub use wit::Response;

pub type WorkerId = String;

Expand Down
57 changes: 56 additions & 1 deletion balius-sdk/src/txbuilder/asset_math.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use pallas_crypto::hash::Hash;
use pallas_primitives::{
conway::{self, Value},
NonEmptyKeyValuePairs, NonZeroInt, PositiveCoin,
AssetName, NonEmptyKeyValuePairs, NonZeroInt, PolicyId, PositiveCoin,
};
use std::collections::{hash_map::Entry, HashMap};

Expand Down Expand Up @@ -88,6 +88,61 @@ pub fn aggregate_values(items: impl IntoIterator<Item = Value>) -> Value {
}
}

pub fn subtract_value(lhs: &Value, rhs: &Value) -> Result<Value, BuildError> {
let (lhs_coin, lhs_assets) = match lhs {
Value::Coin(c) => (*c, vec![]),
Value::Multiasset(c, a) => (*c, a.iter().collect()),
};

let (rhs_coin, mut rhs_assets) = match rhs {
Value::Coin(c) => (*c, HashMap::new()),
Value::Multiasset(c, a) => {
let flattened: HashMap<(&PolicyId, &AssetName), u64> = a
.iter()
.flat_map(|(policy, assets)| {
assets
.iter()
.map(move |(name, value)| ((policy, name), value.into()))
})
.collect();
(*c, flattened)
}
};

let Some(final_coin) = lhs_coin.checked_sub(rhs_coin) else {
return Err(BuildError::OutputsTooHigh);
};

let mut final_assets = vec![];
for (policy, assets) in lhs_assets {
let mut policy_assets = vec![];
for (name, value) in assets.iter() {
let lhs_value: u64 = value.into();
let rhs_value: u64 = rhs_assets.remove(&(policy, name)).unwrap_or_default();
let Some(final_value) = lhs_value.checked_sub(rhs_value) else {
return Err(BuildError::OutputsTooHigh);
};
if let Ok(final_coin) = final_value.try_into() {
policy_assets.push((name.clone(), final_coin));
}
}
if let Some(assets) = NonEmptyKeyValuePairs::from_vec(policy_assets) {
final_assets.push((*policy, assets));
}
}

if !rhs_assets.is_empty() {
// We have an output which didn't come from any inputs
return Err(BuildError::OutputsTooHigh);
}

if let Some(assets) = NonEmptyKeyValuePairs::from_vec(final_assets) {
Ok(Value::Multiasset(final_coin, assets))
} else {
Ok(Value::Coin(final_coin))
}
}

fn try_to_mint<F>(
assets: conway::Multiasset<PositiveCoin>,
f: F,
Expand Down
109 changes: 106 additions & 3 deletions balius-sdk/src/txbuilder/build.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use std::{collections::HashMap, ops::Deref as _};

use pallas_traverse::MultiEraValue;

use super::{
primitives, BuildContext, BuildError, Ledger, PParams, TxExpr, TxoRef, UtxoPattern, UtxoSet,
asset_math, primitives, BuildContext, BuildError, Ledger, PParams, TxExpr, TxoRef, UtxoPattern,
UtxoSet,
};

impl BuildContext {
Expand Down Expand Up @@ -49,7 +54,25 @@ impl crate::txbuilder::Ledger for ExtLedgerFacade {
}

fn search_utxos(&self, pattern: &UtxoPattern) -> Result<UtxoSet, BuildError> {
todo!()
let pattern = pattern.clone().into();
let mut utxos = HashMap::new();
let max_items = 32;
let mut utxo_page = Some(crate::wit::balius::app::ledger::search_utxos(
&pattern, None, max_items,
)?);
while let Some(page) = utxo_page.take() {
for utxo in page.utxos {
utxos.insert(utxo.ref_.into(), utxo.body);
}
if let Some(next) = page.next_token {
utxo_page = Some(crate::wit::balius::app::ledger::search_utxos(
&pattern,
Some(&next),
max_items,
)?);
}
}
Ok(utxos.into())
}
}

Expand All @@ -65,13 +88,34 @@ where
min_fee_b: 3,
min_utxo_value: 2,
},
estimated_fee: 1,
total_input: primitives::Value::Coin(0),
spent_output: primitives::Value::Coin(0),
estimated_fee: 0,
ledger: Box::new(ledger),
tx_body: None,
};

// Build the raw transaction, so we have the info needed to estimate fees and
// compute change.
let body = tx.eval_body(&ctx)?;

let input_refs: Vec<_> = body
.inputs
.iter()
.map(|i| TxoRef {
hash: i.transaction_id,
index: i.index,
})
.collect();
let utxos = ctx.ledger.read_utxos(&input_refs)?;
ctx.total_input =
asset_math::aggregate_values(utxos.txos().map(|txo| input_into_conway(&txo.value())));
ctx.spent_output = asset_math::aggregate_values(body.outputs.iter().map(output_into_conway));
// TODO: estimate the fee
ctx.estimated_fee = 2_000_000;

// Now that we know the inputs/outputs/fee, build the "final" (unsigned)tx
let body = tx.eval_body(&ctx)?;
ctx.tx_body = Some(body);

let wit = tx.eval_witness_set(&ctx).unwrap();
Expand All @@ -85,3 +129,62 @@ where

Ok(tx)
}

// TODO: this belongs in pallas-traverse
// https://github.com/txpipe/pallas/pull/545
fn input_into_conway(value: &MultiEraValue) -> primitives::Value {
use pallas_primitives::{alonzo, conway};
match value {
MultiEraValue::Byron(x) => conway::Value::Coin(*x),
MultiEraValue::AlonzoCompatible(x) => match x.deref() {
alonzo::Value::Coin(x) => conway::Value::Coin(*x),
alonzo::Value::Multiasset(x, assets) => {
let coin = *x;
let assets = assets
.iter()
.filter_map(|(k, v)| {
let v: Vec<(conway::Bytes, conway::PositiveCoin)> = v
.iter()
.filter_map(|(k, v)| Some((k.clone(), (*v).try_into().ok()?)))
.collect();
Some((k.clone(), conway::NonEmptyKeyValuePairs::from_vec(v)?))
})
.collect();
if let Some(assets) = conway::NonEmptyKeyValuePairs::from_vec(assets) {
conway::Value::Multiasset(coin, assets)
} else {
conway::Value::Coin(coin)
}
}
},
MultiEraValue::Conway(x) => x.deref().clone(),
_ => panic!("unrecognized value"),
}
}

fn output_into_conway(output: &primitives::TransactionOutput) -> primitives::Value {
use pallas_primitives::{alonzo, conway};
match output {
primitives::TransactionOutput::Legacy(o) => match &o.amount {
alonzo::Value::Coin(c) => primitives::Value::Coin(*c),
alonzo::Value::Multiasset(c, assets) => {
let assets = assets
.iter()
.filter_map(|(k, v)| {
let v: Vec<(conway::Bytes, conway::PositiveCoin)> = v
.iter()
.filter_map(|(k, v)| Some((k.clone(), (*v).try_into().ok()?)))
.collect();
Some((k.clone(), conway::NonEmptyKeyValuePairs::from_vec(v)?))
})
.collect();
if let Some(assets) = conway::NonEmptyKeyValuePairs::from_vec(assets) {
primitives::Value::Multiasset(*c, assets)
} else {
primitives::Value::Coin(*c)
}
}
},
primitives::TransactionOutput::PostAlonzo(o) => o.value.clone(),
}
}
98 changes: 52 additions & 46 deletions balius-sdk/src/txbuilder/dsl.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use pallas_primitives::conway;
use pallas_traverse::{MultiEraOutput, MultiEraValue};
use pallas_traverse::MultiEraOutput;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
use std::{
collections::{HashMap, HashSet},
ops::Deref as _,
};
use std::collections::{HashMap, HashSet};

use super::*;

Expand All @@ -27,6 +24,10 @@ impl UtxoSet {
self.0.is_empty()
}

pub fn iter(&self) -> impl Iterator<Item = (&TxoRef, MultiEraOutput<'_>)> {
self.0.iter().map(|(k, v)| (k, MultiEraOutput::decode(pallas_traverse::Era::Conway, v).unwrap()))
}

pub fn refs(&self) -> impl Iterator<Item = &TxoRef> {
self.0.keys()
}
Expand Down Expand Up @@ -65,8 +66,48 @@ impl Ledger for UtxoSet {
}
}

#[derive(Clone, Default, Serialize, Deserialize)]
pub struct UtxoPattern {
pub address: Option<AddressPattern>,
pub asset: Option<AssetPattern>,
}

impl From<UtxoPattern> for crate::wit::balius::app::ledger::UtxoPattern {
fn from(value: UtxoPattern) -> Self {
Self {
address: value.address.map(Into::into),
asset: value.asset.map(Into::into),
}
}
}

#[derive(Clone, Serialize, Deserialize)]
pub struct UtxoPattern;
pub struct AddressPattern {
pub exact_address: Vec<u8>,
}

impl From<AddressPattern> for crate::wit::balius::app::ledger::AddressPattern {
fn from(value: AddressPattern) -> Self {
Self {
exact_address: value.exact_address,
}
}
}

#[derive(Clone, Serialize, Deserialize)]
pub struct AssetPattern {
pub policy: Vec<u8>,
pub name: Option<Vec<u8>>,
}

impl From<AssetPattern> for crate::wit::balius::app::ledger::AssetPattern {
fn from(value: AssetPattern) -> Self {
Self {
policy: value.policy,
name: value.name,
}
}
}

pub trait InputExpr: 'static + Send + Sync {
fn eval(&self, ctx: &BuildContext) -> Result<Vec<conway::TransactionInput>, BuildError>;
Expand Down Expand Up @@ -384,59 +425,24 @@ impl AddressExpr for ChangeAddress {
}
}

pub struct TotalLovelaceMinusFee(pub UtxoSource);
pub struct TotalChange;

impl ValueExpr for TotalLovelaceMinusFee {
impl ValueExpr for TotalChange {
fn eval(&self, ctx: &BuildContext) -> Result<conway::Value, BuildError> {
let utxo_set = &self.0.resolve(ctx)?;
let values = utxo_set.txos().map(|o| into_conway(&o.value()));
let total = asset_math::aggregate_values(values);

let change = asset_math::subtract_value(&ctx.total_input, &ctx.spent_output)?;
let fee = ctx.estimated_fee;
let diff = asset_math::value_saturating_add_coin(total, -(fee as i64));

let diff = asset_math::value_saturating_add_coin(change, -(fee as i64));
Ok(diff)
}
}

// TODO: this belongs in pallas-traverse
// https://github.com/txpipe/pallas/pull/545
fn into_conway(value: &MultiEraValue) -> conway::Value {
match value {
MultiEraValue::Byron(x) => conway::Value::Coin(*x),
MultiEraValue::AlonzoCompatible(x) => match x.deref() {
pallas_primitives::alonzo::Value::Coin(x) => conway::Value::Coin(*x),
pallas_primitives::alonzo::Value::Multiasset(x, assets) => {
let coin = *x;
let assets = assets
.iter()
.filter_map(|(k, v)| {
let v: Vec<(conway::Bytes, conway::PositiveCoin)> = v
.iter()
.filter_map(|(k, v)| Some((k.clone(), (*v).try_into().ok()?)))
.collect();
Some((k.clone(), conway::NonEmptyKeyValuePairs::from_vec(v)?))
})
.collect();
if let Some(assets) = conway::NonEmptyKeyValuePairs::from_vec(assets) {
conway::Value::Multiasset(coin, assets)
} else {
conway::Value::Coin(coin)
}
}
},
MultiEraValue::Conway(x) => x.deref().clone(),
_ => panic!("unrecognized value"),
}
}

pub struct FeeChangeReturn(pub UtxoSource);

impl OutputExpr for FeeChangeReturn {
fn eval(&self, ctx: &BuildContext) -> Result<conway::TransactionOutput, BuildError> {
OutputBuilder::new()
.address(ChangeAddress(self.0.clone()))
.with_value(TotalLovelaceMinusFee(self.0.clone()))
.with_value(TotalChange)
.eval(ctx)
}
}
Expand Down
Loading

0 comments on commit 864ab73

Please sign in to comment.