Skip to content

Commit

Permalink
TST: Add test to ensure all optional fields have defaults (#834)
Browse files Browse the repository at this point in the history
  • Loading branch information
tnatt authored Oct 10, 2024
1 parent 1a1054e commit 1cc79a8
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 3 deletions.
10 changes: 8 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@

import fmu.dataio as dio
from fmu.config import utilities as ut
from fmu.dataio._model import fields, global_configuration
from fmu.dataio._model import Root, fields, global_configuration
from fmu.dataio.dataio import ExportData, read_metadata
from fmu.dataio.providers._fmu import FmuEnv

from .utils import _metadata_examples
from .utils import _get_nested_pydantic_models, _metadata_examples

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -699,3 +699,9 @@ def fixture_drogon_volumes(rootpath):
rootpath / "tests/data/drogon/tabular/geogrid--vol.csv",
)
)


@pytest.fixture(scope="session")
def pydantic_models_from_root():
"""Return all nested pydantic models from Root and downwards"""
return _get_nested_pydantic_models(Root)
17 changes: 17 additions & 0 deletions tests/test_schema/test_pydantic_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from copy import deepcopy
from typing import get_args

import pytest
from pydantic import ValidationError
Expand Down Expand Up @@ -32,6 +33,22 @@ def test_validate(file, example):
Root.model_validate(example)


def test_for_optional_fields_without_default(pydantic_models_from_root):
"""Test that all optional fields have a default value"""
optionals_without_default = []
for model in pydantic_models_from_root:
for field_name, field_info in model.model_fields.items():
if (
type(None) in get_args(field_info.annotation)
and field_info.is_required()
):
optionals_without_default.append(
f"{model.__module__}.{model.__name__}.{field_name}"
)

assert not optionals_without_default


def test_schema_file_block(metadata_examples):
"""Test variations on the file block."""

Expand Down
30 changes: 29 additions & 1 deletion tests/test_units/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Dict, List, Optional, Union

import numpy as np
import pytest
from xtgeo import Grid, Polygons, RegularSurface

from fmu.dataio import _utils as utils
from fmu.dataio._model import fields

from ..utils import inside_rms
from ..utils import _get_pydantic_models_from_annotation, inside_rms


@pytest.mark.parametrize(
Expand Down Expand Up @@ -148,3 +150,29 @@ def test_read_named_envvar():

os.environ["MYTESTENV"] = "mytestvalue"
assert utils.read_named_envvar("MYTESTENV") == "mytestvalue"


def test_get_pydantic_models_from_annotation():
annotation = Union[List[fields.Access], fields.File]
assert _get_pydantic_models_from_annotation(annotation) == [
fields.Access,
fields.File,
]
annotation = Optional[Union[Dict[str, fields.Access], List[fields.File]]]
assert _get_pydantic_models_from_annotation(annotation) == [
fields.Access,
fields.File,
]

annotation = List[Union[fields.Access, fields.File, fields.Tracklog]]
assert _get_pydantic_models_from_annotation(annotation) == [
fields.Access,
fields.File,
fields.Tracklog,
]

annotation = List[List[List[List[fields.Tracklog]]]]
assert _get_pydantic_models_from_annotation(annotation) == [fields.Tracklog]

annotation = Union[str, List[int], Dict[str, int]]
assert not _get_pydantic_models_from_annotation(annotation)
29 changes: 29 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

import datetime
from functools import wraps
from pathlib import Path
from typing import Any, get_args

import pytest
import yaml
from pydantic import BaseModel


def inside_rms(func):
Expand Down Expand Up @@ -43,3 +47,28 @@ def _metadata_examples():
path.name: _isoformat_all_datetimes(_parse_yaml(path))
for path in Path(".").absolute().glob("schema/definitions/0.8.0/examples/*.yml")
}


def _get_pydantic_models_from_annotation(annotation: Any) -> list[Any]:
"""
Get a list of all pydantic models defined inside an annotation.
Example: Union[Model1, list[dict[str, Model2]]] returns [Model1, Model2]
"""
if isinstance(annotation, type(BaseModel)):
return [annotation]

annotations = []
for ann in get_args(annotation):
annotations += _get_pydantic_models_from_annotation(ann)
return annotations


def _get_nested_pydantic_models(model: type[BaseModel]) -> set[type[BaseModel]]:
"""Get a set of all nested pydantic models from a pydantic model"""
models = {model}

for field_info in model.model_fields.values():
for model in _get_pydantic_models_from_annotation(field_info.annotation):
if model not in models:
models.update(_get_nested_pydantic_models(model))
return models

0 comments on commit 1cc79a8

Please sign in to comment.