Skip to content

Commit

Permalink
Return deferred db init (#278)
Browse files Browse the repository at this point in the history
* fix: return deferred db init
---------

Co-authored-by: kalombo <[email protected]>
  • Loading branch information
F1int0m and kalombos authored Aug 3, 2024
1 parent bcc25cf commit b77ea75
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 45 deletions.
118 changes: 74 additions & 44 deletions peewee_async/databases.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import logging
from typing import Type, Optional, Any, AsyncIterator, Iterator
from typing import Type, Optional, Any, AsyncIterator, Iterator, Dict

import peewee
from playhouse import postgres_ext as ext
Expand All @@ -11,15 +11,18 @@
from .utils import psycopg2, aiopg, pymysql, aiomysql, __log__


class AioDatabase:
class AioDatabase(peewee.Database):
_allow_sync = True # whether sync queries are allowed

pool_backend_cls: Type[PoolBackend]
pool_backend: PoolBackend

def __init__(self, database: Optional[str], **kwargs: Any) -> None:
super().__init__(database, **kwargs)
if not database:
raise Exception("Deferred initialization is not supported")
@property
def connect_params_async(self) -> Dict[str, Any]:
...

def init(self, database: Optional[str], **kwargs: Any) -> None:
super().init(database, **kwargs)
self.pool_backend = self.pool_backend_cls(
database=self.database,
**self.connect_params_async
Expand All @@ -28,6 +31,8 @@ def __init__(self, database: Optional[str], **kwargs: Any) -> None:
async def aio_connect(self) -> None:
"""Creates a connection pool
"""
if self.deferred:
raise Exception('Error, database must be initialized before creating a connection pool')
await self.pool_backend.connect()

@property
Expand All @@ -39,6 +44,9 @@ def is_connected(self) -> bool:
async def aio_close(self) -> None:
"""Terminate pool backend. The pool is closed until you run aio_connect manually
"""
if self.deferred:
raise Exception('Error, database must be initialized before creating a connection pool')

await self.pool_backend.terminate()

@contextlib.asynccontextmanager
Expand Down Expand Up @@ -91,12 +99,17 @@ def execute_sql(self, *args, **kwargs):
"Error, sync query is not allowed! Call the `.set_allow_sync()` "
"or use the `.allow_sync()` context manager.")
if self._allow_sync in (logging.ERROR, logging.WARNING):
logging.log(self._allow_sync,
"Error, sync query is not allowed: %s %s" %
(str(args), str(kwargs)))
logging.log(
self._allow_sync,
"Error, sync query is not allowed: %s %s" %
(str(args), str(kwargs))
)
return super().execute_sql(*args, **kwargs)

def aio_connection(self) -> ConnectionContextManager:
if self.deferred:
raise Exception('Error, database must be initialized before creating a connection pool')

return ConnectionContextManager(self.pool_backend)

async def aio_execute_sql(self, sql: str, params=None, fetch_results=None):
Expand All @@ -123,38 +136,28 @@ async def aio_execute(self, query, fetch_results=None):
return await self.aio_execute_sql(sql, params, fetch_results=fetch_results)


class AioPostgresqlMixin(AioDatabase):
"""Mixin for `peewee.PostgresqlDatabase` providing extra methods
class AioPostgresqlMixin(AioDatabase, peewee.PostgresqlDatabase):
"""Extension for `peewee.PostgresqlDatabase` providing extra methods
for managing async connection.
"""

_enable_json: bool
_enable_hstore: bool

pool_backend_cls = PostgresqlPoolBackend

if psycopg2:
Error = psycopg2.Error

def init_async(self, enable_json: bool = False, enable_hstore: bool =False) -> None:
def init_async(self, enable_json: bool = False, enable_hstore: bool = False) -> None:
if not aiopg:
raise Exception("Error, aiopg is not installed!")
self._enable_json = enable_json
self._enable_hstore = enable_hstore

@property
def connect_params_async(self):
"""Connection parameters for `aiopg.Connection`
"""
kwargs = self.connect_params.copy()
kwargs.update({
'minsize': self.min_connections,
'maxsize': self.max_connections,
'enable_json': self._enable_json,
'enable_hstore': self._enable_hstore,
})
return kwargs


class PooledPostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
"""PosgreSQL database driver providing **single drop-in sync**
"""PostgreSQL database driver providing **single drop-in sync**
connection and **async connections pool** interface.
:param max_connections: connections pool size
Expand All @@ -166,15 +169,37 @@ class PooledPostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#PostgresqlDatabase
"""
min_connections: int = 1
max_connections: int = 20

def init(self, database: Optional[str], **kwargs: Any) -> None:
self.min_connections = kwargs.pop('min_connections', 1)
self.max_connections = kwargs.pop('max_connections', 20)
super().init(database, **kwargs)
if min_connections := kwargs.pop('min_connections', False):
self.min_connections = min_connections

if max_connections := kwargs.pop('max_connections', False):
self.max_connections = max_connections

self.init_async()
super().init(database, **kwargs)

@property
def connect_params_async(self):
"""Connection parameters for `aiopg.Connection`
"""
kwargs = self.connect_params.copy()
kwargs.update(
{
'minsize': self.min_connections,
'maxsize': self.max_connections,
'enable_json': self._enable_json,
'enable_hstore': self._enable_hstore,
}
)
return kwargs


class PooledPostgresqlExtDatabase(
AioPostgresqlMixin,
PooledPostgresqlDatabase,
ext.PostgresqlExtDatabase
):
"""PosgreSQL database extended driver providing **single drop-in sync**
Expand All @@ -183,8 +208,6 @@ class PooledPostgresqlExtDatabase(
JSON fields support is always enabled, HStore supports is enabled by
default, but can be disabled with ``register_hstore=False`` argument.
:param max_connections: connections pool size
Example::
database = PooledPostgresqlExtDatabase('test', register_hstore=False,
Expand All @@ -195,14 +218,11 @@ class PooledPostgresqlExtDatabase(
"""

def init(self, database: Optional[str], **kwargs: Any) -> None:
self.min_connections = kwargs.pop('min_connections', 1)
self.max_connections = kwargs.pop('max_connections', 20)
connection_timeout = kwargs.pop('connection_timeout', None)
super().init(database, **kwargs)
self.init_async(
enable_json=True,
enable_hstore=self._register_hstore
)
super().init(database, **kwargs)


class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase):
Expand All @@ -218,6 +238,9 @@ class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase):
See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#MySQLDatabase
"""
min_connections: int = 1
max_connections: int = 20

pool_backend_cls = MysqlPoolBackend

if pymysql:
Expand All @@ -226,18 +249,25 @@ class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase):
def init(self, database: Optional[str], **kwargs: Any) -> None:
if not aiomysql:
raise Exception("Error, aiomysql is not installed!")
self.min_connections = kwargs.pop('min_connections', 1)
self.max_connections = kwargs.pop('max_connections', 20)

if min_connections := kwargs.pop('min_connections', False):
self.min_connections = min_connections

if max_connections := kwargs.pop('max_connections', False):
self.max_connections = max_connections

super().init(database, **kwargs)

@property
def connect_params_async(self):
def connect_params_async(self) -> Dict[str, Any]:
"""Connection parameters for `aiomysql.Connection`
"""
kwargs = self.connect_params.copy()
kwargs.update({
'minsize': self.min_connections,
'maxsize': self.max_connections,
'autocommit': True,
})
kwargs.update(
{
'minsize': self.min_connections,
'maxsize': self.max_connections,
'autocommit': True,
}
)
return kwargs
1 change: 1 addition & 0 deletions peewee_async/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class PoolBackend(metaclass=abc.ABCMeta):
"""Asynchronous database connection pool.
"""

def __init__(self, *, database: str, **kwargs: Any) -> None:
self.pool: Optional[PoolProtocol] = None
self.database = database
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ async def db(request):

params = DB_DEFAULTS[db]
database = DB_CLASSES[db](**params)

database._allow_sync = False
with database.allow_sync():
for model in ALL_MODELS:
Expand Down
16 changes: 15 additions & 1 deletion tests/test_database.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest

from peewee_async import connection_context
from tests.conftest import dbs_all
from peewee_async.databases import AioDatabase
from tests.conftest import dbs_all, MYSQL_DBS, PG_DBS
from tests.db_config import DB_DEFAULTS, DB_CLASSES
from tests.models import TestModel


Expand Down Expand Up @@ -53,3 +55,15 @@ async def test_aio_close_idempotent(db):

await db.aio_close()
assert db.is_connected is False


@pytest.mark.parametrize('db_name', PG_DBS + MYSQL_DBS)
async def test_deferred_init(db_name):
database: AioDatabase = DB_CLASSES[db_name](None)

with pytest.raises(Exception, match='Error, database must be initialized before creating a connection pool'):
await database.aio_execute_sql(sql='SELECT 1;')

database.init(**DB_DEFAULTS[db_name])

await database.aio_execute_sql(sql='SELECT 1;')

0 comments on commit b77ea75

Please sign in to comment.