Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alexs hands #3

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,10 @@ implicit_reexport = false
strict_equality = true
namespace_packages = true
explicit_package_bases = true
exclude = [
'example.py'
]

[[tool.mypy.overrides]]
module = ["bson.*", "aiogram.*"]
ignore_missing_imports = true
162 changes: 55 additions & 107 deletions sqlitestorage/storage.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,24 @@
import sqlite3
from sqlite3 import connect, Connection
from aiogram.dispatcher.storage import BaseStorage
from typing import Any, Dict, Optional, Tuple
import json
import typing
from typing import Any
from json import dumps, loads

class SQLiteStorage(BaseStorage):

class SQLiteStorage(BaseStorage): # type: ignore
"""
Simple SQLite based storage for Finite State Machine.

Intended to replace MemoryStorage for simple cases where you want to keep states
between bot restarts.
"""

async def update_data(self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict[Any, Any] | None = None,
**kwargs: Any) -> None:
existing_data = await self.get_data(chat=chat, user=user)
if data:
existing_data.update(data)
existing_data.update(**kwargs)

conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("""
INSERT OR REPLACE INTO fsm_data (key, state, data)
VALUES (?, COALESCE((SELECT state FROM fsm_data WHERE key = ?), '{}'), ?)
""", (str(chat) + ":" + str(user), str(chat) + ":" + str(user), json.dumps(existing_data)))
conn.commit()

async def update_bucket(self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict | None = None,
**kwargs):
pass

async def set_bucket(self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
bucket: typing.Dict | None = None) -> None:
pass

async def get_bucket(self,
*,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[dict] | None = None) -> dict | None:
pass

def __init__(self, db_path: str = "fsm_storage.db"):
self.db_path = db_path
self._conn = None
self._conn: Connection | None = None
self._init_db()

def _init_db(self):
conn = sqlite3.connect(self.db_path)
def _init_db(self) -> None:
conn = connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS fsm_data (
Expand All @@ -71,9 +30,9 @@ def _init_db(self):
conn.commit()
conn.close()

def _get_connection(self) :
def _get_connection(self) -> Connection:
if self._conn is None:
self._conn = sqlite3.connect(self.db_path)
self._conn = connect(self.db_path)
return self._conn

async def close(self) -> None:
Expand All @@ -84,17 +43,13 @@ async def close(self) -> None:
async def wait_closed(self) -> None:
pass

async def set_state(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
state: typing.Optional[typing.AnyStr] = None,
**kwargs):
async def set_state(self,
chat: str | int | None = None,
user: str | int | None = None,
state: Any | str | None = None,
**kwargs: dict[Any, Any]) -> None:
conn = self._get_connection()
cursor = conn.cursor()
# print('Set state')
# print(f'chat: {chat}')
# print(f'user: {user}')
# print(f'state: {self.resolve_state(state)}')
cursor.execute("""
INSERT OR REPLACE INTO fsm_data
(key, state, data)
Expand All @@ -106,68 +61,61 @@ async def get_state(self,
*,
chat: str | int | None = None,
user: str | int | None = None,
default: str | None = None) -> typing.Coroutine[Any, Any, str | None]:
default: str | None = None) -> Any:
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("SELECT state FROM fsm_data WHERE key = ?", (str(chat) + ":" + str(user),))
result = cursor.fetchone()
# print('Get state')
# print(f'chat: {chat}')
# print(f'user: {user}')
# print(f'raw state: {result[0]}')
# print(f'resolved state: {self.resolve_state(result[0])}')
if result:
what = result
else:
what = None
if result and len(result[0]) > 0:
state = result[0]
else:
state = None
# print(f'result: {what}')
# print(f'state: {state}')
return state
return result[0] if result else None

async def set_data(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
data: typing.Dict | None = None):
chat: str | int | None = None,
user: str | int | None = None,
data: dict[Any, Any] | None = None) -> None:
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("""
INSERT OR REPLACE INTO fsm_data (key, state, data)
VALUES (?, COALESCE((SELECT state FROM fsm_data WHERE key = ?), ''), ?)
""", (str(chat) + ":" + str(user), str(chat) + ":" + str(user), json.dumps(data)))
""", (str(chat) + ":" + str(user), str(chat) + ":" + str(user), dumps(data)))
conn.commit()

async def get_data(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
default: typing.Optional[typing.Dict] = None) -> typing.Dict:
chat: str | int | None = None,
user: str | int | None = None,
default: dict[Any, Any] | None = None) -> Any:
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("SELECT data FROM fsm_data WHERE key = ?", (str(chat) + ":" + str(user),))
result = cursor.fetchone()
return json.loads(result[0]) if result else {}

