diff --git a/craft_application/commands/init.py b/craft_application/commands/init.py index 90b8ea7f..3067de85 100644 --- a/craft_application/commands/init.py +++ b/craft_application/commands/init.py @@ -101,8 +101,18 @@ def profiles(self) -> list[str]: def run(self, parsed_args: argparse.Namespace) -> None: """Run the command.""" - project_dir = self._get_project_dir(parsed_args) + # If the user provided a "name" and it's not valid, the command fails. + if parsed_args.name is not None: + self._services.init.validate_project_name(parsed_args.name) + + # However, if the name comes from the directory, we don't fail and + # instead fallback to its default. project_name = self._get_name(parsed_args) + project_name = self._services.init.validate_project_name( + project_name, use_default=True + ) + + project_dir = self._get_project_dir(parsed_args) template_dir = pathlib.Path(self.parent_template_dir / parsed_args.profile) craft_cli.emit.progress("Checking project directory.") diff --git a/craft_application/models/constraints.py b/craft_application/models/constraints.py index edec9a5f..f7d8b483 100644 --- a/craft_application/models/constraints.py +++ b/craft_application/models/constraints.py @@ -91,8 +91,8 @@ def validate(value: str) -> str: * May not have two hyphens in a row """ -_PROJECT_NAME_REGEX = r"^([a-z0-9][a-z0-9-]?)*[a-z]+([a-z0-9-]?[a-z0-9])*$" -_PROJECT_NAME_COMPILED_REGEX = re.compile(_PROJECT_NAME_REGEX) +PROJECT_NAME_REGEX = r"^([a-z0-9][a-z0-9-]?)*[a-z]+([a-z0-9-]?[a-z0-9])*$" +PROJECT_NAME_COMPILED_REGEX = re.compile(PROJECT_NAME_REGEX) MESSAGE_INVALID_NAME = ( "invalid name: Names can only use ASCII lowercase letters, numbers, and hyphens. " "They must have at least one letter, may not start or end with a hyphen, " @@ -102,13 +102,13 @@ def validate(value: str) -> str: ProjectName = Annotated[ str, pydantic.BeforeValidator( - get_validator_by_regex(_PROJECT_NAME_COMPILED_REGEX, MESSAGE_INVALID_NAME) + get_validator_by_regex(PROJECT_NAME_COMPILED_REGEX, MESSAGE_INVALID_NAME) ), pydantic.Field( min_length=1, max_length=40, strict=True, - pattern=_PROJECT_NAME_REGEX, + pattern=PROJECT_NAME_REGEX, description=_PROJECT_NAME_DESCRIPTION, title="Project Name", examples=[ diff --git a/craft_application/services/init.py b/craft_application/services/init.py index a183491f..dfdd5eae 100644 --- a/craft_application/services/init.py +++ b/craft_application/services/init.py @@ -15,10 +15,13 @@ # along with this program. If not, see . """Service for initializing a project.""" +from __future__ import annotations import os import pathlib import shutil +import typing +from re import Pattern from typing import Any import jinja2 @@ -27,12 +30,47 @@ from craft_application.errors import InitError from craft_application.git import GitError, GitRepo, is_repo, parse_describe +from ..models.constraints import MESSAGE_INVALID_NAME, PROJECT_NAME_COMPILED_REGEX from . import base +if typing.TYPE_CHECKING: # pragma: no cover + from craft_application.application import AppMetadata + from craft_application.services import ServiceFactory + class InitService(base.AppService): """Service class for initializing a project.""" + def __init__( + self, + app: AppMetadata, + services: ServiceFactory, + *, + default_name: str = "my-project", + name_regex: Pattern[str] = PROJECT_NAME_COMPILED_REGEX, + invalid_name_message: str = MESSAGE_INVALID_NAME, + ) -> None: + super().__init__(app, services) + self._default_name = default_name + self._name_regex = name_regex + self._invalid_name_message = invalid_name_message + + def validate_project_name(self, name: str, *, use_default: bool = False) -> str: + """Validate that ``name`` is valid as a project name. + + If ``name`` is invalid and ``use_default`` is False, then an InitError + is raised. If ``use_default`` is True, the default project name provided + to the service's constructor is returned. + + If ``name`` is valid, it is returned. + """ + if not self._name_regex.match(name): + if use_default: + return self._default_name + raise InitError(self._invalid_name_message) + + return name + def initialise_project( self, *, diff --git a/docs/reference/changelog.rst b/docs/reference/changelog.rst index e037748f..e7d9c2bd 100644 --- a/docs/reference/changelog.rst +++ b/docs/reference/changelog.rst @@ -17,6 +17,13 @@ Commands ======== - Provide a documentation link in help messages. +- Updates to the ``init`` command: + + - If the ``--name`` argument is provided, the command now checks if the value + is a valid project name, and returns an error if it isn't. + - If the ``--name`` argument is *not* provided, the command now checks whether + the project directory is a valid project name. If it isn't, the command sets + the project name to ``my-project``. Services ======== diff --git a/tests/integration/commands/test_init.py b/tests/integration/commands/test_init.py index 13cab7ba..111502ba 100644 --- a/tests/integration/commands/test_init.py +++ b/tests/integration/commands/test_init.py @@ -64,8 +64,17 @@ def fake_template_dirs(tmp_path): ) @pytest.mark.parametrize("project_dir", [None, "project-dir"]) @pytest.mark.usefixtures("fake_template_dirs") -def test_init(app, capsys, monkeypatch, profile, expected_file, project_dir): +def test_init( + app, + capsys, + monkeypatch, + profile, + expected_file, + project_dir, + empty_working_directory, +): """Initialise a project.""" + monkeypatch.chdir(empty_working_directory) expected_output = "Successfully initialised project" command = ["testcraft", "init"] if profile: @@ -156,3 +165,18 @@ def test_init_nonoverlapping_file(app, capsys, monkeypatch): assert return_code == os.EX_OK assert expected_output in stdout assert pathlib.Path("simple-file").is_file() + + +@pytest.mark.usefixtures("fake_template_dirs") +def test_init_invalid_directory(app, monkeypatch, tmp_path): + """A default name is used if the project dir is not a valid project name.""" + invalid_dir = tmp_path / "invalid--name" + invalid_dir.mkdir() + monkeypatch.chdir(invalid_dir) + + monkeypatch.setattr("sys.argv", ["testcraft", "init", "--profile", "simple"]) + return_code = app.run() + + assert return_code == os.EX_OK + expected_file = invalid_dir / "simple-file" + assert expected_file.read_text() == "name=my-project" diff --git a/tests/unit/commands/test_init.py b/tests/unit/commands/test_init.py index 3f92f596..57299e5e 100644 --- a/tests/unit/commands/test_init.py +++ b/tests/unit/commands/test_init.py @@ -61,6 +61,7 @@ def test_init_in_cwd(init_command, name, new_dir, mock_services, emitter): name=name, profile="test-profile", ) + mock_services.init.validate_project_name.return_value = expected_name init_command.run(parsed_args) @@ -82,6 +83,7 @@ def test_init_run_project_dir(init_command, name, mock_services, emitter): name=name, profile="test-profile", ) + mock_services.init.validate_project_name.return_value = expected_name init_command.run(parsed_args) @@ -113,3 +115,36 @@ def test_existing_files(init_command, tmp_path, mock_services): init_command.run(parsed_args) mock_services.init.initialise_project.assert_not_called() + + +def test_invalid_name(init_command, mock_services): + mock_services.init.validate_project_name.side_effect = InitError("test-error") + parsed_args = argparse.Namespace( + name="invalid--name", + ) + with pytest.raises(InitError, match="test-error"): + init_command.run(parsed_args) + + +def test_invalid_name_directory(init_command, mock_services): + def _validate_project_name(_name: str, *, use_default: bool = False): + if use_default: + return "my-project" + raise InitError("test-error") + + mock_services.init.validate_project_name = _validate_project_name + + project_dir = pathlib.Path("invalid--name") + parsed_args = argparse.Namespace( + project_dir=project_dir, + name=None, + profile="simple", + ) + + init_command.run(parsed_args) + + mock_services.init.initialise_project.assert_called_once_with( + project_dir=project_dir.expanduser().resolve(), + project_name="my-project", + template_dir=init_command.parent_template_dir / "simple", + ) diff --git a/tests/unit/services/test_init.py b/tests/unit/services/test_init.py index 82b09eff..3e1ffdbc 100644 --- a/tests/unit/services/test_init.py +++ b/tests/unit/services/test_init.py @@ -26,6 +26,7 @@ import pytest_mock from craft_application import errors, services from craft_application.git import GitRepo, short_commit_sha +from craft_application.models.constraints import MESSAGE_INVALID_NAME from craft_cli.pytest_plugin import RecordingEmitter @@ -327,3 +328,25 @@ def test_initialise_project( render_project_mock.assert_called_once_with( fake_env, project_dir, templates_dir, fake_context ) + + +@pytest.mark.parametrize( + "invalid_name", ["invalid--name", "-invalid-name", "invalid-name-", "0", "0-0", ""] +) +def test_validate_name_invalid(init_service, invalid_name): + with pytest.raises(errors.InitError, match=MESSAGE_INVALID_NAME): + init_service.validate_project_name(invalid_name) + + +@pytest.mark.parametrize("valid_name", ["valid-name", "a", "a-a", "aaa", "0a"]) +def test_validate_name_valid(init_service, valid_name): + obtained = init_service.validate_project_name(valid_name) + assert obtained == valid_name + + +def test_valid_name_invalid_use_default(init_service): + invalid_name = "invalid--name" + init_service._default_name = "my-default-name" + + obtained = init_service.validate_project_name(invalid_name, use_default=True) + assert obtained == "my-default-name"