Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

let TaskRunRecorder process events into task runs/task run states #14729

Merged
merged 3 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/prefect/server/events/schemas/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ def id(self) -> str:
def name(self) -> Optional[str]:
return self.get("prefect.resource.name")

def prefect_object_id(self, kind: str) -> UUID:
"""Extracts the UUID from an event's resource ID if it's the expected kind
of prefect resource"""
prefix = f"{kind}." if not kind.endswith(".") else kind

if not self.id.startswith(prefix):
raise ValueError(f"Resource ID {self.id} does not start with {prefix}")

return UUID(self.id[len(prefix) :])


class RelatedResource(Resource):
"""A Resource with a specific role in an Event"""
Expand Down
179 changes: 177 additions & 2 deletions src/prefect/server/services/task_run_recorder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,181 @@
import asyncio
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional
from contextlib import AsyncExitStack, asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Optional
from uuid import UUID

import pendulum
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncSession

from prefect.logging import get_logger
from prefect.server.database.dependencies import db_injector, provide_database_interface
from prefect.server.database.interface import PrefectDBInterface
from prefect.server.events.ordering import CausalOrdering, EventArrivedEarly
from prefect.server.events.schemas.events import ReceivedEvent
from prefect.server.schemas.core import TaskRun
from prefect.server.schemas.states import State
from prefect.server.utilities.messaging import Message, MessageHandler, create_consumer

logger = get_logger(__name__)


def causal_ordering():
return CausalOrdering(
"task-run-recorder",
)


@db_injector
async def _insert_task_run(
db: PrefectDBInterface,
session: AsyncSession,
task_run: TaskRun,
task_run_attributes: Dict[str, Any],
):
await session.execute(
db.insert(db.TaskRun)
.values(
created=pendulum.now("UTC"),
**task_run_attributes,
)
.on_conflict_do_update(
index_elements=[
"id",
],
set_={
"updated": pendulum.now("UTC"),
**task_run_attributes,
},
where=db.TaskRun.state_timestamp < task_run.state.timestamp,
)
)


@db_injector
async def _insert_task_run_state(
db: PrefectDBInterface, session: AsyncSession, task_run: TaskRun
):
await session.execute(
db.insert(db.TaskRunState)
.values(
created=pendulum.now("UTC"),
task_run_id=task_run.id,
**task_run.state.model_dump(),
)
.on_conflict_do_nothing(
index_elements=[
"id",
]
)
)


@db_injector
async def _update_task_run_with_state(
db: PrefectDBInterface,
session: AsyncSession,
task_run: TaskRun,
denormalized_state_attributes: Dict[str, Any],
):
await session.execute(
sa.update(db.TaskRun)
.where(
db.TaskRun.id == task_run.id,
sa.or_(
db.TaskRun.state_timestamp.is_(None),
db.TaskRun.state_timestamp < task_run.state.timestamp,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this fits our pattern so that we'll only do this if things are in order

),
)
.values(**denormalized_state_attributes)
)


def task_run_from_event(event: ReceivedEvent) -> TaskRun:
task_run_id = event.resource.prefect_object_id("prefect.task-run")

flow_run_id: Optional[UUID] = None
if flow_run_resource := event.resource_in_role.get("flow-run"):
flow_run_id = flow_run_resource.prefect_object_id("prefect.flow-run")

state: State = State.model_validate(
{
"id": event.id,
"timestamp": event.occurred,
**event.payload["validated_state"],
}
)
state.state_details.task_run_id = task_run_id
state.state_details.flow_run_id = flow_run_id

return TaskRun.model_validate(
{
"id": task_run_id,
"flow_run_id": flow_run_id,
"state_id": state.id,
"state": state,
**event.payload["task_run"],
}
)


async def record_task_run_event(event: ReceivedEvent, depth: int = 0):
db = provide_database_interface()

async with AsyncExitStack() as stack:
await stack.enter_async_context(
(
causal_ordering().preceding_event_confirmed(
record_task_run_event, event, depth=depth
)
)
)

task_run = task_run_from_event(event)

task_run_attributes = task_run.model_dump_for_orm(
exclude={
"state_id",
"state",
"created",
"estimated_run_time",
"estimated_start_time_delta",
},
exclude_unset=True,
)

assert task_run.state

denormalized_state_attributes = {
"state_id": task_run.state.id,
"state_type": task_run.state.type,
"state_name": task_run.state.name,
"state_timestamp": task_run.state.timestamp,
}
session = await stack.enter_async_context(
db.session_context(begin_transaction=True)
)

await _insert_task_run(session, task_run, task_run_attributes)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately due to FK's we need to insert the task run, then insert the task run state, then go back and update the task run with the denormalized state attributes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, sometimes this trips me up with the db.session_context(): begin_transaction defaults to False, so these aren't in a transaction, right? It feels like they should be

Copy link
Contributor Author

@jakekaplan jakekaplan Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

‼️ 100% yes, great catch. Should I set db.session_context(..., with_for_update=True) as well here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, that might make the lock a little more aggressive? I'm thinking about the SQLite implementation hitting database is locked. Since this should be the only thing writing task runs, I think we're good with just the transaction

await _insert_task_run_state(session, task_run)
await _update_task_run_with_state(
session, task_run, denormalized_state_attributes
)

logger.info(
"Recorded task run state change",
extra={
"task_run_id": task_run.id,
"flow_run_id": task_run.flow_run_id,
"event_id": event.id,
"event_follows": event.follows,
"event": event.event,
"occurred": event.occurred,
"current_state_type": task_run.state_type,
"current_state_name": task_run.state_name,
},
)


@asynccontextmanager
async def consumer() -> AsyncGenerator[MessageHandler, None]:
async def message_handler(message: Message):
Expand All @@ -24,6 +191,14 @@ async def message_handler(message: Message):
f"Received event: {event.event} with id: {event.id} for resource: {event.resource.get('prefect.resource.id')}"
)

try:
await record_task_run_event(event)
except EventArrivedEarly:
# We're safe to ACK this message because it has been parked by the
# causal ordering mechanism and will be reprocessed when the preceding
# event arrives.
pass

yield message_handler


Expand Down
Loading
Loading