async def reset_data(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None):
await self.set_data(chat=chat, user=user, data={})

async def reset_state(self, *,
chat: typing.Union[str, int, None] = None,
user: typing.Union[str, int, None] = None,
with_data: typing.Optional[bool] = True):
# await self.set_state(chat=chat, user=user, state=None)
# if with_data:
# await self.set_data(chat=chat, user=user, data={})
self._cleanup(chat, user)

def _cleanup(self, chat, user):
# chat, user = self.resolve_address(chat=chat, user=user)
# if self.get_state(chat=chat, user=user) == None:
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("DELETE FROM fsm_data WHERE key = ?", (str(chat) + ":" + str(user),))
conn.commit()
return loads(result[0]) if result else {}

async def update_data(self,
*,
chat: str | int | None = None,
user: str | int | None = None,
data: dict[Any, Any] | None = None,
**kwargs: Any) -> None:
buffer = await self.get_data(chat=chat, user=user)
buffer.update(**kwargs)
await self.set_data(chat=chat, user=user, data=buffer | data if data else buffer)

async def update_bucket(self,
chat: str | int | None = None,
user: str | int | None = None,
bucket: dict[Any, Any] | None = None,
**kwargs: dict[Any, Any]) -> None:
pass

async def set_bucket(self,
chat: str | int | None = None,
user: str | int | None = None,
bucket: dict[Any, Any] | None = None) -> None:
pass

async def get_bucket(self,
chat: str | int | None = None,
user: str | int | None = None,
default: dict[Any, Any] | None = None) -> dict[Any, Any] | None:
pass

47 changes: 34 additions & 13 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,51 @@
from aiogram.dispatcher.storage import BaseStorage


async def test_empty_id(
storage: BaseStorage
) -> None:
try:
await storage.get_data() == {}
except Exception as exp:
assert isinstance(exp, ValueError)


async def test_set_get(
storage: BaseStorage
) -> None:
assert await storage.get_data() == {}
await storage.set_data(data={'foo': 'bar'})
assert await storage.get_data() == {'foo': 'bar'}
assert await storage.get_data(chat='1') == {}
await storage.set_data(chat='1', data={'foo': 'bar'})
assert await storage.get_data(chat='1') == {'foo': 'bar'}


async def test_update(
storage: BaseStorage
) -> None:
await storage.update_data(chat='1', data={'foo': 'bar'})
assert await storage.get_data(chat='1') == {'foo': 'bar'}

await storage.update_data(chat='1', **{'first': 'second'})
assert await storage.get_data(chat='1') == {'foo': 'bar', 'first': 'second'}


async def test_reset(
storage: BaseStorage
) -> None:
await storage.set_data(data={'foo': 'bar'})
await storage.set_state(state='SECOND')
await storage.set_data(chat='1', data={'foo': 'bar'})
await storage.set_state(chat='1', state='SECOND')

await storage.reset_data()
assert await storage.get_state() == 'SECOND'
await storage.set_data(data={'foo': 'bar'})
await storage.reset_data(chat='1')
assert await storage.get_state(chat='1') == 'SECOND'
await storage.set_data(chat='1', data={'foo': 'bar'})

await storage.reset_state()
assert await storage.get_data() == {'foo': 'bar'}
await storage.reset_state(chat='1')
assert await storage.get_data(chat='1') == {'foo': 'bar'}


async def test_reset_empty(
async def test_test_empty(
storage: BaseStorage
) -> None:
await storage.reset_data()
assert await storage.get_data() == {}
await storage.reset_data(chat='1')
assert await storage.get_data(chat='1') == {}

assert not await storage.get_state(chat='1')