Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix OSS-413: Proper intents in interactive training #12722

Merged
merged 5 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog/12722.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Intent names will not be falsely abbreviated in interactive training (fixes OSS-413).

This will also fix a bug where forced user utterances (using the regex matcher) will
be reverted even though they are present in the domain.
Comment on lines +3 to +4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add the information you gave in the PR description, as that was clearer than these 2 lines 🙏🏻

18 changes: 15 additions & 3 deletions rasa/core/training/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,20 @@ def _print_help(skip_visualization: bool) -> None:
)


def intent_names_from_domain(domain: Any) -> List[Text]:
"""Get a list of the possible intents names from the domain specification.

This is its own function as intents are non-trivial to unpack and this
warrants testing.
"""
domain_intents = domain.get("intents", []) if domain is not None else []

# intents with properties such as `use_entities` or `ignore_entities`
# are a dictionary which needs unpacking. Other intents are strings
# and can be used as-is.
return [next(iter(i)) if isinstance(i, dict) else i for i in domain_intents]


async def record_messages(
endpoint: EndpointConfig,
file_importer: TrainingDataImporter,
Expand All @@ -1475,9 +1489,7 @@ async def record_messages(
)
return

domain_intents = domain.get("intents", []) if domain is not None else []

intents = [next(iter(i)) for i in domain_intents]
intents = intent_names_from_domain(domain)

num_messages = 0

Expand Down
32 changes: 32 additions & 0 deletions tests/core/training/test_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,38 @@ def mock_file_importer(
)


@pytest.mark.parametrize(
"domain_file, expected_intents",
[
(
"data/test_domains/default_unfeaturized_entities.yml",
[
"greet",
"default",
"goodbye",
"thank",
"ask",
"why",
"pure_intent",
],
),
(
"data/test_domains/default.yml",
[
"greet",
"default",
"goodbye",
],
),
],
)
def test_intent_names_from_domain(domain_file, expected_intents):
test_domain = Domain.load(domain_file)

intents = interactive.intent_names_from_domain(test_domain.as_dict())
assert set(intents) == set(expected_intents)


async def test_send_message(mock_endpoint: EndpointConfig):
sender_id = uuid.uuid4().hex

Expand Down