Skip to content

Commit

Permalink
Cache EndpointConfig session (#12886)
Browse files Browse the repository at this point in the history
* cache session property of endpoint config

* fix session is closed runtime error

* ref: remove async manager for session

* revert cached_property, implement sanic listener + lru cache

* trigger CI

* pin ddtrace in CI workflow

* separate pip install commands for ddtrace per OS

* update flaky tests jobs; remove poetry command

* address review suggestions

* make failing unit test async

* add changelog, add closing of connection pools to after_server_stop listener

---------

Co-authored-by: souvik ghosh <[email protected]>
  • Loading branch information
ancalita and souvikg10 authored Oct 5, 2023
1 parent e37774e commit 7717472
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 83 deletions.
20 changes: 14 additions & 6 deletions .github/workflows/continous-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,13 @@ jobs:
(Get-ItemProperty "HKLM:System\CurrentControlSet\Control\FileSystem").LongPathsEnabled
Set-ItemProperty 'HKLM:\System\CurrentControlSet\Control\FileSystem' -Name 'LongPathsEnabled' -value 0
- name: Install ddtrace
if: needs.changes.outputs.backend == 'true'
run: poetry run pip install -U ddtrace
- name: Install ddtrace on Linux
if: needs.changes.outputs.backend == 'true' && matrix.os == 'ubuntu-22.04'
run: poetry run pip install -U 'ddtrace<2.0.0'

- name: Install ddtrace on Windows
if: needs.changes.outputs.backend == 'true' && matrix.os == 'windows-2019'
run: py -m pip install -U 'ddtrace<2.0.0'

- name: Test Code 🔍 (multi-process)
if: needs.changes.outputs.backend == 'true'
Expand Down Expand Up @@ -492,9 +496,13 @@ jobs:
(Get-ItemProperty "HKLM:System\CurrentControlSet\Control\FileSystem").LongPathsEnabled
Set-ItemProperty 'HKLM:\System\CurrentControlSet\Control\FileSystem' -Name 'LongPathsEnabled' -value 0
- name: Install ddtrace
if: needs.changes.outputs.backend == 'true'
run: poetry run pip install -U ddtrace
- name: Install ddtrace on Linux
if: needs.changes.outputs.backend == 'true' && matrix.os == 'ubuntu-22.04'
run: poetry run pip install -U 'ddtrace<2.0.0'

- name: Install ddtrace on Windows
if: needs.changes.outputs.backend == 'true' && matrix.os == 'windows-2019'
run: py -m pip install -U 'ddtrace<2.0.0'

- name: Test Code 🔍 (multi-process)
if: needs.changes.outputs.backend == 'true'
Expand Down
3 changes: 3 additions & 0 deletions changelog/12886.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Cache `EndpointConfig` session object using `cached_property` decorator instead of recreating this object on every request.
Initialize these connection pools for action server and model server endpoints as part of the Sanic `after_server_start` listener.
Also close connection pools during Sanic `after_server_stop` listener.
96 changes: 45 additions & 51 deletions rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,59 +112,53 @@ async def _pull_model_and_fingerprint(

logger.debug(f"Requesting model from server {model_server.url}...")

async with model_server.session() as session:
try:
params = model_server.combine_parameters()
async with session.request(
"GET",
model_server.url,
timeout=DEFAULT_REQUEST_TIMEOUT,
headers=headers,
params=params,
) as resp:

if resp.status in [204, 304]:
logger.debug(
"Model server returned {} status code, "
"indicating that no new model is available. "
"Current fingerprint: {}"
"".format(resp.status, fingerprint)
)
return None
elif resp.status == 404:
logger.debug(
"Model server could not find a model at the requested "
"endpoint '{}'. It's possible that no model has been "
"trained, or that the requested tag hasn't been "
"assigned.".format(model_server.url)
)
return None
elif resp.status != 200:
logger.debug(
"Tried to fetch model from server, but server response "
"status code is {}. We'll retry later..."
"".format(resp.status)
)
return None

model_path = Path(model_directory) / resp.headers.get(
"filename", "model.tar.gz"
try:
params = model_server.combine_parameters()
async with model_server.session.request(
"GET",
model_server.url,
timeout=DEFAULT_REQUEST_TIMEOUT,
headers=headers,
params=params,
) as resp:
if resp.status in [204, 304]:
logger.debug(
"Model server returned {} status code, "
"indicating that no new model is available. "
"Current fingerprint: {}"
"".format(resp.status, fingerprint)
)
with open(model_path, "wb") as file:
file.write(await resp.read())

logger.debug("Saved model to '{}'".format(os.path.abspath(model_path)))

# return the new fingerprint
return resp.headers.get("ETag")

except aiohttp.ClientError as e:
logger.debug(
"Tried to fetch model from server, but "
"couldn't reach server. We'll retry later... "
"Error: {}.".format(e)
return None
elif resp.status == 404:
logger.debug(
"Model server could not find a model at the requested "
"endpoint '{}'. It's possible that no model has been "
"trained, or that the requested tag hasn't been "
"assigned.".format(model_server.url)
)
return None
elif resp.status != 200:
logger.debug(
"Tried to fetch model from server, but server response "
"status code is {}. We'll retry later..."
"".format(resp.status)
)
return None
model_path = Path(model_directory) / resp.headers.get(
"filename", "model.tar.gz"
)
return None
with open(model_path, "wb") as file:
file.write(await resp.read())
logger.debug("Saved model to '{}'".format(os.path.abspath(model_path)))
# return the new fingerprint
return resp.headers.get("ETag")
except aiohttp.ClientError as e:
logger.debug(
"Tried to fetch model from server, but "
"couldn't reach server. We'll retry later... "
"Error: {}.".format(e)
)
return None


async def _run_model_pulling_worker(model_server: EndpointConfig, agent: Agent) -> None:
Expand Down
47 changes: 46 additions & 1 deletion rasa/core/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid
import os
from functools import partial
from typing import Any, List, Optional, Text, Union, Dict
from typing import Any, List, Optional, TYPE_CHECKING, Text, Union, Dict

import rasa.core.utils
from rasa.plugin import plugin_manager
Expand All @@ -23,6 +23,9 @@
from sanic import Sanic
from asyncio import AbstractEventLoop

if TYPE_CHECKING:
from aiohttp import ClientSession

logger = logging.getLogger() # get the root logger


Expand Down Expand Up @@ -214,6 +217,7 @@ def serve_application(
partial(load_agent_on_start, model_path, endpoints, remote_storage),
"before_server_start",
)
app.register_listener(create_connection_pools, "after_server_start")
app.register_listener(close_resources, "after_server_stop")

number_of_workers = rasa.core.utils.number_of_sanic_workers(
Expand Down Expand Up @@ -275,3 +279,44 @@ async def close_resources(app: Sanic, _: AbstractEventLoop) -> None:
event_broker = current_agent.tracker_store.event_broker
if event_broker:
await event_broker.close()

action_endpoint = current_agent.action_endpoint
if action_endpoint:
await action_endpoint.session.close()

model_server = current_agent.model_server
if model_server:
await model_server.session.close()


async def create_connection_pools(app: Sanic, _: AbstractEventLoop) -> None:
"""Create connection pools for the agent's action server and model server."""
current_agent = getattr(app.ctx, "agent", None)
if not current_agent:
logger.debug("No agent found after server start.")
return None

create_action_endpoint_connection_pool(current_agent)
create_model_server_connection_pool(current_agent)

return None


def create_action_endpoint_connection_pool(agent: Agent) -> Optional["ClientSession"]:
"""Create a connection pool for the action endpoint."""
action_endpoint = agent.action_endpoint
if not action_endpoint:
logger.debug("No action endpoint found after server start.")
return None

return action_endpoint.session


def create_model_server_connection_pool(agent: Agent) -> Optional["ClientSession"]:
"""Create a connection pool for the model server."""
model_server = agent.model_server
if not model_server:
logger.debug("No model server endpoint found after server start.")
return None

return model_server.session
41 changes: 20 additions & 21 deletions rasa/utils/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ssl
from functools import cached_property

import aiohttp
import logging
Expand All @@ -19,9 +20,7 @@
def read_endpoint_config(
filename: Text, endpoint_type: Text
) -> Optional["EndpointConfig"]:
"""Read an endpoint configuration file from disk and extract one
config."""
"""Read an endpoint configuration file from disk and extract one config.""" # noqa: E501
if not filename:
return None

Expand Down Expand Up @@ -96,6 +95,7 @@ def __init__(
self.cafile = cafile
self.kwargs = kwargs

@cached_property
def session(self) -> aiohttp.ClientSession:
"""Creates and returns a configured aiohttp client session."""
# create authentication parameters
Expand Down Expand Up @@ -161,24 +161,23 @@ async def request(
f"'{os.path.abspath(self.cafile)}' does not exist."
) from e

async with self.session() as session:
async with session.request(
method,
url,
headers=headers,
params=self.combine_parameters(kwargs),
compress=compress,
ssl=sslcontext,
**kwargs,
) as response:
if response.status >= 400:
raise ClientResponseError(
response.status, response.reason, await response.content.read()
)
try:
return await response.json()
except ContentTypeError:
return None
async with self.session.request(
method,
url,
headers=headers,
params=self.combine_parameters(kwargs),
compress=compress,
ssl=sslcontext,
**kwargs,
) as response:
if response.status >= 400:
raise ClientResponseError(
response.status, response.reason, await response.content.read()
)
try:
return await response.json()
except ContentTypeError:
return None

@classmethod
def from_dict(cls, data: Dict[Text, Any]) -> "EndpointConfig":
Expand Down
9 changes: 6 additions & 3 deletions tests/core/test_run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import warnings
from unittest.mock import Mock

import aiohttp
import pytest
from typing import Text

Expand Down Expand Up @@ -82,8 +84,9 @@ async def test_close_resources(loop: AbstractEventLoop):
broker = SQLEventBroker()
app = Mock()
app.ctx.agent.tracker_store.event_broker = broker
app.ctx.agent.action_endpoint.session = aiohttp.ClientSession()
app.ctx.agent.model_server.session = aiohttp.ClientSession()

with pytest.warns(None) as warnings:
with warnings.catch_warnings() as record:
await run.close_resources(app, loop)

assert len(warnings) == 0
assert record is None
31 changes: 30 additions & 1 deletion tests/utils/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def test_endpoint_config():

# unfortunately, the mock library won't report any headers stored on
# the session object, so we need to verify them separately
async with endpoint.session() as s:
async with endpoint.session as s:
assert s._default_headers.get("X-Powered-By") == "Rasa"
assert s._default_auth.login == "user"
assert s._default_auth.password == "pass"
Expand Down Expand Up @@ -231,3 +231,32 @@ def test_int_arg(value: Optional[Union[int, str]], default: int, expected_result
if value is not None:
request.args = {"key": value}
assert endpoint_utils.int_arg(request, "key", default) == expected_result


async def test_endpoint_config_caches_session() -> None:
"""Test that the EndpointConfig session is cached.
Assert identity of the session object, which should not be recreated when calling
the property `session` multiple times.
"""
endpoint = endpoint_utils.EndpointConfig("https://example.com/")
session = endpoint.session

assert endpoint.session is session

# teardown
await endpoint.session.close()


async def test_endpoint_config_constructor_does_not_create_session_cached_property() -> None: # noqa: E501
"""Test that the instantiation of EndpointConfig does not create the session cached property.""" # noqa: E501
endpoint = endpoint_utils.EndpointConfig("https://example.com/")

assert endpoint.__dict__.get("url") == "https://example.com/"
assert endpoint.__dict__.get("session") is None

# the property is created when it is accessed
async with endpoint.session as session:
assert session is not None

assert endpoint.__dict__.get("session") is session

0 comments on commit 7717472

Please sign in to comment.