diff --git a/tests/fields/test_db_index.py b/tests/fields/test_db_index.py index 60ef15d36..6b3a15039 100644 --- a/tests/fields/test_db_index.py +++ b/tests/fields/test_db_index.py @@ -99,7 +99,10 @@ class TestIndexAliasChar(TestIndexAlias): class TestModelWithIndexes(test.TestCase): def test_meta(self): - self.assertEqual(ModelWithIndexes._meta.indexes, [Index(fields=("f1", "f2"))]) + self.assertEqual( + ModelWithIndexes._meta.indexes, + [Index(fields=("f1", "f2")), Index(fields=("f3",), name="model_with_indexes__f3")], + ) self.assertTrue(ModelWithIndexes._meta.fields_map["id"].index) self.assertTrue(ModelWithIndexes._meta.fields_map["indexed"].index) self.assertTrue(ModelWithIndexes._meta.fields_map["unique_indexed"].unique) diff --git a/tests/schema/test_generate_schema.py b/tests/schema/test_generate_schema.py index 0b45385e5..9bb141f4d 100644 --- a/tests/schema/test_generate_schema.py +++ b/tests/schema/test_generate_schema.py @@ -174,6 +174,11 @@ async def test_create_index(self): sql = self.get_sql("CREATE INDEX") self.assertIsNotNone(re.search(r"idx_tournament_created_\w+", sql)) + async def test_create_index_with_custom_name(self): + await self.init_for("tests.testmodels") + sql = self.get_sql("f3") + self.assertIn("model_with_indexes__f3", sql) + async def test_fk_bad_model_name(self): with self.assertRaisesRegex( ConfigurationError, 'ForeignKeyField accepts model name in format "app.Model"' diff --git a/tests/testmodels.py b/tests/testmodels.py index 319bae26c..dbb8b2ba1 100644 --- a/tests/testmodels.py +++ b/tests/testmodels.py @@ -1059,11 +1059,13 @@ class ModelWithIndexes(Model): unique_indexed = fields.CharField(max_length=16, unique=True) f1 = fields.CharField(max_length=16) f2 = fields.CharField(max_length=16) + f3 = fields.CharField(max_length=16) u1 = fields.IntField() u2 = fields.IntField() class Meta: indexes = [ Index(fields=["f1", "f2"]), + Index(fields=["f3"], name="model_with_indexes__f3"), ] unique_together = [("u1", "u2")] diff --git a/tests/utils/test_describe_model.py b/tests/utils/test_describe_model.py index d1be42986..8c21fc1e7 100644 --- a/tests/utils/test_describe_model.py +++ b/tests/utils/test_describe_model.py @@ -1568,7 +1568,16 @@ def test_describe_indexes_serializable(self): self.assertEqual( val["indexes"], - [{"fields": ["f1", "f2"], "expressions": [], "name": None, "type": "", "extra": ""}], + [ + {"fields": ["f1", "f2"], "expressions": [], "name": None, "type": "", "extra": ""}, + { + "fields": ["f3"], + "expressions": [], + "name": "model_with_indexes__f3", + "type": "", + "extra": "", + }, + ], ) def test_describe_indexes_not_serializable(self): diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index 4d26af19c..070112ebf 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -4,8 +4,6 @@ from hashlib import sha256 from typing import TYPE_CHECKING, Any, Type, cast -from pypika_tortoise.context import DEFAULT_SQL_CONTEXT - from tortoise.exceptions import ConfigurationError from tortoise.fields import JSONField, TextField, UUIDField from tortoise.fields.relational import OneToOneFieldInstance @@ -348,31 +346,17 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: if model._meta.indexes: for index in model._meta.indexes: - if not isinstance(index, Index): + if isinstance(index, Index): + idx_sql = index.get_sql(self, model, safe) + else: fields = [] for field in index: field_object = model._meta.fields_map[field] fields.append(field_object.source_field or field) + idx_sql = self._get_index_sql(model, fields, safe=safe) - _indexes.append(self._get_index_sql(model, fields, safe=safe)) - else: - if index.fields: - fields = [f for f in index.fields] - elif index.expressions: - fields = [ - f"({expression.get_sql(DEFAULT_SQL_CONTEXT)})" - for expression in index.expressions - ] - else: - raise ConfigurationError( - "At least one field or expression is required to define an index." - ) - - _indexes.append( - self._get_index_sql( - model, fields, safe=safe, index_type=index.INDEX_TYPE, extra=index.extra - ) - ) + if idx_sql: + _indexes.append(idx_sql) field_indexes_sqls = [val for val in list(dict.fromkeys(_indexes)) if val] diff --git a/tortoise/indexes.py b/tortoise/indexes.py index bba96e1aa..807eba7f5 100644 --- a/tortoise/indexes.py +++ b/tortoise/indexes.py @@ -1,11 +1,16 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any, Type +from pypika_tortoise.context import DEFAULT_SQL_CONTEXT from pypika_tortoise.terms import Term, ValueWrapper from tortoise.exceptions import ConfigurationError +if TYPE_CHECKING: + from tortoise.backends.base.schema_generator import BaseSchemaGenerator + from tortoise.models import Model + class Index: INDEX_TYPE = "" @@ -46,6 +51,36 @@ def describe(self) -> dict: "extra": self.extra, } + def index_name(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]") -> str: + # This function is required by aerich + return self.name or schema_generator._generate_index_name("idx", model, self.field_names) + + def get_sql( + self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool + ) -> str: + # This function is required by aerich + return schema_generator._get_index_sql( + model, + self.field_names, + safe, + index_name=self.name, + index_type=self.INDEX_TYPE, + extra=self.extra, + ) + + @property + def field_names(self) -> list[str]: + if self.fields: + return list(self.fields) + elif self.expressions: + return [ + f"({expression.get_sql(DEFAULT_SQL_CONTEXT)})" for expression in self.expressions + ] + else: + raise ConfigurationError( + "At least one field or expression is required to define an index." + ) + def __repr__(self) -> str: argument = "" if self.expressions: