From 0b8127bc67f8ec102e2ca6bed0c0ad9b27c9d878 Mon Sep 17 00:00:00 2001 From: Alexey Kinev Date: Thu, 9 May 2024 21:32:20 +0400 Subject: [PATCH] Apply optional compat patching from Database.aio_execute() --- peewee_async.py | 4 ++++ peewee_async_compat.py | 21 --------------------- tests/aio_model/test_selecting.py | 6 ++---- 3 files changed, 6 insertions(+), 25 deletions(-) diff --git a/peewee_async.py b/peewee_async.py index 5998f03..d8b3c2c 100644 --- a/peewee_async.py +++ b/peewee_async.py @@ -24,6 +24,7 @@ from importlib.metadata import version from playhouse.db_url import register_database from peewee_async_compat import Manager, count, execute, prefetch, scalar +from peewee_async_compat import _patch_query_with_compat_methods try: import aiopg @@ -318,6 +319,9 @@ async def aio_execute(self, query, fetch_results=None): don't need to close cursor It will be closed automatically. :return: result depends on query type, it's the same as for sync `query.execute()` """ + # To make `Database.aio_execute` compatible with peewee's sync queries we + # apply optional patching, it will do nothing for Aio-counterparts: + _patch_query_with_compat_methods(query, None) sql, params = query.sql() fetch_results = fetch_results or getattr(query, 'fetch_results', None) return await self.aio_execute_sql(sql, params, fetch_results=fetch_results) diff --git a/peewee_async_compat.py b/peewee_async_compat.py index 1e6abff..72cfde2 100644 --- a/peewee_async_compat.py +++ b/peewee_async_compat.py @@ -113,11 +113,8 @@ async def fetch_results(cursor): async def prefetch(sq, *subqueries, prefetch_type): """Asynchronous version of the `prefetch()` from peewee.""" - from peewee_async import AioModelSelect - database = _query_db(sq) if not subqueries: - _patch_query_with_compat_methods(sq, AioModelSelect) result = await database.aio_execute(sq) return result @@ -136,7 +133,6 @@ async def prefetch(sq, *subqueries, prefetch_type): id_map = deps[query_model] has_relations = bool(rel_map.get(query_model)) database = _query_db(pq.query) - _patch_query_with_compat_methods(pq.query, AioModelSelect) result = await database.aio_execute(pq.query) for instance in result: @@ -230,8 +226,6 @@ async def my_async_func(): All will return `MyModel` instance with `id = 1` """ - from peewee_async import AioModelSelect # noqa - await self.connect() if isinstance(source_, peewee.Query): @@ -246,8 +240,6 @@ async def my_async_func(): if conditions: query = query.where(*conditions) - _patch_query_with_compat_methods(query, AioModelSelect) - try: result = await self.execute(query) return list(result)[0] @@ -256,13 +248,9 @@ async def my_async_func(): async def create(self, model_, **data): """Create a new object saved to database.""" - from peewee_async import AioModelInsert - obj = model_(**data) query = model_.insert(**dict(obj.__data__)) - _patch_query_with_compat_methods(query, AioModelInsert) - pk = await self.execute(query) if obj._pk is None: obj._pk = pk @@ -298,8 +286,6 @@ async def update(self, obj, only=None): :param only: (optional) the list/tuple of fields or field names to update """ - from peewee_async import AioModelUpdate # noqa - field_dict = dict(obj.__data__) pk_field = obj._meta.primary_key @@ -317,7 +303,6 @@ async def update(self, obj, only=None): query = obj.update(**field_dict).where(obj._pk_expr()) - _patch_query_with_compat_methods(query, AioModelUpdate) result = await self.execute(query) obj._dirty.clear() @@ -326,22 +311,17 @@ async def update(self, obj, only=None): async def delete(self, obj, recursive=False, delete_nullable=False): """Delete object from database.""" - from peewee_async import AioModelDelete, AioModelUpdate if recursive: dependencies = obj.dependencies(delete_nullable) for cond, fk in reversed(list(dependencies)): model = fk.model if fk.null and not delete_nullable: sq = model.update(**{fk.name: None}).where(cond) - _patch_query_with_compat_methods(sq, AioModelUpdate) else: sq = model.delete().where(cond) - _patch_query_with_compat_methods(sq, AioModelDelete) await self.execute(sq) query = obj.delete().where(obj._pk_expr()) - _patch_query_with_compat_methods(query, AioModelDelete) - return (await self.execute(query)) async def create_or_get(self, model_, **kwargs): @@ -361,7 +341,6 @@ async def create_or_get(self, model_, **kwargs): async def execute(self, query): """Execute query asyncronously.""" - _patch_query_with_compat_methods(query, None) return await self.database.aio_execute(query) async def prefetch(self, query, *subqueries, prefetch_type=peewee.PREFETCH_TYPE.JOIN): diff --git a/tests/aio_model/test_selecting.py b/tests/aio_model/test_selecting.py index 9749067..749cac2 100644 --- a/tests/aio_model/test_selecting.py +++ b/tests/aio_model/test_selecting.py @@ -1,5 +1,4 @@ import peewee -import pytest from tests.conftest import all_dbs from tests.models import TestModel, TestModelAlpha, TestModelBeta @@ -18,7 +17,6 @@ async def test_select_w_join(manager): assert result.joined_alpha.id == alpha.id -@pytest.mark.skip @all_dbs async def test_select_compound(manager): obj1 = await manager.create(TestModel, text="Test 1") @@ -29,8 +27,8 @@ async def test_select_compound(manager): ) assert isinstance(query, peewee.ModelCompoundSelectQuery) # NOTE: Two `AioModelSelect` when joining via `|` produce `ModelCompoundSelectQuery` - # without `aio_execute()` method, so only compat mode is available for now. - result = await query.aio_execute() + # without `aio_execute()` method, so using `database.aio_execute()` here. + result = await manager.database.aio_execute(query) assert len(list(result)) == 2 assert obj1 in list(result) assert obj2 in list(result)