Skip to content

Commit

Permalink
Use pattern-matching for some typing.py internals
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood committed May 11, 2024
1 parent b88889e commit ef5fa7d
Showing 1 changed file with 129 additions and 94 deletions.
223 changes: 129 additions & 94 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,13 @@

def _type_convert(arg, module=None, *, allow_special_forms=False):
"""For converting None to type(None), and strings to ForwardRef."""
if arg is None:
return type(None)
if isinstance(arg, str):
return ForwardRef(arg, module=module, is_class=allow_special_forms)
return arg
match arg:
case None:
return types.NoneType
case str():
return ForwardRef(arg, module=module, is_class=allow_special_forms)
case _:
return arg


def _type_check(arg, msg, is_argument=True, module=None, *, allow_special_forms=False):
Expand Down Expand Up @@ -242,18 +244,20 @@ def _type_repr(obj):
# When changing this function, don't forget about
# `_collections_abc._type_repr`, which does the same thing
# and must be consistent with this one.
if isinstance(obj, type):
if obj.__module__ == 'builtins':
match obj:
case type(__module__="builtins"):
return obj.__qualname__
return f'{obj.__module__}.{obj.__qualname__}'
if obj is ...:
return '...'
if isinstance(obj, types.FunctionType):
return obj.__name__
if isinstance(obj, tuple):
# Special case for `repr` of types with `ParamSpec`:
return '[' + ', '.join(_type_repr(t) for t in obj) + ']'
return repr(obj)
case type():
return f'{obj.__module__}.{obj.__qualname__}'
case types.EllipsisType():
return '...'
case types.FunctionType():
return obj.__name__
case tuple():
# Special case for `repr` of types with `ParamSpec`:
return '[' + ', '.join(_type_repr(elem) for elem in obj) + ']'
case _:
return repr(obj)


def _collect_type_parameters(args, *, enforce_default_ordering: bool = True):
Expand Down Expand Up @@ -340,13 +344,16 @@ def _check_generic_specialization(cls, arguments):
def _unpack_args(*args):
newargs = []
for arg in args:
subargs = getattr(arg, '__typing_unpacked_tuple_args__', None)
if subargs is not None and not (subargs and subargs[-1] is ...):
newargs.extend(subargs)
else:
newargs.append(arg)
match getattr(arg, '__typing_unpacked_tuple_args__', None):
case str() | [*_, types.EllipsisType()]:
newargs.append(arg)
case [*subargs]:
newargs.extend(subargs)
case _:
newargs.append(arg)
return newargs


def _deduplicate(params, *, unhashable_fallback=False):
# Weed out strict duplicates, preserving the first of each occurrence.
try:
Expand All @@ -357,13 +364,15 @@ def _deduplicate(params, *, unhashable_fallback=False):
# Happens for cases like `Annotated[dict, {'x': IntValidator()}]`
return _deduplicate_unhashable(params)


def _deduplicate_unhashable(unhashable_params):
new_unhashable = []
for t in unhashable_params:
if t not in new_unhashable:
new_unhashable.append(t)
return new_unhashable


def _compare_args_orderless(first_args, second_args):
first_unhashable = _deduplicate_unhashable(first_args)
second_unhashable = _deduplicate_unhashable(second_args)
Expand All @@ -375,6 +384,7 @@ def _compare_args_orderless(first_args, second_args):
return False
return not t


def _remove_dups_flatten(parameters):
"""Internal helper for Union creation and substitution.
Expand Down Expand Up @@ -630,6 +640,7 @@ def stop() -> NoReturn:
"""
raise TypeError(f"{self} is not subscriptable")


# This is semantically identical to NoReturn, but it is implemented
# separately so that type checkers can distinguish between the two
# if they want.
Expand Down Expand Up @@ -727,6 +738,7 @@ class Starship:
item = _type_check(parameters, f'{self} accepts only single type.', allow_special_forms=True)
return _GenericAlias(self, (item,))


