Skip to content

Commit

Permalink
Fixes a couple of mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
cyrillkuettel committed Jul 2, 2024
1 parent 6b76406 commit efa9ba9
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/privatim/i18n/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
'LocaleNegotiator',
'pluralize',
'translate',
locales,
'locales',
)
13 changes: 5 additions & 8 deletions src/privatim/models/searchable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@


from typing import Iterator, TYPE_CHECKING

from privatim.types import HasSearchableFields

if TYPE_CHECKING:
from sqlalchemy.orm import Session

Expand All @@ -27,18 +30,13 @@ def searchable_text(self) -> str:
)


def searchable_models() -> tuple[type[Base], ...]:
def searchable_models() -> tuple[type[HasSearchableFields], ...]:
"""Retrieve all models inheriting from SearchableMixin."""
model_classes = set()
for _ in Base.metadata.tables.values():
for mapper in Base.registry.mappers:
cls = mapper.class_
if (
inspect.isclass(cls)
and issubclass(cls, SearchableMixin)
and issubclass(cls, Base)
and cls != SearchableMixin
):
if issubclass(cls, SearchableMixin):
model_classes.add(cls)
return tuple(model_classes)

Expand All @@ -62,7 +60,6 @@ def reindex_full_text_search(session: 'Session') -> None:
# todo: remove later
assert len(models) != 0, "No models with searchable fields found"
for model in models:
assert issubclass(model, SearchableMixin)
for locale, language in locales.items():
assert language == 'german' # todo: remove later
if hasattr(model, f'searchable_text_{locale}'):
Expand Down
18 changes: 17 additions & 1 deletion src/privatim/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Type, Union, Iterable

from sqlalchemy import ColumnElement

from privatim.models import SearchableMixin
from privatim.orm import Base
from privatim.orm.meta import UUIDStrPK

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence

Expand Down Expand Up @@ -77,3 +84,12 @@ class LaxFileDict(TypedDict):

class Callback(Protocol[_Tco]):
def __call__(self, context: Any, request: IRequest) -> _Tco: ...


class HasSearchableFields(Protocol):
id: UUIDStrPK

@classmethod
def searchable_fields(cls) -> Iterable[ColumnElement[Any]]:
...

88 changes: 53 additions & 35 deletions src/privatim/views/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,43 @@
from privatim.i18n import locales, translate
from sqlalchemy import or_

from privatim.models.searchable import searchable_models
from privatim.models.comment import Comment


from typing import TYPE_CHECKING, List, NamedTuple, TypeVar
from typing import (
TYPE_CHECKING,
List,
NamedTuple,
Dict,
Any,
Optional,
Iterator,
)

from privatim.models.searchable import searchable_models

if TYPE_CHECKING:
from pyramid.interfaces import IRequest
from sqlalchemy.orm import Session

from privatim.orm.meta import UUIDStrPK
from privatim.types import HasSearchableFields
from builtins import type as type_t
from privatim.orm import Base
from privatim.models.searchable import SearchableMixin


class SearchResult(NamedTuple):
id: int
""" headlines are key value pairs of fields on various models that matched
the search query."""
headlines: dict[str, str]
headlines: Dict[str, str]
type: str
model: 'Model | None' # Note that this is not loaded by default for
model: 'Optional[type_t[HasSearchableFields]]' # Note that this is not loaded by
# default for
# performance reasons.


Model = TypeVar('Model')


class SearchCollection:
"""A class for searching the database for a given term.
Expand All @@ -54,26 +66,27 @@ class SearchCollection:
"""

def __init__(self, term: str, session: 'Session', language='de_CH'):
self.lang = locales[language]
self.session = session
self.web_search = term
def __init__(self, term: str, session: 'Session', language: str = 'de_CH'):
self.lang: str = locales[language]
self.session: 'Session' = session
self.web_search: str = term
self.ts_query = func.websearch_to_tsquery(self.lang, self.web_search)
self.results: List[SearchResult] = []
self.models = [Consultation, Meeting, Comment]

def do_search(self) -> None:
for model in searchable_models():
self.results.extend(self.search_model(model, self.ts_query))

