Skip to content

Commit

Permalink
Refactor and test function to retrieve intent names
Browse files Browse the repository at this point in the history
Since the main problem of bug OSS-413 is that intents with
attributes are not retrieved well the test was implemented to
use domains with and without intent definitions using attributes.
  • Loading branch information
ottonemo committed Sep 26, 2023
1 parent b0a7c6d commit ee1dd88
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 11 deletions.
26 changes: 15 additions & 11 deletions rasa/core/training/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,20 @@ def _print_help(skip_visualization: bool) -> None:
)


def intent_names_from_domain(domain: Any) -> List[Text]:
"""Get a list of the possible intents names from the domain specification.
This is its own function as intents are non-trivial to unpack and this
warrants testing.
"""
domain_intents = domain.get("intents", []) if domain is not None else []

# intents with properties such as `use_entities` or `ignore_entities`
# are a dictionary which needs unpacking. Other intents are strings
# and can be used as-is.
return [next(iter(i)) if isinstance(i, dict) else i for i in domain_intents]


async def record_messages(
endpoint: EndpointConfig,
file_importer: TrainingDataImporter,
Expand All @@ -1475,17 +1489,7 @@ async def record_messages(
)
return

domain_intents = domain.get("intents", []) if domain is not None else []

# intents with properties such as `use_entities` or `ignore_entities`
# are a dictionary which needs unpacking. Other intents are strings
# and can be used as-is.
intents = [
next(iter(i))
if isinstance(i, dict)
else i
for i in domain_intents
]
intents = intent_names_from_domain(domain)

num_messages = 0

Expand Down
32 changes: 32 additions & 0 deletions tests/core/training/test_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,38 @@ def mock_file_importer(
)


@pytest.mark.parametrize(
"domain_file, expected_intents",
[
(
"data/test_domains/default_unfeaturized_entities.yml",
[
"greet",
"default",
"goodbye",
"thank",
"ask",
"why",
"pure_intent",
],
),
(
"data/test_domains/default.yml",
[
"greet",
"default",
"goodbye",
],
),
],
)
def test_intent_names_from_domain(domain_file, expected_intents):
test_domain = Domain.load(domain_file)

intents = interactive.intent_names_from_domain(test_domain.as_dict())
assert set(intents) == set(expected_intents)


async def test_send_message(mock_endpoint: EndpointConfig):
sender_id = uuid.uuid4().hex

Expand Down

0 comments on commit ee1dd88

Please sign in to comment.