Skip to content
This repository has been archived by the owner on Jun 13, 2023. It is now read-only.

Commit

Permalink
feat(fastapi): support async endpoint handlers (#416)
Browse files Browse the repository at this point in the history
  • Loading branch information
sagivr2020 authored Jul 20, 2022
1 parent 58a6168 commit 5fb0295
Show file tree
Hide file tree
Showing 6 changed files with 358 additions and 18 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,8 @@ Advanced options can be configured as a parameter to the init() method or as env
|- |EPSAGON_LAMBDA_TIMEOUT_THRESHOLD_MS |Integer|`200` |The threshold in millieseconds to send the trace before a Lambda timeout occurs |
|- |EPSAGON_PAYLOADS_TO_IGNORE |List |- |Array of dictionaries to not instrument. Example: `'[{"source": "serverless-plugin-warmup"}]'` |
|- |EPSAGON_REMOVE_EXCEPTION_FRAMES|Boolean|`False` |Disable the automatic capture of exception frames data (Python 3) |
|- |EPSAGON_FASTAPI_ASYNC_MODE|Boolean|`False` |Enable capturing of Fast API async endpoint handlers calls(Python 3) |




Expand Down
87 changes: 86 additions & 1 deletion epsagon/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from epsagon.trace_transports import NoneTransport, HTTPTransport, LogTransport
from .constants import (
TIMEOUT_GRACE_TIME_MS,
EPSAGON_MARKER,
MAX_LABEL_SIZE,
DEFAULT_SAMPLE_RATE,
TRACE_URL_PREFIX,
Expand All @@ -35,6 +36,8 @@
DEFAULT_MAX_TRACE_SIZE_BYTES = 64 * (2 ** 10)
MAX_METADATA_FIELD_SIZE_LIMIT = 1024 * 3
FAILED_TO_SERIALIZE_MESSAGE = 'Failed to serialize returned object to JSON'
# check if python version is 3.7 and above
IS_PY_VERSION_ABOVE_3_6 = sys.version_info[0] == 3 and sys.version_info[1] > 6


# pylint: disable=invalid-name
Expand Down Expand Up @@ -95,6 +98,7 @@ def __init__(self):
self.keys_to_ignore = None
self.keys_to_allow = None
self.use_single_trace = True
self.use_async_tracer = False
self.singleton_trace = None
self.local_thread_to_unique_id = {}
self.transport = NoneTransport()
Expand Down Expand Up @@ -200,11 +204,25 @@ def update_tracers(self):
tracer.step_dict_output_path = self.step_dict_output_path
tracer.sample_rate = self.sample_rate

def switch_to_async_tracer(self):
"""
Set the use_async_tracer flag to True.
:return: None
"""
self.use_async_tracer = True

def is_async_tracer(self):
"""
Returns whether using an async tracer
"""
return self.use_async_tracer

def switch_to_multiple_traces(self):
"""
Set the use_single_trace flag to False.
:return: None
"""
self.use_async_tracer = False
self.use_single_trace = False

def _create_new_trace(self, unique_id=None):
Expand Down Expand Up @@ -233,6 +251,58 @@ def _create_new_trace(self, unique_id=None):
unique_id=unique_id,
)

@staticmethod
def _get_current_task():
"""
Gets the current asyncio task safely
:return: The task.
"""
# Dynamic import since this is only valid in Python3+
asyncio = __import__('asyncio')

#check if python version 3.7 and above
if IS_PY_VERSION_ABOVE_3_6:
get_event_loop = asyncio.get_event_loop
get_current_task = asyncio.current_task
else:
get_event_loop = asyncio.events._get_running_loop # pylint: disable=W0212
get_current_task = asyncio.events._get_running_loop # pylint: disable=W0212
try:
if not get_event_loop():
return None
return get_current_task()
except Exception: # pylint: disable=broad-except
return None

def _get_tracer_async_mode(self, should_create):
"""
Get trace assuming async tracer.
:return: The trace.
"""
task = type(self)._get_current_task()
if not task:
return None

trace = getattr(task, EPSAGON_MARKER, None)
if not trace and should_create:
trace = self._create_new_trace()
setattr(task, EPSAGON_MARKER, trace)
return trace

def _pop_trace_async_mode(self):
"""
Pops the trace from the current task, assuming async tracer
:return: The trace.
"""
task = type(self)._get_current_task()
if not task:
return None

trace = getattr(task, EPSAGON_MARKER, None)
if trace: # can safely remove tracer from async task
delattr(task, EPSAGON_MARKER)
return trace

def get_or_create_trace(self, unique_id=None):
"""
Gets or create a trace - thread-safe
Expand Down Expand Up @@ -267,6 +337,9 @@ def _get_trace(self, unique_id=None, should_create=False):
:return: The trace.
"""
with TraceFactory.LOCK:
if self.use_async_tracer:
return self._get_tracer_async_mode(should_create=should_create)

unique_id = self.get_thread_local_unique_id(unique_id)
if unique_id:
trace = (
Expand Down Expand Up @@ -321,6 +394,8 @@ def pop_trace(self, trace=None):
:return: unique id
"""
with self.LOCK:
if self.use_async_tracer:
return self._pop_trace_async_mode()
if self.traces:
trace = self.traces.pop(self.get_trace_identifier(trace), None)
if not self.traces:
Expand All @@ -338,6 +413,11 @@ def get_thread_local_unique_id(self, unique_id=None):
:param unique_id: input unique id
:return: active id if there's an active unique id or given one
"""
if self.is_async_tracer():
return self.local_thread_to_unique_id.get(
type(self)._get_current_task(), unique_id
)

return self.local_thread_to_unique_id.get(
get_thread_id(), unique_id
)
Expand All @@ -353,7 +433,12 @@ def set_thread_local_unique_id(self, unique_id=None):
self.singleton_trace.unique_id if self.singleton_trace else None
)
)
self.local_thread_to_unique_id[get_thread_id()] = unique_id

