Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Implements a task scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
Mathieu Velten committed Jul 7, 2023
1 parent b07b14b commit 1c980d4
Show file tree
Hide file tree
Showing 9 changed files with 316 additions and 0 deletions.
1 change: 1 addition & 0 deletions changelog.d/15891.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implements a task scheduler.
2 changes: 2 additions & 0 deletions synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.storage.databases.main.tags import TagsWorkerStore
from synapse.storage.databases.main.task_scheduler import TaskSchedulerWorkerStore
from synapse.storage.databases.main.transactions import TransactionWorkerStore
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
from synapse.storage.databases.main.user_directory import UserDirectoryStore
Expand Down Expand Up @@ -144,6 +145,7 @@ class GenericWorkerStore(
TransactionWorkerStore,
LockStore,
SessionStore,
TaskSchedulerWorkerStore,
):
# Properties that multiple storage classes define. Tell mypy what the
# expected type is.
Expand Down
111 changes: 111 additions & 0 deletions synapse/handlers/task_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set

from attrs import evolve

from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import JsonMapping, ScheduledTask, TaskState
from synapse.util.stringutils import random_string

if TYPE_CHECKING:
from synapse.server import HomeServer


class TaskSchedulerHandler:
SCHEDULING_INTERVAL_MS = 10 * 60 * 1000 # 10mn

def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
self._is_master = hs.config.worker.worker_app is None
self.running_tasks: Set[str] = set()
self.actions: Dict[
str, Callable[[ScheduledTask], Awaitable[Optional[ScheduledTask]]]
] = {}

if self._is_master:
self.clock.looping_call(
run_as_background_process,
TaskSchedulerHandler.SCHEDULING_INTERVAL_MS,
"scheduled_tasks_loop",
self._scheduled_tasks_loop,
)

def bind_action(
self,
fct: Callable[[ScheduledTask], Awaitable[Optional[ScheduledTask]]],
action_name: str,
) -> None:
self.actions[action_name] = fct

async def schedule_task(
self,
action: str,
*,
resource_id: Optional[str] = None,
timestamp: Optional[int] = None,
params: Optional[JsonMapping] = None,
) -> str:
if action not in self.actions:
# TODO
raise SynapseError(400, "Test")
task_id = random_string(16)
state = TaskState.SCHEDULED
if timestamp is None or timestamp < self.clock.time_msec():
state = TaskState.RUNNING
timestamp = self.clock.time_msec()

task = ScheduledTask(
task_id,
action,
state,
resource_id,
timestamp,
params,
None,
)
await self.store.upsert_scheduled_task(task)
return task_id

async def update_task_state(
self,
task: ScheduledTask,
# error: Optional[str],
) -> None:
await self.store.upsert_scheduled_task(task)

async def get_task(self, id: str) -> Optional[ScheduledTask]:
return await self.store.get_scheduled_task(id)

async def get_tasks(
self, action: str, resource_id: Optional[str]
) -> List[ScheduledTask]:
return await self.store.get_scheduled_tasks(action, resource_id)

async def _scheduled_tasks_loop(self) -> None:
for task in await self.store.get_scheduled_tasks():
if task.id not in self.running_tasks:
state = task.state
if (
state == TaskState.SCHEDULED
and task.timestamp is not None
and task.timestamp < self.clock.time_msec()
):
state = TaskState.RUNNING

if state == TaskState.RUNNING:
await self.store.upsert_scheduled_task(task)
self._run_task(task)

def _run_task(self, task: ScheduledTask) -> None:
if task.action in self.actions:
fct = self.actions[task.action]

async def wrapper() -> None:
updated_task = await fct(task)
if updated_task is None:
updated_task = evolve(task, state=TaskState.COMPLETE)
await self.update_task_state(updated_task)

run_as_background_process(task.action, wrapper)
self.running_tasks.add(task.id)
6 changes: 6 additions & 0 deletions synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
from synapse.handlers.sso import SsoHandler
from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.task_scheduler import TaskSchedulerHandler
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.http.client import (
Expand Down Expand Up @@ -242,6 +243,7 @@ class HomeServer(metaclass=abc.ABCMeta):
"profile",
"room_forgetter",
"stats",
"task_scheduler",
]

# This is overridden in derived application classes
Expand Down Expand Up @@ -912,3 +914,7 @@ def get_request_ratelimiter(self) -> RequestRatelimiter:
def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager:
"""Usage metrics shared between phone home stats and the prometheus exporter."""
return CommonUsageMetricsManager(self)

