Skip to content

Commit

Permalink
Fix: Cast keyset keys (amount) to int (#368)
Browse files Browse the repository at this point in the history
* load keys as integers

* add tests

* make format

* revert format from newer branch
  • Loading branch information
callebtc authored Nov 24, 2023
1 parent b519c7d commit f7d0126
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
9 changes: 7 additions & 2 deletions cashu/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,11 @@ def __init__(
self.public_keys = public_keys
# overwrite id by deriving it from the public keys
self.id = derive_keyset_id(self.public_keys)
logger.trace(f"Derived keyset id {self.id} from public keys.")
if id and id != self.id:
logger.warning(
f"WARNING: Keyset id {self.id} does not match the given id {id}."
)

def serialize(self):
return json.dumps(
Expand All @@ -356,9 +361,9 @@ def serialize(self):

@classmethod
def from_row(cls, row: Row):
def deserialize(serialized: str):
def deserialize(serialized: str) -> Dict[int, PublicKey]:
return {
amount: PublicKey(bytes.fromhex(hex_key), raw=True)
int(amount): PublicKey(bytes.fromhex(hex_key), raw=True)
for amount, hex_key in dict(json.loads(serialized)).items()
}

Expand Down
11 changes: 7 additions & 4 deletions cashu/wallet/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,21 +171,24 @@ async def _load_mint_keys(self, keyset_id: Optional[str] = None) -> None:
keyset_local: Union[WalletKeyset, None] = None
if keyset_id:
# check if current keyset is in db
logger.trace(f"Checking if keyset {keyset_id} is in database.")
keyset_local = await get_keyset(keyset_id, db=self.db)
if keyset_local:
logger.debug(f"Found keyset {keyset_id} in database.")
logger.trace(f"Found keyset {keyset_id} in database.")
else:
logger.debug(
f"Cannot find keyset {keyset_id} in database. Loading keyset from"
" mint."
logger.trace(
f"Could not find keyset {keyset_id} in database. Loading keyset"
" from mint."
)
keyset = keyset_local

if keyset_local is None and keyset_id:
# get requested keyset from mint
logger.trace(f"Getting keyset {keyset_id} from mint.")
keyset = await self._get_keys_of_keyset(self.url, keyset_id)
else:
# get current keyset
logger.trace("Getting current keyset from mint.")
keyset = await self._get_keys(self.url)

assert keyset
Expand Down
24 changes: 23 additions & 1 deletion tests/test_wallet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import shutil
from pathlib import Path
from typing import List, Union
Expand All @@ -9,7 +10,7 @@
from cashu.core.errors import CashuError, KeysetNotFoundError
from cashu.core.helpers import sum_proofs
from cashu.core.settings import settings
from cashu.wallet.crud import get_lightning_invoice, get_proofs
from cashu.wallet.crud import get_keyset, get_lightning_invoice, get_proofs
from cashu.wallet.wallet import Wallet
from cashu.wallet.wallet import Wallet as Wallet1
from cashu.wallet.wallet import Wallet as Wallet2
Expand Down Expand Up @@ -114,6 +115,27 @@ async def test_get_keyset(wallet1: Wallet):
assert len(keys1.public_keys) == len(keys2.public_keys)


@pytest.mark.asyncio
async def test_get_keyset_from_db(wallet1: Wallet):
# first load it from the mint
# await wallet1._load_mint_keys()
# NOTE: conftest already called wallet.load_mint() which got the keys from the mint
keyset1 = copy.copy(wallet1.keysets[wallet1.keyset_id])

# then load it from the db
await wallet1._load_mint_keys()
keyset2 = copy.copy(wallet1.keysets[wallet1.keyset_id])

assert keyset1.public_keys == keyset2.public_keys
assert keyset1.id == keyset2.id

# load it directly from the db
keyset3 = await get_keyset(db=wallet1.db, id=keyset1.id)
assert keyset3
assert keyset1.public_keys == keyset3.public_keys
assert keyset1.id == keyset3.id


@pytest.mark.asyncio
async def test_get_info(wallet1: Wallet):
info = await wallet1._get_info(wallet1.url)
Expand Down

0 comments on commit f7d0126

Please sign in to comment.