diff --git a/tests/conftest.py b/tests/conftest.py index 90788a045..f4cc2dfcb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,11 +15,11 @@ import fmu.dataio as dio from fmu.config import utilities as ut -from fmu.dataio._model import fields, global_configuration +from fmu.dataio._model import Root, fields, global_configuration from fmu.dataio.dataio import ExportData, read_metadata from fmu.dataio.providers._fmu import FmuEnv -from .utils import _metadata_examples +from .utils import _get_nested_pydantic_models, _metadata_examples logger = logging.getLogger(__name__) @@ -699,3 +699,9 @@ def fixture_drogon_volumes(rootpath): rootpath / "tests/data/drogon/tabular/geogrid--vol.csv", ) ) + + +@pytest.fixture(scope="session") +def pydantic_models_from_root(): + """Return all nested pydantic models from Root and downwards""" + return _get_nested_pydantic_models(Root) diff --git a/tests/test_schema/test_pydantic_logic.py b/tests/test_schema/test_pydantic_logic.py index d7fc7fff6..3674c01af 100644 --- a/tests/test_schema/test_pydantic_logic.py +++ b/tests/test_schema/test_pydantic_logic.py @@ -2,6 +2,7 @@ import logging from copy import deepcopy +from typing import get_args import pytest from pydantic import ValidationError @@ -32,6 +33,22 @@ def test_validate(file, example): Root.model_validate(example) +def test_for_optional_fields_without_default(pydantic_models_from_root): + """Test that all optional fields have a default value""" + optionals_without_default = [] + for model in pydantic_models_from_root: + for field_name, field_info in model.model_fields.items(): + if ( + type(None) in get_args(field_info.annotation) + and field_info.is_required() + ): + optionals_without_default.append( + f"{model.__module__}.{model.__name__}.{field_name}" + ) + + assert not optionals_without_default + + def test_schema_file_block(metadata_examples): """Test variations on the file block.""" diff --git a/tests/test_units/test_utils.py b/tests/test_units/test_utils.py index 1b24ef9cd..9e9c70b06 100644 --- a/tests/test_units/test_utils.py +++ b/tests/test_units/test_utils.py @@ -3,14 +3,16 @@ import os from pathlib import Path from tempfile import NamedTemporaryFile +from typing import Dict, List, Optional, Union import numpy as np import pytest from xtgeo import Grid, Polygons, RegularSurface from fmu.dataio import _utils as utils +from fmu.dataio._model import fields -from ..utils import inside_rms +from ..utils import _get_pydantic_models_from_annotation, inside_rms @pytest.mark.parametrize( @@ -148,3 +150,29 @@ def test_read_named_envvar(): os.environ["MYTESTENV"] = "mytestvalue" assert utils.read_named_envvar("MYTESTENV") == "mytestvalue" + + +def test_get_pydantic_models_from_annotation(): + annotation = Union[List[fields.Access], fields.File] + assert _get_pydantic_models_from_annotation(annotation) == [ + fields.Access, + fields.File, + ] + annotation = Optional[Union[Dict[str, fields.Access], List[fields.File]]] + assert _get_pydantic_models_from_annotation(annotation) == [ + fields.Access, + fields.File, + ] + + annotation = List[Union[fields.Access, fields.File, fields.Tracklog]] + assert _get_pydantic_models_from_annotation(annotation) == [ + fields.Access, + fields.File, + fields.Tracklog, + ] + + annotation = List[List[List[List[fields.Tracklog]]]] + assert _get_pydantic_models_from_annotation(annotation) == [fields.Tracklog] + + annotation = Union[str, List[int], Dict[str, int]] + assert not _get_pydantic_models_from_annotation(annotation) diff --git a/tests/utils.py b/tests/utils.py index 42a490670..092371319 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,9 +1,13 @@ +from __future__ import annotations + import datetime from functools import wraps from pathlib import Path +from typing import Any, get_args import pytest import yaml +from pydantic import BaseModel def inside_rms(func): @@ -43,3 +47,28 @@ def _metadata_examples(): path.name: _isoformat_all_datetimes(_parse_yaml(path)) for path in Path(".").absolute().glob("schema/definitions/0.8.0/examples/*.yml") } + + +def _get_pydantic_models_from_annotation(annotation: Any) -> list[Any]: + """ + Get a list of all pydantic models defined inside an annotation. + Example: Union[Model1, list[dict[str, Model2]]] returns [Model1, Model2] + """ + if isinstance(annotation, type(BaseModel)): + return [annotation] + + annotations = [] + for ann in get_args(annotation): + annotations += _get_pydantic_models_from_annotation(ann) + return annotations + + +def _get_nested_pydantic_models(model: type[BaseModel]) -> set[type[BaseModel]]: + """Get a set of all nested pydantic models from a pydantic model""" + models = {model} + + for field_info in model.model_fields.values(): + for model in _get_pydantic_models_from_annotation(field_info.annotation): + if model not in models: + models.update(_get_nested_pydantic_models(model)) + return models