Skip to content

Commit

Permalink
Fix task key computation (#14704)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Jul 22, 2024
1 parent 1299f1a commit c5f1970
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 68 deletions.
48 changes: 32 additions & 16 deletions src/prefect/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import datetime
import inspect
import os
from copy import copy
from functools import partial, update_wrapper
from typing import (
Expand Down Expand Up @@ -188,6 +187,31 @@ def _infer_parent_task_runs(
return parents


def _generate_task_key(fn: Callable[..., Any]) -> str:
"""Generate a task key based on the function name and source code.
We may eventually want some sort of top-level namespace here to
disambiguate tasks with the same function name in different modules,
in a more human-readable way, while avoiding relative import problems (see #12337).
As long as the task implementations are unique (even if named the same), we should
not have any collisions.
Args:
fn: The function to generate a task key for.
"""
if not hasattr(fn, "__qualname__"):
return to_qualified_name(type(fn))

qualname = fn.__qualname__.split(".")[-1]

code_hash = (
h[:NUM_CHARS_DYNAMIC_KEY] if (h := hash_objects(fn.__code__)) else "unknown"
)

return f"{qualname}-{code_hash}"


class Task(Generic[P, R]):
"""
A Prefect task definition.
Expand Down Expand Up @@ -270,7 +294,7 @@ def __init__(
description: Optional[str] = None,
tags: Optional[Iterable[str]] = None,
version: Optional[str] = None,
cache_policy: Optional[CachePolicy] = NotSet,
cache_policy: Union[CachePolicy, Type[NotSet]] = NotSet,
cache_key_fn: Optional[
Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
] = None,
Expand Down Expand Up @@ -369,17 +393,7 @@ def __init__(

self.tags = set(tags if tags else [])

if not hasattr(self.fn, "__qualname__"):
self.task_key = to_qualified_name(type(self.fn))
else:
try:
task_origin_hash = hash_objects(
self.name, os.path.abspath(inspect.getsourcefile(self.fn))
)
except TypeError:
task_origin_hash = "unknown-source-file"

self.task_key = f"{self.fn.__qualname__}-{task_origin_hash}"
self.task_key = _generate_task_key(self.fn)

if cache_policy is not NotSet and cache_key_fn is not None:
logger.warning(
Expand Down Expand Up @@ -1496,7 +1510,7 @@ async def serve(self) -> NoReturn:
Args:
task_runner: The task runner to use for serving the task. If not provided,
the default ConcurrentTaskRunner will be used.
the default task runner will be used.
Examples:
Serve a task using the default task runner
Expand All @@ -1523,7 +1537,7 @@ def task(
description: Optional[str] = None,
tags: Optional[Iterable[str]] = None,
version: Optional[str] = None,
cache_policy: CachePolicy = NotSet,
cache_policy: Union[CachePolicy, Type[NotSet]] = NotSet,
cache_key_fn: Optional[
Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
] = None,
Expand Down Expand Up @@ -1561,7 +1575,9 @@ def task(
tags: Optional[Iterable[str]] = None,
version: Optional[str] = None,
cache_policy: Union[CachePolicy, Type[NotSet]] = NotSet,
cache_key_fn: Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] = None,
cache_key_fn: Union[
Callable[["TaskRunContext", Dict[str, Any]], Optional[str]], None
] = None,
cache_expiration: Optional[datetime.timedelta] = None,
task_run_name: Optional[Union[Callable[[], str], str]] = None,
retries: Optional[int] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,7 @@ def happy_path():
== task_run.expected_start_time
)
assert pending.payload["task_run"].pop("estimated_start_time_delta") > 0.0
assert (
pending.payload["task_run"]
.pop("task_key")
.startswith("test_task_state_change_happy_path.<locals>.happy_little_tree")
)
assert pending.payload["task_run"].pop("task_key").startswith("happy_little_tree")
assert pending.payload == {
"initial_state": None,
"intended": {"from": None, "to": "PENDING"},
Expand Down Expand Up @@ -112,11 +108,7 @@ def happy_path():
== task_run.expected_start_time
)
assert running.payload["task_run"].pop("estimated_start_time_delta") > 0.0
assert (
running.payload["task_run"]
.pop("task_key")
.startswith("test_task_state_change_happy_path.<locals>.happy_little_tree")
)
assert running.payload["task_run"].pop("task_key").startswith("happy_little_tree")
assert running.payload == {
"intended": {"from": "PENDING", "to": "RUNNING"},
"initial_state": {
Expand Down Expand Up @@ -169,11 +161,7 @@ def happy_path():
== task_run.expected_start_time
)
assert completed.payload["task_run"].pop("estimated_start_time_delta") > 0.0
assert (
completed.payload["task_run"]
.pop("task_key")
.startswith("test_task_state_change_happy_path.<locals>.happy_little_tree")
)
assert completed.payload["task_run"].pop("task_key").startswith("happy_little_tree")
assert completed.payload["task_run"].pop("estimated_run_time") > 0.0
assert (
pendulum.parse(completed.payload["task_run"].pop("start_time"))
Expand Down Expand Up @@ -262,11 +250,7 @@ def happy_path():
== task_run.expected_start_time
)
assert pending.payload["task_run"].pop("estimated_start_time_delta") > 0.0
assert (
pending.payload["task_run"]
.pop("task_key")
.startswith("test_task_state_change_task_failure.<locals>.happy_little_tree")
)
assert pending.payload["task_run"].pop("task_key").startswith("happy_little_tree")
assert pending.payload == {
"initial_state": None,
"intended": {"from": None, "to": "PENDING"},
Expand Down Expand Up @@ -314,11 +298,7 @@ def happy_path():
== task_run.expected_start_time
)
assert running.payload["task_run"].pop("estimated_start_time_delta") > 0.0
assert (
running.payload["task_run"]
.pop("task_key")
.startswith("test_task_state_change_task_failure.<locals>.happy_little_tree")
)
assert running.payload["task_run"].pop("task_key").startswith("happy_little_tree")
assert running.payload == {
"intended": {"from": "PENDING", "to": "RUNNING"},
"initial_state": {
Expand Down Expand Up @@ -374,11 +354,7 @@ def happy_path():
== task_run.expected_start_time
)
assert failed.payload["task_run"].pop("estimated_start_time_delta") > 0.0
assert (
failed.payload["task_run"]
.pop("task_key")
.startswith("test_task_state_change_task_failure.<locals>.happy_little_tree")
)
assert failed.payload["task_run"].pop("task_key").startswith("happy_little_tree")
assert failed.payload["task_run"].pop("estimated_run_time") > 0.0
assert (
pendulum.parse(failed.payload["task_run"].pop("start_time"))
Expand Down
22 changes: 0 additions & 22 deletions tests/test_background_tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import asyncio
import inspect
import os
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, AsyncGenerator, Iterable, Tuple
Expand All @@ -26,7 +24,6 @@
temporary_settings,
)
from prefect.task_worker import TaskWorker
from prefect.utilities.hashing import hash_objects

if TYPE_CHECKING:
from prefect.client.orchestration import PrefectClient
Expand Down Expand Up @@ -447,22 +444,3 @@ async def bar(x: int, mappable: Iterable) -> Tuple[int, Iterable]:
"parameters": {"x": i + 1, "mappable": ["some", "iterable"]},
"context": mock.ANY,
}


class TestTaskKey:
def test_task_key_includes_qualname_and_source_file_hash(self):
def some_fn():
pass

t = Task(fn=some_fn)
source_file = os.path.abspath(inspect.getsourcefile(some_fn))
task_origin_hash = hash_objects(t.name, source_file)
assert t.task_key == f"{some_fn.__qualname__}-{task_origin_hash}"

def test_task_key_handles_unknown_source_file(self, monkeypatch):
def some_fn():
pass

monkeypatch.setattr(inspect, "getsourcefile", lambda x: None)
t = Task(fn=some_fn)
assert t.task_key == f"{some_fn.__qualname__}-unknown-source-file"
14 changes: 14 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,20 @@ def my_task():
assert my_task.name == "another_name"


class TestTaskKey:
def test_task_key_typical_case(self):
@task
def my_task():
pass

assert my_task.task_key.startswith("my_task-")

def test_task_key_after_import(self):
from tests.generic_tasks import noop

assert noop.task_key.startswith("noop-")


class TestTaskRunName:
def test_run_name_default(self):
@task
Expand Down

0 comments on commit c5f1970

Please sign in to comment.