@_SpecialForm
def Final(self, parameters):
"""Special typing construct to indicate final names to type checkers.
Expand All @@ -749,6 +761,7 @@ class FastConnector(Connection):
item = _type_check(parameters, f'{self} accepts only single type.', allow_special_forms=True)
return _GenericAlias(self, (item,))


@_SpecialForm
def Union(self, parameters):
"""Union type; Union[X, Y] means either X or Y.
Expand Down Expand Up @@ -793,6 +806,7 @@ def Union(self, parameters):
return _UnionGenericAlias(self, parameters, name="Optional")
return _UnionGenericAlias(self, parameters)


def _make_union(left, right):
"""Used from the C implementation of TypeVar.
Expand All @@ -802,12 +816,14 @@ def _make_union(left, right):
"""
return Union[left, right]


@_SpecialForm
def Optional(self, parameters):
"""Optional[X] is equivalent to Union[X, None]."""
arg = _type_check(parameters, f"{self} requires a single type.")
return Union[arg, type(None)]


@_TypedCacheSpecialForm
@_tp_cache(typed=True)
def Literal(self, *parameters):
Expand Down Expand Up @@ -1284,6 +1300,7 @@ def _generic_init_subclass(cls, *args, **kwargs):
def _is_dunder(attr):
return attr.startswith('__') and attr.endswith('__')


