Skip to content

Commit

Permalink
Merge pull request #167 from hotosm/feature/terminate_tasks
Browse files Browse the repository at this point in the history
Feature Terminate Tasks  &  Authentication
  • Loading branch information
kshitijrajsharma authored Nov 15, 2023
2 parents c93fde1 + 9fe3213 commit 2d9e34d
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 102 deletions.
2 changes: 1 addition & 1 deletion API/api_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from src.query_builder.builder import format_file_name_str
from src.validation.models import RawDataOutputType

celery = Celery(__name__)
celery = Celery("Raw Data API")
celery.conf.broker_url = celery_broker_uri
celery.conf.result_backend = celery_backend
celery.conf.task_serializer = "pickle"
Expand Down
41 changes: 37 additions & 4 deletions API/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,53 @@
from enum import Enum
from typing import Union

from fastapi import Header
from fastapi import Depends, Header, HTTPException
from osm_login_python.core import Auth
from pydantic import BaseModel
from pydantic import BaseModel, Field

from src.config import get_oauth_credentials
from src.config import ADMIN_IDS, get_oauth_credentials


class UserRole(Enum):
ADMIN = 1
STAFF = 2
GUEST = 3


class AuthUser(BaseModel):
id: int
username: str
img_url: Union[str, None]
role: UserRole = Field(default=UserRole.GUEST.value)


osm_auth = Auth(*get_oauth_credentials())


def is_admin(osm_id: int):
admin_ids = [int(admin_id) for admin_id in ADMIN_IDS]
return osm_id in admin_ids


def login_required(access_token: str = Header(...)):
return osm_auth.deserialize_access_token(access_token)
user = AuthUser(**osm_auth.deserialize_access_token(access_token))
if is_admin(user.id):
user.role = UserRole.ADMIN
return user


def get_optional_user(access_token: str = Header(default=None)) -> AuthUser:
if access_token:
user = AuthUser(**osm_auth.deserialize_access_token(access_token))
if is_admin(user.id):
user.role = UserRole.ADMIN
return user
else:
# If no token provided, return a user with limited options or guest user
return AuthUser(id=0, username="guest", img_url=None)


def admin_required(user: AuthUser = Depends(login_required)):
if not is_admin(user.id):
raise HTTPException(status_code=403, detail="User is not an admin")
return user
2 changes: 1 addition & 1 deletion API/auth/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from fastapi import APIRouter, Depends, Request

from . import AuthUser, login_required, osm_auth
from . import AuthUser, admin_required, login_required, osm_auth

router = APIRouter(prefix="/auth")

Expand Down
11 changes: 9 additions & 2 deletions API/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from src.config import logger as logging
from src.db_session import database_instance

from .auth.routers import router as auth_router
from .raw_data import router as raw_data_router
from .tasks import router as tasks_router

Expand All @@ -58,10 +59,16 @@
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"

app = FastAPI(title="Raw Data API ")
# app.include_router(auth_router)
app.include_router(auth_router)
app.include_router(raw_data_router)
app.include_router(tasks_router)

app.openapi = {
"info": {
"title": "Raw Data API",
"version": "1.0",
},
"security": [{"OAuth2PasswordBearer": []}],
}

app = VersionedFastAPI(
app, enable_latest=False, version_format="{major}", prefix_format="/v{major}"
Expand Down
54 changes: 50 additions & 4 deletions API/raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@

"""[Router Responsible for Raw data API ]
"""
import json
import os
import shutil
import time

import requests
from fastapi import APIRouter, Body, Request
from area import area
from fastapi import APIRouter, Body, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi_versioning import version
from geojson import FeatureCollection

from src.app import RawData
from src.config import ALLOW_BIND_ZIP_FILTER, EXPORT_MAX_AREA_SQKM
from src.config import LIMITER as limiter
from src.config import RATE_LIMIT_PER_MIN as export_rate_limit
from src.config import logger as logging
Expand All @@ -41,6 +44,7 @@
)

from .api_worker import process_raw_data
from .auth import AuthUser, UserRole, get_optional_user

router = APIRouter(prefix="")

Expand Down Expand Up @@ -421,6 +425,7 @@ def get_osm_current_snapshot_as_file(
},
},
),
user: AuthUser = Depends(get_optional_user),
):
"""Generates the current raw OpenStreetMap data available on database based on the input geometry, query and spatial features.
Expand All @@ -434,17 +439,45 @@ def get_osm_current_snapshot_as_file(
2. Now navigate to /tasks/ with your task id to track progress and result
"""

