Skip to content

Commit

Permalink
facility API, simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
hiddewie committed Jul 21, 2024
1 parent a0ab3d8 commit ab3b767
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 115 deletions.
23 changes: 17 additions & 6 deletions api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,37 @@ async def lifespan(app):
lifespan=lifespan,
)

DEFAULT_LIMIT = 20
MIN_LIMIT = 1
MAX_LIMIT = 200


@app.get("/api/status")
async def status():
api = StatusAPI()
return await api({})
return await api()


@app.get("/api/facility")
async def facility():
async def facility(
q: Annotated[str | None, Query()] = None,
name: Annotated[str | None, Query()] = None,
ref: Annotated[str | None, Query()] = None,
uic_ref: Annotated[str | None, Query()] = None,
limit: Annotated[int, Query(default=DEFAULT_LIMIT, ge=MIN_LIMIT, le=MAX_LIMIT)] = DEFAULT_LIMIT,
):
api = FacilityAPI(app.state.database)
return await api({})
return await api(q=q, name=name, ref=ref, uic_ref=uic_ref, limit=limit)


@app.get("/api/milestone")
async def milestone(
ref: Annotated[str | None, Query()] = None,
position: Annotated[str | None, Query()] = None,
ref: Annotated[str, Query()],
position: Annotated[float, Query()],
limit: Annotated[int | None, Query(default=DEFAULT_LIMIT, ge=MIN_LIMIT, le=MAX_LIMIT)] = DEFAULT_LIMIT,
):
api = MilestoneAPI(app.state.database)
return await api({'ref': ref, 'position': position})
return await api(ref=ref, position=position, limit=limit)

#
# def connect_db():
Expand Down
15 changes: 0 additions & 15 deletions api/openrailwaymap_api/abstract_api.py

This file was deleted.

149 changes: 75 additions & 74 deletions api/openrailwaymap_api/facility_api.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,125 @@
# SPDX-License-Identifier: GPL-2.0-or-later
from openrailwaymap_api.abstract_api import AbstractAPI

from fastapi import HTTPException
from starlette.status import HTTP_400_BAD_REQUEST

from api import MAX_LIMIT

QUERY_PARAMETERS = ['q', 'name', 'ref', 'uic_ref']

class FacilityAPI(AbstractAPI):
def __init__(self, db_conn):
self.db_conn = db_conn
self.search_args = ['q', 'name', 'ref', 'uic_ref']
self.data = []
self.status_code = 200
self.limit = 20
class FacilityAPI:
def __init__(self, database):
self.database = database

def eliminate_duplicates(self, data):
data.sort(key=lambda k: k['osm_id'])
i = 1
while i < len(data):
if data[i]['osm_id'] == data[i-1]['osm_id']:
if data[i]['osm_id'] == data[i - 1]['osm_id']:
data.pop(i)
i += 1
if len(data) > self.limit:
return data[:self.limit]
return data

async def __call__(self, args):
async def __call__(self, *, q, name, ref, uic_ref, limit):
# Validate search arguments
search_args_count = 0
for search_arg in self.search_args:
if search_arg in args and args[search_arg]:
for search_arg in [q, name, ref, uic_ref]:
if search_arg:
search_args_count += 1
if search_args_count > 1:
args = ', '.join(self.search_args)
args = ', '.join(QUERY_PARAMETERS)
raise HTTPException(
HTTP_400_BAD_REQUEST,
{'type': 'multiple_query_args', 'error': 'More than one argument with a search term provided.', 'detail': f'Provide only one of the following arguments: {args}'}
{'type': 'multiple_query_args', 'error': 'More than one argument with a search term provided.', 'detail': f'Provide only one of the following query parameters: {args}'}
)
elif search_args_count == 0:
args = ', '.join(self.search_args)
args = ', '.join(QUERY_PARAMETERS)
raise HTTPException(
HTTP_400_BAD_REQUEST,
{'type': 'no_query_arg', 'error': 'No argument with a search term provided.', 'detail': f'Provide one of the following arguments: {args}'}
{'type': 'no_query_arg', 'error': 'No argument with a search term provided.', 'detail': f'Provide one of the following query parameters: {args}'}
)
if 'limit' in args:
if limit is not None:
try:
self.limit = int(args['limit'])
self.limit = int(limit)
except ValueError:
raise HTTPException(
HTTP_400_BAD_REQUEST,
{'type': 'limit_not_integer', 'error': 'Invalid parameter value provided for parameter "limit".', 'detail': 'The provided limit cannot be parsed as an integer value.'}
)
if self.limit > self.MAX_LIMIT:
if self.limit > MAX_LIMIT:
raise HTTPException(
HTTP_400_BAD_REQUEST,
{'type': 'limit_too_high', 'error': 'Invalid parameter value provided for parameter "limit".', 'detail': 'Limit is too high. Please set up your own instance to query everything.'}
)
if args.get('name'):
return self.search_by_name(args['name'])
if args.get('ref'):
return self.search_by_ref(args['ref'])
if args.get('uic_ref'):
return self.search_by_uic_ref(args['uic_ref'])
if args.get('q'):
return self.eliminate_duplicates(self.search_by_name(args['q']) + self.search_by_ref(args['q']) + self.search_by_uic_ref(args['q']))
if name:
return await self.search_by_name(name)
if ref:
return await self.search_by_ref(ref)
if uic_ref:
return await self.search_by_uic_ref(uic_ref)
if q:
return self.eliminate_duplicates((await self.search_by_name(q)) + (await self.search_by_ref(q)) + (await self.search_by_uic_ref(q)))

