diff --git a/app/main.py b/app/main.py index fe68003fb..a554d3af7 100644 --- a/app/main.py +++ b/app/main.py @@ -16,7 +16,12 @@ from app.errors import http_error_handler from .application import app -from .middleware import no_cache_response_header, redirect_latest, set_db_mode +from .middleware import ( + no_cache_response_header, + redirect_alias_to_version, + redirect_latest, + set_db_mode, +) from .routes import health from .routes.analysis import analysis from .routes.assets import asset, assets @@ -91,7 +96,12 @@ async def rve_error_handler( # MIDDLEWARE ################# -MIDDLEWARE = (set_db_mode, redirect_latest, no_cache_response_header) +MIDDLEWARE = ( + set_db_mode, + redirect_latest, + no_cache_response_header, + redirect_alias_to_version, +) for m in MIDDLEWARE: app.add_middleware(BaseHTTPMiddleware, dispatch=m) diff --git a/app/middleware.py b/app/middleware.py index ed3990188..d9b784d62 100644 --- a/app/middleware.py +++ b/app/middleware.py @@ -3,6 +3,7 @@ from fastapi.responses import ORJSONResponse, RedirectResponse from .application import ContextEngine +from .crud.aliases import get_alias from .crud.versions import get_latest_version from .errors import BadRequestError, RecordNotFoundError, http_error_handler @@ -25,10 +26,9 @@ async def set_db_mode(request: Request, call_next): async def redirect_latest(request: Request, call_next): """Redirect all GET requests using latest version to actual version number. - Redirect only POST requests to for query and download endpoints, as + Redirect only POST requests for query and download endpoints, as other POST endpoints will require to list version number explicitly. """ - if (request.method == "GET" and "latest" in request.url.path) or ( request.method == "POST" and "latest" in request.url.path @@ -77,6 +77,37 @@ async def redirect_latest(request: Request, call_next): return response +async def redirect_alias_to_version(request: Request, call_next): + """Redirect version request by alias to the actual dataset version.""" + + path_items = request.url.path.split("/") + is_dataset_version_path = len(path_items) >= 4 and path_items[1] == "dataset" + is_allowed_post_request = request.method == "POST" and ( + "query" in request.url.path or "download" in request.url.path + ) + if not is_dataset_version_path: + response = await call_next(request) + return response + if request.method not in ["GET", "POST"] or ( + request.method == "POST" and not is_allowed_post_request + ): + response = await call_next(request) + return response + + dataset, version = path_items[2:4] + try: + alias = await get_alias(dataset, version) + path_items[3] = alias.version + url = "/".join(path_items) + if request.query_params: + url = f"{url}?{request.query_params}" + return RedirectResponse(url=url) + + except RecordNotFoundError: + response = await call_next(request) + return response + + async def no_cache_response_header(request: Request, call_next): """This middleware adds a cache control response header. diff --git a/app/routes/__init__.py b/app/routes/__init__.py index 7a53f2667..50d241d13 100644 --- a/app/routes/__init__.py +++ b/app/routes/__init__.py @@ -47,8 +47,29 @@ async def dataset_version_dependency( except RecordNotFoundError as e: try: version_alias = await get_alias(dataset, version) - await get_version(dataset, version_alias.version) + if version_alias is not None: + raise HTTPException( + status_code=400, + detail="Getting version by alias is not supported for this operation.", + ) except RecordNotFoundError: raise HTTPException(status_code=404, detail=str(e)) return dataset, version + + +async def create_dataset_version_dependency( + dataset: str = Depends(dataset_dependency), + version: str = Depends(version_dependency), +) -> Tuple[str, str]: + try: + version_alias = await get_alias(dataset, version) + if version_alias is not None: + raise HTTPException( + status_code=400, + detail="Conflicts with existing version alias and can not overwrite it.", + ) + except RecordNotFoundError: + pass + + return dataset, version diff --git a/app/routes/datasets/versions.py b/app/routes/datasets/versions.py index a353746ec..2a8959f3d 100644 --- a/app/routes/datasets/versions.py +++ b/app/routes/datasets/versions.py @@ -17,11 +17,10 @@ from fastapi.responses import ORJSONResponse from ...authentication.token import is_admin -from ...crud import aliases, assets, versions +from ...crud import assets, versions from ...errors import RecordAlreadyExistsError, RecordNotFoundError from ...models.enum.assets import AssetStatus, AssetType from ...models.enum.pixetl import Grid -from ...models.orm.aliases import Alias as ORMAlias from ...models.orm.assets import Asset as ORMAsset from ...models.orm.versions import Version as ORMVersion from ...models.pydantic.change_log import ChangeLog, ChangeLogResponse @@ -44,7 +43,7 @@ VersionResponse, VersionUpdateIn, ) -from ...routes import dataset_dependency, dataset_version_dependency, version_dependency +from ...routes import create_dataset_version_dependency, dataset_version_dependency from ...settings.globals import TILE_CACHE_CLOUDFRONT_ID from ...tasks.aws_tasks import flush_cloudfront_cache from ...tasks.default_assets import append_default_asset, create_default_asset @@ -68,11 +67,7 @@ async def get_version( """Get basic metadata for a given version.""" dataset, version = dv - try: - row: ORMVersion = await versions.get_version(dataset, version) - except RecordNotFoundError: - version_alias: ORMAlias = await aliases.get_alias(dataset, version) - row = await versions.get_version(dataset, version_alias.version) + row: ORMVersion = await versions.get_version(dataset, version) return await _version_response(dataset, version, row) @@ -86,8 +81,7 @@ async def get_version( ) async def add_new_version( *, - dataset: str = Depends(dataset_dependency), - version: str = Depends(version_dependency), + dataset_version: Tuple[str, str] = Depends(create_dataset_version_dependency), request: VersionCreateIn, background_tasks: BackgroundTasks, is_authorized: bool = Depends(is_admin), @@ -95,6 +89,7 @@ async def add_new_version( ): """Create or update a version for a given dataset.""" + dataset, version = dataset_version input_data = request.dict(exclude_none=True, by_alias=True) creation_options = input_data.pop("creation_options")