Skip to content

Commit

Permalink
[typing] prefect.server.utilities.database
Browse files Browse the repository at this point in the history
This is a complete refactor of this module.

- Move functions into the `sqlalchemy.func` namespace so they don't need to be imported everywhere
- Re-use SQLAlchemy's Postgresql JSONB operators by providing SQLite equivalents
- Provide a new function that calculates the difference between timestamps as seconds.

This removes the need for many separate PostgreSQL vs SQLite queries.
  • Loading branch information
mjpieters committed Dec 12, 2024
1 parent f4f5963 commit 8abba49
Show file tree
Hide file tree
Showing 13 changed files with 705 additions and 699 deletions.
6 changes: 4 additions & 2 deletions src/prefect/server/api/run_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,14 @@ async def run_history(
# estimated run times only includes positive run times (to avoid any unexpected corner cases)
"sum_estimated_run_time",
sa.func.sum(
db.greatest(0, sa.extract("epoch", runs.c.estimated_run_time))
sa.func.greatest(
0, sa.extract("epoch", runs.c.estimated_run_time)
)
),
# estimated lateness is the sum of any positive start time deltas
"sum_estimated_lateness",
sa.func.sum(
db.greatest(
sa.func.greatest(
0, sa.extract("epoch", runs.c.estimated_start_time_delta)
)
),
Expand Down
43 changes: 5 additions & 38 deletions src/prefect/server/api/ui/task_runs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
from datetime import datetime, timezone
from datetime import datetime
from typing import List, Optional, cast

import pendulum
Expand Down Expand Up @@ -37,37 +36,6 @@ def ser_model(self) -> dict:
}


def _postgres_bucket_expression(
db: PrefectDBInterface, delta: pendulum.Duration, start_datetime: datetime
):
# asyncpg under Python 3.7 doesn't support timezone-aware datetimes for the EXTRACT
# function, so we will send it as a naive datetime in UTC
if sys.version_info < (3, 8):
start_datetime = start_datetime.astimezone(timezone.utc).replace(tzinfo=None)

return sa.func.floor(
(
sa.func.extract("epoch", db.TaskRun.start_time)
- sa.func.extract("epoch", start_datetime)
)
/ delta.total_seconds()
).label("bucket")


def _sqlite_bucket_expression(
db: PrefectDBInterface, delta: pendulum.Duration, start_datetime: datetime
):
return sa.func.floor(
(
(
sa.func.strftime("%s", db.TaskRun.start_time)
- sa.func.strftime("%s", start_datetime)
)
/ delta.total_seconds()
)
).label("bucket")


@router.post("/dashboard/counts")
async def read_dashboard_task_run_counts(
task_runs: schemas.filters.TaskRunFilter,
Expand Down Expand Up @@ -121,11 +89,10 @@ async def read_dashboard_task_run_counts(
start_time.microsecond,
start_time.timezone,
)
bucket_expression = (
_sqlite_bucket_expression(db, delta, start_datetime)
if db.dialect.name == "sqlite"
else _postgres_bucket_expression(db, delta, start_datetime)
)
bucket_expression = sa.func.floor(
sa.func.date_diff_seconds(db.TaskRun.start_time, start_datetime)
/ delta.total_seconds()
).label("bucket")

raw_counts = (
(
Expand Down
18 changes: 8 additions & 10 deletions src/prefect/server/database/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
from typing import Dict, Hashable, Optional, Tuple

import sqlalchemy as sa

try:
from sqlalchemy import AdaptedConnection
from sqlalchemy.pool import ConnectionPoolEntry
except ImportError:
# SQLAlchemy 1.4 equivalents
from sqlalchemy.pool import _ConnectionFairy as AdaptedConnection
from sqlalchemy.pool.base import _ConnectionRecord as ConnectionPoolEntry

from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy import AdaptedConnection
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
AsyncSessionTransaction,
create_async_engine,
)
from sqlalchemy.pool import ConnectionPoolEntry
from typing_extensions import Literal

from prefect.settings import (
Expand Down
3 changes: 0 additions & 3 deletions src/prefect/server/database/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,6 @@ def insert(self, model):
"""INSERTs a model into the database"""
return self.queries.insert(model)

def greatest(self, *values):
return self.queries.greatest(*values)

def make_timestamp_intervals(
self,
start_time: datetime.datetime,
Expand Down
41 changes: 18 additions & 23 deletions src/prefect/server/database/orm_models.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import datetime
import uuid
from abc import ABC, abstractmethod
from collections.abc import Hashable, Iterable
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Hashable,
Iterable,
Optional,
Union,
cast,
)

import pendulum
Expand Down Expand Up @@ -46,15 +44,12 @@
WorkQueueStatus,
)
from prefect.server.utilities.database import (
CAMEL_TO_SNAKE,
JSON,
UUID,
GenerateUUID,
Pydantic,
Timestamp,
camel_to_snake,
date_diff,
interval_add,
now,
)
from prefect.server.utilities.encryption import decrypt_fernet, encrypt_fernet
from prefect.utilities.names import generate_slug
Expand Down Expand Up @@ -117,7 +112,7 @@ def __tablename__(cls) -> str:
into a snake-case table name. Override by providing
an explicit `__tablename__` class property.
"""
return camel_to_snake.sub("_", cls.__name__).lower()
return CAMEL_TO_SNAKE.sub("_", cls.__name__).lower()

id: Mapped[uuid.UUID] = mapped_column(
primary_key=True,
Expand All @@ -126,17 +121,17 @@ def __tablename__(cls) -> str:
)

created: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)

# onupdate is only called when statements are actually issued
# against the database. until COMMIT is issued, this column
# will not be updated
updated: Mapped[pendulum.DateTime] = mapped_column(
index=True,
server_default=now(),
server_default=sa.func.now(),
default=lambda: pendulum.now("UTC"),
onupdate=now(),
onupdate=sa.func.now(),
server_onupdate=FetchedValue(),
)

Expand Down Expand Up @@ -175,7 +170,7 @@ class FlowRunState(Base):
sa.Enum(schemas.states.StateType, name="state_type"), index=True
)
timestamp: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)
name: Mapped[str] = mapped_column(index=True)
message: Mapped[Optional[str]]
Expand Down Expand Up @@ -240,7 +235,7 @@ class TaskRunState(Base):
sa.Enum(schemas.states.StateType, name="state_type"), index=True
)
timestamp: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)
name: Mapped[str] = mapped_column(index=True)
message: Mapped[Optional[str]]
Expand Down Expand Up @@ -419,9 +414,9 @@ def _estimated_run_time_expression(cls) -> sa.Label[datetime.timedelta]:
sa.case(
(
cls.state_type == schemas.states.StateType.RUNNING,
interval_add(
sa.func.interval_add(
cls.total_run_time,
date_diff(now(), cls.state_timestamp),
sa.func.date_diff(sa.func.now(), cls.state_timestamp),
),
),
else_=cls.total_run_time,
Expand Down Expand Up @@ -464,15 +459,15 @@ def _estimated_start_time_delta_expression(
return sa.case(
(
cls.start_time > cls.expected_start_time,
date_diff(cls.start_time, cls.expected_start_time),
sa.func.date_diff(cls.start_time, cls.expected_start_time),
),
(
sa.and_(
cls.start_time.is_(None),
cls.state_type.not_in(schemas.states.TERMINAL_STATES),
cls.expected_start_time < now(),
cls.expected_start_time < sa.func.now(),
),
date_diff(now(), cls.expected_start_time),
sa.func.date_diff(sa.func.now(), cls.expected_start_time),
),
else_=datetime.timedelta(0),
)
Expand Down Expand Up @@ -1165,7 +1160,7 @@ class Worker(Base):

name: Mapped[str]
last_heartbeat_time: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)
heartbeat_interval_seconds: Mapped[Optional[int]]

Expand Down Expand Up @@ -1195,7 +1190,7 @@ class Agent(Base):
)

last_activity_time: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)

__table_args__: Any = (sa.UniqueConstraint("name"),)
Expand Down Expand Up @@ -1277,11 +1272,11 @@ class Automation(Base):
@classmethod
def sort_expression(cls, value: AutomationSort) -> sa.ColumnExpressionArgument[Any]:
"""Return an expression used to sort Automations"""
sort_mapping = {
sort_mapping: dict[AutomationSort, sa.ColumnExpressionArgument[Any]] = {
AutomationSort.CREATED_DESC: cls.created.desc(),
AutomationSort.UPDATED_DESC: cls.updated.desc(),
AutomationSort.NAME_ASC: cast(sa.Column, cls.name).asc(),
AutomationSort.NAME_DESC: cast(sa.Column, cls.name).desc(),
AutomationSort.NAME_ASC: cls.name.asc(),
AutomationSort.NAME_DESC: cls.name.desc(),
}
return sort_mapping[value]

Expand Down
22 changes: 1 addition & 21 deletions src/prefect/server/database/query_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,6 @@ def _unique_key(self) -> Tuple[Hashable, ...]:
def insert(self, obj) -> Union[postgresql.Insert, sqlite.Insert]:
"""dialect-specific insert statement"""

@abstractmethod
def greatest(self, *values):
"""dialect-specific SqlAlchemy binding"""

@abstractmethod
def least(self, *values):
"""dialect-specific SqlAlchemy binding"""

# --- dialect-specific JSON handling

@abstractproperty
Expand Down Expand Up @@ -179,7 +171,7 @@ def get_scheduled_flow_runs_from_work_queues(
concurrency_queues = (
sa.select(
orm_models.WorkQueue.id,
self.greatest(
sa.func.greatest(
0,
orm_models.WorkQueue.concurrency_limit
- sa.func.count(orm_models.FlowRun.id),
Expand Down Expand Up @@ -628,12 +620,6 @@ class AsyncPostgresQueryComponents(BaseQueryComponents):
def insert(self, obj) -> postgresql.Insert:
return postgresql.insert(obj)

def greatest(self, *values):
return sa.func.greatest(*values)

def least(self, *values):
return sa.func.least(*values)

# --- Postgres-specific JSON handling

@property
Expand Down Expand Up @@ -984,12 +970,6 @@ class AioSqliteQueryComponents(BaseQueryComponents):
def insert(self, obj) -> sqlite.Insert:
return sqlite.insert(obj)

def greatest(self, *values):
return sa.func.max(*values)

def least(self, *values):
return sa.func.min(*values)

# --- Sqlite-specific JSON handling

@property
Expand Down
13 changes: 2 additions & 11 deletions src/prefect/server/events/counting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from prefect.server.database.dependencies import provide_database_interface
from prefect.server.database.interface import PrefectDBInterface
from prefect.server.utilities.database import json_extract
from prefect.types import DateTime
from prefect.utilities.collections import AutoEnum

Expand Down Expand Up @@ -290,16 +289,8 @@ def _database_label_expression(
return db.Event.event
elif self == self.resource:
return sa.func.coalesce(
json_extract(
db.Event.resource,
"prefect.resource.name",
wrap_quotes=db.dialect.name == "sqlite",
),
json_extract(
db.Event.resource,
"prefect.name",
wrap_quotes=db.dialect.name == "sqlite",
),
db.Event.resource["prefect.resource.name"].astext,
db.Event.resource["prefect.name"].astext,
db.Event.resource_id,
)
else:
Expand Down
13 changes: 3 additions & 10 deletions src/prefect/server/events/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
PrefectFilterBaseModel,
PrefectOperatorFilterBaseModel,
)
from prefect.server.utilities.database import json_extract
from prefect.types import DateTime
from prefect.utilities.collections import AutoEnum

Expand Down Expand Up @@ -309,9 +308,7 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]:
for _, (label, values) in enumerate(labels.items()):
label_ops = LabelOperations(values)

label_column = json_extract(
orm_models.EventResource.resource, label
)
label_column = orm_models.EventResource.resource[label].astext

# With negative labels, the resource _must_ have the label
if label_ops.negative.simple or label_ops.negative.prefixes:
Expand Down Expand Up @@ -404,9 +401,7 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]:
for _, (label, values) in enumerate(labels.items()):
label_ops = LabelOperations(values)

label_column = json_extract(
orm_models.EventResource.resource, label
)
label_column = orm_models.EventResource.resource[label].astext

if label_ops.negative.simple or label_ops.negative.prefixes:
label_filters.append(label_column.is_not(None))
Expand Down Expand Up @@ -518,9 +513,7 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]:
for _, (label, values) in enumerate(labels.items()):
label_ops = LabelOperations(values)

label_column = json_extract(
orm_models.EventResource.resource, label
)
label_column = orm_models.EventResource.resource[label].astext

if label_ops.negative.simple or label_ops.negative.prefixes:
label_filters.append(label_column.is_not(None))
Expand Down
Loading

0 comments on commit 8abba49

Please sign in to comment.