diff --git a/rasa/core/agent.py b/rasa/core/agent.py index bf3d42236e70..47e3360f6a5b 100644 --- a/rasa/core/agent.py +++ b/rasa/core/agent.py @@ -112,53 +112,59 @@ async def _pull_model_and_fingerprint( logger.debug(f"Requesting model from server {model_server.url}...") - try: - params = model_server.combine_parameters() - async with model_server.session.request( - "GET", - model_server.url, - timeout=DEFAULT_REQUEST_TIMEOUT, - headers=headers, - params=params, - ) as resp: - if resp.status in [204, 304]: - logger.debug( - "Model server returned {} status code, " - "indicating that no new model is available. " - "Current fingerprint: {}" - "".format(resp.status, fingerprint) - ) - return None - elif resp.status == 404: - logger.debug( - "Model server could not find a model at the requested " - "endpoint '{}'. It's possible that no model has been " - "trained, or that the requested tag hasn't been " - "assigned.".format(model_server.url) - ) - return None - elif resp.status != 200: - logger.debug( - "Tried to fetch model from server, but server response " - "status code is {}. We'll retry later..." - "".format(resp.status) + async with model_server.session() as session: + try: + params = model_server.combine_parameters() + async with session.request( + "GET", + model_server.url, + timeout=DEFAULT_REQUEST_TIMEOUT, + headers=headers, + params=params, + ) as resp: + + if resp.status in [204, 304]: + logger.debug( + "Model server returned {} status code, " + "indicating that no new model is available. " + "Current fingerprint: {}" + "".format(resp.status, fingerprint) + ) + return None + elif resp.status == 404: + logger.debug( + "Model server could not find a model at the requested " + "endpoint '{}'. It's possible that no model has been " + "trained, or that the requested tag hasn't been " + "assigned.".format(model_server.url) + ) + return None + elif resp.status != 200: + logger.debug( + "Tried to fetch model from server, but server response " + "status code is {}. We'll retry later..." + "".format(resp.status) + ) + return None + + model_path = Path(model_directory) / resp.headers.get( + "filename", "model.tar.gz" ) - return None - model_path = Path(model_directory) / resp.headers.get( - "filename", "model.tar.gz" + with open(model_path, "wb") as file: + file.write(await resp.read()) + + logger.debug("Saved model to '{}'".format(os.path.abspath(model_path))) + + # return the new fingerprint + return resp.headers.get("ETag") + + except aiohttp.ClientError as e: + logger.debug( + "Tried to fetch model from server, but " + "couldn't reach server. We'll retry later... " + "Error: {}.".format(e) ) - with open(model_path, "wb") as file: - file.write(await resp.read()) - logger.debug("Saved model to '{}'".format(os.path.abspath(model_path))) - # return the new fingerprint - return resp.headers.get("ETag") - except aiohttp.ClientError as e: - logger.debug( - "Tried to fetch model from server, but " - "couldn't reach server. We'll retry later... " - "Error: {}.".format(e) - ) - return None + return None async def _run_model_pulling_worker(model_server: EndpointConfig, agent: Agent) -> None: diff --git a/rasa/core/constants.py b/rasa/core/constants.py index 973e4e7b3a99..40d65c3299bb 100644 --- a/rasa/core/constants.py +++ b/rasa/core/constants.py @@ -24,6 +24,8 @@ DEFAULT_LOCK_LIFETIME = 60 # in seconds +DEFAULT_KEEP_ALIVE_TIMEOUT = 120 # in seconds + BEARER_TOKEN_PREFIX = "Bearer " # The lowest priority is intended to be used by machine learning policies. diff --git a/rasa/core/run.py b/rasa/core/run.py index 5270162809dd..3a8133613c3f 100644 --- a/rasa/core/run.py +++ b/rasa/core/run.py @@ -1,9 +1,19 @@ import asyncio import logging import uuid +import platform import os from functools import partial -from typing import Any, List, Optional, TYPE_CHECKING, Text, Union, Dict +from typing import ( + Any, + Callable, + List, + Optional, + Text, + Tuple, + Union, + Dict, +) import rasa.core.utils from rasa.plugin import plugin_manager @@ -23,8 +33,6 @@ from sanic import Sanic from asyncio import AbstractEventLoop -if TYPE_CHECKING: - from aiohttp import ClientSession logger = logging.getLogger() # get the root logger @@ -80,6 +88,14 @@ def _create_app_without_api(cors: Optional[Union[Text, List[Text]]] = None) -> S return app +def _is_apple_silicon_system() -> bool: + # check if the system is MacOS + if platform.system().lower() != "darwin": + return False + # check for arm architecture, indicating apple silicon + return platform.machine().startswith("arm") or os.uname().machine.startswith("arm") + + def configure_app( input_channels: Optional[List["InputChannel"]] = None, cors: Optional[Union[Text, List[Text], None]] = None, @@ -99,6 +115,9 @@ def configure_app( syslog_port: Optional[int] = None, syslog_protocol: Optional[Text] = None, request_timeout: Optional[int] = None, + server_listeners: Optional[List[Tuple[Callable, Text]]] = None, + use_uvloop: Optional[bool] = True, + keep_alive_timeout: int = constants.DEFAULT_KEEP_ALIVE_TIMEOUT, ) -> Sanic: """Run the agent.""" rasa.core.utils.configure_file_logging( @@ -118,6 +137,14 @@ def configure_app( else: app = _create_app_without_api(cors) + app.config.KEEP_ALIVE_TIMEOUT = keep_alive_timeout + if _is_apple_silicon_system() or not use_uvloop: + app.config.USE_UVLOOP = False + # some library still sets the loop to uvloop, even if disabled for sanic + # using uvloop leads to breakingio errors, see + # https://rasahq.atlassian.net/browse/ENG-667 + asyncio.set_event_loop_policy(None) + if input_channels: channels.channel.register(input_channels, app, route=route) else: @@ -150,6 +177,10 @@ async def run_cmdline_io(running_app: Sanic) -> None: app.add_task(run_cmdline_io) + if server_listeners: + for (listener, event) in server_listeners: + app.register_listener(listener, event) + return app @@ -179,6 +210,7 @@ def serve_application( syslog_port: Optional[int] = None, syslog_protocol: Optional[Text] = None, request_timeout: Optional[int] = None, + server_listeners: Optional[List[Tuple[Callable, Text]]] = None, ) -> None: """Run the API entrypoint.""" if not channel and not credentials: @@ -204,6 +236,7 @@ def serve_application( syslog_port=syslog_port, syslog_protocol=syslog_protocol, request_timeout=request_timeout, + server_listeners=server_listeners, ) ssl_context = server.create_ssl_context( @@ -217,7 +250,7 @@ def serve_application( partial(load_agent_on_start, model_path, endpoints, remote_storage), "before_server_start", ) - app.register_listener(create_connection_pools, "after_server_start") + app.register_listener(close_resources, "after_server_stop") number_of_workers = rasa.core.utils.number_of_sanic_workers( @@ -279,44 +312,3 @@ async def close_resources(app: Sanic, _: AbstractEventLoop) -> None: event_broker = current_agent.tracker_store.event_broker if event_broker: await event_broker.close() - - action_endpoint = current_agent.action_endpoint - if action_endpoint: - await action_endpoint.session.close() - - model_server = current_agent.model_server - if model_server: - await model_server.session.close() - - -async def create_connection_pools(app: Sanic, _: AbstractEventLoop) -> None: - """Create connection pools for the agent's action server and model server.""" - current_agent = getattr(app.ctx, "agent", None) - if not current_agent: - logger.debug("No agent found after server start.") - return None - - create_action_endpoint_connection_pool(current_agent) - create_model_server_connection_pool(current_agent) - - return None - - -def create_action_endpoint_connection_pool(agent: Agent) -> Optional["ClientSession"]: - """Create a connection pool for the action endpoint.""" - action_endpoint = agent.action_endpoint - if not action_endpoint: - logger.debug("No action endpoint found after server start.") - return None - - return action_endpoint.session - - -def create_model_server_connection_pool(agent: Agent) -> Optional["ClientSession"]: - """Create a connection pool for the model server.""" - model_server = agent.model_server - if not model_server: - logger.debug("No model server endpoint found after server start.") - return None - - return model_server.session diff --git a/rasa/utils/endpoints.py b/rasa/utils/endpoints.py index 5e1032778e6b..31d1ea7228bc 100644 --- a/rasa/utils/endpoints.py +++ b/rasa/utils/endpoints.py @@ -1,8 +1,6 @@ import ssl -from functools import cached_property import aiohttp -import logging import os from aiohttp.client_exceptions import ContentTypeError from sanic.request import Request @@ -11,10 +9,11 @@ from rasa.shared.exceptions import FileNotFoundException import rasa.shared.utils.io import rasa.utils.io +import structlog from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT -logger = logging.getLogger(__name__) +structlogger = structlog.get_logger() def read_endpoint_config( @@ -32,9 +31,13 @@ def read_endpoint_config( return EndpointConfig.from_dict(content[endpoint_type]) except FileNotFoundError: - logger.error( - "Failed to read endpoint configuration " - "from {}. No such file.".format(os.path.abspath(filename)) + structlogger.error( + "endpoint.read.failed_no_such_file", + filename=os.path.abspath(filename), + event_info=( + "Failed to read endpoint configuration file - " + "the file was not found." + ), ) return None @@ -56,9 +59,13 @@ def concat_url(base: Text, subpath: Optional[Text]) -> Text: """ if not subpath: if base.endswith("/"): - logger.debug( - f"The URL '{base}' has a trailing slash. Please make sure the " - f"target server supports trailing slashes for this endpoint." + structlogger.debug( + "endpoint.concat_url.trailing_slash", + url=base, + event_info=( + "The URL has a trailing slash. Please make sure the " + "target server supports trailing slashes for this endpoint." + ), ) return base @@ -95,7 +102,6 @@ def __init__( self.cafile = cafile self.kwargs = kwargs - @cached_property def session(self) -> aiohttp.ClientSession: """Creates and returns a configured aiohttp client session.""" # create authentication parameters @@ -164,23 +170,26 @@ async def request( f"'{os.path.abspath(self.cafile)}' does not exist." ) from e - async with self.session.request( - method, - url, - headers=headers, - params=self.combine_parameters(kwargs), - compress=compress, - ssl=sslcontext, - **kwargs, - ) as response: - if response.status >= 400: - raise ClientResponseError( - response.status, response.reason, await response.content.read() - ) - try: - return await response.json() - except ContentTypeError: - return None + async with self.session() as session: + async with session.request( + method, + url, + headers=headers, + params=self.combine_parameters(kwargs), + compress=compress, + ssl=sslcontext, + **kwargs, + ) as response: + if response.status >= 400: + raise ClientResponseError( + response.status, + response.reason, + await response.content.read(), + ) + try: + return await response.json() + except ContentTypeError: + return None @classmethod def from_dict(cls, data: Dict[Text, Any]) -> "EndpointConfig": @@ -263,7 +272,7 @@ def float_arg( try: return float(str(arg)) except (ValueError, TypeError): - logger.warning(f"Failed to convert '{arg}' to float.") + structlogger.warning("endpoint.float_arg.convert_failed", arg=arg, key=key) return default @@ -291,5 +300,6 @@ def int_arg( try: return int(str(arg)) except (ValueError, TypeError): - logger.warning(f"Failed to convert '{arg}' to int.") + + structlogger.warning("endpoint.int_arg.convert_failed", arg=arg, key=key) return default diff --git a/tests/core/test_run.py b/tests/core/test_run.py index 1ac276d43772..8eda15058c0d 100644 --- a/tests/core/test_run.py +++ b/tests/core/test_run.py @@ -1,7 +1,6 @@ import warnings from unittest.mock import Mock -import aiohttp import pytest from typing import Text @@ -84,8 +83,6 @@ async def test_close_resources(loop: AbstractEventLoop): broker = SQLEventBroker() app = Mock() app.ctx.agent.tracker_store.event_broker = broker - app.ctx.agent.action_endpoint.session = aiohttp.ClientSession() - app.ctx.agent.model_server.session = aiohttp.ClientSession() with warnings.catch_warnings() as record: await run.close_resources(app, loop) diff --git a/tests/utils/test_endpoints.py b/tests/utils/test_endpoints.py index 071e54ee9318..711f2fd25faa 100644 --- a/tests/utils/test_endpoints.py +++ b/tests/utils/test_endpoints.py @@ -1,4 +1,4 @@ -import logging +import structlog from pathlib import Path from typing import Text, Optional, Union from unittest.mock import Mock @@ -35,13 +35,14 @@ def test_concat_url(base, subpath, expected_result): assert endpoint_utils.concat_url(base, subpath) == expected_result -def test_warning_for_base_paths_with_trailing_slash(caplog): +def test_warning_for_base_paths_with_trailing_slash(): test_path = "base/" - - with caplog.at_level(logging.DEBUG, logger="rasa.utils.endpoints"): + with structlog.testing.capture_logs() as caplog: assert endpoint_utils.concat_url(test_path, None) == test_path - assert len(caplog.records) == 1 + assert len(caplog) == 1 + assert caplog[0]["event"] == "endpoint.concat_url.trailing_slash" + assert caplog[0]["log_level"] == "debug" async def test_endpoint_config(): @@ -88,7 +89,7 @@ async def test_endpoint_config(): # unfortunately, the mock library won't report any headers stored on # the session object, so we need to verify them separately - async with endpoint.session as s: + async with endpoint.session() as s: assert s._default_headers.get("X-Powered-By") == "Rasa" assert s._default_auth.login == "user" assert s._default_auth.password == "pass" @@ -231,32 +232,3 @@ def test_int_arg(value: Optional[Union[int, str]], default: int, expected_result if value is not None: request.args = {"key": value} assert endpoint_utils.int_arg(request, "key", default) == expected_result - - -async def test_endpoint_config_caches_session() -> None: - """Test that the EndpointConfig session is cached. - - Assert identity of the session object, which should not be recreated when calling - the property `session` multiple times. - """ - endpoint = endpoint_utils.EndpointConfig("https://example.com/") - session = endpoint.session - - assert endpoint.session is session - - # teardown - await endpoint.session.close() - - -async def test_endpoint_config_constructor_does_not_create_session_cached_property() -> None: # noqa: E501 - """Test that the instantiation of EndpointConfig does not create the session cached property.""" # noqa: E501 - endpoint = endpoint_utils.EndpointConfig("https://example.com/") - - assert endpoint.__dict__.get("url") == "https://example.com/" - assert endpoint.__dict__.get("session") is None - - # the property is created when it is accessed - async with endpoint.session as session: - assert session is not None - - assert endpoint.__dict__.get("session") is session