From 05d627aef49a3f936857a5d31bb419310b30470d Mon Sep 17 00:00:00 2001 From: "raoul.vonmetzen@telekom.de" Date: Mon, 8 Jan 2024 14:48:49 +0100 Subject: [PATCH] speed up domain loading Domain loading is extremely slow for large bots - about 5 Mins for a 3.3 MBytes domain (including responses), and that is due to multiple times loading and parsing the same file. This commit addresses that, be removing the validation if loading from a model tar.gz and loading the file(s) only once if reading from a directory. --- .../providers/domain_provider.py | 2 +- rasa/shared/core/domain.py | 36 +++++++++++++------ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/rasa/graph_components/providers/domain_provider.py b/rasa/graph_components/providers/domain_provider.py index e959c7ef0f37..3b9350928187 100644 --- a/rasa/graph_components/providers/domain_provider.py +++ b/rasa/graph_components/providers/domain_provider.py @@ -45,7 +45,7 @@ def load( ) -> DomainProvider: """Creates provider using a persisted version of itself.""" with model_storage.read_from(resource) as resource_directory: - domain = Domain.from_path(resource_directory) + domain = Domain.from_path(resource_directory, is_validated=True) return cls(model_storage, resource, domain) def _persist(self, domain: Domain) -> None: diff --git a/rasa/shared/core/domain.py b/rasa/shared/core/domain.py index 8ec59cd6f111..2532e7644288 100644 --- a/rasa/shared/core/domain.py +++ b/rasa/shared/core/domain.py @@ -188,20 +188,21 @@ def load(cls, paths: Union[List[Union[Path, Text]], Text, Path]) -> "Domain": return domain @classmethod - def from_path(cls, path: Union[Text, Path]) -> "Domain": + def from_path(cls, path: Union[Text, Path], is_validated: bool = False) -> "Domain": """Loads the `Domain` from a path.""" + logger.debug(f"Loading from {path}") path = os.path.abspath(path) if os.path.isfile(path): domain = cls.from_file(path) elif os.path.isdir(path): - domain = cls.from_directory(path) + domain = cls.from_directory(path, is_validated=is_validated) else: raise InvalidDomain( "Failed to load domain specification from '{}'. " "File not found!".format(os.path.abspath(path)) ) - + logger.debug(f"done loading domain from {path}") return domain @classmethod @@ -287,20 +288,30 @@ def _get_session_config(session_config: Dict) -> SessionConfig: return SessionConfig(session_expiration_time_min, carry_over_slots) @classmethod - def from_directory(cls, path: Text) -> "Domain": + def from_directory(cls, path: Text, is_validated: bool = False) -> "Domain": """Loads and merges multiple domain files recursively from a directory tree.""" combined: Dict[Text, Any] = {} for root, _, files in os.walk(path, followlinks=True): for file in files: + logger.debug(f"Processing {file=}") full_path = os.path.join(root, file) - if Domain.is_domain_file(full_path): - _ = Domain.from_file(full_path) # does the validation here only - other_dict = rasa.shared.utils.io.read_yaml( - rasa.shared.utils.io.read_file(full_path) - ) + logger.debug(f"Checking file type of {file=}") + if other_dict := Domain.is_domain_file(full_path): + if not is_validated: + logger.debug(f"Validating {file=}") + _ = Domain.from_dict( + other_dict + ) # does the validation here only + # logger.debug(f"Reading {file=}") + # other_dict = rasa.shared.utils.io.read_yaml( + # rasa.shared.utils.io.read_file(full_path) + # ) + logger.debug(f"Merging {file=}") combined = Domain.merge_domain_dicts(other_dict, combined) + logger.debug("Building domain from dict.") domain = Domain.from_dict(combined) + logger.debug(f"Merged domain from directory {path=}") return domain def merge( @@ -1802,7 +1813,7 @@ def is_domain_file(filename: Union[Text, Path]) -> bool: filename: Path of the file which should be checked. Returns: - `True` if it's a domain file, otherwise `False`. + a domain file, otherwise `False`. Raises: YamlException: if the file seems to be a YAML file (extension) but @@ -1824,7 +1835,10 @@ def is_domain_file(filename: Union[Text, Path]) -> bool: ) return False - return any(key in content for key in ALL_DOMAIN_KEYS) + if any(key in content for key in ALL_DOMAIN_KEYS): + return content + else: + return False def required_slots_for_form(self, form_name: Text) -> List[Text]: """Retrieve the list of required slot names for a form defined in the domain.