From 7c21c6a3f82ee409581ea2bd2db34bd4055be508 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Fri, 31 May 2024 17:56:45 -0700 Subject: [PATCH] Fix some bugs --- Lib/annotations.py | 2 ++ Lib/test/test_annotations.py | 7 +++++++ Lib/typing.py | 8 ++++---- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/Lib/annotations.py b/Lib/annotations.py index 72d9b6d1b7b52f..34ba99ffbdbf60 100644 --- a/Lib/annotations.py +++ b/Lib/annotations.py @@ -369,6 +369,8 @@ def get_annotations( obj_globals = getattr(obj, "__globals__", None) obj_locals = None unwrap = obj + elif (ann := getattr(obj, "__annotations__", None)) is not None: + obj_globals = obj_locals = unwrap = None else: raise TypeError(f"{obj!r} is not a module, class, or callable.") diff --git a/Lib/test/test_annotations.py b/Lib/test/test_annotations.py index f366afbfc8135a..c920720d4a3469 100644 --- a/Lib/test/test_annotations.py +++ b/Lib/test/test_annotations.py @@ -69,3 +69,10 @@ class TestGetAnnotations(unittest.TestCase): def test_builtin_type(self): self.assertEqual(annotations.get_annotations(int), {}) self.assertEqual(annotations.get_annotations(object), {}) + + def test_custom_object_with_annotations(self): + class C: + def __init__(self, x: int, y: str): + self.__annotations__ = {"x": int, "y": str} + + self.assertEqual(annotations.get_annotations(C()), {"x": int, "y": str}) diff --git a/Lib/typing.py b/Lib/typing.py index 291161457d3fad..654648464b2ce3 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -3098,7 +3098,7 @@ def __new__(cls, name, bases, ns, total=True): if not hasattr(tp_dict, '__orig_bases__'): tp_dict.__orig_bases__ = bases - annotations = {} + new_annotations = {} if "__annotations__" in ns: own_annotations = ns["__annotations__"] elif "__annotate__" in ns: @@ -3121,7 +3121,7 @@ def __new__(cls, name, bases, ns, total=True): # keys have Required/NotRequired/ReadOnly qualifiers, and create # a new __annotate__ function for the resulting TypedDict that # combines the annotations from this class and its parents. - annotations.update(base.__annotations__) + new_annotations.update(base.__annotations__) base_required = base.__dict__.get('__required_keys__', set()) required_keys |= base_required @@ -3134,7 +3134,7 @@ def __new__(cls, name, bases, ns, total=True): readonly_keys.update(base.__dict__.get('__readonly_keys__', ())) mutable_keys.update(base.__dict__.get('__mutable_keys__', ())) - annotations.update(own_annotations) + new_annotations.update(own_annotations) for annotation_key, annotation_type in own_annotations.items(): qualifiers = set(_get_typeddict_qualifiers(annotation_type)) if Required in qualifiers: @@ -3166,7 +3166,7 @@ def __new__(cls, name, bases, ns, total=True): f"Required keys overlap with optional keys in {name}:" f" {required_keys=}, {optional_keys=}" ) - tp_dict.__annotations__ = annotations + tp_dict.__annotations__ = new_annotations tp_dict.__required_keys__ = frozenset(required_keys) tp_dict.__optional_keys__ = frozenset(optional_keys) tp_dict.__readonly_keys__ = frozenset(readonly_keys)