Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache EndpointConfig session #12886

Merged
merged 11 commits into from
Oct 5, 2023
20 changes: 14 additions & 6 deletions .github/workflows/continous-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,13 @@ jobs:
(Get-ItemProperty "HKLM:System\CurrentControlSet\Control\FileSystem").LongPathsEnabled
Set-ItemProperty 'HKLM:\System\CurrentControlSet\Control\FileSystem' -Name 'LongPathsEnabled' -value 0

- name: Install ddtrace
if: needs.changes.outputs.backend == 'true'
run: poetry run pip install -U ddtrace
- name: Install ddtrace on Linux
if: needs.changes.outputs.backend == 'true' && matrix.os == 'ubuntu-22.04'
run: poetry run pip install -U 'ddtrace<2.0.0'

- name: Install ddtrace on Windows
if: needs.changes.outputs.backend == 'true' && matrix.os == 'windows-2019'
run: py -m pip install -U 'ddtrace<2.0.0'
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@miraai Is there a unified command provided by Github actions that allows us to install a dependency with a constraint for both OS types?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ancalita We can use RUNNER_OS envar to put this into a single step, your call though if it is worth it

- name:  Install ddtrace
  if: needs.changes.outputs.backend == 'true'
  run:   |
         if [ "$RUNNER_OS" == "Linux" ]; then
              poetry run pip install -U 'ddtrace<2.0.0'
         elif [ "$RUNNER_OS" == "Windows" ]; then
              py -m pip install -U 'ddtrace<2.0.0'
         else
              echo "$RUNNER_OS not supported"
              exit 1
         fi

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I'm not sure if matrix.os == runner.os

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@miraai what approach do you recommend?


- name: Test Code 🔍 (multi-process)
if: needs.changes.outputs.backend == 'true'
Expand Down Expand Up @@ -492,9 +496,13 @@ jobs:
(Get-ItemProperty "HKLM:System\CurrentControlSet\Control\FileSystem").LongPathsEnabled
Set-ItemProperty 'HKLM:\System\CurrentControlSet\Control\FileSystem' -Name 'LongPathsEnabled' -value 0

- name: Install ddtrace
if: needs.changes.outputs.backend == 'true'
run: poetry run pip install -U ddtrace
- name: Install ddtrace on Linux
if: needs.changes.outputs.backend == 'true' && matrix.os == 'ubuntu-22.04'
run: poetry run pip install -U 'ddtrace<2.0.0'

- name: Install ddtrace on Windows
if: needs.changes.outputs.backend == 'true' && matrix.os == 'windows-2019'
run: py -m pip install -U 'ddtrace<2.0.0'

