Skip to content

Commit

Permalink
address review comments, some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ancalita committed Sep 27, 2023
1 parent 008a9c2 commit a1714b0
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 193 deletions.
8 changes: 2 additions & 6 deletions rasa/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def validate_files(
if stories_only:
all_good = _validate_story_structure(validator, max_history, fail_on_warnings)
elif flows_only:
all_good = _validate_flows(validator)
all_good = validator.verify_flows()
else:
if importer.get_domain().is_empty():
rasa.shared.utils.cli.print_error_and_exit(
Expand All @@ -247,7 +247,7 @@ def validate_files(
valid_stories = _validate_story_structure(
validator, max_history, fail_on_warnings
)
valid_flows = _validate_flows(validator)
valid_flows = validator.verify_flows()

all_good = valid_domain and valid_nlu and valid_stories and valid_flows

Expand Down Expand Up @@ -293,10 +293,6 @@ def _validate_story_structure(
)


def _validate_flows(validator: "Validator") -> bool:
return validator.verify_flows()


def cancel_cause_not_found(
current: Optional[Union["Path", Text]],
parameter: Text,
Expand Down
82 changes: 41 additions & 41 deletions rasa/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,18 @@ def warn_if_config_mandatory_keys_are_not_set(self) -> None:
f"placeholder value with a unique identifier."
)

@staticmethod
def _raise_exception_if_slot_not_in_domain(
slot_name: str, domain_slots: Dict[Text, Slot], step_id: str, flow_id: str
) -> None:
if slot_name not in domain_slots:
raise RasaException(
f"The slot '{slot_name}' is used in the "
f"step '{step_id}' of flow id '{flow_id}', but it "
f"is not listed in the domain slots. "
f"You should add it to your domain file!",
)

@staticmethod
def _raise_exception_if_list_slot(slot: Slot, step_id: str, flow_id: str) -> None:
if isinstance(slot, ListSlot):
Expand All @@ -518,18 +530,14 @@ def _raise_exception_if_dialogue_stack_slot(

def verify_flows_steps_against_domain(self, user_flows: List[Flow]) -> bool:
"""Checks flows steps' references against the domain file."""
all_good = True
domain_slots = {slot.name: slot for slot in self.domain.slots}
for flow in user_flows:
for step in flow.steps:
if isinstance(step, CollectInformationFlowStep):
if step.collect_information not in domain_slots:
raise RasaException(
f"The slot '{step.collect_information}' is used in the "
f"step '{step.id}' of flow id '{flow.id}', but it "
f"is not listed in the domain slots. "
f"You should add it to your domain file!",
)
self._raise_exception_if_slot_not_in_domain(
step.collect_information, domain_slots, step.id, flow.id
)

current_slot = domain_slots[step.collect_information]
self._raise_exception_if_list_slot(current_slot, step.id, flow.id)
self._raise_exception_if_dialogue_stack_slot(
Expand All @@ -539,13 +547,10 @@ def verify_flows_steps_against_domain(self, user_flows: List[Flow]) -> bool:
elif isinstance(step, SetSlotsFlowStep):
for slot in step.slots:
slot_name = slot["key"]
if slot_name not in domain_slots:
raise RasaException(
f"The slot '{slot_name}' is used in the step "
f"'{step.id}' of flow id '{flow.id}', but it "
f"is not listed in the domain slots. "
f"You should add it to your domain file!",
)
self._raise_exception_if_slot_not_in_domain(
slot_name, domain_slots, step.id, flow.id
)

current_slot = domain_slots[slot_name]
self._raise_exception_if_list_slot(
current_slot, step.id, flow.id
Expand All @@ -562,13 +567,11 @@ def verify_flows_steps_against_domain(self, user_flows: List[Flow]) -> bool:
f"is not listed in the domain file. "
f"You should add it to your domain file!",
)
return all_good
return True

@staticmethod
def verify_unique_flows(user_flows: List[Flow]) -> bool:
"""Checks if all flows have unique names and descriptions."""
all_good = True

flows_mapping: Dict[str, str] = {}
punctuation_table = str.maketrans({i: "" for i in string.punctuation})

Expand All @@ -591,7 +594,20 @@ def verify_unique_flows(user_flows: List[Flow]) -> bool:

flows_mapping[flow.name] = cleaned_description

return all_good
return True

@staticmethod
def _construct_predicate(predicate: Optional[str], step_id: str) -> Predicate:
try:
pred = Predicate(predicate)
except (TypeError, Exception) as exception:
raise RasaException(
f"Could not initialize the predicate found under step "
f"'{step_id}'. Please make sure that all predicates "
f"are strings."
) from exception

return pred

@staticmethod
def verify_predicates(user_flows: List[Flow]) -> bool:
Expand All @@ -602,17 +618,10 @@ def verify_predicates(user_flows: List[Flow]) -> bool:
if isinstance(step, BranchFlowStep):
for link in step.next.links:
if isinstance(link, IfFlowLink):
try:
predicate = Predicate(link.condition)
except (TypeError, Exception) as exception:
raise RasaException(
f"Could not initialize the predicate found "
f"under step '{step.id}'. Please make sure "
f"that all predicates are strings."
) from exception

is_valid = predicate.is_valid()
if not is_valid:
predicate = Validator._construct_predicate(
link.condition, step.id
)
if not predicate.is_valid():
raise RasaException(
f"Detected invalid condition '{link.condition}' "
f"at step '{step.id}' for flow id '{flow.id}'. "
Expand All @@ -621,17 +630,8 @@ def verify_predicates(user_flows: List[Flow]) -> bool:
elif isinstance(step, CollectInformationFlowStep):
predicates = [predicate.if_ for predicate in step.rejections]
for predicate in predicates:
try:
pred = Predicate(predicate)
except (TypeError, Exception) as exception:
raise RasaException(
f"Could not initialize the predicate found under step "
f"'{step.id}'. Please make sure that all predicates "
f"are strings."
) from exception

is_valid = pred.is_valid()
if not is_valid:
pred = Validator._construct_predicate(predicate, step.id)
if not pred.is_valid():
raise RasaException(
f"Detected invalid rejection '{predicate}' "
f"at `collect_information` step '{step.id}' "
Expand Down
184 changes: 38 additions & 146 deletions tests/test_validator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import textwrap
import warnings
from typing import Text
from typing import Any, Dict, List, Text

import pytest
from _pytest.logging import LogCaptureFixture
Expand Down Expand Up @@ -859,137 +859,42 @@ def test_warn_if_config_mandatory_keys_are_not_set_invalid_paths(
validator.warn_if_config_mandatory_keys_are_not_set()


def test_verify_flow_steps_against_domain_missing_slot_in_domain(
tmp_path: Path,
nlu_data_path: Path,
) -> None:
missing_slot_in_domain = "transfer_amount"
flows_file = tmp_path / "flows.yml"
with open(flows_file, "w") as file:
file.write(
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
flows:
transfer_money:
description: This flow lets users send money.
name: transfer money
steps:
- id: "ask_recipient"
collect_information: transfer_recipient
next: "ask_amount"
- id: "ask_amount"
collect_information: {missing_slot_in_domain}
next: "execute_transfer"
- id: "execute_transfer"
action: action_transfer_money
"""
)
domain_file = tmp_path / "domain.yml"
with open(domain_file, "w") as file:
file.write(
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
intents:
- greet
slots:
transfer_recipient:
type: text
mappings: []
actions:
- action_transfer_money
"""
)
importer = RasaFileImporter(
config_file="data/test_moodbot/config.yml",
domain_path=str(domain_file),
training_data_paths=[str(flows_file), str(nlu_data_path)],
)

validator = Validator.from_importer(importer)
user_flows = [
flow
for flow in validator.flows.underlying_flows
if not flow.id.startswith("pattern_")
]

with pytest.raises(RasaException) as e:
validator.verify_flows_steps_against_domain(user_flows)

assert (
f"The slot '{missing_slot_in_domain}' is used in the step 'ask_amount' of "
f"flow id 'transfer_money', but it is not listed in the domain slots."
) in str(e.value)


def test_verify_flow_steps_against_domain_missing_action_in_domain(
tmp_path: Path,
nlu_data_path: Path,
) -> None:
missing_action_in_domain = "action_transfer_money"
flows_file = tmp_path / "flows.yml"
with open(flows_file, "w") as file:
file.write(
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
flows:
transfer_money:
description: This flow lets users send money.
name: transfer money
steps:
- id: "ask_recipient"
collect_information: transfer_recipient
next: "ask_amount"
- id: "ask_amount"
collect_information: transfer_amount
next: "execute_transfer"
- id: "execute_transfer"
action: {missing_action_in_domain}
"""
)
domain_file = tmp_path / "domain.yml"
with open(domain_file, "w") as file:
file.write(
f"""
version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}"
intents:
- greet
slots:
transfer_recipient:
type: text
mappings: []
transfer_amount:
type: float
mappings: []
"""
)
importer = RasaFileImporter(
config_file="data/test_moodbot/config.yml",
domain_path=str(domain_file),
training_data_paths=[str(flows_file), str(nlu_data_path)],
)

validator = Validator.from_importer(importer)
user_flows = [
flow
for flow in validator.flows.underlying_flows
if not flow.id.startswith("pattern_")
]

with pytest.raises(RasaException) as e:
validator.verify_flows_steps_against_domain(user_flows)

assert (
f"The action '{missing_action_in_domain}' is used in the step "
f"'execute_transfer' of flow id 'transfer_money', but it "
f"is not listed in the domain file."
) in str(e.value)


def test_verify_flow_steps_against_domain_missing_slot_from_set_slot_step(
@pytest.mark.parametrize(
"domain_actions, domain_slots, exception_message",
[
# set_slot slot is not listed in the domain
(
["action_transfer_money"],
{"transfer_amount": {"type": "float", "mappings": []}},
"The slot 'account_type' is used in the step 'set_account_type' "
"of flow id 'transfer_money', but it is not listed in the domain slots.",
),
# collect_information slot is not listed in the domain
(
["action_transfer_money"],
{"account_type": {"type": "text", "mappings": []}},
"The slot 'transfer_amount' is used in the step 'ask_amount' "
"of flow id 'transfer_money', but it is not listed in the domain slots.",
),
# action name is not listed in the domain
(
[],
{
"account_type": {"type": "text", "mappings": []},
"transfer_amount": {"type": "float", "mappings": []},
},
"The action 'action_transfer_money' is used in the step 'execute_transfer' "
"of flow id 'transfer_money', but it is not listed in the domain file.",
),
],
)
def test_verify_flow_steps_against_domain_fail(
tmp_path: Path,
nlu_data_path: Path,
domain_actions: List[Text],
domain_slots: Dict[Text, Any],
exception_message: Text,
) -> None:
missing_slot_in_domain = "account_type"
flows_file = tmp_path / "flows.yml"
with open(flows_file, "w") as file:
file.write(
Expand All @@ -1000,15 +905,12 @@ def test_verify_flow_steps_against_domain_missing_slot_from_set_slot_step(
description: This flow lets users send money.
name: transfer money
steps:
- id: "ask_recipient"
collect_information: transfer_recipient
next: "ask_amount"
- id: "ask_amount"
collect_information: transfer_amount
next: "set_account_type"
- id: "set_account_type"
set_slots:
- {missing_slot_in_domain}: "debit"
- account_type: "debit"
next: "execute_transfer"
- id: "execute_transfer"
action: action_transfer_money
Expand All @@ -1022,14 +924,8 @@ def test_verify_flow_steps_against_domain_missing_slot_from_set_slot_step(
intents:
- greet
slots:
transfer_recipient:
type: text
mappings: []
transfer_amount:
type: float
mappings: []
actions:
- action_transfer_money
{domain_slots}
actions: {domain_actions}
"""
)
importer = RasaFileImporter(
Expand All @@ -1048,11 +944,7 @@ def test_verify_flow_steps_against_domain_missing_slot_from_set_slot_step(
with pytest.raises(RasaException) as e:
validator.verify_flows_steps_against_domain(user_flows)

assert (
f"The slot '{missing_slot_in_domain}' is used in the step "
f"'set_account_type' of flow id 'transfer_money', "
f"but it is not listed in the domain slots."
) in str(e.value)
assert exception_message in str(e.value)


def test_verify_flow_steps_against_domain_disallowed_list_slot(
Expand Down

0 comments on commit a1714b0

Please sign in to comment.