Skip to content

Commit

Permalink
Check model fields on filtering methods of queryset types (#2277)
Browse files Browse the repository at this point in the history
  • Loading branch information
flaeppe authored Jul 26, 2024
1 parent 7dd81cc commit a28717d
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 67 deletions.
12 changes: 1 addition & 11 deletions mypy_django_plugin/lib/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
DynamicClassDefContext,
FunctionContext,
MethodContext,
SemanticAnalyzerPluginInterface,
)
from mypy.semanal import SemanticAnalyzer
from mypy.types import AnyType, Instance, LiteralType, NoneTyp, TupleType, TypedDictType, TypeOfAny, UnionType
Expand Down Expand Up @@ -63,9 +62,7 @@ def get_django_metadata(model_info: TypeInfo) -> DjangoTypeMetadata:
return cast(DjangoTypeMetadata, model_info.metadata.setdefault("django", {}))


def get_django_metadata_bases(
model_info: TypeInfo, key: Literal["baseform_bases", "manager_bases", "queryset_bases"]
) -> Dict[str, int]:
def get_django_metadata_bases(model_info: TypeInfo, key: Literal["baseform_bases", "queryset_bases"]) -> Dict[str, int]:
return get_django_metadata(model_info).setdefault(key, cast(Dict[str, int], {}))


Expand Down Expand Up @@ -422,13 +419,6 @@ def add_new_sym_for_info(
info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True, no_serialize=no_serialize)


def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) -> None:
sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
if sym is not None and isinstance(sym.node, TypeInfo):
bases = get_django_metadata_bases(sym.node, "manager_bases")
bases[fullname] = 1


def is_abstract_model(model: TypeInfo) -> bool:
if model.fullname in fullnames.DJANGO_ABSTRACT_MODELS:
return True
Expand Down
74 changes: 25 additions & 49 deletions mypy_django_plugin/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
import sys
from functools import partial
from functools import cached_property, partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

from mypy.build import PRI_MED, PRI_MYPY
Expand All @@ -19,7 +19,6 @@
)
from mypy.types import Type as MypyType

import mypy_django_plugin.transformers.orm_lookups
from mypy_django_plugin.config import DjangoPluginConfig
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.exceptions import UnregisteredModelError
Expand All @@ -31,6 +30,7 @@
manytomany,
manytoone,
meta,
orm_lookups,
querysets,
request,
settings,
Expand Down Expand Up @@ -60,10 +60,6 @@ def transform_form_class(ctx: ClassDefContext) -> None:
forms.make_meta_nested_class_inherit_from_any(ctx)


def add_new_manager_base_hook(ctx: ClassDefContext) -> None:
helpers.add_new_manager_base(ctx.api, ctx.cls.fullname)


class NewSemanalDjangoPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
Expand All @@ -83,15 +79,6 @@ def _get_current_queryset_bases(self) -> Dict[str, int]:
else:
return {}

def _get_current_manager_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(fullnames.MANAGER_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
bases = helpers.get_django_metadata_bases(model_sym.node, "manager_bases")
bases[fullnames.MANAGER_CLASS_FULLNAME] = 1
return bases
else:
return {}

def _get_current_form_bases(self) -> Dict[str, int]:
model_sym = self.lookup_fully_qualified(fullnames.BASEFORM_CLASS_FULLNAME)
if model_sym is not None and isinstance(model_sym.node, TypeInfo):
Expand Down Expand Up @@ -165,10 +152,6 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext
if fullname == "django.contrib.auth.get_user_model":
return partial(settings.get_user_model_hook, django_context=self.django_context)

manager_bases = self._get_current_manager_bases()
if fullname in manager_bases:
return querysets.determine_proper_manager_type

info = self._get_typeinfo_or_none(fullname)
if info:
if info.has_base(fullnames.FIELD_FULLNAME):
Expand All @@ -177,8 +160,26 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext
if helpers.is_model_type(info):
return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context)

if info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME):
return querysets.determine_proper_manager_type

return None

@cached_property
def manager_and_queryset_method_hooks(self) -> Dict[str, Callable[[MethodContext], MypyType]]:
typecheck_filtering_method = partial(orm_lookups.typecheck_queryset_filter, django_context=self.django_context)
return {
"values": partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context),
"values_list": partial(
querysets.extract_proper_type_queryset_values_list, django_context=self.django_context
),
"annotate": partial(querysets.extract_proper_type_queryset_annotate, django_context=self.django_context),
"create": partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context),
"filter": typecheck_filtering_method,
"get": typecheck_filtering_method,
"exclude": typecheck_filtering_method,
}