def query_has_no_wildcards(self, q):
if '%' in q or '_' in q:
return False
return True

def search_by_name(self, q):
async def search_by_name(self, q):
if not self.query_has_no_wildcards(q):
self.status_code = 400
return {'type': 'wildcard_in_query', 'error': 'Wildcard in query.', 'detail': 'Query contains any of the wildcard characters: %_'}
with self.db_conn.cursor() as cursor:
data = []
# TODO support filtering on state of feature: abandoned, in construction, disused, preserved, etc.
# We do not sort the result although we use DISTINCT ON because osm_id is sufficient to sort out duplicates.
fields = self.sql_select_fieldlist()
sql_query = f"""SELECT
raise HTTPException(
HTTP_400_BAD_REQUEST,
{'type': 'wildcard_in_query', 'error': 'Wildcard in query.', 'detail': 'Query contains any of the wildcard characters: %_'}
)

# TODO support filtering on state of feature: abandoned, in construction, disused, preserved, etc.
# We do not sort the result although we use DISTINCT ON because osm_id is sufficient to sort out duplicates.
fields = self.sql_select_fieldlist()
sql_query = f"""SELECT
{fields}, latitude, longitude, rank
FROM (
SELECT DISTINCT ON (osm_id)
{fields}, latitude, longitude, rank
FROM (
SELECT DISTINCT ON (osm_id)
{fields}, latitude, longitude, rank
FROM (
SELECT
{fields}, ST_X(ST_Transform(geom, 4326)) AS latitude, ST_Y(ST_Transform(geom, 4326)) AS longitude, openrailwaymap_name_rank(phraseto_tsquery('simple', unaccent(openrailwaymap_hyphen_to_space(%s))), terms, route_count, railway, station) AS rank
FROM openrailwaymap_facilities_for_search
WHERE terms @@ phraseto_tsquery('simple', unaccent(openrailwaymap_hyphen_to_space(%s)))
) AS a
) AS b
ORDER BY rank DESC NULLS LAST
LIMIT %s;"""
cursor.execute(sql_query, (q, q, self.limit))
results = cursor.fetchall()
for r in results:
data.append(self.build_result_item_dict(cursor.description, r))
return data
FROM (
SELECT
{fields}, ST_X(ST_Transform(geom, 4326)) AS latitude, ST_Y(ST_Transform(geom, 4326)) AS longitude, openrailwaymap_name_rank(phraseto_tsquery('simple', unaccent(openrailwaymap_hyphen_to_space($1))), terms, route_count, railway, station) AS rank
FROM openrailwaymap_facilities_for_search
WHERE terms @@ phraseto_tsquery('simple', unaccent(openrailwaymap_hyphen_to_space($1)))
) AS a
) AS b
ORDER BY rank DESC NULLS LAST
LIMIT %2;"""

def _search_by_ref(self, search_key, ref):
with self.db_conn.cursor() as cursor:
data = []
# We do not sort the result although we use DISTINCT ON because osm_id is sufficient to sort out duplicates.
fields = self.sql_select_fieldlist()
sql_query = f"""SELECT DISTINCT ON (osm_id)
{fields}, ST_X(ST_Transform(geom, 4326)) AS latitude, ST_Y(ST_Transform(geom, 4326)) AS longitude
FROM openrailwaymap_ref
WHERE {search_key} = %s
LIMIT %s;"""
cursor.execute(sql_query, (ref, self.limit))
results = cursor.fetchall()
for r in results:
data.append(self.build_result_item_dict(cursor.description, r))
return data
async with self.database.acquire() as connection:
statement = await connection.prepare(sql_query)
async with connection.transaction():
data = []
async for record in statement.cursor(q, self.limit):
data.append(dict(record))
return data

async def _search_by_ref(self, search_key, ref):
# We do not sort the result, although we use DISTINCT ON because osm_id is sufficient to sort out duplicates.
fields = self.sql_select_fieldlist()
sql_query = f"""SELECT DISTINCT ON (osm_id)
{fields}, ST_X(ST_Transform(geom, 4326)) AS latitude, ST_Y(ST_Transform(geom, 4326)) AS longitude
FROM openrailwaymap_ref
WHERE {search_key} = %s
LIMIT %s;"""

async with self.database.acquire() as connection:
statement = await connection.prepare(sql_query)
async with connection.transaction():
data = []
async for record in statement.cursor(ref, self.limit):
data.append(dict(record))
return data

def search_by_ref(self, ref):
return self._search_by_ref("railway_ref", ref)
async def search_by_ref(self, ref):
return await self._search_by_ref("railway_ref", ref)

def search_by_uic_ref(self, ref):
return self._search_by_ref("uic_ref", ref)
async def search_by_uic_ref(self, ref):
return await self._search_by_ref("uic_ref", ref)

def sql_select_fieldlist(self):
return "osm_id, name, railway, railway_ref"
24 changes: 10 additions & 14 deletions api/openrailwaymap_api/milestone_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# SPDX-License-Identifier: GPL-2.0-or-later
from fastapi import HTTPException
from openrailwaymap_api.abstract_api import AbstractAPI
from starlette.status import HTTP_400_BAD_REQUEST


class MilestoneAPI(AbstractAPI):
class MilestoneAPI:
def __init__(self, database):
self.database = database
self.route_ref = None
Expand All @@ -13,40 +11,38 @@ def __init__(self, database):
self.status_code = 200
self.limit = 2

async def __call__(self, args):
async def __call__(self, *, ref, position, limit):
# Validate search arguments
ref = args.get('ref')
position = args.get('position')
if ref is None or position is None:
raise HTTPException(
HTTP_400_BAD_REQUEST,
{'type': 'no_query_arg', 'error': 'One or multiple mandatory parameters are missing.', 'detail': 'You have to provide both "ref" and "position".'}
)
self.route_ref = args.get('ref')

try:
self.position = float(args.get('position'))
position = float(position)
except ValueError:
raise HTTPException(
HTTP_400_BAD_REQUEST,
{'type': 'position_not_float', 'error': 'Invalid value provided for parameter "position".', 'detail': 'The provided position cannot be parsed as a float.'}
)
if 'limit' in args:
if limit is not None:
try:
self.limit = int(args['limit'])
limit = int(limit)
except ValueError:
raise HTTPException(
HTTP_400_BAD_REQUEST,
{'type': 'limit_not_integer', 'error': 'Invalid parameter value provided for parameter "limit".', 'detail': 'The provided limit cannot be parsed as an integer value.'}
)
if self.limit > self.MAX_LIMIT:
if limit > self.MAX_LIMIT:
raise HTTPException(
HTTP_400_BAD_REQUEST,
{'type': 'limit_too_high', 'error': 'Invalid parameter value provided for parameter "limit".', 'detail': 'Limit is too high. Please set up your own instance to query everything.'}
)
self.data = await self.get_milestones()
self.data = await self.get_milestones(position, ref, limit)
return self.data

async def get_milestones(self):
async def get_milestones(self, position, route_ref, limit):
# We do not sort the result, although we use DISTINCT ON because osm_id is sufficient to sort out duplicates.
sql_query = """SELECT
osm_id,
Expand Down Expand Up @@ -116,6 +112,6 @@ async def get_milestones(self):
statement = await connection.prepare(sql_query)
async with connection.transaction():
data = []
async for record in statement.cursor(self.position, self.route_ref, self.limit):
async for record in statement.cursor(position, route_ref, limit):
data.append(dict(record))
return data
8 changes: 2 additions & 6 deletions api/openrailwaymap_api/status_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# SPDX-License-Identifier: GPL-2.0-or-later
from openrailwaymap_api.abstract_api import AbstractAPI


class StatusAPI(AbstractAPI):
async def __call__(self, args):
class StatusAPI:
async def __call__(self):
return 'OK'

0 comments on commit ab3b767

Please sign in to comment.