Skip to content

Commit

Permalink
⬆️Pydantic V2: Migrate director v0 + some fixes from query PR (#6755)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanderegg authored Nov 19, 2024
1 parent 373c314 commit f442e63
Show file tree
Hide file tree
Showing 28 changed files with 262 additions and 121 deletions.
31 changes: 14 additions & 17 deletions api/specs/web-server/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from typing import Any, ClassVar, NamedTuple

import yaml
from common_library.json_serialization import json_dumps
from common_library.pydantic_fields_extension import get_type
from fastapi import FastAPI, Query
from models_library.basic_types import LogLevel
from models_library.utils.json_serialization import json_dumps
from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo
from servicelib.fastapi.openapi import override_fastapi_openapi_method
Expand All @@ -38,31 +38,28 @@ def __modify_schema__(cls, field_schema: dict[str, Any]) -> None:

def as_query(model_class: type[BaseModel]) -> type[BaseModel]:
fields = {}
for field_name, model_field in model_class.__fields__.items():
for field_name, field_info in model_class.model_fields.items():

field_type = model_field.type_
default_value = model_field.default
field_type = get_type(field_info)
default_value = field_info.default

kwargs = {
"alias": model_field.field_info.alias,
"title": model_field.field_info.title,
"description": model_field.field_info.description,
"gt": model_field.field_info.gt,
"ge": model_field.field_info.ge,
"lt": model_field.field_info.lt,
"le": model_field.field_info.le,
"min_length": model_field.field_info.min_length,
"max_length": model_field.field_info.max_length,
"regex": model_field.field_info.regex,
**model_field.field_info.extra,
"alias": field_info.alias,
"title": field_info.title,
"description": field_info.description,
"metadata": field_info.metadata,
"json_schema_extra": field_info.json_schema_extra,
}

if issubclass(field_type, BaseModel):
# Complex fields
assert "json_schema_extra" in kwargs # nosec
assert kwargs["json_schema_extra"] # nosec
field_type = _create_json_type(
description=kwargs["description"],
example=kwargs.get("example_json"),
example=kwargs.get("json_schema_extra", {}).get("example_json"),
)

default_value = json_dumps(default_value) if default_value else None

fields[field_name] = (field_type, Query(default=default_value, **kwargs))
Expand Down Expand Up @@ -148,7 +145,7 @@ def create_and_save_openapi_specs(
)
with file_path.open("wt") as fh:
yaml.safe_dump(openapi, fh, indent=1, sort_keys=False)
print("Saved OAS to", file_path)
print("Saved OAS to", file_path) # noqa: T201


class ParamSpec(NamedTuple):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
def get_type(info: FieldInfo) -> Any:
field_type = info.annotation
if args := get_args(info.annotation):
field_type = next(a for a in args if a != type(None))
field_type = next(a for a in args if a is not type(None))
return field_type


Expand Down
4 changes: 2 additions & 2 deletions packages/models-library/src/models_library/basic_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from decimal import Decimal
from enum import StrEnum
from re import Pattern
from typing import Annotated, Final, TypeAlias
from typing import Annotated, ClassVar, Final, TypeAlias

from pydantic import Field, HttpUrl, PositiveInt, StringConstraints
from pydantic_core import core_schema
Expand Down Expand Up @@ -137,7 +137,7 @@ class LongTruncatedStr(ConstrainedStr):

# https e.g. https://techterms.com/definition/https
class HttpSecureUrl(HttpUrl):
allowed_schemes = {"https"}
allowed_schemes: ClassVar[set[str]] = {"https"}


class HttpUrlWithCustomMinLength(HttpUrl):
Expand Down
55 changes: 30 additions & 25 deletions packages/models-library/src/models_library/rest_ordering.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from enum import Enum
from typing import Any, ClassVar
from typing import Annotated

from models_library.utils.json_serialization import json_dumps
from pydantic import BaseModel, Extra, Field, field_validator
from common_library.json_serialization import json_dumps
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, field_validator

from .basic_types import IDStr
from .rest_base import RequestParameters
Expand Down Expand Up @@ -62,17 +62,19 @@ def create_ordering_query_model_classes(
msg_direction_options = "|".join(sorted(OrderDirection))

class _OrderBy(OrderBy):
class Config:
schema_extra: ClassVar[dict[str, Any]] = {
"example": {
"field": next(iter(ordering_fields)),
"direction": OrderDirection.DESC.value,
}
}
extra = Extra.forbid
# Necessary to run _check_ordering_field_and_map in defaults and assignments
validate_all = True
validate_assignment = True
model_config = ConfigDict(
extra="forbid",
json_schema_extra={
"examples": [
{
"field": next(iter(ordering_fields)),
"direction": OrderDirection.DESC.value,
}
]
},
validate_assignment=True, # Necessary to run _check_ordering_field_and_map in defaults and assignments
validate_default=True,
)

@field_validator("field", mode="before")
@classmethod
Expand All @@ -87,28 +89,31 @@ def _check_ordering_field_and_map(cls, v):
# API field name -> DB column_name conversion
return _ordering_fields_api_to_column_map.get(v) or v

order_by_example: dict[str, Any] = _OrderBy.Config.schema_extra["example"]
assert "json_schema_extra" in _OrderBy.model_config # nosec
assert isinstance(_OrderBy.model_config["json_schema_extra"], dict) # nosec
assert isinstance(
_OrderBy.model_config["json_schema_extra"]["examples"], list
) # nosec
order_by_example = _OrderBy.model_config["json_schema_extra"]["examples"][0]
order_by_example_json = json_dumps(order_by_example)
assert _OrderBy.parse_obj(order_by_example), "Example is invalid" # nosec
assert _OrderBy.model_validate(order_by_example), "Example is invalid" # nosec

converted_default = _OrderBy.parse_obj(
converted_default = _OrderBy.model_validate(
# NOTE: enforces ordering_fields_api_to_column_map
default.dict()
default.model_dump()
)

class _OrderQueryParams(_BaseOrderQueryParams):
order_by: _OrderBy = Field(
order_by: Annotated[
_OrderBy, BeforeValidator(parse_json_pre_validator)
] = Field(
default=converted_default,
description=(
f"Order by field (`{msg_field_options}`) and direction (`{msg_direction_options}`). "
f"The default sorting order is `{json_dumps(default)}`."
),
example=order_by_example,
example_json=order_by_example_json,
)

_pre_parse_string = field_validator("order_by", mode="before")(
parse_json_pre_validator
examples=[order_by_example],
json_schema_extra={"example_json": order_by_example_json},
)

return _OrderQueryParams
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ class MyModel(BaseModel):
import operator
from typing import Any

from common_library.json_serialization import json_loads
from orjson import JSONDecodeError

from .json_serialization import json_loads


def empty_str_to_none_pre_validator(value: Any):
if isinstance(value, str) and value.strip() == "":
Expand Down
45 changes: 29 additions & 16 deletions packages/models-library/tests/test_rest_ordering.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import pytest
from common_library.json_serialization import json_dumps
from models_library.basic_types import IDStr
from models_library.rest_ordering import (
OrderBy,
OrderDirection,
create_ordering_query_model_classes,
)
from models_library.utils.json_serialization import json_dumps
from pydantic import BaseModel, Extra, Field, Json, ValidationError, validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
Json,
ValidationError,
field_validator,
)


class ReferenceOrderQueryParamsClass(BaseModel):
Expand All @@ -18,10 +25,10 @@ class ReferenceOrderQueryParamsClass(BaseModel):
order_by: Json[OrderBy] = Field(
default=OrderBy(field=IDStr("modified_at"), direction=OrderDirection.DESC),
description="Order by field (modified_at|name|description) and direction (asc|desc). The default sorting order is ascending.",
example='{"field": "name", "direction": "desc"}',
json_schema_extra={"examples": ['{"field": "name", "direction": "desc"}']},
)

@validator("order_by", check_fields=False)
@field_validator("order_by", check_fields=False)
@classmethod
def _validate_order_by_field(cls, v):
if v.field not in {
Expand All @@ -35,8 +42,9 @@ def _validate_order_by_field(cls, v):
v.field = "modified_column"
return v

class Config:
extra = Extra.forbid
model_config = ConfigDict(
extra="forbid",
)


def test_ordering_query_model_class_factory():
Expand All @@ -52,16 +60,19 @@ class OrderQueryParamsModel(BaseOrderingQueryModel):

# normal
data = {"order_by": {"field": "modified_at", "direction": "asc"}}
model = OrderQueryParamsModel.parse_obj(data)
model = OrderQueryParamsModel.model_validate(data)

assert model.order_by
assert model.order_by.dict() == {"field": "modified_column", "direction": "asc"}
assert model.order_by.model_dump() == {
"field": "modified_column",
"direction": "asc",
}

# test against reference
expected = ReferenceOrderQueryParamsClass.parse_obj(
expected = ReferenceOrderQueryParamsClass.model_validate(
{"order_by": json_dumps({"field": "modified_at", "direction": "asc"})}
)
assert expected.dict() == model.dict()
assert expected.model_dump() == model.model_dump()


def test_ordering_query_model_class__fails_with_invalid_fields():
Expand All @@ -73,7 +84,7 @@ def test_ordering_query_model_class__fails_with_invalid_fields():

# fails with invalid field to sort
with pytest.raises(ValidationError) as err_info:
OrderQueryParamsModel.parse_obj({"order_by": {"field": "INVALID"}})
OrderQueryParamsModel.model_validate({"order_by": {"field": "INVALID"}})

error = err_info.value.errors()[0]

Expand All @@ -89,7 +100,7 @@ def test_ordering_query_model_class__fails_with_invalid_direction():
)

with pytest.raises(ValidationError) as err_info:
OrderQueryParamsModel.parse_obj(
OrderQueryParamsModel.model_validate(
{"order_by": {"field": "modified", "direction": "INVALID"}}
)

Expand All @@ -110,18 +121,19 @@ def test_ordering_query_model_class__defaults():
# checks all defaults
model = OrderQueryParamsModel()
assert model.order_by
assert isinstance(model.order_by, OrderBy) # nosec
assert model.order_by.field == "modified_at" # NOTE that this was mapped!
assert model.order_by.direction == OrderDirection.DESC

# partial defaults
model = OrderQueryParamsModel.parse_obj({"order_by": {"field": "name"}})
model = OrderQueryParamsModel.model_validate({"order_by": {"field": "name"}})
assert model.order_by
assert model.order_by.field == "name"
assert model.order_by.direction == OrderBy.__fields__["direction"].default
assert model.order_by.direction == OrderBy.model_fields["direction"].default

# direction alone is invalid
with pytest.raises(ValidationError) as err_info:
OrderQueryParamsModel.parse_obj({"order_by": {"direction": "asc"}})
OrderQueryParamsModel.model_validate({"order_by": {"direction": "asc"}})

error = err_info.value.errors()[0]
assert error["loc"] == ("order_by", "field")
Expand All @@ -135,5 +147,6 @@ def test_ordering_query_model_with_map():
ordering_fields_api_to_column_map={"modified": "some_db_column_name"},
)

model = OrderQueryParamsModel.parse_obj({"order_by": {"field": "modified"}})
model = OrderQueryParamsModel.model_validate({"order_by": {"field": "modified"}})
assert model.order_by
assert model.order_by.field == "some_db_column_name"
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import TypeAlias, TypeVar, Union

from aiohttp import web
from models_library.utils.json_serialization import json_dumps
from common_library.json_serialization import json_dumps
from pydantic import BaseModel, TypeAdapter, ValidationError

from ..mimetype_constants import MIMETYPE_APPLICATION_JSON
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import pytest
from aiohttp import web
from aiohttp.test_utils import TestClient, make_mocked_request
from common_library.json_serialization import json_dumps
from faker import Faker
from models_library.rest_base import RequestParameters, StrictRequestParameters
from models_library.rest_ordering import (
OrderBy,
OrderDirection,
create_ordering_query_model_classes,
)
from models_library.utils.json_serialization import json_dumps
from pydantic import BaseModel, ConfigDict, Field
from servicelib.aiohttp import status
from servicelib.aiohttp.requests_validation import (
Expand Down
5 changes: 3 additions & 2 deletions packages/settings-library/src/settings_library/application.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from models_library.basic_types import BootModeEnum
from pydantic import Field, PositiveInt

from .base import BaseCustomSettings
from .basic_types import BootMode, BuildTargetEnum
from .basic_types import BuildTargetEnum


class BaseApplicationSettings(BaseCustomSettings):
Expand All @@ -16,7 +17,7 @@ class BaseApplicationSettings(BaseCustomSettings):
SC_VCS_URL: str | None = None

# @Dockerfile
SC_BOOT_MODE: BootMode | None = None
SC_BOOT_MODE: BootModeEnum | None = None
SC_BOOT_TARGET: BuildTargetEnum | None = None
SC_HEALTHCHECK_TIMEOUT: PositiveInt | None = Field(
default=None,
Expand Down
9 changes: 5 additions & 4 deletions scripts/maintenance/migrate_project/src/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
from pathlib import Path
from typing import Optional
from uuid import UUID

from pydantic import BaseModel, Field
Expand Down Expand Up @@ -32,7 +31,7 @@ class SourceConfig(BaseModel):
db: DBConfig
s3: S3Config
project_uuid: UUID = Field(..., description="project to be moved from the source")
hidden_projects_for_user: Optional[int] = Field(
hidden_projects_for_user: int | None = Field(
None,
description="by default nothing is moved, must provide an user ID for which to move the hidden projects",
)
Expand All @@ -57,7 +56,7 @@ class Settings(BaseModel):

@classmethod
def load_from_file(cls, path: Path) -> "Settings":
return Settings.parse_obj(json.loads(path.read_text()))
return Settings.model_validate(json.loads(path.read_text()))

class Config:
schema_extra = {
Expand Down Expand Up @@ -92,4 +91,6 @@ class Config:

if __name__ == "__main__":
# produces an empty configuration to be saved as starting point
print(Settings.parse_obj(Settings.Config.schema_extra["example"]).json(indent=2))
print(
Settings.model_validate(Settings.Config.schema_extra["example"]).json(indent=2)
)
Loading

0 comments on commit f442e63

Please sign in to comment.