Skip to content

Commit

Permalink
Merge upstream/develop
Browse files Browse the repository at this point in the history
  • Loading branch information
waketzheng committed Jan 20, 2025
2 parents 601713f + 71624e7 commit dda50a2
Show file tree
Hide file tree
Showing 14 changed files with 212 additions and 91 deletions.
13 changes: 11 additions & 2 deletions tests/fields/test_db_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tortoise.contrib import test
from tortoise.exceptions import ConfigurationError
from tortoise.indexes import Index
from tests.testmodels import ModelWithIndexes


class CustomIndex(Index):
Expand All @@ -14,7 +15,7 @@ def __init__(self, *args, **kw):
self._foo = ""


class TestIndexHashEqualRepr(test.TestCase):
class TestIndexHashEqualRepr(test.SimpleTestCase):
def test_index_eq(self):
assert Index(fields=("id",)) == Index(fields=("id",))
assert CustomIndex(fields=("id",)) == CustomIndex(fields=("id",))
Expand Down Expand Up @@ -46,7 +47,7 @@ def test_index_repr(self):
assert repr(Index(fields=("id",), name="MyIndex")) == "Index(fields=['id'], name='MyIndex')"
assert repr(Index(Field("id"))) == f'Index({str(Field("id"))})'
assert repr(Index(Field("a"), name="Id")) == f"Index({str(Field('a'))}, name='Id')"
with self.assertRaises(ValueError):
with self.assertRaises(ConfigurationError):
Index(Field("id"), fields=("name",))


Expand Down Expand Up @@ -94,3 +95,11 @@ class TestIndexAliasUUID(TestIndexAlias):
class TestIndexAliasChar(TestIndexAlias):
Field = fields.CharField
init_kwargs = {"max_length": 10}


