Skip to content

Commit

Permalink
feat: Add context.resetting() (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol authored Aug 26, 2023
1 parent fc73e3f commit b0744bd
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 8 deletions.
1 change: 1 addition & 0 deletions changes/62.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `aiotools.context.resetting()` as a sync/async context manager to auto-reset the given context variable
56 changes: 49 additions & 7 deletions src/aiotools/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,19 @@

import asyncio
import contextlib
from typing import Iterable, List, Optional
from contextvars import ContextVar
from typing import (
Generic,
Iterable,
List,
Optional,
TypeVar,
)

from .types import AsyncClosable

__all__ = [
"resetting",
"AsyncContextManager",
"async_ctx_manager",
"actxmgr",
Expand All @@ -24,29 +34,61 @@
]


T = TypeVar("T")
T_AsyncClosable = TypeVar("T_AsyncClosable", bound=AsyncClosable)

AbstractAsyncContextManager = contextlib.AbstractAsyncContextManager
AsyncContextManager = contextlib._AsyncGeneratorContextManager
AsyncExitStack = contextlib.AsyncExitStack
async_ctx_manager = contextlib.asynccontextmanager
aclosing = contextlib.aclosing


class closing_async:
class resetting(Generic[T]):
"""
An extra context manager to auto-reset the given context variable.
It supports both the standard contextmanager protocol and the
async-contextmanager protocol.
.. versionadded:: 1.8.0
"""
An analogy to :func:`contextlib.closing` for objects with ``close()``
methods as async functions.

def __init__(self, ctxvar: ContextVar[T], value: T) -> None:
self._ctxvar = ctxvar
self._value = value

def __enter__(self) -> None:
self._token = self._ctxvar.set(self._value)

async def __aenter__(self) -> None:
self._token = self._ctxvar.set(self._value)

def __exit__(self, *exc_info) -> Optional[bool]:
self._ctxvar.reset(self._token)
return None

async def __aexit__(self, *exc_info) -> Optional[bool]:
self._ctxvar.reset(self._token)
return None


class closing_async(Generic[T_AsyncClosable]):
"""
An analogy to :func:`contextlib.closing` for objects defining the ``close()``
method as an async function.
.. versionadded:: 1.5.6
"""

def __init__(self, thing):
def __init__(self, thing: T_AsyncClosable) -> None:
self.thing = thing

async def __aenter__(self):
async def __aenter__(self) -> T_AsyncClosable:
return self.thing

async def __aexit__(self, *args):
async def __aexit__(self, *exc_info) -> Optional[bool]:
await self.thing.close()
return None


class AsyncContextGroup:
Expand Down
9 changes: 9 additions & 0 deletions src/aiotools/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,21 @@
Awaitable,
Coroutine,
Generator,
Protocol,
TypeAlias,
TypeVar,
runtime_checkable,
)

_T = TypeVar("_T")


@runtime_checkable
class AsyncClosable(Protocol):
async def close(self) -> None:
...


# taken from the typeshed
if sys.version_info >= (3, 12):
AwaitableLike: TypeAlias = Awaitable[_T] # noqa: Y047
Expand Down
47 changes: 46 additions & 1 deletion tests/test_context.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,56 @@
import asyncio
import sys
import warnings
from contextlib import suppress
from contextvars import ContextVar

import pytest

import aiotools
from aiotools.context import AbstractAsyncContextManager
from aiotools.context import AbstractAsyncContextManager, resetting

my_variable: ContextVar[int] = ContextVar("my_variable")


def test_resetting_ctxvar():
with pytest.raises(LookupError):
my_variable.get()
with resetting(my_variable, 1):
assert my_variable.get() == 1
with resetting(my_variable, 2):
assert my_variable.get() == 2
assert my_variable.get() == 1
with pytest.raises(LookupError):
my_variable.get()

# should behave the same way even when an exception occurs
with suppress(RuntimeError):
with resetting(my_variable, 10):
assert my_variable.get() == 10
raise RuntimeError("oops")
with pytest.raises(LookupError):
my_variable.get()


@pytest.mark.asyncio
async def test_resetting_ctxvar_async():
with pytest.raises(LookupError):
my_variable.get()
async with resetting(my_variable, 1):
assert my_variable.get() == 1
async with resetting(my_variable, 2):
assert my_variable.get() == 2
assert my_variable.get() == 1
with pytest.raises(LookupError):
my_variable.get()

# should behave the same way even when an exception occurs
with suppress(RuntimeError):
async with resetting(my_variable, 10):
assert my_variable.get() == 10
raise RuntimeError("oops")
with pytest.raises(LookupError):
my_variable.get()


def test_actxmgr_types():
Expand Down

0 comments on commit b0744bd

Please sign in to comment.