From f7d01268056783184a09f21f3deec3f79448408c Mon Sep 17 00:00:00 2001 From: callebtc <93376500+callebtc@users.noreply.github.com> Date: Fri, 24 Nov 2023 14:47:36 -0300 Subject: [PATCH] Fix: Cast keyset keys (amount) to int (#368) * load keys as integers * add tests * make format * revert format from newer branch --- cashu/core/base.py | 9 +++++++-- cashu/wallet/wallet.py | 11 +++++++---- tests/test_wallet.py | 24 +++++++++++++++++++++++- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/cashu/core/base.py b/cashu/core/base.py index 76492cf7..b5e14f8d 100644 --- a/cashu/core/base.py +++ b/cashu/core/base.py @@ -348,6 +348,11 @@ def __init__( self.public_keys = public_keys # overwrite id by deriving it from the public keys self.id = derive_keyset_id(self.public_keys) + logger.trace(f"Derived keyset id {self.id} from public keys.") + if id and id != self.id: + logger.warning( + f"WARNING: Keyset id {self.id} does not match the given id {id}." + ) def serialize(self): return json.dumps( @@ -356,9 +361,9 @@ def serialize(self): @classmethod def from_row(cls, row: Row): - def deserialize(serialized: str): + def deserialize(serialized: str) -> Dict[int, PublicKey]: return { - amount: PublicKey(bytes.fromhex(hex_key), raw=True) + int(amount): PublicKey(bytes.fromhex(hex_key), raw=True) for amount, hex_key in dict(json.loads(serialized)).items() } diff --git a/cashu/wallet/wallet.py b/cashu/wallet/wallet.py index 7c20ff81..532df1b0 100644 --- a/cashu/wallet/wallet.py +++ b/cashu/wallet/wallet.py @@ -171,21 +171,24 @@ async def _load_mint_keys(self, keyset_id: Optional[str] = None) -> None: keyset_local: Union[WalletKeyset, None] = None if keyset_id: # check if current keyset is in db + logger.trace(f"Checking if keyset {keyset_id} is in database.") keyset_local = await get_keyset(keyset_id, db=self.db) if keyset_local: - logger.debug(f"Found keyset {keyset_id} in database.") + logger.trace(f"Found keyset {keyset_id} in database.") else: - logger.debug( - f"Cannot find keyset {keyset_id} in database. Loading keyset from" - " mint." + logger.trace( + f"Could not find keyset {keyset_id} in database. Loading keyset" + " from mint." ) keyset = keyset_local if keyset_local is None and keyset_id: # get requested keyset from mint + logger.trace(f"Getting keyset {keyset_id} from mint.") keyset = await self._get_keys_of_keyset(self.url, keyset_id) else: # get current keyset + logger.trace("Getting current keyset from mint.") keyset = await self._get_keys(self.url) assert keyset diff --git a/tests/test_wallet.py b/tests/test_wallet.py index 6433009c..d719ce94 100644 --- a/tests/test_wallet.py +++ b/tests/test_wallet.py @@ -1,3 +1,4 @@ +import copy import shutil from pathlib import Path from typing import List, Union @@ -9,7 +10,7 @@ from cashu.core.errors import CashuError, KeysetNotFoundError from cashu.core.helpers import sum_proofs from cashu.core.settings import settings -from cashu.wallet.crud import get_lightning_invoice, get_proofs +from cashu.wallet.crud import get_keyset, get_lightning_invoice, get_proofs from cashu.wallet.wallet import Wallet from cashu.wallet.wallet import Wallet as Wallet1 from cashu.wallet.wallet import Wallet as Wallet2 @@ -114,6 +115,27 @@ async def test_get_keyset(wallet1: Wallet): assert len(keys1.public_keys) == len(keys2.public_keys) +@pytest.mark.asyncio +async def test_get_keyset_from_db(wallet1: Wallet): + # first load it from the mint + # await wallet1._load_mint_keys() + # NOTE: conftest already called wallet.load_mint() which got the keys from the mint + keyset1 = copy.copy(wallet1.keysets[wallet1.keyset_id]) + + # then load it from the db + await wallet1._load_mint_keys() + keyset2 = copy.copy(wallet1.keysets[wallet1.keyset_id]) + + assert keyset1.public_keys == keyset2.public_keys + assert keyset1.id == keyset2.id + + # load it directly from the db + keyset3 = await get_keyset(db=wallet1.db, id=keyset1.id) + assert keyset3 + assert keyset1.public_keys == keyset3.public_keys + assert keyset1.id == keyset3.id + + @pytest.mark.asyncio async def test_get_info(wallet1: Wallet): info = await wallet1._get_info(wallet1.url)