Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn if a user has defined a slot named "context" - [ENG 503] #12829

Merged
merged 3 commits into from
Sep 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions rasa/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,28 @@ def _check_unresolved_slots(domain: Domain, stories: StoryGraph) -> None:
return None


def _check_restricted_slots(domain: Domain) -> None:
"""Checks if there are any restricted slots.

Args:
domain: The domain.

Raises:
Warn user if there are any restricted slots.

Returns:
`None` if there are no restricted slots.
"""
restricted_slot_names = [rasa.shared.constants.CONTEXT]
varunshankar marked this conversation as resolved.
Show resolved Hide resolved
for slot in domain.slots:
if slot.name in restricted_slot_names:
rasa.shared.utils.cli.print_warning(
f"Slot name - '{slot.name}' is reserved and can not be used. "
f"Please use another slot name."
)
return None


def train(
domain: Text,
config: Text,
Expand Down Expand Up @@ -210,6 +232,7 @@ def train(
training_type = TrainingType.CORE

_check_unresolved_slots(domain_object, stories)
_check_restricted_slots(domain_object)

with telemetry.track_model_training(file_importer, model_type="rasa"):
return _train_graph(
Expand Down Expand Up @@ -393,6 +416,7 @@ def train_core(
return None

_check_unresolved_slots(domain, stories_data)
_check_restricted_slots(domain)

return _train_graph(
file_importer,
Expand Down
1 change: 1 addition & 0 deletions rasa/shared/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,4 @@
OPENAI_API_KEY_ENV_VAR = "OPENAI_API_KEY"

RASA_DEFAULT_FLOW_PATTERN_PREFIX = "pattern_"
CONTEXT = "context"
21 changes: 21 additions & 0 deletions tests/test_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from rasa.shared.constants import LATEST_TRAINING_DATA_FORMAT_VERSION
import rasa.shared.utils.io
from rasa.shared.core.domain import Domain
from rasa.shared.core.slots import AnySlot
from rasa.shared.exceptions import InvalidConfigException
from rasa.utils.tensorflow.constants import EPOCHS

Expand Down Expand Up @@ -1046,6 +1047,26 @@ def test_check_unresolved_slots(capsys: CaptureFixture):
assert rasa.model_training._check_unresolved_slots(domain, stories) is None


def test_check_restricted_slots(monkeypatch: MonkeyPatch):
domain_path = "data/test_domains/default_with_mapping.yml"
domain = Domain.load(domain_path)
mock = Mock()
monkeypatch.setattr(rasa.shared.utils.cli, "print_warning", mock)
rasa.model_training._check_restricted_slots(domain)
assert not mock.called

domain.slots.append(
AnySlot(
name="context",
mappings=[{}],
initial_value=None,
influence_conversation=False,
)
)
rasa.model_training._check_restricted_slots(domain)
assert mock.called


@pytest.mark.parametrize(
"fingerprint_results, expected_code",
[
Expand Down
Loading