Skip to content

Commit

Permalink
Fix worker capacity test
Browse files Browse the repository at this point in the history
  • Loading branch information
bunchesofdonald committed Jul 19, 2024
1 parent c5f4a91 commit f2f521a
Showing 1 changed file with 33 additions and 57 deletions.
90 changes: 33 additions & 57 deletions tests/test_task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,80 +769,56 @@ async def mock_iter():
async def test_tasks_execute_when_capacity_frees_up(
self, mock_subscription, prefect_client
):
event = asyncio.Event()
execution_order = []

@task
async def slow_task():
await asyncio.sleep(1)
if event.is_set():
raise ValueError("Something went wrong! This event should not be set.")
event.set()
async def slow_task(task_id: str):
execution_order.append(f"{task_id} start")
await asyncio.sleep(0.1) # Simulating some work
execution_order.append(f"{task_id} end")

task_worker = TaskWorker(slow_task, limit=1)

task_run_future_1 = slow_task.apply_async()
task_run_future_1 = slow_task.apply_async(("task1",))
task_run_1 = await prefect_client.read_task_run(task_run_future_1.task_run_id)
task_run_future_2 = slow_task.apply_async()
task_run_future_2 = slow_task.apply_async(("task2",))
task_run_2 = await prefect_client.read_task_run(task_run_future_2.task_run_id)

async def mock_iter():
yield task_run_1
yield task_run_2
# sleep for a second to ensure that task execution starts
await asyncio.sleep(1)
while len(execution_order) < 4:
await asyncio.sleep(0.1)

mock_subscription.return_value = mock_iter()

server_task = asyncio.create_task(task_worker.start())
await event.wait()
updated_task_run_1 = await prefect_client.read_task_run(task_run_1.id)
updated_task_run_2 = await prefect_client.read_task_run(task_run_2.id)

assert updated_task_run_1.state.is_completed()
assert not updated_task_run_2.state.is_completed()

# clear the event to allow the second task to complete
event.clear()

await event.wait()
updated_task_run_2 = await prefect_client.read_task_run(task_run_2.id)

assert updated_task_run_2.state.is_completed()

server_task.cancel()
await server_task

async def test_execute_task_run_respects_limit(self, prefect_client):
@task
def slow_task():
import time

time.sleep(1)

task_worker = TaskWorker(slow_task, limit=1)

task_run_future_1 = slow_task.apply_async()
task_run_1 = await prefect_client.read_task_run(task_run_future_1.task_run_id)
task_run_future_2 = slow_task.apply_async()
task_run_2 = await prefect_client.read_task_run(task_run_future_2.task_run_id)

try:
with anyio.move_on_after(1):
# start task worker first to avoid race condition between two execute_task_run calls
async with task_worker:
await asyncio.gather(
task_worker.execute_task_run(task_run_1),
task_worker.execute_task_run(task_run_2),
)
except asyncio.exceptions.CancelledError:
# We want to cancel the second task run, so this is expected
pass

updated_task_run_1 = await prefect_client.read_task_run(task_run_1.id)
updated_task_run_2 = await prefect_client.read_task_run(task_run_2.id)

assert updated_task_run_1.state.is_completed()
assert updated_task_run_2.state.is_scheduled()
# Wait for both tasks to complete
await asyncio.sleep(2)

# Verify the execution order
assert execution_order == [
"task1 start",
"task1 end",
"task2 start",
"task2 end",
], "Tasks should execute sequentially"

# Verify the states of both tasks
updated_task_run_1 = await prefect_client.read_task_run(task_run_1.id)
updated_task_run_2 = await prefect_client.read_task_run(task_run_2.id)

assert updated_task_run_1.state.is_completed()
assert updated_task_run_2.state.is_completed()

finally:
server_task.cancel()
try:
await server_task
except asyncio.CancelledError:
pass

async def test_serve_respects_limit(self, prefect_client, mock_subscription):
@task
Expand Down

0 comments on commit f2f521a

Please sign in to comment.