Skip to content

Commit

Permalink
chore: AsyncQueryWrapper removed (#283)
Browse files Browse the repository at this point in the history
  • Loading branch information
kalombos authored Aug 17, 2024
1 parent b77ea75 commit 6ef315c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 71 deletions.
18 changes: 7 additions & 11 deletions peewee_async/aio_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import peewee

from .result_wrappers import AsyncQueryWrapper
from .result_wrappers import fetch_models
from .utils import CursorProtocol


Expand Down Expand Up @@ -45,29 +45,29 @@ class AioQueryMixin:
async def aio_execute(self, database):
return await database.aio_execute(self)

async def make_async_query_wrapper(self, cursor: CursorProtocol):
return await AsyncQueryWrapper.make_for_all_rows(cursor, self)
async def fetch_results(self, cursor: CursorProtocol):
return await fetch_models(cursor, self)


class AioModelDelete(peewee.ModelDelete, AioQueryMixin):
async def fetch_results(self, cursor: CursorProtocol):
if self._returning:
return await self.make_async_query_wrapper(cursor)
return await fetch_models(cursor, self)
return cursor.rowcount


class AioModelUpdate(peewee.ModelUpdate, AioQueryMixin):

async def fetch_results(self, cursor: CursorProtocol):
if self._returning:
return await self.make_async_query_wrapper(cursor)
return await fetch_models(cursor, self)
return cursor.rowcount


class AioModelInsert(peewee.ModelInsert, AioQueryMixin):
async def fetch_results(self, cursor: CursorProtocol):
if self._returning is not None and len(self._returning) > 1:
return await self.make_async_query_wrapper(cursor)
return await fetch_models(cursor, self)

if self._returning:
row = await cursor.fetchone()
Expand All @@ -77,15 +77,11 @@ async def fetch_results(self, cursor: CursorProtocol):


class AioModelRaw(peewee.ModelRaw, AioQueryMixin):
async def fetch_results(self, cursor: CursorProtocol):
return await self.make_async_query_wrapper(cursor)
pass


class AioSelectMixin(AioQueryMixin):

async def fetch_results(self, cursor: CursorProtocol):
return await self.make_async_query_wrapper(cursor)

@peewee.database_required
async def aio_scalar(self, database, as_tuple=False):
"""
Expand Down
68 changes: 8 additions & 60 deletions peewee_async/result_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Any, List, Iterator
from typing import Any, List
from typing import Optional, Sequence

from peewee import CursorWrapper, BaseQuery
from peewee import BaseQuery

from .utils import CursorProtocol


class RowsCursor(object):
class SyncCursorAdapter(object):
def __init__(self, rows: List[Any], description: Optional[Sequence[Any]]) -> None:
self._rows = rows
self.description = description
Expand All @@ -23,60 +23,8 @@ def close(self) -> None:
pass


class AsyncQueryWrapper:
"""Async query results wrapper for async `select()`. Internally uses
results wrapper produced by sync peewee select query.
Arguments:
result_wrapper -- empty results wrapper produced by sync `execute()`
call cursor -- async cursor just executed query
To retrieve results after async fetching just iterate over this class
instance, like you generally iterate over sync results wrapper.
"""
def __init__(self, *, cursor: CursorProtocol, query: BaseQuery) -> None:
self._cursor = cursor
self._rows: List[Any] = []
self._result_cache: Optional[List[Any]] = None
self._result_wrapper = self._get_result_wrapper(query)

def __iter__(self) -> Iterator[Any]:
return iter(self._result_wrapper)

def __len__(self) -> int:
return len(self._rows)

def __getitem__(self, idx: int) -> Any:
# NOTE: side effects will appear when both
# iterating and accessing by index!
if self._result_cache is None:
self._result_cache = list(self)
return self._result_cache[idx]

def _get_result_wrapper(self, query: BaseQuery) -> CursorWrapper:
"""Get result wrapper class.
"""
cursor = RowsCursor(self._rows, self._cursor.description)
return query._get_cursor_wrapper(cursor)

async def fetchone(self) -> None:
"""Fetch single row from the cursor.
"""
row = await self._cursor.fetchone()
if not row:
raise GeneratorExit
self._rows.append(row)

async def fetchall(self) -> None:
try:
while True:
await self.fetchone()
except GeneratorExit:
pass

@classmethod
async def make_for_all_rows(cls, cursor: CursorProtocol, query: BaseQuery) -> 'AsyncQueryWrapper':
result = AsyncQueryWrapper(cursor=cursor, query=query)
await result.fetchall()
return result
async def fetch_models(cursor: CursorProtocol, query: BaseQuery):

This comment has been minimized.

Copy link
@gshmu

gshmu Aug 21, 2024

I want support iter_models, aio fetch row from db one by one.

rows = await cursor.fetchall()
sync_cursor = SyncCursorAdapter(rows, cursor.description)
_result_wrapper = query._get_cursor_wrapper(sync_cursor)
return list(_result_wrapper)

0 comments on commit 6ef315c

Please sign in to comment.