Skip to content

Commit

Permalink
Async API implementation with FastAPI and asyncpg (#76)
Browse files Browse the repository at this point in the history
Docs in https://fastapi.tiangolo.com/tutorial/ and
https://magicstack.github.io/asyncpg/current/usage.html

Change the current werkzeug implementation from a development-only WSGI
server to a production ready async implementation that can run quickly
using only a single CPU core
  • Loading branch information
hiddewie authored Jul 21, 2024
1 parent f6b6bd7 commit 978e0f8
Show file tree
Hide file tree
Showing 9 changed files with 408 additions and 296 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
data/filtered/germany.osm.pbf
key: ${{ runner.os }}-data-${{ steps.get-date.outputs.date }}-berlin

- name: Download Germany
- name: Download Berlin
if: ${{ steps.cache.outputs.cache-hit != 'true' }}
run: |
curl --location --fail --output data/berlin.osm.pbf https://download.geofabrik.de/europe/germany/berlin-latest.osm.pbf
Expand Down
11 changes: 6 additions & 5 deletions api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ HEALTHCHECK CMD ["pg_isready", "--host", "localhost", "--user", "postgres", "--d
FROM postgis/postgis:16-3.4-alpine as runtime

RUN apk add --no-cache \
curl \
python3 \
py3-pip \
py3-werkzeug \
py3-psycopg2
curl \
python3 \
py3-pip \
&& python3 -m pip install --no-cache-dir --no-color --no-python-version-warning --disable-pip-version-check --break-system-packages \
fastapi \
asyncpg

WORKDIR /app
COPY openrailwaymap_api openrailwaymap_api
Expand Down
129 changes: 65 additions & 64 deletions api/api.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,70 @@
#! /usr/bin/env python3
# SPDX-License-Identifier: GPL-2.0-or-later

import psycopg2
import psycopg2.extras
import json
import sys
import contextlib
import os
from werkzeug.exceptions import HTTPException, NotFound, InternalServerError
from werkzeug.routing import Map, Rule
from werkzeug.wrappers import Request, Response
from typing import Annotated

import asyncpg
from fastapi import FastAPI
from fastapi import Query

from openrailwaymap_api.facility_api import FacilityAPI
from openrailwaymap_api.milestone_api import MilestoneAPI
from openrailwaymap_api.status_api import StatusAPI

def connect_db():
conn = psycopg2.connect(dbname=os.environ['POSTGRES_DB'], user=os.environ['POSTGRES_USER'], host=os.environ['POSTGRES_HOST'])
return conn

class OpenRailwayMapAPI:

db_conn = connect_db()

def __init__(self):
self.url_map = Map([
Rule('/api/facility', endpoint=FacilityAPI, methods=('GET',)),
Rule('/api/milestone', endpoint=MilestoneAPI, methods=('GET',)),
Rule('/api/status', endpoint=StatusAPI, methods=('GET',)),
])

def ensure_db_connection_alive(self):
if self.db_conn.closed != 0:
self.db_conn = connect_db()

def dispatch_request(self, environ, start_response):
request = Request(environ)
urls = self.url_map.bind_to_environ(environ)
response = None
try:
endpoint, args = urls.match()
self.ensure_db_connection_alive()
response = endpoint(self.db_conn)(request.args)
except HTTPException as e:
return e
except Exception as e:
print('Error during request:', e, file=sys.stderr)
return InternalServerError()
finally:
if not response:
self.db_conn.close()
self.db_conn = connect_db()
return response

def wsgi_app(self, environ, start_response):
request = Request(environ)
response = self.dispatch_request(request)
return response(environ, start_response)


def application(environ, start_response):
openrailwaymap_api = OpenRailwayMapAPI()
response = openrailwaymap_api.dispatch_request(environ, start_response)
return response(environ, start_response)


if __name__ == '__main__':
openrailwaymap_api = OpenRailwayMapAPI()
from werkzeug.serving import run_simple
run_simple('::', int(os.environ['PORT']), application, use_debugger=True, use_reloader=True)

@contextlib.asynccontextmanager
async def lifespan(app):
async with asyncpg.create_pool(
user=os.environ['POSTGRES_USER'],
host=os.environ['POSTGRES_HOST'],
database=os.environ['POSTGRES_DB'],
command_timeout=10,
min_size=1,
max_size=20,
) as pool:
print('Connected to database')
app.state.database = pool

yield

app.state.database = None

print('Disconnected from database')


app = FastAPI(
title="OpenRailwayMap API",
lifespan=lifespan,
)

DEFAULT_FACILITY_LIMIT = 20
DEFAULT_MILESTONE_LIMIT = 2
MIN_LIMIT = 1
MAX_LIMIT = 200


@app.get("/api/status")
async def status():
api = StatusAPI()
return await api()


@app.get("/api/facility")
async def facility(
q: Annotated[str | None, Query()] = None,
name: Annotated[str | None, Query()] = None,
ref: Annotated[str | None, Query()] = None,
uic_ref: Annotated[str | None, Query()] = None,
limit: Annotated[int, Query(ge=MIN_LIMIT, le=MAX_LIMIT)] = DEFAULT_FACILITY_LIMIT,
):
api = FacilityAPI(app.state.database)
return await api(q=q, name=name, ref=ref, uic_ref=uic_ref, limit=limit)


@app.get("/api/milestone")
async def milestone(
ref: Annotated[str, Query()],
position: Annotated[float, Query()],
limit: Annotated[int | None, Query(ge=MIN_LIMIT, le=MAX_LIMIT)] = DEFAULT_MILESTONE_LIMIT,
):
api = MilestoneAPI(app.state.database)
return await api(ref=ref, position=position, limit=limit)
22 changes: 0 additions & 22 deletions api/openrailwaymap_api/abstract_api.py

This file was deleted.

175 changes: 83 additions & 92 deletions api/openrailwaymap_api/facility_api.py
Original file line number Diff line number Diff line change
@@ -1,118 +1,109 @@
# SPDX-License-Identifier: GPL-2.0-or-later
from openrailwaymap_api.abstract_api import AbstractAPI
from fastapi import HTTPException
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_422_UNPROCESSABLE_ENTITY

class FacilityAPI(AbstractAPI):
def __init__(self, db_conn):
self.db_conn = db_conn
self.search_args = ['q', 'name', 'ref', 'uic_ref']
self.data = []
self.status_code = 200
self.limit = 20
QUERY_PARAMETERS = ['q', 'name', 'ref', 'uic_ref']

def eliminate_duplicates(self, data):
class FacilityAPI:
def __init__(self, database):
self.database = database

def eliminate_duplicates(self, data, limit):
data.sort(key=lambda k: k['osm_id'])
i = 1
while i < len(data):
if data[i]['osm_id'] == data[i-1]['osm_id']:
if data[i]['osm_id'] == data[i - 1]['osm_id']:
data.pop(i)
i += 1
if len(data) > self.limit:
return data[:self.limit]
if len(data) > limit:
return data[:limit]
return data

def __call__(self, args):
data = []
async def __call__(self, *, q, name, ref, uic_ref, limit):
# Validate search arguments
search_args_count = 0
for search_arg in self.search_args:
if search_arg in args and args[search_arg]:
search_args_count += 1
search_args_count = sum(1 for search_arg in [q, name, ref, uic_ref] if search_arg)

if search_args_count > 1:
args = ', '.join(self.search_args)
self.data = {'type': 'multiple_query_args', 'error': 'More than one argument with a search term provided.', 'detail': f'Provide only one of the following arguments: {args}'}
self.status_code = 400
return self.build_response()
args = ', '.join(QUERY_PARAMETERS)
raise HTTPException(
HTTP_422_UNPROCESSABLE_ENTITY,
{'type': 'multiple_query_args', 'error': 'More than one argument with a search term provided.', 'detail': f'Provide only one of the following query parameters: {args}'}
)
elif search_args_count == 0:
args = ', '.join(self.search_args)
self.data = {'type': 'no_query_arg', 'error': 'No argument with a search term provided.', 'detail': f'Provide one of the following arguments: {args}'}
self.status_code = 400
return self.build_response()
if 'limit' in args:
try:
self.limit = int(args['limit'])
except ValueError:
self.data = {'type': 'limit_not_integer', 'error': 'Invalid parameter value provided for parameter "limit".', 'detail': 'The provided limit cannot be parsed as an integer value.'}
self.status_code = 400
return self.build_response()
if self.limit > self.MAX_LIMIT:
self.data = {'type': 'limit_too_high', 'error': 'Invalid parameter value provided for parameter "limit".', 'detail': 'Limit is too high. Please set up your own instance to query everything.'}
self.status_code = 400
return self.build_response()
if args.get('name'):
self.data = self.search_by_name(args['name'])
if args.get('ref'):
self.data = self.search_by_ref(args['ref'])
if args.get('uic_ref'):
self.data = self.search_by_uic_ref(args['uic_ref'])
if args.get('q'):
self.data = self.eliminate_duplicates(self.search_by_name(args['q']) + self.search_by_ref(args['q']) + self.search_by_uic_ref(args['q']))
return self.build_response()
args = ', '.join(QUERY_PARAMETERS)
raise HTTPException(
HTTP_422_UNPROCESSABLE_ENTITY,
{'type': 'no_query_arg', 'error': 'No argument with a search term provided.', 'detail': f'Provide one of the following query parameters: {args}'}
)

if name:
return await self.search_by_name(name, limit)
if ref:
return await self.search_by_ref(ref, limit)
if uic_ref:
return await self.search_by_uic_ref(uic_ref, limit)
if q:
return self.eliminate_duplicates((await self.search_by_name(q, limit)) + (await self.search_by_ref(q, limit)) + (await self.search_by_uic_ref(q, limit)), limit)

def query_has_no_wildcards(self, q):
if '%' in q or '_' in q:
return False
return True

def search_by_name(self, q):
async def search_by_name(self, q, limit):
if not self.query_has_no_wildcards(q):
self.status_code = 400
return {'type': 'wildcard_in_query', 'error': 'Wildcard in query.', 'detail': 'Query contains any of the wildcard characters: %_'}
with self.db_conn.cursor() as cursor:
data = []
# TODO support filtering on state of feature: abandoned, in construction, disused, preserved, etc.
# We do not sort the result although we use DISTINCT ON because osm_id is sufficient to sort out duplicates.
fields = self.sql_select_fieldlist()
sql_query = f"""SELECT
raise HTTPException(
HTTP_400_BAD_REQUEST,
{'type': 'wildcard_in_query', 'error': 'Wildcard in query.', 'detail': 'Query contains any of the wildcard characters: %_'}
)

# TODO support filtering on state of feature: abandoned, in construction, disused, preserved, etc.
# We do not sort the result although we use DISTINCT ON because osm_id is sufficient to sort out duplicates.
fields = self.sql_select_fieldlist()
sql_query = f"""SELECT
{fields}, latitude, longitude, rank
FROM (
SELECT DISTINCT ON (osm_id)
{fields}, latitude, longitude, rank
FROM (
SELECT DISTINCT ON (osm_id)
{fields}, latitude, longitude, rank
FROM (
SELECT
{fields}, ST_X(ST_Transform(geom, 4326)) AS latitude, ST_Y(ST_Transform(geom, 4326)) AS longitude, openrailwaymap_name_rank(phraseto_tsquery('simple', unaccent(openrailwaymap_hyphen_to_space(%s))), terms, route_count, railway, station) AS rank
FROM openrailwaymap_facilities_for_search
WHERE terms @@ phraseto_tsquery('simple', unaccent(openrailwaymap_hyphen_to_space(%s)))
) AS a
) AS b
ORDER BY rank DESC NULLS LAST
LIMIT %s;"""
cursor.execute(sql_query, (q, q, self.limit))
results = cursor.fetchall()
for r in results:
data.append(self.build_result_item_dict(cursor.description, r))
return data
FROM (
SELECT
{fields}, ST_X(ST_Transform(geom, 4326)) AS latitude, ST_Y(ST_Transform(geom, 4326)) AS longitude, openrailwaymap_name_rank(phraseto_tsquery('simple', unaccent(openrailwaymap_hyphen_to_space($1))), terms, route_count, railway, station) AS rank
FROM openrailwaymap_facilities_for_search
WHERE terms @@ phraseto_tsquery('simple', unaccent(openrailwaymap_hyphen_to_space($1)))
) AS a
) AS b
ORDER BY rank DESC NULLS LAST
LIMIT $2;"""

def _search_by_ref(self, search_key, ref):
with self.db_conn.cursor() as cursor:
data = []
# We do not sort the result although we use DISTINCT ON because osm_id is sufficient to sort out duplicates.
fields = self.sql_select_fieldlist()
sql_query = f"""SELECT DISTINCT ON (osm_id)
{fields}, ST_X(ST_Transform(geom, 4326)) AS latitude, ST_Y(ST_Transform(geom, 4326)) AS longitude
FROM openrailwaymap_ref
WHERE {search_key} = %s
LIMIT %s;"""
cursor.execute(sql_query, (ref, self.limit))
results = cursor.fetchall()
for r in results:
data.append(self.build_result_item_dict(cursor.description, r))
return data
async with self.database.acquire() as connection:
statement = await connection.prepare(sql_query)
async with connection.transaction():
data = []
async for record in statement.cursor(q, limit):
data.append(dict(record))
return data

async def _search_by_ref(self, search_key, ref, limit):
# We do not sort the result, although we use DISTINCT ON because osm_id is sufficient to sort out duplicates.
fields = self.sql_select_fieldlist()
sql_query = f"""SELECT DISTINCT ON (osm_id)
{fields}, ST_X(ST_Transform(geom, 4326)) AS latitude, ST_Y(ST_Transform(geom, 4326)) AS longitude
FROM openrailwaymap_ref
WHERE {search_key} = $1
LIMIT $2;"""

async with self.database.acquire() as connection:
statement = await connection.prepare(sql_query)
async with connection.transaction():
data = []
async for record in statement.cursor(ref, limit):
data.append(dict(record))
return data

def search_by_ref(self, ref):
return self._search_by_ref("railway_ref", ref)
async def search_by_ref(self, ref, limit):
return await self._search_by_ref("railway_ref", ref, limit)

def search_by_uic_ref(self, ref):
return self._search_by_ref("uic_ref", ref)
async def search_by_uic_ref(self, ref, limit):
return await self._search_by_ref("uic_ref", ref, limit)

def sql_select_fieldlist(self):
return "osm_id, name, railway, railway_ref"
Loading

0 comments on commit 978e0f8

Please sign in to comment.