diff --git a/rasa/cli/arguments/run.py b/rasa/cli/arguments/run.py index f982672700d1..2f807ad74718 100644 --- a/rasa/cli/arguments/run.py +++ b/rasa/cli/arguments/run.py @@ -1,8 +1,11 @@ +import os + import argparse from typing import Union from rasa.cli.arguments.default_arguments import add_model_param, add_endpoint_param from rasa.core import constants +from rasa.env import DEFAULT_JWT_METHOD, JWT_METHOD_ENV, JWT_SECRET_ENV, JWT_PRIVATE_KEY_ENV, AUTH_TOKEN_ENV def set_run_arguments(parser: argparse.ArgumentParser) -> None: @@ -82,18 +85,27 @@ def add_server_arguments(parser: argparse.ArgumentParser) -> None: "yml file.", ) + add_server_settings_arguments(parser) + + +def add_server_settings_arguments(parser: argparse.ArgumentParser) -> None: + """Add arguments for the API server. + + Args: + parser: Argument parser. + """ server_arguments = parser.add_argument_group("Server Settings") add_interface_argument(server_arguments) - add_port_argument(server_arguments) server_arguments.add_argument( "-t", "--auth-token", type=str, + default=os.getenv(AUTH_TOKEN_ENV), help="Enable token based authentication. Requests need to provide " - "the token to be accepted.", + "the token to be accepted.", ) server_arguments.add_argument( "--cors", @@ -132,13 +144,13 @@ def add_server_arguments(parser: argparse.ArgumentParser) -> None: server_arguments.add_argument( "--ssl-ca-file", help="If your SSL certificate needs to be verified, " - "you can specify the CA file " - "using this parameter.", + "you can specify the CA file " + "using this parameter.", ) server_arguments.add_argument( "--ssl-password", help="If your ssl-keyfile is protected by a password, you can specify it " - "using this paramer.", + "using this paramer.", ) channel_arguments = parser.add_argument_group("Channels") channel_arguments.add_argument( @@ -150,26 +162,37 @@ def add_server_arguments(parser: argparse.ArgumentParser) -> None: "--connector", type=str, help="Service to connect to." ) + add_jwt_arguments(parser) + + +def add_jwt_arguments(parser: argparse.ArgumentParser) -> None: + """Adds arguments related to JWT authentication. + + Args: + parser: Argument parser. + """ jwt_auth = parser.add_argument_group("JWT Authentication") jwt_auth.add_argument( "--jwt-secret", type=str, + default=os.getenv(JWT_SECRET_ENV), help="Public key for asymmetric JWT methods or shared secret" - "for symmetric methods. Please also make sure to use " - "--jwt-method to select the method of the signature, " - "otherwise this argument will be ignored." - "Note that this key is meant for securing the HTTP API.", + "for symmetric methods. Please also make sure to use " + "--jwt-method to select the method of the signature, " + "otherwise this argument will be ignored." + "Note that this key is meant for securing the HTTP API.", ) jwt_auth.add_argument( "--jwt-method", type=str, - default="HS256", + default=os.getenv(JWT_METHOD_ENV, DEFAULT_JWT_METHOD), help="Method used for the signature of the JWT authentication payload.", ) jwt_auth.add_argument( "--jwt-private-key", type=str, + default=os.getenv(JWT_PRIVATE_KEY_ENV), help="A private key used for generating web tokens, dependent upon " - "which hashing algorithm is used. It must be used together with " - "--jwt-secret for providing the public key.", - ) + "which hashing algorithm is used. It must be used together with " + "--jwt-secret for providing the public key.", + ) \ No newline at end of file diff --git a/rasa/env.py b/rasa/env.py new file mode 100644 index 000000000000..3415487c3807 --- /dev/null +++ b/rasa/env.py @@ -0,0 +1,5 @@ +AUTH_TOKEN_ENV = "AUTH_TOKEN" +JWT_SECRET_ENV = "JWT_SECRET" +JWT_METHOD_ENV = "JWT_METHOD" +DEFAULT_JWT_METHOD = "HS256" +JWT_PRIVATE_KEY_ENV = "JWT_PRIVATE_KEY" diff --git a/tests/cli/arguments/__init__.py b/tests/cli/arguments/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/cli/arguments/test_run.py b/tests/cli/arguments/test_run.py new file mode 100644 index 000000000000..5d555e6db2b2 --- /dev/null +++ b/tests/cli/arguments/test_run.py @@ -0,0 +1,188 @@ +from typing import List, Dict + +import argparse +import pytest +from _pytest.monkeypatch import MonkeyPatch + +from rasa.cli.arguments.run import add_jwt_arguments, add_server_settings_arguments +from rasa.env import ( + JWT_SECRET_ENV, + JWT_METHOD_ENV, + JWT_PRIVATE_KEY_ENV, + DEFAULT_JWT_METHOD, + AUTH_TOKEN_ENV, +) + + +@pytest.mark.parametrize( + "env_variables, input_args, expected", + [ + ( + # all env variables are set + { + JWT_SECRET_ENV: "secret", + JWT_METHOD_ENV: "HS256", + JWT_PRIVATE_KEY_ENV: "private_key", + }, + [], + argparse.Namespace( + jwt_secret="secret", + jwt_method="HS256", + jwt_private_key="private_key", + ), + ), + ( + # no JWT_SECRET_ENV and --jwt-secret is set + { + JWT_METHOD_ENV: "HS256", + JWT_PRIVATE_KEY_ENV: "private_key", + }, + ["--jwt-secret", "secret"], + argparse.Namespace( + jwt_secret="secret", + jwt_method="HS256", + jwt_private_key="private_key", + ), + ), + ( + # no JWT_METHOD_ENV and --jwt-method is set + { + JWT_SECRET_ENV: "secret", + JWT_PRIVATE_KEY_ENV: "private_key", + }, + ["--jwt-method", "HS256"], + argparse.Namespace( + jwt_secret="secret", + jwt_method="HS256", + jwt_private_key="private_key", + ), + ), + ( + # no JWT_PRIVATE_KEY_ENV and --jwt-private-key is set + { + JWT_SECRET_ENV: "secret", + JWT_METHOD_ENV: "HS256", + }, + ["--jwt-private-key", "private_key"], + argparse.Namespace( + jwt_secret="secret", + jwt_method="HS256", + jwt_private_key="private_key", + ), + ), + ( + # no JWT_SECRET_ENV and no --jwt-secret + { + JWT_METHOD_ENV: "HS256", + JWT_PRIVATE_KEY_ENV: "private_key", + }, + [], + argparse.Namespace( + jwt_secret=None, + jwt_method="HS256", + jwt_private_key="private_key", + ), + ), + ( + # no JWT_METHOD_ENV and no --jwt-method + { + JWT_SECRET_ENV: "secret", + JWT_PRIVATE_KEY_ENV: "private_key", + }, + [], + argparse.Namespace( + jwt_secret="secret", + jwt_method=DEFAULT_JWT_METHOD, + jwt_private_key="private_key", + ), + ), + ( + # no JWT_PRIVATE_KEY_ENV and no --jwt-private-key + { + JWT_SECRET_ENV: "secret", + JWT_METHOD_ENV: "HS256", + }, + [], + argparse.Namespace( + jwt_secret="secret", + jwt_method="HS256", + jwt_private_key=None, + ), + ), + ( + # no env variables and no arguments + {}, + [], + argparse.Namespace( + jwt_secret=None, + jwt_method="HS256", + jwt_private_key=None, + ), + ), + ], +) +def test_jwt_argument_parsing( + env_variables: Dict[str, str], + input_args: List[str], + expected: argparse.Namespace, + monkeypatch: MonkeyPatch, +) -> None: + """Tests parsing of the JWT arguments.""" + parser = argparse.ArgumentParser() + + for env_name, env_value in env_variables.items(): + monkeypatch.setenv(env_name, env_value) + + add_jwt_arguments(parser) + args = parser.parse_args(input_args) + + assert args.jwt_secret == expected.jwt_secret + assert args.jwt_method == expected.jwt_method + assert args.jwt_private_key == expected.jwt_private_key + + +@pytest.mark.parametrize( + "env_variables, input_args, expected", + [ + ( + { + AUTH_TOKEN_ENV: "secret", + }, + [], + argparse.Namespace( + auth_token="secret", + ), + ), + ( + {}, + ["--auth-token", "secret"], + argparse.Namespace( + auth_token="secret", + ), + ), + ( + {}, + [], + argparse.Namespace( + auth_token=None, + ), + ), + ], +) +def test_add_server_settings_arguments( + env_variables: Dict[str, str], + input_args: List[str], + expected: argparse.Namespace, + monkeypatch: MonkeyPatch, +) -> None: + """Tests parsing of the server settings arguments.""" + parser = argparse.ArgumentParser() + + for env_name, env_value in env_variables.items(): + monkeypatch.setenv(env_name, env_value) + + add_server_settings_arguments(parser) + + args = parser.parse_args(input_args) + + assert args.auth_token == expected.auth_token