Skip to content

Commit

Permalink
Warn if a user has defined a slot named "context" - [ENG 503] (#12829)
Browse files Browse the repository at this point in the history
warn restricted slot names in the domain
  • Loading branch information
varunshankar authored Sep 24, 2023
1 parent ed71677 commit ad3712d
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 0 deletions.
25 changes: 25 additions & 0 deletions rasa/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import rasa.shared.exceptions
import rasa.shared.utils.io
import rasa.shared.constants
from rasa.shared.constants import CONTEXT
import rasa.model

CODE_NEEDS_TO_BE_RETRAINED = 0b0001
Expand Down Expand Up @@ -124,6 +125,28 @@ def _check_unresolved_slots(domain: Domain, stories: StoryGraph) -> None:
return None


def _check_restricted_slots(domain: Domain) -> None:
"""Checks if there are any restricted slots.
Args:
domain: The domain.
Raises:
Warn user if there are any restricted slots.
Returns:
`None` if there are no restricted slots.
"""
restricted_slot_names = [CONTEXT]
for slot in domain.slots:
if slot.name in restricted_slot_names:
rasa.shared.utils.cli.print_warning(
f"Slot name - '{slot.name}' is reserved and can not be used. "
f"Please use another slot name."
)
return None


def train(
domain: Text,
config: Text,
Expand Down Expand Up @@ -210,6 +233,7 @@ def train(
training_type = TrainingType.CORE

_check_unresolved_slots(domain_object, stories)
_check_restricted_slots(domain_object)

with telemetry.track_model_training(file_importer, model_type="rasa"):
return _train_graph(
Expand Down Expand Up @@ -393,6 +417,7 @@ def train_core(
return None

_check_unresolved_slots(domain, stories_data)
_check_restricted_slots(domain)

return _train_graph(
file_importer,
Expand Down
1 change: 1 addition & 0 deletions rasa/shared/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,4 @@
OPENAI_API_KEY_ENV_VAR = "OPENAI_API_KEY"

RASA_DEFAULT_FLOW_PATTERN_PREFIX = "pattern_"
CONTEXT = "context"
21 changes: 21 additions & 0 deletions tests/test_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from rasa.shared.constants import LATEST_TRAINING_DATA_FORMAT_VERSION
import rasa.shared.utils.io
from rasa.shared.core.domain import Domain
from rasa.shared.core.slots import AnySlot
from rasa.shared.exceptions import InvalidConfigException
from rasa.utils.tensorflow.constants import EPOCHS

Expand Down Expand Up @@ -1046,6 +1047,26 @@ def test_check_unresolved_slots(capsys: CaptureFixture):
assert rasa.model_training._check_unresolved_slots(domain, stories) is None


def test_check_restricted_slots(monkeypatch: MonkeyPatch):
domain_path = "data/test_domains/default_with_mapping.yml"
domain = Domain.load(domain_path)
mock = Mock()
monkeypatch.setattr(rasa.shared.utils.cli, "print_warning", mock)
rasa.model_training._check_restricted_slots(domain)
assert not mock.called

domain.slots.append(
AnySlot(
name="context",
mappings=[{}],
initial_value=None,
influence_conversation=False,
)
)
rasa.model_training._check_restricted_slots(domain)
assert mock.called


@pytest.mark.parametrize(
"fingerprint_results, expected_code",
[
Expand Down

0 comments on commit ad3712d

Please sign in to comment.