Skip to content

Commit

Permalink
[typedef] support type | type in 3.10+ and __init_subclass__
Browse files Browse the repository at this point in the history
  • Loading branch information
autumnjolitz committed Feb 29, 2024
1 parent f069f7e commit 78c1a85
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 13 deletions.
57 changes: 47 additions & 10 deletions instruct/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
import inspect
import logging
import os
Expand Down Expand Up @@ -134,7 +135,7 @@ def public_class(
return public_atomic_classes[0]
return public_atomic_classes
else:
next_cls, = atomic_classes
(next_cls,) = atomic_classes
return public_class(next_cls, *rest, preserve_subtraction=preserve_subtraction)
cls = cls.__public_class__()
if preserve_subtraction and any((cls._skipped_fields, cls._modified_fields)):
Expand Down Expand Up @@ -204,7 +205,7 @@ def keys(
return cls._all_accessible_fields
return KeysView(tuple(cls._slots))
if len(property_path) == 1:
key, = property_path
(key,) = property_path
if key not in cls._nested_atomic_collection_keys:
return keys(cls._slots[key])
if len(cls._nested_atomic_collection_keys[key]) == 1:
Expand Down Expand Up @@ -471,7 +472,6 @@ def key_func(item: Type) -> int:
def make_class_cell():
return CellType(None)


else:

def make_class_cell() -> CellType:
Expand All @@ -489,7 +489,7 @@ def bar():
return bar

fake_function = closure_maker()
class_cell, = fake_function.__closure__
(class_cell,) = fake_function.__closure__
del fake_function
return class_cell

Expand Down Expand Up @@ -1029,9 +1029,10 @@ def apply_skip_keys(
current_coerce = None
else:
while hasattr(current_coerce_cast_function, "__union_subtypes__"):
current_coerce_types, current_coerce_cast_function = (
current_coerce_cast_function.__union_subtypes__
)
(
current_coerce_types,
current_coerce_cast_function,
) = current_coerce_cast_function.__union_subtypes__
current_coerce = (current_coerce_types, current_coerce_cast_function)
del current_coerce_types, current_coerce_cast_function

Expand Down Expand Up @@ -1119,6 +1120,16 @@ def is_defined_coerce(cls, key):
return None


def wrap_init_subclass(func):
@functools.wraps(func)
def __init_subclass__(cls, **kwargs):
if cls._is_data_class:
return
return func(cls, **kwargs)

return __init_subclass__


class Atomic(type):
__slots__ = ()
REGISTRY = ReadOnly(set())
Expand Down Expand Up @@ -1286,6 +1297,7 @@ def __new__(
**mixins,
):
if concrete_class:
attrs["_is_data_class"] = ReadOnly(True)
cls = super().__new__(klass, class_name, bases, attrs)
if not getattr(cls, "__hash__", None):
cls.__hash__ = object.__hash__
Expand Down Expand Up @@ -1387,10 +1399,26 @@ def __new__(
nested_atomic_collections: Dict[str, Atomic] = {}
# Mapping of public name -> custom type vector for `isinstance(...)` checks!
column_types: Dict[str, Union[Type, Tuple[Type, ...]]] = {}
base_class_has_subclass_init = False

for mixin_name in mixins:
for cls in bases:
if cls is object:
break
base_class_has_subclass_init = hasattr(cls, "__init_subclass__")
if base_class_has_subclass_init:
break

init_subclass_kwargs = {}

for mixin_name in tuple(mixins):
if mixins[mixin_name]:
mixin_cls = klass.MIXINS[mixin_name]
try:
mixin_cls = klass.MIXINS[mixin_name]
except KeyError:
if base_class_has_subclass_init:
init_subclass_kwargs[mixin_name] = mixins[mixin_name]
continue
raise ValueError(f"{mixin_name!r} is not a registered Mixin on Atomic!")
if isinstance(mixins[mixin_name], type):
mixin_cls = mixins[mixin_name]
bases = (mixin_cls,) + bases
Expand Down Expand Up @@ -1645,6 +1673,10 @@ def __new__(

ns_globals = {"NoneType": NoneType, "Flags": Flags, "typing": typing}
ns_globals[class_name] = ReadOnly(None)
init_subclass = None

if "__init_subclass__" in support_cls_attrs:
init_subclass = support_cls_attrs.pop("__init_subclass__")

if combined_columns:
exec(
Expand Down Expand Up @@ -1789,7 +1821,10 @@ def __new__(

support_cls_attrs["_data_class"] = support_cls_attrs[f"_{class_name}"] = dc = ReadOnly(None)
support_cls_attrs["_parent"] = parent_cell = ReadOnly(None)
support_cls = super().__new__(klass, class_name, bases, support_cls_attrs)
support_cls_attrs["_is_data_class"] = ReadOnly(False)
support_cls = super().__new__(
klass, class_name, bases, support_cls_attrs, **init_subclass_kwargs
)

for prop_name, value in support_cls_attrs.items():
if isinstance(value, property):
Expand Down Expand Up @@ -1823,6 +1858,8 @@ def __new__(
data_class.__qualname__ = f"{support_cls.__qualname__}.{data_class.__name__}"
parent_cell.value = support_cls
klass.REGISTRY.add(support_cls)
if init_subclass is not None:
support_cls.__init_subclass__ = classmethod(wrap_init_subclass(init_subclass))
return support_cls

def from_json(cls: Type[T], data: Dict[str, Any]) -> T:
Expand Down
19 changes: 16 additions & 3 deletions instruct/typedef.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations
import collections.abc
from functools import wraps
import types
import sys
from collections.abc import Mapping as AbstractMapping
from typing import Union, Any, AnyStr, List, Tuple, cast, Optional, Callable, Type

Expand All @@ -13,18 +15,27 @@
except ImportError:
from typing_extensions import Annotated

from typing_extensions import get_origin
from typing_extensions import get_origin as _get_origin
from typing_extensions import get_args

from .utils import flatten_restrict as flatten
from .typing import ICustomTypeCheck
from .constants import Range
from .exceptions import RangeError

if sys.version_info < (3, 10):
get_origin = _get_origin
else:

def get_origin(cls):
t = _get_origin(cls)
if isinstance(t, type) and issubclass(t, types.UnionType):
return Union[cls.__args__]
return t


def make_custom_typecheck(func) -> Type[ICustomTypeCheck]:
"""Create a custom type that will turn `isinstance(item, klass)` into `func(item)`
"""
"""Create a custom type that will turn `isinstance(item, klass)` into `func(item)`"""
typename = "WrappedType<{}>"

class WrappedType(type):
Expand Down Expand Up @@ -435,6 +446,8 @@ def is_typing_definition(item):
origin = get_origin(item)
if origin is not None:
return is_typing_definition(origin)
if isinstance(item, (types.UnionType)):
return True
return False


Expand Down
41 changes: 41 additions & 0 deletions tests/test_atomic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import pprint
import sys
from typing import Union, List, Tuple, Optional, Dict, Any, Type

try:
Expand Down Expand Up @@ -1489,3 +1490,43 @@ class Foo(SimpleBase):
pass

assert list(Foo()) == []


@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.8 or higher")
def test_using_builtin_unions():
class TestUnion(SimpleBase):
field: str | int

TestUnion("foo")
TestUnion(1)
with pytest.raises(TypeError):
TestUnion(1.5)


def test_with_init_subclass():
Registry = {}

class Foo(SimpleBase):
def __init_subclass__(cls, swallow: str, **kwargs):
Registry[cls] = swallow
super().__init_subclass__()

f = Foo()

class Bar(Foo, swallow="Barn!"):
...

assert Bar in Registry
assert Registry[Bar] == "Barn!"
assert len(Registry) == 1

class BarBar(Bar, swallow="Farter"):
def __init_subclass__(cls, **kwargs):
return

assert len(Registry) == 2

class BreakChainBar(BarBar):
...

assert len(Registry) == 2

0 comments on commit 78c1a85

Please sign in to comment.