From c147de5e111d201717c2563dba4decba80742a4c Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Tue, 3 Oct 2023 11:22:41 +0100 Subject: [PATCH 01/11] cache session property of endpoint config --- rasa/core/agent.py | 2 +- rasa/utils/endpoints.py | 8 ++++---- tests/utils/test_endpoints.py | 16 +++++++++++++++- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/rasa/core/agent.py b/rasa/core/agent.py index 47e3360f6a5b..73d3e8ac38d8 100644 --- a/rasa/core/agent.py +++ b/rasa/core/agent.py @@ -112,7 +112,7 @@ async def _pull_model_and_fingerprint( logger.debug(f"Requesting model from server {model_server.url}...") - async with model_server.session() as session: + async with model_server.session as session: try: params = model_server.combine_parameters() async with session.request( diff --git a/rasa/utils/endpoints.py b/rasa/utils/endpoints.py index cffc7523a2e3..c0ea5b35d4f9 100644 --- a/rasa/utils/endpoints.py +++ b/rasa/utils/endpoints.py @@ -1,4 +1,5 @@ import ssl +from functools import cached_property import aiohttp import logging @@ -19,9 +20,7 @@ def read_endpoint_config( filename: Text, endpoint_type: Text ) -> Optional["EndpointConfig"]: - """Read an endpoint configuration file from disk and extract one - - config.""" + """Read an endpoint configuration file from disk and extract one config.""" # noqa: E501 if not filename: return None @@ -96,6 +95,7 @@ def __init__( self.cafile = cafile self.kwargs = kwargs + @cached_property def session(self) -> aiohttp.ClientSession: """Creates and returns a configured aiohttp client session.""" # create authentication parameters @@ -161,7 +161,7 @@ async def request( f"'{os.path.abspath(self.cafile)}' does not exist." ) from e - async with self.session() as session: + async with self.session as session: async with session.request( method, url, diff --git a/tests/utils/test_endpoints.py b/tests/utils/test_endpoints.py index 847c940622b9..17c3821f8352 100644 --- a/tests/utils/test_endpoints.py +++ b/tests/utils/test_endpoints.py @@ -88,7 +88,7 @@ async def test_endpoint_config(): # unfortunately, the mock library won't report any headers stored on # the session object, so we need to verify them separately - async with endpoint.session() as s: + async with endpoint.session as s: assert s._default_headers.get("X-Powered-By") == "Rasa" assert s._default_auth.login == "user" assert s._default_auth.password == "pass" @@ -231,3 +231,17 @@ def test_int_arg(value: Optional[Union[int, str]], default: int, expected_result if value is not None: request.args = {"key": value} assert endpoint_utils.int_arg(request, "key", default) == expected_result + + +async def test_endpoint_config_does_not_create_session_cached_property() -> None: + """Test the instantiation of EndpointConfig does not create the session cached property.""" # noqa: E501 + endpoint = endpoint_utils.EndpointConfig("https://example.com/") + + assert endpoint.__dict__.get("url") == "https://example.com/" + assert endpoint.__dict__.get("session") is None + + # the property is created when it is accessed + async with endpoint.session as session: + assert session is not None + + assert endpoint.__dict__.get("session") is session From 6a2310f5bfc40dd68c29a5bf00fc8a6830e6cca9 Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Tue, 3 Oct 2023 11:52:02 +0100 Subject: [PATCH 02/11] fix session is closed runtime error --- rasa/utils/endpoints.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/rasa/utils/endpoints.py b/rasa/utils/endpoints.py index c0ea5b35d4f9..821cdbffa3e9 100644 --- a/rasa/utils/endpoints.py +++ b/rasa/utils/endpoints.py @@ -161,24 +161,23 @@ async def request( f"'{os.path.abspath(self.cafile)}' does not exist." ) from e - async with self.session as session: - async with session.request( - method, - url, - headers=headers, - params=self.combine_parameters(kwargs), - compress=compress, - ssl=sslcontext, - **kwargs, - ) as response: - if response.status >= 400: - raise ClientResponseError( - response.status, response.reason, await response.content.read() - ) - try: - return await response.json() - except ContentTypeError: - return None + async with self.session.request( + method, + url, + headers=headers, + params=self.combine_parameters(kwargs), + compress=compress, + ssl=sslcontext, + **kwargs, + ) as response: + if response.status >= 400: + raise ClientResponseError( + response.status, response.reason, await response.content.read() + ) + try: + return await response.json() + except ContentTypeError: + return None @classmethod def from_dict(cls, data: Dict[Text, Any]) -> "EndpointConfig": From eb8d2757010b6b6e9c5d86e95262e1576015df46 Mon Sep 17 00:00:00 2001 From: souvik ghosh Date: Wed, 4 Oct 2023 10:12:41 +0200 Subject: [PATCH 03/11] ref: remove async manager for session --- rasa/core/agent.py | 97 ++++++++++++++++++++++------------------------ 1 file changed, 46 insertions(+), 51 deletions(-) diff --git a/rasa/core/agent.py b/rasa/core/agent.py index 73d3e8ac38d8..8e579849abda 100644 --- a/rasa/core/agent.py +++ b/rasa/core/agent.py @@ -112,59 +112,54 @@ async def _pull_model_and_fingerprint( logger.debug(f"Requesting model from server {model_server.url}...") - async with model_server.session as session: - try: - params = model_server.combine_parameters() - async with session.request( - "GET", - model_server.url, - timeout=DEFAULT_REQUEST_TIMEOUT, - headers=headers, - params=params, - ) as resp: - - if resp.status in [204, 304]: - logger.debug( - "Model server returned {} status code, " - "indicating that no new model is available. " - "Current fingerprint: {}" - "".format(resp.status, fingerprint) - ) - return None - elif resp.status == 404: - logger.debug( - "Model server could not find a model at the requested " - "endpoint '{}'. It's possible that no model has been " - "trained, or that the requested tag hasn't been " - "assigned.".format(model_server.url) - ) - return None - elif resp.status != 200: - logger.debug( - "Tried to fetch model from server, but server response " - "status code is {}. We'll retry later..." - "".format(resp.status) - ) - return None - - model_path = Path(model_directory) / resp.headers.get( - "filename", "model.tar.gz" + + try: + params = model_server.combine_parameters() + async with model_server.session.request( + "GET", + model_server.url, + timeout=DEFAULT_REQUEST_TIMEOUT, + headers=headers, + params=params, + ) as resp: + if resp.status in [204, 304]: + logger.debug( + "Model server returned {} status code, " + "indicating that no new model is available. " + "Current fingerprint: {}" + "".format(resp.status, fingerprint) ) - with open(model_path, "wb") as file: - file.write(await resp.read()) - - logger.debug("Saved model to '{}'".format(os.path.abspath(model_path))) - - # return the new fingerprint - return resp.headers.get("ETag") - - except aiohttp.ClientError as e: - logger.debug( - "Tried to fetch model from server, but " - "couldn't reach server. We'll retry later... " - "Error: {}.".format(e) + return None + elif resp.status == 404: + logger.debug( + "Model server could not find a model at the requested " + "endpoint '{}'. It's possible that no model has been " + "trained, or that the requested tag hasn't been " + "assigned.".format(model_server.url) + ) + return None + elif resp.status != 200: + logger.debug( + "Tried to fetch model from server, but server response " + "status code is {}. We'll retry later..." + "".format(resp.status) + ) + return None + model_path = Path(model_directory) / resp.headers.get( + "filename", "model.tar.gz" ) - return None + with open(model_path, "wb") as file: + file.write(await resp.read()) + logger.debug("Saved model to '{}'".format(os.path.abspath(model_path))) + # return the new fingerprint + return resp.headers.get("ETag") + except aiohttp.ClientError as e: + logger.debug( + "Tried to fetch model from server, but " + "couldn't reach server. We'll retry later... " + "Error: {}.".format(e) + ) + return None async def _run_model_pulling_worker(model_server: EndpointConfig, agent: Agent) -> None: From 15426e755ff97a632286543086db95b4b3b85aa0 Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Wed, 4 Oct 2023 10:25:02 +0100 Subject: [PATCH 04/11] revert cached_property, implement sanic listener + lru cache --- rasa/core/agent.py | 3 +-- rasa/core/run.py | 22 +++++++++++++++++++++- rasa/utils/endpoints.py | 18 +++++++++++++++--- tests/utils/test_endpoints.py | 15 ++++----------- 4 files changed, 41 insertions(+), 17 deletions(-) diff --git a/rasa/core/agent.py b/rasa/core/agent.py index 8e579849abda..a05e810060f4 100644 --- a/rasa/core/agent.py +++ b/rasa/core/agent.py @@ -112,10 +112,9 @@ async def _pull_model_and_fingerprint( logger.debug(f"Requesting model from server {model_server.url}...") - try: params = model_server.combine_parameters() - async with model_server.session.request( + async with model_server.session().request( "GET", model_server.url, timeout=DEFAULT_REQUEST_TIMEOUT, diff --git a/rasa/core/run.py b/rasa/core/run.py index 6c9463d2c515..2866af06adf3 100644 --- a/rasa/core/run.py +++ b/rasa/core/run.py @@ -3,7 +3,7 @@ import uuid import os from functools import partial -from typing import Any, List, Optional, Text, Union, Dict +from typing import Any, List, Optional, TYPE_CHECKING, Text, Union, Dict import rasa.core.utils from rasa.plugin import plugin_manager @@ -23,6 +23,9 @@ from sanic import Sanic from asyncio import AbstractEventLoop +if TYPE_CHECKING: + from aiohttp import ClientSession + logger = logging.getLogger() # get the root logger @@ -214,6 +217,7 @@ def serve_application( partial(load_agent_on_start, model_path, endpoints, remote_storage), "before_server_start", ) + app.register_listener(create_connections, "after_server_start") app.register_listener(close_resources, "after_server_stop") number_of_workers = rasa.core.utils.number_of_sanic_workers( @@ -275,3 +279,19 @@ async def close_resources(app: Sanic, _: AbstractEventLoop) -> None: event_broker = current_agent.tracker_store.event_broker if event_broker: await event_broker.close() + + +async def create_connections( + app: Sanic, _: AbstractEventLoop +) -> Optional["ClientSession"]: + current_agent = getattr(app.ctx, "agent", None) + if not current_agent: + logger.debug("No agent found after server start.") + return None + + action_endpoint = current_agent.action_endpoint + if not action_endpoint: + logger.debug("No action endpoint found after server start.") + return None + + return action_endpoint.session() diff --git a/rasa/utils/endpoints.py b/rasa/utils/endpoints.py index 821cdbffa3e9..8c4b92de341c 100644 --- a/rasa/utils/endpoints.py +++ b/rasa/utils/endpoints.py @@ -1,5 +1,5 @@ import ssl -from functools import cached_property +from functools import lru_cache import aiohttp import logging @@ -95,7 +95,7 @@ def __init__( self.cafile = cafile self.kwargs = kwargs - @cached_property + @lru_cache def session(self) -> aiohttp.ClientSession: """Creates and returns a configured aiohttp client session.""" # create authentication parameters @@ -161,7 +161,7 @@ async def request( f"'{os.path.abspath(self.cafile)}' does not exist." ) from e - async with self.session.request( + async with self.session().request( method, url, headers=headers, @@ -210,6 +210,18 @@ def __eq__(self, other: Any) -> bool: def __ne__(self, other: Any) -> bool: return not self.__eq__(other) + def __hash__(self) -> int: + return hash( + ( + self.url, + tuple(self.params.items()), + tuple(self.headers.items()), + tuple(self.basic_auth.items()), + self.token, + self.token_name, + ) + ) + class ClientResponseError(aiohttp.ClientError): def __init__(self, status: int, message: Text, text: Text) -> None: diff --git a/tests/utils/test_endpoints.py b/tests/utils/test_endpoints.py index 17c3821f8352..408fffd0258e 100644 --- a/tests/utils/test_endpoints.py +++ b/tests/utils/test_endpoints.py @@ -88,7 +88,7 @@ async def test_endpoint_config(): # unfortunately, the mock library won't report any headers stored on # the session object, so we need to verify them separately - async with endpoint.session as s: + async with endpoint.session() as s: assert s._default_headers.get("X-Powered-By") == "Rasa" assert s._default_auth.login == "user" assert s._default_auth.password == "pass" @@ -233,15 +233,8 @@ def test_int_arg(value: Optional[Union[int, str]], default: int, expected_result assert endpoint_utils.int_arg(request, "key", default) == expected_result -async def test_endpoint_config_does_not_create_session_cached_property() -> None: - """Test the instantiation of EndpointConfig does not create the session cached property.""" # noqa: E501 +async def test_endpoint_config_caches_session() -> None: endpoint = endpoint_utils.EndpointConfig("https://example.com/") + session = endpoint.session() - assert endpoint.__dict__.get("url") == "https://example.com/" - assert endpoint.__dict__.get("session") is None - - # the property is created when it is accessed - async with endpoint.session as session: - assert session is not None - - assert endpoint.__dict__.get("session") is session + assert session is endpoint.session() From 136dbd1bf3a1d9471f2af8d3d9711c9c7223dd75 Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Wed, 4 Oct 2023 10:48:17 +0100 Subject: [PATCH 05/11] trigger CI From ebad913dfb1ab4ae35628fdc8d5d6b06c5e2680e Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Wed, 4 Oct 2023 11:02:56 +0100 Subject: [PATCH 06/11] pin ddtrace in CI workflow --- .github/workflows/continous-integration.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/continous-integration.yml b/.github/workflows/continous-integration.yml index 44f84b6b5757..952197b8a8cb 100644 --- a/.github/workflows/continous-integration.yml +++ b/.github/workflows/continous-integration.yml @@ -354,7 +354,7 @@ jobs: - name: Install ddtrace if: needs.changes.outputs.backend == 'true' - run: poetry run pip install -U ddtrace + run: poetry run pip install -U "ddtrace<2.0.0" - name: Test Code 🔍 (multi-process) if: needs.changes.outputs.backend == 'true' From fa6c4b4c7adf1f0d345b2c7c2521ca1d859f3cef Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Wed, 4 Oct 2023 11:17:01 +0100 Subject: [PATCH 07/11] separate pip install commands for ddtrace per OS --- .github/workflows/continous-integration.yml | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/continous-integration.yml b/.github/workflows/continous-integration.yml index 952197b8a8cb..d0c4864f7f11 100644 --- a/.github/workflows/continous-integration.yml +++ b/.github/workflows/continous-integration.yml @@ -352,9 +352,13 @@ jobs: (Get-ItemProperty "HKLM:System\CurrentControlSet\Control\FileSystem").LongPathsEnabled Set-ItemProperty 'HKLM:\System\CurrentControlSet\Control\FileSystem' -Name 'LongPathsEnabled' -value 0 - - name: Install ddtrace - if: needs.changes.outputs.backend == 'true' - run: poetry run pip install -U "ddtrace<2.0.0" + - name: Install ddtrace on Linux + if: needs.changes.outputs.backend == 'true' && matrix.os == 'ubuntu-22.04' + run: poetry run pip install -U 'ddtrace<2.0.0' + + - name: Install ddtrace on Windows + if: needs.changes.outputs.backend == 'true' && matrix.os == 'windows-2019' + run: poetry run py -m pip install -U 'ddtrace<2.0.0' - name: Test Code 🔍 (multi-process) if: needs.changes.outputs.backend == 'true' From 82fe65dce99c53e1db252125ab0c3de9659a5175 Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Wed, 4 Oct 2023 11:23:05 +0100 Subject: [PATCH 08/11] update flaky tests jobs; remove poetry command --- .github/workflows/continous-integration.yml | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/continous-integration.yml b/.github/workflows/continous-integration.yml index d0c4864f7f11..453107c6d335 100644 --- a/.github/workflows/continous-integration.yml +++ b/.github/workflows/continous-integration.yml @@ -358,7 +358,7 @@ jobs: - name: Install ddtrace on Windows if: needs.changes.outputs.backend == 'true' && matrix.os == 'windows-2019' - run: poetry run py -m pip install -U 'ddtrace<2.0.0' + run: py -m pip install -U 'ddtrace<2.0.0' - name: Test Code 🔍 (multi-process) if: needs.changes.outputs.backend == 'true' @@ -496,9 +496,13 @@ jobs: (Get-ItemProperty "HKLM:System\CurrentControlSet\Control\FileSystem").LongPathsEnabled Set-ItemProperty 'HKLM:\System\CurrentControlSet\Control\FileSystem' -Name 'LongPathsEnabled' -value 0 - - name: Install ddtrace - if: needs.changes.outputs.backend == 'true' - run: poetry run pip install -U ddtrace + - name: Install ddtrace on Linux + if: needs.changes.outputs.backend == 'true' && matrix.os == 'ubuntu-22.04' + run: poetry run pip install -U 'ddtrace<2.0.0' + + - name: Install ddtrace on Windows + if: needs.changes.outputs.backend == 'true' && matrix.os == 'windows-2019' + run: py -m pip install -U 'ddtrace<2.0.0' - name: Test Code 🔍 (multi-process) if: needs.changes.outputs.backend == 'true' From 749fa5007b4cefd41532a70416ba9fc3dfbfa664 Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Wed, 4 Oct 2023 14:38:06 +0100 Subject: [PATCH 09/11] address review suggestions --- rasa/core/agent.py | 2 +- rasa/core/run.py | 29 +++++++++++++++++++++++------ rasa/utils/endpoints.py | 18 +++--------------- tests/utils/test_endpoints.py | 27 +++++++++++++++++++++++---- 4 files changed, 50 insertions(+), 26 deletions(-) diff --git a/rasa/core/agent.py b/rasa/core/agent.py index a05e810060f4..bf3d42236e70 100644 --- a/rasa/core/agent.py +++ b/rasa/core/agent.py @@ -114,7 +114,7 @@ async def _pull_model_and_fingerprint( try: params = model_server.combine_parameters() - async with model_server.session().request( + async with model_server.session.request( "GET", model_server.url, timeout=DEFAULT_REQUEST_TIMEOUT, diff --git a/rasa/core/run.py b/rasa/core/run.py index 2866af06adf3..e7db50a79dbe 100644 --- a/rasa/core/run.py +++ b/rasa/core/run.py @@ -217,7 +217,7 @@ def serve_application( partial(load_agent_on_start, model_path, endpoints, remote_storage), "before_server_start", ) - app.register_listener(create_connections, "after_server_start") + app.register_listener(create_connection_pools, "after_server_start") app.register_listener(close_resources, "after_server_stop") number_of_workers = rasa.core.utils.number_of_sanic_workers( @@ -281,17 +281,34 @@ async def close_resources(app: Sanic, _: AbstractEventLoop) -> None: await event_broker.close() -async def create_connections( - app: Sanic, _: AbstractEventLoop -) -> Optional["ClientSession"]: +async def create_connection_pools(app: Sanic, _: AbstractEventLoop) -> None: + """Create connection pools for the agent's action server and model server.""" current_agent = getattr(app.ctx, "agent", None) if not current_agent: logger.debug("No agent found after server start.") return None - action_endpoint = current_agent.action_endpoint + create_action_endpoint_connection_pool(current_agent) + create_model_server_connection_pool(current_agent) + + return None + + +def create_action_endpoint_connection_pool(agent: Agent) -> Optional["ClientSession"]: + """Create a connection pool for the action endpoint.""" + action_endpoint = agent.action_endpoint if not action_endpoint: logger.debug("No action endpoint found after server start.") return None - return action_endpoint.session() + return action_endpoint.session + + +def create_model_server_connection_pool(agent: Agent) -> Optional["ClientSession"]: + """Create a connection pool for the model server.""" + model_server = agent.model_server + if not model_server: + logger.debug("No model server endpoint found after server start.") + return None + + return model_server.session diff --git a/rasa/utils/endpoints.py b/rasa/utils/endpoints.py index 8c4b92de341c..821cdbffa3e9 100644 --- a/rasa/utils/endpoints.py +++ b/rasa/utils/endpoints.py @@ -1,5 +1,5 @@ import ssl -from functools import lru_cache +from functools import cached_property import aiohttp import logging @@ -95,7 +95,7 @@ def __init__( self.cafile = cafile self.kwargs = kwargs - @lru_cache + @cached_property def session(self) -> aiohttp.ClientSession: """Creates and returns a configured aiohttp client session.""" # create authentication parameters @@ -161,7 +161,7 @@ async def request( f"'{os.path.abspath(self.cafile)}' does not exist." ) from e - async with self.session().request( + async with self.session.request( method, url, headers=headers, @@ -210,18 +210,6 @@ def __eq__(self, other: Any) -> bool: def __ne__(self, other: Any) -> bool: return not self.__eq__(other) - def __hash__(self) -> int: - return hash( - ( - self.url, - tuple(self.params.items()), - tuple(self.headers.items()), - tuple(self.basic_auth.items()), - self.token, - self.token_name, - ) - ) - class ClientResponseError(aiohttp.ClientError): def __init__(self, status: int, message: Text, text: Text) -> None: diff --git a/tests/utils/test_endpoints.py b/tests/utils/test_endpoints.py index 408fffd0258e..b620d445ceb3 100644 --- a/tests/utils/test_endpoints.py +++ b/tests/utils/test_endpoints.py @@ -88,7 +88,7 @@ async def test_endpoint_config(): # unfortunately, the mock library won't report any headers stored on # the session object, so we need to verify them separately - async with endpoint.session() as s: + async with endpoint.session as s: assert s._default_headers.get("X-Powered-By") == "Rasa" assert s._default_auth.login == "user" assert s._default_auth.password == "pass" @@ -233,8 +233,27 @@ def test_int_arg(value: Optional[Union[int, str]], default: int, expected_result assert endpoint_utils.int_arg(request, "key", default) == expected_result -async def test_endpoint_config_caches_session() -> None: +def test_endpoint_config_caches_session() -> None: + """Test that the EndpointConfig session is cached. + + Assert identity of the session object, which should not be recreated when calling + the property `session` multiple times. + """ endpoint = endpoint_utils.EndpointConfig("https://example.com/") - session = endpoint.session() + session = endpoint.session + + assert session is endpoint.session + + +async def test_endpoint_config_constructor_does_not_create_session_cached_property() -> None: # noqa: E501 + """Test that the instantiation of EndpointConfig does not create the session cached property.""" # noqa: E501 + endpoint = endpoint_utils.EndpointConfig("https://example.com/") + + assert endpoint.__dict__.get("url") == "https://example.com/" + assert endpoint.__dict__.get("session") is None + + # the property is created when it is accessed + async with endpoint.session as session: + assert session is not None - assert session is endpoint.session() + assert endpoint.__dict__.get("session") is session From 085ef71f295b92a69dd93fa3889738bb42d3d52c Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Wed, 4 Oct 2023 14:56:53 +0100 Subject: [PATCH 10/11] make failing unit test async --- tests/utils/test_endpoints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_endpoints.py b/tests/utils/test_endpoints.py index b620d445ceb3..7ff7e008ccac 100644 --- a/tests/utils/test_endpoints.py +++ b/tests/utils/test_endpoints.py @@ -233,7 +233,7 @@ def test_int_arg(value: Optional[Union[int, str]], default: int, expected_result assert endpoint_utils.int_arg(request, "key", default) == expected_result -def test_endpoint_config_caches_session() -> None: +async def test_endpoint_config_caches_session() -> None: """Test that the EndpointConfig session is cached. Assert identity of the session object, which should not be recreated when calling @@ -242,7 +242,7 @@ def test_endpoint_config_caches_session() -> None: endpoint = endpoint_utils.EndpointConfig("https://example.com/") session = endpoint.session - assert session is endpoint.session + assert endpoint.session is session async def test_endpoint_config_constructor_does_not_create_session_cached_property() -> None: # noqa: E501 From eeed9c35b2067a8b9903b66d124ea58c56198993 Mon Sep 17 00:00:00 2001 From: Anca Lita <27920906+ancalita@users.noreply.github.com> Date: Wed, 4 Oct 2023 15:58:26 +0100 Subject: [PATCH 11/11] add changelog, add closing of connection pools to after_server_stop listener --- changelog/12886.bugfix.md | 3 +++ rasa/core/run.py | 8 ++++++++ tests/core/test_run.py | 9 ++++++--- tests/utils/test_endpoints.py | 3 +++ 4 files changed, 20 insertions(+), 3 deletions(-) create mode 100644 changelog/12886.bugfix.md diff --git a/changelog/12886.bugfix.md b/changelog/12886.bugfix.md new file mode 100644 index 000000000000..e327578350d7 --- /dev/null +++ b/changelog/12886.bugfix.md @@ -0,0 +1,3 @@ +Cache `EndpointConfig` session object using `cached_property` decorator instead of recreating this object on every request. +Initialize these connection pools for action server and model server endpoints as part of the Sanic `after_server_start` listener. +Also close connection pools during Sanic `after_server_stop` listener. diff --git a/rasa/core/run.py b/rasa/core/run.py index e7db50a79dbe..5270162809dd 100644 --- a/rasa/core/run.py +++ b/rasa/core/run.py @@ -280,6 +280,14 @@ async def close_resources(app: Sanic, _: AbstractEventLoop) -> None: if event_broker: await event_broker.close() + action_endpoint = current_agent.action_endpoint + if action_endpoint: + await action_endpoint.session.close() + + model_server = current_agent.model_server + if model_server: + await model_server.session.close() + async def create_connection_pools(app: Sanic, _: AbstractEventLoop) -> None: """Create connection pools for the agent's action server and model server.""" diff --git a/tests/core/test_run.py b/tests/core/test_run.py index d57f1263c74c..1ac276d43772 100644 --- a/tests/core/test_run.py +++ b/tests/core/test_run.py @@ -1,5 +1,7 @@ +import warnings from unittest.mock import Mock +import aiohttp import pytest from typing import Text @@ -82,8 +84,9 @@ async def test_close_resources(loop: AbstractEventLoop): broker = SQLEventBroker() app = Mock() app.ctx.agent.tracker_store.event_broker = broker + app.ctx.agent.action_endpoint.session = aiohttp.ClientSession() + app.ctx.agent.model_server.session = aiohttp.ClientSession() - with pytest.warns(None) as warnings: + with warnings.catch_warnings() as record: await run.close_resources(app, loop) - - assert len(warnings) == 0 + assert record is None diff --git a/tests/utils/test_endpoints.py b/tests/utils/test_endpoints.py index 7ff7e008ccac..071e54ee9318 100644 --- a/tests/utils/test_endpoints.py +++ b/tests/utils/test_endpoints.py @@ -244,6 +244,9 @@ async def test_endpoint_config_caches_session() -> None: assert endpoint.session is session + # teardown + await endpoint.session.close() + async def test_endpoint_config_constructor_does_not_create_session_cached_property() -> None: # noqa: E501 """Test that the instantiation of EndpointConfig does not create the session cached property.""" # noqa: E501