Skip to content

Commit

Permalink
[typing] prefect.serializers (#16331)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored Dec 11, 2024
1 parent 0e562b8 commit 3c65bca
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 27 deletions.
7 changes: 4 additions & 3 deletions src/prefect/_internal/schemas/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

if TYPE_CHECKING:
from prefect.blocks.core import Block
from prefect.serializers import Serializer
from prefect.utilities.callables import ParameterSchema


Expand Down Expand Up @@ -578,15 +579,15 @@ def validate_picklelib_and_modules(values: dict) -> dict:
return values


def validate_dump_kwargs(value: dict) -> dict:
def validate_dump_kwargs(value: dict[str, Any]) -> dict[str, Any]:
# `default` is set by `object_encoder`. A user provided callable would make this
# class unserializable anyway.
if "default" in value:
raise ValueError("`default` cannot be provided. Use `object_encoder` instead.")
return value


def validate_load_kwargs(value: dict) -> dict:
def validate_load_kwargs(value: dict[str, Any]) -> dict[str, Any]:
# `object_hook` is set by `object_decoder`. A user provided callable would make
# this class unserializable anyway.
if "object_hook" in value:
Expand All @@ -596,7 +597,7 @@ def validate_load_kwargs(value: dict) -> dict:
return value


def cast_type_names_to_serializers(value):
def cast_type_names_to_serializers(value: Union[str, "Serializer"]) -> "Serializer":
from prefect.serializers import Serializer

if isinstance(value, str):
Expand Down
52 changes: 28 additions & 24 deletions src/prefect/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import abc
import base64
from typing import Any, Dict, Generic, Optional, Type
from typing import Any, Generic, Optional, Type, Union

from pydantic import (
BaseModel,
Expand All @@ -23,7 +23,7 @@
ValidationError,
field_validator,
)
from typing_extensions import Literal, Self, TypeVar
from typing_extensions import Self, TypeVar

from prefect._internal.schemas.validators import (
cast_type_names_to_serializers,
Expand Down Expand Up @@ -54,7 +54,7 @@ def prefect_json_object_encoder(obj: Any) -> Any:
}


def prefect_json_object_decoder(result: dict):
def prefect_json_object_decoder(result: dict[str, Any]):
"""
`JSONDecoder.object_hook` for decoding objects from JSON when previously encoded
with `prefect_json_object_encoder`
Expand All @@ -80,12 +80,16 @@ def __init__(self, **data: Any) -> None:
data.setdefault("type", type_string)
super().__init__(**data)

def __new__(cls: Type[Self], **kwargs) -> Self:
def __new__(cls: Type[Self], **kwargs: Any) -> Self:
if "type" in kwargs:
try:
subcls = lookup_type(cls, dispatch_key=kwargs["type"])
except KeyError as exc:
raise ValidationError(errors=[exc], model=cls)
raise ValidationError.from_exception_data(
title=cls.__name__,
line_errors=[{"type": str(exc), "input": kwargs["type"]}],
input_type="python",
)

return super().__new__(subcls)
else:
Expand All @@ -104,7 +108,7 @@ def loads(self, blob: bytes) -> D:
model_config = ConfigDict(extra="forbid")

@classmethod
def __dispatch_key__(cls) -> str:
def __dispatch_key__(cls) -> Optional[str]:
type_str = cls.model_fields["type"].default
return type_str if isinstance(type_str, str) else None

Expand All @@ -119,19 +123,15 @@ class PickleSerializer(Serializer):
- Wraps pickles in base64 for safe transmission.
"""

type: Literal["pickle"] = "pickle"
type: str = Field(default="pickle", frozen=True)

picklelib: str = "cloudpickle"
picklelib_version: Optional[str] = None

@field_validator("picklelib")
def check_picklelib(cls, value):
def check_picklelib(cls, value: str) -> str:
return validate_picklelib(value)

# @model_validator(mode="before")
# def check_picklelib_version(cls, values):
# return validate_picklelib_version(values)

def dumps(self, obj: Any) -> bytes:
pickler = from_qualified_name(self.picklelib)
blob = pickler.dumps(obj)
Expand All @@ -151,7 +151,7 @@ class JSONSerializer(Serializer):
Wraps the `json` library to serialize to UTF-8 bytes instead of string types.
"""

type: Literal["json"] = "json"
type: str = Field(default="json", frozen=True)

jsonlib: str = "json"
object_encoder: Optional[str] = Field(
Expand All @@ -171,23 +171,27 @@ class JSONSerializer(Serializer):
"by our default `object_encoder`."
),
)
dumps_kwargs: Dict[str, Any] = Field(default_factory=dict)
loads_kwargs: Dict[str, Any] = Field(default_factory=dict)
dumps_kwargs: dict[str, Any] = Field(default_factory=dict)
loads_kwargs: dict[str, Any] = Field(default_factory=dict)

@field_validator("dumps_kwargs")
def dumps_kwargs_cannot_contain_default(cls, value):
def dumps_kwargs_cannot_contain_default(
cls, value: dict[str, Any]
) -> dict[str, Any]:
return validate_dump_kwargs(value)

@field_validator("loads_kwargs")
def loads_kwargs_cannot_contain_object_hook(cls, value):
def loads_kwargs_cannot_contain_object_hook(
cls, value: dict[str, Any]
) -> dict[str, Any]:
return validate_load_kwargs(value)

def dumps(self, data: Any) -> bytes:
def dumps(self, obj: Any) -> bytes:
json = from_qualified_name(self.jsonlib)
kwargs = self.dumps_kwargs.copy()
if self.object_encoder:
kwargs["default"] = from_qualified_name(self.object_encoder)
result = json.dumps(data, **kwargs)
result = json.dumps(obj, **kwargs)
if isinstance(result, str):
# The standard library returns str but others may return bytes directly
result = result.encode()
Expand All @@ -213,17 +217,17 @@ class CompressedSerializer(Serializer):
level: If not null, the level of compression to pass to `compress`.
"""

type: Literal["compressed"] = "compressed"
type: str = Field(default="compressed", frozen=True)

serializer: Serializer
compressionlib: str = "lzma"

@field_validator("serializer", mode="before")
def validate_serializer(cls, value):
def validate_serializer(cls, value: Union[str, Serializer]) -> Serializer:
return cast_type_names_to_serializers(value)

@field_validator("compressionlib")
def check_compressionlib(cls, value):
def check_compressionlib(cls, value: str) -> str:
return validate_compressionlib(value)

def dumps(self, obj: Any) -> bytes:
Expand All @@ -242,7 +246,7 @@ class CompressedPickleSerializer(CompressedSerializer):
A compressed serializer preconfigured to use the pickle serializer.
"""

type: Literal["compressed/pickle"] = "compressed/pickle"
type: str = Field(default="compressed/pickle", frozen=True)

serializer: Serializer = Field(default_factory=PickleSerializer)

Expand All @@ -252,6 +256,6 @@ class CompressedJSONSerializer(CompressedSerializer):
A compressed serializer preconfigured to use the json serializer.
"""

type: Literal["compressed/json"] = "compressed/json"
type: str = Field(default="compressed/json", frozen=True)

serializer: Serializer = Field(default_factory=JSONSerializer)

0 comments on commit 3c65bca

Please sign in to comment.