Skip to content

Commit

Permalink
Move enum schema gen to _generate_schema.py for consistency (pydant…
Browse files Browse the repository at this point in the history
  • Loading branch information
sydney-runkle authored Jul 24, 2024
1 parent f3dc053 commit 9bcb120
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 83 deletions.
76 changes: 73 additions & 3 deletions pydantic/_internal/_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,78 @@ def _set_schema(self, tp: Any, items_type: Any) -> CoreSchema:
def _frozenset_schema(self, tp: Any, items_type: Any) -> CoreSchema:
return core_schema.frozenset_schema(self.generate_schema(items_type))

def _enum_schema(self, enum_type: type[Enum]) -> CoreSchema:
cases: list[Any] = list(enum_type.__members__.values())

enum_ref = get_type_ref(enum_type)
description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__)
if (
description == 'An enumeration.'
): # This is the default value provided by enum.EnumMeta.__new__; don't use it
description = None
js_updates = {'title': enum_type.__name__, 'description': description}
js_updates = {k: v for k, v in js_updates.items() if v is not None}

sub_type: Literal['str', 'int', 'float'] | None = None
if issubclass(enum_type, int):
sub_type = 'int'
value_ser_type: core_schema.SerSchema = core_schema.simple_ser_schema('int')
elif issubclass(enum_type, str):
# this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)`
sub_type = 'str'
value_ser_type = core_schema.simple_ser_schema('str')
elif issubclass(enum_type, float):
sub_type = 'float'
value_ser_type = core_schema.simple_ser_schema('float')
else:
# TODO this is an ugly hack, how do we trigger an Any schema for serialization?
value_ser_type = core_schema.plain_serializer_function_ser_schema(lambda x: x)

if cases:

def get_json_schema(schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
json_schema = handler(schema)
original_schema = handler.resolve_ref_schema(json_schema)
original_schema.update(js_updates)
return json_schema

# we don't want to add the missing to the schema if it's the default one
default_missing = getattr(enum_type._missing_, '__func__', None) == Enum._missing_.__func__ # type: ignore
enum_schema = core_schema.enum_schema(
enum_type,
cases,
sub_type=sub_type,
missing=None if default_missing else enum_type._missing_,
ref=enum_ref,
metadata={'pydantic_js_functions': [get_json_schema]},
)

if self._config_wrapper.use_enum_values:
enum_schema = core_schema.no_info_after_validator_function(
attrgetter('value'), enum_schema, serialization=value_ser_type
)

return enum_schema

else:

def get_json_schema_no_cases(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
json_schema = handler(core_schema.enum_schema(enum_type, cases, sub_type=sub_type, ref=enum_ref))
original_schema = handler.resolve_ref_schema(json_schema)
original_schema.update(js_updates)
return json_schema

# Use an isinstance check for enums with no cases.
# The most important use case for this is creating TypeVar bounds for generics that should
# be restricted to enums. This is more consistent than it might seem at first, since you can only
# subclass enum.Enum (or subclasses of enum.Enum) if all parent classes have no cases.
# We use the get_json_schema function when an Enum subclass has been declared with no cases
# so that we can still generate a valid json schema.
return core_schema.is_instance_schema(
enum_type,
metadata={'pydantic_js_functions': [get_json_schema_no_cases]},
)

def _arbitrary_type_schema(self, tp: Any) -> CoreSchema:
if not isinstance(tp, type):
warn(
Expand Down Expand Up @@ -855,9 +927,7 @@ def match_type(self, obj: Any) -> core_schema.CoreSchema: # noqa: C901
elif isinstance(obj, (FunctionType, LambdaType, MethodType, partial)):
return self._callable_schema(obj)
elif inspect.isclass(obj) and issubclass(obj, Enum):
from ._std_types_schema import get_enum_core_schema

return get_enum_core_schema(obj, self._config_wrapper.config_dict)
return self._enum_schema(obj)
elif is_zoneinfo_type(obj):
return self._zoneinfo_schema()

Expand Down
81 changes: 1 addition & 80 deletions pydantic/_internal/_std_types_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@
import collections.abc
import dataclasses
import decimal
import inspect
import os
import typing
from enum import Enum
from functools import partial
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from operator import attrgetter
from typing import Any, Callable, Iterable, Literal, Tuple, TypeVar
from typing import Any, Callable, Iterable, Tuple, TypeVar

import typing_extensions
from pydantic_core import (
Expand All @@ -36,7 +33,6 @@
from ..config import ConfigDict
from ..json_schema import JsonSchemaValue
from . import _known_annotated_metadata, _typing_extra
from ._core_utils import get_type_ref
from ._internal_dataclass import slots_true
from ._schema_generation_shared import GetCoreSchemaHandler, GetJsonSchemaHandler

Expand All @@ -58,77 +54,6 @@ def __get_pydantic_json_schema__(self, schema: CoreSchema, handler: GetJsonSchem
return self.get_json_schema(schema, handler)


def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchema:
cases: list[Any] = list(enum_type.__members__.values())

enum_ref = get_type_ref(enum_type)
description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__)
if description == 'An enumeration.': # This is the default value provided by enum.EnumMeta.__new__; don't use it
description = None
js_updates = {'title': enum_type.__name__, 'description': description}
js_updates = {k: v for k, v in js_updates.items() if v is not None}

sub_type: Literal['str', 'int', 'float'] | None = None
if issubclass(enum_type, int):
sub_type = 'int'
value_ser_type: core_schema.SerSchema = core_schema.simple_ser_schema('int')
elif issubclass(enum_type, str):
# this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)`
sub_type = 'str'
value_ser_type = core_schema.simple_ser_schema('str')
elif issubclass(enum_type, float):
sub_type = 'float'
value_ser_type = core_schema.simple_ser_schema('float')
else:
# TODO this is an ugly hack, how do we trigger an Any schema for serialization?
value_ser_type = core_schema.plain_serializer_function_ser_schema(lambda x: x)

if cases:

def get_json_schema(schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
json_schema = handler(schema)
original_schema = handler.resolve_ref_schema(json_schema)
original_schema.update(js_updates)
return json_schema

# we don't want to add the missing to the schema if it's the default one
default_missing = getattr(enum_type._missing_, '__func__', None) == Enum._missing_.__func__ # type: ignore
enum_schema = core_schema.enum_schema(
enum_type,
cases,
sub_type=sub_type,
missing=None if default_missing else enum_type._missing_,
ref=enum_ref,
metadata={'pydantic_js_functions': [get_json_schema]},
)

if config.get('use_enum_values', False):
enum_schema = core_schema.no_info_after_validator_function(
attrgetter('value'), enum_schema, serialization=value_ser_type
)

return enum_schema

else:

def get_json_schema_no_cases(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
json_schema = handler(core_schema.enum_schema(enum_type, cases, sub_type=sub_type, ref=enum_ref))
original_schema = handler.resolve_ref_schema(json_schema)
original_schema.update(js_updates)
return json_schema

# Use an isinstance check for enums with no cases.
# The most important use case for this is creating TypeVar bounds for generics that should
# be restricted to enums. This is more consistent than it might seem at first, since you can only
# subclass enum.Enum (or subclasses of enum.Enum) if all parent classes have no cases.
# We use the get_json_schema function when an Enum subclass has been declared with no cases
# so that we can still generate a valid json schema.
return core_schema.is_instance_schema(
enum_type,
metadata={'pydantic_js_functions': [get_json_schema_no_cases]},
)


@dataclasses.dataclass(**slots_true)
class InnerSchemaValidator:
"""Use a fixed CoreSchema, avoiding interference from outward annotations."""
Expand Down Expand Up @@ -414,10 +339,6 @@ def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaH
}


def identity(s: CoreSchema) -> CoreSchema:
return s


def sequence_like_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
Expand Down

0 comments on commit 9bcb120

Please sign in to comment.