# Fetch Comment objects after the search
comment_ids = [
comment_ids: List[int] = [
result.id for result in self.results if result.type == 'Comment'
]
if comment_ids:
stmt = select(Comment).filter(Comment.id.in_(comment_ids))
comments = self.session.scalars(stmt).all()
comment_dict = {comment.id: comment for comment in comments}
comment_dict: dict[UUIDStrPK, Comment] = {
comment.id: comment for comment in comments
}

# Update results with fetched Comment objects
self.results = [
Expand All @@ -85,16 +98,18 @@ def do_search(self) -> None:
for result in self.results
]

def search_model(self, model: type[Model], ts_query) -> List[SearchResult]:
def search_model(
self, model: 'type[HasSearchableFields]', ts_query: Any
) -> List[SearchResult]:
query = self.build_query(model, ts_query)
raw_results = self.session.execute(query).all()
return self.process_results(raw_results, model)

def build_query(self, model: type[Model], ts_query):
def build_query(self, model: 'type[HasSearchableFields]', ts_query: Any) -> Any:

headline_expression = self.generate_headlines(model, ts_query)

select_fields = [
select_fields: List[Any] = [
model.id,
*headline_expression,
cast(literal(model.__name__), String).label('type'), # noqa: MS001
Expand All @@ -104,7 +119,9 @@ def build_query(self, model: type[Model], ts_query):
or_(*self.term_filter_text_for_model(model, self.lang))
)

def generate_headlines(self, model: type[Model], ts_query):
def generate_headlines(
self, model:'type[HasSearchableFields]', ts_query: Any
) -> List[ColumnElement]:
"""
Generate headline expressions for all searchable fields of the model.
Expand All @@ -114,7 +131,7 @@ def generate_headlines(self, model: type[Model], ts_query):
field.
Args:
model (type[Model]): The model class to generate headlines for.
model (type['type[HasSearchableFields]']): The model class to generate headlines for.
ts_query: The text search query to use for highlighting.
Returns:
Expand All @@ -139,15 +156,15 @@ def generate_headlines(self, model: type[Model], ts_query):
]

def term_filter_text_for_model(
self, model: type[Model], language: str
self, model:'type[HasSearchableFields]', language: str
) -> List[ColumnElement[bool]]:
def match(
column: ColumnElement[str],
column: ColumnElement[str],
) -> ColumnElement[bool]:
return column.op('@@')(self.ts_query)

def match_convert(
column: ColumnElement[str], language: str
column: ColumnElement[str], language: str
) -> ColumnElement[bool]:
return match(func.to_tsvector(language, column))

Expand All @@ -157,11 +174,11 @@ def match_convert(
]

def process_results(
self, raw_results, model: type[Model]
self, raw_results: List[Any], model:'type[HasSearchableFields]'
) -> List[SearchResult]:
processed_results = []
processed_results: List[SearchResult] = []
for result in raw_results:
headlines = {
headlines: Dict[str, str] = {
translate(field.name.capitalize()): value
for field in model.searchable_fields()
if (value := getattr(result, field.name, None)) is not None
Expand All @@ -171,25 +188,25 @@ def process_results(
id=result.id,
headlines=headlines,
type=result.type,
model=None
model=None,
)
)
return processed_results

def __len__(self):
def __len__(self) -> int:
return len(self.results)

def __iter__(self):
def __iter__(self) -> Iterator[SearchResult]:
return iter(self.results)

def __getitem__(self, index):
def __getitem__(self, index: int) -> SearchResult:
return self.results[index]

def __repr__(self) -> str:
return f'<SearchResultCollection {self.results[:4]}>'


def search(request: 'IRequest'):
def search(request: 'IRequest') -> Dict[str, Any]:
"""
Handle search form submission using POST/Redirect/GET pattern.
Expand All @@ -199,20 +216,21 @@ def search(request: 'IRequest'):
"""
session = request.dbsession
form = SearchForm(request)
session: 'Session' = request.dbsession
form: SearchForm = SearchForm(request)
if request.method == 'POST' and form.validate():
query = form.term.data
return HTTPFound(
location=request.route_url(
'search',
_query={'q': query},
_query={'q': form.term.data},
)
)

query = request.GET.get('q')
if query:
result_collection = SearchCollection(term=query, session=session)
result_collection: SearchCollection = SearchCollection(
term=query, session=session
)
result_collection.do_search()
return {
'search_results': result_collection.results,
Expand Down

0 comments on commit efa9ba9

Please sign in to comment.