diff --git a/API/api_worker.py b/API/api_worker.py index 930ac111..92e24822 100644 --- a/API/api_worker.py +++ b/API/api_worker.py @@ -19,7 +19,7 @@ from src.query_builder.builder import format_file_name_str from src.validation.models import RawDataOutputType -celery = Celery(__name__) +celery = Celery("Raw Data API") celery.conf.broker_url = celery_broker_uri celery.conf.result_backend = celery_backend celery.conf.task_serializer = "pickle" diff --git a/API/auth/__init__.py b/API/auth/__init__.py index e6f9bcba..0839309d 100644 --- a/API/auth/__init__.py +++ b/API/auth/__init__.py @@ -1,20 +1,53 @@ +from enum import Enum from typing import Union -from fastapi import Header +from fastapi import Depends, Header, HTTPException from osm_login_python.core import Auth -from pydantic import BaseModel +from pydantic import BaseModel, Field -from src.config import get_oauth_credentials +from src.config import ADMIN_IDS, get_oauth_credentials + + +class UserRole(Enum): + ADMIN = 1 + STAFF = 2 + GUEST = 3 class AuthUser(BaseModel): id: int username: str img_url: Union[str, None] + role: UserRole = Field(default=UserRole.GUEST.value) osm_auth = Auth(*get_oauth_credentials()) +def is_admin(osm_id: int): + admin_ids = [int(admin_id) for admin_id in ADMIN_IDS] + return osm_id in admin_ids + + def login_required(access_token: str = Header(...)): - return osm_auth.deserialize_access_token(access_token) + user = AuthUser(**osm_auth.deserialize_access_token(access_token)) + if is_admin(user.id): + user.role = UserRole.ADMIN + return user + + +def get_optional_user(access_token: str = Header(default=None)) -> AuthUser: + if access_token: + user = AuthUser(**osm_auth.deserialize_access_token(access_token)) + if is_admin(user.id): + user.role = UserRole.ADMIN + return user + else: + # If no token provided, return a user with limited options or guest user + return AuthUser(id=0, username="guest", img_url=None) + + +def admin_required(user: AuthUser = Depends(login_required)): + if not is_admin(user.id): + raise HTTPException(status_code=403, detail="User is not an admin") + return user diff --git a/API/auth/routers.py b/API/auth/routers.py index a2f5a89f..5d428b7c 100644 --- a/API/auth/routers.py +++ b/API/auth/routers.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends, Request -from . import AuthUser, login_required, osm_auth +from . import AuthUser, admin_required, login_required, osm_auth router = APIRouter(prefix="/auth") diff --git a/API/main.py b/API/main.py index 42c8480d..aa6a2a92 100644 --- a/API/main.py +++ b/API/main.py @@ -38,6 +38,7 @@ from src.config import logger as logging from src.db_session import database_instance +from .auth.routers import router as auth_router from .raw_data import router as raw_data_router from .tasks import router as tasks_router @@ -58,10 +59,16 @@ os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" app = FastAPI(title="Raw Data API ") -# app.include_router(auth_router) +app.include_router(auth_router) app.include_router(raw_data_router) app.include_router(tasks_router) - +app.openapi = { + "info": { + "title": "Raw Data API", + "version": "1.0", + }, + "security": [{"OAuth2PasswordBearer": []}], +} app = VersionedFastAPI( app, enable_latest=False, version_format="{major}", prefix_format="/v{major}" diff --git a/API/raw_data.py b/API/raw_data.py index 3b3cbdfe..ae4b484b 100644 --- a/API/raw_data.py +++ b/API/raw_data.py @@ -19,17 +19,20 @@ """[Router Responsible for Raw data API ] """ +import json import os import shutil import time import requests -from fastapi import APIRouter, Body, Request +from area import area +from fastapi import APIRouter, Body, Depends, HTTPException, Request from fastapi.responses import JSONResponse from fastapi_versioning import version from geojson import FeatureCollection from src.app import RawData +from src.config import ALLOW_BIND_ZIP_FILTER, EXPORT_MAX_AREA_SQKM from src.config import LIMITER as limiter from src.config import RATE_LIMIT_PER_MIN as export_rate_limit from src.config import logger as logging @@ -41,6 +44,7 @@ ) from .api_worker import process_raw_data +from .auth import AuthUser, UserRole, get_optional_user router = APIRouter(prefix="") @@ -421,6 +425,7 @@ def get_osm_current_snapshot_as_file( }, }, ), + user: AuthUser = Depends(get_optional_user), ): """Generates the current raw OpenStreetMap data available on database based on the input geometry, query and spatial features. @@ -434,17 +439,45 @@ def get_osm_current_snapshot_as_file( 2. Now navigate to /tasks/ with your task id to track progress and result """ + + if not (user.role == UserRole.STAFF or user.role == UserRole.ADMIN): + area_m2 = area(json.loads(params.geometry.json())) + area_km2 = area_m2 * 1e-6 + RAWDATA_CURRENT_POLYGON_AREA = int(EXPORT_MAX_AREA_SQKM) + if area_km2 > RAWDATA_CURRENT_POLYGON_AREA: + raise HTTPException( + status_code=400, + detail=[ + { + "msg": f"""Polygon Area {int(area_km2)} Sq.KM is higher than Threshold : {RAWDATA_CURRENT_POLYGON_AREA} Sq.KM""" + } + ], + ) + if not params.uuid: + raise HTTPException( + status_code=403, + detail=[{"msg": "Insufficient Permission for uuid = False"}], + ) + if ALLOW_BIND_ZIP_FILTER: + if not params.bind_zip: + raise HTTPException( + status_code=403, + detail=[{"msg": "Insufficient Permission for bind_zip"}], + ) + queue_name = "recurring_queue" if not params.uuid else "raw_default" task = process_raw_data.apply_async(args=(params,), queue=queue_name) return JSONResponse({"task_id": task.id, "track_link": f"/tasks/status/{task.id}/"}) -@router.get("/snapshot/plain/", response_model=FeatureCollection) +@router.post("/snapshot/plain/", response_model=FeatureCollection) @version(1) def get_osm_current_snapshot_as_plain_geojson( - request: Request, params: RawDataCurrentParamsBase + request: Request, + params: RawDataCurrentParamsBase, + user: AuthUser = Depends(get_optional_user), ): - """Generates the Plain geojson for the polygon within 100 Sqkm and returns the result right away + """Generates the Plain geojson for the polygon within 30 Sqkm and returns the result right away Args: request (Request): _description_ @@ -453,6 +486,19 @@ def get_osm_current_snapshot_as_plain_geojson( Returns: Featurecollection: Geojson """ + if not (user.role == UserRole.STAFF or user.role == UserRole.ADMIN): + area_m2 = area(json.loads(params.geometry.json())) + area_km2 = area_m2 * 1e-6 + if area_km2 > 30: + raise HTTPException( + status_code=400, + detail=[ + { + "msg": f"""Polygon Area {int(area_km2)} Sq.KM is higher than Threshold : 30 Sq.KM""" + } + ], + ) + params.output_type = "geojson" # always geojson result = RawData(params).extract_plain_geojson() return result diff --git a/API/tasks.py b/API/tasks.py index c6686e98..1554a685 100644 --- a/API/tasks.py +++ b/API/tasks.py @@ -1,11 +1,12 @@ from celery.result import AsyncResult -from fastapi import APIRouter +from fastapi import APIRouter, Depends from fastapi.responses import JSONResponse from fastapi_versioning import version from src.validation.models import SnapshotTaskResponse from .api_worker import celery +from .auth import AuthUser, admin_required, login_required router = APIRouter(prefix="/tasks") @@ -35,3 +36,55 @@ def get_task_status(task_id): "result": task_result.result if task_result.status == "SUCCESS" else None, } return JSONResponse(result) + + +@router.get("/revoke/{task_id}/") +@version(1) +def revoke_task(task_id, user: AuthUser = Depends(login_required)): + """Revokes task , Terminates if it is executing + + Args: + task_id (_type_): task id of raw data task + + Returns: + status: status of revoked task + """ + revoked_task = celery.control.revoke(task_id=task_id, terminate=True) + return JSONResponse({"id": task_id, "status": revoked_task}) + + +@router.get("/inspect/") +@version(1) +def inspect_workers(): + """Inspects tasks assigned to workers + + Returns: + scheduled: All scheduled tasks to be picked up by workers + active: Current Active tasks ongoing on workers + """ + inspected = celery.control.inspect() + return JSONResponse( + {"scheduled": str(inspected.scheduled()), "active": str(inspected.active())} + ) + + +@router.get("/ping/") +@version(1) +def ping_workers(): + """Pings available workers + + Returns: {worker_name : return_result} + """ + inspected_ping = celery.control.inspect().ping() + return JSONResponse(inspected_ping) + + +@router.get("/purge/") +@version(1) +def discard_all_waiting_tasks(user: AuthUser = Depends(admin_required)): + """ + Discards all waiting tasks from the queue + Returns : Number of tasks discarded + """ + purged = celery.control.purge() + return JSONResponse({"tasks_discarded": purged}) diff --git a/README.md b/README.md index 99a80fb1..1b341e9c 100644 --- a/README.md +++ b/README.md @@ -113,11 +113,11 @@ You should be able to start [celery](https://docs.celeryq.dev/en/stable/getting- - Start for default queue ``` - celery --app API.api_worker worker --loglevel=INFO --queues="raw_default" + celery --app API.api_worker worker --loglevel=INFO --queues="raw_default" -n 'default_worker' ``` - Start for recurring queue ``` - celery --app API.api_worker worker --loglevel=INFO --queues="recurring_queue" + celery --app API.api_worker worker --loglevel=INFO --queues="recurring_queue" -n 'recurring_worker' ``` Set no of request that a worker can take at a time by using --concurrency diff --git a/docs/src/installation/configurations.md b/docs/src/installation/configurations.md index 3aa4dd67..a118c42f 100644 --- a/docs/src/installation/configurations.md +++ b/docs/src/installation/configurations.md @@ -34,6 +34,7 @@ The following are the different configuration options that are accepted. | `LOGIN_REDIRECT_URI` | `LOGIN_REDIRECT_URI` | `[OAUTH]` | _none_ | Redirect URL set in the OAuth2 application | REQUIRED | | `APP_SECRET_KEY` | `APP_SECRET_KEY` | `[OAUTH]` | _none_ | High-entropy string generated for the application | REQUIRED | | `OSM_URL` | `OSM_URL` | `[OAUTH]` | `https://www.openstreetmap.org` | OSM instance Base URL | OPTIONAL | +| `ADMIN_IDS` | `ADMIN_IDS` | `[OAUTH]` | `00000` | List of Admin OSMId separated by , | OPTIONAL | | `LOG_LEVEL` | `LOG_LEVEL` | `[API_CONFIG]` | `debug` | Application log level; info,debug,warning,error | OPTIONAL | | `RATE_LIMITER_STORAGE_URI` | `RATE_LIMITER_STORAGE_URI` | `[API_CONFIG]` | `redis://redis:6379` | Redis connection string for rate-limiter data | OPTIONAL | | `RATE_LIMIT_PER_MIN` | `RATE_LIMIT_PER_MIN` | `[API_CONFIG]` | `5` | Number of requests per minute before being rate limited | OPTIONAL | @@ -67,6 +68,7 @@ The following are the different configuration options that are accepted. | `LOGIN_REDIRECT_URI` | TBD | Yes | No | | `APP_SECRET_KEY` | TBD | Yes | No | | `OSM_URL` | TBD | Yes | No | +| `ADMIN_IDS` | TBD | Yes | No | | `LOG_LEVEL` | `[API_CONFIG]` | Yes | Yes | | `RATE_LIMITER_STORAGE_URI` | `[API_CONFIG]` | Yes | No | | `RATE_LIMIT_PER_MIN` | `[API_CONFIG]` | Yes | No | diff --git a/src/app.py b/src/app.py index ba0457dc..1e019432 100644 --- a/src/app.py +++ b/src/app.py @@ -39,6 +39,7 @@ AWS_SECRET_ACCESS_KEY, BUCKET_NAME, ENABLE_TILES, + EXPORT_MAX_AREA_SQKM, ) from src.config import EXPORT_PATH as export_path from src.config import INDEX_THRESHOLD as index_threshold @@ -465,15 +466,21 @@ def get_grid_id(geom, cur): countries = backend_match[0] country_export = True logging.debug(f"Using Country Export Mode with id : {countries[0]}") - else: - if int(geom_area) > int(index_threshold): - # this will be applied only when polygon gets bigger we will be slicing index size to search - country_query = get_country_id_query(geometry_dump) - cur.execute(country_query) - result_country = cur.fetchall() - countries = [int(f[0]) for f in result_country] - logging.debug(f"Intersected Countries : {countries}") - cur.close() + else : + if int(geom_area) > int(EXPORT_MAX_AREA_SQKM): + raise ValueError( + f"""Polygon Area {int(geom_area)} Sq.KM is higher than Threshold : {EXPORT_MAX_AREA_SQKM} Sq.KM""" + ) + + # else: + # if int(geom_area) > int(index_threshold): + # # this will be applied only when polygon gets bigger we will be slicing index size to search + # country_query = get_country_id_query(geometry_dump) + # cur.execute(country_query) + # result_country = cur.fetchall() + # countries = [int(f[0]) for f in result_country] + # logging.debug(f"Intersected Countries : {countries}") + # cur.close() return ( g_id, geometry_dump, @@ -663,7 +670,7 @@ def get_osm_feature(self, osm_id): return FeatureCollection(features=features) def extract_plain_geojson(self): - """Gets geojson for small area : Performs direct query with/without geometry""" + """Gets geojson for small area Returns plain geojson without binding""" extraction_query = raw_currentdata_extraction_query(self.params) features = [] diff --git a/src/config.py b/src/config.py index a7b4b621..290f5d9c 100644 --- a/src/config.py +++ b/src/config.py @@ -76,6 +76,13 @@ "API_CONFIG", "ENABLE_TILES", fallback=None ) +###### + +ADMIN_IDS = os.environ.get("ADMIN_IDS") or config.get( + "OAUTH", "ADMIN_IDS", fallback="00000" +).split(",") + + #################### ### EXPORT_UPLOAD CONFIG BLOCK diff --git a/src/validation/models.py b/src/validation/models.py index 842f4f02..367e8691 100644 --- a/src/validation/models.py +++ b/src/validation/models.py @@ -21,7 +21,6 @@ from enum import Enum from typing import Dict, List, Optional, Union -from area import area from geojson_pydantic import MultiPolygon, Polygon from geojson_pydantic.types import BBox from pydantic import BaseModel as PydanticModel @@ -85,22 +84,6 @@ class JoinFilterType(Enum): AND = "AND" -# -# "tags": { # no of rows returned -# "point" : {"amenity":["shop"]}, -# "line" : {}, -# "polygon" : {"key":["value"]}, -# "all_geometry" : {"building":['yes']} -# }, -# "attributes": { # no of columns / name -# "point": [], column -# "line" : [], -# "polygon" : [], -# "all_geometry" : [], -# } -# } - - class SQLFilter(BaseModel): join_or: Optional[Dict[str, List[str]]] join_and: Optional[Dict[str, List[str]]] @@ -126,6 +109,9 @@ class Filters(BaseModel): class RawDataCurrentParamsBase(BaseModel): + output_type: Optional[RawDataOutputType] = Field( + default=RawDataOutputType.GEOJSON.value, example="geojson" + ) geometry_type: Optional[List[SupportedGeometryFilters]] = Field( default=None, example=["point", "polygon"] ) @@ -160,18 +146,6 @@ class RawDataCurrentParamsBase(BaseModel): }, ) - @validator("geometry", always=True) - def check_geometry_area(cls, value, values): - """Validates geom area_m2""" - area_m2 = area(json.loads(value.json())) - area_km2 = area_m2 * 1e-6 - RAWDATA_CURRENT_POLYGON_AREA = int(EXPORT_MAX_AREA_SQKM) - if area_km2 > 100: # 100 square km - raise ValueError( - f"""Polygon Area {int(area_km2)} Sq.KM is higher than Threshold : {RAWDATA_CURRENT_POLYGON_AREA} Sq.KM for {output_type}""" - ) - return value - @validator("geometry_type", allow_reuse=True) def return_unique_value(cls, value): """return unique list""" @@ -179,9 +153,6 @@ def return_unique_value(cls, value): class RawDataCurrentParams(RawDataCurrentParamsBase): - output_type: Optional[RawDataOutputType] = Field( - default=RawDataOutputType.GEOJSON.value, example="geojson" - ) if ENABLE_TILES: min_zoom: Optional[int] = Field( default=None, description="Only for mbtiles" @@ -206,51 +177,6 @@ def check_bind_option(cls, value, values): ) return value - @validator("geometry", always=True) - def check_geometry_area(cls, value, values): - """Validates geom area_m2""" - area_m2 = area(json.loads(value.json())) - area_km2 = area_m2 * 1e-6 - RAWDATA_CURRENT_POLYGON_AREA = int(EXPORT_MAX_AREA_SQKM) - if area_km2 > RAWDATA_CURRENT_POLYGON_AREA: - raise ValueError( - f"""Polygon Area {int(area_km2)} Sq.KM is higher than Threshold : {RAWDATA_CURRENT_POLYGON_AREA} Sq.KM for {output_type}""" - ) - return value - - -class WhereCondition(TypedDict): - key: str - value: List[str] - - -class OsmFeatureType(Enum): - NODES = "nodes" - WAYS_LINE = "ways_line" - WAYS_POLY = "ways_poly" - RELATIONS = "relations" - - -class SnapshotParamsPlain(BaseModel): - bbox: Optional[ - BBox - ] = None # xmin: NumType, ymin: NumType, xmax: NumType, ymax: NumType , srid:4326 - select: Optional[List[str]] = ["*"] - where: List[WhereCondition] = [{"key": "building", "value": ["*"]}] - join_by: Optional[JoinFilterType] = JoinFilterType.OR.value - look_in: Optional[List[OsmFeatureType]] = ["nodes", "ways_poly"] - geometry_type: SupportedGeometryFilters = None - - @validator("select", always=True) - def validate_select_statement(cls, value, values): - """Validates geom area_m2""" - for v in value: - if v != "*" and len(v) < 2: - raise ValueError( - "length of select attribute must be greater than 2 letters" - ) - return value - class SnapshotResponse(BaseModel): task_id: str