Skip to content

Commit

Permalink
Merge branch 'master' into pr-osparc-connect-opentelemetry-to-missing…
Browse files Browse the repository at this point in the history
…-services
  • Loading branch information
GitHK authored Nov 14, 2024
2 parents 8c0f106 + 0781e63 commit b2e0b38
Show file tree
Hide file tree
Showing 11 changed files with 481 additions and 283 deletions.
36 changes: 30 additions & 6 deletions packages/service-library/src/servicelib/redis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from collections.abc import Awaitable, Callable
from datetime import timedelta
from typing import Any
from typing import Any, ParamSpec, TypeVar

import arrow

Expand All @@ -12,10 +12,16 @@

_logger = logging.getLogger(__file__)

P = ParamSpec("P")
R = TypeVar("R")


def exclusive(
redis: RedisClientSDK, *, lock_key: str, lock_value: bytes | str | None = None
):
redis: RedisClientSDK | Callable[..., RedisClientSDK],
*,
lock_key: str | Callable[..., str],
lock_value: bytes | str | None = None,
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
"""
Define a method to run exclusively across
processes by leveraging a Redis Lock.
Expand All @@ -24,12 +30,30 @@ def exclusive(
redis: the redis client SDK
lock_key: a string as the name of the lock (good practice: app_name:lock_name)
lock_value: some additional data that can be retrieved by another client
Raises:
- ValueError if used incorrectly
- CouldNotAcquireLockError if the lock could not be acquired
"""

def decorator(func):
if not lock_key:
msg = "lock_key cannot be empty string!"
raise ValueError(msg)

def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
@functools.wraps(func)
async def wrapper(*args, **kwargs):
async with redis.lock_context(lock_key=lock_key, lock_value=lock_value):
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
redis_lock_key = (
lock_key(*args, **kwargs) if callable(lock_key) else lock_key
)
assert isinstance(redis_lock_key, str) # nosec

redis_client = redis(*args, **kwargs) if callable(redis) else redis
assert isinstance(redis_client, RedisClientSDK) # nosec

async with redis_client.lock_context(
lock_key=redis_lock_key, lock_value=lock_value
):
return await func(*args, **kwargs)

return wrapper
Expand Down
127 changes: 104 additions & 23 deletions packages/service-library/tests/test_redis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import AbstractAsyncContextManager
from datetime import timedelta
from itertools import chain
from typing import Awaitable
from unittest.mock import Mock

import arrow
Expand Down Expand Up @@ -32,39 +33,117 @@ async def _is_locked(redis_client_sdk: RedisClientSDK, lock_name: str) -> bool:

@pytest.fixture
def lock_name(faker: Faker) -> str:
return faker.uuid4() # type: ignore
return faker.pystr()


def _exclusive_sleeping_task(
redis_client_sdk: RedisClientSDK | Callable[..., RedisClientSDK],
lock_name: str | Callable[..., str],
sleep_duration: float,
) -> Callable[..., Awaitable[float]]:
@exclusive(redis_client_sdk, lock_key=lock_name)
async def _() -> float:
resolved_client = (
redis_client_sdk() if callable(redis_client_sdk) else redis_client_sdk
)
resolved_lock_name = lock_name() if callable(lock_name) else lock_name
assert await _is_locked(resolved_client, resolved_lock_name)
await asyncio.sleep(sleep_duration)
assert await _is_locked(resolved_client, resolved_lock_name)
return sleep_duration

return _


@pytest.fixture
def sleep_duration(faker: Faker) -> float:
return faker.pyfloat(positive=True, min_value=0.2, max_value=0.8)


async def _contained_client(
async def test_exclusive_decorator(
get_redis_client_sdk: Callable[
[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]
],
lock_name: str,
task_duration: float,
) -> None:
async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client_sdk:
assert not await _is_locked(redis_client_sdk, lock_name)

@exclusive(redis_client_sdk, lock_key=lock_name)
async def _some_task() -> None:
assert await _is_locked(redis_client_sdk, lock_name)
await asyncio.sleep(task_duration)
assert await _is_locked(redis_client_sdk, lock_name)

await _some_task()
sleep_duration: float,
):

assert not await _is_locked(redis_client_sdk, lock_name)
async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client:
for _ in range(3):
assert (
await _exclusive_sleeping_task(
redis_client, lock_name, sleep_duration
)()
== sleep_duration
)


