diff --git a/README.md b/README.md index d5f1762..89731d2 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ For example: data = load_pipeline(f.read()) -The returned object is a Pydantic model, `Pipeline`, defined in `pipeline/models.py`. +The returned object is an instance of `pipeline.models.Pipeline`. ## Developer docs diff --git a/pipeline/exceptions.py b/pipeline/exceptions.py index 5befe71..313aace 100644 --- a/pipeline/exceptions.py +++ b/pipeline/exceptions.py @@ -8,3 +8,7 @@ class InvalidPatternError(ProjectValidationError): class YAMLError(Exception): pass + + +class ValidationError(Exception): + pass diff --git a/pipeline/legacy.py b/pipeline/legacy.py index 2ea83cf..ff795d1 100644 --- a/pipeline/legacy.py +++ b/pipeline/legacy.py @@ -5,8 +5,8 @@ def get_all_output_patterns_from_project_file(project_file: str) -> list[str]: config = load_pipeline(project_file) - all_patterns = set() + all_patterns: set[str] = set() for action in config.actions.values(): - for patterns in action.outputs.dict(exclude_unset=True).values(): + for patterns in action.outputs.dict().values(): all_patterns.update(patterns.values()) return list(all_patterns) diff --git a/pipeline/main.py b/pipeline/main.py index 824a78d..a2ce221 100644 --- a/pipeline/main.py +++ b/pipeline/main.py @@ -2,9 +2,7 @@ from pathlib import Path -import pydantic - -from .exceptions import ProjectValidationError, YAMLError +from .exceptions import ProjectValidationError, ValidationError, YAMLError from .loading import parse_yaml_file from .models import Pipeline @@ -27,8 +25,8 @@ def load_pipeline(pipeline_config: str | Path, filename: str | None = None) -> P # validate try: - return Pipeline(**parsed_data) - except pydantic.ValidationError as exc: + return Pipeline.build(**parsed_data) + except ValidationError as exc: raise ProjectValidationError( f"Invalid project: {filename or ''}\n{exc}" ) from exc diff --git a/pipeline/models.py b/pipeline/models.py index a89c595..65fc257 100644 --- a/pipeline/models.py +++ b/pipeline/models.py @@ -1,25 +1,24 @@ +from __future__ import annotations + import pathlib import re import shlex from collections import defaultdict -from typing import Any, Dict, Iterable, List, Optional, Set, TypedDict - -from pydantic import BaseModel, root_validator, validator +from dataclasses import dataclass +from typing import Any from .constants import RUN_ALL_COMMAND -from .exceptions import InvalidPatternError +from .exceptions import InvalidPatternError, ValidationError from .features import LATEST_VERSION, get_feature_flags_for_version -from .types import RawOutputs, RawPipeline from .validation import ( - assert_valid_glob_pattern, validate_cohortextractor_outputs, validate_databuilder_outputs, + validate_glob_pattern, + validate_no_kwargs, + validate_type, ) -cohortextractor_pat = re.compile(r"cohortextractor:\S+ generate_cohort") -databuilder_pat = re.compile(r"databuilder|ehrql:\S+ generate[-_]dataset") - # orderd by most common, going forwards DB_COMMANDS = { "ehrql": ("generate-dataset", "generate-measures"), @@ -29,7 +28,7 @@ } -def is_database_action(args: List[str]) -> bool: +def is_database_action(args: list[str]) -> bool: """ By default actions do not have database access, but certain trusted actions require it """ @@ -50,59 +49,96 @@ def is_database_action(args: List[str]) -> bool: return args[1] in db_commands -class Expectations(BaseModel): +@dataclass(frozen=True) +class Expectations: population_size: int - @validator("population_size", pre=True) - def validate_population_size(cls, population_size: str) -> int: + @classmethod + def build( + cls, + population_size: Any = None, + **kwargs: Any, + ) -> Expectations: + validate_no_kwargs(kwargs, "project `expectations` section") try: - return int(population_size) + population_size = int(population_size) except (TypeError, ValueError): - raise ValueError( + raise ValidationError( "Project expectations population size must be a number", ) - - -class Outputs(BaseModel): - highly_sensitive: Optional[Dict[str, str]] - moderately_sensitive: Optional[Dict[str, str]] - minimally_sensitive: Optional[Dict[str, str]] - - def __len__(self) -> int: - return len(self.dict(exclude_unset=True)) - - @root_validator() - def at_least_one_output(cls, outputs: Dict[str, str]) -> Dict[str, str]: - if not any(outputs.values()): - raise ValueError( - f"must specify at least one output of: {', '.join(outputs)}" + return cls(population_size) + + +@dataclass(frozen=True) +class Outputs: + highly_sensitive: dict[str, str] | None + moderately_sensitive: dict[str, str] | None + minimally_sensitive: dict[str, str] | None + + @classmethod + def build( + cls, + action_id: str, + highly_sensitive: Any = None, + moderately_sensitive: Any = None, + minimally_sensitive: Any = None, + **kwargs: Any, + ) -> Outputs: + if ( + highly_sensitive is None + and moderately_sensitive is None + and minimally_sensitive is None + ): + raise ValidationError( + f"must specify at least one output of: {', '.join(['highly_sensitive', 'moderately_sensitive', 'minimally_sensitive'])}" ) - return outputs - - @root_validator(pre=True) - def validate_output_filenames_are_valid(cls, outputs: RawOutputs) -> RawOutputs: - # we use pre=True here so that we only get the outputs specified in the - # input data. With Optional[…] wrapped fields pydantic will set None - # for us and that just makes the logic a little fiddler with no - # benefit. - for privacy_level, output in outputs.items(): - for output_id, filename in output.items(): - try: - assert_valid_glob_pattern(filename, privacy_level) - except InvalidPatternError as e: - raise ValueError(f"Output path {filename} is invalid: {e}") - - return outputs + validate_no_kwargs(kwargs, f"`outputs` section for action {action_id}") + cls.validate_output_filenames_are_valid( + action_id, "highly_sensitive", highly_sensitive + ) + cls.validate_output_filenames_are_valid( + action_id, "moderately_sensitive", moderately_sensitive + ) + cls.validate_output_filenames_are_valid( + action_id, "minimally_sensitive", minimally_sensitive + ) -class Command(BaseModel): - raw: str # original string + return cls(highly_sensitive, moderately_sensitive, minimally_sensitive) - class Config: - # this makes Command hashable, which for some reason due to the - # Action.parse_run_string works, pydantic requires. - frozen = True + def __len__(self) -> int: + return len(self.dict()) + + def dict(self) -> dict[str, dict[str, str]]: + d = { + k: getattr(self, k) + for k in [ + "highly_sensitive", + "moderately_sensitive", + "minimally_sensitive", + ] + } + return {k: v for k, v in d.items() if v is not None} + + @classmethod + def validate_output_filenames_are_valid( + cls, action_id: str, privacy_level: str, output: Any + ) -> None: + if output is None: + return + validate_type(output, dict, f"`{privacy_level}` section for action {action_id}") + for output_id, filename in output.items(): + validate_type(filename, str, f"`{output_id}` output for action {action_id}") + try: + validate_glob_pattern(filename, privacy_level) + except InvalidPatternError as e: + raise ValidationError(f"Output path {filename} is invalid: {e}") + + +@dataclass(frozen=True) +class Command: + raw: str @property def args(self) -> str: @@ -114,7 +150,7 @@ def name(self) -> str: return self.parts[0].split(":")[0] @property - def parts(self) -> List[str]: + def parts(self) -> list[str]: return shlex.split(self.raw) @property @@ -123,20 +159,71 @@ def version(self) -> str: return self.parts[0].split(":")[1] -class Action(BaseModel): - config: Optional[Dict[Any, Any]] = None - run: Command - needs: List[str] = [] +@dataclass(frozen=True) +class Action: + action_id: str outputs: Outputs - dummy_data_file: Optional[pathlib.Path] + run: Command + needs: list[str] + config: dict[Any, Any] | None + dummy_data_file: pathlib.Path | None + + @classmethod + def build( + cls, + action_id: str, + outputs: Any, + run: Any, + needs: Any = None, + config: Any = None, + dummy_data_file: Any = None, + **kwargs: Any, + ) -> Action: + validate_no_kwargs(kwargs, f"action {action_id}") + validate_type(outputs, dict, f"`outputs` section for action {action_id}") + validate_type(run, str, f"`run` section for action {action_id}") + validate_type( + needs, list, f"`needs` section for action {action_id}", optional=True + ) + validate_type( + config, dict, f"`config` section for action {action_id}", optional=True + ) + validate_type( + dummy_data_file, + str, + f"`dummy_data_file` section for action {action_id}", + optional=True, + ) + + outputs = Outputs.build(action_id=action_id, **outputs) + run = cls.parse_run_string(action_id, run) + needs = needs or [] + for n in needs: + if " " in n: + raise ValidationError( + f"`needs` actions should be separated with commas, but {action_id} needs `{n}`" + ) + action = cls(action_id, outputs, run, needs, config, dummy_data_file) + + if re.match(r"cohortextractor:\S+ generate_cohort", run.raw): + validate_cohortextractor_outputs(action_id, action) + if re.match(r"databuilder|ehrql:\S+ generate[-_]dataset", run.raw): + validate_databuilder_outputs(action_id, action) + + return action + + @classmethod + def parse_run_string(cls, action_id: str, run: str) -> Command: + if run == "": + raise ValidationError( + f"run must have a value, {action_id} has an empty run key" + ) - @validator("run", pre=True) - def parse_run_string(cls, run: str) -> Command: parts = shlex.split(run) name, _, version = parts[0].partition(":") if not version: - raise ValueError( + raise ValidationError( f"{name} must have a version specified (e.g. {name}:0.5.2)", ) @@ -147,217 +234,93 @@ def is_database_action(self) -> bool: return is_database_action(self.run.parts) -class PartiallyValidatedPipeline(TypedDict): - """ - A custom type to type-check the values in "post" root validators - - A root_validator with pre=False (or no kwargs) runs after the values have - been ingested already, and the `values` arg is a dictionary of model types. - - Note: This is defined here so we don't have to deal with forward reference - types. - """ - +@dataclass(frozen=True) +class Pipeline: version: float + actions: dict[str, Action] expectations: Expectations - actions: Dict[str, Action] - - -class Pipeline(BaseModel): - version: float - expectations: Expectations - actions: Dict[str, Action] - - @property - def all_actions(self) -> List[str]: - """ - Get all actions for this Pipeline instance - - We ignore any manually defined run_all action (in later project - versions this will be an error). We use a list comprehension rather - than set operators as previously so we preserve the original order. - """ - return [action for action in self.actions.keys() if action != RUN_ALL_COMMAND] - - @root_validator() - def validate_actions( - cls, values: PartiallyValidatedPipeline - ) -> PartiallyValidatedPipeline: - # TODO: move to Action when we move name onto it - validators = { - cohortextractor_pat: validate_cohortextractor_outputs, - databuilder_pat: validate_databuilder_outputs, - } - for action_id, config in values.get("actions", {}).items(): - for cmd, validator_func in validators.items(): - if cmd.match(config.run.raw): - validator_func(action_id, config) - return values + @classmethod + def build( + cls, + version: Any = None, + actions: Any = None, + expectations: Any = None, + **kwargs: Any, + ) -> Pipeline: + validate_no_kwargs(kwargs, "project") + if version is None: + raise ValidationError( + f"Project file must have a `version` attribute specifying which " + f"version of the project configuration format it uses (current " + f"latest version is {LATEST_VERSION})" + ) - @root_validator(pre=True) - def validate_expectations_per_version(cls, values: RawPipeline) -> RawPipeline: - """Ensure the expectations key exists for version 3 onwards""" try: - version = float(values["version"]) - except (KeyError, TypeError, ValueError): - # this is handled in the validate_version_exists and - # validate_version_value validators - return values - - feat = get_feature_flags_for_version(version) - - if not feat.EXPECTATIONS_POPULATION: - # set the default here because pydantic doesn't seem to set it - # otherwise - values["expectations"] = {"population_size": 1000} - return values - - if "expectations" not in values: - raise ValueError("Project must include `expectations` section") - - if "population_size" not in values["expectations"]: - raise ValueError( - "Project `expectations` section must include `population_size` section", + version = float(version) + except (TypeError, ValueError): + raise ValidationError( + f"`version` must be a number between 1 and {LATEST_VERSION}" ) - return values - - @root_validator() - def validate_outputs_per_version( - cls, values: PartiallyValidatedPipeline - ) -> PartiallyValidatedPipeline: - """ - Ensure outputs are unique for version 2 onwards - - We validate this on Pipeline so we can get the version - """ - - # we're not using pre=True in the validator so we can rely on the - # version and action keys being the correct type but we have to handle - # them not existing - if not (version := values.get("version")): - return values # handle missing version - - if (actions := values.get("actions")) is None: - return values # hand no actions - - feat = get_feature_flags_for_version(version) - if not feat.UNIQUE_OUTPUT_PATH: - return values - - # find duplicate paths defined in the outputs section - seen_files = [] - for config in actions.values(): - for output in config.outputs.dict(exclude_unset=True).values(): - for filename in output.values(): - if filename in seen_files: - raise ValueError(f"Output path {filename} is not unique") - - seen_files.append(filename) - - return values - - @root_validator(pre=True) - def validate_actions_run(cls, values: RawPipeline) -> RawPipeline: - # TODO: move to Action when we move name onto it - for action_id, config in values.get("actions", {}).items(): - if config["run"] == "": - # key is present but empty - raise ValueError( - f"run must have a value, {action_id} has an empty run key" - ) - - return values + validate_type(actions, dict, "Project `actions` section") + actions = { + action_id: Action.build(action_id, **action_config) + for action_id, action_config in actions.items() + } - @validator("actions") - def validate_unique_commands(cls, actions: Dict[str, Action]) -> Dict[str, Action]: - seen: Dict[Command, List[str]] = defaultdict(list) + seen: dict[Command, list[str]] = defaultdict(list) for name, config in actions.items(): run = config.run if run in seen: - raise ValueError( + raise ValidationError( f"Action {name} has the same 'run' command as other actions: {' ,'.join(seen[run])}" ) seen[run].append(name) - return actions - - @validator("actions") - def validate_needs_are_comma_delimited( - cls, actions: Dict[str, Action] - ) -> Dict[str, Action]: - space_delimited = {} - for name, action in actions.items(): - # find needs definitions with spaces in them - incorrect = [dep for dep in action.needs if " " in dep] - if incorrect: - space_delimited[name] = incorrect - - if not space_delimited: - return actions - - def iter_incorrect_needs( - space_delimited: Dict[str, List[str]] - ) -> Iterable[str]: - for name, needs in space_delimited.items(): - yield f"Action: {name}" - for need in needs: - yield f" - {need}" - - msg = [ - "`needs` actions should be separated with commas. The following actions need fixing:", - *iter_incorrect_needs(space_delimited), - ] - - raise ValueError("\n".join(msg)) - - @validator("actions") - def validate_needs_exist(cls, actions: Dict[str, Action]) -> Dict[str, Action]: - missing = {} - for name, action in actions.items(): - unknown_needs = set(action.needs) - set(actions) - if unknown_needs: - missing[name] = unknown_needs - - if not missing: - return actions - - def iter_missing_needs(missing: Dict[str, Set[str]]) -> Iterable[str]: - for name, needs in missing.items(): - yield f"Action: {name}" - for need in needs: - yield f" - {need}" - - msg = [ - "One or more actions is referencing unknown actions in its needs list:", - *iter_missing_needs(missing), - ] - raise ValueError("\n".join(msg)) - - @root_validator(pre=True) - def validate_version_exists(cls, values: RawPipeline) -> RawPipeline: - """ - Ensure the version key exists. + if get_feature_flags_for_version(version).UNIQUE_OUTPUT_PATH: + # find duplicate paths defined in the outputs section + seen_files = [] + for config in actions.values(): + for output in config.outputs.dict().values(): + for filename in output.values(): + if filename in seen_files: + raise ValidationError( + f"Output path {filename} is not unique" + ) + + seen_files.append(filename) + + for a in actions.values(): + for n in a.needs: + if n not in actions: + raise ValidationError( + f"Action `{a.action_id}` references an unknown action in its `needs` list: {n}" + ) - This is a re-implementation of pydantic's field validation so we can - get a custom error message. This can be removed when we add a wrapper - around the models to generate more UI friendly error messages. - """ - if "version" in values: - return values + feat = get_feature_flags_for_version(version) + if feat.EXPECTATIONS_POPULATION: + if expectations is None: + raise ValidationError("Project must include `expectations` section") + else: + expectations = {"population_size": 1000} + + validate_type(expectations, dict, "Project `expectations` section") + if "population_size" not in expectations: + raise ValidationError( + "Project `expectations` section must include `population_size` section", + ) + expectations = Expectations.build(**expectations) - raise ValueError( - f"Project file must have a `version` attribute specifying which " - f"version of the project configuration format it uses (current " - f"latest version is {LATEST_VERSION})" - ) + return cls(version, actions, expectations) - @validator("version", pre=True) - def validate_version_value(cls, value: str) -> float: - try: - return float(value) - except (TypeError, ValueError): - raise ValueError( - f"`version` must be a number between 1 and {LATEST_VERSION}" - ) + @property + def all_actions(self) -> list[str]: + """ + Get all actions for this Pipeline instance + + We ignore any manually defined run_all action (in later project + versions this will be an error). We use a list comprehension rather + than set operators as previously so we preserve the original order. + """ + return [action for action in self.actions.keys() if action != RUN_ALL_COMMAND] diff --git a/pipeline/outputs.py b/pipeline/outputs.py index 3699cc1..c76ec3a 100644 --- a/pipeline/outputs.py +++ b/pipeline/outputs.py @@ -23,5 +23,5 @@ def get_output_dirs(output_spec: Outputs) -> list[PurePosixPath]: def iter_all_outputs(output_spec: Outputs) -> Iterator[str]: - for group in output_spec.dict(exclude_unset=True).values(): + for group in output_spec.dict().values(): yield from group.values() diff --git a/pipeline/types.py b/pipeline/types.py deleted file mode 100644 index aa359e5..0000000 --- a/pipeline/types.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Custom types to type check the raw input data - -When loading data from YAML we get dictionaries which pydantic attempts to -validate. Some of our validation is done via custom methods using the raw -dictionary data. -""" - -from __future__ import annotations - -import pathlib -from typing import Any, Dict, TypedDict - - -RawOutputs = Dict[str, Dict[str, str]] - - -class RawAction(TypedDict): - config: dict[Any, Any] | None - run: str - needs: list[str] | None - outputs: RawOutputs - dummy_data_file: pathlib.Path | None - - -class RawExpectations(TypedDict): - population_size: str | int | None - - -class RawPipeline(TypedDict): - version: str | float | int - expectations: RawExpectations - actions: dict[str, RawAction] diff --git a/pipeline/validation.py b/pipeline/validation.py index e74ae93..d1cb2f2 100644 --- a/pipeline/validation.py +++ b/pipeline/validation.py @@ -2,10 +2,10 @@ import posixpath from pathlib import Path, PurePosixPath, PureWindowsPath -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from .constants import LEVEL4_FILE_TYPES -from .exceptions import InvalidPatternError +from .exceptions import InvalidPatternError, ValidationError from .outputs import get_first_output_file, get_output_dirs @@ -13,7 +13,24 @@ from .models import Action -def assert_valid_glob_pattern(pattern: str, privacy_level: str) -> None: +def validate_type(val: Any, exp_type: type, loc: str, optional: bool = False) -> None: + type_lookup: dict[type, str] = { + str: "string", + dict: "dictionary of key/value pairs", + list: "list", + } + if optional and val is None: + return + if not isinstance(val, exp_type): + raise ValidationError(f"{loc} must be a {type_lookup[exp_type]}") + + +def validate_no_kwargs(kwargs: dict[str, Any], loc: str) -> None: + if kwargs: + raise ValidationError(f"Unexpected parameters ({', '.join(kwargs)}) in {loc}") + + +def validate_glob_pattern(pattern: str, privacy_level: str) -> None: """ These patterns get converted into regular expressions and matched with a `find` command so there shouldn't be any possibility of a path @@ -72,7 +89,7 @@ def validate_cohortextractor_outputs(action_id: str, action: Action) -> None: # ensure we only have output level defined. num_output_levels = len(action.outputs) if num_output_levels != 1: - raise ValueError( + raise ValidationError( "A `generate_cohort` action must have exactly one output; " f"{action_id} had {num_output_levels}" ) @@ -90,7 +107,7 @@ def validate_cohortextractor_outputs(action_id: str, action: Action) -> None: arg == flag or arg.startswith(f"{flag}=") for arg in action.run.parts ) if not has_output_dir: - raise ValueError( + raise ValidationError( f"generate_cohort command should produce output in only one " f"directory, found {len(output_dirs)}:\n" + "\n".join([f" - {d}/" for d in output_dirs]) @@ -107,11 +124,11 @@ def validate_databuilder_outputs(action_id: str, action: Action) -> None: # TODO: should this be checking output _paths_ instead of levels? num_output_levels = len(action.outputs) if num_output_levels != 1: - raise ValueError( + raise ValidationError( "A `generate-dataset` action must have exactly one output; " f"{action_id} had {num_output_levels}" ) first_output_file = get_first_output_file(action.outputs) if first_output_file not in action.run.raw: - raise ValueError("--output in run command and outputs must match") + raise ValidationError("--output in run command and outputs must match") diff --git a/pyproject.toml b/pyproject.toml index 4a91f9e..3c5ae3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,6 @@ classifiers = [ ] requires-python = ">=3.8" dependencies = [ - "pydantic<2", "ruyaml", ] dynamic = ["version"] @@ -51,9 +50,6 @@ skip_glob = [".direnv", "venv", ".venv"] [tool.mypy] files = "pipeline" exclude = "^pipeline/__main__.py$" -plugins = [ - "pydantic.mypy", -] strict = true warn_redundant_casts = true warn_unused_ignores = true diff --git a/requirements.prod.txt b/requirements.prod.txt index d770e2e..fb5ad3c 100644 --- a/requirements.prod.txt +++ b/requirements.prod.txt @@ -6,12 +6,8 @@ # distro==1.9.0 # via ruyaml -pydantic==1.10.17 - # via opensafely-pipeline (pyproject.toml) ruyaml==0.91.0 # via opensafely-pipeline (pyproject.toml) -typing-extensions==4.12.2 - # via pydantic # The following packages are considered to be unsafe in a requirements file: setuptools==71.1.0 diff --git a/tests/test_main.py b/tests/test_main.py index 7e3f1c3..03aeae2 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,7 +1,7 @@ -import pydantic import pytest from pipeline import ProjectValidationError, load_pipeline +from pipeline.exceptions import ValidationError from pipeline.models import Pipeline @@ -67,4 +67,4 @@ def test_load_pipeline_with_project_error_raises_projectvalidationerror(): with pytest.raises(ProjectValidationError, match="Invalid project") as exc: load_pipeline(config) - assert isinstance(exc.value.__cause__, pydantic.ValidationError) + assert isinstance(exc.value.__cause__, ValidationError) diff --git a/tests/test_models.py b/tests/test_models.py index 922fd19..55e2465 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,7 +1,7 @@ import pytest -from pydantic import ValidationError from pipeline import load_pipeline +from pipeline.exceptions import ValidationError from pipeline.models import Pipeline @@ -19,7 +19,7 @@ def test_success(): }, } - Pipeline(**data) + Pipeline.build(**data) def test_action_has_a_version(): @@ -37,7 +37,7 @@ def test_action_has_a_version(): msg = "test must have a version specified" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) def test_action_cohortextractor_multiple_outputs_with_output_flag(): @@ -56,7 +56,7 @@ def test_action_cohortextractor_multiple_outputs_with_output_flag(): }, } - run_command = Pipeline(**data).actions["generate_cohort"].run.raw + run_command = Pipeline.build(**data).actions["generate_cohort"].run.raw assert run_command == "cohortextractor:latest generate_cohort --output-dir=output" @@ -81,7 +81,7 @@ def test_action_cohortextractor_multiple_ouputs_without_output_flag(): "generate_cohort command should produce output in only one directory, found 2:" ) with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) @pytest.mark.parametrize( @@ -108,7 +108,7 @@ def test_action_extraction_command_with_multiple_outputs(image, command): msg = f"A `{command}` action must have exactly one output" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) def test_action_extraction_command_with_one_outputs(): @@ -124,9 +124,9 @@ def test_action_extraction_command_with_one_outputs(): }, } - config = Pipeline(**data) + config = Pipeline.build(**data) - outputs = config.actions["generate_cohort"].outputs.dict(exclude_unset=True) + outputs = config.actions["generate_cohort"].outputs.dict() assert len(outputs.values()) == 1 @@ -141,7 +141,7 @@ def test_command_properties(): }, } - action = Pipeline(**data).actions["generate_cohort"] + action = Pipeline.build(**data).actions["generate_cohort"] assert action.run.args == "generate_cohort another_arg" assert action.run.name == "cohortextractor" assert action.run.parts == [ @@ -164,7 +164,7 @@ def test_expectations_before_v3_has_a_default_set(): }, } - config = Pipeline(**data) + config = Pipeline.build(**data) assert config.expectations.population_size == 1000 @@ -183,7 +183,7 @@ def test_expectations_exists(): msg = "Project must include `expectations` section" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) def test_expectations_population_size_exists(): @@ -200,7 +200,7 @@ def test_expectations_population_size_exists(): msg = "Project `expectations` section must include `population_size` section" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) def test_expectations_population_size_is_a_number(): @@ -217,7 +217,7 @@ def test_expectations_population_size_is_a_number(): msg = "Project expectations population size must be a number" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) def test_pipeline_all_actions(test_file): @@ -252,7 +252,7 @@ def test_pipeline_needs_success(): }, } - config = Pipeline(**data) + config = Pipeline.build(**data) assert config.actions["do_analysis"].needs == ["generate_cohort"] @@ -277,9 +277,9 @@ def test_pipeline_needs_with_non_comma_delimited_actions(): }, } - msg = "`needs` actions should be separated with commas. The following actions need fixing:" + msg = "`needs` actions should be separated with commas, but do_further_analysis needs `generate_cohort do_analysis`" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) def test_pipeline_needs_with_unknown_action(): @@ -296,9 +296,9 @@ def test_pipeline_needs_with_unknown_action(): }, } - match = "One or more actions is referencing unknown actions in its needs list" + match = "Action `action1` references an unknown action in its `needs` list: action2" with pytest.raises(ValidationError, match=match): - Pipeline(**data) + Pipeline.build(**data) def test_pipeline_with_duplicated_action_run_commands(): @@ -322,7 +322,7 @@ def test_pipeline_with_duplicated_action_run_commands(): match = "Action action2 has the same 'run' command as other actions: action1" with pytest.raises(ValidationError, match=match): - Pipeline(**data) + Pipeline.build(**data) def test_pipeline_with_empty_run_command(): @@ -340,7 +340,7 @@ def test_pipeline_with_empty_run_command(): match = "run must have a value, action1 has an empty run key" with pytest.raises(ValidationError, match=match): - Pipeline(**data) + Pipeline.build(**data) def test_pipeline_with_missing_or_none_version(): @@ -357,7 +357,11 @@ def test_pipeline_with_missing_or_none_version(): msg = "Project file must have a `version` attribute" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) + + with pytest.raises(ValidationError, match=msg): + data["version"] = None + Pipeline.build(**data) def test_pipeline_with_non_numeric_version(): @@ -374,11 +378,7 @@ def test_pipeline_with_non_numeric_version(): with pytest.raises(ValidationError, match=msg): data["version"] = "test" - Pipeline(**data) - - with pytest.raises(ValidationError, match=msg): - data["version"] = None - Pipeline(**data) + Pipeline.build(**data) def test_outputs_files_are_unique(): @@ -399,7 +399,7 @@ def test_outputs_files_are_unique(): msg = "Output path output/input.csv is not unique" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) def test_outputs_duplicate_files_in_v1(): @@ -418,7 +418,7 @@ def test_outputs_duplicate_files_in_v1(): }, } - generate_cohort = Pipeline(**data).actions["generate_cohort"] + generate_cohort = Pipeline.build(**data).actions["generate_cohort"] cohort = generate_cohort.outputs.highly_sensitive["cohort"] test = generate_cohort.outputs.highly_sensitive["test"] @@ -431,7 +431,7 @@ def test_outputs_with_unknown_privacy_level(): with pytest.raises(ValidationError, match=msg): # no outputs - Pipeline( + Pipeline.build( **{ "version": 1, "actions": { @@ -444,7 +444,7 @@ def test_outputs_with_unknown_privacy_level(): ) with pytest.raises(ValidationError, match=msg): - Pipeline( + Pipeline.build( **{ "version": 1, "actions": { @@ -470,7 +470,7 @@ def test_outputs_with_invalid_pattern(): msg = "Output path test/foo is invalid:" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) @pytest.mark.parametrize("image,tag", [("databuilder", "latest"), ("ehrql", "v0")]) @@ -485,7 +485,7 @@ def test_pipeline_databuilder_specifies_same_output(image, tag): }, } - Pipeline(**data) + Pipeline.build(**data) @pytest.mark.parametrize("image,tag", [("databuilder", "latest"), ("ehrql", "v0")]) @@ -502,7 +502,7 @@ def test_pipeline_databuilder_specifies_different_output(image, tag): msg = "--output in run command and outputs must match" with pytest.raises(ValidationError, match=msg): - Pipeline(**data) + Pipeline.build(**data) def test_pipeline_databuilder_recognizes_old_action_spelling(): @@ -519,7 +519,7 @@ def test_pipeline_databuilder_recognizes_old_action_spelling(): } with pytest.raises(ValidationError): - Pipeline(**data) + Pipeline.build(**data) @pytest.mark.parametrize( @@ -575,5 +575,5 @@ def test_action_is_database_action(name, run, is_database_action): }, } - action = Pipeline(**data).actions[name] + action = Pipeline.build(**data).actions[name] assert action.is_database_action == is_database_action diff --git a/tests/test_outputs.py b/tests/test_outputs.py index 5a02cbb..00014e0 100644 --- a/tests/test_outputs.py +++ b/tests/test_outputs.py @@ -5,7 +5,8 @@ def test_get_output_dirs_with_duplicates(): - outputs = Outputs( + outputs = Outputs.build( + action_id="test", highly_sensitive={ "a": "output/1a.csv", "b": "output/2a.csv", @@ -18,7 +19,8 @@ def test_get_output_dirs_with_duplicates(): def test_get_output_dirs_without_duplicates(): - outputs = Outputs( + outputs = Outputs.build( + action_id="test", highly_sensitive={ "a": "1a/output.csv", "b": "2a/output.csv", diff --git a/tests/test_type_validation.py b/tests/test_type_validation.py new file mode 100644 index 0000000..2ae5720 --- /dev/null +++ b/tests/test_type_validation.py @@ -0,0 +1,176 @@ +import re + +import pytest + +from pipeline.exceptions import ValidationError +from pipeline.models import Pipeline + + +def test_missing_actions(): + with pytest.raises( + ValidationError, match="Project `actions` section must be a dictionary" + ): + Pipeline.build(version=3, expectations={"population_size": 10}) + + +def test_actions_incorrect_type(): + with pytest.raises( + ValidationError, match="Project `actions` section must be a dictionary" + ): + Pipeline.build(version=3, actions=[], expectations={"population_size": 10}) + + +def test_expectations_incorrect_type(): + with pytest.raises( + ValidationError, match="Project `expectations` section must be a dictionary" + ): + Pipeline.build(version=3, actions={}, expectations=[]) + + +def test_outputs_incorrect_type(): + with pytest.raises( + ValidationError, + match="`outputs` section for action action1 must be a dictionary", + ): + Pipeline.build( + version=3, + actions={"action1": {"outputs": [], "run": "test:v1"}}, + expectations={"population_size": 10}, + ) + + +def test_run_incorrect_type(): + with pytest.raises( + ValidationError, match="`run` section for action action1 must be a string" + ): + Pipeline.build( + version=3, + actions={"action1": {"outputs": {}, "run": ["test:v1"]}}, + expectations={"population_size": 10}, + ) + + +def test_needs_incorrect_type(): + with pytest.raises( + ValidationError, match="`needs` section for action action1 must be a list" + ): + Pipeline.build( + version=3, + actions={"action1": {"outputs": {}, "run": "test:v1", "needs": ""}}, + expectations={"population_size": 10}, + ) + + +def test_config_incorrect_type(): + with pytest.raises( + ValidationError, + match="`config` section for action action1 must be a dictionary", + ): + Pipeline.build( + version=3, + actions={"action1": {"outputs": {}, "run": "test:v1", "config": []}}, + expectations={"population_size": 10}, + ) + + +def test_dummy_data_file_incorrect_type(): + with pytest.raises( + ValidationError, + match="`dummy_data_file` section for action action1 must be a string", + ): + Pipeline.build( + version=3, + actions={ + "action1": {"outputs": {}, "run": "test:v1", "dummy_data_file": []} + }, + expectations={"population_size": 10}, + ) + + +def test_output_files_incorrect_type(): + with pytest.raises( + ValidationError, + match="`highly_sensitive` section for action action1 must be a dictionary", + ): + Pipeline.build( + version=3, + actions={ + "action1": {"outputs": {"highly_sensitive": []}, "run": "test:v1"} + }, + expectations={"population_size": 10}, + ) + + +def test_output_filename_incorrect_type(): + with pytest.raises( + ValidationError, + match="`dataset` output for action action1 must be a string", + ): + Pipeline.build( + version=3, + actions={ + "action1": { + "outputs": {"highly_sensitive": {"dataset": {}}}, + "run": "test:v1", + } + }, + expectations={"population_size": 10}, + ) + + +def test_project_extra_parameters(): + with pytest.raises( + ValidationError, match=re.escape("Unexpected parameters (extra) in project") + ): + Pipeline.build(extra=123) + + +def test_action_extra_parameters(): + with pytest.raises( + ValidationError, + match=re.escape("Unexpected parameters (extra) in action action1"), + ): + Pipeline.build( + version=3, + actions={ + "action1": { + "outputs": {}, + "run": "test:v1", + "extra": 123, + } + }, + expectations={"population_size": 10}, + ) + + +def test_outputs_extra_parameters(): + with pytest.raises( + ValidationError, + match=re.escape( + "Unexpected parameters (extra) in `outputs` section for action action1" + ), + ): + Pipeline.build( + version=3, + actions={ + "action1": { + "outputs": {"highly_sensitive": {"dataset": {}}, "extra": 123}, + "run": "test:v1", + } + }, + expectations={"population_size": 10}, + ) + + +def test_expectations_extra_parameters(): + with pytest.raises( + ValidationError, + match=re.escape( + "Unexpected parameters (extra) in project `expectations` section" + ), + ): + Pipeline.build( + version=3, + actions={}, + expectations={"population_size": 10, "extra": 123}, + ) diff --git a/tests/test_validation.py b/tests/test_validation.py index 5a6c5bc..83a61f4 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,12 +1,12 @@ import pytest from pipeline.exceptions import InvalidPatternError -from pipeline.validation import assert_valid_glob_pattern +from pipeline.validation import validate_glob_pattern -def test_assert_valid_glob_pattern(): - assert_valid_glob_pattern("foo/bar/*.txt", "highly_sensitive") - assert_valid_glob_pattern("foo/bar/*.txt", "moderately_sensitive") +def test_validate_glob_pattern(): + validate_glob_pattern("foo/bar/*.txt", "highly_sensitive") + validate_glob_pattern("foo/bar/*.txt", "moderately_sensitive") bad_patterns = [ ("/abs/path.txt", "highly_sensitive"), ("not//canonical.txt", "highly_sensitive"), @@ -25,4 +25,4 @@ def test_assert_valid_glob_pattern(): ] for pattern, sensitivity in bad_patterns: with pytest.raises(InvalidPatternError): - assert_valid_glob_pattern(pattern, sensitivity) + validate_glob_pattern(pattern, sensitivity)