From 9bcb120ce806aa82c78dbe402b042b8eaa02ecbd Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Wed, 24 Jul 2024 14:26:40 -0500 Subject: [PATCH] Move enum schema gen to `_generate_schema.py` for consistency (#9963) --- pydantic/_internal/_generate_schema.py | 76 ++++++++++++++++++++++- pydantic/_internal/_std_types_schema.py | 81 +------------------------ 2 files changed, 74 insertions(+), 83 deletions(-) diff --git a/pydantic/_internal/_generate_schema.py b/pydantic/_internal/_generate_schema.py index ca95d887c2..712137d67f 100644 --- a/pydantic/_internal/_generate_schema.py +++ b/pydantic/_internal/_generate_schema.py @@ -399,6 +399,78 @@ def _set_schema(self, tp: Any, items_type: Any) -> CoreSchema: def _frozenset_schema(self, tp: Any, items_type: Any) -> CoreSchema: return core_schema.frozenset_schema(self.generate_schema(items_type)) + def _enum_schema(self, enum_type: type[Enum]) -> CoreSchema: + cases: list[Any] = list(enum_type.__members__.values()) + + enum_ref = get_type_ref(enum_type) + description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__) + if ( + description == 'An enumeration.' + ): # This is the default value provided by enum.EnumMeta.__new__; don't use it + description = None + js_updates = {'title': enum_type.__name__, 'description': description} + js_updates = {k: v for k, v in js_updates.items() if v is not None} + + sub_type: Literal['str', 'int', 'float'] | None = None + if issubclass(enum_type, int): + sub_type = 'int' + value_ser_type: core_schema.SerSchema = core_schema.simple_ser_schema('int') + elif issubclass(enum_type, str): + # this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)` + sub_type = 'str' + value_ser_type = core_schema.simple_ser_schema('str') + elif issubclass(enum_type, float): + sub_type = 'float' + value_ser_type = core_schema.simple_ser_schema('float') + else: + # TODO this is an ugly hack, how do we trigger an Any schema for serialization? + value_ser_type = core_schema.plain_serializer_function_ser_schema(lambda x: x) + + if cases: + + def get_json_schema(schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + json_schema = handler(schema) + original_schema = handler.resolve_ref_schema(json_schema) + original_schema.update(js_updates) + return json_schema + + # we don't want to add the missing to the schema if it's the default one + default_missing = getattr(enum_type._missing_, '__func__', None) == Enum._missing_.__func__ # type: ignore + enum_schema = core_schema.enum_schema( + enum_type, + cases, + sub_type=sub_type, + missing=None if default_missing else enum_type._missing_, + ref=enum_ref, + metadata={'pydantic_js_functions': [get_json_schema]}, + ) + + if self._config_wrapper.use_enum_values: + enum_schema = core_schema.no_info_after_validator_function( + attrgetter('value'), enum_schema, serialization=value_ser_type + ) + + return enum_schema + + else: + + def get_json_schema_no_cases(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + json_schema = handler(core_schema.enum_schema(enum_type, cases, sub_type=sub_type, ref=enum_ref)) + original_schema = handler.resolve_ref_schema(json_schema) + original_schema.update(js_updates) + return json_schema + + # Use an isinstance check for enums with no cases. + # The most important use case for this is creating TypeVar bounds for generics that should + # be restricted to enums. This is more consistent than it might seem at first, since you can only + # subclass enum.Enum (or subclasses of enum.Enum) if all parent classes have no cases. + # We use the get_json_schema function when an Enum subclass has been declared with no cases + # so that we can still generate a valid json schema. + return core_schema.is_instance_schema( + enum_type, + metadata={'pydantic_js_functions': [get_json_schema_no_cases]}, + ) + def _arbitrary_type_schema(self, tp: Any) -> CoreSchema: if not isinstance(tp, type): warn( @@ -855,9 +927,7 @@ def match_type(self, obj: Any) -> core_schema.CoreSchema: # noqa: C901 elif isinstance(obj, (FunctionType, LambdaType, MethodType, partial)): return self._callable_schema(obj) elif inspect.isclass(obj) and issubclass(obj, Enum): - from ._std_types_schema import get_enum_core_schema - - return get_enum_core_schema(obj, self._config_wrapper.config_dict) + return self._enum_schema(obj) elif is_zoneinfo_type(obj): return self._zoneinfo_schema() diff --git a/pydantic/_internal/_std_types_schema.py b/pydantic/_internal/_std_types_schema.py index 603ab4cbd1..9a0c6a1e17 100644 --- a/pydantic/_internal/_std_types_schema.py +++ b/pydantic/_internal/_std_types_schema.py @@ -9,14 +9,11 @@ import collections.abc import dataclasses import decimal -import inspect import os import typing -from enum import Enum from functools import partial from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network -from operator import attrgetter -from typing import Any, Callable, Iterable, Literal, Tuple, TypeVar +from typing import Any, Callable, Iterable, Tuple, TypeVar import typing_extensions from pydantic_core import ( @@ -36,7 +33,6 @@ from ..config import ConfigDict from ..json_schema import JsonSchemaValue from . import _known_annotated_metadata, _typing_extra -from ._core_utils import get_type_ref from ._internal_dataclass import slots_true from ._schema_generation_shared import GetCoreSchemaHandler, GetJsonSchemaHandler @@ -58,77 +54,6 @@ def __get_pydantic_json_schema__(self, schema: CoreSchema, handler: GetJsonSchem return self.get_json_schema(schema, handler) -def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchema: - cases: list[Any] = list(enum_type.__members__.values()) - - enum_ref = get_type_ref(enum_type) - description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__) - if description == 'An enumeration.': # This is the default value provided by enum.EnumMeta.__new__; don't use it - description = None - js_updates = {'title': enum_type.__name__, 'description': description} - js_updates = {k: v for k, v in js_updates.items() if v is not None} - - sub_type: Literal['str', 'int', 'float'] | None = None - if issubclass(enum_type, int): - sub_type = 'int' - value_ser_type: core_schema.SerSchema = core_schema.simple_ser_schema('int') - elif issubclass(enum_type, str): - # this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)` - sub_type = 'str' - value_ser_type = core_schema.simple_ser_schema('str') - elif issubclass(enum_type, float): - sub_type = 'float' - value_ser_type = core_schema.simple_ser_schema('float') - else: - # TODO this is an ugly hack, how do we trigger an Any schema for serialization? - value_ser_type = core_schema.plain_serializer_function_ser_schema(lambda x: x) - - if cases: - - def get_json_schema(schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: - json_schema = handler(schema) - original_schema = handler.resolve_ref_schema(json_schema) - original_schema.update(js_updates) - return json_schema - - # we don't want to add the missing to the schema if it's the default one - default_missing = getattr(enum_type._missing_, '__func__', None) == Enum._missing_.__func__ # type: ignore - enum_schema = core_schema.enum_schema( - enum_type, - cases, - sub_type=sub_type, - missing=None if default_missing else enum_type._missing_, - ref=enum_ref, - metadata={'pydantic_js_functions': [get_json_schema]}, - ) - - if config.get('use_enum_values', False): - enum_schema = core_schema.no_info_after_validator_function( - attrgetter('value'), enum_schema, serialization=value_ser_type - ) - - return enum_schema - - else: - - def get_json_schema_no_cases(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue: - json_schema = handler(core_schema.enum_schema(enum_type, cases, sub_type=sub_type, ref=enum_ref)) - original_schema = handler.resolve_ref_schema(json_schema) - original_schema.update(js_updates) - return json_schema - - # Use an isinstance check for enums with no cases. - # The most important use case for this is creating TypeVar bounds for generics that should - # be restricted to enums. This is more consistent than it might seem at first, since you can only - # subclass enum.Enum (or subclasses of enum.Enum) if all parent classes have no cases. - # We use the get_json_schema function when an Enum subclass has been declared with no cases - # so that we can still generate a valid json schema. - return core_schema.is_instance_schema( - enum_type, - metadata={'pydantic_js_functions': [get_json_schema_no_cases]}, - ) - - @dataclasses.dataclass(**slots_true) class InnerSchemaValidator: """Use a fixed CoreSchema, avoiding interference from outward annotations.""" @@ -414,10 +339,6 @@ def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaH } -def identity(s: CoreSchema) -> CoreSchema: - return s - - def sequence_like_prepare_pydantic_annotations( source_type: Any, annotations: Iterable[Any], _config: ConfigDict ) -> tuple[Any, list[Any]] | None: