Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add some typing #284

Merged
merged 2 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions peewee_async/aio_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import peewee
from peewee import PREFETCH_TYPE

from .databases import AioDatabase
from .result_wrappers import fetch_models
from .utils import CursorProtocol
from typing_extensions import Self
from typing import Tuple, List, Any, cast


async def aio_prefetch(sq, *subqueries, prefetch_type):
async def aio_prefetch(sq, *subqueries, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE) -> List[Any]:
"""Asynchronous version of `prefetch()`.

See also:
Expand Down Expand Up @@ -42,10 +46,10 @@ async def aio_prefetch(sq, *subqueries, prefetch_type):

class AioQueryMixin:
@peewee.database_required
async def aio_execute(self, database):
async def aio_execute(self, database: AioDatabase) -> Any:
return await database.aio_execute(self)

async def fetch_results(self, cursor: CursorProtocol):
async def fetch_results(self, cursor: CursorProtocol) -> List[Any]:
return await fetch_models(cursor, self)


Expand Down Expand Up @@ -116,7 +120,7 @@ async def aio_get(self, database=None):
(clone.model, sql, params))

@peewee.database_required
async def aio_count(self, database, clear_limit=False):
async def aio_count(self, database, clear_limit=False) -> int:
"""
Async version of **peewee.SelectBase.count**

Expand All @@ -133,7 +137,10 @@ async def aio_count(self, database, clear_limit=False):
clone = clone.select(peewee.SQL('1'))
except AttributeError:
pass
return await AioSelect([clone], [peewee.fn.COUNT(peewee.SQL('1'))]).aio_scalar(database)
return cast(
int,
await AioSelect([clone], [peewee.fn.COUNT(peewee.SQL('1'))]).aio_scalar(database)
)

@peewee.database_required
async def aio_exists(self, database):
Expand Down Expand Up @@ -164,14 +171,14 @@ def except_(self, rhs):
return AioModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs)
__sub__ = except_

def aio_prefetch(self, *subqueries, **kwargs):
def aio_prefetch(self, *subqueries, prefetch_type: PREFETCH_TYPE = PREFETCH_TYPE.WHERE):
"""
Async version of **peewee.ModelSelect.prefetch**

See also:
http://docs.peewee-orm.com/en/3.15.3/peewee/api.html#ModelSelect.prefetch
"""
return aio_prefetch(self, *subqueries, **kwargs)
return aio_prefetch(self, *subqueries, prefetch_type=prefetch_type)


class AioSelect(AioSelectMixin, peewee.Select):
Expand Down Expand Up @@ -207,39 +214,39 @@ class User(peewee_async.AioModel):
"""

@classmethod
def select(cls, *fields):
def select(cls, *fields) -> AioModelSelect:
is_default = not fields
if not fields:
fields = cls._meta.sorted_fields
return AioModelSelect(cls, fields, is_default=is_default)

@classmethod
def update(cls, __data=None, **update):
def update(cls, __data=None, **update) -> AioModelUpdate:
return AioModelUpdate(cls, cls._normalize_data(__data, update))

@classmethod
def insert(cls, __data=None, **insert):
def insert(cls, __data=None, **insert) -> AioModelInsert:
return AioModelInsert(cls, cls._normalize_data(__data, insert))

@classmethod
def insert_many(cls, rows, fields=None):
def insert_many(cls, rows, fields=None) -> AioModelInsert:
return AioModelInsert(cls, insert=rows, columns=fields)

@classmethod
def insert_from(cls, query, fields):
def insert_from(cls, query, fields) -> AioModelInsert:
columns = [getattr(cls, field) if isinstance(field, str)
else field for field in fields]
return AioModelInsert(cls, insert=query, columns=columns)

@classmethod
def raw(cls, sql, *params):
def raw(cls, sql, *params) -> AioModelRaw:
return AioModelRaw(cls, sql, params)

@classmethod
def delete(cls):
def delete(cls) -> AioModelDelete:
return AioModelDelete(cls)

async def aio_delete_instance(self, recursive=False, delete_nullable=False):
async def aio_delete_instance(self, recursive: bool = False, delete_nullable: bool = False) -> int:
"""
Async version of **peewee.Model.delete_instance**

Expand All @@ -254,9 +261,9 @@ async def aio_delete_instance(self, recursive=False, delete_nullable=False):
await model.update(**{fk.name: None}).where(query).aio_execute()
else:
await model.delete().where(query).aio_execute()
return await type(self).delete().where(self._pk_expr()).aio_execute()
return cast(int, await type(self).delete().where(self._pk_expr()).aio_execute())

