diff --git a/rasa/core/training/interactive.py b/rasa/core/training/interactive.py index dae4aad17516..2978b94099af 100644 --- a/rasa/core/training/interactive.py +++ b/rasa/core/training/interactive.py @@ -1457,6 +1457,20 @@ def _print_help(skip_visualization: bool) -> None: ) +def intent_names_from_domain(domain: Any) -> List[Text]: + """Get a list of the possible intents names from the domain specification. + + This is its own function as intents are non-trivial to unpack and this + warrants testing. + """ + domain_intents = domain.get("intents", []) if domain is not None else [] + + # intents with properties such as `use_entities` or `ignore_entities` + # are a dictionary which needs unpacking. Other intents are strings + # and can be used as-is. + return [next(iter(i)) if isinstance(i, dict) else i for i in domain_intents] + + async def record_messages( endpoint: EndpointConfig, file_importer: TrainingDataImporter, @@ -1475,17 +1489,7 @@ async def record_messages( ) return - domain_intents = domain.get("intents", []) if domain is not None else [] - - # intents with properties such as `use_entities` or `ignore_entities` - # are a dictionary which needs unpacking. Other intents are strings - # and can be used as-is. - intents = [ - next(iter(i)) - if isinstance(i, dict) - else i - for i in domain_intents - ] + intents = intent_names_from_domain(domain) num_messages = 0 diff --git a/tests/core/training/test_interactive.py b/tests/core/training/test_interactive.py index 215bd09a8f1a..172322842b8b 100644 --- a/tests/core/training/test_interactive.py +++ b/tests/core/training/test_interactive.py @@ -54,6 +54,38 @@ def mock_file_importer( ) +@pytest.mark.parametrize( + "domain_file, expected_intents", + [ + ( + "data/test_domains/default_unfeaturized_entities.yml", + [ + "greet", + "default", + "goodbye", + "thank", + "ask", + "why", + "pure_intent", + ], + ), + ( + "data/test_domains/default.yml", + [ + "greet", + "default", + "goodbye", + ], + ), + ], +) +def test_intent_names_from_domain(domain_file, expected_intents): + test_domain = Domain.load(domain_file) + + intents = interactive.intent_names_from_domain(test_domain.as_dict()) + assert set(intents) == set(expected_intents) + + async def test_send_message(mock_endpoint: EndpointConfig): sender_id = uuid.uuid4().hex