From 216048ba30b8bb13ab5175fd5ffdd66545421b97 Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Tue, 21 Jan 2025 14:33:22 +0800 Subject: [PATCH 1/5] Fix index class argument `name` not work --- tests/fields/test_db_index.py | 7 +++- tests/testmodels.py | 2 + tests/utils/test_describe_model.py | 11 ++++- tortoise/backends/base/schema_generator.py | 47 +++++++++++++--------- 4 files changed, 44 insertions(+), 23 deletions(-) diff --git a/tests/fields/test_db_index.py b/tests/fields/test_db_index.py index 8d61779f1..6b3a15039 100644 --- a/tests/fields/test_db_index.py +++ b/tests/fields/test_db_index.py @@ -2,11 +2,11 @@ from pypika_tortoise.terms import Field +from tests.testmodels import ModelWithIndexes from tortoise import fields from tortoise.contrib import test from tortoise.exceptions import ConfigurationError from tortoise.indexes import Index -from tests.testmodels import ModelWithIndexes class CustomIndex(Index): @@ -99,7 +99,10 @@ class TestIndexAliasChar(TestIndexAlias): class TestModelWithIndexes(test.TestCase): def test_meta(self): - self.assertEqual(ModelWithIndexes._meta.indexes, [Index(fields=("f1", "f2"))]) + self.assertEqual( + ModelWithIndexes._meta.indexes, + [Index(fields=("f1", "f2")), Index(fields=("f3",), name="model_with_indexes__f3")], + ) self.assertTrue(ModelWithIndexes._meta.fields_map["id"].index) self.assertTrue(ModelWithIndexes._meta.fields_map["indexed"].index) self.assertTrue(ModelWithIndexes._meta.fields_map["unique_indexed"].unique) diff --git a/tests/testmodels.py b/tests/testmodels.py index afd4b82fe..f5cfc6692 100644 --- a/tests/testmodels.py +++ b/tests/testmodels.py @@ -1059,11 +1059,13 @@ class ModelWithIndexes(Model): unique_indexed = fields.CharField(max_length=16, unique=True) f1 = fields.CharField(max_length=16) f2 = fields.CharField(max_length=16) + f3 = fields.CharField(max_length=16) u1 = fields.IntField() u2 = fields.IntField() class Meta: indexes = [ Index(fields=["f1", "f2"]), + Index(fields=["f3"], name="model_with_indexes__f3"), ] unique_together = [("u1", "u2")] diff --git a/tests/utils/test_describe_model.py b/tests/utils/test_describe_model.py index d1be42986..8c21fc1e7 100644 --- a/tests/utils/test_describe_model.py +++ b/tests/utils/test_describe_model.py @@ -1568,7 +1568,16 @@ def test_describe_indexes_serializable(self): self.assertEqual( val["indexes"], - [{"fields": ["f1", "f2"], "expressions": [], "name": None, "type": "", "extra": ""}], + [ + {"fields": ["f1", "f2"], "expressions": [], "name": None, "type": "", "extra": ""}, + { + "fields": ["f3"], + "expressions": [], + "name": "model_with_indexes__f3", + "type": "", + "extra": "", + }, + ], ) def test_describe_indexes_not_serializable(self): diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index 1f56e7cf6..bac90266d 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -171,6 +171,27 @@ def _generate_fk_name( ) return index_name + def _generate_custom_index_sql(self, index: Index, model: "Type[Model]", safe: bool) -> str: + if index.fields: + fields = list(index.fields) + elif index.expressions: + fields = [ + f"({expression.get_sql(DEFAULT_SQL_CONTEXT)})" for expression in index.expressions + ] + else: + raise ConfigurationError( + "At least one field or expression is required to define an index." + ) + + return self._get_index_sql( + model, + fields, + safe=safe, + index_name=index.name, + index_type=index.INDEX_TYPE, + extra=index.extra, + ) + def _get_index_sql( self, model: "Type[Model]", @@ -346,31 +367,17 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: if model._meta.indexes: for index in model._meta.indexes: - if not isinstance(index, Index): + if isinstance(index, Index): + idx_sql = self._generate_custom_index_sql(index, model, safe) + else: fields = [] for field in index: field_object = model._meta.fields_map[field] fields.append(field_object.source_field or field) + idx_sql = self._get_index_sql(model, fields, safe=safe) - _indexes.append(self._get_index_sql(model, fields, safe=safe)) - else: - if index.fields: - fields = [f for f in index.fields] - elif index.expressions: - fields = [ - f"({expression.get_sql(DEFAULT_SQL_CONTEXT)})" - for expression in index.expressions - ] - else: - raise ConfigurationError( - "At least one field or expression is required to define an index." - ) - - _indexes.append( - self._get_index_sql( - model, fields, safe=safe, index_type=index.INDEX_TYPE, extra=index.extra - ) - ) + if idx_sql: + _indexes.append(idx_sql) field_indexes_sqls = [val for val in list(dict.fromkeys(_indexes)) if val] From 7b1f697213f5827f7749f92f91e7d0a422a3af30 Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Tue, 21 Jan 2025 20:05:49 +0800 Subject: [PATCH 2/5] refactor: add `field_names` property to Index class --- tortoise/backends/base/schema_generator.py | 14 +------------- tortoise/indexes.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index bac90266d..dff68dda8 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -2,8 +2,6 @@ from hashlib import sha256 from typing import TYPE_CHECKING, Any, List, Optional, Set, Type, Union, cast -from pypika_tortoise.context import DEFAULT_SQL_CONTEXT - from tortoise.exceptions import ConfigurationError from tortoise.fields import JSONField, TextField, UUIDField from tortoise.fields.relational import OneToOneFieldInstance @@ -172,20 +170,10 @@ def _generate_fk_name( return index_name def _generate_custom_index_sql(self, index: Index, model: "Type[Model]", safe: bool) -> str: - if index.fields: - fields = list(index.fields) - elif index.expressions: - fields = [ - f"({expression.get_sql(DEFAULT_SQL_CONTEXT)})" for expression in index.expressions - ] - else: - raise ConfigurationError( - "At least one field or expression is required to define an index." - ) return self._get_index_sql( model, - fields, + index.field_names, safe=safe, index_name=index.name, index_type=index.INDEX_TYPE, diff --git a/tortoise/indexes.py b/tortoise/indexes.py index bba96e1aa..6e8008aaf 100644 --- a/tortoise/indexes.py +++ b/tortoise/indexes.py @@ -2,6 +2,7 @@ from typing import Any +from pypika_tortoise.context import DEFAULT_SQL_CONTEXT from pypika_tortoise.terms import Term, ValueWrapper from tortoise.exceptions import ConfigurationError @@ -46,6 +47,19 @@ def describe(self) -> dict: "extra": self.extra, } + @property + def field_names(self) -> list[str]: + if self.fields: + return list(self.fields) + elif self.expressions: + return [ + f"({expression.get_sql(DEFAULT_SQL_CONTEXT)})" for expression in self.expressions + ] + else: + raise ConfigurationError( + "At least one field or expression is required to define an index." + ) + def __repr__(self) -> str: argument = "" if self.expressions: From 94b7539e9fe2af44b9515c22e3152d3288fc7951 Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Tue, 21 Jan 2025 20:35:43 +0800 Subject: [PATCH 3/5] Remove `_generate_custom_index_sql` function in favor of Index.field_names --- tortoise/backends/base/schema_generator.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index dff68dda8..e3c3c1d2b 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -169,17 +169,6 @@ def _generate_fk_name( ) return index_name - def _generate_custom_index_sql(self, index: Index, model: "Type[Model]", safe: bool) -> str: - - return self._get_index_sql( - model, - index.field_names, - safe=safe, - index_name=index.name, - index_type=index.INDEX_TYPE, - extra=index.extra, - ) - def _get_index_sql( self, model: "Type[Model]", @@ -356,7 +345,14 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: if model._meta.indexes: for index in model._meta.indexes: if isinstance(index, Index): - idx_sql = self._generate_custom_index_sql(index, model, safe) + idx_sql = self._get_index_sql( + model, + index.field_names, + safe=safe, + index_name=index.name, + index_type=index.INDEX_TYPE, + extra=index.extra, + ) else: fields = [] for field in index: From 81892fe276b7cb703ab8bcc9b18e7ff9191af091 Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Tue, 21 Jan 2025 20:52:50 +0800 Subject: [PATCH 4/5] Check custom index name in generated schema --- tests/schema/test_generate_schema.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/schema/test_generate_schema.py b/tests/schema/test_generate_schema.py index 0b45385e5..9bb141f4d 100644 --- a/tests/schema/test_generate_schema.py +++ b/tests/schema/test_generate_schema.py @@ -174,6 +174,11 @@ async def test_create_index(self): sql = self.get_sql("CREATE INDEX") self.assertIsNotNone(re.search(r"idx_tournament_created_\w+", sql)) + async def test_create_index_with_custom_name(self): + await self.init_for("tests.testmodels") + sql = self.get_sql("f3") + self.assertIn("model_with_indexes__f3", sql) + async def test_fk_bad_model_name(self): with self.assertRaisesRegex( ConfigurationError, 'ForeignKeyField accepts model name in format "app.Model"' From a6c26f52caeaab10552074eaf21d92c0c980af0b Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Wed, 22 Jan 2025 09:54:47 +0800 Subject: [PATCH 5/5] Add `index_name` and `get_sql` back to Index class for aerich --- tortoise/backends/base/schema_generator.py | 9 +-------- tortoise/indexes.py | 23 +++++++++++++++++++++- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index e3c3c1d2b..11c166727 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -345,14 +345,7 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: if model._meta.indexes: for index in model._meta.indexes: if isinstance(index, Index): - idx_sql = self._get_index_sql( - model, - index.field_names, - safe=safe, - index_name=index.name, - index_type=index.INDEX_TYPE, - extra=index.extra, - ) + idx_sql = index.get_sql(self, model, safe) else: fields = [] for field in index: diff --git a/tortoise/indexes.py b/tortoise/indexes.py index 6e8008aaf..807eba7f5 100644 --- a/tortoise/indexes.py +++ b/tortoise/indexes.py @@ -1,12 +1,16 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any, Type from pypika_tortoise.context import DEFAULT_SQL_CONTEXT from pypika_tortoise.terms import Term, ValueWrapper from tortoise.exceptions import ConfigurationError +if TYPE_CHECKING: + from tortoise.backends.base.schema_generator import BaseSchemaGenerator + from tortoise.models import Model + class Index: INDEX_TYPE = "" @@ -47,6 +51,23 @@ def describe(self) -> dict: "extra": self.extra, } + def index_name(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]") -> str: + # This function is required by aerich + return self.name or schema_generator._generate_index_name("idx", model, self.field_names) + + def get_sql( + self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool + ) -> str: + # This function is required by aerich + return schema_generator._get_index_sql( + model, + self.field_names, + safe, + index_name=self.name, + index_type=self.INDEX_TYPE, + extra=self.extra, + ) + @property def field_names(self) -> list[str]: if self.fields: