diff --git a/.github/workflows/security-scans.yml b/.github/workflows/security-scans.yml index 7348db48311f..b4cbf69cc92f 100644 --- a/.github/workflows/security-scans.yml +++ b/.github/workflows/security-scans.yml @@ -129,24 +129,3 @@ jobs: - name: Run Bandit 🔪 if: needs.changes.outputs.backend == 'true' run: make lint-security - - snyk: - runs-on: ubuntu-22.04 - steps: - - uses: actions/checkout@ac593985615ec2ede58e132d2e21d2b1cbd6127c - - name: Run Snyk Open Source to check for Python vulnerabilities - uses: snyk/actions/python-3.8@master - continue-on-error: true - env: - SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }} - with: - command: monitor - args: --all-projects --org=rasa --skip-unresolved - - name: Run Snyk Open Source to check for JS vulnerabilities - uses: snyk/actions/node@master - continue-on-error: true - env: - SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }} - with: - command: monitor - args: --org=rasa --yarn-workspaces --strict-out-of-sync=false --prune-repeated-subdependencies diff --git a/CHANGELOG.mdx b/CHANGELOG.mdx index 14c7da88d3b8..fce125567d98 100644 --- a/CHANGELOG.mdx +++ b/CHANGELOG.mdx @@ -16,6 +16,27 @@ https://github.com/RasaHQ/rasa/tree/main/changelog/ . --> +## [3.6.15] - 2023-11-30 + +Rasa 3.6.15 (2023-11-30) +### Bugfixes +- [#12965](https://github.com/rasahq/rasa/issues/12965): Fixed connection timeout to action server by setting KEEP_ALIVE_TIMEOUT to 120, and reverting changes introduced in #12886. + + +## [3.6.14] - 2023-11-17 + +Rasa 3.6.14 (2023-11-17) +### Bugfixes +- [#12948](https://github.com/rasahq/rasa/issues/12948): Fixed UnexpecTEDIntentlessPolicy training errors that resulted from a change to batching behavior. Changed the batching behavior back to the original for all components. Made the changed batching behavior accessible in DietClassifier using `drop_small_last_batch: True`. + + +## [3.6.13] - 2023-10-23 + +Rasa 3.6.13 (2023-10-23) +### Bugfixes +- [#12927](https://github.com/rasahq/rasa/issues/12927): Fix wrong conflicts that occur when rasa validate stories is run with slots that have active_loop set to null in mapping conditions. + + ## [3.6.12] - 2023-10-10 Rasa 3.6.12 (2023-10-10) diff --git a/changelog/12927.bugfix.md b/changelog/12927.bugfix.md deleted file mode 100644 index 7b9ff9d69410..000000000000 --- a/changelog/12927.bugfix.md +++ /dev/null @@ -1 +0,0 @@ -Fix wrong conflicts that occur when rasa validate stories is run with slots that have active_loop set to null in mapping conditions. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index efa2bce4abc0..bd34f9e7ec4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ exclude = "((.eggs | .git | .pytest_cache | build | dist))" [tool.poetry] name = "rasa" -version = "3.6.12" +version = "3.6.15" description = "Open source machine learning framework to automate text- and voice-based conversations: NLU, dialogue management, connect to Slack, Facebook, and more - Create chatbots and voice assistants" authors = [ "Rasa Technologies GmbH ",] maintainers = [ "Tom Bocklisch ",] diff --git a/rasa/core/agent.py b/rasa/core/agent.py index bf3d42236e70..47e3360f6a5b 100644 --- a/rasa/core/agent.py +++ b/rasa/core/agent.py @@ -112,53 +112,59 @@ async def _pull_model_and_fingerprint( logger.debug(f"Requesting model from server {model_server.url}...") - try: - params = model_server.combine_parameters() - async with model_server.session.request( - "GET", - model_server.url, - timeout=DEFAULT_REQUEST_TIMEOUT, - headers=headers, - params=params, - ) as resp: - if resp.status in [204, 304]: - logger.debug( - "Model server returned {} status code, " - "indicating that no new model is available. " - "Current fingerprint: {}" - "".format(resp.status, fingerprint) - ) - return None - elif resp.status == 404: - logger.debug( - "Model server could not find a model at the requested " - "endpoint '{}'. It's possible that no model has been " - "trained, or that the requested tag hasn't been " - "assigned.".format(model_server.url) - ) - return None - elif resp.status != 200: - logger.debug( - "Tried to fetch model from server, but server response " - "status code is {}. We'll retry later..." - "".format(resp.status) + async with model_server.session() as session: + try: + params = model_server.combine_parameters() + async with session.request( + "GET", + model_server.url, + timeout=DEFAULT_REQUEST_TIMEOUT, + headers=headers, + params=params, + ) as resp: + + if resp.status in [204, 304]: + logger.debug( + "Model server returned {} status code, " + "indicating that no new model is available. " + "Current fingerprint: {}" + "".format(resp.status, fingerprint) + ) + return None + elif resp.status == 404: + logger.debug( + "Model server could not find a model at the requested " + "endpoint '{}'. It's possible that no model has been " + "trained, or that the requested tag hasn't been " + "assigned.".format(model_server.url) + ) + return None + elif resp.status != 200: + logger.debug( + "Tried to fetch model from server, but server response " + "status code is {}. We'll retry later..." + "".format(resp.status) + ) + return None + + model_path = Path(model_directory) / resp.headers.get( + "filename", "model.tar.gz" ) - return None - model_path = Path(model_directory) / resp.headers.get( - "filename", "model.tar.gz" + with open(model_path, "wb") as file: + file.write(await resp.read()) + + logger.debug("Saved model to '{}'".format(os.path.abspath(model_path))) + + # return the new fingerprint + return resp.headers.get("ETag") + + except aiohttp.ClientError as e: + logger.debug( + "Tried to fetch model from server, but " + "couldn't reach server. We'll retry later... " + "Error: {}.".format(e) ) - with open(model_path, "wb") as file: - file.write(await resp.read()) - logger.debug("Saved model to '{}'".format(os.path.abspath(model_path))) - # return the new fingerprint - return resp.headers.get("ETag") - except aiohttp.ClientError as e: - logger.debug( - "Tried to fetch model from server, but " - "couldn't reach server. We'll retry later... " - "Error: {}.".format(e) - ) - return None + return None async def _run_model_pulling_worker(model_server: EndpointConfig, agent: Agent) -> None: diff --git a/rasa/core/constants.py b/rasa/core/constants.py index 973e4e7b3a99..40d65c3299bb 100644 --- a/rasa/core/constants.py +++ b/rasa/core/constants.py @@ -24,6 +24,8 @@ DEFAULT_LOCK_LIFETIME = 60 # in seconds +DEFAULT_KEEP_ALIVE_TIMEOUT = 120 # in seconds + BEARER_TOKEN_PREFIX = "Bearer " # The lowest priority is intended to be used by machine learning policies. diff --git a/rasa/core/run.py b/rasa/core/run.py index 5270162809dd..3a8133613c3f 100644 --- a/rasa/core/run.py +++ b/rasa/core/run.py @@ -1,9 +1,19 @@ import asyncio import logging import uuid +import platform import os from functools import partial -from typing import Any, List, Optional, TYPE_CHECKING, Text, Union, Dict +from typing import ( + Any, + Callable, + List, + Optional, + Text, + Tuple, + Union, + Dict, +) import rasa.core.utils from rasa.plugin import plugin_manager @@ -23,8 +33,6 @@ from sanic import Sanic from asyncio import AbstractEventLoop -if TYPE_CHECKING: - from aiohttp import ClientSession logger = logging.getLogger() # get the root logger @@ -80,6 +88,14 @@ def _create_app_without_api(cors: Optional[Union[Text, List[Text]]] = None) -> S return app +def _is_apple_silicon_system() -> bool: + # check if the system is MacOS + if platform.system().lower() != "darwin": + return False + # check for arm architecture, indicating apple silicon + return platform.machine().startswith("arm") or os.uname().machine.startswith("arm") + + def configure_app( input_channels: Optional[List["InputChannel"]] = None, cors: Optional[Union[Text, List[Text], None]] = None, @@ -99,6 +115,9 @@ def configure_app( syslog_port: Optional[int] = None, syslog_protocol: Optional[Text] = None, request_timeout: Optional[int] = None, + server_listeners: Optional[List[Tuple[Callable, Text]]] = None, + use_uvloop: Optional[bool] = True, + keep_alive_timeout: int = constants.DEFAULT_KEEP_ALIVE_TIMEOUT, ) -> Sanic: """Run the agent.""" rasa.core.utils.configure_file_logging( @@ -118,6 +137,14 @@ def configure_app( else: app = _create_app_without_api(cors) + app.config.KEEP_ALIVE_TIMEOUT = keep_alive_timeout + if _is_apple_silicon_system() or not use_uvloop: + app.config.USE_UVLOOP = False + # some library still sets the loop to uvloop, even if disabled for sanic + # using uvloop leads to breakingio errors, see + # https://rasahq.atlassian.net/browse/ENG-667 + asyncio.set_event_loop_policy(None) + if input_channels: channels.channel.register(input_channels, app, route=route) else: @@ -150,6 +177,10 @@ async def run_cmdline_io(running_app: Sanic) -> None: app.add_task(run_cmdline_io) + if server_listeners: + for (listener, event) in server_listeners: + app.register_listener(listener, event) + return app @@ -179,6 +210,7 @@ def serve_application( syslog_port: Optional[int] = None, syslog_protocol: Optional[Text] = None, request_timeout: Optional[int] = None, + server_listeners: Optional[List[Tuple[Callable, Text]]] = None, ) -> None: """Run the API entrypoint.""" if not channel and not credentials: @@ -204,6 +236,7 @@ def serve_application( syslog_port=syslog_port, syslog_protocol=syslog_protocol, request_timeout=request_timeout, + server_listeners=server_listeners, ) ssl_context = server.create_ssl_context( @@ -217,7 +250,7 @@ def serve_application( partial(load_agent_on_start, model_path, endpoints, remote_storage), "before_server_start", ) - app.register_listener(create_connection_pools, "after_server_start") + app.register_listener(close_resources, "after_server_stop") number_of_workers = rasa.core.utils.number_of_sanic_workers( @@ -279,44 +312,3 @@ async def close_resources(app: Sanic, _: AbstractEventLoop) -> None: event_broker = current_agent.tracker_store.event_broker if event_broker: await event_broker.close() - - action_endpoint = current_agent.action_endpoint - if action_endpoint: - await action_endpoint.session.close() - - model_server = current_agent.model_server - if model_server: - await model_server.session.close() - - -async def create_connection_pools(app: Sanic, _: AbstractEventLoop) -> None: - """Create connection pools for the agent's action server and model server.""" - current_agent = getattr(app.ctx, "agent", None) - if not current_agent: - logger.debug("No agent found after server start.") - return None - - create_action_endpoint_connection_pool(current_agent) - create_model_server_connection_pool(current_agent) - - return None - - -def create_action_endpoint_connection_pool(agent: Agent) -> Optional["ClientSession"]: - """Create a connection pool for the action endpoint.""" - action_endpoint = agent.action_endpoint - if not action_endpoint: - logger.debug("No action endpoint found after server start.") - return None - - return action_endpoint.session - - -def create_model_server_connection_pool(agent: Agent) -> Optional["ClientSession"]: - """Create a connection pool for the model server.""" - model_server = agent.model_server - if not model_server: - logger.debug("No model server endpoint found after server start.") - return None - - return model_server.session diff --git a/rasa/nlu/classifiers/diet_classifier.py b/rasa/nlu/classifiers/diet_classifier.py index 1cc65c89b3c9..bea4735da6fe 100644 --- a/rasa/nlu/classifiers/diet_classifier.py +++ b/rasa/nlu/classifiers/diet_classifier.py @@ -50,6 +50,7 @@ from rasa.shared.nlu.training_data.training_data import TrainingData from rasa.shared.nlu.training_data.message import Message from rasa.utils.tensorflow.constants import ( + DROP_SMALL_LAST_BATCH, LABEL, IDS, HIDDEN_LAYERS_SIZES, @@ -288,6 +289,9 @@ def get_default_config() -> Dict[Text, Any]: # a few steps, as the compilation of the graph tends to take more time than # running it. It is recommended to not adjust the optimization parameter. RUN_EAGERLY: False, + # Determines whether the last batch should be dropped if it contains fewer + # than half a batch size of examples + DROP_SMALL_LAST_BATCH: False, } def __init__( @@ -931,6 +935,7 @@ def train(self, training_data: TrainingData) -> Resource: self.component_config[BATCH_STRATEGY], self.component_config[EVAL_NUM_EXAMPLES], self.component_config[RANDOM_SEED], + drop_small_last_batch=self.component_config[DROP_SMALL_LAST_BATCH], ) callbacks = train_utils.create_common_callbacks( self.component_config[EPOCHS], diff --git a/rasa/utils/endpoints.py b/rasa/utils/endpoints.py index 5e1032778e6b..31d1ea7228bc 100644 --- a/rasa/utils/endpoints.py +++ b/rasa/utils/endpoints.py @@ -1,8 +1,6 @@ import ssl -from functools import cached_property import aiohttp -import logging import os from aiohttp.client_exceptions import ContentTypeError from sanic.request import Request @@ -11,10 +9,11 @@ from rasa.shared.exceptions import FileNotFoundException import rasa.shared.utils.io import rasa.utils.io +import structlog from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT -logger = logging.getLogger(__name__) +structlogger = structlog.get_logger() def read_endpoint_config( @@ -32,9 +31,13 @@ def read_endpoint_config( return EndpointConfig.from_dict(content[endpoint_type]) except FileNotFoundError: - logger.error( - "Failed to read endpoint configuration " - "from {}. No such file.".format(os.path.abspath(filename)) + structlogger.error( + "endpoint.read.failed_no_such_file", + filename=os.path.abspath(filename), + event_info=( + "Failed to read endpoint configuration file - " + "the file was not found." + ), ) return None @@ -56,9 +59,13 @@ def concat_url(base: Text, subpath: Optional[Text]) -> Text: """ if not subpath: if base.endswith("/"): - logger.debug( - f"The URL '{base}' has a trailing slash. Please make sure the " - f"target server supports trailing slashes for this endpoint." + structlogger.debug( + "endpoint.concat_url.trailing_slash", + url=base, + event_info=( + "The URL has a trailing slash. Please make sure the " + "target server supports trailing slashes for this endpoint." + ), ) return base @@ -95,7 +102,6 @@ def __init__( self.cafile = cafile self.kwargs = kwargs - @cached_property def session(self) -> aiohttp.ClientSession: """Creates and returns a configured aiohttp client session.""" # create authentication parameters @@ -164,23 +170,26 @@ async def request( f"'{os.path.abspath(self.cafile)}' does not exist." ) from e - async with self.session.request( - method, - url, - headers=headers, - params=self.combine_parameters(kwargs), - compress=compress, - ssl=sslcontext, - **kwargs, - ) as response: - if response.status >= 400: - raise ClientResponseError( - response.status, response.reason, await response.content.read() - ) - try: - return await response.json() - except ContentTypeError: - return None + async with self.session() as session: + async with session.request( + method, + url, + headers=headers, + params=self.combine_parameters(kwargs), + compress=compress, + ssl=sslcontext, + **kwargs, + ) as response: + if response.status >= 400: + raise ClientResponseError( + response.status, + response.reason, + await response.content.read(), + ) + try: + return await response.json() + except ContentTypeError: + return None @classmethod def from_dict(cls, data: Dict[Text, Any]) -> "EndpointConfig": @@ -263,7 +272,7 @@ def float_arg( try: return float(str(arg)) except (ValueError, TypeError): - logger.warning(f"Failed to convert '{arg}' to float.") + structlogger.warning("endpoint.float_arg.convert_failed", arg=arg, key=key) return default @@ -291,5 +300,6 @@ def int_arg( try: return int(str(arg)) except (ValueError, TypeError): - logger.warning(f"Failed to convert '{arg}' to int.") + + structlogger.warning("endpoint.int_arg.convert_failed", arg=arg, key=key) return default diff --git a/rasa/utils/tensorflow/constants.py b/rasa/utils/tensorflow/constants.py index 047db9878c67..39d5ea6d0560 100644 --- a/rasa/utils/tensorflow/constants.py +++ b/rasa/utils/tensorflow/constants.py @@ -113,3 +113,4 @@ USE_GPU = "use_gpu" RUN_EAGERLY = "run_eagerly" +DROP_SMALL_LAST_BATCH = "drop_small_last_batch" diff --git a/rasa/utils/tensorflow/data_generator.py b/rasa/utils/tensorflow/data_generator.py index a696f607c026..e54b95dad335 100644 --- a/rasa/utils/tensorflow/data_generator.py +++ b/rasa/utils/tensorflow/data_generator.py @@ -344,6 +344,7 @@ def __init__( epochs: int = 1, batch_strategy: Text = SEQUENCE, shuffle: bool = True, + drop_small_last_batch: bool = False, ): """Initializes the increasing batch size data generator. @@ -353,6 +354,8 @@ def __init__( epochs: The total number of epochs. batch_strategy: The batch strategy. shuffle: If 'True', data will be shuffled. + drop_small_last_batch: if 'True', the last batch in an epoch will be dropped + if it has less examples than half the batch size """ super().__init__(model_data, batch_size, batch_strategy, shuffle) @@ -370,6 +373,7 @@ def __init__( self._current_batch_size = 0 # create separate data variable that will store modified data for each batch self._data: Data = {} + self.drop_small_last_batch = drop_small_last_batch self.on_epoch_end() def __len__(self) -> int: @@ -381,11 +385,16 @@ def __len__(self) -> int: # data was rebalanced, so need to recalculate number of examples num_examples = self.model_data.number_of_examples(self._data) batch_size = self._current_batch_size - # keep last batch only if it has at least half a batch size of examples - last_batch_half_full = num_examples % batch_size >= math.ceil(batch_size / 2) - num_batches = num_examples // batch_size + int(last_batch_half_full) - # Return at least 1 if there is an example - return max(num_batches, int(num_examples > 0)) + if self.drop_small_last_batch: + # keep last batch only if it has at least half a batch size of examples + last_batch_half_full = num_examples % batch_size >= math.ceil( + batch_size / 2 + ) + num_batches = num_examples // batch_size + int(last_batch_half_full) + # Return at least 1 if there is an example + return max(num_batches, int(num_examples > 0)) + else: + return num_examples // batch_size + int(num_examples % batch_size > 0) def __getitem__(self, index: int) -> Tuple[Any, Any]: """Gets batch at position `index`. diff --git a/rasa/utils/train_utils.py b/rasa/utils/train_utils.py index 36de0370d210..764507d7e39d 100644 --- a/rasa/utils/train_utils.py +++ b/rasa/utils/train_utils.py @@ -302,6 +302,7 @@ def create_data_generators( eval_num_examples: int = 0, random_seed: Optional[int] = None, shuffle: bool = True, + drop_small_last_batch: bool = False, ) -> Tuple[RasaBatchDataGenerator, Optional[RasaBatchDataGenerator]]: """Create data generators for train and optional validation data. @@ -313,6 +314,8 @@ def create_data_generators( eval_num_examples: Number of examples to use for validation data. random_seed: The random seed. shuffle: Whether to shuffle data inside the data generator. + drop_small_last_batch: whether to drop the last batch if it has fewer than half + a batch size of examples Returns: The training data generator and optional validation data generator. @@ -328,6 +331,7 @@ def create_data_generators( epochs=epochs, batch_strategy=batch_strategy, shuffle=shuffle, + drop_small_last_batch=drop_small_last_batch, ) data_generator = RasaBatchDataGenerator( @@ -336,6 +340,7 @@ def create_data_generators( epochs=epochs, batch_strategy=batch_strategy, shuffle=shuffle, + drop_small_last_batch=drop_small_last_batch, ) return data_generator, validation_data_generator diff --git a/rasa/version.py b/rasa/version.py index 98e1d7a3ca12..3d8e9f0ee007 100644 --- a/rasa/version.py +++ b/rasa/version.py @@ -1,3 +1,3 @@ # this file will automatically be changed, # do not add anything but the version number here! -__version__ = "3.6.12" +__version__ = "3.6.15" diff --git a/scripts/release.py b/scripts/release.py index d1ac98325f80..068bd14e5fd2 100644 --- a/scripts/release.py +++ b/scripts/release.py @@ -30,6 +30,10 @@ RELEASE_BRANCH_PATTERN = re.compile(r"^\d+\.\d+\.x$") +PUBLIC_REMOTE = "public" +DEFAULT_REMOTE = "origin" +FIRST_CALM_VERSION = "3.7.0" + def create_argument_parser() -> argparse.ArgumentParser: """Parse all the command line arguments for the release script.""" @@ -247,9 +251,9 @@ def create_commit(version: Version) -> None: check_call(["git", "commit", "-m", f"prepared release of version {version}"]) -def push_changes() -> None: - """Pushes the current branch to origin.""" - check_call(["git", "push", "origin", "HEAD"]) +def push_changes(remote: str = DEFAULT_REMOTE) -> None: + """Pushes the current branch to the specified remote.""" + check_call(["git", "push", remote, "HEAD"]) def ensure_clean_git() -> None: @@ -337,10 +341,11 @@ def main(args: argparse.Namespace) -> None: # never update changelog on a prerelease version generate_changelog(version) + remote = PUBLIC_REMOTE if str(version) < FIRST_CALM_VERSION else DEFAULT_REMOTE # alpha workflow on feature branch when a version bump is required if version.is_alpha and not git_current_branch_is_main_or_release(): create_commit(version) - push_changes() + push_changes(remote) print_done_message_same_branch(version) else: @@ -348,7 +353,7 @@ def main(args: argparse.Namespace) -> None: branch = create_release_branch(version) create_commit(version) - push_changes() + push_changes(remote) print_done_message(branch, base, version) diff --git a/tests/core/test_run.py b/tests/core/test_run.py index 1ac276d43772..8eda15058c0d 100644 --- a/tests/core/test_run.py +++ b/tests/core/test_run.py @@ -1,7 +1,6 @@ import warnings from unittest.mock import Mock -import aiohttp import pytest from typing import Text @@ -84,8 +83,6 @@ async def test_close_resources(loop: AbstractEventLoop): broker = SQLEventBroker() app = Mock() app.ctx.agent.tracker_store.event_broker = broker - app.ctx.agent.action_endpoint.session = aiohttp.ClientSession() - app.ctx.agent.model_server.session = aiohttp.ClientSession() with warnings.catch_warnings() as record: await run.close_resources(app, loop) diff --git a/tests/nlu/classifiers/test_diet_classifier.py b/tests/nlu/classifiers/test_diet_classifier.py index 1f0c37a85faa..1fd84fdac47d 100644 --- a/tests/nlu/classifiers/test_diet_classifier.py +++ b/tests/nlu/classifiers/test_diet_classifier.py @@ -971,24 +971,35 @@ async def test_no_bilou_when_entity_recognition_off( @pytest.mark.timeout(120, func_only=True) @pytest.mark.parametrize( - "batch_size, expected_num_batches", + "batch_size, expected_num_batches, drop_small_last_batch", # the training dataset has 48 NLU examples [ - (1, 48), - (8, 6), - (15, 3), - (16, 3), - (18, 3), - (20, 2), - (32, 2), - (64, 1), - (128, 1), - (256, 1), + (1, 48, True), + (8, 6, True), + (15, 3, True), + (16, 3, True), + (18, 3, True), + (20, 2, True), + (32, 2, True), + (64, 1, True), + (128, 1, True), + (256, 1, True), + (1, 48, False), + (8, 6, False), + (15, 4, False), + (16, 3, False), + (18, 3, False), + (20, 3, False), + (32, 2, False), + (64, 1, False), + (128, 1, False), + (256, 1, False), ], ) async def test_dropping_of_last_partial_batch( batch_size: int, expected_num_batches: int, + drop_small_last_batch: bool, create_diet: Callable[..., DIETClassifier], train_and_preprocess: Callable[..., Tuple[TrainingData, List[GraphComponent]]], ): @@ -1012,7 +1023,9 @@ async def test_dropping_of_last_partial_batch( ) model_data = diet.preprocess_train_data(training_data) - data_generator, _ = train_utils.create_data_generators(model_data, batch_size, 1) + data_generator, _ = train_utils.create_data_generators( + model_data, batch_size, 1, drop_small_last_batch=drop_small_last_batch + ) assert len(data_generator) == expected_num_batches @@ -1041,6 +1054,8 @@ async def test_dropping_of_last_partial_batch_empty_data( ) model_data = diet.preprocess_train_data(training_data) - data_generator, _ = train_utils.create_data_generators(model_data, 64, 1) + data_generator, _ = train_utils.create_data_generators( + model_data, 64, 1, drop_small_last_batch=True + ) assert len(data_generator) == 0 diff --git a/tests/utils/test_endpoints.py b/tests/utils/test_endpoints.py index 071e54ee9318..711f2fd25faa 100644 --- a/tests/utils/test_endpoints.py +++ b/tests/utils/test_endpoints.py @@ -1,4 +1,4 @@ -import logging +import structlog from pathlib import Path from typing import Text, Optional, Union from unittest.mock import Mock @@ -35,13 +35,14 @@ def test_concat_url(base, subpath, expected_result): assert endpoint_utils.concat_url(base, subpath) == expected_result -def test_warning_for_base_paths_with_trailing_slash(caplog): +def test_warning_for_base_paths_with_trailing_slash(): test_path = "base/" - - with caplog.at_level(logging.DEBUG, logger="rasa.utils.endpoints"): + with structlog.testing.capture_logs() as caplog: assert endpoint_utils.concat_url(test_path, None) == test_path - assert len(caplog.records) == 1 + assert len(caplog) == 1 + assert caplog[0]["event"] == "endpoint.concat_url.trailing_slash" + assert caplog[0]["log_level"] == "debug" async def test_endpoint_config(): @@ -88,7 +89,7 @@ async def test_endpoint_config(): # unfortunately, the mock library won't report any headers stored on # the session object, so we need to verify them separately - async with endpoint.session as s: + async with endpoint.session() as s: assert s._default_headers.get("X-Powered-By") == "Rasa" assert s._default_auth.login == "user" assert s._default_auth.password == "pass" @@ -231,32 +232,3 @@ def test_int_arg(value: Optional[Union[int, str]], default: int, expected_result if value is not None: request.args = {"key": value} assert endpoint_utils.int_arg(request, "key", default) == expected_result - - -async def test_endpoint_config_caches_session() -> None: - """Test that the EndpointConfig session is cached. - - Assert identity of the session object, which should not be recreated when calling - the property `session` multiple times. - """ - endpoint = endpoint_utils.EndpointConfig("https://example.com/") - session = endpoint.session - - assert endpoint.session is session - - # teardown - await endpoint.session.close() - - -async def test_endpoint_config_constructor_does_not_create_session_cached_property() -> None: # noqa: E501 - """Test that the instantiation of EndpointConfig does not create the session cached property.""" # noqa: E501 - endpoint = endpoint_utils.EndpointConfig("https://example.com/") - - assert endpoint.__dict__.get("url") == "https://example.com/" - assert endpoint.__dict__.get("session") is None - - # the property is created when it is accessed - async with endpoint.session as session: - assert session is not None - - assert endpoint.__dict__.get("session") is session