From c40a5caac4a19fedb27a79e58170a0ad5d88e3e9 Mon Sep 17 00:00:00 2001 From: Slesa Adhikari Date: Tue, 23 Jul 2024 17:02:51 -0500 Subject: [PATCH] Use `pystac` for validation instead --- stac_api/runtime/src/validation.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/stac_api/runtime/src/validation.py b/stac_api/runtime/src/validation.py index e6939b55..ffad6273 100644 --- a/stac_api/runtime/src/validation.py +++ b/stac_api/runtime/src/validation.py @@ -5,27 +5,23 @@ from typing import Dict from pydantic import BaseModel, Field, ValidationError -from src.config import ApiSettings -from stac_pydantic import Collection, Item +from src.config import api_settings + +from pystac import STACObjectType +from pystac.validation import validate_dict +from pystac.errors import STACValidationError from fastapi import Request from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware -api_settings = ApiSettings() path_prefix = api_settings.root_path or "" -class Items(BaseModel): - """Validation model for items used in BulkItems""" - - items: Dict[str, Item] - - class BulkItems(BaseModel): """Validation model for bulk-items endpoint request""" - items: Items + items: Dict[str, dict] method: str = Field(default="insert") @@ -42,21 +38,23 @@ async def dispatch(self, request: Request, call_next): f"^{path_prefix}/collections(?:/[^/]+)?$", request.url.path, ): - Collection(**request_data) + validate_dict(request_data, STACObjectType.COLLECTION) elif re.match( f"^{path_prefix}/collections/[^/]+/items(?:/[^/]+)?$", request.url.path, ): - Item(**request_data) + validate_dict(request_data, STACObjectType.ITEM) elif re.match( f"^{path_prefix}/collections/[^/]+/bulk-items$", request.url.path, ): - BulkItems(**request_data) - except ValidationError as e: + bulk_items = BulkItems(**request_data) + for item_data in bulk_items.items.items.values(): + validate_dict(item_data, STACObjectType.ITEM) + except STACValidationError as e: return JSONResponse( - status_code=400, - content={"detail": "Validation Error", "errors": e.errors()}, + status_code=422, + content={"detail": "Validation Error", "errors": str(e)}, ) response = await call_next(request)