Skip to content

Commit

Permalink
fix(ARCO-212): Cumulative fee validation uses GetMempoolAncestors fun…
Browse files Browse the repository at this point in the history
…ction
  • Loading branch information
boecklim committed Nov 14, 2024
1 parent c3bdb0d commit cd5ee3e
Show file tree
Hide file tree
Showing 10 changed files with 292 additions and 444 deletions.
2 changes: 1 addition & 1 deletion cmd/arc/services/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func StartAPIServer(logger *slog.Logger, arcConfig *config.ArcConfig) (func(), e
// get the transaction from the bitcoin node rpc
bitcoinClient, err := bitcoin.NewFromURL(rpcURL, false)
if err != nil {
return nil, fmt.Errorf("failed to create node client: %w", err)
return nil, fmt.Errorf("failed to create bitcoin client: %w", err)
}

nodeClient, err := node_client.New(bitcoinClient, nodeClientOpts...)
Expand Down
13 changes: 10 additions & 3 deletions internal/validator/default/default_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,16 @@ func cumulativeCheckFees(ctx context.Context, txFinder validator.TxFinderI, tx *
cumulativeSize := 0
cumulativePaidFee := uint64(0)

for _, tx := range txSet {
cumulativeSize += tx.Size()
cumulativePaidFee += tx.TotalInputSatoshis() - tx.TotalOutputSatoshis()
for _, txFromSet := range txSet {
cumulativeSize += txFromSet.Size()
totalInput := txFromSet.TotalInputSatoshis()
totalOutput := txFromSet.TotalOutputSatoshis()

if totalOutput > totalInput {
return validator.NewError(fmt.Errorf("total outputs %d is larger than total inputs %d for tx %s", totalOutput, totalInput, tx.TxID()), api.ErrStatusCumulativeFees)
}

cumulativePaidFee += totalInput - totalOutput
}

expectedFee, err := feeModel.ComputeFeeBasedOnSize(uint64(cumulativeSize))
Expand Down
138 changes: 46 additions & 92 deletions internal/validator/default/default_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,13 @@ func TestValidator(t *testing.T) {
t.Run("valid Raw Format tx - expect success", func(t *testing.T) {
// given
txFinder := mocks.TxFinderIMock{
GetRawTxsFunc: func(_ context.Context, _ validation.FindSourceFlag, _ []string) ([]validation.RawTx, error) {
res := []validation.RawTx{fixture.ParentTx1, fixture.ParentTx2}
GetRawTxsFunc: func(_ context.Context, _ validation.FindSourceFlag, _ []string) ([]*sdkTx.Transaction, error) {
res := []*sdkTx.Transaction{fixture.ParentTx1}
return res, nil
},
}

rawTx, _ := sdkTx.NewTransactionFromHex(fixture.ValidTxRawHex)
rawTx := fixture.ValidTx

sut := New(getPolicy(5), &txFinder)

Expand Down Expand Up @@ -352,139 +352,93 @@ func TestNeedExtension(t *testing.T) {
}

func TestCumulativeCheckFees(t *testing.T) {
txMap := map[string]*sdkTx.Transaction{
fixture.ParentTxID1: fixture.ParentTx1,
fixture.AncestorTxID1: fixture.AncestorTx1,
fixture.AncestorOfAncestorTx1ID1: fixture.AncestorOfAncestor1Tx1,
}

tcs := []struct {
name string
hex string
feeModel *fees.SatoshisPerKilobyte
getTxFinderFn func(t *testing.T) mocks.TxFinderIMock
name string
feeModel *fees.SatoshisPerKilobyte
mempoolAncestors []string
getMempoolAncestorsErr error
getRawTxsErr error

expectedErr *validation.Error
}{
{
name: "no unmined ancestors - valid fee",
hex: fixture.ValidTxRawHex,
feeModel: func() *fees.SatoshisPerKilobyte {
return &fees.SatoshisPerKilobyte{Satoshis: 1}
}(),
getTxFinderFn: func(_ *testing.T) mocks.TxFinderIMock {
return mocks.TxFinderIMock{
GetRawTxsFunc: func(_ context.Context, _ validation.FindSourceFlag, _ []string) ([]validation.RawTx, error) {
return []validation.RawTx{fixture.ParentTx1, fixture.ParentTx2}, nil
},
}
},
mempoolAncestors: []string{},
},
{
name: "no unmined ancestors - to low fee",
hex: fixture.ValidTxRawHex,
name: "no unmined ancestors - too low fee",
feeModel: func() *fees.SatoshisPerKilobyte {
return &fees.SatoshisPerKilobyte{Satoshis: 50}
}(),
getTxFinderFn: func(_ *testing.T) mocks.TxFinderIMock {
return mocks.TxFinderIMock{
GetRawTxsFunc: func(_ context.Context, _ validation.FindSourceFlag, _ []string) ([]validation.RawTx, error) {
return []validation.RawTx{fixture.ParentTx1, fixture.ParentTx2}, nil
},
}
},
mempoolAncestors: []string{},

expectedErr: validation.NewError(ErrTxFeeTooLow, api.ErrStatusCumulativeFees),
},
{
name: "cumulative fees too low",
hex: fixture.ValidTxRawHex,
feeModel: func() *fees.SatoshisPerKilobyte {
return &fees.SatoshisPerKilobyte{Satoshis: 50}
}(),
getTxFinderFn: func(t *testing.T) mocks.TxFinderIMock {
var getRawTxCount = 0
var counterPtr = &getRawTxCount

return mocks.TxFinderIMock{
GetRawTxsFunc: func(_ context.Context, _ validation.FindSourceFlag, _ []string) ([]validation.RawTx, error) {
i := *counterPtr
*counterPtr = i + 1

if i == 0 {
p1 := validation.RawTx{
TxID: fixture.ParentTx1.TxID,
Bytes: fixture.ParentTx1.Bytes,
IsMined: false,
}
return []validation.RawTx{p1, fixture.ParentTx2}, nil
}
mempoolAncestors: []string{fixture.AncestorTxID1},

if i == 1 {
return []validation.RawTx{fixture.AncestorTx1, fixture.AncestorTx2}, nil
}

t.Fatal("to many calls")
return nil, nil
},
}
},
expectedErr: validation.NewError(ErrTxFeeTooLow, api.ErrStatusCumulativeFees),
},
{
name: "cumulative fees sufficient",
hex: fixture.ValidTxRawHex,
feeModel: func() *fees.SatoshisPerKilobyte {
return &fees.SatoshisPerKilobyte{Satoshis: 1}
}(),
getTxFinderFn: func(t *testing.T) mocks.TxFinderIMock {
var getRawTxCount = 0
var counterPtr = &getRawTxCount

return mocks.TxFinderIMock{
GetRawTxsFunc: func(_ context.Context, _ validation.FindSourceFlag, _ []string) ([]validation.RawTx, error) {
i := *counterPtr
*counterPtr = i + 1

if i == 0 {
p1 := validation.RawTx{
TxID: fixture.ParentTx1.TxID,
Bytes: fixture.ParentTx1.Bytes,
IsMined: false,
}
return []validation.RawTx{p1, fixture.ParentTx2}, nil
}

if i == 1 {
return []validation.RawTx{fixture.AncestorTx1, fixture.AncestorTx2}, nil
}

t.Fatal("to many calls")
return nil, nil
},
}
},
mempoolAncestors: []string{fixture.AncestorTxID1},
},
{
name: "issue with getUnminedAncestors",
hex: fixture.ValidTxRawHex,
name: "failed to get mempool ancestors",
feeModel: func() *fees.SatoshisPerKilobyte {
return &fees.SatoshisPerKilobyte{Satoshis: 5}
}(),
getTxFinderFn: func(_ *testing.T) mocks.TxFinderIMock {
return mocks.TxFinderIMock{
GetRawTxsFunc: func(_ context.Context, _ validation.FindSourceFlag, _ []string) ([]validation.RawTx, error) {
return nil, errors.New("test error")
},
}
},
getMempoolAncestorsErr: errors.New("some error"),

expectedErr: validation.NewError(
ErrFailedToGetRawTxs,
ErrFailedToGetMempoolAncestors,
api.ErrStatusCumulativeFees),
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
// given
txFinder := tc.getTxFinderFn(t)
tx, _ := sdkTx.NewTransactionFromHex(tc.hex)
txFinder := &mocks.TxFinderIMock{
GetMempoolAncestorsFunc: func(_ context.Context, _ []string) ([]string, error) {
return tc.mempoolAncestors, tc.getMempoolAncestorsErr
},
GetRawTxsFunc: func(_ context.Context, _ validation.FindSourceFlag, ids []string) ([]*sdkTx.Transaction, error) {
rawTxs := make([]*sdkTx.Transaction, len(ids))
for i, id := range ids {
rawTx, ok := txMap[id]
if !ok {
t.Fatalf("tx id %s not found", id)
}
rawTxs[i] = rawTx
}

return rawTxs, tc.getRawTxsErr
},
}
tx := fixture.ValidTx

err := extendTx(context.TODO(), txFinder, tx, false)
require.NoError(t, err)

// when
actualError := cumulativeCheckFees(context.TODO(), &txFinder, tx, tc.feeModel, false)
actualError := cumulativeCheckFees(context.TODO(), txFinder, tx, tc.feeModel, false)

// then
if tc.expectedErr == nil {
Expand Down
106 changes: 34 additions & 72 deletions internal/validator/default/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ import (
)

var (
ErrParentNotFound = errors.New("parent transaction not found")
ErrFailedToGetRawTxs = errors.New("failed to get raw transactions for parent")
ErrParentNotFound = errors.New("parent transaction not found")
ErrFailedToGetRawTxs = errors.New("failed to get raw transactions for parent")
ErrFailedToGetMempoolAncestors = errors.New("failed to get mempool ancestors")
)

func extendTx(ctx context.Context, f validator.TxFinderI, rawTx *sdkTx.Transaction, tracingEnabled bool, tracingAttributes ...attribute.KeyValue) error {
func extendTx(ctx context.Context, txFinder validator.TxFinderI, rawTx *sdkTx.Transaction, tracingEnabled bool, tracingAttributes ...attribute.KeyValue) error {
ctx, span := tracing.StartTracing(ctx, "extendTx", tracingEnabled, tracingAttributes...)
defer tracing.EndTracing(span)

Expand Down Expand Up @@ -45,9 +46,9 @@ func extendTx(ctx context.Context, f validator.TxFinderI, rawTx *sdkTx.Transacti
// get parents
const finderSource = validator.SourceTransactionHandler | validator.SourceNodes | validator.SourceWoC

parentsTxs, err := f.GetRawTxs(ctx, finderSource, parentsIDs)
parentsTxs, err := txFinder.GetRawTxs(ctx, finderSource, parentsIDs)
if err != nil {
return fmt.Errorf("failed to get raw transactions for parent: %v. Reason: %w", parentsIDs, err)
return errors.Join(ErrFailedToGetRawTxs, fmt.Errorf("failed to get raw transactions for parents %v: %w", parentsIDs, err))
}

if len(parentsTxs) != len(parentsIDs) {
Expand All @@ -56,110 +57,71 @@ func extendTx(ctx context.Context, f validator.TxFinderI, rawTx *sdkTx.Transacti

// extend inputs with parents data
for _, p := range parentsTxs {
childInputs, found := parentInputMap[p.TxID]
childInputs, found := parentInputMap[p.TxID()]
if !found {
return ErrParentNotFound
}

bTx, err := sdkTx.NewTransactionFromBytes(p.Bytes)
if err != nil {
return fmt.Errorf("cannot parse parent tx: %w", err)
if err = extendInputs(p, childInputs); err != nil {
return err
}
}

if err = extendInputs(bTx, childInputs); err != nil {
return err
return nil
}

func extendInputs(tx *sdkTx.Transaction, childInputs []*sdkTx.TransactionInput) error {
for _, input := range childInputs {
if len(tx.Outputs) < int(input.SourceTxOutIndex) {
return fmt.Errorf("output %d not found in transaction %s", input.SourceTxOutIndex, input.PreviousTxIDStr())
}
output := tx.Outputs[input.SourceTxOutIndex]

input.SetPrevTxFromOutput(output)
}

return nil
}

// getUnminedAncestors returns unmined ancestors with data necessary to perform Deep Fee validation
func getUnminedAncestors(ctx context.Context, w validator.TxFinderI, tx *sdkTx.Transaction, tracingEnabled bool, tracingAttributes ...attribute.KeyValue) (map[string]*sdkTx.Transaction, error) {
// getUnminedAncestors returns unmined ancestors with data necessary to perform cumulative fee validation
func getUnminedAncestors(ctx context.Context, txFinder validator.TxFinderI, tx *sdkTx.Transaction, tracingEnabled bool, tracingAttributes ...attribute.KeyValue) (map[string]*sdkTx.Transaction, error) {
ctx, span := tracing.StartTracing(ctx, "getUnminedAncestors", tracingEnabled, tracingAttributes...)
defer tracing.EndTracing(span)
unmindedAncestorsSet := make(map[string]*sdkTx.Transaction)

// get distinct parents
// map parentID with inputs collection to avoid duplication and simplify later processing
parentInputMap := make(map[string][]*sdkTx.TransactionInput)
parentInputMap := make(map[string]struct{})
parentsIDs := make([]string, 0, len(tx.Inputs))

for _, in := range tx.Inputs {
prevTxID := in.PreviousTxIDStr()

inputs, found := parentInputMap[prevTxID]
_, found := parentInputMap[prevTxID]
if !found {
// first occurrence of the parent
inputs = make([]*sdkTx.TransactionInput, 0)
parentsIDs = append(parentsIDs, prevTxID)
}

inputs = append(inputs, in)
parentInputMap[prevTxID] = inputs
}

// get parents
const finderSource = validator.SourceTransactionHandler | validator.SourceWoC
parentsTxs, err := w.GetRawTxs(ctx, finderSource, parentsIDs)
mempoolAncestorTxIDs, err := txFinder.GetMempoolAncestors(ctx, parentsIDs)
if err != nil {
return nil, errors.Join(ErrFailedToGetRawTxs, fmt.Errorf("parent: %v", parentsIDs), err)
return nil, errors.Join(ErrFailedToGetMempoolAncestors, err)
}

if len(parentsTxs) != len(parentsIDs) {
return nil, ErrParentNotFound
}

for _, p := range parentsTxs {
if _, found := unmindedAncestorsSet[p.TxID]; found {
continue // parent was proccesed already
}
allTxIDs := append(parentsIDs, mempoolAncestorTxIDs...)

childInputs, found := parentInputMap[p.TxID]
if !found {
return nil, ErrParentNotFound
}

// fulfill data about the parent for further validation
bTx, err := sdkTx.NewTransactionFromBytes(p.Bytes)
if err != nil {
return nil, fmt.Errorf("cannot parse parent tx: %w", err)
}
ancestorTxs, err := txFinder.GetRawTxs(ctx, validator.SourceNodes, allTxIDs)
if err != nil {
return nil, errors.Join(ErrFailedToGetRawTxs, err)
}

err = extendInputs(bTx, childInputs)
for _, ancestorTx := range ancestorTxs {
err = extendTx(ctx, txFinder, ancestorTx, tracingEnabled, tracingAttributes...)
if err != nil {
return nil, err
}

if p.IsMined {
continue // we don't need its ancestors
}

unmindedAncestorsSet[p.TxID] = bTx

// get parent ancestors
parentAncestorsSet, err := getUnminedAncestors(ctx, w, bTx, tracingEnabled, tracingAttributes...)
for aID, aTx := range parentAncestorsSet {
unmindedAncestorsSet[aID] = aTx
}

if err != nil {
return unmindedAncestorsSet, err
}
unmindedAncestorsSet[ancestorTx.TxID()] = ancestorTx
}

return unmindedAncestorsSet, nil
}

func extendInputs(tx *sdkTx.Transaction, childInputs []*sdkTx.TransactionInput) error {
for _, input := range childInputs {
if len(tx.Outputs) < int(input.SourceTxOutIndex) {
return fmt.Errorf("output %d not found in transaction %s", input.SourceTxOutIndex, input.PreviousTxIDStr())
}
output := tx.Outputs[input.SourceTxOutIndex]

input.SetPrevTxFromOutput(output)
}

return nil
}
Loading

0 comments on commit cd5ee3e

Please sign in to comment.