diff --git a/examples/router.py b/examples/router.py index d88f9ce06..a648a46ef 100644 --- a/examples/router.py +++ b/examples/router.py @@ -1,11 +1,7 @@ """ This example to use router to implement read/write separation """ - -from typing import Type - -from tortoise import Tortoise, fields, run_async -from tortoise.models import Model +from tortoise import Model, Tortoise, fields, run_async class Event(Model): @@ -21,10 +17,10 @@ def __str__(self): class Router: - def db_for_read(self, model: Type[Model]): + def db_for_read(self, model: type[Model]): return "slave" - def db_for_write(self, model: Type[Model]): + def db_for_write(self, model: type[Model]): return "master" diff --git a/examples/signals.py b/examples/signals.py index 929d0f6f6..ea3fc692c 100644 --- a/examples/signals.py +++ b/examples/signals.py @@ -2,7 +2,7 @@ This example demonstrates model signals usage """ -from typing import Optional, Type +from typing import Optional from tortoise import BaseDBAsyncClient, Tortoise, fields, run_async from tortoise.models import Model @@ -22,14 +22,14 @@ def __str__(self): @pre_save(Signal) async def signal_pre_save( - sender: "Type[Signal]", instance: Signal, using_db, update_fields + sender: "type[Signal]", instance: Signal, using_db, update_fields ) -> None: print(sender, instance, using_db, update_fields) @post_save(Signal) async def signal_post_save( - sender: "Type[Signal]", + sender: "type[Signal]", instance: Signal, created: bool, using_db: "Optional[BaseDBAsyncClient]", @@ -40,14 +40,14 @@ async def signal_post_save( @pre_delete(Signal) async def signal_pre_delete( - sender: "Type[Signal]", instance: Signal, using_db: "Optional[BaseDBAsyncClient]" + sender: "type[Signal]", instance: Signal, using_db: "Optional[BaseDBAsyncClient]" ) -> None: print(sender, instance, using_db) @post_delete(Signal) async def signal_post_delete( - sender: "Type[Signal]", instance: Signal, using_db: "Optional[BaseDBAsyncClient]" + sender: "type[Signal]", instance: Signal, using_db: "Optional[BaseDBAsyncClient]" ) -> None: print(sender, instance, using_db) diff --git a/tests/fields/subclass_fields.py b/tests/fields/subclass_fields.py index 265127737..e82a9609d 100644 --- a/tests/fields/subclass_fields.py +++ b/tests/fields/subclass_fields.py @@ -1,5 +1,5 @@ from enum import Enum, IntEnum -from typing import Any, Type +from typing import Any from tortoise import ConfigurationError from tortoise.fields import CharField, IntField @@ -13,7 +13,7 @@ class EnumField(CharField): __slots__ = ("enum_type",) - def __init__(self, enum_type: Type[Enum], **kwargs): + def __init__(self, enum_type: type[Enum], **kwargs): super().__init__(128, **kwargs) if not issubclass(enum_type, Enum): raise ConfigurationError(f"{enum_type} is not a subclass of Enum!") @@ -48,7 +48,7 @@ class IntEnumField(IntField): __slots__ = ("enum_type",) - def __init__(self, enum_type: Type[IntEnum], **kwargs): + def __init__(self, enum_type: type[IntEnum], **kwargs): super().__init__(**kwargs) if not issubclass(enum_type, IntEnum): raise ConfigurationError(f"{enum_type} is not a subclass of IntEnum!") diff --git a/tests/fields/test_time.py b/tests/fields/test_time.py index 94f11cf3a..8a6c2af1f 100644 --- a/tests/fields/test_time.py +++ b/tests/fields/test_time.py @@ -2,7 +2,6 @@ from datetime import date, datetime, time, timedelta from datetime import timezone as dt_timezone from time import sleep -from typing import Type from unittest.mock import patch import pytz @@ -18,7 +17,7 @@ class TestEmpty(test.TestCase): - model: Type[Model] = testmodels.DatetimeFields + model: type[Model] = testmodels.DatetimeFields async def test_empty(self): with self.assertRaises(IntegrityError): diff --git a/tests/test_queryset.py b/tests/test_queryset.py index 8daf7fd8c..acaf34a6a 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -1,5 +1,3 @@ -from typing import Type - from tests.testmodels import ( Author, Book, @@ -845,7 +843,7 @@ async def test_joins_in_arithmetic_expressions(self): class TestNotExist(test.TestCase): - exp_cls: Type[NotExistOrMultiple] = DoesNotExist + exp_cls: type[NotExistOrMultiple] = DoesNotExist @test.requireCapability(dialect="sqlite") def test_does_not_exist(self): diff --git a/tests/test_signals.py b/tests/test_signals.py index c75000834..c915c3f4c 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -1,4 +1,4 @@ -from typing import Optional, Type +from typing import Optional from tests.testmodels import Signals from tortoise import BaseDBAsyncClient @@ -8,7 +8,7 @@ @pre_save(Signals) async def signal_pre_save( - sender: "Type[Signals]", instance: Signals, using_db, update_fields + sender: "type[Signals]", instance: Signals, using_db, update_fields ) -> None: await Signals.filter(name="test1").update(name="test_pre-save") await Signals.filter(name="test5").update(name="test_pre-save") @@ -16,7 +16,7 @@ async def signal_pre_save( @post_save(Signals) async def signal_post_save( - sender: "Type[Signals]", + sender: "type[Signals]", instance: Signals, created: bool, using_db: "Optional[BaseDBAsyncClient]", @@ -28,14 +28,14 @@ async def signal_post_save( @pre_delete(Signals) async def signal_pre_delete( - sender: "Type[Signals]", instance: Signals, using_db: "Optional[BaseDBAsyncClient]" + sender: "type[Signals]", instance: Signals, using_db: "Optional[BaseDBAsyncClient]" ) -> None: await Signals.filter(name="test3").update(name="test_pre-delete") @post_delete(Signals) async def signal_post_delete( - sender: "Type[Signals]", instance: Signals, using_db: "Optional[BaseDBAsyncClient]" + sender: "type[Signals]", instance: Signals, using_db: "Optional[BaseDBAsyncClient]" ) -> None: await Signals.filter(name="test4").update(name="test_post-delete") diff --git a/tests/test_table_name.py b/tests/test_table_name.py index 9d00c0629..04ca9197d 100644 --- a/tests/test_table_name.py +++ b/tests/test_table_name.py @@ -1,11 +1,9 @@ -from typing import Type - from tortoise import Tortoise, fields from tortoise.contrib.test import SimpleTestCase from tortoise.models import Model -def table_name_generator(model_cls: Type[Model]): +def table_name_generator(model_cls: type[Model]): return f"test_{model_cls.__name__.lower()}" diff --git a/tortoise/__init__.py b/tortoise/__init__.py index 88eb7f458..9cf7e4627 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -11,7 +11,7 @@ from copy import deepcopy from inspect import isclass from types import ModuleType -from typing import Any, Type, cast +from typing import Any, cast from pypika_tortoise import Query, Table @@ -34,8 +34,8 @@ class Tortoise: - apps: dict[str, dict[str, Type["Model"]]] = {} - table_name_generator: Callable[[Type["Model"]], str] | None = None + apps: dict[str, dict[str, type["Model"]]] = {} + table_name_generator: Callable[[type["Model"]], str] | None = None _inited: bool = False @classmethod @@ -53,7 +53,7 @@ def get_connection(cls, connection_name: str) -> BaseDBAsyncClient: @classmethod def describe_model( - cls, model: Type["Model"], serializable: bool = True + cls, model: type["Model"], serializable: bool = True ) -> dict[str, Any]: # pragma: nocoverage """ Describes the given list of models or ALL registered models. @@ -79,7 +79,7 @@ def describe_model( @classmethod def describe_models( - cls, models: list[Type["Model"]] | None = None, serializable: bool = True + cls, models: list[type["Model"]] | None = None, serializable: bool = True ) -> dict[str, dict[str, Any]]: """ Describes the given list of models or ALL registered models. @@ -115,7 +115,7 @@ def describe_models( @classmethod def _init_relations(cls) -> None: - def get_related_model(related_app_name: str, related_model_name: str) -> Type["Model"]: + def get_related_model(related_app_name: str, related_model_name: str) -> type["Model"]: """ Test, if app and model really exist. Throws a ConfigurationError with a hopefully helpful message. If successful, returns the requested model. @@ -151,7 +151,7 @@ def split_reference(reference: str) -> tuple[str, str]: ) return items[0], items[1] - def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None: + def init_fk_o2o_field(model: type["Model"], field: str, is_o2o=False) -> None: fk_object = cast( "OneToOneFieldInstance | ForeignKeyFieldInstance", model._meta.fields_map[field] ) @@ -284,7 +284,7 @@ def init_fk_o2o_field(model: Type["Model"], field: str, is_o2o=False) -> None: related_model._meta.add_field(backward_relation_name, m2m_relation) @classmethod - def _discover_models(cls, models_path: ModuleType | str, app_label: str) -> list[Type["Model"]]: + def _discover_models(cls, models_path: ModuleType | str, app_label: str) -> list[type["Model"]]: if isinstance(models_path, ModuleType): module = models_path else: @@ -329,7 +329,7 @@ def init_models( :raises ConfigurationError: If models are invalid. """ - app_models: list[Type[Model]] = [] + app_models: list[type[Model]] = [] for models_path in models_paths: app_models += cls._discover_models(models_path, app_label) @@ -399,7 +399,7 @@ async def init( use_tz: bool = False, timezone: str = "UTC", routers: list[str | type] | None = None, - table_name_generator: Callable[[Type["Model"]], str] | None = None, + table_name_generator: Callable[[type["Model"]], str] | None = None, ) -> None: """ Sets up Tortoise-ORM: loads apps and models, configures database connections but does not diff --git a/tortoise/backends/base/client.py b/tortoise/backends/base/client.py index 7c6aa3f8f..98bfafd5e 100644 --- a/tortoise/backends/base/client.py +++ b/tortoise/backends/base/client.py @@ -3,7 +3,7 @@ import abc import asyncio from collections.abc import Sequence -from typing import Any, Generic, Type, TypeVar, cast +from typing import Any, Generic, TypeVar, cast from pypika_tortoise import Query @@ -85,17 +85,17 @@ class BaseDBAsyncClient(abc.ABC): Parameters get passed as kwargs, and is mostly driver specific. .. attribute:: query_class - :annotation: Type[pypika_tortoise.Query] + :annotation: type[pypika_tortoise.Query] The PyPika Query dialect (low level dialect) .. attribute:: executor_class - :annotation: Type[BaseExecutor] + :annotation: type[BaseExecutor] The executor dialect class (high level dialect) .. attribute:: schema_generator - :annotation: Type[BaseSchemaGenerator] + :annotation: type[BaseSchemaGenerator] The DDL schema generator @@ -109,9 +109,9 @@ class BaseDBAsyncClient(abc.ABC): _parent: "BaseDBAsyncClient" _pool: Any connection_name: str - query_class: Type[Query] = Query - executor_class: Type[BaseExecutor] = BaseExecutor - schema_generator: Type[BaseSchemaGenerator] = BaseSchemaGenerator + query_class: type[Query] = Query + executor_class: type[BaseExecutor] = BaseExecutor + schema_generator: type[BaseSchemaGenerator] = BaseSchemaGenerator capabilities: Capabilities = Capabilities("") def __init__(self, connection_name: str, fetch_inserted: bool = True, **kwargs: Any) -> None: diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index b0619eb24..23f190086 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -3,7 +3,7 @@ import decimal from collections.abc import Callable, Iterable, Sequence from copy import copy -from typing import TYPE_CHECKING, Any, Optional, Type, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast from pypika_tortoise import JoinType, Parameter, Table from pypika_tortoise.queries import QueryBuilder @@ -38,12 +38,12 @@ class BaseExecutor: def __init__( self, - model: "Type[Model]", + model: "type[Model]", db: "BaseDBAsyncClient", prefetch_map: "Optional[dict[str, set[Union[str, Prefetch]]]]" = None, prefetch_queries: Optional[dict[str, list[tuple[Optional[str], "QuerySet"]]]] = None, select_related_idx: Optional[ - list[tuple["Type[Model]", int, str, "Type[Model]", Iterable[Optional[str]]]] + list[tuple["type[Model]", int, str, "type[Model]", Iterable[Optional[str]]]] ] = None, ) -> None: self.model = model @@ -262,7 +262,7 @@ def get_update_sql( return sql async def execute_update( - self, instance: "Union[Type[Model], Model]", update_fields: Optional[Iterable[str]] + self, instance: "Union[type[Model], Model]", update_fields: Optional[Iterable[str]] ) -> int: values = [] expressions = {} @@ -279,7 +279,7 @@ async def execute_update( await self.db.execute_query(self.get_update_sql(update_fields, expressions), values) )[0] - async def execute_delete(self, instance: "Union[Type[Model], Model]") -> int: + async def execute_delete(self, instance: "Union[type[Model], Model]") -> int: return ( await self.db.execute_query( self.delete_query, [self.model._meta.pk.to_db_value(instance.pk, instance)] @@ -466,7 +466,7 @@ async def _prefetch_direct_relation( to_attr, related_queryset = related_query related_objects_for_fetch: dict[str, list] = {} relation_key_field = f"{field}_id" - model_to_field: dict["Type[Model]", str] = {} + model_to_field: dict["type[Model]", str] = {} for instance in instance_list: if (value := getattr(instance, relation_key_field)) is not None: if (model_cls := instance.__class__) in model_to_field: @@ -510,7 +510,7 @@ def _make_prefetch_queries(self) -> None: to_attr, related_query = self._prefetch_queries[field_name][0] else: relation_field = self.model._meta.fields_map[field_name] - related_model: "Type[Model]" = relation_field.related_model # type: ignore + related_model: "type[Model]" = relation_field.related_model # type: ignore related_query = related_model.all().using_db(self.db) related_query.query = copy( related_query.model._meta.basequery diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index 070112ebf..47d77163c 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -2,7 +2,7 @@ import re from hashlib import sha256 -from typing import TYPE_CHECKING, Any, Type, cast +from typing import TYPE_CHECKING, Any, cast from tortoise.exceptions import ConfigurationError from tortoise.fields import JSONField, TextField, UUIDField @@ -144,7 +144,7 @@ def _make_hash(*args: str, length: int) -> str: return sha256(";".join(args).encode("utf-8")).hexdigest()[:length] def _generate_index_name( - self, prefix: str, model: "Type[Model] | str", field_names: list[str] + self, prefix: str, model: "type[Model] | str", field_names: list[str] ) -> str: # NOTE: for compatibility, index name should not be longer than 30 # characters (Oracle limit). @@ -173,7 +173,7 @@ def _generate_fk_name( def _get_index_sql( self, - model: "Type[Model]", + model: "type[Model]", field_names: list[str], safe: bool, index_name: str | None = None, @@ -200,7 +200,7 @@ def _get_unique_index_sql(self, exists: str, table_name: str, field_names: list[ extra="", ) - def _get_unique_constraint_sql(self, model: "Type[Model]", field_names: list[str]) -> str: + def _get_unique_constraint_sql(self, model: "type[Model]", field_names: list[str]) -> str: return self.UNIQUE_CONSTRAINT_CREATE_TEMPLATE.format( index_name=self._generate_index_name("uid", model, field_names), fields=", ".join([self.quote(f) for f in field_names]), @@ -213,12 +213,12 @@ def _get_pk_field_sql_type(self, pk_field: "Field") -> str: return sql_type raise ConfigurationError(f"Can't get SQL type of {pk_field} for {self.DIALECT}") - def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: + def _get_table_sql(self, model: "type[Model]", safe: bool = True) -> dict: fields_to_create = [] fields_with_index = [] m2m_tables_for_create = [] references = set() - models_to_create: "list[Type[Model]]" = [] + models_to_create: "list[type[Model]]" = [] self._get_models_to_create(models_to_create) models_tables = [model._meta.db_table for model in models_to_create] @@ -458,7 +458,7 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: "m2m_tables": m2m_tables_for_create, } - def _get_models_to_create(self, models_to_create: "list[Type[Model]]") -> None: + def _get_models_to_create(self, models_to_create: "list[type[Model]]") -> None: from tortoise import Tortoise for app in Tortoise.apps.values(): @@ -468,7 +468,7 @@ def _get_models_to_create(self, models_to_create: "list[Type[Model]]") -> None: models_to_create.append(model) def get_create_schema_sql(self, safe: bool = True) -> str: - models_to_create: "list[Type[Model]]" = [] + models_to_create: "list[type[Model]]" = [] self._get_models_to_create(models_to_create) diff --git a/tortoise/backends/base_postgres/client.py b/tortoise/backends/base_postgres/client.py index 9eb588e94..c7ffdbf8f 100644 --- a/tortoise/backends/base_postgres/client.py +++ b/tortoise/backends/base_postgres/client.py @@ -3,7 +3,7 @@ from asyncio.events import AbstractEventLoop from collections.abc import Callable, Coroutine from functools import wraps -from typing import TYPE_CHECKING, Any, Optional, SupportsInt, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional, SupportsInt, TypeVar, Union from pypika_tortoise import PostgreSQLQuery @@ -39,9 +39,9 @@ class BasePostgresPool: class BasePostgresClient(BaseDBAsyncClient, abc.ABC): DSN_TEMPLATE = "postgres://{user}:{password}@{host}:{port}/{database}" - query_class: Type[PostgreSQLQuery] = PostgreSQLQuery - executor_class: Type[BasePostgresExecutor] = BasePostgresExecutor - schema_generator: Type[BasePostgresSchemaGenerator] = BasePostgresSchemaGenerator + query_class: type[PostgreSQLQuery] = PostgreSQLQuery + executor_class: type[BasePostgresExecutor] = BasePostgresExecutor + schema_generator: type[BasePostgresSchemaGenerator] = BasePostgresSchemaGenerator capabilities = Capabilities( "postgres", support_update_limit_order_by=False, support_for_posix_regex_queries=True ) diff --git a/tortoise/backends/base_postgres/schema_generator.py b/tortoise/backends/base_postgres/schema_generator.py index fe6090d4c..bf22c80d8 100644 --- a/tortoise/backends/base_postgres/schema_generator.py +++ b/tortoise/backends/base_postgres/schema_generator.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.converters import encoders @@ -71,7 +71,7 @@ def _escape_default_value(self, default: Any): def _get_index_sql( self, - model: "Type[Model]", + model: "type[Model]", field_names: list[str], safe: bool, index_name: str | None = None, diff --git a/tortoise/backends/mssql/schema_generator.py b/tortoise/backends/mssql/schema_generator.py index d9c0102ae..1b4c8edac 100644 --- a/tortoise/backends/mssql/schema_generator.py +++ b/tortoise/backends/mssql/schema_generator.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.converters import encoders @@ -63,7 +63,7 @@ def _escape_default_value(self, default: Any): def _get_index_sql( self, - model: "Type[Model]", + model: "type[Model]", field_names: list[str], safe: bool, index_name: str | None = None, @@ -74,7 +74,7 @@ def _get_index_sql( model, field_names, False, index_name=index_name, index_type=index_type, extra=extra ) - def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: + def _get_table_sql(self, model: "type[Model]", safe: bool = True) -> dict: return super()._get_table_sql(model, False) def _create_fk_string( diff --git a/tortoise/backends/mysql/schema_generator.py b/tortoise/backends/mysql/schema_generator.py index e1a506345..a07ba9978 100644 --- a/tortoise/backends/mysql/schema_generator.py +++ b/tortoise/backends/mysql/schema_generator.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.converters import encoders @@ -72,7 +72,7 @@ def _escape_default_value(self, default: Any): def _get_index_sql( self, - model: "Type[Model]", + model: "type[Model]", field_names: list[str], safe: bool, index_name: str | None = None, diff --git a/tortoise/backends/oracle/schema_generator.py b/tortoise/backends/oracle/schema_generator.py index f294a6cd8..d3cb54300 100644 --- a/tortoise/backends/oracle/schema_generator.py +++ b/tortoise/backends/oracle/schema_generator.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.converters import encoders @@ -89,7 +89,7 @@ def _escape_default_value(self, default: Any): def _get_index_sql( self, - model: "Type[Model]", + model: "type[Model]", field_names: list[str], safe: bool, index_name: str | None = None, @@ -100,7 +100,7 @@ def _get_index_sql( model, field_names, False, index_name=index_name, index_type=index_type, extra=extra ) - def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: + def _get_table_sql(self, model: "type[Model]", safe: bool = True) -> dict: return super()._get_table_sql(model, False) def _create_fk_string( diff --git a/tortoise/backends/psycopg/client.py b/tortoise/backends/psycopg/client.py index 4623b10f4..2a2f59bd1 100644 --- a/tortoise/backends/psycopg/client.py +++ b/tortoise/backends/psycopg/client.py @@ -4,7 +4,7 @@ from collections.abc import Callable from contextlib import _AsyncGeneratorContextManager from ssl import SSLContext -from typing import Any, Type, TypeVar, cast +from typing import Any, TypeVar, cast import psycopg import psycopg.conninfo @@ -54,9 +54,9 @@ def get_parameterized_sql(self, ctx: SqlContext | None = None) -> tuple[str, lis class PsycopgClient(postgres_client.BasePostgresClient): - query_class: Type[PsycopgSQLQuery] = PsycopgSQLQuery - executor_class: Type[executor.PsycopgExecutor] = executor.PsycopgExecutor - schema_generator: Type[PsycopgSchemaGenerator] = PsycopgSchemaGenerator + query_class: type[PsycopgSQLQuery] = PsycopgSQLQuery + executor_class: type[executor.PsycopgExecutor] = executor.PsycopgExecutor + schema_generator: type[PsycopgSchemaGenerator] = PsycopgSchemaGenerator _pool: AsyncConnectionPool | None = None _connection: psycopg.AsyncConnection default_timeout: float = 30 diff --git a/tortoise/connection.py b/tortoise/connection.py index a0723b6a4..44e166fb7 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -3,7 +3,7 @@ import importlib from contextvars import ContextVar from copy import copy -from typing import TYPE_CHECKING, Any, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Optional, Union from tortoise.backends.base.config_generator import expand_db_url from tortoise.exceptions import ConfigurationError @@ -65,7 +65,7 @@ def _copy_storage(self) -> dict[str, "BaseDBAsyncClient"]: def _clear_storage(self) -> None: self._get_storage().clear() - def _discover_client_class(self, db_info: dict) -> Type["BaseDBAsyncClient"]: + def _discover_client_class(self, db_info: dict) -> type["BaseDBAsyncClient"]: # Let exception bubble up for transparency engine_str = db_info.get("engine", "") engine_module = importlib.import_module(engine_str) diff --git a/tortoise/contrib/mysql/fields.py b/tortoise/contrib/mysql/fields.py index 5d593e1e6..30717bd10 100644 --- a/tortoise/contrib/mysql/fields.py +++ b/tortoise/contrib/mysql/fields.py @@ -1,10 +1,4 @@ -from typing import ( # noqa pylint: disable=unused-import - TYPE_CHECKING, - Any, - Optional, - Type, - Union, -) +from typing import TYPE_CHECKING, Any, Optional, Union from uuid import UUID, uuid4 from tortoise.fields import Field @@ -44,7 +38,7 @@ def __init__(self, binary_compression: bool = True, **kwargs: Any) -> None: self.SQL_TYPE = "BINARY(16)" self._binary_compression = binary_compression - def to_db_value(self, value: Any, instance: "Union[Type[Model], Model]") -> Optional[Union[str, bytes]]: # type: ignore + def to_db_value(self, value: Any, instance: "Union[type[Model], Model]") -> Optional[Union[str, bytes]]: # type: ignore # Make sure that value is a UUIDv4 # If not, raise an error # This is to prevent UUIDv1 or any other version from being stored in the database diff --git a/tortoise/contrib/pydantic/base.py b/tortoise/contrib/pydantic/base.py index 5f8fd8a3e..cecf8255a 100644 --- a/tortoise/contrib/pydantic/base.py +++ b/tortoise/contrib/pydantic/base.py @@ -1,5 +1,5 @@ import sys -from typing import TYPE_CHECKING, List, Type, Union +from typing import TYPE_CHECKING, List, Union import pydantic from pydantic import BaseModel, ConfigDict, RootModel @@ -17,7 +17,7 @@ def _get_fetch_fields( - pydantic_class: "Type[PydanticModel]", model_class: "Type[Model]" + pydantic_class: "type[PydanticModel]", model_class: "type[Model]" ) -> list[str]: """ Recursively collect fields needed to fetch diff --git a/tortoise/contrib/pydantic/creator.py b/tortoise/contrib/pydantic/creator.py index 2a8f59a29..e4a98cba6 100644 --- a/tortoise/contrib/pydantic/creator.py +++ b/tortoise/contrib/pydantic/creator.py @@ -4,7 +4,7 @@ from copy import copy from enum import Enum, IntEnum from hashlib import sha3_224 -from typing import TYPE_CHECKING, Any, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Optional, Union from pydantic import ConfigDict from pydantic import Field as PydanticField @@ -30,7 +30,7 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model -_MODEL_INDEX: dict[str, Type[PydanticModel]] = {} +_MODEL_INDEX: dict[str, type[PydanticModel]] = {} """ The index works as follows: 1. the hash is calculated from the following: @@ -52,7 +52,7 @@ def _cleandoc(obj: Any) -> str: def _pydantic_recursion_protector( - cls: "Type[Model]", + cls: "type[Model]", *, stack: tuple, exclude: tuple[str, ...] = (), @@ -61,7 +61,7 @@ def _pydantic_recursion_protector( name=None, allow_cycles: bool = False, sort_alphabetically: Optional[bool] = None, -) -> Optional[Type[PydanticModel]]: +) -> Optional[type[PydanticModel]]: """ It is an inner function to protect pydantic model creator against cyclic recursion """ @@ -125,7 +125,7 @@ def __setitem__(self, __key, __value): def sort_alphabetically(self) -> None: self._field_map = {k: self._field_map[k] for k in sorted(self._field_map)} - def sort_definition_order(self, cls: "Type[Model]", computed: tuple[str, ...]) -> None: + def sort_definition_order(self, cls: "type[Model]", computed: tuple[str, ...]) -> None: self._field_map = { k: self._field_map[k] for k in tuple(cls._meta.fields_map.keys()) + computed @@ -149,7 +149,7 @@ def field_map_update(self, fields: list[Field], meta: PydanticMetaData) -> None: self.pop(raw_field, None) self[name] = field - def computed_field_map_update(self, computed: tuple[str, ...], cls: "Type[Model]"): + def computed_field_map_update(self, computed: tuple[str, ...], cls: "type[Model]"): self._field_map.update( { k: ComputedFieldDescription( @@ -163,7 +163,7 @@ def computed_field_map_update(self, computed: tuple[str, ...], cls: "Type[Model] def pydantic_queryset_creator( - cls: "Type[Model]", + cls: "type[Model]", *, name=None, exclude: tuple[str, ...] = (), @@ -171,7 +171,7 @@ def pydantic_queryset_creator( computed: tuple[str, ...] = (), allow_cycles: Optional[bool] = None, sort_alphabetically: Optional[bool] = None, -) -> Type[PydanticListModel]: +) -> type[PydanticListModel]: """ Function to build a `Pydantic Model `__ list off Tortoise Model. @@ -224,7 +224,7 @@ def pydantic_queryset_creator( class PydanticModelCreator: def __init__( self, - cls: "Type[Model]", + cls: "type[Model]", name: Optional[str] = None, exclude: Optional[tuple[str, ...]] = None, include: Optional[tuple[str, ...]] = None, @@ -233,16 +233,16 @@ def __init__( allow_cycles: Optional[bool] = None, sort_alphabetically: Optional[bool] = None, exclude_readonly: bool = False, - meta_override: Optional[Type] = None, + meta_override: Optional[type] = None, model_config: Optional[ConfigDict] = None, validators: Optional[dict[str, Any]] = None, module: str = __name__, _stack: tuple = (), _as_submodel: bool = False, ) -> None: - self._cls: "Type[Model]" = cls - self._stack: tuple[tuple["Type[Model]", str, int], ...] = ( - _stack # ((Type[Model], field_name, max_recursion),) + self._cls: "type[Model]" = cls + self._stack: tuple[tuple["type[Model]", str, int], ...] = ( + _stack # ((type[Model], field_name, max_recursion),) ) self._is_default: bool = ( exclude is None @@ -368,7 +368,7 @@ def _construct_field_map(self) -> None: else: self._field_map.sort_definition_order(self._cls, self.meta.computed) - def create_pydantic_model(self) -> Type[PydanticModel]: + def create_pydantic_model(self) -> type[PydanticModel]: for field_name, field in self._field_map.items(): self._process_field(field_name, field) @@ -464,9 +464,9 @@ def _process_single_field_relation( field_name: str, field: Union[ForeignKeyFieldInstance, OneToOneFieldInstance, BackwardOneToOneRelation], json_schema_extra: dict[str, Any], - ) -> Optional[Type[PydanticModel]]: + ) -> Optional[type[PydanticModel]]: python_type = getattr(field, "related_model", field.field_type) - model: Optional[Type[PydanticModel]] = self._get_submodel(python_type, field_name) + model: Optional[type[PydanticModel]] = self._get_submodel(python_type, field_name) if model: self._relational_fields_index.append((field_name, model.__name__)) if field.null: @@ -481,7 +481,7 @@ def _process_many_field_relation( self, field_name: str, field: Union[BackwardFKRelation, ManyToManyFieldInstance], - ) -> Optional[Type[list[Type[PydanticModel]]]]: + ) -> Optional[type[list[type[PydanticModel]]]]: python_type = field.related_model model = self._get_submodel(python_type, field_name) if model: @@ -502,7 +502,7 @@ def _process_data_field( json_schema_extra["readOnly"] = constraints["readOnly"] del constraints["readOnly"] fconfig.update(constraints) - python_type: Union[Type[Enum], Type[IntEnum], Type] + python_type: Union[type[Enum], type[IntEnum], type] if isinstance(field, (IntEnumFieldInstance, CharEnumFieldInstance)): python_type = field.enum_type else: @@ -532,8 +532,8 @@ def _process_computed_field( return None def _get_submodel( - self, _model: Optional["Type[Model]"], field_name: str - ) -> Optional[Type[PydanticModel]]: + self, _model: Optional["type[Model]"], field_name: str + ) -> Optional[type[PydanticModel]]: """Get Pydantic model for the submodel""" if _model: @@ -567,7 +567,7 @@ def get_fields_to_carry_on(field_tuple: tuple[str, ...]) -> tuple[str, ...]: def pydantic_model_creator( - cls: "Type[Model]", + cls: "type[Model]", *, name=None, exclude: Optional[tuple[str, ...]] = None, @@ -577,11 +577,11 @@ def pydantic_model_creator( allow_cycles: Optional[bool] = None, sort_alphabetically: Optional[bool] = None, exclude_readonly: bool = False, - meta_override: Optional[Type] = None, + meta_override: Optional[type] = None, model_config: Optional[ConfigDict] = None, validators: Optional[dict[str, Any]] = None, module: str = __name__, -) -> Type[PydanticModel]: +) -> type[PydanticModel]: """ Function to build `Pydantic Model `__ off Tortoise Model. diff --git a/tortoise/contrib/pydantic/descriptions.py b/tortoise/contrib/pydantic/descriptions.py index 48eedb99e..a0c7b6c66 100644 --- a/tortoise/contrib/pydantic/descriptions.py +++ b/tortoise/contrib/pydantic/descriptions.py @@ -1,7 +1,7 @@ import dataclasses import sys from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Optional, Type +from typing import TYPE_CHECKING, Any, Optional if sys.version_info >= (3, 11): from typing import Self @@ -27,7 +27,7 @@ class ModelDescription: m2m_fields: list[Field] = dataclasses.field(default_factory=list) @classmethod - def from_model(cls, model: Type["Model"]) -> Self: + def from_model(cls, model: type["Model"]) -> Self: return cls( pk_field=model._meta.fields_map[model._meta.pk_attr], data_fields=[ @@ -140,7 +140,7 @@ def get_param_from_pydantic_meta(attr: str, default: Any) -> Any: ) return pmd - def construct_pydantic_meta(self, meta_override: Type) -> "PydanticMetaData": + def construct_pydantic_meta(self, meta_override: type) -> "PydanticMetaData": def get_param_from_meta_override(attr: str) -> Any: return getattr(meta_override, attr, getattr(self, attr)) diff --git a/tortoise/contrib/pydantic/utils.py b/tortoise/contrib/pydantic/utils.py index e7410f0e5..e2d577326 100644 --- a/tortoise/contrib/pydantic/utils.py +++ b/tortoise/contrib/pydantic/utils.py @@ -1,16 +1,15 @@ -import typing from collections.abc import Callable -from typing import Any, Optional, Type +from typing import TYPE_CHECKING, Any, Optional, get_type_hints -if typing.TYPE_CHECKING: # pragma: nocoverage +if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model -def get_annotations(cls: "Type[Model]", method: Optional[Callable] = None) -> dict[str, Any]: +def get_annotations(cls: "type[Model]", method: Optional[Callable] = None) -> dict[str, Any]: """ Get all annotations including base classes :param cls: The model class we need annotations from :param method: If specified, we try to get the annotations for the callable :return: The list of annotations """ - return typing.get_type_hints(method or cls) + return get_type_hints(method or cls) diff --git a/tortoise/exceptions.py b/tortoise/exceptions.py index edfef4ebf..0662e29cc 100644 --- a/tortoise/exceptions.py +++ b/tortoise/exceptions.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union if TYPE_CHECKING: - from tortoise import Model, Type + from tortoise import Model class BaseORMException(Exception): @@ -55,8 +55,8 @@ class NoValuesFetched(OperationalError): class NotExistOrMultiple(OperationalError): TEMPLATE = "" - def __init__(self, model: "Union[Type[Model], str]", *args) -> None: - self.model: "Optional[Type[Model]]" = None + def __init__(self, model: "Union[type[Model], str]", *args) -> None: + self.model: "Optional[type[Model]]" = None if isinstance(model, str): args = (model,) + args else: @@ -83,8 +83,8 @@ class ObjectDoesNotExistError(OperationalError, KeyError): The DoesNotExist exception is raised when an item with the passed primary key does not exist """ - def __init__(self, model: "Type[Model]", pk_name: str, pk_val: Any) -> None: - self.model: "Type[Model]" = model + def __init__(self, model: "type[Model]", pk_name: str, pk_val: Any) -> None: + self.model: "type[Model]" = model self.pk_name: str = pk_name self.pk_val: Any = pk_val diff --git a/tortoise/expressions.py b/tortoise/expressions.py index 76329c4b1..9e0cc05be 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from dataclasses import field as dataclass_field from enum import Enum -from typing import TYPE_CHECKING, Any, Type, cast +from typing import TYPE_CHECKING, Any, cast from pypika_tortoise import Case as PypikaCase from pypika_tortoise import Field as PypikaField @@ -36,7 +36,7 @@ @dataclass(frozen=True) class ResolveContext: - model: Type["Model"] + model: type["Model"] table: Table annotations: dict[str, Any] custom_filters: dict[str, FilterInfoDict] @@ -350,7 +350,7 @@ def _resolve_custom_kwarg( return modifier def _process_filter_kwarg( - self, model: "Type[Model]", key: str, value: Any, table: Table + self, model: "type[Model]", key: str, value: Any, table: Table ) -> tuple[Criterion, tuple[Table, Criterion] | None]: join = None @@ -500,7 +500,7 @@ class Function(Expression): __slots__ = ("field", "field_object", "default_values") - database_func: Type[PypikaFunction] = PypikaFunction + database_func: type[PypikaFunction] = PypikaFunction # Enable populate_field_object where we want to try and preserve the field type. populate_field_object = False @@ -570,7 +570,7 @@ class Aggregate(Function): :param is_distinct: Flag for aggregate with distinction """ - database_func: Type[AggregateFunction] = DistinctOptionFunction + database_func: type[AggregateFunction] = DistinctOptionFunction def __init__( self, diff --git a/tortoise/fields/base.py b/tortoise/fields/base.py index 6912ea4cd..de2c92dbd 100644 --- a/tortoise/fields/base.py +++ b/tortoise/fields/base.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Callable from enum import Enum -from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, overload from pypika_tortoise.terms import Term @@ -40,7 +40,7 @@ class OnDelete(StrEnum): class _FieldMeta(type): # TODO: Require functions to return field instances instead of this hack - def __new__(mcs, name: str, bases: tuple[Type, ...], attrs: dict) -> type: + def __new__(mcs, name: str, bases: tuple[type, ...], attrs: dict) -> type: if len(bases) > 1 and bases[0] is Field: # Instantiate class with only the 1st base class (should be Field) cls = type.__new__(mcs, name, (bases[0],), attrs) @@ -73,7 +73,7 @@ class Field(Generic[VALUE], metaclass=_FieldMeta): These attributes needs to be defined when defining an actual field type. .. attribute:: field_type - :annotation: Type[Any] + :annotation: type[Any] The Python type the field is. If adding a type as a mixin, _FieldMeta will automatically set this to that. @@ -137,7 +137,7 @@ def function_cast(self, term: Term) -> Term: """ # Field_type is a readonly property for the instance, it is set by _FieldMeta - field_type: Type[Any] = None # type: ignore + field_type: type[Any] = None # type: ignore indexable: bool = True has_db_field: bool = True skip_to_python_if_native: bool = False @@ -153,13 +153,13 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "Field[VALUE]": return super().__new__(cls) @overload - def __get__(self, instance: None, owner: Type["Model"]) -> "Field[VALUE]": ... + def __get__(self, instance: None, owner: type["Model"]) -> "Field[VALUE]": ... @overload - def __get__(self, instance: "Model", owner: Type["Model"]) -> VALUE: ... + def __get__(self, instance: "Model", owner: type["Model"]) -> VALUE: ... def __get__( - self, instance: Optional["Model"], owner: Type["Model"] + self, instance: Optional["Model"], owner: type["Model"] ) -> "Field[VALUE] | VALUE": ... def __set__(self, instance: "Model", value: VALUE) -> None: ... @@ -228,10 +228,10 @@ def __init__( self.docstring: Optional[str] = None self.validators: list[Union[Validator, Callable]] = validators or [] # TODO: consider making this not be set from constructor - self.model: Type["Model"] = model # type: ignore + self.model: type["Model"] = model # type: ignore self.reference: "Optional[Field]" = None - def to_db_value(self, value: Any, instance: "Union[Type[Model], Model]") -> Any: + def to_db_value(self, value: Any, instance: "Union[type[Model], Model]") -> Any: """ Converts from the Python type to the DB type. @@ -390,7 +390,7 @@ def describe(self, serializable: bool) -> dict: } """ - def _type_name(typ: Type) -> str: + def _type_name(typ: type) -> str: if typ.__module__ == "builtins": return typ.__name__ if typ.__module__ == "typing": diff --git a/tortoise/fields/data.py b/tortoise/fields/data.py index b1588817f..307fb6df0 100644 --- a/tortoise/fields/data.py +++ b/tortoise/fields/data.py @@ -5,7 +5,7 @@ from collections.abc import Callable from decimal import Decimal from enum import Enum, IntEnum -from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from uuid import UUID, uuid4 from pypika_tortoise import functions @@ -360,7 +360,7 @@ def to_python_value(self, value: Any) -> Optional[datetime.datetime]: return value def to_db_value( - self, value: Optional[DatetimeFieldQueryValueType], instance: "Union[Type[Model], Model]" + self, value: Optional[DatetimeFieldQueryValueType], instance: "Union[type[Model], Model]" ) -> Optional[DatetimeFieldQueryValueType]: # Only do this if it is a Model instance, not class. Test for guaranteed instance var if hasattr(instance, "_saved_in_db") and ( @@ -410,7 +410,7 @@ def to_python_value(self, value: Any) -> Optional[datetime.date]: return value def to_db_value( - self, value: Optional[Union[datetime.date, str]], instance: "Union[Type[Model], Model]" + self, value: Optional[Union[datetime.date, str]], instance: "Union[type[Model], Model]" ) -> Optional[datetime.date]: if value is not None and not isinstance(value, datetime.date): value = parse_datetime(value).date() @@ -449,7 +449,7 @@ def to_python_value(self, value: Any) -> Optional[Union[datetime.time, datetime. def to_db_value( self, value: Optional[Union[datetime.time, datetime.timedelta]], - instance: "Union[Type[Model], Model]", + instance: "Union[type[Model], Model]", ) -> Optional[Union[datetime.time, datetime.timedelta]]: # Only do this if it is a Model instance, not class. Test for guaranteed instance var if hasattr(instance, "_saved_in_db") and ( @@ -497,7 +497,7 @@ def to_python_value(self, value: Any) -> Optional[datetime.timedelta]: return datetime.timedelta(microseconds=value) def to_db_value( - self, value: Optional[datetime.timedelta], instance: "Union[Type[Model], Model]" + self, value: Optional[datetime.timedelta], instance: "Union[type[Model], Model]" ) -> Optional[int]: self.validate(value) @@ -571,7 +571,7 @@ def __init__( def to_db_value( self, value: Optional[Union[T, dict, list, str, bytes]], - instance: "Union[Type[Model], Model]", + instance: "Union[type[Model], Model]", ) -> Optional[str]: self.validate(value) if value is None: @@ -639,7 +639,7 @@ def __init__(self, **kwargs: Any) -> None: kwargs["default"] = uuid4 super().__init__(**kwargs) - def to_db_value(self, value: Any, instance: "Union[Type[Model], Model]") -> Optional[str]: + def to_db_value(self, value: Any, instance: "Union[type[Model], Model]") -> Optional[str]: return value and str(value) def to_python_value(self, value: Any) -> Optional[UUID]: @@ -672,7 +672,7 @@ class _db_mssql: class IntEnumFieldInstance(SmallIntField): def __init__( self, - enum_type: Type[IntEnum], + enum_type: type[IntEnum], description: Optional[str] = None, generated: bool = False, **kwargs: Any, @@ -701,7 +701,7 @@ def to_python_value(self, value: Union[int, None]) -> Union[IntEnum, None]: return value def to_db_value( - self, value: Union[IntEnum, None, int], instance: "Union[Type[Model], Model]" + self, value: Union[IntEnum, None, int], instance: "Union[type[Model], Model]" ) -> Union[int, None]: if isinstance(value, IntEnum): value = int(value.value) @@ -715,7 +715,7 @@ def to_db_value( def IntEnumField( - enum_type: Type[IntEnumType], + enum_type: type[IntEnumType], description: Optional[str] = None, **kwargs: Any, ) -> IntEnumType: @@ -742,7 +742,7 @@ def IntEnumField( class CharEnumFieldInstance(CharField): def __init__( self, - enum_type: Type[Enum], + enum_type: type[Enum], description: Optional[str] = None, max_length: int = 0, **kwargs: Any, @@ -765,7 +765,7 @@ def to_python_value(self, value: Union[str, None]) -> Union[Enum, None]: return self.enum_type(value) if value is not None else None def to_db_value( - self, value: Union[Enum, None, str], instance: "Union[Type[Model], Model]" + self, value: Union[Enum, None, str], instance: "Union[type[Model], Model]" ) -> Union[str, None]: self.validate(value) if isinstance(value, Enum): @@ -779,7 +779,7 @@ def to_db_value( def CharEnumField( - enum_type: Type[CharEnumType], + enum_type: type[CharEnumType], description: Optional[str] = None, max_length: int = 0, **kwargs: Any, diff --git a/tortoise/fields/relational.py b/tortoise/fields/relational.py index 9df05b947..016824203 100644 --- a/tortoise/fields/relational.py +++ b/tortoise/fields/relational.py @@ -1,8 +1,16 @@ from collections.abc import AsyncGenerator, Generator, Iterator -from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Literal, + Optional, + TypeVar, + Union, + overload, +) from pypika_tortoise import Table -from typing_extensions import Literal from tortoise.exceptions import ConfigurationError, NoValuesFetched, OperationalError from tortoise.fields.base import CASCADE, SET_NULL, Field, OnDelete @@ -35,7 +43,7 @@ class ReverseRelation(Generic[MODEL]): def __init__( self, - remote_model: Type[MODEL], + remote_model: type[MODEL], relation_field: str, instance: "Model", from_field: str, @@ -234,13 +242,13 @@ class RelationalField(Field[MODEL]): def __init__( self, - related_model: "Type[MODEL]", + related_model: "type[MODEL]", to_field: Optional[str] = None, db_constraint: bool = True, **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.related_model: "Type[MODEL]" = related_model + self.related_model: "type[MODEL]" = related_model self.to_field: str = to_field # type: ignore self.to_field_instance: Field = None # type: ignore self.db_constraint = db_constraint @@ -248,13 +256,13 @@ def __init__( if TYPE_CHECKING: @overload - def __get__(self, instance: None, owner: Type["Model"]) -> "RelationalField[MODEL]": ... + def __get__(self, instance: None, owner: type["Model"]) -> "RelationalField[MODEL]": ... @overload - def __get__(self, instance: "Model", owner: Type["Model"]) -> MODEL: ... + def __get__(self, instance: "Model", owner: type["Model"]) -> MODEL: ... def __get__( - self, instance: Optional["Model"], owner: Type["Model"] + self, instance: Optional["Model"], owner: type["Model"] ) -> "RelationalField[MODEL] | MODEL": ... def __set__(self, instance: "Model", value: MODEL) -> None: ... @@ -302,7 +310,7 @@ def describe(self, serializable: bool) -> dict: class BackwardFKRelation(RelationalField[MODEL]): def __init__( self, - field_type: "Type[MODEL]", + field_type: "type[MODEL]", relation_field: str, relation_source_field: str, null: bool, @@ -342,7 +350,7 @@ def __init__( backward_key: str = "", related_name: str = "", on_delete: OnDelete = CASCADE, - field_type: "Type[MODEL]" = None, # type: ignore + field_type: "type[MODEL]" = None, # type: ignore create_unique_index: bool = True, **kwargs: Any, ) -> None: diff --git a/tortoise/indexes.py b/tortoise/indexes.py index 807eba7f5..4b1eb43fe 100644 --- a/tortoise/indexes.py +++ b/tortoise/indexes.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any from pypika_tortoise.context import DEFAULT_SQL_CONTEXT from pypika_tortoise.terms import Term, ValueWrapper @@ -51,12 +51,12 @@ def describe(self) -> dict: "extra": self.extra, } - def index_name(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]") -> str: + def index_name(self, schema_generator: "BaseSchemaGenerator", model: "type[Model]") -> str: # This function is required by aerich return self.name or schema_generator._generate_index_name("idx", model, self.field_names) def get_sql( - self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool + self, schema_generator: "BaseSchemaGenerator", model: "type[Model]", safe: bool ) -> str: # This function is required by aerich return schema_generator._get_index_sql( diff --git a/tortoise/models.py b/tortoise/models.py index 6a7dfaef7..e4b2c82b5 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -4,7 +4,7 @@ from collections.abc import Awaitable, Callable, Generator, Iterable from copy import copy, deepcopy from functools import partial -from typing import Any, Optional, Type, TypedDict, TypeVar, Union, cast +from typing import Any, Optional, TypedDict, TypeVar, Union, cast from pypika_tortoise import Order, Query, Table from pypika_tortoise.terms import Term @@ -94,7 +94,7 @@ def _fk_setter( def _fk_getter( - self: "Model", _key: str, ftype: "Type[Model]", relation_field: str, to_field: str + self: "Model", _key: str, ftype: "type[Model]", relation_field: str, to_field: str ) -> Awaitable: try: return getattr(self, _key) @@ -106,7 +106,7 @@ def _fk_getter( def _rfk_getter( - self: "Model", _key: str, ftype: "Type[Model]", frelfield: str, from_field: str + self: "Model", _key: str, ftype: "type[Model]", frelfield: str, from_field: str ) -> ReverseRelation: val = getattr(self, _key, None) if val is None: @@ -116,7 +116,7 @@ def _rfk_getter( def _ro2o_getter( - self: "Model", _key: str, ftype: "Type[Model]", frelfield: str, from_field: str + self: "Model", _key: str, ftype: "type[Model]", frelfield: str, from_field: str ) -> "QuerySetSingle[Optional[Model]]": if hasattr(self, _key): return getattr(self, _key) @@ -136,7 +136,7 @@ def _m2m_getter( return val -def _get_comments(cls: "Type[Model]") -> dict[str, str]: +def _get_comments(cls: "type[Model]") -> dict[str, str]: """ Get comments exactly before attributes @@ -237,7 +237,7 @@ def __init__(self, meta: "Model.Meta") -> None: self.basetable: Table = Table("") self.pk_attr: str = getattr(meta, "pk_attr", "") self.generated_db_fields: tuple[str, ...] = None # type: ignore - self._model: Type["Model"] = None # type: ignore + self._model: type["Model"] = None # type: ignore self.table_description: str = getattr(meta, "table_description", "") self.pk: Field = None # type: ignore self.db_pk_column: str = "" @@ -474,7 +474,7 @@ def _generate_filters(self) -> None: class ModelMeta(type): __slots__ = () - def __new__(cls, name: str, bases: tuple[Type, ...], attrs: dict[str, Any]) -> "ModelMeta": + def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> "ModelMeta": fields_db_projection: dict[str, str] = {} meta_class: "Model.Meta" = attrs.get("Meta", type("Meta", (), {})) pk_attr: str = "id" @@ -528,7 +528,7 @@ def __new__(cls, name: str, bases: tuple[Type, ...], attrs: dict[str, Any]) -> " return new_class @classmethod - def _search_for_field_attributes(cls, base: Type, attrs: dict) -> None: + def _search_for_field_attributes(cls, base: type, attrs: dict) -> None: """ Searching for class attributes of type fields.Field in the given class. @@ -666,7 +666,7 @@ def build_meta( meta.abstract = True return meta - def __getitem__(cls: Type[MODEL], key: Any) -> QuerySetSingle[MODEL]: # type: ignore + def __getitem__(cls: type[MODEL], key: Any) -> QuerySetSingle[MODEL]: # type: ignore return cls._getbypk(key) # type: ignore @@ -677,7 +677,7 @@ class Model(metaclass=ModelMeta): # I don' like this here, but it makes auto completion and static analysis much happier _meta = MetaInfo(None) # type: ignore - _listeners: dict[Signals, dict[Type[MODEL], list[Callable]]] = { # type: ignore + _listeners: dict[Signals, dict[type[MODEL], list[Callable]]] = { # type: ignore Signals.pre_save: {}, Signals.post_save: {}, Signals.pre_delete: {}, @@ -749,7 +749,7 @@ def _set_kwargs(self, kwargs: dict) -> set[str]: return passed_fields @classmethod - def _init_from_db(cls: Type[MODEL], **kwargs: Any) -> MODEL: + def _init_from_db(cls: type[MODEL], **kwargs: Any) -> MODEL: self = cls.__new__(cls) self._partial = False self._saved_in_db = True @@ -852,7 +852,7 @@ def _validate_relation_type(cls, field_key: str, value: Optional["Model"]) -> No ) @classmethod - async def _getbypk(cls: Type[MODEL], key: Any) -> MODEL: + async def _getbypk(cls: type[MODEL], key: Any) -> MODEL: try: return await cls.get(pk=key) except (DoesNotExist, ValueError): @@ -1153,7 +1153,7 @@ def select_for_update( @classmethod async def update_or_create( - cls: Type[MODEL], + cls: type[MODEL], defaults: Optional[dict] = None, using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any, @@ -1177,7 +1177,7 @@ async def update_or_create( @classmethod async def create( - cls: Type[MODEL], using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any + cls: type[MODEL], using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any ) -> MODEL: """ Create a record in the DB and returns the object. @@ -1204,7 +1204,7 @@ async def create( @classmethod def bulk_update( - cls: Type[MODEL], + cls: type[MODEL], objects: Iterable[MODEL], fields: Iterable[str], batch_size: Optional[int] = None, @@ -1234,7 +1234,7 @@ def bulk_update( @classmethod async def in_bulk( - cls: Type[MODEL], + cls: type[MODEL], id_list: Iterable[Union[str, int]], field_name: str = "pk", using_db: Optional[BaseDBAsyncClient] = None, @@ -1251,7 +1251,7 @@ async def in_bulk( @classmethod def bulk_create( - cls: Type[MODEL], + cls: type[MODEL], objects: Iterable[MODEL], batch_size: Optional[int] = None, ignore_conflicts: bool = False, @@ -1394,7 +1394,7 @@ def raw(cls, sql: str, using_db: Optional[BaseDBAsyncClient] = None) -> "RawSQLQ @classmethod def exists( - cls: Type[MODEL], *args: Q, using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any + cls: type[MODEL], *args: Q, using_db: Optional[BaseDBAsyncClient] = None, **kwargs: Any ) -> ExistsQuery: """ Return True/False whether record exists with the provided filter parameters. diff --git a/tortoise/query_utils.py b/tortoise/query_utils.py index 0ac562bb4..6cfe38640 100644 --- a/tortoise/query_utils.py +++ b/tortoise/query_utils.py @@ -1,7 +1,7 @@ from __future__ import annotations from copy import copy -from typing import TYPE_CHECKING, Optional, Type, cast +from typing import TYPE_CHECKING, Optional, cast from pypika_tortoise import Table from pypika_tortoise.terms import Criterion, Term @@ -79,7 +79,7 @@ def get_joins_for_related_field( def resolve_nested_field( - model: Type["Model"], table: Table, field: str + model: type["Model"], table: Table, field: str ) -> tuple[Term, list[TableCriterionTuple], Optional[Field]]: """ Resolves a nested field string like events__participants__name and diff --git a/tortoise/queryset.py b/tortoise/queryset.py index a627f7f9d..2d1fc4d42 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1,17 +1,7 @@ import types from collections.abc import AsyncIterator, Callable, Generator, Iterable from copy import copy -from typing import ( - TYPE_CHECKING, - Any, - Generic, - Optional, - Type, - TypeVar, - Union, - cast, - overload, -) +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast, overload from pypika_tortoise import JoinType, Order, Table from pypika_tortoise.analytics import Count @@ -97,9 +87,9 @@ class AwaitableQuery(Generic[MODEL]): "_q_objects", ) - def __init__(self, model: Type[MODEL]) -> None: + def __init__(self, model: type[MODEL]) -> None: self._joined_tables: list[Table] = [] - self.model: "Type[MODEL]" = model + self.model: "type[MODEL]" = model self.query: QueryBuilder = QUERY self._db: BaseDBAsyncClient = None # type: ignore self.capabilities: Capabilities = model._meta.db.capabilities @@ -184,7 +174,7 @@ def _resolve_ordering_string(ordering: str, reverse: bool = False) -> tuple[str, def resolve_ordering( self, - model: "Type[Model]", + model: "type[Model]", table: Table, orderings: Iterable[tuple[str, Union[str, Order]]], annotations: dict[str, Any], @@ -326,7 +316,7 @@ class QuerySet(AwaitableQuery[MODEL]): "_force_indexes", ) - def __init__(self, model: Type[MODEL]) -> None: + def __init__(self, model: type[MODEL]) -> None: super().__init__(model) self.fields: set[str] = model._meta.db_fields self._prefetch_map: dict[str, set[Union[str, Prefetch]]] = {} @@ -347,7 +337,7 @@ def __init__(self, model: Type[MODEL]) -> None: self._select_for_update_of: set[str] = set() self._select_related: set[str] = set() self._select_related_idx: list[ - tuple["Type[Model]", int, Union[Table, str], "Type[Model]", Iterable[Optional[str]]] + tuple["type[Model]", int, Union[Table, str], "type[Model]", Iterable[Optional[str]]] ] = [] # format with: model,idx,model_name,parent_model self._force_indexes: set[str] = set() self._use_indexes: set[str] = set() @@ -1017,7 +1007,7 @@ def using_db(self, _db: Optional[BaseDBAsyncClient]) -> "QuerySet[MODEL]": def _join_table_with_select_related( self, - model: "Type[Model]", + model: "type[Model]", table: Table, field: str, forwarded_fields: str, @@ -1162,7 +1152,7 @@ class UpdateQuery(AwaitableQuery): def __init__( self, - model: Type[MODEL], + model: type[MODEL], update_kwargs: dict[str, Any], db: BaseDBAsyncClient, q_objects: list[Q], @@ -1241,7 +1231,7 @@ class DeleteQuery(AwaitableQuery): def __init__( self, - model: Type[MODEL], + model: type[MODEL], db: BaseDBAsyncClient, q_objects: list[Q], annotations: dict[str, Any], @@ -1288,7 +1278,7 @@ class ExistsQuery(AwaitableQuery): def __init__( self, - model: Type[MODEL], + model: type[MODEL], db: BaseDBAsyncClient, q_objects: list[Q], annotations: dict[str, Any], @@ -1339,7 +1329,7 @@ class CountQuery(AwaitableQuery): def __init__( self, - model: Type[MODEL], + model: type[MODEL], db: BaseDBAsyncClient, q_objects: list[Q], annotations: dict[str, Any], @@ -1395,12 +1385,12 @@ async def _execute(self) -> int: class FieldSelectQuery(AwaitableQuery): # pylint: disable=W0223 - def __init__(self, model: Type[MODEL], annotations: dict[str, Any]) -> None: + def __init__(self, model: type[MODEL], annotations: dict[str, Any]) -> None: super().__init__(model) self._annotations = annotations def _join_table_with_forwarded_fields( - self, model: Type[MODEL], table: Table, field: str, forwarded_fields: str + self, model: type[MODEL], table: Table, field: str, forwarded_fields: str ) -> tuple[Table, str]: if field in model._meta.fields_db_projection and not forwarded_fields: return table, model._meta.fields_db_projection[field] @@ -1459,7 +1449,7 @@ def add_field_to_select_query(self, field: str, return_as: str) -> None: raise FieldError(f'Unknown field "{field}" for model "{self.model.__name__}"') - def resolve_to_python_value(self, model: Type[MODEL], field: str) -> Callable: + def resolve_to_python_value(self, model: type[MODEL], field: str) -> Callable: if field in model._meta.fetch_fields: # return as is to get whole model objects return lambda x: x @@ -1522,7 +1512,7 @@ class ValuesListQuery(FieldSelectQuery, Generic[SINGLE]): def __init__( self, - model: Type[MODEL], + model: type[MODEL], db: BaseDBAsyncClient, q_objects: list[Q], single: bool, @@ -1650,7 +1640,7 @@ class ValuesQuery(FieldSelectQuery, Generic[SINGLE]): def __init__( self, - model: Type[MODEL], + model: type[MODEL], db: BaseDBAsyncClient, q_objects: list[Q], single: bool, @@ -1768,7 +1758,7 @@ async def _execute(self) -> Union[list[dict], dict]: class RawSQLQuery(AwaitableQuery): __slots__ = ("_sql", "_db") - def __init__(self, model: Type[MODEL], db: BaseDBAsyncClient, sql: str) -> None: + def __init__(self, model: type[MODEL], db: BaseDBAsyncClient, sql: str) -> None: super().__init__(model) self._sql = sql self._db = db @@ -1790,7 +1780,7 @@ class BulkUpdateQuery(UpdateQuery, Generic[MODEL]): def __init__( self, - model: Type[MODEL], + model: type[MODEL], db: BaseDBAsyncClient, q_objects: list[Q], annotations: dict[str, Any], @@ -1890,7 +1880,7 @@ class BulkCreateQuery(AwaitableQuery, Generic[MODEL]): def __init__( self, - model: Type[MODEL], + model: type[MODEL], db: BaseDBAsyncClient, objects: Iterable[MODEL], batch_size: Optional[int] = None, diff --git a/tortoise/router.py b/tortoise/router.py index 8a77b74c2..e428dc052 100644 --- a/tortoise/router.py +++ b/tortoise/router.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any from tortoise.connection import connections from tortoise.exceptions import ConfigurationError @@ -17,7 +17,7 @@ def __init__(self) -> None: def init_routers(self, routers: list[Callable]) -> None: self._routers = [r() for r in routers] - def _router_func(self, model: Type["Model"], action: str) -> Any: + def _router_func(self, model: type["Model"], action: str) -> Any: for r in self._routers: try: method = getattr(r, action) @@ -29,16 +29,16 @@ def _router_func(self, model: Type["Model"], action: str) -> Any: if chosen_db: return chosen_db - def _db_route(self, model: Type["Model"], action: str) -> "BaseDBAsyncClient" | None: + def _db_route(self, model: type["Model"], action: str) -> "BaseDBAsyncClient" | None: try: return connections.get(self._router_func(model, action)) except ConfigurationError: return None - def db_for_read(self, model: Type["Model"]) -> "BaseDBAsyncClient" | None: + def db_for_read(self, model: type["Model"]) -> "BaseDBAsyncClient" | None: return self._db_route(model, "db_for_read") - def db_for_write(self, model: Type["Model"]) -> "BaseDBAsyncClient" | None: + def db_for_write(self, model: type["Model"]) -> "BaseDBAsyncClient" | None: return self._db_route(model, "db_for_write")