diff --git a/ray/deployments/base.py b/ray/deployments/base.py index 8bdca1a..bf6ee9a 100755 --- a/ray/deployments/base.py +++ b/ray/deployments/base.py @@ -226,7 +226,7 @@ def check_health(self): def pre(self) -> Graph: """Logic to execute before execution.""" graph = self.request.deserialize(self.model) - + self.respond( status=BackendResponseModel.JobStatus.RUNNING, description="Your job has started running.", @@ -293,7 +293,7 @@ def exception(self, exception: Exception) -> None: sys.tracebacklimit = None self.respond( status=BackendResponseModel.JobStatus.NNSIGHT_ERROR, - description="An error has occured during the execution of the intervention graph.", + description=f"An error has occured during the execution of the intervention graph.\n{exception.traceback_content}", data={ "err_message": exception.message, "node_id": exception.node_id, diff --git a/schema/mixins.py b/schema/mixins.py index bdc6eed..306f9a4 100644 --- a/schema/mixins.py +++ b/schema/mixins.py @@ -2,7 +2,7 @@ import logging from io import BytesIO -from typing import ClassVar +from typing import ClassVar, TYPE_CHECKING import torch from minio import Minio @@ -10,7 +10,11 @@ from typing_extensions import Self from urllib3.response import HTTPResponse -from nnsight.schema.response import ResponseModel +if TYPE_CHECKING: + from metrics import NDIFGauge + from nnsight.schema.response import ResponseModel + from nnsight.schema.request import RequestModel + class ObjectStorageMixin(BaseModel): """ @@ -146,7 +150,7 @@ class TelemetryMixin: backend_log(logger: logging.Logger, message: str, level: str = 'info') -> Self: Logs a message with the specified logging level (info, error, exception). - update_gauge(gauge: NDIFGauge, request: RequestModel, status: ResponseModel.JobStatus, gpu_mem: int = 0) -> Self: + update_gauge(gauge: NDIFGauge) -> Self: Updates the telemetry gauge to track the status of a request or response. """ def backend_log(self, logger: logging.Logger, message: str, level: str = 'info'): @@ -158,11 +162,26 @@ def backend_log(self, logger: logging.Logger, message: str, level: str = 'info') logger.exception(message) return self - def update_gauge(self, gauge: "NDIFGauge", request: "BackendRequestModel", status: ResponseModel.JobStatus, api_key:str = " ", gpu_mem: int = 0): - gauge.update( - request=request, - api_key=api_key, - status=status, - gpu_mem=gpu_mem - ) + def update_gauge( + self, + gauge: "NDIFGauge", + request: "RequestModel", + status: "ResponseModel.JobStatus", + **kwargs, + ) -> Self: + """ Updates the telemetry gauge to track the status of a request or response. + + Args: + + - gauge (NDIFGauge): Telemetry Gauge. + - request (RequestModel): user request. + - status (ResponseModel.JobStatus): status of the user request. + - kwargs: key word arguments to NDIFGauge.update(). + + Returns: + Self. + """ + + gauge.update(request, status, **kwargs) + return self \ No newline at end of file diff --git a/schema/request.py b/schema/request.py index f85fab7..bf00fda 100644 --- a/schema/request.py +++ b/schema/request.py @@ -2,14 +2,10 @@ import logging import uuid -import zlib from datetime import datetime -from io import BytesIO -from typing import ClassVar, Optional, Union +from typing import TYPE_CHECKING, ClassVar, Optional, Union -import msgspec import ray -import torch from fastapi import Request from pydantic import ConfigDict from typing_extensions import Self @@ -21,8 +17,26 @@ from .mixins import ObjectStorageMixin from .response import BackendResponseModel +if TYPE_CHECKING: + from .metrics import NDIFGauge + class BackendRequestModel(ObjectStorageMixin): + """ + + Attributes: + - model_config: model configuration. + - graph (Union[bytes, ray.ObjectRef]): intervention graph object, could be in multiple forms. + - model_key (str): model key name. + - session_id (Optional[str]): connection session id. + - format (str): format of the request body. + - zlib (bool): is the request body compressed. + - id (str): request id. + - received (datetime.datetime): time of the request being received. + - api_key (str): api key associated with this request. + - _bucket_name (str): request result bucket storage name. + - _file_extension (str): file extension. + """ model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) @@ -38,6 +52,8 @@ class BackendRequestModel(ObjectStorageMixin): id: str received: datetime + + api_key: str def deserialize(self, model: NNsight) -> Graph: @@ -51,7 +67,7 @@ def deserialize(self, model: NNsight) -> Graph: return RequestModel.deserialize(model, graph, 'json', self.zlib) @classmethod - async def from_request(cls, request: Request, put: bool = True) -> Self: + async def from_request(cls, request: Request, api_key: str, put: bool = True) -> Self: headers = request.headers @@ -68,6 +84,7 @@ async def from_request(cls, request: Request, put: bool = True) -> Self: zlib=headers["zlib"], id=str(uuid.uuid4()), received=datetime.now(), + api_key=api_key, ) def create_response( @@ -75,13 +92,13 @@ def create_response( status: ResponseModel.JobStatus, logger: logging.Logger, gauge: "NDIFGauge", - description: str = None, + description: str = "", data: bytes = None, gpu_mem: int = 0, ) -> BackendResponseModel: """Generates a BackendResponseModel given a change in status to an ongoing request.""" - msg = f"{self.id} - {status.name}: {description}" + log_msg = f"{self.id} - {status.name}: {description}" response = ( BackendResponseModel( @@ -94,14 +111,15 @@ def create_response( ) .backend_log( logger=logger, - message=msg, + message=log_msg, ) .update_gauge( gauge=gauge, request=self, status=status, - api_key=" ", + api_key=self.api_key, gpu_mem=gpu_mem, + msg=description, ) ) diff --git a/services/api/environment.yml b/services/api/environment.yml index 698203e..b884e4a 100755 --- a/services/api/environment.yml +++ b/services/api/environment.yml @@ -34,4 +34,4 @@ dependencies: - requests - torch - python-slugify - - git+https://github.com/ndif-team/nnsight@dev \ No newline at end of file + - git+https://github.com/ndif-team/nnsight@0.4 \ No newline at end of file diff --git a/services/api/src/api_key.py b/services/api/src/api_key.py index 8894495..bc6a0d6 100644 --- a/services/api/src/api_key.py +++ b/services/api/src/api_key.py @@ -1,20 +1,21 @@ import os -import uuid import firebase_admin from cachetools import TTLCache, cached -from datetime import datetime -from fastapi import HTTPException, Request, Security -from fastapi.security.api_key import APIKeyHeader +from typing import TYPE_CHECKING from firebase_admin import credentials, firestore from starlette.status import HTTP_401_UNAUTHORIZED +from fastapi import HTTPException -from nnsight.schema.response import ResponseModel from .schema import BackendRequestModel from .metrics import NDIFGauge from .util import check_valid_email from .logging import load_logger +if TYPE_CHECKING: + from fastapi import Request + from .schema import BackendRequestModel + logger = load_logger(service_name='api', logger_name='gunicorn.error') gauge = NDIFGauge(service='app') @@ -55,9 +56,7 @@ def get_uid(self, doc): api_key_store = ApiKeyStore(FIREBASE_CREDS_PATH) -api_key_header = APIKeyHeader(name="ndif-api-key", auto_error=False) - -def extract_request_metadata(raw_request: Request) -> dict: +def extract_request_metadata(raw_request: "Request") -> dict: """ Extracts relevant metadata from the incoming raw request, such as IP address, user agent, and content length, and returns them as a dictionary. @@ -69,52 +68,45 @@ def extract_request_metadata(raw_request: Request) -> dict: } return metadata +def api_key_auth( + raw_request: "Request", + request: "BackendRequestModel", +) -> None: + """ + Authenticates the API request by extracting metadata and initializing the BackendRequestModel + with relevant information, including API key, client details, and headers. + Args: + - raw_request (Request): user request. + - request (BackendRequestModel): user request object. -async def api_key_auth( - raw_request: Request, - api_key: str = Security(api_key_header) -): + Returns: """ - Authenticates the API request by extracting metadata and initializing the - BackendRequestModel with relevant information, including API key, client details, and headers. - """ - # Extract metadata from the raw request + metadata = extract_request_metadata(raw_request) - # TODO: Update the RequestModel to include additional fields (e.g. API key) - request = await BackendRequestModel.from_request(raw_request) - - gauge.update(request, api_key, ResponseModel.JobStatus.RECEIVED) - ip_address, user_agent, content_length = metadata.values() gauge.update_network(request.id, ip_address, user_agent, content_length) if FIREBASE_CREDS_PATH is not None: check_405b = True if request.model_key == llama_405b else False - doc = api_key_store.fetch_document(api_key) + doc = api_key_store.fetch_document(request.api_key) # Check if the API key exists and is valid if api_key_store.does_api_key_exist(doc, check_405b): # Check if the document contains a valid email user_id = api_key_store.get_uid(doc) logger.info(user_id) - if check_valid_email(user_id): - gauge.update(request=request, api_key=api_key, status=ResponseModel.JobStatus.APPROVED, user_id=user_id) - return request - else: + if not check_valid_email(user_id): # Handle case where API key exists but doesn't contain a valid email - gauge.update(request, api_key, ResponseModel.JobStatus.ERROR) raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, detail="Invalid API key: A valid API key must contain an email. Please visit https://login.ndif.us/ to create a new one." ) else: # Handle case where API key does not exist or is invalid - gauge.update(request, api_key, ResponseModel.JobStatus.ERROR) raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, detail="Missing or invalid API key. Please visit https://login.ndif.us/ to create a new one." ) - return request diff --git a/services/api/src/app.py b/services/api/src/app.py index 9f6fdfa..47a63a1 100755 --- a/services/api/src/app.py +++ b/services/api/src/app.py @@ -1,12 +1,13 @@ import asyncio import os from contextlib import asynccontextmanager -from typing import Annotated, Any, Dict +from typing import Any, Dict import ray import socketio import uvicorn -from fastapi import Depends, FastAPI, Path, Request +from fastapi import FastAPI, Security, Request +from fastapi.security.api_key import APIKeyHeader from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from fastapi_cache import FastAPICache @@ -17,6 +18,7 @@ from prometheus_fastapi_instrumentator import Instrumentator from ray import serve from slugify import slugify +from datetime import datetime from nnsight.schema.request import StreamValueModel from nnsight.schema.response import ResponseModel @@ -76,26 +78,46 @@ async def lifespan(app: FastAPI): # Prometheus instrumentation (for metrics) Instrumentator().instrument(app).expose(app) +api_key_header = APIKeyHeader(name="ndif-api-key", auto_error=False) @app.post("/request") -async def request( - request: BackendRequestModel = Depends(api_key_auth), -) -> BackendResponseModel: +async def request(raw_request: Request, api_key: str = Security(api_key_header)) -> BackendResponseModel: """Endpoint to submit request. - Args: - request (BackendRequestModel): _description_ + Header: + - api_key: user api key. + + Request Body: + raw_request (Request): user request containing the intervention graph. Returns: - BackendResponseModel: _description_ + BackendResponseModel: reponse to the user request. """ - try: - # Send to request workers waiting to process requests on the "request" queue. - # Forget as we don't care about the response. - serve.get_app_handle("Request").remote(request) - # Create response object. - # Log and save to data backend. + # extract the request data + try: + request: BackendRequestModel = await BackendRequestModel.from_request(raw_request, api_key) + except Exception as e: + logger.error(f"{ResponseModel.JobStatus.ERROR.name}: {str(e)}") + labels = { + "request_id": "", + "api_key": api_key, + "model_key": "", + "gpu_mem": "", + "timestamp": str( + datetime.now() + ), # Ensure timestamp is string for consistency + "user_id": "", + "msg": str(e), + } + + gauge._gauge.labels(**labels).set(gauge.NumericJobStatus[ResponseModel.JobStatus.ERROR].value) + + raise e + + # process the request + try: + response = request.create_response( status=ResponseModel.JobStatus.RECEIVED, description="Your job has been received and is waiting approval.", @@ -103,11 +125,17 @@ async def request( gauge=gauge, ) + # authenticate api key + api_key_auth(raw_request, request) + + # Send to request workers waiting to process requests on the "request" queue. + # Forget as we don't care about the response. + serve.get_app_handle("Request").remote(request) + # Back up request object by default (to be deleted on successful completion) # request = request.model_copy() # request.object = object # request.save(object_store) - except Exception as exception: # Create exception response object. diff --git a/telemetry/grafana/dashboards/telemetry/request-table.json b/telemetry/grafana/dashboards/telemetry/request-table.json index f738d8d..89d2f03 100644 --- a/telemetry/grafana/dashboards/telemetry/request-table.json +++ b/telemetry/grafana/dashboards/telemetry/request-table.json @@ -23,7 +23,7 @@ { "datasource": { "type": "prometheus", - "uid": "cdwtb84j7ug3ka" + "uid": "PBFA97CFB590B2093" }, "description": "", "fieldConfig": { @@ -176,7 +176,7 @@ { "datasource": { "type": "prometheus", - "uid": "cdwtb84j7ug3ka" + "uid": "PBFA97CFB590B2093" }, "disableTextWrap": false, "editorMode": "code", diff --git a/telemetry/grafana/provisioning/datasources/prometheus.yml b/telemetry/grafana/provisioning/datasources/prometheus.yml index 4aa9fb5..8ef9169 100644 --- a/telemetry/grafana/provisioning/datasources/prometheus.yml +++ b/telemetry/grafana/provisioning/datasources/prometheus.yml @@ -5,5 +5,3 @@ datasources: type: prometheus access: proxy url: http://localhost:9090 - isDefault: true - editable: true diff --git a/telemetry/metrics/gauge.py b/telemetry/metrics/gauge.py index 9eb593e..bb9c614 100644 --- a/telemetry/metrics/gauge.py +++ b/telemetry/metrics/gauge.py @@ -14,6 +14,7 @@ "gpu_mem", "timestamp", "user_id", + "msg", ) network_labels = ("request_id", "ip_address", "user_agent") @@ -82,14 +83,25 @@ def _initialize_network_gauge(self): def update( self, request: RequestModel, - api_key: str, status: ResponseModel.JobStatus, - user_id=None, + api_key: str = "", + user_id: str = None, gpu_mem: int = 0, + msg: str = "", ) -> None: """ Update the values of the gauge to reflect the current status of a request. Handles both Ray and Prometheus Gauge APIs. + + Args: + - request (RequestModel): request object. + - status (ResponseModel.JobStatus): user request job status. + - api_key (str): user api key. + - user_id (str): + - gpu_mem (int): gpu memory utilization. + - msg (str): description of the current job status of the request. + + Returns: """ numeric_status = int(self.NumericJobStatus[status.value].value) @@ -102,6 +114,7 @@ def update( request.received ), # Ensure timestamp is string for consistency "user_id": str(user_id) if user_id is not None else " ", + "msg": msg, } if self.service == "ray":