Skip to content

Commit

Permalink
fix: adjust copy_with to attempt to use __class_getitem__ as …
Browse files Browse the repository at this point in the history
…the fallback
  • Loading branch information
autumnjolitz committed Jul 30, 2024
1 parent 6f343b2 commit 2117a8d
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
4 changes: 4 additions & 0 deletions instruct/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ def copy_with(hint: TypeHint, args) -> TypeHint:
return Union[args]
if is_copywithable(hint):
return hint.copy_with(args)
type_cls = get_origin(hint)
with suppress(AttributeError):
if type_cls is not None:
return type_cls[args]
raise NotImplementedError(f"Unable to copy with new type args on {hint!r} ({type(hint)!r})")

else:
Expand Down
13 changes: 12 additions & 1 deletion tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,24 @@ def black(context: Context, check: bool = False):


@task
def test(context: Context, *, verbose: bool = False, fail_fast: bool = False):
def test(
context: Context,
*,
test_files: Optional[Union[str, List[str], Tuple[str, ...]]] = None,
verbose: bool = False,
fail_fast: bool = False,
):
python_bin = _.python_path(str, silent=True)
extra = ""
if verbose:
extra = f"{extra} -svvv"
if fail_fast:
extra = f"{extra} -x"
if test_files:
if isinstance(test_files, str):
test_files = tuple(x.strip() for x in test_files.split(","))
f = " ".join(test_files)
extra = f"{extra} {f}"
context.run(f"{python_bin} -m coverage run -m pytest {extra}")


Expand Down
17 changes: 17 additions & 0 deletions tests/test_atomic_310.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from instruct import Base, public_class


def test_subtraction():
class Foo(Base):
a: tuple[int, ...]
b: str | int
c: str

class FooGroup(Base):
foos: list[Foo]

cls = FooGroup - {"foos": {"a", "b"}}

ModifiedFoo = public_class(cls, "foos", preserve_subtraction=True)

assert tuple(ModifiedFoo) == ("c",)

0 comments on commit 2117a8d

Please sign in to comment.