if self.is_async_tracer():
self.local_thread_to_unique_id[
type(self)._get_current_task()] = unique_id
else:
self.local_thread_to_unique_id[get_thread_id()] = unique_id
return unique_id

def unset_thread_local_unique_id(self):
Expand Down
105 changes: 98 additions & 7 deletions epsagon/wrappers/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import json.decoder
import asyncio
import os

import warnings
from fastapi import Request, Response
Expand All @@ -29,9 +30,17 @@
SCOPE_UNIQUE_ID = 'trace_unique_id'
SCOPE_CONTAINER_METADATA_COLLECTED = 'container_metadata'
SCOPE_IGNORE_REQUEST = 'ignore_request'
IS_ASYNC_MODE = False

def _initialize_async_mode(mode):
global IS_ASYNC_MODE # pylint: disable=global-statement
IS_ASYNC_MODE = mode

_initialize_async_mode(os.getenv(
'EPSAGON_FASTAPI_ASYNC_MODE', 'FALSE') == 'TRUE')

def _handle_wrapper_params(_args, kwargs, original_request_param_name):
"""
"""f
Handles the sync/async given parameters - gets the request object
If original handler is set to get the Request object, then getting the
request using this param. Otherwise, trying to get the Request object using
Expand Down Expand Up @@ -222,6 +231,71 @@ def _fastapi_handler(
raised_err
)

async def _async_fastapi_handler(
original_handler,
request,
status_code,
args,
kwargs
):
"""
FastAPI generic handler - for callbacks executed by a threadpool
:param original_handler: the wrapped original handler
:param request: the given handler request
:param status_code: the default configured response status code.
Can be None when called by exception handlers wrapper, as there's
no status code configuration for exception handlers.
"""
has_setup_succeeded = False
should_ignore_request = False

try:
epsagon_scope, trace = _setup_handler(request)
if epsagon_scope and trace:
has_setup_succeeded = True
if (
ignore_request('', request.url.path.lower())
or
is_ignored_endpoint(request.url.path.lower())
):
should_ignore_request = True
epsagon_scope[SCOPE_IGNORE_REQUEST] = True

except Exception: # pylint: disable=broad-except
has_setup_succeeded = False

if not has_setup_succeeded or should_ignore_request:
return await original_handler(*args, **kwargs)

created_runner = False
response = None
if not trace.runner:
if not _setup_trace_runner(epsagon_scope, trace, request):
return await original_handler(*args, **kwargs)

raised_err = None
try:
response = await original_handler(*args, **kwargs)
except Exception as exception: # pylint: disable=W0703
raised_err = exception
finally:
try:
epsagon.trace.trace_factory.unset_thread_local_unique_id()
except Exception: # pylint: disable=broad-except
pass
# no need to update request body if runner already created before
if created_runner:
_extract_request_body(trace, request)

return _handle_response(
epsagon_scope,
response,
status_code,
trace,
raised_err
)



# pylint: disable=too-many-statements
def _wrap_handler(dependant, status_code):
Expand All @@ -230,9 +304,6 @@ def _wrap_handler(dependant, status_code):
"""
original_handler = dependant.call
is_async = asyncio.iscoroutinefunction(original_handler)
if is_async:
# async endpoints are not supported
return

original_request_param_name = dependant.request_param_name
if not original_request_param_name:
Expand All @@ -249,7 +320,23 @@ def wrapped_handler(*args, **kwargs):
original_handler, request, status_code, args, kwargs
)

dependant.call = wrapped_handler
async def async_wrapped_handler(*args, **kwargs):
"""
Asynchronous wrapper handler
"""
request: Request = _handle_wrapper_params(
args, kwargs, original_request_param_name
)
return await _async_fastapi_handler(
original_handler, request, status_code, args, kwargs
)

if is_async and IS_ASYNC_MODE:
# async endpoints
dependant.call = async_wrapped_handler

elif not is_async and not IS_ASYNC_MODE:
dependant.call = wrapped_handler


def route_class_wrapper(wrapped, instance, args, kwargs):
Expand Down Expand Up @@ -280,7 +367,7 @@ def exception_handler_wrapper(original_handler):
Wraps an exception handler
"""
is_async = asyncio.iscoroutinefunction(original_handler)
if is_async:
if is_async or IS_ASYNC_MODE:
# async handlers are not supported
return original_handler

Expand Down Expand Up @@ -323,12 +410,16 @@ async def server_call_wrapper(wrapped, _instance, args, kwargs):

trace = None
try:
epsagon.trace.trace_factory.switch_to_multiple_traces()
if IS_ASYNC_MODE:
epsagon.trace.trace_factory.switch_to_async_tracer()
else:
epsagon.trace.trace_factory.switch_to_multiple_traces()
unique_id = str(uuid.uuid4())
trace = epsagon.trace.trace_factory.get_or_create_trace(
unique_id=unique_id
)
trace.prepare()

scope[EPSAGON_MARKER] = {
SCOPE_UNIQUE_ID: unique_id,
}
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pytest-asyncio; python_version >= '3.5'
pytest-aiohttp; python_version >= '3.5'
httpx; python_version >= '3.5'
asynctest; python_version >= '3.5'
pytest-lazy-fixture; python_version >= '3.5'
moto; python_version >= '3.5'
moto==2.1.0; python_version < '3.5'
tornado
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,4 @@ def reset_tracer_mode():
Resets trace factory tracer mode to a single trace.
"""
epsagon.trace_factory.use_single_trace = True
epsagon.use_async_tracer = False
Loading

0 comments on commit 5fb0295

Please sign in to comment.