@pytest.mark.parametrize("task_duration", [0.1, 1, 2])
async def test_exclusive_sequentially(
async def test_exclusive_decorator_with_key_builder(
get_redis_client_sdk: Callable[
[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]
],
lock_name: str,
task_duration: float,
sleep_duration: float,
):
await _contained_client(get_redis_client_sdk, lock_name, task_duration)
def _get_lock_name(*args, **kwargs) -> str:
assert args is not None
assert kwargs is not None
return lock_name

async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client:
for _ in range(3):
assert (
await _exclusive_sleeping_task(
redis_client, _get_lock_name, sleep_duration
)()
== sleep_duration
)


async def test_exclusive_decorator_with_client_builder(
get_redis_client_sdk: Callable[
[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]
],
lock_name: str,
sleep_duration: float,
):
async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client:

def _get_redis_client_builder(*args, **kwargs) -> RedisClientSDK:
assert args is not None
assert kwargs is not None
return redis_client

for _ in range(3):
assert (
await _exclusive_sleeping_task(
_get_redis_client_builder, lock_name, sleep_duration
)()
== sleep_duration
)


async def _acquire_lock_and_exclusively_sleep(
get_redis_client_sdk: Callable[
[RedisDatabase], AbstractAsyncContextManager[RedisClientSDK]
],
lock_name: str | Callable[..., str],
sleep_duration: float,
) -> None:
async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client_sdk:
redis_lock_name = lock_name() if callable(lock_name) else lock_name
assert not await _is_locked(redis_client_sdk, redis_lock_name)

@exclusive(redis_client_sdk, lock_key=lock_name)
async def _() -> float:
assert await _is_locked(redis_client_sdk, redis_lock_name)
await asyncio.sleep(sleep_duration)
assert await _is_locked(redis_client_sdk, redis_lock_name)
return sleep_duration

assert await _() == sleep_duration

assert not await _is_locked(redis_client_sdk, redis_lock_name)


async def test_exclusive_parallel_lock_is_released_and_reacquired(
Expand All @@ -76,17 +155,19 @@ async def test_exclusive_parallel_lock_is_released_and_reacquired(
parallel_tasks = 10
results = await logged_gather(
*[
_contained_client(get_redis_client_sdk, lock_name, task_duration=0.1)
_acquire_lock_and_exclusively_sleep(
get_redis_client_sdk, lock_name, sleep_duration=0.1
)
for _ in range(parallel_tasks)
],
reraise=False
reraise=False,
)
assert results.count(None) == 1
assert [isinstance(x, CouldNotAcquireLockError) for x in results].count(
True
) == parallel_tasks - 1

# check lock is being released
# check lock is released
async with get_redis_client_sdk(RedisDatabase.RESOURCES) as redis_client_sdk:
assert not await _is_locked(redis_client_sdk, lock_name)

Expand Down Expand Up @@ -168,7 +249,7 @@ async def test_start_exclusive_periodic_task_parallel_all_finish(
_assert_task_completes_once(get_redis_client_sdk, stop_after=60)
for _ in range(parallel_tasks)
],
reraise=False
reraise=False,
)

# check no error occurred
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Annotated

from fastapi import Depends, FastAPI, Request

from ...core.settings import ComputationalBackendSettings
Expand All @@ -11,7 +13,7 @@ def get_scheduler(request: Request) -> BaseCompScheduler:


def get_scheduler_settings(
app: FastAPI = Depends(get_app),
app: Annotated[FastAPI, Depends(get_app)]
) -> ComputationalBackendSettings:
settings: ComputationalBackendSettings = (
app.state.settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,38 @@
import logging
from collections.abc import Callable, Coroutine
from typing import Any, cast

from fastapi import FastAPI
from servicelib.logging_utils import log_context

from . import _scheduler_factory
from ._base_scheduler import BaseCompScheduler
from ._task import on_app_shutdown, on_app_startup

_logger = logging.getLogger(__name__)


def on_app_startup(app: FastAPI) -> Callable[[], Coroutine[Any, Any, None]]:
async def start_scheduler() -> None:
with log_context(
_logger, level=logging.INFO, msg="starting computational scheduler"
):
app.state.scheduler = scheduler = await _scheduler_factory.create_from_db(
app
)
scheduler.recover_scheduling()

return start_scheduler


def on_app_shutdown(app: FastAPI) -> Callable[[], Coroutine[Any, Any, None]]:
async def stop_scheduler() -> None:
await get_scheduler(app).shutdown()

return stop_scheduler


def get_scheduler(app: FastAPI) -> BaseCompScheduler:
return cast(BaseCompScheduler, app.state.scheduler)


def setup(app: FastAPI):
Expand All @@ -12,4 +43,5 @@ def setup(app: FastAPI):
__all__: tuple[str, ...] = (
"setup",
"BaseCompScheduler",
"get_scheduler",
)
Loading

0 comments on commit b2e0b38

Please sign in to comment.