diff --git a/contrib/test_bench.py b/contrib/test_bench.py index 3fa77d46..b1baf024 100644 --- a/contrib/test_bench.py +++ b/contrib/test_bench.py @@ -4,10 +4,8 @@ from yente.data.entity import Entity from yente.data.common import EntityExample -from yente.search.base import close_es from yente.search.queries import entity_query from yente.search.search import search_entities, result_entities -from yente.data.entity import Entity from yente.scoring import score_results from yente.routers.util import get_dataset @@ -34,7 +32,7 @@ async def test_example(): print(ent.id, ent.caption, ent.schema.name) algorithm = get_algorithm("name-based") - scored = score_results( + total, scored = score_results( algorithm, entity, ents, @@ -43,11 +41,9 @@ async def test_example(): limit=LIMIT, ) - print("\n\nSCORED RESULTS:") + print("\n\nSCORED RESULTS [%d]:" % total) for res in scored: print(res.id, res.caption, res.schema_, res.score) - await close_es() - asyncio.run(test_example()) diff --git a/yente/routers/match.py b/yente/routers/match.py index b6cb1100..b9653817 100644 --- a/yente/routers/match.py +++ b/yente/routers/match.py @@ -6,10 +6,10 @@ from yente.logs import get_logger from yente.data.common import ErrorResponse from yente.data.common import EntityMatchQuery, EntityMatchResponse, EntityExample -from yente.data.common import EntityMatches +from yente.data.common import EntityMatches, TotalSpec from yente.provider import SearchProvider, get_provider from yente.search.queries import entity_query, FilterDict -from yente.search.search import search_entities, result_entities, result_total +from yente.search.search import search_entities, result_entities from yente.data.entity import Entity from yente.util import limit_window from yente.scoring import score_results @@ -173,7 +173,7 @@ async def match( for (name, entity), resp in zip(entities, results): ents = result_entities(resp) - scored = score_results( + total, scored = score_results( algorithm_type, entity, ents, @@ -182,17 +182,16 @@ async def match( limit=limit, weights=match.weights, ) - total = result_total(resp) log.info( f"/match/{ds.name}", action="match", schema=entity.schema.name, - results=total.value, + results=total, ) responses[name] = EntityMatches( status=200, results=scored, - total=total, + total=TotalSpec(value=total, relation="eq"), query=EntityExample.model_validate(entity.to_dict()), ) response.headers["x-batch-size"] = str(len(responses)) diff --git a/yente/routers/reconcile.py b/yente/routers/reconcile.py index d9e30540..20b58941 100644 --- a/yente/routers/reconcile.py +++ b/yente/routers/reconcile.py @@ -188,13 +188,13 @@ async def reconcile_query( resp = await search_entities(provider, query, limit=limit, offset=offset) algorithm_ = get_algorithm_by_name(algorithm) entities = result_entities(resp) - scoreds = [s for s in score_results(algorithm_, proxy, entities, limit=limit)] + total, scoreds = score_results(algorithm_, proxy, entities, limit=limit) results = [FreebaseScoredEntity.from_scored(s) for s in scoreds] log.info( f"/reconcile/{dataset.name}", action="reconcile", schema=proxy.schema.name, - results=result_total(resp).value, + matches=total, ) return name, FreebaseEntityResult(result=results) diff --git a/yente/scoring.py b/yente/scoring.py index 8fcc2287..a8ef96d6 100644 --- a/yente/scoring.py +++ b/yente/scoring.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Type, Dict +from typing import Iterable, List, Optional, Type, Dict, Tuple from nomenklatura.matching.types import ScoringAlgorithm from yente import settings @@ -14,7 +14,7 @@ def score_results( cutoff: float = 0.0, limit: Optional[int] = None, weights: Dict[str, float] = {}, -) -> List[ScoredEntityResponse]: +) -> Tuple[int, List[ScoredEntityResponse]]: scored: List[ScoredEntityResponse] = [] matches = 0 for proxy in results: @@ -29,4 +29,4 @@ def score_results( scored = sorted(scored, key=lambda r: r.score, reverse=True) if limit is not None: scored = scored[:limit] - return scored + return matches, scored