Skip to content

Commit

Permalink
Fix some bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra committed Jun 1, 2024
1 parent 44d890e commit 7c21c6a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
2 changes: 2 additions & 0 deletions Lib/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ def get_annotations(
obj_globals = getattr(obj, "__globals__", None)
obj_locals = None
unwrap = obj
elif (ann := getattr(obj, "__annotations__", None)) is not None:
obj_globals = obj_locals = unwrap = None
else:
raise TypeError(f"{obj!r} is not a module, class, or callable.")

Expand Down
7 changes: 7 additions & 0 deletions Lib/test/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,10 @@ class TestGetAnnotations(unittest.TestCase):
def test_builtin_type(self):
self.assertEqual(annotations.get_annotations(int), {})
self.assertEqual(annotations.get_annotations(object), {})

def test_custom_object_with_annotations(self):
class C:
def __init__(self, x: int, y: str):
self.__annotations__ = {"x": int, "y": str}

self.assertEqual(annotations.get_annotations(C()), {"x": int, "y": str})
8 changes: 4 additions & 4 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3098,7 +3098,7 @@ def __new__(cls, name, bases, ns, total=True):
if not hasattr(tp_dict, '__orig_bases__'):
tp_dict.__orig_bases__ = bases

annotations = {}
new_annotations = {}
if "__annotations__" in ns:
own_annotations = ns["__annotations__"]
elif "__annotate__" in ns:
Expand All @@ -3121,7 +3121,7 @@ def __new__(cls, name, bases, ns, total=True):
# keys have Required/NotRequired/ReadOnly qualifiers, and create
# a new __annotate__ function for the resulting TypedDict that
# combines the annotations from this class and its parents.
annotations.update(base.__annotations__)
new_annotations.update(base.__annotations__)

base_required = base.__dict__.get('__required_keys__', set())
required_keys |= base_required
Expand All @@ -3134,7 +3134,7 @@ def __new__(cls, name, bases, ns, total=True):
readonly_keys.update(base.__dict__.get('__readonly_keys__', ()))
mutable_keys.update(base.__dict__.get('__mutable_keys__', ()))

annotations.update(own_annotations)
new_annotations.update(own_annotations)
for annotation_key, annotation_type in own_annotations.items():
qualifiers = set(_get_typeddict_qualifiers(annotation_type))
if Required in qualifiers:
Expand Down Expand Up @@ -3166,7 +3166,7 @@ def __new__(cls, name, bases, ns, total=True):
f"Required keys overlap with optional keys in {name}:"
f" {required_keys=}, {optional_keys=}"
)
tp_dict.__annotations__ = annotations
tp_dict.__annotations__ = new_annotations
tp_dict.__required_keys__ = frozenset(required_keys)
tp_dict.__optional_keys__ = frozenset(optional_keys)
tp_dict.__readonly_keys__ = frozenset(readonly_keys)
Expand Down

0 comments on commit 7c21c6a

Please sign in to comment.