Skip to content

Commit

Permalink
Ensure get_all query returns all results when datasets have many tags…
Browse files Browse the repository at this point in the history
… or formats (#321)
  • Loading branch information
florimondmanca authored Jul 1, 2022
1 parent 450258a commit fb27c4e
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 22 deletions.
25 changes: 13 additions & 12 deletions server/infrastructure/datasets/queries/get_all.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from typing import List

from sqlalchemy import desc, func, select, text
from sqlalchemy.engine import Row
from sqlalchemy.orm import contains_eager
from sqlalchemy.sql import ColumnElement
from sqlalchemy.orm import contains_eager, selectinload

from server.domain.datasets.repositories import DatasetGetAllExtras
from server.domain.datasets.specifications import DatasetSpec
Expand All @@ -18,7 +15,8 @@

class GetAllQuery:
def __init__(self, spec: DatasetSpec) -> None:
columns: List[ColumnElement] = []
columns = []
joinclauses = []
whereclauses = []
orderbyclauses = []

Expand Down Expand Up @@ -70,23 +68,26 @@ def __init__(self, spec: DatasetSpec) -> None:
whereclauses.append(DatasetModel.service.in_(services))

if (formats := spec.format__in) is not None:
joinclauses.append((DatasetModel.formats, {"isouter": True}))
whereclauses.append(DataFormatModel.name.in_(formats))

if (technical_sources := spec.technical_source__in) is not None:
whereclauses.append(DatasetModel.technical_source.in_(technical_sources))

if (tag_ids := spec.tag__id__in) is not None:
joinclauses.append((DatasetModel.tags, {"isouter": True}))
whereclauses.append(TagModel.id.in_(tag_ids))

stmt = select(DatasetModel, *columns).join(DatasetModel.catalog_record)

for target, kwargs in joinclauses:
stmt = stmt.join(target, **kwargs)

self.statement = (
select(DatasetModel, *columns)
.join(DatasetModel.catalog_record)
.join(DatasetModel.formats, isouter=True)
.join(DatasetModel.tags, isouter=True)
.options(
stmt.options(
contains_eager(DatasetModel.catalog_record),
contains_eager(DatasetModel.formats),
contains_eager(DatasetModel.tags),
selectinload(DatasetModel.formats),
selectinload(DatasetModel.tags),
)
.where(*whereclauses)
.order_by(*orderbyclauses, CatalogRecordModel.created_at.desc())
Expand Down
2 changes: 1 addition & 1 deletion server/infrastructure/datasets/repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def get_all(
result = await session.stream(stmt.limit(limit).offset(offset))
items = [
(make_entity(query.instance(row)), query.extras(row))
async for row in result.unique()
async for row in result
]
return items, count

Expand Down
12 changes: 9 additions & 3 deletions tests/api/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from typing import Any, List

import httpx
Expand Down Expand Up @@ -112,6 +113,7 @@ async def test_dataset_crud(
update_frequency=UpdateFrequency.WEEKLY,
last_updated_at=last_updated_at,
published_url=None,
tag_ids=[],
)
)

Expand Down Expand Up @@ -200,11 +202,14 @@ async def test_delete_not_admin(
assert response.status_code == 403


async def add_dataset_pagination_corpus(n: int) -> None:
async def add_dataset_pagination_corpus(n: int, tags: list) -> None:
bus = resolve(MessageBus)

for k in range(1, n + 1):
await bus.execute(CreateDatasetFactory.build(title=f"Dataset {k}"))
tag_ids = [tag.id for tag in random.choices(tags, k=random.randint(0, 2))]
await bus.execute(
CreateDatasetFactory.build(title=f"Dataset {k}", tag_ids=tag_ids)
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -251,12 +256,13 @@ async def add_dataset_pagination_corpus(n: int) -> None:
async def test_dataset_pagination(
client: httpx.AsyncClient,
temp_user: TestUser,
tags: list,
params: dict,
expected_total_pages: int,
expected_num_items: int,
expected_dataset_titles: List[str],
) -> None:
await add_dataset_pagination_corpus(n=13)
await add_dataset_pagination_corpus(n=13, tags=tags)

response = await client.get("/datasets/", params=params, auth=temp_user.auth)
assert response.status_code == 200
Expand Down
20 changes: 19 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import os
from typing import TYPE_CHECKING, AsyncIterator, Iterator
from typing import TYPE_CHECKING, AsyncIterator, Iterator, List

import httpx
import pytest
Expand All @@ -11,11 +11,14 @@
from sqlalchemy_utils import create_database, database_exists, drop_database

from server.application.datasets.queries import GetAllDatasets
from server.application.tags.queries import GetAllTags
from server.application.tags.views import TagView
from server.config import Settings
from server.config.di import bootstrap, resolve
from server.domain.auth.entities import UserRole
from server.infrastructure.database import Database
from server.seedwork.application.messages import MessageBus
from tests.factories import CreateTagFactory

from .helpers import TestUser, create_client, create_test_user

Expand Down Expand Up @@ -61,6 +64,21 @@ async def warmup_db() -> None:
await bus.execute(GetAllDatasets())


@pytest_asyncio.fixture
async def tags(transaction: None) -> List[TagView]:
bus = resolve(MessageBus)

for name in [
"Monument historique",
"Lieu culturel",
"Musée de France",
"Statistiques",
]:
await bus.execute(CreateTagFactory.build(name=name))

return await bus.execute(GetAllTags())


@pytest.fixture(scope="session")
def event_loop() -> Iterator[asyncio.AbstractEventLoop]:
loop = asyncio.new_event_loop()
Expand Down
7 changes: 2 additions & 5 deletions tests/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from server.application.datasets.commands import CreateDataset, UpdateDataset
from server.application.tags.commands import CreateTag
from server.domain.common import datetime as dtutil
from server.domain.datasets.entities import DataFormat

T = TypeVar("T", bound=BaseModel)

Expand Down Expand Up @@ -41,15 +42,11 @@ class CreateUserFactory(Factory[CreateUser]):
class CreateTagFactory(Factory[CreateTag]):
__model__ = CreateTag

name = Use(
random.choice,
("Monument historique", "Lieu culturel", "Musée de France", "Statistiques"),
)


class CreateDatasetFactory(Factory[CreateDataset]):
__model__ = CreateDataset

formats = Use(lambda: random.choices(list(DataFormat), k=random.randint(1, 3)))
tag_ids = Use(lambda: [])


Expand Down

0 comments on commit fb27c4e

Please sign in to comment.