Skip to content

Commit

Permalink
Mint: invalidate and generate promises in single db transaction for s…
Browse files Browse the repository at this point in the history
…plit (#374)

* test for spending output again

* first gernerate (which can fail) then invalidate (db and memory)

* use external get_db_connection function to be compatible with existing Database class in LNbits
  • Loading branch information
callebtc authored Dec 2, 2023
1 parent 0ec3af9 commit 34a2e7e
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 36 deletions.
23 changes: 23 additions & 0 deletions cashu/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,26 @@ def lock_table(db: Database, table: str) -> str:
elif db.type == SQLITE:
return "BEGIN EXCLUSIVE TRANSACTION;"
return "<nothing>"


@asynccontextmanager
async def get_db_connection(db: Database, conn: Optional[Connection] = None):
"""Either yield the existing database connection or create a new one.
Note: This should be implemented as Database.get_db_connection(self, conn) but
since we want to use it in LNbits, we can't change the Database class their.
Args:
db (Database): Database object.
conn (Optional[Connection], optional): Connection object. Defaults to None.
Yields:
Connection: Connection object.
"""
if conn is not None:
# Yield the existing connection
yield conn
else:
# Create and yield a new connection
async with db.connect() as new_conn:
yield new_conn
52 changes: 19 additions & 33 deletions cashu/mint/ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..core.crypto import b_dhke
from ..core.crypto.keys import derive_pubkey, random_hash
from ..core.crypto.secp import PublicKey
from ..core.db import Connection, Database
from ..core.db import Connection, Database, get_db_connection
from ..core.errors import (
KeysetError,
KeysetNotFoundError,
Expand Down Expand Up @@ -151,17 +151,17 @@ async def get_balance(self) -> int:

# ------- ECASH -------

async def _invalidate_proofs(self, proofs: List[Proof]) -> None:
async def _invalidate_proofs(
self, proofs: List[Proof], conn: Optional[Connection] = None
) -> None:
"""Adds secrets of proofs to the list of known secrets and stores them in the db.
Removes proofs from pending table. This is executed if the ecash has been redeemed.
Args:
proofs (List[Proof]): Proofs to add to known secret table.
"""
# Mark proofs as used and prepare new promises
secrets = set([p.secret for p in proofs])
self.secrets_used |= secrets
async with self.db.connect() as conn:
async with get_db_connection(self.db, conn) as conn:
# store in db
for p in proofs:
await self.crud.invalidate_proof(proof=p, db=self.db, conn=conn)
Expand Down Expand Up @@ -450,14 +450,12 @@ async def split(
proofs: List[Proof],
outputs: List[BlindedMessage],
keyset: Optional[MintKeyset] = None,
amount: Optional[int] = None, # backwards compatibility < 0.13.0
):
"""Consumes proofs and prepares new promises based on the amount split. Used for splitting tokens
Before sending or for redeeming tokens for new ones that have been received by another wallet.
Args:
proofs (List[Proof]): Proofs to be invalidated for the split.
amount (int): Amount at which the split should happen.
outputs (List[BlindedMessage]): New outputs that should be signed in return.
keyset (Optional[MintKeyset], optional): Keyset to use. Uses default keyset if not given. Defaults to None.
Expand All @@ -471,33 +469,18 @@ async def split(

await self._set_proofs_pending(proofs)
try:
# explicitly check that amount of inputs is equal to amount of outputs
# note: we check this again in verify_inputs_and_outputs but only if any
# outputs are provided at all. To make sure of that before calling
# verify_inputs_and_outputs, we check it here.
self._verify_equation_balanced(proofs, outputs)
# verify spending inputs, outputs, and spending conditions
await self.verify_inputs_and_outputs(proofs, outputs)

# BEGIN backwards compatibility < 0.13.0
if amount is not None:
logger.debug(
"Split: Client provided `amount` - backwards compatibility response"
" pre 0.13.0"
)
# split outputs according to amount
total = sum_proofs(proofs)
if amount > total:
raise Exception("split amount is higher than the total sum.")
outs_fst = amount_split(total - amount)
B_fst = [od for od in outputs[: len(outs_fst)]]
B_snd = [od for od in outputs[len(outs_fst) :]]

# Mark proofs as used and prepare new promises
await self._invalidate_proofs(proofs)
prom_fst = await self._generate_promises(B_fst, keyset)
prom_snd = await self._generate_promises(B_snd, keyset)
promises = prom_fst + prom_snd
# END backwards compatibility < 0.13.0
else:
# Mark proofs as used and prepare new promises
await self._invalidate_proofs(proofs)
promises = await self._generate_promises(outputs, keyset)
# Mark proofs as used and prepare new promises
async with get_db_connection(self.db) as conn:
promises = await self._generate_promises(outputs, keyset, conn)
await self._invalidate_proofs(proofs, conn)

except Exception as e:
logger.trace(f"split failed: {e}")
Expand Down Expand Up @@ -535,7 +518,10 @@ async def restore(
# ------- BLIND SIGNATURES -------

async def _generate_promises(
self, B_s: List[BlindedMessage], keyset: Optional[MintKeyset] = None
self,
B_s: List[BlindedMessage],
keyset: Optional[MintKeyset] = None,
conn: Optional[Connection] = None,
) -> list[BlindedSignature]:
"""Generates a promises (Blind signatures) for given amount and returns a pair (amount, C').
Expand All @@ -557,7 +543,7 @@ async def _generate_promises(
promises.append((B_, amount, C_, e, s))

signatures = []
async with self.db.connect() as conn:
async with get_db_connection(self.db, conn) as conn:
for promise in promises:
B_, amount, C_, e, s = promise
logger.trace(f"crud: _generate_promise storing promise for {amount}")
Expand Down
4 changes: 1 addition & 3 deletions cashu/mint/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,7 @@ async def split(
logger.trace(f"> POST /split: {payload}")
assert payload.outputs, Exception("no outputs provided.")

promises = await ledger.split(
proofs=payload.proofs, outputs=payload.outputs, amount=payload.amount
)
promises = await ledger.split(proofs=payload.proofs, outputs=payload.outputs)

if payload.amount:
# BEGIN backwards compatibility < 0.13
Expand Down
32 changes: 32 additions & 0 deletions tests/test_mint_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,38 @@ async def test_split_with_input_more_than_outputs(wallet1: Wallet, ledger: Ledge
print(keep_proofs, send_proofs)


@pytest.mark.asyncio
async def test_split_twice_with_same_outputs(wallet1: Wallet, ledger: Ledger):
invoice = await wallet1.request_mint(128)
pay_if_regtest(invoice.bolt11)
await wallet1.mint(128, [64, 64], id=invoice.id)
inputs1 = wallet1.proofs[:1]
inputs2 = wallet1.proofs[1:]

output_amounts = [64]
secrets, rs, derivation_paths = await wallet1.generate_n_secrets(
len(output_amounts)
)
outputs, rs = wallet1._construct_outputs(output_amounts, secrets, rs)

await ledger.split(proofs=inputs1, outputs=outputs)

# try to spend other proofs with the same outputs again
await assert_err(
ledger.split(proofs=inputs2, outputs=outputs),
"UNIQUE constraint failed: promises.B_b",
)

# try to spend inputs2 again with new outputs
output_amounts = [64]
secrets, rs, derivation_paths = await wallet1.generate_n_secrets(
len(output_amounts)
)
outputs, rs = wallet1._construct_outputs(output_amounts, secrets, rs)

await ledger.split(proofs=inputs2, outputs=outputs)


@pytest.mark.asyncio
async def test_check_proof_state(wallet1: Wallet, ledger: Ledger):
invoice = await wallet1.request_mint(64)
Expand Down

0 comments on commit 34a2e7e

Please sign in to comment.