From 3ae8a3d6ab71fd561a272577439f7a7cd3610da7 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Sat, 11 May 2024 16:33:03 +0100 Subject: [PATCH] Use pattern-matching for some `typing.py` internals --- Lib/_collections_abc.py | 20 ++-- Lib/typing.py | 228 +++++++++++++++++++++++----------------- 2 files changed, 146 insertions(+), 102 deletions(-) diff --git a/Lib/_collections_abc.py b/Lib/_collections_abc.py index 1135e17e3790590..d909d62f97fe57c 100644 --- a/Lib/_collections_abc.py +++ b/Lib/_collections_abc.py @@ -508,6 +508,7 @@ def __getitem__(self, item): new_args = (t_args, t_result) return _CallableGenericAlias(Callable, tuple(new_args)) + def _is_param_expr(obj): """Checks if obj matches either a list of types, ``...``, ``ParamSpec`` or ``_ConcatenateGenericAlias`` from typing.py @@ -520,6 +521,7 @@ def _is_param_expr(obj): names = ('ParamSpec', '_ConcatenateGenericAlias') return obj.__module__ == 'typing' and any(obj.__name__ == name for name in names) + def _type_repr(obj): """Return the repr() of an object, special-casing types (internal helper). @@ -527,15 +529,17 @@ def _type_repr(obj): shouldn't depend on that module. (Keep this roughly in sync with the typing version.) """ - if isinstance(obj, type): - if obj.__module__ == 'builtins': + match obj: + case type(__module__="builtins"): return obj.__qualname__ - return f'{obj.__module__}.{obj.__qualname__}' - if obj is Ellipsis: - return '...' - if isinstance(obj, FunctionType): - return obj.__name__ - return repr(obj) + case type(): + return f'{obj.__module__}.{obj.__qualname__}' + case EllipsisType(): + return '...' + case FunctionType(): + return obj.__name__ + case _: + return repr(obj) class Callable(metaclass=ABCMeta): diff --git a/Lib/typing.py b/Lib/typing.py index 434574559e04fcb..36cd0341ecbc5fc 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -162,11 +162,13 @@ def _type_convert(arg, module=None, *, allow_special_forms=False): """For converting None to type(None), and strings to ForwardRef.""" - if arg is None: - return type(None) - if isinstance(arg, str): - return ForwardRef(arg, module=module, is_class=allow_special_forms) - return arg + match arg: + case None: + return types.NoneType + case str(): + return ForwardRef(arg, module=module, is_class=allow_special_forms) + case _: + return arg def _type_check(arg, msg, is_argument=True, module=None, *, allow_special_forms=False): @@ -242,18 +244,20 @@ def _type_repr(obj): # When changing this function, don't forget about # `_collections_abc._type_repr`, which does the same thing # and must be consistent with this one. - if isinstance(obj, type): - if obj.__module__ == 'builtins': + match obj: + case type(__module__="builtins"): return obj.__qualname__ - return f'{obj.__module__}.{obj.__qualname__}' - if obj is ...: - return '...' - if isinstance(obj, types.FunctionType): - return obj.__name__ - if isinstance(obj, tuple): - # Special case for `repr` of types with `ParamSpec`: - return '[' + ', '.join(_type_repr(t) for t in obj) + ']' - return repr(obj) + case type(): + return f'{obj.__module__}.{obj.__qualname__}' + case types.EllipsisType(): + return '...' + case types.FunctionType(): + return obj.__name__ + case tuple(): + # Special case for `repr` of types with `ParamSpec`: + return '[' + ', '.join(_type_repr(elem) for elem in obj) + ']' + case _: + return repr(obj) def _collect_type_parameters(args, *, enforce_default_ordering: bool = True): @@ -340,13 +344,16 @@ def _check_generic_specialization(cls, arguments): def _unpack_args(*args): newargs = [] for arg in args: - subargs = getattr(arg, '__typing_unpacked_tuple_args__', None) - if subargs is not None and not (subargs and subargs[-1] is ...): - newargs.extend(subargs) - else: - newargs.append(arg) + match getattr(arg, '__typing_unpacked_tuple_args__', None): + case str() | [*_, types.EllipsisType()]: + newargs.append(arg) + case [*subargs]: + newargs.extend(subargs) + case _: + newargs.append(arg) return newargs + def _deduplicate(params, *, unhashable_fallback=False): # Weed out strict duplicates, preserving the first of each occurrence. try: @@ -357,6 +364,7 @@ def _deduplicate(params, *, unhashable_fallback=False): # Happens for cases like `Annotated[dict, {'x': IntValidator()}]` return _deduplicate_unhashable(params) + def _deduplicate_unhashable(unhashable_params): new_unhashable = [] for t in unhashable_params: @@ -364,6 +372,7 @@ def _deduplicate_unhashable(unhashable_params): new_unhashable.append(t) return new_unhashable + def _compare_args_orderless(first_args, second_args): first_unhashable = _deduplicate_unhashable(first_args) second_unhashable = _deduplicate_unhashable(second_args) @@ -375,6 +384,7 @@ def _compare_args_orderless(first_args, second_args): return False return not t + def _remove_dups_flatten(parameters): """Internal helper for Union creation and substitution. @@ -630,6 +640,7 @@ def stop() -> NoReturn: """ raise TypeError(f"{self} is not subscriptable") + # This is semantically identical to NoReturn, but it is implemented # separately so that type checkers can distinguish between the two # if they want. @@ -727,6 +738,7 @@ class Starship: item = _type_check(parameters, f'{self} accepts only single type.', allow_special_forms=True) return _GenericAlias(self, (item,)) + @_SpecialForm def Final(self, parameters): """Special typing construct to indicate final names to type checkers. @@ -749,6 +761,7 @@ class FastConnector(Connection): item = _type_check(parameters, f'{self} accepts only single type.', allow_special_forms=True) return _GenericAlias(self, (item,)) + @_SpecialForm def Union(self, parameters): """Union type; Union[X, Y] means either X or Y. @@ -793,6 +806,7 @@ def Union(self, parameters): return _UnionGenericAlias(self, parameters, name="Optional") return _UnionGenericAlias(self, parameters) + def _make_union(left, right): """Used from the C implementation of TypeVar. @@ -802,12 +816,14 @@ def _make_union(left, right): """ return Union[left, right] + @_SpecialForm def Optional(self, parameters): """Optional[X] is equivalent to Union[X, None].""" arg = _type_check(parameters, f"{self} requires a single type.") return Union[arg, type(None)] + @_TypedCacheSpecialForm @_tp_cache(typed=True) def Literal(self, *parameters): @@ -1284,6 +1300,7 @@ def _generic_init_subclass(cls, *args, **kwargs): def _is_dunder(attr): return attr.startswith('__') and attr.endswith('__') + class _BaseGenericAlias(_Final, _root=True): """The central part of the internal API. @@ -1952,6 +1969,7 @@ def _caller(depth=1, default='__main__'): pass return None + def _allow_reckless_class_checks(depth=2): """Allow instance and class checks for special stdlib modules. @@ -2192,6 +2210,7 @@ class _AnnotatedAlias(_NotIterable, _GenericAlias, _root=True): The metadata itself is stored in a '__metadata__' attribute as a tuple. """ + __match_args__ = ("__origin__", "__metadata__") def __init__(self, origin, metadata): if isinstance(origin, _AnnotatedAlias): @@ -2470,32 +2489,33 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()} -def _strip_annotations(t): +def _strip_annotations(typ): """Strip the annotations from a given type.""" - if isinstance(t, _AnnotatedAlias): - return _strip_annotations(t.__origin__) - if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired, ReadOnly): - return _strip_annotations(t.__args__[0]) - if isinstance(t, _GenericAlias): - stripped_args = tuple(_strip_annotations(a) for a in t.__args__) - if stripped_args == t.__args__: - return t - return t.copy_with(stripped_args) - if isinstance(t, GenericAlias): - stripped_args = tuple(_strip_annotations(a) for a in t.__args__) - if stripped_args == t.__args__: - return t - return GenericAlias(t.__origin__, stripped_args) - if isinstance(t, types.UnionType): - stripped_args = tuple(_strip_annotations(a) for a in t.__args__) - if stripped_args == t.__args__: - return t - return functools.reduce(operator.or_, stripped_args) - - return t - - -def get_origin(tp): + match typ: + case _AnnotatedAlias(origin): + return _strip_annotations(origin) + case object(__origin__=origin) if origin in {Required, NotRequired, ReadOnly}: + return _strip_annotations(typ.__args__[0]) + case _GenericAlias(__args__=args): + stripped_args = tuple(_strip_annotations(arg) for arg in args) + if stripped_args == args: + return typ + return typ.copy_with(stripped_args) + case GenericAlias(__args__=args): + stripped_args = tuple(_strip_annotations(arg) for arg in args) + if stripped_args == args: + return typ + return GenericAlias(typ.__origin__, stripped_args) + case types.UnionType(__args__=args): + stripped_args = tuple(_strip_annotations(arg) for arg in args) + if stripped_args == args: + return typ + return functools.reduce(operator.or_, stripped_args) + case _: + return typ + + +def get_origin(typ): """Get the unsubscripted version of a type. This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar, @@ -2513,19 +2533,20 @@ def get_origin(tp): >>> assert get_origin(List[Tuple[T, T]][int]) is list >>> assert get_origin(P.args) is P """ - if isinstance(tp, _AnnotatedAlias): - return Annotated - if isinstance(tp, (_BaseGenericAlias, GenericAlias, - ParamSpecArgs, ParamSpecKwargs)): - return tp.__origin__ - if tp is Generic: - return Generic - if isinstance(tp, types.UnionType): - return types.UnionType - return None - - -def get_args(tp): + match typ: + case _AnnotatedAlias(): + return Annotated + case _BaseGenericAlias() | GenericAlias() | ParamSpecArgs() | ParamSpecKwargs(): + return typ.__origin__ + case types.UnionType(): + return types.UnionType + case typ if typ is Generic: + return Generic + case _: + return None + + +def get_args(typ): """Get type arguments with all substitutions performed. For unions, basic simplifications used by Union constructor are performed. @@ -2539,16 +2560,18 @@ def get_args(tp): >>> assert get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) >>> assert get_args(Callable[[], T][int]) == ([], int) """ - if isinstance(tp, _AnnotatedAlias): - return (tp.__origin__,) + tp.__metadata__ - if isinstance(tp, (_GenericAlias, GenericAlias)): - res = tp.__args__ - if _should_unflatten_callable_args(tp, res): - res = (list(res[:-1]), res[-1]) - return res - if isinstance(tp, types.UnionType): - return tp.__args__ - return () + match typ: + case _AnnotatedAlias(origin, metadata): + return (origin,) + metadata + case _GenericAlias() | GenericAlias(): + res = typ.__args__ + if _should_unflatten_callable_args(typ, res): + res = (list(res[:-1]), res[-1]) + return res + case types.UnionType(__args__=args): + return args + case _: + return () def is_typeddict(tp): @@ -2618,14 +2641,19 @@ def no_type_check(arg): # If classes / methods are nested in multiple layers, # we will modify them when processing their direct holders. continue + # Instance, class, and static methods: - if isinstance(obj, types.FunctionType): - obj.__no_type_check__ = True - if isinstance(obj, types.MethodType): - obj.__func__.__no_type_check__ = True - # Nested types: - if isinstance(obj, type): - no_type_check(obj) + match obj: + case types.FunctionType(): + obj.__no_type_check__ = True + case types.MethodType(): + obj.__func__.__no_type_check__ = True + case type(): + # Nested types: + no_type_check(obj) + case _: + pass + try: arg.__no_type_check__ = True except TypeError: # built-in classes @@ -3074,12 +3102,15 @@ class Employee(NamedTuple): nt.__orig_bases__ = (NamedTuple,) return nt + _NamedTuple = type.__new__(NamedTupleMeta, 'NamedTuple', (), {}) + def _namedtuple_mro_entries(bases): assert NamedTuple in bases return (_NamedTuple,) + NamedTuple.__mro_entries__ = _namedtuple_mro_entries @@ -3554,7 +3585,7 @@ def encoding(self) -> str: @property @abstractmethod - def errors(self) -> Optional[str]: + def errors(self) -> str | None: pass @property @@ -3764,23 +3795,32 @@ def __getattr__(attr): Soft-deprecated objects which are costly to create are only created on-demand here. """ - if attr in {"Pattern", "Match"}: - import re - obj = _alias(getattr(re, attr), 1) - elif attr in {"ContextManager", "AsyncContextManager"}: - import contextlib - obj = _alias(getattr(contextlib, f"Abstract{attr}"), 2, name=attr, defaults=(bool | None,)) - elif attr == "_collect_parameters": - import warnings + match attr: + case "Pattern" | "Match": + import re + obj = _alias(getattr(re, attr), 1) + + case "ContextManager" | "AsyncContextManager": + import contextlib + obj = _alias( + getattr(contextlib, f"Abstract{attr}"), + 2, + name=attr, + defaults=(bool | None,) + ) + + case "_collect_parameters": + import warnings + depr_message = ( + "The private _collect_parameters function is deprecated and will be" + " removed in a future version of Python. Any use of private functions" + " is discouraged and may break in the future." + ) + warnings.warn(depr_message, category=DeprecationWarning, stacklevel=2) + obj = _collect_type_parameters + + case _: + raise AttributeError(f"module {__name__!r} has no attribute {attr!r}") - depr_message = ( - "The private _collect_parameters function is deprecated and will be" - " removed in a future version of Python. Any use of private functions" - " is discouraged and may break in the future." - ) - warnings.warn(depr_message, category=DeprecationWarning, stacklevel=2) - obj = _collect_type_parameters - else: - raise AttributeError(f"module {__name__!r} has no attribute {attr!r}") globals()[attr] = obj return obj