Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Observability 0.4 #76

Open
wants to merge 2 commits into
base: 0.4
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ray/deployments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def check_health(self):
def pre(self) -> Graph:
"""Logic to execute before execution."""
graph = self.request.deserialize(self.model)

self.respond(
status=BackendResponseModel.JobStatus.RUNNING,
description="Your job has started running.",
Expand Down Expand Up @@ -293,7 +293,7 @@ def exception(self, exception: Exception) -> None:
sys.tracebacklimit = None
self.respond(
status=BackendResponseModel.JobStatus.NNSIGHT_ERROR,
description="An error has occured during the execution of the intervention graph.",
description=f"An error has occured during the execution of the intervention graph.\n{exception.traceback_content}",
data={
"err_message": exception.message,
"node_id": exception.node_id,
Expand Down
39 changes: 29 additions & 10 deletions schema/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@

import logging
from io import BytesIO
from typing import ClassVar
from typing import ClassVar, TYPE_CHECKING

import torch
from minio import Minio
from pydantic import BaseModel
from typing_extensions import Self
from urllib3.response import HTTPResponse

from nnsight.schema.response import ResponseModel
if TYPE_CHECKING:
from metrics import NDIFGauge
from nnsight.schema.response import ResponseModel
from nnsight.schema.request import RequestModel


class ObjectStorageMixin(BaseModel):
"""
Expand Down Expand Up @@ -146,7 +150,7 @@ class TelemetryMixin:
backend_log(logger: logging.Logger, message: str, level: str = 'info') -> Self:
Logs a message with the specified logging level (info, error, exception).

update_gauge(gauge: NDIFGauge, request: RequestModel, status: ResponseModel.JobStatus, gpu_mem: int = 0) -> Self:
update_gauge(gauge: NDIFGauge) -> Self:
Updates the telemetry gauge to track the status of a request or response.
"""
def backend_log(self, logger: logging.Logger, message: str, level: str = 'info'):
Expand All @@ -158,11 +162,26 @@ def backend_log(self, logger: logging.Logger, message: str, level: str = 'info')
logger.exception(message)
return self

def update_gauge(self, gauge: "NDIFGauge", request: "BackendRequestModel", status: ResponseModel.JobStatus, api_key:str = " ", gpu_mem: int = 0):
gauge.update(
request=request,
api_key=api_key,
status=status,
gpu_mem=gpu_mem
)
def update_gauge(
self,
gauge: "NDIFGauge",
request: "RequestModel",
status: "ResponseModel.JobStatus",
**kwargs,
) -> Self:
""" Updates the telemetry gauge to track the status of a request or response.

Args:

- gauge (NDIFGauge): Telemetry Gauge.
- request (RequestModel): user request.
- status (ResponseModel.JobStatus): status of the user request.
- kwargs: key word arguments to NDIFGauge.update().

Returns:
Self.
"""

gauge.update(request, status, **kwargs)

return self
38 changes: 28 additions & 10 deletions schema/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,10 @@

import logging
import uuid
import zlib
from datetime import datetime
from io import BytesIO
from typing import ClassVar, Optional, Union
from typing import TYPE_CHECKING, ClassVar, Optional, Union

import msgspec
import ray
import torch
from fastapi import Request
from pydantic import ConfigDict
from typing_extensions import Self
Expand All @@ -21,8 +17,26 @@
from .mixins import ObjectStorageMixin
from .response import BackendResponseModel

if TYPE_CHECKING:
from .metrics import NDIFGauge


class BackendRequestModel(ObjectStorageMixin):
"""

Attributes:
- model_config: model configuration.
- graph (Union[bytes, ray.ObjectRef]): intervention graph object, could be in multiple forms.
- model_key (str): model key name.
- session_id (Optional[str]): connection session id.
- format (str): format of the request body.
- zlib (bool): is the request body compressed.
- id (str): request id.
- received (datetime.datetime): time of the request being received.
- api_key (str): api key associated with this request.
- _bucket_name (str): request result bucket storage name.
- _file_extension (str): file extension.
"""

model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())

Expand All @@ -38,6 +52,8 @@ class BackendRequestModel(ObjectStorageMixin):

id: str
received: datetime

api_key: str


def deserialize(self, model: NNsight) -> Graph:
Expand All @@ -51,7 +67,7 @@ def deserialize(self, model: NNsight) -> Graph:
return RequestModel.deserialize(model, graph, 'json', self.zlib)

@classmethod
async def from_request(cls, request: Request, put: bool = True) -> Self:
async def from_request(cls, request: Request, api_key: str, put: bool = True) -> Self:

headers = request.headers

Expand All @@ -68,20 +84,21 @@ async def from_request(cls, request: Request, put: bool = True) -> Self:
zlib=headers["zlib"],
id=str(uuid.uuid4()),
received=datetime.now(),
api_key=api_key,
)

def create_response(
self,
status: ResponseModel.JobStatus,
logger: logging.Logger,
gauge: "NDIFGauge",
description: str = None,
description: str = "",
data: bytes = None,
gpu_mem: int = 0,
) -> BackendResponseModel:
"""Generates a BackendResponseModel given a change in status to an ongoing request."""

msg = f"{self.id} - {status.name}: {description}"
log_msg = f"{self.id} - {status.name}: {description}"

response = (
BackendResponseModel(
Expand All @@ -94,14 +111,15 @@ def create_response(
)
.backend_log(
logger=logger,
message=msg,
message=log_msg,
)
.update_gauge(
gauge=gauge,
request=self,
status=status,
api_key=" ",
api_key=self.api_key,
gpu_mem=gpu_mem,
msg=description,
)
)

Expand Down
2 changes: 1 addition & 1 deletion services/api/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ dependencies:
- requests
- torch
- python-slugify
- git+https://github.com/ndif-team/nnsight@dev
- git+https://github.com/ndif-team/nnsight@0.4
50 changes: 21 additions & 29 deletions services/api/src/api_key.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import os

import uuid
import firebase_admin
from cachetools import TTLCache, cached
from datetime import datetime
from fastapi import HTTPException, Request, Security
from fastapi.security.api_key import APIKeyHeader
from typing import TYPE_CHECKING
from firebase_admin import credentials, firestore
from starlette.status import HTTP_401_UNAUTHORIZED
from fastapi import HTTPException

from nnsight.schema.response import ResponseModel
from .schema import BackendRequestModel
from .metrics import NDIFGauge
from .util import check_valid_email
from .logging import load_logger

if TYPE_CHECKING:
from fastapi import Request
from .schema import BackendRequestModel

logger = load_logger(service_name='api', logger_name='gunicorn.error')
gauge = NDIFGauge(service='app')

Expand Down Expand Up @@ -55,9 +56,7 @@ def get_uid(self, doc):

api_key_store = ApiKeyStore(FIREBASE_CREDS_PATH)

api_key_header = APIKeyHeader(name="ndif-api-key", auto_error=False)

def extract_request_metadata(raw_request: Request) -> dict:
def extract_request_metadata(raw_request: "Request") -> dict:
"""
Extracts relevant metadata from the incoming raw request, such as IP address,
user agent, and content length, and returns them as a dictionary.
Expand All @@ -69,52 +68,45 @@ def extract_request_metadata(raw_request: Request) -> dict:
}
return metadata

def api_key_auth(
raw_request: "Request",
request: "BackendRequestModel",
) -> None:
"""
Authenticates the API request by extracting metadata and initializing the BackendRequestModel
with relevant information, including API key, client details, and headers.

Args:
- raw_request (Request): user request.
- request (BackendRequestModel): user request object.

async def api_key_auth(
raw_request: Request,
api_key: str = Security(api_key_header)
):
Returns:
"""
Authenticates the API request by extracting metadata and initializing the
BackendRequestModel with relevant information, including API key, client details, and headers.
"""
# Extract metadata from the raw request

metadata = extract_request_metadata(raw_request)

# TODO: Update the RequestModel to include additional fields (e.g. API key)
request = await BackendRequestModel.from_request(raw_request)

gauge.update(request, api_key, ResponseModel.JobStatus.RECEIVED)

ip_address, user_agent, content_length = metadata.values()
gauge.update_network(request.id, ip_address, user_agent, content_length)

if FIREBASE_CREDS_PATH is not None:
check_405b = True if request.model_key == llama_405b else False

doc = api_key_store.fetch_document(api_key)
doc = api_key_store.fetch_document(request.api_key)

# Check if the API key exists and is valid
if api_key_store.does_api_key_exist(doc, check_405b):
# Check if the document contains a valid email
user_id = api_key_store.get_uid(doc)
logger.info(user_id)
if check_valid_email(user_id):
gauge.update(request=request, api_key=api_key, status=ResponseModel.JobStatus.APPROVED, user_id=user_id)
return request
else:
if not check_valid_email(user_id):
# Handle case where API key exists but doesn't contain a valid email
gauge.update(request, api_key, ResponseModel.JobStatus.ERROR)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid API key: A valid API key must contain an email. Please visit https://login.ndif.us/ to create a new one."
)
else:
# Handle case where API key does not exist or is invalid
gauge.update(request, api_key, ResponseModel.JobStatus.ERROR)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Missing or invalid API key. Please visit https://login.ndif.us/ to create a new one."
)
return request
Loading