Skip to content

Commit

Permalink
Return early from concurrency() and rate_limit() without limit names (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
abrookins authored Jul 30, 2024
1 parent efd02de commit 0335c76
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/prefect/concurrency/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ async def main():
await resource_heavy()
```
"""
if not names:
yield
return

names = names if isinstance(names, list) else [names]

limits = await _acquire_concurrency_slots(
names,
occupy,
Expand Down Expand Up @@ -111,7 +116,11 @@ async def rate_limit(
raising a `TimeoutError`. A timeout of `None` will wait indefinitely.
create_if_missing: Whether to create the concurrency limits if they do not exist.
"""
if not names:
return

names = names if isinstance(names, list) else [names]

limits = await _acquire_concurrency_slots(
names,
occupy,
Expand Down
8 changes: 8 additions & 0 deletions src/prefect/concurrency/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def main():
resource_heavy()
```
"""
if not names:
yield
return

names = names if isinstance(names, list) else [names]

limits: List[MinimalConcurrencyLimitResponse] = _call_async_function_from_sync(
Expand Down Expand Up @@ -110,7 +114,11 @@ def rate_limit(
raising a `TimeoutError`. A timeout of `None` will wait indefinitely.
create_if_missing: Whether to create the concurrency limits if they do not exist.
"""
if not names:
return

names = names if isinstance(names, list) else [names]

limits = _call_async_function_from_sync(
_acquire_concurrency_slots,
names,
Expand Down
54 changes: 54 additions & 0 deletions tests/concurrency/test_concurrency_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,33 @@ async def resource_heavy():
}


@pytest.mark.parametrize("names", [[], None])
async def test_rate_limit_without_limit_names(names):
executed = False

async def resource_heavy():
nonlocal executed
await rate_limit(names=names, occupy=1)
executed = True

assert not executed

with mock.patch(
"prefect.concurrency.asyncio._acquire_concurrency_slots",
wraps=lambda *args, **kwargs: None,
) as acquire_spy:
with mock.patch(
"prefect.concurrency.asyncio._release_concurrency_slots",
wraps=lambda *args, **kwargs: None,
) as release_spy:
await resource_heavy()

acquire_spy.assert_not_called()
release_spy.assert_not_called()

assert executed


async def test_concurrency_creates_new_limits_if_requested(
concurrency_limit: ConcurrencyLimitV2,
):
Expand Down Expand Up @@ -401,3 +428,30 @@ async def resource_heavy():
assert occupy_seconds > 0

assert executed


@pytest.mark.parametrize("names", [[], None])
async def test_concurrency_without_limit_names(names):
executed = False

async def resource_heavy():
nonlocal executed
async with concurrency(names=names, occupy=1):
executed = True

assert not executed

with mock.patch(
"prefect.concurrency.asyncio._acquire_concurrency_slots",
wraps=lambda *args, **kwargs: None,
) as acquire_spy:
with mock.patch(
"prefect.concurrency.asyncio._release_concurrency_slots",
wraps=lambda *args, **kwargs: None,
) as release_spy:
await resource_heavy()

acquire_spy.assert_not_called()
release_spy.assert_not_called()

assert executed
54 changes: 54 additions & 0 deletions tests/concurrency/test_concurrency_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,33 @@ def my_flow():
assert executed


@pytest.mark.parametrize("names", [[], None])
def test_rate_limit_without_limit_names_sync(names):
executed = False

def resource_heavy():
nonlocal executed
rate_limit(names=names, occupy=1)
executed = True

assert not executed

with mock.patch(
"prefect.concurrency.sync._acquire_concurrency_slots",
wraps=lambda *args, **kwargs: None,
) as acquire_spy:
with mock.patch(
"prefect.concurrency.sync._release_concurrency_slots",
wraps=lambda *args, **kwargs: None,
) as release_spy:
resource_heavy()

acquire_spy.assert_not_called()
release_spy.assert_not_called()

assert executed


async def test_concurrency_can_be_used_while_event_loop_is_running(
concurrency_limit: ConcurrencyLimitV2,
):
Expand Down Expand Up @@ -350,3 +377,30 @@ def resource_heavy():
),
"prefect.resource.role": "concurrency-limit",
}


@pytest.mark.parametrize("names", [[], None])
def test_concurrency_without_limit_names_sync(names):
executed = False

def resource_heavy():
nonlocal executed
with concurrency(names=names, occupy=1):
executed = True

assert not executed

with mock.patch(
"prefect.concurrency.sync._acquire_concurrency_slots",
wraps=lambda *args, **kwargs: None,
) as acquire_spy:
with mock.patch(
"prefect.concurrency.sync._release_concurrency_slots",
wraps=lambda *args, **kwargs: None,
) as release_spy:
resource_heavy()

acquire_spy.assert_not_called()
release_spy.assert_not_called()

assert executed

0 comments on commit 0335c76

Please sign in to comment.