class _BaseGenericAlias(_Final, _root=True):
"""The central part of the internal API.
Expand Down Expand Up @@ -1952,6 +1969,7 @@ def _caller(depth=1, default='__main__'):
pass
return None


def _allow_reckless_class_checks(depth=2):
"""Allow instance and class checks for special stdlib modules.
Expand Down Expand Up @@ -2192,6 +2210,7 @@ class _AnnotatedAlias(_NotIterable, _GenericAlias, _root=True):
The metadata itself is stored in a '__metadata__' attribute as a tuple.
"""
__match_args__ = ("__origin__", "__metadata__")

def __init__(self, origin, metadata):
if isinstance(origin, _AnnotatedAlias):
Expand Down Expand Up @@ -2470,32 +2489,33 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()}


def _strip_annotations(t):
def _strip_annotations(typ):
"""Strip the annotations from a given type."""
if isinstance(t, _AnnotatedAlias):
return _strip_annotations(t.__origin__)
if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired, ReadOnly):
return _strip_annotations(t.__args__[0])
if isinstance(t, _GenericAlias):
stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
if stripped_args == t.__args__:
return t
return t.copy_with(stripped_args)
if isinstance(t, GenericAlias):
stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
if stripped_args == t.__args__:
return t
return GenericAlias(t.__origin__, stripped_args)
if isinstance(t, types.UnionType):
stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
if stripped_args == t.__args__:
return t
return functools.reduce(operator.or_, stripped_args)

return t


def get_origin(tp):
match typ:
case _AnnotatedAlias(origin):
return _strip_annotations(origin)
case object(__origin__=origin) if origin in {Required, NotRequired, ReadOnly}:
return _strip_annotations(typ.__args__[0])
case _GenericAlias(__args__=args):
stripped_args = tuple(_strip_annotations(arg) for arg in args)
if stripped_args == args:
return typ
return typ.copy_with(stripped_args)
case GenericAlias(__args__=args):
stripped_args = tuple(_strip_annotations(arg) for arg in args)
if stripped_args == args:
return typ
return GenericAlias(typ.__origin__, stripped_args)
case types.UnionType(__args__=args):
stripped_args = tuple(_strip_annotations(arg) for arg in args)
if stripped_args == args:
return typ
return functools.reduce(operator.or_, stripped_args)
case _:
return typ


def get_origin(typ):
"""Get the unsubscripted version of a type.
This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar,
Expand All @@ -2513,19 +2533,20 @@ def get_origin(tp):
>>> assert get_origin(List[Tuple[T, T]][int]) is list
>>> assert get_origin(P.args) is P
"""
if isinstance(tp, _AnnotatedAlias):
return Annotated
if isinstance(tp, (_BaseGenericAlias, GenericAlias,
ParamSpecArgs, ParamSpecKwargs)):
return tp.__origin__
if tp is Generic:
return Generic
if isinstance(tp, types.UnionType):
return types.UnionType
return None


def get_args(tp):
match typ:
case _AnnotatedAlias():
return Annotated
case _BaseGenericAlias() | GenericAlias() | ParamSpecArgs() | ParamSpecKwargs():
return typ.__origin__
case types.UnionType():
return types.UnionType
case typ if typ is Generic:
return Generic
case _:
return None


def get_args(typ):
"""Get type arguments with all substitutions performed.
For unions, basic simplifications used by Union constructor are performed.
Expand All @@ -2539,16 +2560,18 @@ def get_args(tp):
>>> assert get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
>>> assert get_args(Callable[[], T][int]) == ([], int)
"""
if isinstance(tp, _AnnotatedAlias):
return (tp.__origin__,) + tp.__metadata__
if isinstance(tp, (_GenericAlias, GenericAlias)):
res = tp.__args__
if _should_unflatten_callable_args(tp, res):
res = (list(res[:-1]), res[-1])
return res
if isinstance(tp, types.UnionType):
return tp.__args__
return ()
match typ:
case _AnnotatedAlias(origin, metadata):
return (origin,) + metadata
case _GenericAlias() | GenericAlias():
res = typ.__args__
if _should_unflatten_callable_args(typ, res):
res = (list(res[:-1]), res[-1])
return res
case types.UnionType(__args__=args):
return args
case _:
return ()


def is_typeddict(tp):
Expand Down Expand Up @@ -2618,14 +2641,19 @@ def no_type_check(arg):
# If classes / methods are nested in multiple layers,
# we will modify them when processing their direct holders.
continue

# Instance, class, and static methods:
if isinstance(obj, types.FunctionType):
obj.__no_type_check__ = True
if isinstance(obj, types.MethodType):
obj.__func__.__no_type_check__ = True
# Nested types:
if isinstance(obj, type):
no_type_check(obj)
match obj:
case types.FunctionType():
obj.__no_type_check__ = True
case types.MethodType():
obj.__func__.__no_type_check__ = True
case type():
# Nested types:
no_type_check(obj)
case _:
pass

try:
arg.__no_type_check__ = True
except TypeError: # built-in classes
Expand Down Expand Up @@ -3074,12 +3102,15 @@ class Employee(NamedTuple):
nt.__orig_bases__ = (NamedTuple,)
return nt


_NamedTuple = type.__new__(NamedTupleMeta, 'NamedTuple', (), {})


def _namedtuple_mro_entries(bases):
assert NamedTuple in bases
return (_NamedTuple,)


NamedTuple.__mro_entries__ = _namedtuple_mro_entries


Expand Down Expand Up @@ -3554,7 +3585,7 @@ def encoding(self) -> str:

@property
@abstractmethod
def errors(self) -> Optional[str]:
def errors(self) -> str | None:
pass

@property
Expand Down Expand Up @@ -3764,23 +3795,27 @@ def __getattr__(attr):
Soft-deprecated objects which are costly to create
are only created on-demand here.
"""
if attr in {"Pattern", "Match"}:
import re
obj = _alias(getattr(re, attr), 1)
elif attr in {"ContextManager", "AsyncContextManager"}:
import contextlib
obj = _alias(getattr(contextlib, f"Abstract{attr}"), 2, name=attr, defaults=(bool | None,))
elif attr == "_collect_parameters":
import warnings
match attr:
case "Pattern" | "Match":
import re
obj = _alias(getattr(re, attr), 1)

case "ContextManager" | "AsyncContextManager":
import contextlib
obj = _alias(getattr(contextlib, f"Abstract{attr}"), 2, name=attr, defaults=(bool | None,))

case "_collect_parameters":
import warnings
depr_message = (
"The private _collect_parameters function is deprecated and will be"
" removed in a future version of Python. Any use of private functions"
" is discouraged and may break in the future."
)
warnings.warn(depr_message, category=DeprecationWarning, stacklevel=2)
obj = _collect_type_parameters

case _:
raise AttributeError(f"module {__name__!r} has no attribute {attr!r}")

depr_message = (
"The private _collect_parameters function is deprecated and will be"
" removed in a future version of Python. Any use of private functions"
" is discouraged and may break in the future."
)
warnings.warn(depr_message, category=DeprecationWarning, stacklevel=2)
obj = _collect_type_parameters
else:
raise AttributeError(f"module {__name__!r} has no attribute {attr!r}")
globals()[attr] = obj
return obj

0 comments on commit ef5fa7d

Please sign in to comment.