def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], MypyType]]:
class_fullname, _, method_name = fullname.rpartition(".")
# Methods called very often -- short circuit for minor speed up
Expand Down Expand Up @@ -208,38 +209,17 @@ def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], M
}
return hooks.get(class_fullname)

manager_classes = self._get_current_manager_bases()

if method_name == "values":
if method_name in self.manager_and_queryset_method_hooks:
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context)

elif method_name == "values_list":
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)

elif method_name == "annotate":
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME) or class_fullname in manager_classes:
return partial(querysets.extract_proper_type_queryset_annotate, django_context=self.django_context)

if info and helpers.has_any_of_bases(
info, [fullnames.QUERYSET_CLASS_FULLNAME, fullnames.MANAGER_CLASS_FULLNAME]
):
return self.manager_and_queryset_method_hooks[method_name]
elif method_name == "get_field":
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME):
return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context)

elif method_name == "create":
# We need `BASE_MANAGER_CLASS_FULLNAME` to check abstract models.
if class_fullname in manager_classes or class_fullname == fullnames.BASE_MANAGER_CLASS_FULLNAME:
return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context)
elif method_name in {"filter", "get", "exclude"} and class_fullname in manager_classes:
return partial(
mypy_django_plugin.transformers.orm_lookups.typecheck_queryset_filter,
django_context=self.django_context,
)

return None

def get_customize_class_mro_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
Expand All @@ -262,10 +242,6 @@ def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefConte
if sym is not None and isinstance(sym.node, TypeInfo) and helpers.is_model_type(sym.node):
return partial(process_model_class, django_context=self.django_context)

# Base class is a Manager class definition
if fullname in self._get_current_manager_bases():
return add_new_manager_base_hook

# Base class is a Form class definition
if fullname in self._get_current_form_bases():
return transform_form_class
Expand Down
6 changes: 0 additions & 6 deletions mypy_django_plugin/transformers/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,6 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte
ctx.api.defer()
return

# So that the plugin will reparameterize the manager when it is constructed inside of a Model definition
helpers.add_new_manager_base(semanal_api, new_manager_info.fullname)


def register_dynamically_created_manager(fullname: str, manager_name: str, manager_base: TypeInfo) -> None:
manager_base.metadata.setdefault("from_queryset_managers", {})
Expand Down Expand Up @@ -558,9 +555,6 @@ def create_new_manager_class_from_as_manager_method(ctx: DynamicClassDefContext)
manager_base=manager_base,
)

# So that the plugin will reparameterize the manager when it is constructed inside of a Model definition
helpers.add_new_manager_base(semanal_api, new_manager_info.fullname)

# Whenever `<QuerySet>.as_manager()` isn't called at class level, we want to ensure
# that the variable is an instance of our generated manager. Instead of the return
# value of `.as_manager()`. Though model argument is populated as `Any`.
Expand Down
1 change: 0 additions & 1 deletion mypy_django_plugin/transformers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,6 @@ def create_many_related_manager(self, model: Instance) -> None:
helpers.set_many_to_many_manager_info(
to=model.type, derived_from="_default_manager", manager_info=related_manager_info
)
helpers.add_new_manager_base(self.api, related_manager_info.fullname)


class MetaclassAdjustments(ModelClassInitializer):
Expand Down
17 changes: 17 additions & 0 deletions tests/typecheck/managers/test_managers.yml
Original file line number Diff line number Diff line change
Expand Up @@ -653,3 +653,20 @@
def get_instance(self) -> int:
pass
objects = MyManager()
- case: test_typechecks_filter_methods_of_queryset_type
main: |
from myapp.models import MyModel
MyModel.objects.filter(id=1).filter(invalid=1) # E: Cannot resolve keyword 'invalid' into field. Choices are: id [misc]
MyModel.objects.filter(id=1).get(invalid=1) # E: Cannot resolve keyword 'invalid' into field. Choices are: id [misc]
MyModel.objects.filter(id=1).exclude(invalid=1) # E: Cannot resolve keyword 'invalid' into field. Choices are: id [misc]
MyModel.objects.filter(id=1).create(invalid=1) # E: Unexpected attribute "invalid" for model "MyModel" [misc]
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class MyModel(models.Model): ...

0 comments on commit a28717d

Please sign in to comment.