class TestModelWithIndexes(test.TestCase):
def test_meta(self):
self.assertEqual(ModelWithIndexes._meta.indexes, [Index(fields=("f1", "f2"))])
self.assertTrue(ModelWithIndexes._meta.fields_map["id"].index)
self.assertTrue(ModelWithIndexes._meta.fields_map["indexed"].index)
self.assertTrue(ModelWithIndexes._meta.fields_map["unique_indexed"].unique)
20 changes: 10 additions & 10 deletions tests/schema/test_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,10 +724,10 @@ async def test_index_safe(self):
"""CREATE TABLE IF NOT EXISTS `index` (
`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
`full_text` LONGTEXT NOT NULL,
`geometry` GEOMETRY NOT NULL
) CHARACTER SET utf8mb4;
CREATE FULLTEXT INDEX IF NOT EXISTS `idx_index_full_te_3caba4` ON `index` (`full_text`) WITH PARSER ngram;
CREATE SPATIAL INDEX IF NOT EXISTS `idx_index_geometr_0b4dfb` ON `index` (`geometry`);""",
`geometry` GEOMETRY NOT NULL,
FULLTEXT KEY `idx_index_full_te_3caba4` (`full_text`) WITH PARSER ngram,
SPATIAL KEY `idx_index_geometr_0b4dfb` (`geometry`)
) CHARACTER SET utf8mb4;""",
)

async def test_index_unsafe(self):
Expand All @@ -738,10 +738,10 @@ async def test_index_unsafe(self):
"""CREATE TABLE `index` (
`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
`full_text` LONGTEXT NOT NULL,
`geometry` GEOMETRY NOT NULL
) CHARACTER SET utf8mb4;
CREATE FULLTEXT INDEX `idx_index_full_te_3caba4` ON `index` (`full_text`) WITH PARSER ngram;
CREATE SPATIAL INDEX `idx_index_geometr_0b4dfb` ON `index` (`geometry`);""",
`geometry` GEOMETRY NOT NULL,
FULLTEXT KEY `idx_index_full_te_3caba4` (`full_text`) WITH PARSER ngram,
SPATIAL KEY `idx_index_geometr_0b4dfb` (`geometry`)
) CHARACTER SET utf8mb4;""",
)

async def test_m2m_no_auto_create(self):
Expand Down Expand Up @@ -1102,7 +1102,7 @@ async def test_index_unsafe(self):
CREATE INDEX "idx_index_gist_c807bf" ON "index" USING GIST ("gist");
CREATE INDEX "idx_index_sp_gist_2c0bad" ON "index" USING SPGIST ("sp_gist");
CREATE INDEX "idx_index_hash_cfe6b5" ON "index" USING HASH ("hash");
CREATE INDEX "idx_index_partial_c5be6a" ON "index" USING ("partial") WHERE id = 1;""",
CREATE INDEX "idx_index_partial_c5be6a" ON "index" ("partial") WHERE id = 1;""",
)

async def test_index_safe(self):
Expand All @@ -1126,7 +1126,7 @@ async def test_index_safe(self):
CREATE INDEX IF NOT EXISTS "idx_index_gist_c807bf" ON "index" USING GIST ("gist");
CREATE INDEX IF NOT EXISTS "idx_index_sp_gist_2c0bad" ON "index" USING SPGIST ("sp_gist");
CREATE INDEX IF NOT EXISTS "idx_index_hash_cfe6b5" ON "index" USING HASH ("hash");
CREATE INDEX IF NOT EXISTS "idx_index_partial_c5be6a" ON "index" USING ("partial") WHERE id = 1;""",
CREATE INDEX IF NOT EXISTS "idx_index_partial_c5be6a" ON "index" ("partial") WHERE id = 1;""",
)

async def test_m2m_no_auto_create(self):
Expand Down
27 changes: 11 additions & 16 deletions tests/test_two_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,30 @@ async def asyncTearDown(self) -> None:
await Tortoise._drop_databases()
await super().asyncTearDown()

def build_select_sql(self) -> str:
if isinstance(self.db, OracleClient):
return 'SELECT * FROM "eventtwo"'
return "SELECT * FROM eventtwo"

async def test_two_databases(self):
tournament = await Tournament.create(name="Tournament")
await EventTwo.create(name="Event", tournament_id=tournament.id)

select_sql = self.build_select_sql()
with self.assertRaises(OperationalError):
if isinstance(self.db, OracleClient):
await self.db.execute_query('SELECT * FROM "eventtwo"')
else:
await self.db.execute_query("SELECT * FROM eventtwo")
if isinstance(self.db, OracleClient):
_, results = await self.second_db.execute_query('SELECT * FROM "eventtwo"')
else:
_, results = await self.second_db.execute_query("SELECT * FROM eventtwo")
await self.db.execute_query(select_sql)
_, results = await self.second_db.execute_query(select_sql)
self.assertEqual(dict(results[0]), {"id": 1, "name": "Event", "tournament_id": 1})

async def test_two_databases_relation(self):
tournament = await Tournament.create(name="Tournament")
event = await EventTwo.create(name="Event", tournament_id=tournament.id)

select_sql = self.build_select_sql()
with self.assertRaises(OperationalError):
if isinstance(self.db, OracleClient):
await self.db.execute_query('SELECT * FROM "eventtwo"')
else:
await self.db.execute_query("SELECT * FROM eventtwo")
await self.db.execute_query(select_sql)

if isinstance(self.db, OracleClient):
_, results = await self.second_db.execute_query('SELECT * FROM "eventtwo"')
else:
_, results = await self.second_db.execute_query("SELECT * FROM eventtwo")
_, results = await self.second_db.execute_query(select_sql)
self.assertEqual(dict(results[0]), {"id": 1, "name": "Event", "tournament_id": 1})

teams = []
Expand Down
17 changes: 17 additions & 0 deletions tests/testmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tortoise import fields
from tortoise.exceptions import ValidationError
from tortoise.fields import NO_ACTION
from tortoise.indexes import Index
from tortoise.manager import Manager
from tortoise.models import Model
from tortoise.queryset import QuerySet
Expand Down Expand Up @@ -1050,3 +1051,19 @@ class BenchmarkManyFields(Model):
col_text4 = fields.TextField(null=True)
col_decimal4 = fields.DecimalField(12, 8, null=True)
col_json4 = fields.JSONField[dict](null=True)


class ModelWithIndexes(Model):
id = fields.IntField(primary_key=True)
indexed = fields.CharField(max_length=16, index=True)
unique_indexed = fields.CharField(max_length=16, unique=True)
f1 = fields.CharField(max_length=16)
f2 = fields.CharField(max_length=16)
u1 = fields.IntField()
u2 = fields.IntField()

class Meta:
indexes = [
Index(fields=["f1", "f2"]),
]
unique_together = [("u1", "u2")]
17 changes: 17 additions & 0 deletions tests/utils/test_describe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tests.testmodels import (
Event,
JSONFields,
ModelWithIndexes,
Reporter,
SourceFields,
StraightFields,
Expand Down Expand Up @@ -1561,3 +1562,19 @@ def test_describe_model_json_native(self):
"m2m_fields": [],
},
)

def test_describe_indexes_serializable(self):
val = ModelWithIndexes.describe()

self.assertEqual(
val["indexes"],
[{"fields": ["f1", "f2"], "expressions": [], "name": None, "type": "", "extra": ""}],
)

def test_describe_indexes_not_serializable(self):
val = ModelWithIndexes.describe(serializable=False)

self.assertEqual(
val["indexes"],
ModelWithIndexes._meta.indexes,
)
61 changes: 47 additions & 14 deletions tortoise/backends/base/schema_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from __future__ import annotations

import re
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Type, Union, cast
from typing import TYPE_CHECKING, Any, Type, cast

from pypika_tortoise.context import DEFAULT_SQL_CONTEXT

from tortoise.exceptions import ConfigurationError
from tortoise.fields import JSONField, TextField, UUIDField
Expand All @@ -23,8 +27,10 @@ class BaseSchemaGenerator:
DIALECT = "sql"
TABLE_CREATE_TEMPLATE = 'CREATE TABLE {exists}"{table_name}" ({fields}){extra}{comment};'
FIELD_TEMPLATE = '"{name}" {type}{nullable}{unique}{primary}{default}{comment}'
INDEX_CREATE_TEMPLATE = 'CREATE INDEX {exists}"{index_name}" ON "{table_name}" ({fields});'
UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace(" INDEX", " UNIQUE INDEX")
INDEX_CREATE_TEMPLATE = (
'CREATE {index_type}INDEX {exists}"{index_name}" ON "{table_name}" ({fields}){extra};'
)
UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace("INDEX", "UNIQUE INDEX")
UNIQUE_CONSTRAINT_CREATE_TEMPLATE = 'CONSTRAINT "{index_name}" UNIQUE ({fields})'
GENERATED_PK_TEMPLATE = '"{field_name}" {generated_sql}{comment}'
FK_TEMPLATE = ' REFERENCES "{table}" ("{field}") ON DELETE {on_delete}{comment}'
Expand Down Expand Up @@ -140,7 +146,7 @@ def _make_hash(*args: str, length: int) -> str:
return sha256(";".join(args).encode("utf-8")).hexdigest()[:length]

def _generate_index_name(
self, prefix: str, model: "Union[Type[Model], str]", field_names: list[str]
self, prefix: str, model: "Type[Model] | str", field_names: list[str]
) -> str:
# NOTE: for compatibility, index name should not be longer than 30
# characters (Oracle limit).
Expand All @@ -167,21 +173,33 @@ def _generate_fk_name(
)
return index_name

def _get_index_sql(self, model: "Type[Model]", field_names: list[str], safe: bool) -> str:
def _get_index_sql(
self,
model: "Type[Model]",
field_names: list[str],
safe: bool,
index_name: str | None = None,
index_type: str | None = None,
extra: str | None = None,
) -> str:
return self.INDEX_CREATE_TEMPLATE.format(
exists="IF NOT EXISTS " if safe else "",
index_name=self._generate_index_name("idx", model, field_names),
index_name=index_name or self._generate_index_name("idx", model, field_names),
index_type=f"{index_type} " if index_type else "",
table_name=model._meta.db_table,
fields=", ".join([self.quote(f) for f in field_names]),
extra=f"{extra}" if extra else "",
)

def _get_unique_index_sql(self, exists: str, table_name: str, field_names: list[str]) -> str:
index_name = self._generate_index_name("uidx", table_name, field_names)
return self.UNIQUE_INDEX_CREATE_TEMPLATE.format(
exists=exists,
index_name=index_name,
index_type="",
table_name=table_name,
fields=", ".join([self.quote(f) for f in field_names]),
extra="",
)

def _get_unique_constraint_sql(self, model: "Type[Model]", field_names: list[str]) -> str:
Expand Down Expand Up @@ -324,22 +342,37 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
self._get_unique_constraint_sql(model, unique_together_to_create)
)

# Indexes.
_indexes = [
self._get_index_sql(model, [field_name], safe=safe) for field_name in fields_with_index
]

if model._meta.indexes:
for indexes_list in model._meta.indexes:
if not isinstance(indexes_list, Index):
indexes_to_create = []
for field in indexes_list:
for index in model._meta.indexes:
if not isinstance(index, Index):
fields = []
for field in index:
field_object = model._meta.fields_map[field]
indexes_to_create.append(field_object.source_field or field)
fields.append(field_object.source_field or field)

_indexes.append(self._get_index_sql(model, indexes_to_create, safe=safe))
_indexes.append(self._get_index_sql(model, fields, safe=safe))
else:
_indexes.append(indexes_list.get_sql(self, model, safe))
if index.fields:
fields = [f for f in index.fields]
elif index.expressions:
fields = [
f"({expression.get_sql(DEFAULT_SQL_CONTEXT)})"
for expression in index.expressions
]
else:
raise ConfigurationError(
"At least one field or expression is required to define an index."
)

_indexes.append(
self._get_index_sql(
model, fields, safe=safe, index_type=index.INDEX_TYPE, extra=index.extra
)
)

field_indexes_sqls = [val for val in list(dict.fromkeys(_indexes)) if val]

Expand Down
25 changes: 24 additions & 1 deletion tortoise/backends/base_postgres/schema_generator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
from typing import TYPE_CHECKING, Any
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Type

from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.converters import encoders
from tortoise.models import Model

if TYPE_CHECKING: # pragma: nocoverage
from .client import BasePostgresClient


class BasePostgresSchemaGenerator(BaseSchemaGenerator):
DIALECT = "postgres"
INDEX_CREATE_TEMPLATE = (
'CREATE INDEX {exists}"{index_name}" ON "{table_name}" {index_type}({fields}){extra};'
)
UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace("INDEX", "UNIQUE INDEX")
TABLE_COMMENT_TEMPLATE = "COMMENT ON TABLE \"{table}\" IS '{comment}';"
COLUMN_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table}"."{column}" IS \'{comment}\';'
GENERATED_PK_TEMPLATE = '"{field_name}" {generated_sql}'
Expand Down Expand Up @@ -61,3 +68,19 @@ def _escape_default_value(self, default: Any):
if isinstance(default, bool):
return default
return encoders.get(type(default))(default) # type: ignore

def _get_index_sql(
self,
model: "Type[Model]",
field_names: list[str],
safe: bool,
index_name: str | None = None,
index_type: str | None = None,
extra: str | None = None,
) -> str:
if index_type:
index_type = f"USING {index_type}"

return super()._get_index_sql(
model, field_names, safe, index_name=index_name, index_type=index_type, extra=extra
)
20 changes: 16 additions & 4 deletions tortoise/backends/mssql/schema_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Type

from tortoise.backends.base.schema_generator import BaseSchemaGenerator
Expand Down Expand Up @@ -59,11 +61,21 @@ def _column_default_generator(
def _escape_default_value(self, default: Any):
return encoders.get(type(default))(default) # type: ignore

def _get_index_sql(self, model: "Type[Model]", field_names: list[str], safe: bool) -> str:
return super(MSSQLSchemaGenerator, self)._get_index_sql(model, field_names, False)
def _get_index_sql(
self,
model: "Type[Model]",
field_names: list[str],
safe: bool,
index_name: str | None = None,
index_type: str | None = None,
extra: str | None = None,
) -> str:
return super()._get_index_sql(
model, field_names, False, index_name=index_name, index_type=index_type, extra=extra
)

def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
return super(MSSQLSchemaGenerator, self)._get_table_sql(model, False)
return super()._get_table_sql(model, False)

def _create_fk_string(
self,
Expand Down Expand Up @@ -99,7 +111,7 @@ def _create_string(
) -> str:
if nullable == "":
unique = ""
return super(MSSQLSchemaGenerator, self)._create_string(
return super()._create_string(
db_column=db_column,
field_type=field_type,
nullable=nullable,
Expand Down
Loading

0 comments on commit dda50a2

Please sign in to comment.