Skip to content

Commit

Permalink
Add index_name and get_sql back to Index class for aerich
Browse files Browse the repository at this point in the history
  • Loading branch information
waketzheng committed Jan 22, 2025
1 parent 81892fe commit a6c26f5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
9 changes: 1 addition & 8 deletions tortoise/backends/base/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,14 +345,7 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
if model._meta.indexes:
for index in model._meta.indexes:
if isinstance(index, Index):
idx_sql = self._get_index_sql(
model,
index.field_names,
safe=safe,
index_name=index.name,
index_type=index.INDEX_TYPE,
extra=index.extra,
)
idx_sql = index.get_sql(self, model, safe)
else:
fields = []
for field in index:
Expand Down
23 changes: 22 additions & 1 deletion tortoise/indexes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Any, Type

from pypika_tortoise.context import DEFAULT_SQL_CONTEXT
from pypika_tortoise.terms import Term, ValueWrapper

from tortoise.exceptions import ConfigurationError

if TYPE_CHECKING:
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.models import Model


class Index:
INDEX_TYPE = ""
Expand Down Expand Up @@ -47,6 +51,23 @@ def describe(self) -> dict:
"extra": self.extra,
}

def index_name(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]") -> str:
# This function is required by aerich
return self.name or schema_generator._generate_index_name("idx", model, self.field_names)

def get_sql(
self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool
) -> str:
# This function is required by aerich
return schema_generator._get_index_sql(
model,
self.field_names,
safe,
index_name=self.name,
index_type=self.INDEX_TYPE,
extra=self.extra,
)

@property
def field_names(self) -> list[str]:
if self.fields:
Expand Down

0 comments on commit a6c26f5

Please sign in to comment.