@cache_in_self
def get_task_scheduler_handler(self) -> TaskSchedulerHandler:
return TaskSchedulerHandler(self)
2 changes: 2 additions & 0 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from .stats import StatsStore
from .stream import StreamWorkerStore
from .tags import TagsStore
from .task_scheduler import TaskSchedulerWorkerStore
from .transactions import TransactionWorkerStore
from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
Expand Down Expand Up @@ -127,6 +128,7 @@ class DataStore(
CacheInvalidationWorkerStore,
LockStore,
SessionStore,
TaskSchedulerWorkerStore,
):
def __init__(
self,
Expand Down
100 changes: 100 additions & 0 deletions synapse/storage/databases/main/task_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.types import ScheduledTask, TaskState

if TYPE_CHECKING:
from synapse.server import HomeServer


class TaskSchedulerWorkerStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

@staticmethod
def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask:
row["state"] = TaskState(row["state"])
if row["params"] is not None:
row["params"] = json.loads(row["params"])
if row["result"] is not None:
row["result"] = json.loads(row["result"])
return ScheduledTask(**row)

async def get_scheduled_tasks(
self, action: Optional[str] = None, resource_id: Optional[str] = None
) -> List[ScheduledTask]:
keyvalues = {}
if action:
keyvalues["action"] = action
if resource_id:
keyvalues["resource_id"] = resource_id

rows = await self.db_pool.simple_select_list(
table="scheduled_tasks",
keyvalues=keyvalues,
retcols=(
"id",
"action",
"state",
"timestamp",
"resource_id",
"params",
"result",
# "error",
),
desc="get_scheduled_tasks",
)

return [TaskSchedulerWorkerStore._convert_row_to_task(row) for row in rows]

async def upsert_scheduled_task(self, task: ScheduledTask) -> None:
await self.db_pool.simple_upsert(
"scheduled_tasks",
{"id": task.id},
{
"action": task.action,
"state": task.state,
"resource_id": task.resource_id,
"timestamp": task.timestamp,
"params": None if task.params is None else json.dumps(task.params),
"result": None if task.result is None else json.dumps(task.result),
# "error": task.error,
},
desc="upsert_scheduled_task",
)

async def get_scheduled_task(self, id: str) -> Optional[ScheduledTask]:
row = await self.db_pool.simple_select_one(
table="scheduled_tasks",
keyvalues={"id": id},
retcols=(
"id",
"action",
"state",
"resource_id",
"timestamp",
"params",
"result",
# "error",
),
desc="get_scheduled_task",
)

return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None

async def delete_scheduled_task(self, id: str) -> bool:
return (
await self.db_pool.simple_delete(
"scheduled_tasks",
keyvalues={"id": id},
desc="delete_scheduled_task",
)
> 0
)
26 changes: 26 additions & 0 deletions synapse/storage/schema/main/delta/78/05_scheduled_tasks.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/* Copyright 2023 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

-- cf ScheduledTask docstring for the meaning of the fields.
CREATE TABLE IF NOT EXISTS scheduled_tasks(
id text PRIMARY KEY,
action text NOT NULL,
state text NOT NULL,
resource_id text,
timestamp bigint,
params text,
result text
-- error text
);
21 changes: 21 additions & 0 deletions synapse/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import abc
import re
import string
from enum import Enum
from typing import (
TYPE_CHECKING,
AbstractSet,
Expand Down Expand Up @@ -979,3 +980,23 @@ class UserProfile(TypedDict):
class RetentionPolicy:
min_lifetime: Optional[int] = None
max_lifetime: Optional[int] = None


class TaskState(str, Enum):
SCHEDULED = "scheduled"
RUNNING = "running"
COMPLETE = "complete"
FAILED = "failed"
ABORTED = "aborted"


@attr.s(auto_attribs=True, frozen=True, slots=True)
class ScheduledTask:
id: str
action: str
state: TaskState
resource_id: Optional[str]
timestamp: Optional[int]
params: Optional[JsonMapping]
result: Optional[JsonDict]
# error: Optional[str]
47 changes: 47 additions & 0 deletions tests/handlers/test_task_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Optional

from attrs import evolve

from twisted.test.proto_helpers import MemoryReactor

from synapse.server import HomeServer
from synapse.types import ScheduledTask, TaskState
from synapse.util import Clock

from tests import unittest


class TestTaskScheduler(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_task_scheduler_handler()
self.handler.bind_action(self._test_task, "test_action")

async def _test_task(self, task: ScheduledTask) -> Optional[ScheduledTask]:
if task.params:
val = task.params.get("val")
task = evolve(task, state=TaskState.COMPLETE, result={"val": val})
return task
return None

def test_schedule_task(self) -> None:
timestamp = self.clock.time_msec() + 5 * 60 * 1000
task_id = self.get_success(
self.handler.schedule_task(
"test_action",
timestamp=timestamp,
params={"val": 1},
)
)

running_task = self.get_success(self.handler.get_task(task_id))
assert running_task is not None
self.assertEqual(running_task.state, TaskState.SCHEDULED)
self.assertIsNone(running_task.result)

self.reactor.advance(20 * 60)

running_task = self.get_success(self.handler.get_task(task_id))
assert running_task is not None
self.assertEqual(running_task.state, TaskState.COMPLETE)
assert running_task.result is not None
self.assertTrue(running_task.result.get("val") == 1)

0 comments on commit 1c980d4

Please sign in to comment.