Skip to content

Commit

Permalink
[core][feat] Merge deferred edges via API (#2136)
Browse files Browse the repository at this point in the history
  • Loading branch information
aquamatthias authored Jul 16, 2024
1 parent 327e6a0 commit 526b549
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 60 deletions.
6 changes: 3 additions & 3 deletions fixcore/fixcore/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from urllib3.exceptions import HTTPWarning

from fixcore import version
from fixcore.action_handlers.merge_outer_edge_handler import MergeOuterEdgesHandler
from fixcore.action_handlers.merge_deferred_edge_handler import MergeDeferredEdgesHandler
from fixcore.analytics import CoreEvent, NoEventSender
from fixcore.analytics.posthog import PostHogEventSender
from fixcore.analytics.recurrent_events import emit_recurrent_events
Expand Down Expand Up @@ -228,8 +228,8 @@ async def direct_tenant(deps: TenantDependencies) -> None:
)
deps.add(ServiceNames.graph_manager, GraphManager(db, config, core_config_handler, task_handler))
deps.add(
ServiceNames.merge_outer_edges_handler,
MergeOuterEdgesHandler(message_bus, subscriptions, task_handler, db, model),
ServiceNames.merge_deferred_edges_handler,
MergeDeferredEdgesHandler(message_bus, subscriptions, task_handler, db, model),
)
deps.add(
ServiceNames.event_emitter_periodic,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from attr import frozen

from fixcore.db.model import QueryModel
from fixcore.message_bus import MessageBus, Action
import logging
Expand All @@ -21,10 +23,17 @@
log = logging.getLogger(__name__)

subscriber_id = SubscriberId("fixcore")
merge_outer_edges = "merge_outer_edges"
merge_deferred_edges = "merge_deferred_edges"


@frozen
class DeferredMergeResult:
processed: int
updated: int
deleted: int

class MergeOuterEdgesHandler(Service):

class MergeDeferredEdgesHandler(Service):
def __init__(
self,
message_bus: MessageBus,
Expand All @@ -35,18 +44,21 @@ def __init__(
):
super().__init__()
self.message_bus = message_bus
self.merge_outer_edges_listener: Optional[Task[None]] = None
self.merge_deferred_edges_listener: Optional[Task[None]] = None
self.subscription_handler = subscription_handler
self.subscriber: Optional[Subscriber] = None
self.task_handler_service = task_handler_service
self.db_access = db_access
self.model_handler = model_handler

async def merge_outer_edges(self, task_id: TaskId) -> Tuple[int, int]:
async def merge_deferred_edges(self, task_ids: List[TaskId]) -> DeferredMergeResult:
deferred_outer_edge_db = self.db_access.deferred_outer_edge_db
pending_edges = await deferred_outer_edge_db.all_for_task(task_id)
pending_edges = []
for task_id in task_ids:
pending_edges.extend(await deferred_outer_edge_db.all_for_task(task_id))
if pending_edges:
first = pending_edges[0]
processed = 0
first = min(pending_edges, key=lambda x: x.created_at)
graph_db = self.db_access.get_graph_db(first.graph)
model = await self.model_handler.load_model(first.graph)

Expand Down Expand Up @@ -76,49 +88,43 @@ async def find_node_id(selector: NodeSelector) -> Optional[NodeId]:
for edge in pending_edge.edges:
from_id = await find_node_id(edge.from_node)
to_id = await find_node_id(edge.to_node)
processed += 1
if from_id and to_id:
edges.append((from_id, to_id, edge.edge_type))

# apply edges in graph
updated, deleted = await graph_db.update_deferred_edges(edges, first.created_at)

log.info(
f"MergeOuterEdgesHandler: updated {updated}/{len(edges)},"
f" deleted {deleted} edges in task id {task_id}"
)

return updated, deleted
# delete processed edge definitions
for task_id in task_ids:
await deferred_outer_edge_db.delete_for_task(task_id)
log.info(f"DeferredEdges: {len(edges)} edges: {updated} updated, {deleted} deleted. ({task_ids})")
return DeferredMergeResult(processed, updated, deleted)
else:
log.info(f"MergeOuterEdgesHandler: no pending edges for task id {task_id} found.")

return 0, 0

async def mark_done(self, task_id: TaskId) -> None:
deferred_outer_edge_db = self.db_access.deferred_outer_edge_db
await deferred_outer_edge_db.delete_for_task(task_id)
log.info(f"MergeOuterEdgesHandler: no pending edges found. ({task_ids})")
return DeferredMergeResult(0, 0, 0)

async def __handle_events(self, subscription_done: Future[None]) -> None:
async with self.message_bus.subscribe(subscriber_id, [merge_outer_edges]) as events:
async with self.message_bus.subscribe(subscriber_id, [merge_deferred_edges]) as events:
subscription_done.set_result(None)
while True:
event = await events.get()
if isinstance(event, Action) and event.message_type == merge_outer_edges:
await self.merge_outer_edges(event.task_id)
await self.mark_done(event.task_id)
if isinstance(event, Action) and event.message_type == merge_deferred_edges:
await self.merge_deferred_edges([event.task_id])
await self.task_handler_service.handle_action_done(event.done(subscriber_id))

async def start(self) -> None:
subscription_done = asyncio.get_event_loop().create_future()
self.subscriber = await self.subscription_handler.add_subscription(
subscriber_id, merge_outer_edges, True, timedelta(seconds=30)
subscriber_id, merge_deferred_edges, True, timedelta(seconds=30)
)
self.merge_outer_edges_listener = asyncio.create_task(
self.merge_deferred_edges_listener = asyncio.create_task(
self.__handle_events(subscription_done), name=subscriber_id
)
await subscription_done

async def stop(self) -> None:
if self.merge_outer_edges_listener:
if self.merge_deferred_edges_listener:
with suppress(Exception):
self.merge_outer_edges_listener.cancel()
self.merge_deferred_edges_listener.cancel()
if self.subscriber:
await self.subscription_handler.remove_subscription(subscriber_id, merge_outer_edges)
await self.subscription_handler.remove_subscription(subscriber_id, merge_deferred_edges)
2 changes: 1 addition & 1 deletion fixcore/fixcore/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class ServiceNames:
infra_apps_runtime = "infra_apps_runtime"
inspector = "inspector"
jwt_signing_key_holder = "jwt_signing_key_holder"
merge_outer_edges_handler = "merge_outer_edges_handler"
merge_deferred_edges_handler = "merge_deferred_edges_handler"
message_bus = "message_bus"
model_handler = "model_handler"
scheduler = "scheduler"
Expand Down
36 changes: 36 additions & 0 deletions fixcore/fixcore/static/api-doc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,42 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/GraphUpdate"
/graph/{graph_id}/merge/deferred_edges:
post:
summary: "Merge deferred edges by a given task id"
description: |
All existing deferred edges will be replaced by the definition of all deferred edges of the given task ids.
We might be able to track deferred edges more specifically in the future.
tags:
- graph_management
parameters:
- $ref: "#/components/parameters/graph_id"
requestBody:
description: "The task ids to merge."
required: true
content:
application/json:
schema:
type: array
items:
type: string
responses:
"200":
description: "Return a summary of actions that has been applied."
content:
application/json:
schema:
type: object
properties:
processed:
type: integer
description: Number of processed edges
updated:
type: integer
description: Number of updated edges
deleted:
type: integer
description: Number of deleted edges
/graph/{graph_id}/batch/merge:
post:
summary: "Merge a given graph with the existing graph under marked merge nodes as batch update."
Expand Down
13 changes: 9 additions & 4 deletions fixcore/fixcore/task/task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,12 +610,17 @@ def workflow(name: TaskDescriptorId, steps: List[Step]) -> Workflow:
trigger.append(TimeTrigger(wf_config.schedule))
return Workflow(uid=name, name=name, steps=steps, triggers=trigger, on_surpass=TaskSurpassBehaviour.Wait)

collect_steps = [
pre_collect = [
Step("pre_collect", PerformAction("pre_collect"), timedelta(seconds=10)),
Step("collect", PerformAction("collect"), timedelta(seconds=10)),
Step("wait_for_graph_merged", WaitForCollectDone(), timedelta(minutes=10)),
Step("merge_outer_edges", PerformAction("merge_outer_edges"), timedelta(seconds=10)),
Step("post_collect", PerformAction("post_collect"), timedelta(seconds=10)),
]
post_collect = [Step("post_collect", PerformAction("post_collect"), timedelta(seconds=10))]
collect_steps = pre_collect + post_collect
collect_with_merge = [
*pre_collect,
Step("merge_deferred_edges", PerformAction("merge_deferred_edges"), timedelta(seconds=10)),
*post_collect, # deferred edges are merged before post_collect
]
cleanup_steps = [
Step("pre_cleanup_plan", PerformAction("pre_cleanup_plan"), timedelta(seconds=10)),
Expand All @@ -634,7 +639,7 @@ def workflow(name: TaskDescriptorId, steps: List[Step]) -> Workflow:
workflow(TaskDescriptorId("collect"), collect_steps + metrics_steps),
workflow(TaskDescriptorId("cleanup"), cleanup_steps + metrics_steps),
workflow(TaskDescriptorId("metrics"), metrics_steps),
workflow(TaskDescriptorId("collect_and_cleanup"), collect_steps + cleanup_steps + metrics_steps),
workflow(TaskDescriptorId("collect_and_cleanup"), collect_with_merge + cleanup_steps + metrics_steps),
]

# endregion
11 changes: 10 additions & 1 deletion fixcore/fixcore/web/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from multidict import MultiDict
from networkx.readwrite import cytoscape_data

from fixcore.action_handlers.merge_deferred_edge_handler import MergeDeferredEdgesHandler
from fixcore.analytics import AnalyticsEvent
from fixcore.cli.command import alias_names
from fixcore.cli.model import (
Expand All @@ -74,7 +75,7 @@
from fixcore.console_renderer import ConsoleColorSystem, ConsoleRenderer
from fixcore.db.graphdb import GraphDB, HistoryChange
from fixcore.db.model import QueryModel
from fixcore.dependencies import Dependencies, TenantDependencies
from fixcore.dependencies import Dependencies, TenantDependencies, ServiceNames
from fixcore.dependencies import TenantDependencyProvider
from fixcore.error import NotFoundError, NotEnoughPermissions
from fixcore.ids import (
Expand Down Expand Up @@ -251,6 +252,7 @@ def __add_routes(self, prefix: str) -> None:
# maintain the graph
web.patch(prefix + "/graph/{graph_id}/nodes", require(self.update_nodes, r, w)),
web.post(prefix + "/graph/{graph_id}/merge", require(self.merge_graph, r, w)),
web.post(prefix + "/graph/{graph_id}/merge/deferred_edges", require(self.merge_deferred_edges, r, w)),
web.post(prefix + "/graph/{graph_id}/batch/merge", require(self.update_merge_graph_batch, r, w)),
web.get(prefix + "/graph/{graph_id}/batch", require(self.list_batches, r, w)),
web.post(prefix + "/graph/{graph_id}/batch/{batch_id}", require(self.commit_batch, r, w)),
Expand Down Expand Up @@ -1074,6 +1076,13 @@ async def create_graph(self, request: Request, deps: TenantDependencies) -> Stre
root = await graph.get_node(model, NodeId("root"))
return web.json_response(root)

async def merge_deferred_edges(self, request: Request, deps: TenantDependencies) -> StreamResponse:
task_ids = await request.json()
assert isinstance(task_ids, list), "Expected a list of task ids"
deferred_edges_handler = deps.service(ServiceNames.merge_deferred_edges_handler, MergeDeferredEdgesHandler)
r = await deferred_edges_handler.merge_deferred_edges(task_ids)
return await single_result(request, {"processed": r.processed, "updated": r.updated, "deleted": r.deleted})

async def merge_graph(self, request: Request, deps: TenantDependencies) -> StreamResponse:
graph_id = GraphName(request.match_info.get("graph_id", "fix"))
wait_for_result = request.query.get("wait_for_result", "true").lower() == "true"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
from datetime import timedelta
from typing import List

import pytest

from fixcore.action_handlers.merge_outer_edge_handler import MergeOuterEdgesHandler
from fixcore.action_handlers.merge_deferred_edge_handler import MergeDeferredEdgesHandler, merge_deferred_edges
from fixcore.db.db_access import DbAccess
from fixcore.db.deferredouteredgedb import DeferredOuterEdges
from fixcore.db.graphdb import ArangoGraphDB
Expand All @@ -19,38 +20,36 @@
from fixcore.util import utc
from tests.fixcore.db.graphdb_test import Foo, Bla, BaseResource

merge_outer_edges = "merge_outer_edges"


@pytest.mark.asyncio
async def test_handler_invocation(
merge_handler: MergeOuterEdgesHandler,
merge_handler: MergeDeferredEdgesHandler,
subscription_handler: SubscriptionHandler,
message_bus: MessageBus,
) -> None:
merge_called: asyncio.Future[TaskId] = asyncio.get_event_loop().create_future()
merge_called: asyncio.Future[List[TaskId]] = asyncio.get_event_loop().create_future()

def mocked_merge(task_id: TaskId) -> None:
merge_called.set_result(task_id)
def mocked_merge(task_ids: List[TaskId]) -> None:
merge_called.set_result(task_ids)

# monkey patching the merge_outer_edges method
# monkey patching the merge_deferred_edges method
# use setattr here, since assignment does not work in mypy https://github.com/python/mypy/issues/2427
setattr(merge_handler, "merge_outer_edges", mocked_merge)
setattr(merge_handler, "merge_deferred_edges", mocked_merge)

subscribers = await subscription_handler.list_subscriber_for(merge_outer_edges)
subscribers = await subscription_handler.list_subscriber_for(merge_deferred_edges)

assert subscribers[0].id == "fixcore"

task_id = TaskId("test_task_1")

await message_bus.emit(Action(merge_outer_edges, task_id, merge_outer_edges))
await message_bus.emit(Action(merge_deferred_edges, task_id, merge_deferred_edges))

assert await merge_called == task_id
assert await merge_called == [task_id]


@pytest.mark.asyncio
async def test_merge_outer_edges(
merge_handler: MergeOuterEdgesHandler, graph_db: ArangoGraphDB, foo_model: Model, db_access: DbAccess
async def test_merge_deferred_edges(
merge_handler: MergeDeferredEdgesHandler, graph_db: ArangoGraphDB, foo_model: Model, db_access: DbAccess
) -> None:
now = utc()

Expand All @@ -68,7 +67,7 @@ async def test_merge_outer_edges(
await db_access.deferred_outer_edge_db.update(
DeferredOuterEdges("t0", "c0", TaskId("task123"), now, graph_db.name, [e1])
)
await merge_handler.merge_outer_edges(TaskId("task123"))
await merge_handler.merge_deferred_edges([TaskId("task123")])

graph = await graph_db.search_graph(QueryModel(parse_query("is(graph_root) -default[0:]->"), foo_model))
assert graph.has_edge("id1", "id2")
Expand All @@ -82,7 +81,7 @@ async def test_merge_outer_edges(
await db_access.deferred_outer_edge_db.update(
DeferredOuterEdges("t1", "c1", TaskId("task456"), new_now, graph_db.name, [e2])
)
await merge_handler.merge_outer_edges(TaskId("task456"))
await merge_handler.merge_deferred_edges([TaskId("task456")])

graph = await graph_db.search_graph(QueryModel(parse_query("is(graph_root) -default[0:]->"), foo_model))
assert not graph.has_edge("id1", "id2")
Expand All @@ -96,11 +95,12 @@ async def test_merge_outer_edges(
await db_access.deferred_outer_edge_db.update(
DeferredOuterEdges("t2", "c4", TaskId("task789"), new_now_2, graph_db.name, [e2])
)
updated, deleted = await merge_handler.merge_outer_edges(TaskId("task789"))
r = await merge_handler.merge_deferred_edges([TaskId("task789")])
assert r.processed == 1
# here we also implicitly test that the timestamp was updated, because otherwise the edge
# would have an old timestamp and would be deleted
assert updated == 1
assert deleted == 0
assert r.updated == 1
assert r.deleted == 0
graph = await graph_db.search_graph(QueryModel(parse_query("is(graph_root) -default[0:]->"), foo_model))
assert not graph.has_edge("id1", "id2")
assert graph.has_edge("id2", "id1")
Expand Down
6 changes: 3 additions & 3 deletions fixcore/tests/fixcore/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytest import fixture
from rich.console import Console

from fixcore.action_handlers.merge_outer_edge_handler import MergeOuterEdgesHandler
from fixcore.action_handlers.merge_deferred_edge_handler import MergeDeferredEdgesHandler
from fixcore.analytics import AnalyticsEventSender, InMemoryEventSender, NoEventSender
from fixcore.cli.cli import CLIService
from fixcore.cli.command import (
Expand Down Expand Up @@ -843,9 +843,9 @@ async def merge_handler(
task_handler: TaskHandlerService,
db_access: DbAccess,
foo_model: Model,
) -> AsyncGenerator[MergeOuterEdgesHandler, None]:
) -> AsyncGenerator[MergeDeferredEdgesHandler, None]:
model_handler = ModelHandlerStatic(foo_model)
handler = MergeOuterEdgesHandler(message_bus, subscription_handler, task_handler, db_access, model_handler)
handler = MergeDeferredEdgesHandler(message_bus, subscription_handler, task_handler, db_access, model_handler)
await handler.start()
yield handler
await handler.stop()
Expand Down

0 comments on commit 526b549

Please sign in to comment.