Skip to content

Commit

Permalink
Apply Gemini's ideas for improving the types.
Browse files Browse the repository at this point in the history
Closes #2
  • Loading branch information
jaraco committed Jul 26, 2024
1 parent d892f53 commit e051242
Showing 1 changed file with 31 additions and 12 deletions.
43 changes: 31 additions & 12 deletions jaraco/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@
import subprocess
import sys
import tempfile
import types
import urllib.request
from typing import Iterator
from typing import Iterator, TypeVar, Union, Optional, Callable, Type, Tuple


if sys.version_info < (3, 12):
from backports import tarfile
else:
import tarfile

PathLike = Union[str, os.PathLike]

This comment has been minimized.

Copy link
@Avasam

Avasam Jul 29, 2024

Contributor

StrPath = Union[str, os.PathLike[str]] would be more accurate, but that would lead to an error in Python 3.8
However this alias already exists and is used everywhere in typeshed.

if TYPE_CHECKING:
    from _typeshed import StrPath

If you want a runtime alias that is complete and works on all supported Python version, here's how I did it in setuptools: https://github.com/pypa/setuptools/blob/5e1b3c414779317bc3e105d9bae82ce70c22dbf9/setuptools/_path.py#L9-L12

(There's also a str | os.PathLike leftover below)

T = TypeVar('T')


@contextlib.contextmanager
def pushd(dir: str | os.PathLike) -> Iterator[str | os.PathLike]:
def pushd(dir: PathLike) -> Iterator[PathLike]:
"""
>>> tmp_path = getfixture('tmp_path')
>>> with pushd(tmp_path):
Expand All @@ -37,8 +41,8 @@ def pushd(dir: str | os.PathLike) -> Iterator[str | os.PathLike]:

@contextlib.contextmanager
def tarball(
url, target_dir: str | os.PathLike | None = None
) -> Iterator[str | os.PathLike]:
url: str, target_dir: str | os.PathLike | None = None
) -> Iterator[PathLike]:
"""
Get a URL to a tarball, download, extract, yield, then clean up.
Expand Down Expand Up @@ -89,7 +93,11 @@ def strip_first_component(
return member


def _compose(*cmgrs):
CM = TypeVar('CM', bound=contextlib.AbstractContextManager)
"""Type var for context managers."""


def _compose(*cmgrs: Callable[..., CM]) -> Callable[..., CM]:
"""
Compose any number of dependent context managers into a single one.
Expand Down Expand Up @@ -126,7 +134,7 @@ def composed(*args, **kwargs):


@contextlib.contextmanager
def temp_dir(remover=shutil.rmtree):
def temp_dir(remover: Callable[[str], None] = shutil.rmtree) -> Iterator[str]:
"""
Create a temporary directory context. Pass a custom remover
to override the removal behavior.
Expand All @@ -145,7 +153,12 @@ def temp_dir(remover=shutil.rmtree):


@contextlib.contextmanager
def repo_context(url, branch: str | None = None, quiet: bool = True, dest_ctx=temp_dir):
def repo_context(
url,
branch: str | None = None,
quiet: bool = True,
dest_ctx: Callable[[], contextlib.AbstractContextManager[str]] = temp_dir,
):
"""
Check out the repo indicated by url.
Expand All @@ -167,7 +180,7 @@ def repo_context(url, branch: str | None = None, quiet: bool = True, dest_ctx=te
yield repo_dir


class ExceptionTrap:
class ExceptionTrap(contextlib.AbstractContextManager):
"""
A context manager that will catch certain exceptions and provide an
indication they occurred.
Expand Down Expand Up @@ -201,9 +214,13 @@ class ExceptionTrap:
False
"""

exc_info = None, None, None
exc_info: Tuple[
Optional[Type[BaseException]],
Optional[BaseException],
Optional[types.TracebackType],
] = (None, None, None) # Explicitly type the tuple

def __init__(self, exceptions=(Exception,)):
def __init__(self, exceptions: Tuple[Type[BaseException], ...] = (Exception,)):
self.exceptions = exceptions

def __enter__(self):
Expand Down Expand Up @@ -231,7 +248,9 @@ def __exit__(self, *exc_info):
def __bool__(self):
return bool(self.type)

def raises(self, func, *, _test=bool):
def raises(
self, func: Callable[..., T], *, _test: Callable[[ExceptionTrap], bool] = bool
):
"""
Wrap func and replace the result with the truth
value of the trap (True if an exception occurred).
Expand All @@ -258,7 +277,7 @@ def wrapper(*args, **kwargs):

return wrapper

def passes(self, func):
def passes(self, func: Callable[..., T]) -> Callable[..., bool]:
"""
Wrap func and replace the result with the truth
value of the trap (True if no exception).
Expand Down

0 comments on commit e051242

Please sign in to comment.