Skip to content

Commit

Permalink
Apply optional compat patching from Database.aio_execute()
Browse files Browse the repository at this point in the history
  • Loading branch information
rudyryk committed May 9, 2024
1 parent b7b1255 commit 72f033f
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 25 deletions.
4 changes: 4 additions & 0 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from importlib.metadata import version
from playhouse.db_url import register_database
from peewee_async_compat import Manager, count, execute, prefetch, scalar
from peewee_async_compat import _patch_query_with_compat_methods

try:
import aiopg
Expand Down Expand Up @@ -318,6 +319,9 @@ async def aio_execute(self, query, fetch_results=None):
don't need to close cursor It will be closed automatically.
:return: result depends on query type, it's the same as for sync `query.execute()`
"""
# To make `Database.aio_execute` compatible with peewee's sync queries we
# apply optional patching, it will do nothing for Aio-counterparts:
_patch_query_with_compat_methods(query, None)
sql, params = query.sql()
fetch_results = fetch_results or getattr(query, 'fetch_results', None)
return await self.aio_execute_sql(sql, params, fetch_results=fetch_results)
Expand Down
21 changes: 0 additions & 21 deletions peewee_async_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,8 @@ async def fetch_results(cursor):

async def prefetch(sq, *subqueries, prefetch_type):
"""Asynchronous version of the `prefetch()` from peewee."""
from peewee_async import AioModelSelect

database = _query_db(sq)
if not subqueries:
_patch_query_with_compat_methods(sq, AioModelSelect)
result = await database.aio_execute(sq)
return result

Expand All @@ -136,7 +133,6 @@ async def prefetch(sq, *subqueries, prefetch_type):
id_map = deps[query_model]
has_relations = bool(rel_map.get(query_model))
database = _query_db(pq.query)
_patch_query_with_compat_methods(pq.query, AioModelSelect)
result = await database.aio_execute(pq.query)

for instance in result:
Expand Down Expand Up @@ -230,8 +226,6 @@ async def my_async_func():
All will return `MyModel` instance with `id = 1`
"""
from peewee_async import AioModelSelect # noqa

await self.connect()

if isinstance(source_, peewee.Query):
Expand All @@ -246,8 +240,6 @@ async def my_async_func():
if conditions:
query = query.where(*conditions)

_patch_query_with_compat_methods(query, AioModelSelect)

try:
result = await self.execute(query)
return list(result)[0]
Expand All @@ -256,13 +248,9 @@ async def my_async_func():

async def create(self, model_, **data):
"""Create a new object saved to database."""
from peewee_async import AioModelInsert

obj = model_(**data)
query = model_.insert(**dict(obj.__data__))

_patch_query_with_compat_methods(query, AioModelInsert)

pk = await self.execute(query)
if obj._pk is None:
obj._pk = pk
Expand Down Expand Up @@ -298,8 +286,6 @@ async def update(self, obj, only=None):
:param only: (optional) the list/tuple of fields or
field names to update
"""
from peewee_async import AioModelUpdate # noqa

field_dict = dict(obj.__data__)
pk_field = obj._meta.primary_key

Expand All @@ -317,7 +303,6 @@ async def update(self, obj, only=None):

query = obj.update(**field_dict).where(obj._pk_expr())

_patch_query_with_compat_methods(query, AioModelUpdate)

result = await self.execute(query)
obj._dirty.clear()
Expand All @@ -326,22 +311,17 @@ async def update(self, obj, only=None):

async def delete(self, obj, recursive=False, delete_nullable=False):
"""Delete object from database."""
from peewee_async import AioModelDelete, AioModelUpdate
if recursive:
dependencies = obj.dependencies(delete_nullable)
for cond, fk in reversed(list(dependencies)):
model = fk.model
if fk.null and not delete_nullable:
sq = model.update(**{fk.name: None}).where(cond)
_patch_query_with_compat_methods(sq, AioModelUpdate)
else:
sq = model.delete().where(cond)
_patch_query_with_compat_methods(sq, AioModelDelete)
await self.execute(sq)

query = obj.delete().where(obj._pk_expr())
_patch_query_with_compat_methods(query, AioModelDelete)

return (await self.execute(query))

async def create_or_get(self, model_, **kwargs):
Expand All @@ -361,7 +341,6 @@ async def create_or_get(self, model_, **kwargs):

async def execute(self, query):
"""Execute query asyncronously."""
_patch_query_with_compat_methods(query, None)
return await self.database.aio_execute(query)

async def prefetch(self, query, *subqueries, prefetch_type=peewee.PREFETCH_TYPE.JOIN):
Expand Down
6 changes: 2 additions & 4 deletions tests/aio_model/test_selecting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import peewee
import pytest
from tests.conftest import all_dbs
from tests.models import TestModel, TestModelAlpha, TestModelBeta

Expand All @@ -18,7 +17,6 @@ async def test_select_w_join(manager):
assert result.joined_alpha.id == alpha.id


@pytest.mark.skip
@all_dbs
async def test_select_compound(manager):
obj1 = await manager.create(TestModel, text="Test 1")
Expand All @@ -29,8 +27,8 @@ async def test_select_compound(manager):
)
assert isinstance(query, peewee.ModelCompoundSelectQuery)
# NOTE: Two `AioModelSelect` when joining via `|` produce `ModelCompoundSelectQuery`
# without `aio_execute()` method, so only compat mode is available for now.
result = await query.aio_execute()
# without `aio_execute()` method, so using `database.aio_execute()` here.
result = await manager.database.aio_execute(query)
assert len(list(result)) == 2
assert obj1 in list(result)
assert obj2 in list(result)

0 comments on commit 72f033f

Please sign in to comment.