Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scheduler extra models #3749

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions mula/scheduler/server/handlers/health.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Any

import fastapi
import structlog
from fastapi import status
Expand All @@ -22,7 +20,7 @@ def __init__(self, api: fastapi.FastAPI, ctx: context.AppContext) -> None:
description="Health check endpoint",
)

def health(self, externals: bool = False) -> Any:
def health(self, externals: bool = False) -> models.ServiceHealth:
response = models.ServiceHealth(service="scheduler", healthy=True, version=version.__version__)

if externals:
Expand Down
26 changes: 13 additions & 13 deletions mula/scheduler/server/handlers/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from fastapi import status

from scheduler import context, models, queues, schedulers, storage
from scheduler.server import serializers
from scheduler.server.errors import BadRequestError, ConflictError, NotFoundError, TooManyRequestsError
from scheduler.server.models import Queue, Task, TaskCreate


class QueueAPI:
Expand All @@ -20,7 +20,7 @@ def __init__(self, api: fastapi.FastAPI, ctx: context.AppContext, s: dict[str, s
path="/queues",
endpoint=self.list,
methods=["GET"],
response_model=list[models.Queue],
response_model=list[Queue],
response_model_exclude_unset=True,
status_code=status.HTTP_200_OK,
description="List all queues",
Expand All @@ -30,7 +30,7 @@ def __init__(self, api: fastapi.FastAPI, ctx: context.AppContext, s: dict[str, s
path="/queues/{queue_id}",
endpoint=self.get,
methods=["GET"],
response_model=models.Queue,
response_model=Queue,
status_code=status.HTTP_200_OK,
description="Get a queue",
)
Expand All @@ -39,7 +39,7 @@ def __init__(self, api: fastapi.FastAPI, ctx: context.AppContext, s: dict[str, s
path="/queues/{queue_id}/pop",
endpoint=self.pop,
methods=["POST"],
response_model=models.Task | None,
response_model=Task | None,
status_code=status.HTTP_200_OK,
description="Pop an item from a queue",
)
Expand All @@ -48,22 +48,22 @@ def __init__(self, api: fastapi.FastAPI, ctx: context.AppContext, s: dict[str, s
path="/queues/{queue_id}/push",
endpoint=self.push,
methods=["POST"],
response_model=models.Task | None,
response_model=Task | None,
status_code=status.HTTP_201_CREATED,
description="Push an item to a queue",
)

def list(self) -> Any:
return [models.Queue(**s.queue.dict(include_pq=False)) for s in self.schedulers.copy().values()]
return [Queue(**s.queue.dict(include_pq=False)) for s in self.schedulers.copy().values()]

def get(self, queue_id: str) -> Any:
def get(self, queue_id: str) -> Queue:
s = self.schedulers.get(queue_id)
if s is None:
raise NotFoundError(f"queue not found, by queue_id: {queue_id}")

return models.Queue(**s.queue.dict())
return Queue(**s.queue.dict())

def pop(self, queue_id: str, filters: storage.filters.FilterRequest | None = None) -> Any:
def pop(self, queue_id: str, filters: storage.filters.FilterRequest | None = None) -> Task | None:
s = self.schedulers.get(queue_id)
if s is None:
raise NotFoundError(f"queue not found, by queue_id: {queue_id}")
Expand All @@ -76,15 +76,15 @@ def pop(self, queue_id: str, filters: storage.filters.FilterRequest | None = Non
if item is None:
raise NotFoundError("could not pop item from queue, check your filters")

return models.Task(**item.model_dump())
return Task(**item.model_dump())

def push(self, queue_id: str, item_in: serializers.Task) -> Any:
def push(self, queue_id: str, item: TaskCreate) -> Task | None:
s = self.schedulers.get(queue_id)
if s is None:
raise NotFoundError(f"queue not found, by queue_id: {queue_id}")

# Load default values
new_item = models.Task(**item_in.model_dump(exclude_unset=True))
new_item = models.Task(**item.model_dump(exclude_unset=True))

# Set values
if new_item.scheduler_id is None:
Expand All @@ -99,4 +99,4 @@ def push(self, queue_id: str, item_in: serializers.Task) -> Any:
except queues.errors.NotAllowedError:
raise ConflictError("queue is not allowed to push items")

return pushed_item
return Task(**pushed_item.model_dump())
21 changes: 10 additions & 11 deletions mula/scheduler/server/handlers/schedulers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any

import fastapi
import structlog
from fastapi import status

from scheduler import context, models, schedulers
from scheduler.server.errors import BadRequestError, NotFoundError
from scheduler.server.models import Scheduler


class SchedulerAPI:
Expand All @@ -19,7 +18,7 @@ def __init__(self, api: fastapi.FastAPI, ctx: context.AppContext, s: dict[str, s
path="/schedulers",
endpoint=self.list,
methods=["GET"],
response_model=list[models.Scheduler],
response_model=list[Scheduler],
status_code=status.HTTP_200_OK,
description="List all schedulers",
)
Expand All @@ -28,7 +27,7 @@ def __init__(self, api: fastapi.FastAPI, ctx: context.AppContext, s: dict[str, s
path="/schedulers/{scheduler_id}",
endpoint=self.get,
methods=["GET"],
response_model=models.Scheduler,
response_model=Scheduler,
status_code=status.HTTP_200_OK,
description="Get a scheduler",
)
Expand All @@ -37,22 +36,22 @@ def __init__(self, api: fastapi.FastAPI, ctx: context.AppContext, s: dict[str, s
path="/schedulers/{scheduler_id}",
endpoint=self.patch,
methods=["PATCH"],
response_model=models.Scheduler,
response_model=Scheduler,
status_code=status.HTTP_200_OK,
description="Update a scheduler",
)

def list(self) -> Any:
return [models.Scheduler(**s.dict()) for s in self.schedulers.values()]
def list(self) -> list[Scheduler]:
return [Scheduler(**s.dict()) for s in self.schedulers.values()]

def get(self, scheduler_id: str) -> Any:
def get(self, scheduler_id: str) -> Scheduler:
s = self.schedulers.get(scheduler_id)
if s is None:
raise NotFoundError(f"Scheduler {scheduler_id} not found")

return models.Scheduler(**s.dict())
return Scheduler(**s.dict())

def patch(self, scheduler_id: str, item: models.Scheduler) -> Any:
def patch(self, scheduler_id: str, item: models.Scheduler) -> Scheduler:
s = self.schedulers.get(scheduler_id)
if s is None:
raise NotFoundError(f"Scheduler {scheduler_id} not found")
Expand All @@ -75,4 +74,4 @@ def patch(self, scheduler_id: str, item: models.Scheduler) -> Any:
elif not updated_scheduler.enabled:
s.disable()

return updated_scheduler
return Scheduler(**updated_scheduler.dict())
28 changes: 15 additions & 13 deletions mula/scheduler/server/handlers/schedules.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import datetime
import uuid
from typing import Any

import fastapi
import structlog
from fastapi import Body

from scheduler import context, models, schedulers, storage
from scheduler.server import serializers, utils
from scheduler.server import utils
from scheduler.server.errors import BadRequestError, ConflictError, NotFoundError, ValidationError
from scheduler.server.models import Schedule, ScheduleCreate, ScheduleUpdate


class ScheduleAPI:
Expand All @@ -33,7 +33,7 @@ def __init__(
path="/schedules",
endpoint=self.create,
methods=["POST"],
response_model=models.Schedule,
response_model=Schedule,
status_code=201,
description="Create a schedule",
)
Expand All @@ -42,7 +42,7 @@ def __init__(
path="/schedules/{schedule_id}",
endpoint=self.get,
methods=["GET"],
response_model=models.Schedule,
response_model=Schedule,
status_code=200,
description="Get a schedule",
)
Expand All @@ -51,7 +51,7 @@ def __init__(
path="/schedules/{schedule_id}",
endpoint=self.patch,
methods=["PATCH"],
response_model=models.Schedule,
response_model=Schedule,
response_model_exclude_unset=True,
status_code=200,
description="Update a schedule",
Expand Down Expand Up @@ -86,7 +86,7 @@ def list(
max_deadline_at: datetime.datetime | None = None,
min_created_at: datetime.datetime | None = None,
max_created_at: datetime.datetime | None = None,
) -> Any:
) -> utils.PaginatedResponse:
if (min_created_at is not None and max_created_at is not None) and min_created_at > max_created_at:
raise BadRequestError("min_created_at must be less than max_created_at")

Expand All @@ -104,10 +104,11 @@ def list(
offset=offset,
limit=limit,
)
results = [Schedule(**s.model_dump()) for s in results]

return utils.paginate(request, results, count, offset, limit)

def create(self, schedule: serializers.ScheduleCreate) -> Any:
def create(self, schedule: ScheduleCreate) -> Schedule:
try:
new_schedule = models.Schedule(**schedule.model_dump())
except ValueError:
Expand All @@ -131,17 +132,17 @@ def create(self, schedule: serializers.ScheduleCreate) -> Any:
if schedule is not None:
raise ConflictError(f"schedule with the same hash already exists: {new_schedule.hash}")

self.ctx.datastores.schedule_store.create_schedule(new_schedule)
return new_schedule
created_schedule = self.ctx.datastores.schedule_store.create_schedule(new_schedule)
return Schedule(**created_schedule.model_dump())

def get(self, schedule_id: uuid.UUID) -> Any:
def get(self, schedule_id: uuid.UUID) -> Schedule:
schedule = self.ctx.datastores.schedule_store.get_schedule(schedule_id)
if schedule is None:
raise NotFoundError(f"schedule not found, by schedule_id: {schedule_id}")

return schedule
return Schedule(**schedule.model_dump())

def patch(self, schedule_id: uuid.UUID, schedule: serializers.SchedulePatch) -> Any:
def patch(self, schedule_id: uuid.UUID, schedule: ScheduleUpdate) -> Schedule:
schedule_db = self.ctx.datastores.schedule_store.get_schedule(schedule_id)
if schedule_db is None:
raise NotFoundError(f"schedule not found, by schedule_id: {schedule_id}")
Expand All @@ -162,7 +163,7 @@ def patch(self, schedule_id: uuid.UUID, schedule: serializers.SchedulePatch) ->
# Update schedule in database
self.ctx.datastores.schedule_store.update_schedule(updated_schedule)

return updated_schedule
return Schedule(**updated_schedule.model_dump())

def search(
self,
Expand All @@ -180,6 +181,7 @@ def search(
results, count = self.ctx.datastores.schedule_store.get_schedules(
offset=offset, limit=limit, filters=filters
)
results = [Schedule(**s.model_dump()) for s in results]
except storage.filters.errors.FilterError as exc:
raise fastapi.HTTPException(
status_code=fastapi.status.HTTP_400_BAD_REQUEST, detail=f"invalid filter(s) [exception: {exc}]"
Expand Down
22 changes: 11 additions & 11 deletions mula/scheduler/server/handlers/tasks.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import datetime
import uuid
from typing import Any

import fastapi
import structlog
from fastapi import status

from scheduler import context, models, storage
from scheduler.server import serializers, utils
from scheduler import context, storage
from scheduler.server import utils
from scheduler.server.errors import BadRequestError, NotFoundError
from scheduler.server.models import Task, TaskUpdate


class TaskAPI:
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(self, api: fastapi.FastAPI, ctx: context.AppContext) -> None:
path="/tasks/{task_id}",
endpoint=self.get,
methods=["GET"],
response_model=models.Task,
response_model=Task,
status_code=status.HTTP_200_OK,
description="Get a task",
)
Expand All @@ -55,7 +55,7 @@ def __init__(self, api: fastapi.FastAPI, ctx: context.AppContext) -> None:
path="/tasks/{task_id}",
endpoint=self.patch,
methods=["PATCH"],
response_model=models.Task,
response_model=TaskUpdate,
response_model_exclude_unset=True,
status_code=status.HTTP_200_OK,
description="Update a task",
Expand All @@ -74,7 +74,7 @@ def list(
input_ooi: str | None = None, # FIXME: deprecated
plugin_id: str | None = None, # FIXME: deprecated
filters: storage.filters.FilterRequest | None = None,
) -> Any:
) -> utils.PaginatedResponse:
if (min_created_at is not None and max_created_at is not None) and min_created_at > max_created_at:
raise BadRequestError("min_created_at must be less than max_created_at")

Expand Down Expand Up @@ -137,16 +137,16 @@ def list(
max_created_at=max_created_at,
filters=f_req,
)

results = [Task(**t.model_dump()) for t in results]
return utils.paginate(request, results, count, offset, limit)

def get(self, task_id: uuid.UUID) -> Any:
def get(self, task_id: uuid.UUID) -> Task:
task = self.ctx.datastores.task_store.get_task(task_id)
if task is None:
raise NotFoundError(f"task not found, by task_id: {task_id}")
return task
return Task(**task.model_dump())

def patch(self, task_id: uuid.UUID, item: serializers.Task) -> Any:
def patch(self, task_id: uuid.UUID, item: TaskUpdate) -> TaskUpdate:
task_db = self.ctx.datastores.task_store.get_task(task_id)

if task_db is None:
Expand All @@ -161,7 +161,7 @@ def patch(self, task_id: uuid.UUID, item: serializers.Task) -> Any:

self.ctx.datastores.task_store.update_task(updated_task)

return updated_task
return TaskUpdate(**updated_task.model_dump())

def stats(self, scheduler_id: str | None = None) -> dict[str, dict[str, int]] | None:
return self.ctx.datastores.task_store.get_status_count_per_hour(scheduler_id)
4 changes: 4 additions & 0 deletions mula/scheduler/server/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .queue import Queue
from .schedule import Schedule, ScheduleCreate, ScheduleUpdate
from .scheduler import Scheduler
from .task import Task, TaskCreate, TaskStatus, TaskUpdate
14 changes: 14 additions & 0 deletions mula/scheduler/server/models/queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pydantic import BaseModel

from scheduler.models import Task


class Queue(BaseModel):
id: str
size: int
maxsize: int
item_type: str
allow_replace: bool
allow_updates: bool
allow_priority_updates: bool
pq: list[Task] | None = None
Loading
Loading