async def aio_save(self, force_insert=False, only=None):
async def aio_save(self, force_insert: bool = False, only=None) -> int:
"""
Async version of **peewee.Model.save**

Expand Down Expand Up @@ -306,7 +313,7 @@ async def aio_save(self, force_insert=False, only=None):
return rows

@classmethod
async def aio_get(cls, *query, **filters):
async def aio_get(cls, *query, **filters) -> Self:
"""Async version of **peewee.Model.get**

See also:
Expand All @@ -323,7 +330,7 @@ async def aio_get(cls, *query, **filters):
return await sq.aio_get()

@classmethod
async def aio_get_or_none(cls, *query, **filters):
async def aio_get_or_none(cls, *query, **filters) -> Self | None:
"""
Async version of **peewee.Model.get_or_none**

Expand All @@ -336,7 +343,7 @@ async def aio_get_or_none(cls, *query, **filters):
return None

@classmethod
async def aio_create(cls, **query):
async def aio_create(cls, **query) -> "Self":
"""
Async version of **peewee.Model.create**

Expand All @@ -348,7 +355,7 @@ async def aio_create(cls, **query):
return inst

@classmethod
async def aio_get_or_create(cls, **kwargs):
async def aio_get_or_create(cls, **kwargs) -> Tuple[Self, bool]:
"""
Async version of **peewee.Model.get_or_create**

Expand Down
2 changes: 1 addition & 1 deletion peewee_async/result_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def close(self) -> None:
pass


async def fetch_models(cursor: CursorProtocol, query: BaseQuery):
async def fetch_models(cursor: CursorProtocol, query: BaseQuery) -> List[Any]:
rows = await cursor.fetchall()
sync_cursor = SyncCursorAdapter(rows, cursor.description)
_result_wrapper = query._get_cursor_wrapper(sync_cursor)
Expand Down
5 changes: 4 additions & 1 deletion peewee_async/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Protocol, Optional, Sequence, Set, AsyncContextManager
from typing import Any, Protocol, Optional, Sequence, Set, AsyncContextManager, List

try:
import aiopg
Expand All @@ -23,6 +23,9 @@ class CursorProtocol(Protocol):
async def fetchone(self) -> Any:
...

async def fetchall(self) -> List[Any]:
...

@property
def lastrowid(self) -> int:
...
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ disallow_any_generics = True
disallow_untyped_calls = True
disallow_untyped_defs = True
disallow_incomplete_defs = True
exclude = (venv|load-testing|examples)
exclude = (venv|load-testing|examples|docs)
7 changes: 4 additions & 3 deletions tests/aio_model/test_deleting.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import uuid

from peewee_async.databases import AioDatabase
from tests.conftest import dbs_all, dbs_postgres
from tests.models import TestModel
from tests.utils import model_has_fields


@dbs_all
async def test_delete__count(db):
async def test_delete__count(db: AioDatabase) -> None:
query = TestModel.insert_many([
{'text': "Test %s" % uuid.uuid4()},
{'text': "Test %s" % uuid.uuid4()},
Expand All @@ -19,7 +20,7 @@ async def test_delete__count(db):


@dbs_all
async def test_delete__by_condition(db):
async def test_delete__by_condition(db: AioDatabase) -> None:
expected_text = "text1"
deleted_text = "text2"
query = TestModel.insert_many([
Expand All @@ -36,7 +37,7 @@ async def test_delete__by_condition(db):


@dbs_postgres
async def test_delete__return_model(db):
async def test_delete__return_model(db: AioDatabase) -> None:
m = await TestModel.aio_create(text="text", data="data")

res = await TestModel.delete().returning(TestModel).aio_execute()
Expand Down
17 changes: 9 additions & 8 deletions tests/aio_model/test_inserting.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import uuid

from peewee_async.databases import AioDatabase
from tests.conftest import dbs_all, dbs_postgres
from tests.models import TestModel, UUIDTestModel
from tests.utils import model_has_fields


@dbs_all
async def test_insert_many(db):
async def test_insert_many(db: AioDatabase) -> None:
last_id = await TestModel.insert_many([
{'text': "Test %s" % uuid.uuid4()},
{'text': "Test %s" % uuid.uuid4()},
Expand All @@ -19,7 +20,7 @@ async def test_insert_many(db):


@dbs_all
async def test_insert__return_id(db):
async def test_insert__return_id(db: AioDatabase) -> None:
last_id = await TestModel.insert(text="Test %s" % uuid.uuid4()).aio_execute()

res = await TestModel.select().aio_execute()
Expand All @@ -28,7 +29,7 @@ async def test_insert__return_id(db):


@dbs_postgres
async def test_insert_on_conflict_ignore__last_id_is_none(db):
async def test_insert_on_conflict_ignore__last_id_is_none(db: AioDatabase) -> None:
query = TestModel.insert(text="text").on_conflict_ignore()
await query.aio_execute()

Expand All @@ -38,7 +39,7 @@ async def test_insert_on_conflict_ignore__last_id_is_none(db):


@dbs_postgres
async def test_insert_on_conflict_ignore__return_model(db):
async def test_insert_on_conflict_ignore__return_model(db: AioDatabase) -> None:
query = TestModel.insert(text="text", data="data").on_conflict_ignore().returning(TestModel)

res = await query.aio_execute()
Expand All @@ -55,7 +56,7 @@ async def test_insert_on_conflict_ignore__return_model(db):


@dbs_postgres
async def test_insert_on_conflict_ignore__inserted_once(db):
async def test_insert_on_conflict_ignore__inserted_once(db: AioDatabase) -> None:
query = TestModel.insert(text="text").on_conflict_ignore()
last_id = await query.aio_execute()

Expand All @@ -67,14 +68,14 @@ async def test_insert_on_conflict_ignore__inserted_once(db):


@dbs_postgres
async def test_insert__uuid_pk(db):
async def test_insert__uuid_pk(db: AioDatabase) -> None:
query = UUIDTestModel.insert(text="Test %s" % uuid.uuid4())
last_id = await query.aio_execute()
assert len(str(last_id)) == 36


@dbs_postgres
async def test_insert__return_model(db):
async def test_insert__return_model(db: AioDatabase) -> None:
text = "Test %s" % uuid.uuid4()
data = "data"
query = TestModel.insert(text=text, data=data).returning(TestModel)
Expand All @@ -88,7 +89,7 @@ async def test_insert__return_model(db):


@dbs_postgres
async def test_insert_many__return_model(db):
async def test_insert_many__return_model(db: AioDatabase) -> None:
texts = [f"text{n}" for n in range(2)]
query = TestModel.insert_many([
{"text": text} for text in texts
Expand Down
17 changes: 9 additions & 8 deletions tests/aio_model/test_selecting.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from peewee_async.aio_model import AioModelCompoundSelectQuery, AioModelRaw
from peewee_async.databases import AioDatabase
from tests.conftest import dbs_all
from tests.models import TestModel, TestModelAlpha, TestModelBeta


@dbs_all
async def test_select_w_join(db):
async def test_select_w_join(db: AioDatabase) -> None:
alpha = await TestModelAlpha.aio_create(text="Test 1")
beta = await TestModelBeta.aio_create(alpha_id=alpha.id, text="text")

Expand All @@ -18,7 +19,7 @@ async def test_select_w_join(db):


@dbs_all
async def test_raw_select(db):
async def test_raw_select(db: AioDatabase) -> None:
obj1 = await TestModel.aio_create(text="Test 1")
obj2 = await TestModel.aio_create(text="Test 2")
query = TestModel.raw(
Expand All @@ -30,23 +31,23 @@ async def test_raw_select(db):


@dbs_all
async def test_tuples(db):
async def test_tuples(db: AioDatabase) -> None:
obj = await TestModel.aio_create(text="Test 1")

result = await TestModel.select(TestModel.id, TestModel.text).tuples().aio_execute()
assert result[0] == (obj.id, obj.text)


@dbs_all
async def test_dicts(db):
async def test_dicts(db: AioDatabase) -> None:
obj = await TestModel.aio_create(text="Test 1")

result = await TestModel.select(TestModel.id, TestModel.text).dicts().aio_execute()
assert result[0] == {"id": obj.id, "text": obj.text}


@dbs_all
async def test_union_all(db):
async def test_union_all(db: AioDatabase) -> None:
obj1 = await TestModel.aio_create(text="1")
obj2 = await TestModel.aio_create(text="2")
query = (
Expand All @@ -59,7 +60,7 @@ async def test_union_all(db):


@dbs_all
async def test_union(db):
async def test_union(db: AioDatabase) -> None:
obj1 = await TestModel.aio_create(text="1")
obj2 = await TestModel.aio_create(text="2")
query = (
Expand All @@ -73,7 +74,7 @@ async def test_union(db):


@dbs_all
async def test_intersect(db):
async def test_intersect(db: AioDatabase) -> None:
await TestModel.aio_create(text="1")
await TestModel.aio_create(text="2")
await TestModel.aio_create(text="3")
Expand All @@ -90,7 +91,7 @@ async def test_intersect(db):


@dbs_all
async def test_except(db):
async def test_except(db: AioDatabase) -> None:
await TestModel.aio_create(text="1")
await TestModel.aio_create(text="2")
await TestModel.aio_create(text="3")
Expand Down
Loading
Loading