Skip to content

Commit

Permalink
fix: correct subtype generation for 3.10+ types.UnionTypes
Browse files Browse the repository at this point in the history
  • Loading branch information
autumnjolitz committed Jul 30, 2024
1 parent 10095e7 commit 50a7439
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 32 deletions.
3 changes: 2 additions & 1 deletion instruct/subtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import itertools
from copy import copy
from typing import Type, Any, Callable, TypeVar, Iterable, Mapping, Union, cast, overload
from typing_extensions import get_origin, get_args
from typing_extensions import get_args

from .typedef import is_typing_definition, parse_typedef, ismetasubclass

Expand All @@ -16,6 +16,7 @@
TypeHint,
Atomic,
isabstractcollectiontype,
get_origin,
)

T = TypeVar("T")
Expand Down
40 changes: 9 additions & 31 deletions instruct/typedef.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
TYPE_CHECKING,
)
from typing_extensions import (
get_origin as _get_origin,
get_original_bases,
is_protocol,
get_protocol_members,
Expand All @@ -52,6 +51,8 @@
TypeHint,
CustomTypeCheck,
Never,
get_origin,
copy_with,
)
from .utils import flatten_restrict as flatten
from .exceptions import RangeError, TypeError as InstructTypeError
Expand All @@ -62,7 +63,7 @@

_has_typealiastype: bool = False

if sys.version_info >= (3, 12):
if sys.version_info[:2] >= (3, 12):
from typing import TypeAliasType

_has_typealiastype = True
Expand All @@ -76,31 +77,11 @@ def __new__(cls):
raise NotImplementedError


if sys.version_info >= (3, 11):
if sys.version_info[:2] >= (3, 11):
from typing import TypeVarTuple, Unpack
else:
from typing_extensions import TypeVarTuple, Unpack

if sys.version_info >= (3, 10):
from typing import ParamSpec
from types import UnionType

UnionTypes = (Union, UnionType)

# patch get_origin to always return a Union over a 'a | b'
def get_origin(cls):
t = _get_origin(cls)
if isinstance(t, type) and issubclass(t, UnionType):
return Union
return t

else:
from typing_extensions import ParamSpec

UnionTypes = (Union,)
get_origin = _get_origin


if typing.TYPE_CHECKING:
from weakref import WeakKeyDictionary as _WeakKeyDictionary

Expand All @@ -110,10 +91,7 @@ class WeakKeyDictionary(_WeakKeyDictionary, Generic[T, U]):
else:
from weakref import WeakKeyDictionary


def is_union_typedef(t) -> bool:
return _get_origin(t) in UnionTypes

Unpack

_abstract_custom_types: WeakKeyDictionary[
CustomTypeCheck, Tuple[Callable, Callable]
Expand Down Expand Up @@ -471,7 +449,7 @@ def find_class_in_definition(
test_func = lambda child: isinstance(child, TypeVar)

if is_typing_definition(type_hints):
type_cls: Type = cast(Type, type_hints)
type_cls: TypeHint = type_hints
origin_cls = get_origin(type_cls)
args = get_args(type_cls)
if origin_cls is Annotated:
Expand All @@ -492,7 +470,7 @@ def find_class_in_definition(
args = (*args[:index], replacement, *args[index + 1 :])
# args = args[:index] + (replacement,) + args[index + 1 :]
if args != get_args(type_cls):
type_cls = type_cls.copy_with(args)
type_cls = copy_with(type_cls, args)
type_cls_copied = True

elif isinstance(origin_cls, type) and (
Expand All @@ -510,7 +488,7 @@ def find_class_in_definition(
if replacement is not None:
args = (key_type, replacement)
if args != get_args(type_cls):
type_cls = type_cls.copy_with(args)
type_cls = copy_with(type_cls, args)
type_cls_copied = True
else:
for index, child in enumerate(args):
Expand All @@ -524,7 +502,7 @@ def find_class_in_definition(
args = (*args[:index], replacement, *args[index + 1 :])
# args = args[:index] + (replacement,) + args[index + 1 :]
if args != get_args(type_cls):
type_cls = type_cls.copy_with(args)
type_cls = copy_with(type_cls, args)
type_cls_copied = True
elif test_func(type_cls):
replacement = yield type_cls
Expand Down
39 changes: 39 additions & 0 deletions instruct/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from contextlib import suppress
import sys
from collections.abc import Collection as AbstractCollection
from typing import (
Collection,
Expand All @@ -14,6 +15,8 @@
List,
)

from typing_extensions import get_origin as _get_origin

from .compat import *
from .types import BaseAtomic

Expand Down Expand Up @@ -210,3 +213,39 @@ class ExceptionHasMetadata(Protocol):

def exception_is_jsonable(e: Exception) -> TypeGuard[Union[HasJSONMagicMethod, HasToJSON]]:
return callable(getattr(e, "__json__", None)) or callable(getattr(e, "to_json", None))


class ICopyWithable(Protocol[T_co]):
def copy_with(self: T_co, args) -> T_co:
...


def is_copywithable(t: Union[Type[Any], TypeHint]) -> TypeGuard[ICopyWithable[TypeHint]]:
return callable(getattr(t, "copy_with", None))


if sys.version_info[:2] >= (3, 10):
from types import UnionType

UnionTypes = (Union, UnionType)

# patch get_origin to always return a Union over a 'a | b'
def get_origin(cls): # type:ignore[no-redef]
t = _get_origin(cls)
if isinstance(t, type) and issubclass(t, UnionType):
return Union
return t

def copy_with(hint: TypeHint, args) -> TypeHint:
if isinstance(hint, UnionType):
return Union[args]
if is_copywithable(hint):
return hint.copy_with(args)
raise NotImplementedError(f"Unable to copy with new type args on {hint!r} ({type(hint)!r})")

else:
UnionTypes = (Union,)
get_origin = _get_origin

def copy_with(hint, args):
return hint.copy_with(args)
8 changes: 8 additions & 0 deletions tests/test_subtype_310.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from instruct.subtype import (
wrapper_for_type,
)
from instruct import AtomicMeta


def test_union():
wrapper_for_type(int | str, {}, AtomicMeta)

0 comments on commit 50a7439

Please sign in to comment.