Skip to content

Commit

Permalink
Merge pull request #12722 from ottonemo/issue/OS-413
Browse files Browse the repository at this point in the history
Fix OSS-413: Proper intents in interactive training
  • Loading branch information
ancalita authored Sep 27, 2023
2 parents 6027058 + ee1dd88 commit e37774e
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 3 deletions.
4 changes: 4 additions & 0 deletions changelog/12722.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Intent names will not be falsely abbreviated in interactive training (fixes OSS-413).

This will also fix a bug where forced user utterances (using the regex matcher) will
be reverted even though they are present in the domain.
18 changes: 15 additions & 3 deletions rasa/core/training/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,20 @@ def _print_help(skip_visualization: bool) -> None:
)


def intent_names_from_domain(domain: Any) -> List[Text]:
"""Get a list of the possible intents names from the domain specification.
This is its own function as intents are non-trivial to unpack and this
warrants testing.
"""
domain_intents = domain.get("intents", []) if domain is not None else []

# intents with properties such as `use_entities` or `ignore_entities`
# are a dictionary which needs unpacking. Other intents are strings
# and can be used as-is.
return [next(iter(i)) if isinstance(i, dict) else i for i in domain_intents]


async def record_messages(
endpoint: EndpointConfig,
file_importer: TrainingDataImporter,
Expand All @@ -1475,9 +1489,7 @@ async def record_messages(
)
return

domain_intents = domain.get("intents", []) if domain is not None else []

intents = [next(iter(i)) for i in domain_intents]
intents = intent_names_from_domain(domain)

num_messages = 0

Expand Down
32 changes: 32 additions & 0 deletions tests/core/training/test_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,38 @@ def mock_file_importer(
)


@pytest.mark.parametrize(
"domain_file, expected_intents",
[
(
"data/test_domains/default_unfeaturized_entities.yml",
[
"greet",
"default",
"goodbye",
"thank",
"ask",
"why",
"pure_intent",
],
),
(
"data/test_domains/default.yml",
[
"greet",
"default",
"goodbye",
],
),
],
)
def test_intent_names_from_domain(domain_file, expected_intents):
test_domain = Domain.load(domain_file)

intents = interactive.intent_names_from_domain(test_domain.as_dict())
assert set(intents) == set(expected_intents)


async def test_send_message(mock_endpoint: EndpointConfig):
sender_id = uuid.uuid4().hex

Expand Down

0 comments on commit e37774e

Please sign in to comment.