Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: improve type hints (use type instead of typing.Type) #1864

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions examples/router.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
"""
This example to use router to implement read/write separation
"""

from typing import Type

from tortoise import Tortoise, fields, run_async
from tortoise.models import Model
from tortoise import Model, Tortoise, fields, run_async


class Event(Model):
Expand All @@ -21,10 +17,10 @@ def __str__(self):


class Router:
def db_for_read(self, model: Type[Model]):
def db_for_read(self, model: type[Model]):
return "slave"

def db_for_write(self, model: Type[Model]):
def db_for_write(self, model: type[Model]):
return "master"


Expand Down
10 changes: 5 additions & 5 deletions examples/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
This example demonstrates model signals usage
"""

from typing import Optional, Type
from typing import Optional

from tortoise import BaseDBAsyncClient, Tortoise, fields, run_async
from tortoise.models import Model
Expand All @@ -22,14 +22,14 @@ def __str__(self):

@pre_save(Signal)
async def signal_pre_save(
sender: "Type[Signal]", instance: Signal, using_db, update_fields
sender: "type[Signal]", instance: Signal, using_db, update_fields
) -> None:
print(sender, instance, using_db, update_fields)


@post_save(Signal)
async def signal_post_save(
sender: "Type[Signal]",
sender: "type[Signal]",
instance: Signal,
created: bool,
using_db: "Optional[BaseDBAsyncClient]",
Expand All @@ -40,14 +40,14 @@ async def signal_post_save(

@pre_delete(Signal)
async def signal_pre_delete(
sender: "Type[Signal]", instance: Signal, using_db: "Optional[BaseDBAsyncClient]"
sender: "type[Signal]", instance: Signal, using_db: "Optional[BaseDBAsyncClient]"
) -> None:
print(sender, instance, using_db)


@post_delete(Signal)
async def signal_post_delete(
sender: "Type[Signal]", instance: Signal, using_db: "Optional[BaseDBAsyncClient]"
sender: "type[Signal]", instance: Signal, using_db: "Optional[BaseDBAsyncClient]"
) -> None:
print(sender, instance, using_db)

Expand Down
6 changes: 3 additions & 3 deletions tests/fields/subclass_fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum, IntEnum
from typing import Any, Type
from typing import Any

from tortoise import ConfigurationError
from tortoise.fields import CharField, IntField
Expand All @@ -13,7 +13,7 @@ class EnumField(CharField):

__slots__ = ("enum_type",)

def __init__(self, enum_type: Type[Enum], **kwargs):
def __init__(self, enum_type: type[Enum], **kwargs):
super().__init__(128, **kwargs)
if not issubclass(enum_type, Enum):
raise ConfigurationError(f"{enum_type} is not a subclass of Enum!")
Expand Down Expand Up @@ -48,7 +48,7 @@ class IntEnumField(IntField):

__slots__ = ("enum_type",)

def __init__(self, enum_type: Type[IntEnum], **kwargs):
def __init__(self, enum_type: type[IntEnum], **kwargs):
super().__init__(**kwargs)
if not issubclass(enum_type, IntEnum):
raise ConfigurationError(f"{enum_type} is not a subclass of IntEnum!")
Expand Down
3 changes: 1 addition & 2 deletions tests/fields/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from datetime import date, datetime, time, timedelta
from datetime import timezone as dt_timezone
from time import sleep
from typing import Type
from unittest.mock import patch

import pytz
Expand All @@ -18,7 +17,7 @@


class TestEmpty(test.TestCase):
model: Type[Model] = testmodels.DatetimeFields
model: type[Model] = testmodels.DatetimeFields

async def test_empty(self):
with self.assertRaises(IntegrityError):
Expand Down
4 changes: 1 addition & 3 deletions tests/test_queryset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Type

from tests.testmodels import (
Author,
Book,
Expand Down Expand Up @@ -845,7 +843,7 @@ async def test_joins_in_arithmetic_expressions(self):


class TestNotExist(test.TestCase):
exp_cls: Type[NotExistOrMultiple] = DoesNotExist
exp_cls: type[NotExistOrMultiple] = DoesNotExist

@test.requireCapability(dialect="sqlite")
def test_does_not_exist(self):
Expand Down
10 changes: 5 additions & 5 deletions tests/test_signals.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Type
from typing import Optional

from tests.testmodels import Signals
from tortoise import BaseDBAsyncClient
Expand All @@ -8,15 +8,15 @@

@pre_save(Signals)
async def signal_pre_save(
sender: "Type[Signals]", instance: Signals, using_db, update_fields
sender: "type[Signals]", instance: Signals, using_db, update_fields
) -> None:
await Signals.filter(name="test1").update(name="test_pre-save")
await Signals.filter(name="test5").update(name="test_pre-save")


@post_save(Signals)
async def signal_post_save(
sender: "Type[Signals]",
sender: "type[Signals]",
instance: Signals,
created: bool,
using_db: "Optional[BaseDBAsyncClient]",
Expand All @@ -28,14 +28,14 @@ async def signal_post_save(

@pre_delete(Signals)
async def signal_pre_delete(
sender: "Type[Signals]", instance: Signals, using_db: "Optional[BaseDBAsyncClient]"
sender: "type[Signals]", instance: Signals, using_db: "Optional[BaseDBAsyncClient]"
) -> None:
await Signals.filter(name="test3").update(name="test_pre-delete")


@post_delete(Signals)
async def signal_post_delete(
sender: "Type[Signals]", instance: Signals, using_db: "Optional[BaseDBAsyncClient]"
sender: "type[Signals]", instance: Signals, using_db: "Optional[BaseDBAsyncClient]"
) -> None:
await Signals.filter(name="test4").update(name="test_post-delete")

Expand Down
4 changes: 1 addition & 3 deletions tests/test_table_name.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Type

from tortoise import Tortoise, fields
from tortoise.contrib.test import SimpleTestCase
from tortoise.models import Model


def table_name_generator(model_cls: Type[Model]):
def table_name_generator(model_cls: type[Model]):
return f"test_{model_cls.__name__.lower()}"


Expand Down
20 changes: 10 additions & 10 deletions tortoise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from copy import deepcopy
from inspect import isclass
from types import ModuleType
from typing import Any, Type, cast
from typing import Any, cast

from pypika_tortoise import Query, Table

Expand All @@ -34,8 +34,8 @@


class Tortoise:
apps: dict[str, dict[str, Type["Model"]]] = {}
table_name_generator: Callable[[Type["Model"]], str] | None = None
apps: dict[str, dict[str, type["Model"]]] = {}
table_name_generator: Callable[[type["Model"]], str] | None = None
_inited: bool = False

@classmethod
Expand All @@ -53,7 +53,7 @@ def get_connection(cls, connection_name: str) -> BaseDBAsyncClient:

@classmethod
def describe_model(
cls, model: Type["Model"], serializable: bool = True
cls, model: type["Model"], serializable: bool = True
) -> dict[str, Any]: # pragma: nocoverage
"""
Describes the given list of models or ALL registered models.
Expand All @@ -79,7 +79,7 @@ def describe_model(

@classmethod
def describe_models(
cls, models: list[Type["Model"]] | None = None, serializable: bool = True
cls, models: list[type["Model"]] | None = None, serializable: bool = True
) -> dict[str, dict[str, Any]]:
"""
Describes the given list of models or ALL registered models.
Expand Down Expand Up @@ -115,7 +115,7 @@ def describe_models(

@classmethod
def _init_relations(cls) -> None:
def get_related_model(related_app_name: str, related_model_name: str) -> Type["Model"]:
def get_related_model(related_app_name: str, related_model_name: str) -> type["Model"]:
"""
Test, if app and model really exist. Throws a ConfigurationError with a hopefully
helpful message. If successful, returns the requested model.
Expand Down Expand Up @@ -151,7 +151,7 @@ def split_reference(reference: str) -> tuple[str, str]:
)
return items[0], items[1]

