Skip to content

Commit

Permalink
Incorporate PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
bunchesofdonald committed Jul 26, 2024
1 parent 87db377 commit 1fbba5c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 23 deletions.
21 changes: 6 additions & 15 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,6 @@ def client(self) -> SyncPrefectClient:
raise RuntimeError("Engine has not started.")
return self._client

def sleep(self, interval: float):
time.sleep(interval)

def call_hooks(self, state: Optional[State] = None):
if state is None:
state = self.state
Expand Down Expand Up @@ -374,7 +371,7 @@ def begin_run(self):
interval = clamped_poisson_interval(
average_interval=backoff_count, clamping_factor=0.3
)
self.sleep(interval)
time.sleep(interval)
state = self.set_state(new_state)

def set_state(self, state: State, force: bool = False) -> State:
Expand Down Expand Up @@ -485,7 +482,7 @@ def handle_retry(self, exc: Exception) -> bool:
"""Handle any task run retries.
- If the task has retries left, and the retry condition is met, set the task to retrying and return True.
- If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
- If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
- If the task has no retries left, or the retry condition is not met, return False.
"""
if self.retries < self.task.retries and self.can_retry(exc):
Expand Down Expand Up @@ -796,9 +793,6 @@ def client(self) -> PrefectClient:
raise RuntimeError("Engine has not started.")
return self._client

async def sleep(self, interval: float):
await anyio.sleep(interval)

async def call_hooks(self, state: Optional[State] = None):
if state is None:
state = self.state
Expand Down Expand Up @@ -883,7 +877,7 @@ async def begin_run(self):
interval = clamped_poisson_interval(
average_interval=backoff_count, clamping_factor=0.3
)
await self.sleep(interval)
await anyio.sleep(interval)
state = await self.set_state(new_state)

async def set_state(self, state: State, force: bool = False) -> State:
Expand Down Expand Up @@ -935,10 +929,7 @@ async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]"
if self._return_value is not NotSet:
# if the return value is a BaseResult, we need to fetch it
if isinstance(self._return_value, BaseResult):
_result = self._return_value.get()
if inspect.isawaitable(_result):
_result = await _result
return _result
return await self._return_value.get()

# otherwise, return the value as is
return self._return_value
Expand Down Expand Up @@ -992,10 +983,10 @@ async def handle_retry(self, exc: Exception) -> bool:
"""Handle any task run retries.
- If the task has retries left, and the retry condition is met, set the task to retrying and return True.
- If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
- If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
- If the task has no retries left, or the retry condition is not met, return False.
"""
if self.retries < self.task.retries and self.can_retry:
if self.retries < self.task.retries and self.can_retry(exc):
if self.task.retry_delay_seconds:
delay = (
self.task.retry_delay_seconds[
Expand Down
12 changes: 6 additions & 6 deletions tests/test_flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings
from textwrap import dedent
from typing import Optional
from unittest import mock
from unittest.mock import MagicMock
from uuid import UUID

Expand Down Expand Up @@ -36,7 +37,6 @@
from prefect.input.run_input import RunInput
from prefect.logging import get_run_logger
from prefect.server.schemas.core import FlowRun as ServerFlowRun
from prefect.task_engine import AsyncTaskRunEngine
from prefect.testing.utilities import AsyncMock
from prefect.utilities.callables import get_call_parameters
from prefect.utilities.filesystem import tmpchdir
Expand Down Expand Up @@ -1126,9 +1126,8 @@ async def flow_resumer():
)
assert schema is not None

async def test_paused_task_polling(self, monkeypatch, prefect_client):
async def test_paused_task_polling(self, prefect_client):
sleeper = AsyncMock(side_effect=[None, None, None, None, None])
monkeypatch.setattr(AsyncTaskRunEngine, "sleep", sleeper)

@task
async def doesnt_pause():
Expand All @@ -1153,9 +1152,10 @@ async def pausing_flow():

# execution isn't blocked, so this task should enter the engine, but not begin
# execution
with pytest.raises(RuntimeError):
# the sleeper mock will exhaust its side effects after 6 calls
await doesnt_run()
with mock.patch("prefect.task_engine.anyio.sleep", sleeper):
with pytest.raises(RuntimeError):
# the sleeper mock will exhaust its side effects after 6 calls
await doesnt_run()

await pausing_flow()

Expand Down
7 changes: 5 additions & 2 deletions tests/test_task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,7 +1204,10 @@ def my_task():
with pytest.raises(interrupt_type):
my_task()

task_run = await get_task_run(task_run_id=None)
await events_pipeline.process_events()
task_runs = await prefect_client.read_task_runs()
assert len(task_runs) == 1
task_run = task_runs[0]
assert task_run.state.is_crashed()
assert task_run.state.type == StateType.CRASHED
assert "Execution was aborted" in task_run.state.message
Expand All @@ -1213,7 +1216,7 @@ def my_task():

@pytest.mark.parametrize("interrupt_type", [KeyboardInterrupt, SystemExit])
async def test_interrupt_in_task_orchestration_crashes_task_and_flow_async(
self, interrupt_type, monkeypatch
self, prefect_client, events_pipeline, interrupt_type, monkeypatch
):
monkeypatch.setattr(
AsyncTaskRunEngine, "begin_run", MagicMock(side_effect=interrupt_type)
Expand Down

0 comments on commit 1fbba5c

Please sign in to comment.