if not (user.role == UserRole.STAFF or user.role == UserRole.ADMIN):
area_m2 = area(json.loads(params.geometry.json()))
area_km2 = area_m2 * 1e-6
RAWDATA_CURRENT_POLYGON_AREA = int(EXPORT_MAX_AREA_SQKM)
if area_km2 > RAWDATA_CURRENT_POLYGON_AREA:
raise HTTPException(
status_code=400,
detail=[
{
"msg": f"""Polygon Area {int(area_km2)} Sq.KM is higher than Threshold : {RAWDATA_CURRENT_POLYGON_AREA} Sq.KM"""
}
],
)
if not params.uuid:
raise HTTPException(
status_code=403,
detail=[{"msg": "Insufficient Permission for uuid = False"}],
)
if ALLOW_BIND_ZIP_FILTER:
if not params.bind_zip:
raise HTTPException(
status_code=403,
detail=[{"msg": "Insufficient Permission for bind_zip"}],
)

queue_name = "recurring_queue" if not params.uuid else "raw_default"
task = process_raw_data.apply_async(args=(params,), queue=queue_name)
return JSONResponse({"task_id": task.id, "track_link": f"/tasks/status/{task.id}/"})


@router.get("/snapshot/plain/", response_model=FeatureCollection)
@router.post("/snapshot/plain/", response_model=FeatureCollection)
@version(1)
def get_osm_current_snapshot_as_plain_geojson(
request: Request, params: RawDataCurrentParamsBase
request: Request,
params: RawDataCurrentParamsBase,
user: AuthUser = Depends(get_optional_user),
):
"""Generates the Plain geojson for the polygon within 100 Sqkm and returns the result right away
"""Generates the Plain geojson for the polygon within 30 Sqkm and returns the result right away
Args:
request (Request): _description_
Expand All @@ -453,6 +486,19 @@ def get_osm_current_snapshot_as_plain_geojson(
Returns:
Featurecollection: Geojson
"""
if not (user.role == UserRole.STAFF or user.role == UserRole.ADMIN):
area_m2 = area(json.loads(params.geometry.json()))
area_km2 = area_m2 * 1e-6
if area_km2 > 30:
raise HTTPException(
status_code=400,
detail=[
{
"msg": f"""Polygon Area {int(area_km2)} Sq.KM is higher than Threshold : 30 Sq.KM"""
}
],
)
params.output_type = "geojson" # always geojson
result = RawData(params).extract_plain_geojson()
return result

Expand Down
55 changes: 54 additions & 1 deletion API/tasks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from celery.result import AsyncResult
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from fastapi_versioning import version

from src.validation.models import SnapshotTaskResponse

from .api_worker import celery
from .auth import AuthUser, admin_required, login_required

router = APIRouter(prefix="/tasks")

Expand Down Expand Up @@ -35,3 +36,55 @@ def get_task_status(task_id):
"result": task_result.result if task_result.status == "SUCCESS" else None,
}
return JSONResponse(result)


@router.get("/revoke/{task_id}/")
@version(1)
def revoke_task(task_id, user: AuthUser = Depends(login_required)):
"""Revokes task , Terminates if it is executing
Args:
task_id (_type_): task id of raw data task
Returns:
status: status of revoked task
"""
revoked_task = celery.control.revoke(task_id=task_id, terminate=True)
return JSONResponse({"id": task_id, "status": revoked_task})


@router.get("/inspect/")
@version(1)
def inspect_workers():
"""Inspects tasks assigned to workers
Returns:
scheduled: All scheduled tasks to be picked up by workers
active: Current Active tasks ongoing on workers
"""
inspected = celery.control.inspect()
return JSONResponse(
{"scheduled": str(inspected.scheduled()), "active": str(inspected.active())}
)


@router.get("/ping/")
@version(1)
def ping_workers():
"""Pings available workers
Returns: {worker_name : return_result}
"""
inspected_ping = celery.control.inspect().ping()
return JSONResponse(inspected_ping)


@router.get("/purge/")
@version(1)
def discard_all_waiting_tasks(user: AuthUser = Depends(admin_required)):
"""
Discards all waiting tasks from the queue
Returns : Number of tasks discarded
"""
purged = celery.control.purge()
return JSONResponse({"tasks_discarded": purged})
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,11 @@ You should be able to start [celery](https://docs.celeryq.dev/en/stable/getting-

- Start for default queue
```
celery --app API.api_worker worker --loglevel=INFO --queues="raw_default"
celery --app API.api_worker worker --loglevel=INFO --queues="raw_default" -n 'default_worker'
```
- Start for recurring queue
```
celery --app API.api_worker worker --loglevel=INFO --queues="recurring_queue"
celery --app API.api_worker worker --loglevel=INFO --queues="recurring_queue" -n 'recurring_worker'
```

Set no of request that a worker can take at a time by using --concurrency
Expand Down
2 changes: 2 additions & 0 deletions docs/src/installation/configurations.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ The following are the different configuration options that are accepted.
| `LOGIN_REDIRECT_URI` | `LOGIN_REDIRECT_URI` | `[OAUTH]` | _none_ | Redirect URL set in the OAuth2 application | REQUIRED |
| `APP_SECRET_KEY` | `APP_SECRET_KEY` | `[OAUTH]` | _none_ | High-entropy string generated for the application | REQUIRED |
| `OSM_URL` | `OSM_URL` | `[OAUTH]` | `https://www.openstreetmap.org` | OSM instance Base URL | OPTIONAL |
| `ADMIN_IDS` | `ADMIN_IDS` | `[OAUTH]` | `00000` | List of Admin OSMId separated by , | OPTIONAL |
| `LOG_LEVEL` | `LOG_LEVEL` | `[API_CONFIG]` | `debug` | Application log level; info,debug,warning,error | OPTIONAL |
| `RATE_LIMITER_STORAGE_URI` | `RATE_LIMITER_STORAGE_URI` | `[API_CONFIG]` | `redis://redis:6379` | Redis connection string for rate-limiter data | OPTIONAL |
| `RATE_LIMIT_PER_MIN` | `RATE_LIMIT_PER_MIN` | `[API_CONFIG]` | `5` | Number of requests per minute before being rate limited | OPTIONAL |
Expand Down Expand Up @@ -67,6 +68,7 @@ The following are the different configuration options that are accepted.
| `LOGIN_REDIRECT_URI` | TBD | Yes | No |
| `APP_SECRET_KEY` | TBD | Yes | No |
| `OSM_URL` | TBD | Yes | No |
| `ADMIN_IDS` | TBD | Yes | No |
| `LOG_LEVEL` | `[API_CONFIG]` | Yes | Yes |
| `RATE_LIMITER_STORAGE_URI` | `[API_CONFIG]` | Yes | No |
| `RATE_LIMIT_PER_MIN` | `[API_CONFIG]` | Yes | No |
Expand Down
27 changes: 17 additions & 10 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
AWS_SECRET_ACCESS_KEY,
BUCKET_NAME,
ENABLE_TILES,
EXPORT_MAX_AREA_SQKM,
)
from src.config import EXPORT_PATH as export_path
from src.config import INDEX_THRESHOLD as index_threshold
Expand Down Expand Up @@ -465,15 +466,21 @@ def get_grid_id(geom, cur):
countries = backend_match[0]
country_export = True
logging.debug(f"Using Country Export Mode with id : {countries[0]}")
else:
if int(geom_area) > int(index_threshold):
# this will be applied only when polygon gets bigger we will be slicing index size to search
country_query = get_country_id_query(geometry_dump)
cur.execute(country_query)
result_country = cur.fetchall()
countries = [int(f[0]) for f in result_country]
logging.debug(f"Intersected Countries : {countries}")
cur.close()
else :
if int(geom_area) > int(EXPORT_MAX_AREA_SQKM):
raise ValueError(
f"""Polygon Area {int(geom_area)} Sq.KM is higher than Threshold : {EXPORT_MAX_AREA_SQKM} Sq.KM"""
)

# else:
# if int(geom_area) > int(index_threshold):
# # this will be applied only when polygon gets bigger we will be slicing index size to search
# country_query = get_country_id_query(geometry_dump)
# cur.execute(country_query)
# result_country = cur.fetchall()
# countries = [int(f[0]) for f in result_country]
# logging.debug(f"Intersected Countries : {countries}")
# cur.close()
return (
g_id,
geometry_dump,
Expand Down Expand Up @@ -663,7 +670,7 @@ def get_osm_feature(self, osm_id):
return FeatureCollection(features=features)

def extract_plain_geojson(self):
"""Gets geojson for small area : Performs direct query with/without geometry"""
"""Gets geojson for small area Returns plain geojson without binding"""
extraction_query = raw_currentdata_extraction_query(self.params)
features = []

Expand Down
7 changes: 7 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@
"API_CONFIG", "ENABLE_TILES", fallback=None
)

######

ADMIN_IDS = os.environ.get("ADMIN_IDS") or config.get(
"OAUTH", "ADMIN_IDS", fallback="00000"
).split(",")


####################

### EXPORT_UPLOAD CONFIG BLOCK
Expand Down
Loading

0 comments on commit 2d9e34d

Please sign in to comment.