def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:
def init_fk_o2o_field(model: type["Model"], field: str, is_o2o=False) -> None:
fk_object = cast(
"OneToOneFieldInstance | ForeignKeyFieldInstance", model._meta.fields_map[field]
)
Expand Down Expand Up @@ -284,7 +284,7 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None:
related_model._meta.add_field(backward_relation_name, m2m_relation)

@classmethod
def _discover_models(cls, models_path: ModuleType | str, app_label: str) -> list[Type["Model"]]:
def _discover_models(cls, models_path: ModuleType | str, app_label: str) -> list[type["Model"]]:
if isinstance(models_path, ModuleType):
module = models_path
else:
Expand Down Expand Up @@ -329,7 +329,7 @@ def init_models(

:raises ConfigurationError: If models are invalid.
"""
app_models: list[Type[Model]] = []
app_models: list[type[Model]] = []
for models_path in models_paths:
app_models += cls._discover_models(models_path, app_label)

Expand Down Expand Up @@ -399,7 +399,7 @@ async def init(
use_tz: bool = False,
timezone: str = "UTC",
routers: list[str | type] | None = None,
table_name_generator: Callable[[Type["Model"]], str] | None = None,
table_name_generator: Callable[[type["Model"]], str] | None = None,
) -> None:
"""
Sets up Tortoise-ORM: loads apps and models, configures database connections but does not
Expand Down
14 changes: 7 additions & 7 deletions tortoise/backends/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import abc
import asyncio
from collections.abc import Sequence
from typing import Any, Generic, Type, TypeVar, cast
from typing import Any, Generic, TypeVar, cast

from pypika_tortoise import Query

Expand Down Expand Up @@ -85,17 +85,17 @@ class BaseDBAsyncClient(abc.ABC):
Parameters get passed as kwargs, and is mostly driver specific.

.. attribute:: query_class
:annotation: Type[pypika_tortoise.Query]
:annotation: type[pypika_tortoise.Query]

The PyPika Query dialect (low level dialect)

.. attribute:: executor_class
:annotation: Type[BaseExecutor]
:annotation: type[BaseExecutor]

The executor dialect class (high level dialect)

.. attribute:: schema_generator
:annotation: Type[BaseSchemaGenerator]
:annotation: type[BaseSchemaGenerator]

The DDL schema generator

Expand All @@ -109,9 +109,9 @@ class BaseDBAsyncClient(abc.ABC):
_parent: "BaseDBAsyncClient"
_pool: Any
connection_name: str
query_class: Type[Query] = Query
executor_class: Type[BaseExecutor] = BaseExecutor
schema_generator: Type[BaseSchemaGenerator] = BaseSchemaGenerator
query_class: type[Query] = Query
executor_class: type[BaseExecutor] = BaseExecutor
schema_generator: type[BaseSchemaGenerator] = BaseSchemaGenerator
capabilities: Capabilities = Capabilities("")

def __init__(self, connection_name: str, fetch_inserted: bool = True, **kwargs: Any) -> None:
Expand Down
14 changes: 7 additions & 7 deletions tortoise/backends/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import decimal
from collections.abc import Callable, Iterable, Sequence
from copy import copy
from typing import TYPE_CHECKING, Any, Optional, Type, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union, cast

from pypika_tortoise import JoinType, Parameter, Table
from pypika_tortoise.queries import QueryBuilder
Expand Down Expand Up @@ -38,12 +38,12 @@ class BaseExecutor:

def __init__(
self,
model: "Type[Model]",
model: "type[Model]",
db: "BaseDBAsyncClient",
prefetch_map: "Optional[dict[str, set[Union[str, Prefetch]]]]" = None,
prefetch_queries: Optional[dict[str, list[tuple[Optional[str], "QuerySet"]]]] = None,
select_related_idx: Optional[
list[tuple["Type[Model]", int, str, "Type[Model]", Iterable[Optional[str]]]]
list[tuple["type[Model]", int, str, "type[Model]", Iterable[Optional[str]]]]
] = None,
) -> None:
self.model = model
Expand Down Expand Up @@ -262,7 +262,7 @@ def get_update_sql(
return sql

async def execute_update(
self, instance: "Union[Type[Model], Model]", update_fields: Optional[Iterable[str]]
self, instance: "Union[type[Model], Model]", update_fields: Optional[Iterable[str]]
) -> int:
values = []
expressions = {}
Expand All @@ -279,7 +279,7 @@ async def execute_update(
await self.db.execute_query(self.get_update_sql(update_fields, expressions), values)
)[0]

async def execute_delete(self, instance: "Union[Type[Model], Model]") -> int:
async def execute_delete(self, instance: "Union[type[Model], Model]") -> int:
return (
await self.db.execute_query(
self.delete_query, [self.model._meta.pk.to_db_value(instance.pk, instance)]
Expand Down Expand Up @@ -466,7 +466,7 @@ async def _prefetch_direct_relation(
to_attr, related_queryset = related_query
related_objects_for_fetch: dict[str, list] = {}
relation_key_field = f"{field}_id"
model_to_field: dict["Type[Model]", str] = {}
model_to_field: dict["type[Model]", str] = {}
for instance in instance_list:
if (value := getattr(instance, relation_key_field)) is not None:
if (model_cls := instance.__class__) in model_to_field:
Expand Down Expand Up @@ -510,7 +510,7 @@ def _make_prefetch_queries(self) -> None:
to_attr, related_query = self._prefetch_queries[field_name][0]
else:
relation_field = self.model._meta.fields_map[field_name]
related_model: "Type[Model]" = relation_field.related_model # type: ignore
related_model: "type[Model]" = relation_field.related_model # type: ignore
related_query = related_model.all().using_db(self.db)
related_query.query = copy(
related_query.model._meta.basequery
Expand Down
Loading
Loading