Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Make root_path for app configurable #400

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions app/api/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
# Request constants
EnvVar = namedtuple("EnvVar", ["name", "val"])

ROOT_PATH = EnvVar(
"NB_NAPI_ROOT_PATH", os.environ.get("NB_NAPI_ROOT_PATH", "")
)

ALLOWED_ORIGINS = EnvVar(
"NB_API_ALLOWED_ORIGINS", os.environ.get("NB_API_ALLOWED_ORIGINS", "")
)
Expand All @@ -21,12 +25,12 @@
"NB_GRAPH_PASSWORD", os.environ.get("NB_GRAPH_PASSWORD")
)
GRAPH_ADDRESS = EnvVar(
"NB_GRAPH_ADDRESS", os.environ.get("NB_GRAPH_ADDRESS", "206.12.99.17")
"NB_GRAPH_ADDRESS", os.environ.get("NB_GRAPH_ADDRESS", "127.0.0.1")
)
GRAPH_DB = EnvVar(
"NB_GRAPH_DB", os.environ.get("NB_GRAPH_DB", "test_data/query")
"NB_GRAPH_DB", os.environ.get("NB_GRAPH_DB", "repositories/my_db")
)
GRAPH_PORT = EnvVar("NB_GRAPH_PORT", os.environ.get("NB_GRAPH_PORT", 5820))
GRAPH_PORT = EnvVar("NB_GRAPH_PORT", os.environ.get("NB_GRAPH_PORT", 7200))
# TODO: Environment variables can't be parsed as bool so this is a workaround but isn't ideal.
# Another option is to switch this to a command-line argument, but that would require changing the
# Dockerfile also since Uvicorn can't accept custom command-line args.
Expand Down
19 changes: 11 additions & 8 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tempfile import TemporaryDirectory

import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.responses import HTMLResponse, ORJSONResponse, RedirectResponse
Expand All @@ -16,6 +16,7 @@
from .api.security import check_client_id

app = FastAPI(
root_path=util.ROOT_PATH.val,
default_response_class=ORJSONResponse,
docs_url=None,
redoc_url=None,
Expand All @@ -33,15 +34,15 @@


@app.get("/", response_class=HTMLResponse)
def root():
def root(request: Request):
"""
Display a welcome message and a link to the API documentation.
"""
return """
return f"""
<html>
<body>
<h1>Welcome to the Neurobagel REST API!</h1>
<p>Please visit the <a href="/docs">documentation</a> to view available API endpoints.</p>
<p>Please visit the <a href="{request.scope.get("root_path", "")}/docs">documentation</a> to view available API endpoints.</p>
</body>
</html>
"""
Expand All @@ -56,24 +57,24 @@ async def favicon():


@app.get("/docs", include_in_schema=False)
def overridden_swagger():
def overridden_swagger(request: Request):
"""
Overrides the Swagger UI HTML for the "/docs" endpoint.
"""
return get_swagger_ui_html(
openapi_url="/openapi.json",
openapi_url=f"{request.scope.get('root_path', '')}/openapi.json",
title="Neurobagel API",
swagger_favicon_url=favicon_url,
)


@app.get("/redoc", include_in_schema=False)
def overridden_redoc():
def overridden_redoc(request: Request):
"""
Overrides the Redoc HTML for the "/redoc" endpoint.
"""
return get_redoc_html(
openapi_url="/openapi.json",
openapi_url=f"{request.scope.get('root_path', '')}/openapi.json",
title="Neurobagel API",
redoc_favicon_url=favicon_url,
)
Expand Down Expand Up @@ -103,6 +104,8 @@ async def auth_check():
async def allowed_origins_check():
"""Raises warning if allowed origins environment variable has not been set or is an empty string."""
if os.environ.get(util.ALLOWED_ORIGINS.name, "") == "":
# TODO: For debugging - remove
print(util.ROOT_PATH.val)
warnings.warn(
f"The API was launched without providing any values for the {util.ALLOWED_ORIGINS.name} environment variable. "
"This means that the API will only be accessible from the same origin it is hosted from: https://developer.mozilla.org/en-US/docs/Web/Security/Same-origin_policy. "
Expand Down
30 changes: 26 additions & 4 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import pytest

from app.api import crud
from app.main import app


@pytest.mark.parametrize(
"root_path",
"route",
["/", ""],
)
def test_root(test_app, root_path):
def test_root(test_app, route, monkeypatch):
"""Given a GET request to the root endpoint, Check for 200 status and expected content."""

response = test_app.get(root_path, follow_redirects=False)
# root_path determines the docs link on the welcome page
monkeypatch.setattr(app, "root_path", "")
response = test_app.get(route, follow_redirects=False)

assert response.status_code == 200
assert "Welcome to the Neurobagel REST API!" in response.text
Expand Down Expand Up @@ -43,3 +45,23 @@ def test_request_with_trailing_slash_not_redirected(
"""
response = test_app.get(invalid_route)
assert response.status_code == 404


@pytest.mark.parametrize(
"test_root_path,expected_status_code",
[("", 200), ("/api/v1", 200), ("/wrongroot", 404)],
)
def test_docs_work_using_defined_root_path(
test_app, test_root_path, expected_status_code, monkeypatch
):
monkeypatch.setattr(app, "root_path", "/api/v1")
docs_response = test_app.get(
f"{test_root_path}/docs", follow_redirects=False
)
# When the root path is not set correctly, the docs can break due to failure to fetch openapi.json
# See also https://fastapi.tiangolo.com/advanced/behind-a-proxy/#proxy-with-a-stripped-path-prefix
schema_response = test_app.get(
f"{test_root_path}/openapi.json", follow_redirects=False
)
assert docs_response.status_code == expected_status_code
assert schema_response.status_code == expected_status_code
Loading