forked from OpenRailwayMap/OpenRailwayMap-CartoCSS
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Async API implementation with FastAPI and
asyncpg
(#76)
Docs in https://fastapi.tiangolo.com/tutorial/ and https://magicstack.github.io/asyncpg/current/usage.html Change the current werkzeug implementation from a development-only WSGI server to a production ready async implementation that can run quickly using only a single CPU core
- Loading branch information
Showing
9 changed files
with
408 additions
and
296 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,69 +1,70 @@ | ||
#! /usr/bin/env python3 | ||
# SPDX-License-Identifier: GPL-2.0-or-later | ||
|
||
import psycopg2 | ||
import psycopg2.extras | ||
import json | ||
import sys | ||
import contextlib | ||
import os | ||
from werkzeug.exceptions import HTTPException, NotFound, InternalServerError | ||
from werkzeug.routing import Map, Rule | ||
from werkzeug.wrappers import Request, Response | ||
from typing import Annotated | ||
|
||
import asyncpg | ||
from fastapi import FastAPI | ||
from fastapi import Query | ||
|
||
from openrailwaymap_api.facility_api import FacilityAPI | ||
from openrailwaymap_api.milestone_api import MilestoneAPI | ||
from openrailwaymap_api.status_api import StatusAPI | ||
|
||
def connect_db(): | ||
conn = psycopg2.connect(dbname=os.environ['POSTGRES_DB'], user=os.environ['POSTGRES_USER'], host=os.environ['POSTGRES_HOST']) | ||
return conn | ||
|
||
class OpenRailwayMapAPI: | ||
|
||
db_conn = connect_db() | ||
|
||
def __init__(self): | ||
self.url_map = Map([ | ||
Rule('/api/facility', endpoint=FacilityAPI, methods=('GET',)), | ||
Rule('/api/milestone', endpoint=MilestoneAPI, methods=('GET',)), | ||
Rule('/api/status', endpoint=StatusAPI, methods=('GET',)), | ||
]) | ||
|
||
def ensure_db_connection_alive(self): | ||
if self.db_conn.closed != 0: | ||
self.db_conn = connect_db() | ||
|
||
def dispatch_request(self, environ, start_response): | ||
request = Request(environ) | ||
urls = self.url_map.bind_to_environ(environ) | ||
response = None | ||
try: | ||
endpoint, args = urls.match() | ||
self.ensure_db_connection_alive() | ||
response = endpoint(self.db_conn)(request.args) | ||
except HTTPException as e: | ||
return e | ||
except Exception as e: | ||
print('Error during request:', e, file=sys.stderr) | ||
return InternalServerError() | ||
finally: | ||
if not response: | ||
self.db_conn.close() | ||
self.db_conn = connect_db() | ||
return response | ||
|
||
def wsgi_app(self, environ, start_response): | ||
request = Request(environ) | ||
response = self.dispatch_request(request) | ||
return response(environ, start_response) | ||
|
||
|
||
def application(environ, start_response): | ||
openrailwaymap_api = OpenRailwayMapAPI() | ||
response = openrailwaymap_api.dispatch_request(environ, start_response) | ||
return response(environ, start_response) | ||
|
||
|
||
if __name__ == '__main__': | ||
openrailwaymap_api = OpenRailwayMapAPI() | ||
from werkzeug.serving import run_simple | ||
run_simple('::', int(os.environ['PORT']), application, use_debugger=True, use_reloader=True) | ||
|
||
@contextlib.asynccontextmanager | ||
async def lifespan(app): | ||
async with asyncpg.create_pool( | ||
user=os.environ['POSTGRES_USER'], | ||
host=os.environ['POSTGRES_HOST'], | ||
database=os.environ['POSTGRES_DB'], | ||
command_timeout=10, | ||
min_size=1, | ||
max_size=20, | ||
) as pool: | ||
print('Connected to database') | ||
app.state.database = pool | ||
|
||
yield | ||
|
||
app.state.database = None | ||
|
||
print('Disconnected from database') | ||
|
||
|
||
app = FastAPI( | ||
title="OpenRailwayMap API", | ||
lifespan=lifespan, | ||
) | ||
|
||
DEFAULT_FACILITY_LIMIT = 20 | ||
DEFAULT_MILESTONE_LIMIT = 2 | ||
MIN_LIMIT = 1 | ||
MAX_LIMIT = 200 | ||
|
||
|
||
@app.get("/api/status") | ||
async def status(): | ||
api = StatusAPI() | ||
return await api() | ||
|
||
|
||
@app.get("/api/facility") | ||
async def facility( | ||
q: Annotated[str | None, Query()] = None, | ||
name: Annotated[str | None, Query()] = None, | ||
ref: Annotated[str | None, Query()] = None, | ||
uic_ref: Annotated[str | None, Query()] = None, | ||
limit: Annotated[int, Query(ge=MIN_LIMIT, le=MAX_LIMIT)] = DEFAULT_FACILITY_LIMIT, | ||
): | ||
api = FacilityAPI(app.state.database) | ||
return await api(q=q, name=name, ref=ref, uic_ref=uic_ref, limit=limit) | ||
|
||
|
||
@app.get("/api/milestone") | ||
async def milestone( | ||
ref: Annotated[str, Query()], | ||
position: Annotated[float, Query()], | ||
limit: Annotated[int | None, Query(ge=MIN_LIMIT, le=MAX_LIMIT)] = DEFAULT_MILESTONE_LIMIT, | ||
): | ||
api = MilestoneAPI(app.state.database) | ||
return await api(ref=ref, position=position, limit=limit) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,118 +1,109 @@ | ||
# SPDX-License-Identifier: GPL-2.0-or-later | ||
from openrailwaymap_api.abstract_api import AbstractAPI | ||
from fastapi import HTTPException | ||
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_422_UNPROCESSABLE_ENTITY | ||
|
||
class FacilityAPI(AbstractAPI): | ||
def __init__(self, db_conn): | ||
self.db_conn = db_conn | ||
self.search_args = ['q', 'name', 'ref', 'uic_ref'] | ||
self.data = [] | ||
self.status_code = 200 | ||
self.limit = 20 | ||
QUERY_PARAMETERS = ['q', 'name', 'ref', 'uic_ref'] | ||
|
||
def eliminate_duplicates(self, data): | ||
class FacilityAPI: | ||
def __init__(self, database): | ||
self.database = database | ||
|
||
def eliminate_duplicates(self, data, limit): | ||
data.sort(key=lambda k: k['osm_id']) | ||
i = 1 | ||
while i < len(data): | ||
if data[i]['osm_id'] == data[i-1]['osm_id']: | ||
if data[i]['osm_id'] == data[i - 1]['osm_id']: | ||
data.pop(i) | ||
i += 1 | ||
if len(data) > self.limit: | ||
return data[:self.limit] | ||
if len(data) > limit: | ||
return data[:limit] | ||
return data | ||
|
||
def __call__(self, args): | ||
data = [] | ||
async def __call__(self, *, q, name, ref, uic_ref, limit): | ||
# Validate search arguments | ||
search_args_count = 0 | ||
for search_arg in self.search_args: | ||
if search_arg in args and args[search_arg]: | ||
search_args_count += 1 | ||
search_args_count = sum(1 for search_arg in [q, name, ref, uic_ref] if search_arg) | ||
|
||
if search_args_count > 1: | ||
args = ', '.join(self.search_args) | ||
self.data = {'type': 'multiple_query_args', 'error': 'More than one argument with a search term provided.', 'detail': f'Provide only one of the following arguments: {args}'} | ||
self.status_code = 400 | ||
return self.build_response() | ||
args = ', '.join(QUERY_PARAMETERS) | ||
raise HTTPException( | ||
HTTP_422_UNPROCESSABLE_ENTITY, | ||
{'type': 'multiple_query_args', 'error': 'More than one argument with a search term provided.', 'detail': f'Provide only one of the following query parameters: {args}'} | ||
) | ||
elif search_args_count == 0: | ||
args = ', '.join(self.search_args) | ||
self.data = {'type': 'no_query_arg', 'error': 'No argument with a search term provided.', 'detail': f'Provide one of the following arguments: {args}'} | ||
self.status_code = 400 | ||
return self.build_response() | ||
if 'limit' in args: | ||
try: | ||
self.limit = int(args['limit']) | ||
except ValueError: | ||
self.data = {'type': 'limit_not_integer', 'error': 'Invalid parameter value provided for parameter "limit".', 'detail': 'The provided limit cannot be parsed as an integer value.'} | ||
self.status_code = 400 | ||
return self.build_response() | ||
if self.limit > self.MAX_LIMIT: | ||
self.data = {'type': 'limit_too_high', 'error': 'Invalid parameter value provided for parameter "limit".', 'detail': 'Limit is too high. Please set up your own instance to query everything.'} | ||
self.status_code = 400 | ||
return self.build_response() | ||
if args.get('name'): | ||
self.data = self.search_by_name(args['name']) | ||
if args.get('ref'): | ||
self.data = self.search_by_ref(args['ref']) | ||
if args.get('uic_ref'): | ||
self.data = self.search_by_uic_ref(args['uic_ref']) | ||
if args.get('q'): | ||
self.data = self.eliminate_duplicates(self.search_by_name(args['q']) + self.search_by_ref(args['q']) + self.search_by_uic_ref(args['q'])) | ||
return self.build_response() | ||
args = ', '.join(QUERY_PARAMETERS) | ||
raise HTTPException( | ||
HTTP_422_UNPROCESSABLE_ENTITY, | ||
{'type': 'no_query_arg', 'error': 'No argument with a search term provided.', 'detail': f'Provide one of the following query parameters: {args}'} | ||
) | ||
|
||
if name: | ||
return await self.search_by_name(name, limit) | ||
if ref: | ||
return await self.search_by_ref(ref, limit) | ||
if uic_ref: | ||
return await self.search_by_uic_ref(uic_ref, limit) | ||
if q: | ||
return self.eliminate_duplicates((await self.search_by_name(q, limit)) + (await self.search_by_ref(q, limit)) + (await self.search_by_uic_ref(q, limit)), limit) | ||
|
||
def query_has_no_wildcards(self, q): | ||
if '%' in q or '_' in q: | ||
return False | ||
return True | ||
|
||
def search_by_name(self, q): | ||
async def search_by_name(self, q, limit): | ||
if not self.query_has_no_wildcards(q): | ||
self.status_code = 400 | ||
return {'type': 'wildcard_in_query', 'error': 'Wildcard in query.', 'detail': 'Query contains any of the wildcard characters: %_'} | ||
with self.db_conn.cursor() as cursor: | ||
data = [] | ||
# TODO support filtering on state of feature: abandoned, in construction, disused, preserved, etc. | ||
# We do not sort the result although we use DISTINCT ON because osm_id is sufficient to sort out duplicates. | ||
fields = self.sql_select_fieldlist() | ||
sql_query = f"""SELECT | ||
raise HTTPException( | ||
HTTP_400_BAD_REQUEST, | ||
{'type': 'wildcard_in_query', 'error': 'Wildcard in query.', 'detail': 'Query contains any of the wildcard characters: %_'} | ||
) | ||
|
||
# TODO support filtering on state of feature: abandoned, in construction, disused, preserved, etc. | ||
# We do not sort the result although we use DISTINCT ON because osm_id is sufficient to sort out duplicates. | ||
fields = self.sql_select_fieldlist() | ||
sql_query = f"""SELECT | ||
{fields}, latitude, longitude, rank | ||
FROM ( | ||
SELECT DISTINCT ON (osm_id) | ||
{fields}, latitude, longitude, rank | ||
FROM ( | ||
SELECT DISTINCT ON (osm_id) | ||
{fields}, latitude, longitude, rank | ||
FROM ( | ||
SELECT | ||
{fields}, ST_X(ST_Transform(geom, 4326)) AS latitude, ST_Y(ST_Transform(geom, 4326)) AS longitude, openrailwaymap_name_rank(phraseto_tsquery('simple', unaccent(openrailwaymap_hyphen_to_space(%s))), terms, route_count, railway, station) AS rank | ||
FROM openrailwaymap_facilities_for_search | ||
WHERE terms @@ phraseto_tsquery('simple', unaccent(openrailwaymap_hyphen_to_space(%s))) | ||
) AS a | ||
) AS b | ||
ORDER BY rank DESC NULLS LAST | ||
LIMIT %s;""" | ||
cursor.execute(sql_query, (q, q, self.limit)) | ||
results = cursor.fetchall() | ||
for r in results: | ||
data.append(self.build_result_item_dict(cursor.description, r)) | ||
return data | ||
FROM ( | ||
SELECT | ||
{fields}, ST_X(ST_Transform(geom, 4326)) AS latitude, ST_Y(ST_Transform(geom, 4326)) AS longitude, openrailwaymap_name_rank(phraseto_tsquery('simple', unaccent(openrailwaymap_hyphen_to_space($1))), terms, route_count, railway, station) AS rank | ||
FROM openrailwaymap_facilities_for_search | ||
WHERE terms @@ phraseto_tsquery('simple', unaccent(openrailwaymap_hyphen_to_space($1))) | ||
) AS a | ||
) AS b | ||
ORDER BY rank DESC NULLS LAST | ||
LIMIT $2;""" | ||
|
||
def _search_by_ref(self, search_key, ref): | ||
with self.db_conn.cursor() as cursor: | ||
data = [] | ||
# We do not sort the result although we use DISTINCT ON because osm_id is sufficient to sort out duplicates. | ||
fields = self.sql_select_fieldlist() | ||
sql_query = f"""SELECT DISTINCT ON (osm_id) | ||
{fields}, ST_X(ST_Transform(geom, 4326)) AS latitude, ST_Y(ST_Transform(geom, 4326)) AS longitude | ||
FROM openrailwaymap_ref | ||
WHERE {search_key} = %s | ||
LIMIT %s;""" | ||
cursor.execute(sql_query, (ref, self.limit)) | ||
results = cursor.fetchall() | ||
for r in results: | ||
data.append(self.build_result_item_dict(cursor.description, r)) | ||
return data | ||
async with self.database.acquire() as connection: | ||
statement = await connection.prepare(sql_query) | ||
async with connection.transaction(): | ||
data = [] | ||
async for record in statement.cursor(q, limit): | ||
data.append(dict(record)) | ||
return data | ||
|
||
async def _search_by_ref(self, search_key, ref, limit): | ||
# We do not sort the result, although we use DISTINCT ON because osm_id is sufficient to sort out duplicates. | ||
fields = self.sql_select_fieldlist() | ||
sql_query = f"""SELECT DISTINCT ON (osm_id) | ||
{fields}, ST_X(ST_Transform(geom, 4326)) AS latitude, ST_Y(ST_Transform(geom, 4326)) AS longitude | ||
FROM openrailwaymap_ref | ||
WHERE {search_key} = $1 | ||
LIMIT $2;""" | ||
|
||
async with self.database.acquire() as connection: | ||
statement = await connection.prepare(sql_query) | ||
async with connection.transaction(): | ||
data = [] | ||
async for record in statement.cursor(ref, limit): | ||
data.append(dict(record)) | ||
return data | ||
|
||
def search_by_ref(self, ref): | ||
return self._search_by_ref("railway_ref", ref) | ||
async def search_by_ref(self, ref, limit): | ||
return await self._search_by_ref("railway_ref", ref, limit) | ||
|
||
def search_by_uic_ref(self, ref): | ||
return self._search_by_ref("uic_ref", ref) | ||
async def search_by_uic_ref(self, ref, limit): | ||
return await self._search_by_ref("uic_ref", ref, limit) | ||
|
||
def sql_select_fieldlist(self): | ||
return "osm_id, name, railway, railway_ref" |
Oops, something went wrong.