- name: Test Code 🔍 (multi-process)
if: needs.changes.outputs.backend == 'true'
Expand Down
96 changes: 45 additions & 51 deletions rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,59 +112,53 @@ async def _pull_model_and_fingerprint(

logger.debug(f"Requesting model from server {model_server.url}...")

async with model_server.session() as session:
try:
params = model_server.combine_parameters()
async with session.request(
"GET",
model_server.url,
timeout=DEFAULT_REQUEST_TIMEOUT,
headers=headers,
params=params,
) as resp:

if resp.status in [204, 304]:
logger.debug(
"Model server returned {} status code, "
"indicating that no new model is available. "
"Current fingerprint: {}"
"".format(resp.status, fingerprint)
)
return None
elif resp.status == 404:
logger.debug(
"Model server could not find a model at the requested "
"endpoint '{}'. It's possible that no model has been "
"trained, or that the requested tag hasn't been "
"assigned.".format(model_server.url)
)
return None
elif resp.status != 200:
logger.debug(
"Tried to fetch model from server, but server response "
"status code is {}. We'll retry later..."
"".format(resp.status)
)
return None

model_path = Path(model_directory) / resp.headers.get(
"filename", "model.tar.gz"
try:
params = model_server.combine_parameters()
async with model_server.session().request(
"GET",
model_server.url,
timeout=DEFAULT_REQUEST_TIMEOUT,
headers=headers,
params=params,
) as resp:
if resp.status in [204, 304]:
logger.debug(
"Model server returned {} status code, "
"indicating that no new model is available. "
"Current fingerprint: {}"
"".format(resp.status, fingerprint)
)
with open(model_path, "wb") as file:
file.write(await resp.read())

logger.debug("Saved model to '{}'".format(os.path.abspath(model_path)))

# return the new fingerprint
return resp.headers.get("ETag")

except aiohttp.ClientError as e:
logger.debug(
"Tried to fetch model from server, but "
"couldn't reach server. We'll retry later... "
"Error: {}.".format(e)
return None
elif resp.status == 404:
logger.debug(
"Model server could not find a model at the requested "
"endpoint '{}'. It's possible that no model has been "
"trained, or that the requested tag hasn't been "
"assigned.".format(model_server.url)
)
return None
elif resp.status != 200:
logger.debug(
"Tried to fetch model from server, but server response "
"status code is {}. We'll retry later..."
"".format(resp.status)
)
return None
model_path = Path(model_directory) / resp.headers.get(
"filename", "model.tar.gz"
)
return None
with open(model_path, "wb") as file:
file.write(await resp.read())
logger.debug("Saved model to '{}'".format(os.path.abspath(model_path)))
# return the new fingerprint
return resp.headers.get("ETag")
except aiohttp.ClientError as e:
logger.debug(
"Tried to fetch model from server, but "
"couldn't reach server. We'll retry later... "
"Error: {}.".format(e)
)
vcidst marked this conversation as resolved.
Show resolved Hide resolved
return None


async def _run_model_pulling_worker(model_server: EndpointConfig, agent: Agent) -> None:
Expand Down
22 changes: 21 additions & 1 deletion rasa/core/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid
import os
from functools import partial
from typing import Any, List, Optional, Text, Union, Dict
from typing import Any, List, Optional, TYPE_CHECKING, Text, Union, Dict

import rasa.core.utils
from rasa.plugin import plugin_manager
Expand All @@ -23,6 +23,9 @@
from sanic import Sanic
from asyncio import AbstractEventLoop

if TYPE_CHECKING:
from aiohttp import ClientSession

logger = logging.getLogger() # get the root logger


Expand Down Expand Up @@ -214,6 +217,7 @@ def serve_application(
partial(load_agent_on_start, model_path, endpoints, remote_storage),
"before_server_start",
)
app.register_listener(create_connections, "after_server_start")
app.register_listener(close_resources, "after_server_stop")

number_of_workers = rasa.core.utils.number_of_sanic_workers(
Expand Down Expand Up @@ -275,3 +279,19 @@ async def close_resources(app: Sanic, _: AbstractEventLoop) -> None:
event_broker = current_agent.tracker_store.event_broker
if event_broker:
await event_broker.close()


async def create_connections(
app: Sanic, _: AbstractEventLoop
) -> Optional["ClientSession"]:
ancalita marked this conversation as resolved.
Show resolved Hide resolved
current_agent = getattr(app.ctx, "agent", None)
if not current_agent:
logger.debug("No agent found after server start.")
return None

action_endpoint = current_agent.action_endpoint
if not action_endpoint:
logger.debug("No action endpoint found after server start.")
return None

return action_endpoint.session()
53 changes: 32 additions & 21 deletions rasa/utils/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ssl
from functools import lru_cache

import aiohttp
import logging
Expand All @@ -19,9 +20,7 @@
def read_endpoint_config(
filename: Text, endpoint_type: Text
) -> Optional["EndpointConfig"]:
"""Read an endpoint configuration file from disk and extract one

config."""
"""Read an endpoint configuration file from disk and extract one config.""" # noqa: E501
if not filename:
return None

Expand Down Expand Up @@ -96,6 +95,7 @@ def __init__(
self.cafile = cafile
self.kwargs = kwargs

@lru_cache
ancalita marked this conversation as resolved.
Show resolved Hide resolved
def session(self) -> aiohttp.ClientSession:
"""Creates and returns a configured aiohttp client session."""
# create authentication parameters
Expand Down Expand Up @@ -161,24 +161,23 @@ async def request(
f"'{os.path.abspath(self.cafile)}' does not exist."
) from e

async with self.session() as session:
async with session.request(
method,
url,
headers=headers,
params=self.combine_parameters(kwargs),
compress=compress,
ssl=sslcontext,
**kwargs,
) as response:
if response.status >= 400:
raise ClientResponseError(
response.status, response.reason, await response.content.read()
)
try:
return await response.json()
except ContentTypeError:
return None
async with self.session().request(
method,
url,
headers=headers,
params=self.combine_parameters(kwargs),
compress=compress,
ssl=sslcontext,
**kwargs,
) as response:
if response.status >= 400:
raise ClientResponseError(
response.status, response.reason, await response.content.read()
)
try:
return await response.json()
except ContentTypeError:
return None

@classmethod
def from_dict(cls, data: Dict[Text, Any]) -> "EndpointConfig":
Expand Down Expand Up @@ -211,6 +210,18 @@ def __eq__(self, other: Any) -> bool:
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)

def __hash__(self) -> int:
return hash(
(
self.url,
tuple(self.params.items()),
tuple(self.headers.items()),
tuple(self.basic_auth.items()),
self.token,
self.token_name,
)
)


class ClientResponseError(aiohttp.ClientError):
def __init__(self, status: int, message: Text, text: Text) -> None:
Expand Down
7 changes: 7 additions & 0 deletions tests/utils/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,10 @@ def test_int_arg(value: Optional[Union[int, str]], default: int, expected_result
if value is not None:
request.args = {"key": value}
assert endpoint_utils.int_arg(request, "key", default) == expected_result


async def test_endpoint_config_caches_session() -> None:
endpoint = endpoint_utils.EndpointConfig("https://example.com/")
session = endpoint.session()

assert session is endpoint.session()
vcidst marked this conversation as resolved.
Show resolved Hide resolved
Loading