diff --git a/instruct/__init__.py b/instruct/__init__.py index fbacb40..f9d7f81 100644 --- a/instruct/__init__.py +++ b/instruct/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +import functools import inspect import logging import os @@ -134,7 +135,7 @@ def public_class( return public_atomic_classes[0] return public_atomic_classes else: - next_cls, = atomic_classes + (next_cls,) = atomic_classes return public_class(next_cls, *rest, preserve_subtraction=preserve_subtraction) cls = cls.__public_class__() if preserve_subtraction and any((cls._skipped_fields, cls._modified_fields)): @@ -204,7 +205,7 @@ def keys( return cls._all_accessible_fields return KeysView(tuple(cls._slots)) if len(property_path) == 1: - key, = property_path + (key,) = property_path if key not in cls._nested_atomic_collection_keys: return keys(cls._slots[key]) if len(cls._nested_atomic_collection_keys[key]) == 1: @@ -471,7 +472,6 @@ def key_func(item: Type) -> int: def make_class_cell(): return CellType(None) - else: def make_class_cell() -> CellType: @@ -489,7 +489,7 @@ def bar(): return bar fake_function = closure_maker() - class_cell, = fake_function.__closure__ + (class_cell,) = fake_function.__closure__ del fake_function return class_cell @@ -1029,9 +1029,10 @@ def apply_skip_keys( current_coerce = None else: while hasattr(current_coerce_cast_function, "__union_subtypes__"): - current_coerce_types, current_coerce_cast_function = ( - current_coerce_cast_function.__union_subtypes__ - ) + ( + current_coerce_types, + current_coerce_cast_function, + ) = current_coerce_cast_function.__union_subtypes__ current_coerce = (current_coerce_types, current_coerce_cast_function) del current_coerce_types, current_coerce_cast_function @@ -1119,6 +1120,16 @@ def is_defined_coerce(cls, key): return None +def wrap_init_subclass(func): + @functools.wraps(func) + def __init_subclass__(cls, **kwargs): + if cls._is_data_class: + return + return func(cls, **kwargs) + + return __init_subclass__ + + class Atomic(type): __slots__ = () REGISTRY = ReadOnly(set()) @@ -1286,6 +1297,7 @@ def __new__( **mixins, ): if concrete_class: + attrs["_is_data_class"] = ReadOnly(True) cls = super().__new__(klass, class_name, bases, attrs) if not getattr(cls, "__hash__", None): cls.__hash__ = object.__hash__ @@ -1387,10 +1399,26 @@ def __new__( nested_atomic_collections: Dict[str, Atomic] = {} # Mapping of public name -> custom type vector for `isinstance(...)` checks! column_types: Dict[str, Union[Type, Tuple[Type, ...]]] = {} + base_class_has_subclass_init = False - for mixin_name in mixins: + for cls in bases: + if cls is object: + break + base_class_has_subclass_init = hasattr(cls, "__init_subclass__") + if base_class_has_subclass_init: + break + + init_subclass_kwargs = {} + + for mixin_name in tuple(mixins): if mixins[mixin_name]: - mixin_cls = klass.MIXINS[mixin_name] + try: + mixin_cls = klass.MIXINS[mixin_name] + except KeyError: + if base_class_has_subclass_init: + init_subclass_kwargs[mixin_name] = mixins[mixin_name] + continue + raise ValueError(f"{mixin_name!r} is not a registered Mixin on Atomic!") if isinstance(mixins[mixin_name], type): mixin_cls = mixins[mixin_name] bases = (mixin_cls,) + bases @@ -1645,6 +1673,10 @@ def __new__( ns_globals = {"NoneType": NoneType, "Flags": Flags, "typing": typing} ns_globals[class_name] = ReadOnly(None) + init_subclass = None + + if "__init_subclass__" in support_cls_attrs: + init_subclass = support_cls_attrs.pop("__init_subclass__") if combined_columns: exec( @@ -1789,7 +1821,10 @@ def __new__( support_cls_attrs["_data_class"] = support_cls_attrs[f"_{class_name}"] = dc = ReadOnly(None) support_cls_attrs["_parent"] = parent_cell = ReadOnly(None) - support_cls = super().__new__(klass, class_name, bases, support_cls_attrs) + support_cls_attrs["_is_data_class"] = ReadOnly(False) + support_cls = super().__new__( + klass, class_name, bases, support_cls_attrs, **init_subclass_kwargs + ) for prop_name, value in support_cls_attrs.items(): if isinstance(value, property): @@ -1823,6 +1858,8 @@ def __new__( data_class.__qualname__ = f"{support_cls.__qualname__}.{data_class.__name__}" parent_cell.value = support_cls klass.REGISTRY.add(support_cls) + if init_subclass is not None: + support_cls.__init_subclass__ = classmethod(wrap_init_subclass(init_subclass)) return support_cls def from_json(cls: Type[T], data: Dict[str, Any]) -> T: diff --git a/instruct/typedef.py b/instruct/typedef.py index 1223599..7ceab60 100644 --- a/instruct/typedef.py +++ b/instruct/typedef.py @@ -1,6 +1,8 @@ from __future__ import annotations import collections.abc from functools import wraps +import types +import sys from collections.abc import Mapping as AbstractMapping from typing import Union, Any, AnyStr, List, Tuple, cast, Optional, Callable, Type @@ -13,7 +15,7 @@ except ImportError: from typing_extensions import Annotated -from typing_extensions import get_origin +from typing_extensions import get_origin as _get_origin from typing_extensions import get_args from .utils import flatten_restrict as flatten @@ -21,10 +23,19 @@ from .constants import Range from .exceptions import RangeError +if sys.version_info < (3, 10): + get_origin = _get_origin +else: + + def get_origin(cls): + t = _get_origin(cls) + if isinstance(t, type) and issubclass(t, types.UnionType): + return Union[cls.__args__] + return t + def make_custom_typecheck(func) -> Type[ICustomTypeCheck]: - """Create a custom type that will turn `isinstance(item, klass)` into `func(item)` - """ + """Create a custom type that will turn `isinstance(item, klass)` into `func(item)`""" typename = "WrappedType<{}>" class WrappedType(type): @@ -435,6 +446,8 @@ def is_typing_definition(item): origin = get_origin(item) if origin is not None: return is_typing_definition(origin) + if isinstance(item, (types.UnionType)): + return True return False diff --git a/tests/test_atomic.py b/tests/test_atomic.py index c2717f7..c4d3906 100644 --- a/tests/test_atomic.py +++ b/tests/test_atomic.py @@ -1,5 +1,6 @@ import json import pprint +import sys from typing import Union, List, Tuple, Optional, Dict, Any, Type try: @@ -1489,3 +1490,43 @@ class Foo(SimpleBase): pass assert list(Foo()) == [] + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.8 or higher") +def test_using_builtin_unions(): + class TestUnion(SimpleBase): + field: str | int + + TestUnion("foo") + TestUnion(1) + with pytest.raises(TypeError): + TestUnion(1.5) + + +def test_with_init_subclass(): + Registry = {} + + class Foo(SimpleBase): + def __init_subclass__(cls, swallow: str, **kwargs): + Registry[cls] = swallow + super().__init_subclass__() + + f = Foo() + + class Bar(Foo, swallow="Barn!"): + ... + + assert Bar in Registry + assert Registry[Bar] == "Barn!" + assert len(Registry) == 1 + + class BarBar(Bar, swallow="Farter"): + def __init_subclass__(cls, **kwargs): + return + + assert len(Registry) == 2 + + class BreakChainBar(BarBar): + ... + + assert len(Registry) == 2