Skip to content

Commit

Permalink
prevent the creation of embedded models
Browse files Browse the repository at this point in the history
  • Loading branch information
timgraham committed Jan 21, 2025
1 parent d85d5a5 commit 05a3ee1
Show file tree
Hide file tree
Showing 16 changed files with 236 additions and 23 deletions.
11 changes: 11 additions & 0 deletions django_mongodb_backend/fields/embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,18 @@ def __init__(self, embedded_model, *args, **kwargs):
super().__init__(*args, **kwargs)

def check(self, **kwargs):
from ..models import EmbeddedModel

errors = super().check(**kwargs)
if not issubclass(self.embedded_model, EmbeddedModel):
return [
checks.Error(
"Embedded models must be a subclass of "
"django_mongodb_backend.models.EmbeddedModel.",
obj=self,
id="django_mongodb_backend.embedded_model.E002",
)
]
for field in self.embedded_model._meta.fields:
if field.remote_field:
errors.append(
Expand Down
30 changes: 30 additions & 0 deletions django_mongodb_backend/managers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,37 @@
from django.db import NotSupportedError
from django.db.models.manager import BaseManager

from .queryset import MongoQuerySet


class MongoManager(BaseManager.from_queryset(MongoQuerySet)):
pass


class EmbeddedModelManager(BaseManager):
"""
Prevent all queryset operations on embedded models since they don't have
their own collection.
Raise a helpful error message for some basic QuerySet methods. Subclassing
BaseManager means that other methods raise, e.g. AttributeError:
'EmbeddedModelManager' object has no attribute 'update_or_create'".
"""

def all(self):
raise NotSupportedError("EmbeddedModels cannot be queried.")

def get(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be queried.")

def filter(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be queried.")

def create(self, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be created.")

def update(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be updated.")

def delete(self):
raise NotSupportedError("EmbeddedModels cannot be deleted.")
16 changes: 16 additions & 0 deletions django_mongodb_backend/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from django.db import NotSupportedError, models

from .managers import EmbeddedModelManager


class EmbeddedModel(models.Model):
objects = EmbeddedModelManager()

class Meta:
abstract = True

def delete(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be deleted.")

def save(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be saved.")
30 changes: 30 additions & 0 deletions django_mongodb_backend/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,24 @@
from .utils import OperationCollector


def ignore_embedded_models(func):
"""
Make a SchemaEditor method a no-op if model is an EmbeddedModel (unless
parent_model isn't None, in which case this is a valid recursive operation
such as adding an index on an embedded model's field).
"""

def wrapper(self, model, *args, **kwargs):
parent_model = kwargs.get("parent_model")
from .models import EmbeddedModel

if issubclass(model, EmbeddedModel) and parent_model is None:
return
func(self, model, *args, **kwargs)

return wrapper


class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def get_collection(self, name):
if self.collect_sql:
Expand All @@ -22,6 +40,7 @@ def get_database(self):
return self.connection.get_database()

@wrap_database_errors
@ignore_embedded_models
def create_model(self, model):
self.get_database().create_collection(model._meta.db_table)
self._create_model_indexes(model)
Expand Down Expand Up @@ -75,13 +94,15 @@ def _create_model_indexes(self, model, column_prefix="", parent_model=None):
for index in model._meta.indexes:
self.add_index(model, index, column_prefix=column_prefix, parent_model=parent_model)

@ignore_embedded_models
def delete_model(self, model):
# Delete implicit M2m tables.
for field in model._meta.local_many_to_many:
if field.remote_field.through._meta.auto_created:
self.delete_model(field.remote_field.through)
self.get_collection(model._meta.db_table).drop()

@ignore_embedded_models
def add_field(self, model, field):
# Create implicit M2M tables.
if field.many_to_many and field.remote_field.through._meta.auto_created:
Expand All @@ -103,6 +124,7 @@ def add_field(self, model, field):
elif self._field_should_have_unique(field):
self._add_field_unique(model, field)

@ignore_embedded_models
def _alter_field(
self,
model,
Expand Down Expand Up @@ -149,6 +171,7 @@ def _alter_field(
if not old_field_unique and new_field_unique:
self._add_field_unique(model, new_field)

@ignore_embedded_models
def remove_field(self, model, field):
# Remove implicit M2M tables.
if field.many_to_many and field.remote_field.through._meta.auto_created:
Expand Down Expand Up @@ -210,6 +233,7 @@ def _remove_model_indexes(self, model, column_prefix="", parent_model=None):
for index in model._meta.indexes:
self.remove_index(parent_model or model, index)

@ignore_embedded_models
def alter_index_together(self, model, old_index_together, new_index_together, column_prefix=""):
olds = {tuple(fields) for fields in old_index_together}
news = {tuple(fields) for fields in new_index_together}
Expand All @@ -222,6 +246,7 @@ def alter_index_together(self, model, old_index_together, new_index_together, co
for field_names in news.difference(olds):
self._add_composed_index(model, field_names, column_prefix=column_prefix)

@ignore_embedded_models
def alter_unique_together(
self, model, old_unique_together, new_unique_together, column_prefix="", parent_model=None
):
Expand Down Expand Up @@ -249,6 +274,7 @@ def alter_unique_together(
model, constraint, parent_model=parent_model, column_prefix=column_prefix
)

@ignore_embedded_models
def add_index(
self, model, index, *, field=None, unique=False, column_prefix="", parent_model=None
):
Expand Down Expand Up @@ -302,6 +328,7 @@ def _add_field_index(self, model, field, *, column_prefix=""):
index.name = self._create_index_name(model._meta.db_table, [column_prefix + field.column])
self.add_index(model, index, field=field, column_prefix=column_prefix)

@ignore_embedded_models
def remove_index(self, model, index):
if index.contains_expressions:
return
Expand Down Expand Up @@ -355,6 +382,7 @@ def _remove_field_index(self, model, field, column_prefix=""):
)
collection.drop_index(index_names[0])

@ignore_embedded_models
def add_constraint(self, model, constraint, field=None, column_prefix="", parent_model=None):
if isinstance(constraint, UniqueConstraint) and self._unique_supported(
condition=constraint.condition,
Expand Down Expand Up @@ -384,6 +412,7 @@ def _add_field_unique(self, model, field, column_prefix=""):
constraint = UniqueConstraint(fields=[field.name], name=name)
self.add_constraint(model, constraint, field=field, column_prefix=column_prefix)

@ignore_embedded_models
def remove_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and self._unique_supported(
condition=constraint.condition,
Expand Down Expand Up @@ -417,6 +446,7 @@ def _remove_field_unique(self, model, field, column_prefix=""):
)
self.get_collection(model._meta.db_table).drop_index(constraint_names[0])

@ignore_embedded_models
def alter_db_table(self, model, old_db_table, new_db_table):
if old_db_table == new_db_table:
return
Expand Down
3 changes: 2 additions & 1 deletion docs/source/embedded-models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ The basics
Let's consider this example::

from django_mongodb_backend.fields import EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel

class Customer(models.Model):
name = models.CharField(...)
address = EmbeddedModelField("Address")
...

class Address(models.Model):
class Address(EmbeddedModel):
...
city = models.CharField(...)

Expand Down
12 changes: 8 additions & 4 deletions docs/source/fields.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,11 @@ Stores a model of type ``embedded_model``.

This is a required argument.

Specifies the model class to embed. It can be either a concrete model
class or a :ref:`lazy reference <lazy-relationships>` to a model class.
Specifies the model class to embed. It must be a subclass of
:class:`django_mongodb_backend.models.EmbeddedModel`.

It can be either a concrete model class or a :ref:`lazy reference
<lazy-relationships>` to a model class.

The embedded model cannot have relational fields
(:class:`~django.db.models.ForeignKey`,
Expand All @@ -234,11 +237,12 @@ Stores a model of type ``embedded_model``.

from django.db import models
from django_mongodb_backend.fields import EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel

class Address(models.Model):
class Address(EmbeddedModel):
...

class Author(models.Model):
class Author(EmbeddedModel):
address = EmbeddedModelField(Address)

class Book(models.Model):
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ django-mongodb-backend 5.0.x documentation
fields
querysets
forms
models
embedded-models

Indices and tables
Expand Down
15 changes: 15 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Model reference
===============

.. module:: django_mongodb_backend.models

One MongoDB-specific model is available in ``django_mongodb_backend.models``.

.. class:: EmbeddedModel

An abstract model which all :doc:`embedded models <embedded-models>` must
subclass.

Since these models are not stored in their own collection, they do not have
any of the normal ``QuerySet`` methods (``all()``, ``filter()``, ``delete()``,
etc.) You also cannot call ``Model.save()`` and ``delete()`` on them.
7 changes: 4 additions & 3 deletions tests/model_fields_/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.db import models

from django_mongodb_backend.fields import ArrayField, EmbeddedModelField, ObjectIdField
from django_mongodb_backend.models import EmbeddedModel


# ObjectIdField
Expand Down Expand Up @@ -98,19 +99,19 @@ class Holder(models.Model):
data = EmbeddedModelField("Data", null=True, blank=True)


class Data(models.Model):
class Data(EmbeddedModel):
integer = models.IntegerField(db_column="custom_column")
auto_now = models.DateTimeField(auto_now=True)
auto_now_add = models.DateTimeField(auto_now_add=True)


class Address(models.Model):
class Address(EmbeddedModel):
city = models.CharField(max_length=20)
state = models.CharField(max_length=2)
zip_code = models.IntegerField(db_index=True)


class Author(models.Model):
class Author(EmbeddedModel):
name = models.CharField(max_length=10)
age = models.IntegerField()
address = EmbeddedModelField(Address)
Expand Down
19 changes: 18 additions & 1 deletion tests/model_fields_/test_embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.test.utils import isolate_apps

from django_mongodb_backend.fields import EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel

from .models import (
Address,
Expand Down Expand Up @@ -108,7 +109,7 @@ def test_nested(self):
@isolate_apps("model_fields_")
class CheckTests(SimpleTestCase):
def test_no_relational_fields(self):
class Target(models.Model):
class Target(EmbeddedModel):
key = models.ForeignKey("MyModel", models.CASCADE)

class MyModel(models.Model):
Expand All @@ -121,3 +122,19 @@ class MyModel(models.Model):
self.assertEqual(
msg, "Embedded models cannot have relational fields (Target.key is a ForeignKey)."
)

def test_embedded_model_subclass(self):
class Target(models.Model):
pass

class MyModel(models.Model):
field = EmbeddedModelField(Target)

errors = MyModel().check()
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].id, "django_mongodb_backend.embedded_model.E002")
msg = errors[0].msg
self.assertEqual(
msg,
"Embedded models must be a subclass of django_mongodb_backend.models.EmbeddedModel.",
)
8 changes: 2 additions & 6 deletions tests/model_forms_/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from django.db import models

from django_mongodb_backend.fields import EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel


class Address(models.Model):
class Address(EmbeddedModel):
po_box = models.CharField(max_length=50, blank=True, verbose_name="PO Box")
city = models.CharField(max_length=20)
state = models.CharField(max_length=2)
Expand All @@ -15,8 +16,3 @@ class Author(models.Model):
age = models.IntegerField()
address = EmbeddedModelField(Address)
billing_address = EmbeddedModelField(Address, blank=True, null=True)


class Book(models.Model):
name = models.CharField(max_length=100)
author = EmbeddedModelField(Author)
Empty file added tests/models_/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions tests/models_/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django_mongodb_backend.models import EmbeddedModel


class Embed(EmbeddedModel):
pass
Loading

0 comments on commit 05a3ee1

Please sign in to comment.