Skip to content

Commit

Permalink
Implement ?algorithm=best as an option
Browse files Browse the repository at this point in the history
  • Loading branch information
pudo committed Sep 17, 2023
1 parent 925272a commit f2c56ab
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 20 deletions.
2 changes: 2 additions & 0 deletions yente/data/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,5 @@ class Algorithm(BaseModel):

class AlgorithmResponse(BaseModel):
algorithms: List[Algorithm]
default: str
best: str
6 changes: 5 additions & 1 deletion yente/routers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@ async def algorithms() -> AlgorithmResponse:
features=algo.explain(),
)
algorithms.append(desc)
return AlgorithmResponse(algorithms=algorithms)
return AlgorithmResponse(
algorithms=algorithms,
default=settings.DEFAULT_ALGORITHM,
best=settings.BEST_ALGORITHM,
)


@router.post(
Expand Down
14 changes: 4 additions & 10 deletions yente/routers/match.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
from typing import Dict, List, Optional
from fastapi import APIRouter, Query, Response, HTTPException
from nomenklatura.matching import get_algorithm

from yente import settings
from yente.logs import get_logger
Expand All @@ -13,8 +12,8 @@
from yente.data.entity import Entity
from yente.util import limit_window
from yente.scoring import score_results
from yente.routers.util import get_dataset
from yente.routers.util import PATH_DATASET, TS_PATTERN, ALGO_LIST
from yente.routers.util import get_dataset, get_algorithm_by_name
from yente.routers.util import PATH_DATASET, TS_PATTERN, ALGO_HELP

log = get_logger(__name__)
router = APIRouter()
Expand Down Expand Up @@ -47,10 +46,7 @@ async def match(
settings.SCORE_CUTOFF,
title="Lower bound of score for results to be returned at all",
),
algorithm: str = Query(
settings.DEFAULT_ALGORITHM,
title=f"Scoring algorithm to use, options: {ALGO_LIST}, best: {settings.BEST_ALGORITHM}", # noqa
),
algorithm: str = Query(settings.DEFAULT_ALGORITHM, title=ALGO_HELP),
exclude_schema: List[str] = Query(
[], title="Remove the given types of entities from results"
),
Expand Down Expand Up @@ -122,9 +118,7 @@ async def match(
"""
ds = await get_dataset(dataset)
limit, _ = limit_window(limit, 0, settings.MATCH_PAGE)
algorithm_type = get_algorithm(algorithm)
if algorithm_type is None:
raise HTTPException(400, detail=f"Unknown algorithm: {algorithm}")
algorithm_type = get_algorithm_by_name(algorithm)

if len(match.queries) > settings.MAX_BATCH:
msg = "Too many queries in one batch (limit: %d)" % settings.MAX_BATCH
Expand Down
13 changes: 6 additions & 7 deletions yente/routers/reconcile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fastapi import HTTPException
from followthemoney import model
from followthemoney.types import registry
from nomenklatura.matching import get_algorithm


from yente import settings
from yente.data.common import ErrorResponse, EntityExample
Expand All @@ -34,8 +34,9 @@
from yente.search.search import get_matchable_schemata
from yente.scoring import score_results
from yente.util import match_prefix, limit_window, typed_url
from yente.routers.util import PATH_DATASET, QUERY_PREFIX, get_dataset
from yente.routers.util import TS_PATTERN, ALGO_LIST
from yente.routers.util import PATH_DATASET, QUERY_PREFIX
from yente.routers.util import TS_PATTERN, ALGO_HELP
from yente.routers.util import get_algorithm_by_name, get_dataset


log = get_logger(__name__)
Expand Down Expand Up @@ -116,7 +117,7 @@ async def reconcile_post(
queries: str = Form(None, description="JSON-encoded reconciliation queries"),
algorithm: str = Query(
settings.BEST_ALGORITHM,
title=f"Scoring algorithm to use, options: {ALGO_LIST}",
title=ALGO_HELP,
),
changed_since: Optional[str] = Query(
None,
Expand Down Expand Up @@ -180,9 +181,7 @@ async def reconcile_query(
proxy = Entity.from_example(example)
query = entity_query(dataset, proxy, fuzzy=False, changed_since=changed_since)
resp = await search_entities(query, limit=limit, offset=offset)
algorithm_ = get_algorithm(algorithm)
if algorithm_ is None:
raise HTTPException(400, detail=f"Unknown algorithm: {algorithm}")
algorithm_ = get_algorithm_by_name(algorithm)
entities = result_entities(resp)
scoreds = [s for s in score_results(algorithm_, proxy, entities, limit=limit)]
results = [FreebaseScoredEntity.from_scored(s) for s in scoreds]
Expand Down
20 changes: 18 additions & 2 deletions yente/routers/util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Type
from fastapi import Path, Query
from fastapi import HTTPException
from nomenklatura.matching import ALGORITHMS
from nomenklatura.matching import ALGORITHMS, ScoringAlgorithm, get_algorithm

from yente.data.dataset import Dataset
from yente import settings
from yente.data import get_catalog
from yente.data.dataset import Dataset


PATH_DATASET = Path(
Expand All @@ -13,6 +15,20 @@
QUERY_PREFIX = Query("", min_length=1, description="Search prefix")
TS_PATTERN = r"^\d{4}-\d{2}-\d{2}(T\d{2}(:\d{2}(:\d{2})?)?)?$"
ALGO_LIST = ", ".join([a.NAME for a in ALGORITHMS])
ALGO_HELP = (
f"Scoring algorithm to use, options: {ALGO_LIST} (best: {settings.BEST_ALGORITHM})"
)


def get_algorithm_by_name(name: str) -> Type[ScoringAlgorithm]:
"""Return the scoring algorithm class with the given name."""
name = name.lower().strip()
if name == "best":
name = settings.BEST_ALGORITHM
algorithm = get_algorithm(name)
if algorithm is None:
raise HTTPException(400, detail=f"Invalid algorithm: {name}")
return algorithm


async def get_dataset(name: str) -> Dataset:
Expand Down

0 comments on commit f2c56ab

Please sign in to comment.