diff --git a/cashu/core/settings.py b/cashu/core/settings.py index 274e1098..c84f66fb 100644 --- a/cashu/core/settings.py +++ b/cashu/core/settings.py @@ -56,6 +56,7 @@ class MintSettings(CashuSettings): mint_peg_out_only: bool = Field(default=False) mint_max_peg_in: int = Field(default=None) mint_max_peg_out: int = Field(default=None) + mint_max_balance: int = Field(default=None) mint_lnbits_endpoint: str = Field(default=None) mint_lnbits_key: str = Field(default=None) diff --git a/cashu/mint/crud.py b/cashu/mint/crud.py index 5a189710..29d84c9b 100644 --- a/cashu/mint/crud.py +++ b/cashu/mint/crud.py @@ -158,6 +158,16 @@ async def update_lightning_invoice( conn=conn, ) + async def get_balance( + self, + db: Database, + conn: Optional[Connection] = None, + ) -> int: + return await get_balance( + db=db, + conn=conn, + ) + async def store_promise( *, @@ -394,3 +404,14 @@ async def get_keyset( tuple(values), ) return [MintKeyset(**row) for row in rows] + + +async def get_balance( + db: Database, + conn: Optional[Connection] = None, +) -> int: + row = await (conn or db).fetchone(f""" + SELECT * from {table_with_schema(db, 'balance')} + """) + assert row, "Balance not found" + return int(row[0]) diff --git a/cashu/mint/ledger.py b/cashu/mint/ledger.py index 371da7dd..79faf966 100644 --- a/cashu/mint/ledger.py +++ b/cashu/mint/ledger.py @@ -146,6 +146,10 @@ def get_keyset(self, keyset_id: Optional[str] = None) -> Dict[int, str]: assert keyset.public_keys, KeysetError("no public keys for this keyset") return {a: p.serialize().hex() for a, p in keyset.public_keys.items()} + async def get_balance(self) -> int: + """Returns the balance of the mint.""" + return await self.crud.get_balance(db=self.db) + # ------- ECASH ------- async def _invalidate_proofs(self, proofs: List[Proof]) -> None: @@ -245,6 +249,10 @@ async def request_mint(self, amount: int) -> Tuple[str, str]: ) if settings.mint_peg_out_only: raise NotAllowedError("Mint does not allow minting new tokens.") + if settings.mint_max_balance: + balance = await self.get_balance() + if balance + amount > settings.mint_max_balance: + raise NotAllowedError("Mint has reached maximum balance.") logger.trace(f"requesting invoice for {amount} satoshis") invoice_response = await self._request_lightning_invoice(amount) diff --git a/tests/conftest.py b/tests/conftest.py index 9906b832..72dfb635 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,6 +32,7 @@ settings.mint_database = "./test_data/test_mint" settings.mint_derivation_path = "0/0/0/0" settings.mint_private_key = "TEST_PRIVATE_KEY" +settings.mint_max_balance = 0 shutil.rmtree(settings.cashu_dir, ignore_errors=True) Path(settings.cashu_dir).mkdir(parents=True, exist_ok=True) diff --git a/tests/test_cli.py b/tests/test_cli.py index 3ec460e3..64119ae4 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,4 +1,5 @@ import asyncio +from typing import Tuple import pytest from click.testing import CliRunner @@ -15,7 +16,7 @@ def cli_prefix(): yield ["--wallet", "test_cli_wallet", "--host", settings.mint_url, "--tests"] -def get_bolt11_and_invoice_id_from_invoice_command(output: str) -> (str, str): +def get_bolt11_and_invoice_id_from_invoice_command(output: str) -> Tuple[str, str]: invoice = [ line.split(" ")[1] for line in output.split("\n") if line.startswith("Invoice") ][0] diff --git a/tests/test_mint.py b/tests/test_mint.py index 404f42ab..c646bbba 100644 --- a/tests/test_mint.py +++ b/tests/test_mint.py @@ -182,3 +182,20 @@ async def test_generate_change_promises_returns_empty_if_no_outputs(ledger: Ledg total_provided, invoice_amount, actual_fee_msat, outputs ) assert len(promises) == 0 + + +@pytest.mark.asyncio +async def test_get_balance(ledger: Ledger): + balance = await ledger.get_balance() + assert balance == 0 + + +@pytest.mark.asyncio +async def test_maximum_balance(ledger: Ledger): + settings.mint_max_balance = 1000 + invoice, id = await ledger.request_mint(8) + await assert_err( + ledger.request_mint(8000), + "Mint has reached maximum balance.", + ) + settings.mint_max_balance = 0