From 2117a8d0ca154c86ceedff2a546b5942c56b0301 Mon Sep 17 00:00:00 2001 From: Autumn Date: Tue, 30 Jul 2024 16:04:24 -0700 Subject: [PATCH] fix: adjust ``copy_with`` to attempt to use ``__class_getitem__`` as the fallback --- instruct/typing.py | 4 ++++ tasks.py | 13 ++++++++++++- tests/test_atomic_310.py | 17 +++++++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 tests/test_atomic_310.py diff --git a/instruct/typing.py b/instruct/typing.py index d4a7a5e..0d462f0 100644 --- a/instruct/typing.py +++ b/instruct/typing.py @@ -241,6 +241,10 @@ def copy_with(hint: TypeHint, args) -> TypeHint: return Union[args] if is_copywithable(hint): return hint.copy_with(args) + type_cls = get_origin(hint) + with suppress(AttributeError): + if type_cls is not None: + return type_cls[args] raise NotImplementedError(f"Unable to copy with new type args on {hint!r} ({type(hint)!r})") else: diff --git a/tasks.py b/tasks.py index 529e4fa..eceb945 100644 --- a/tasks.py +++ b/tasks.py @@ -288,13 +288,24 @@ def black(context: Context, check: bool = False): @task -def test(context: Context, *, verbose: bool = False, fail_fast: bool = False): +def test( + context: Context, + *, + test_files: Optional[Union[str, List[str], Tuple[str, ...]]] = None, + verbose: bool = False, + fail_fast: bool = False, +): python_bin = _.python_path(str, silent=True) extra = "" if verbose: extra = f"{extra} -svvv" if fail_fast: extra = f"{extra} -x" + if test_files: + if isinstance(test_files, str): + test_files = tuple(x.strip() for x in test_files.split(",")) + f = " ".join(test_files) + extra = f"{extra} {f}" context.run(f"{python_bin} -m coverage run -m pytest {extra}") diff --git a/tests/test_atomic_310.py b/tests/test_atomic_310.py new file mode 100644 index 0000000..3f8dfd6 --- /dev/null +++ b/tests/test_atomic_310.py @@ -0,0 +1,17 @@ +from instruct import Base, public_class + + +def test_subtraction(): + class Foo(Base): + a: tuple[int, ...] + b: str | int + c: str + + class FooGroup(Base): + foos: list[Foo] + + cls = FooGroup - {"foos": {"a", "b"}} + + ModifiedFoo = public_class(cls, "foos", preserve_subtraction=True) + + assert tuple(ModifiedFoo) == ("c",)