diff --git a/mula/scheduler/alembic/versions/0008_create_events_table.py b/mula/scheduler/alembic/versions/0008_create_events_table.py new file mode 100644 index 00000000000..a88807b2cbe --- /dev/null +++ b/mula/scheduler/alembic/versions/0008_create_events_table.py @@ -0,0 +1,40 @@ +"""Create events table + +Revision ID: 0008 +Revises: 0007 +Create Date: 2023-11-14 15:00:00.000000 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +import scheduler + +# revision identifiers, used by Alembic. +revision = "0008" +down_revision = "0007" +branch_labels = None +depends_on = None + + +def upgrade(): + # Add events table + op.create_table( + "events", + sa.Column("id", sa.Integer(), nullable=False, autoincrement=True), + sa.Column("task_id", scheduler.utils.datastore.GUID(), nullable=True), + sa.Column("type", sa.String(), nullable=True), + sa.Column("context", sa.String(), nullable=True), + sa.Column("event", sa.String(), nullable=True), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.Column("data", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + + op.create_index(op.f("ix_events_task_id"), "events", ["task_id"], unique=False) + + +def downgrade(): + # Drop the events table + op.drop_table("events") diff --git a/mula/scheduler/alembic/versions/0009_add_task_trigger.py b/mula/scheduler/alembic/versions/0009_add_task_trigger.py new file mode 100644 index 00000000000..65e7ce5e821 --- /dev/null +++ b/mula/scheduler/alembic/versions/0009_add_task_trigger.py @@ -0,0 +1,59 @@ +"""Add tasks trigger + +Revision ID: 0009 +Revises: 0008 +Create Date: 2023-11-14 15:00:00.000000 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "0009" +down_revision = "0008" +branch_labels = None +depends_on = None + + +def upgrade(): + # Create the record_event function + op.execute( + sa.DDL( + """ + CREATE OR REPLACE FUNCTION record_event() + RETURNS TRIGGER AS + $$ + BEGIN + IF TG_OP = 'INSERT' THEN + INSERT INTO events (task_id, type, context, event, data) + VALUES (NEW.id, 'events.db', 'task', 'insert', row_to_json(NEW)); + ELSIF TG_OP = 'UPDATE' THEN + INSERT INTO events (task_id, type, context, event, data) + VALUES (NEW.id, 'events.db', 'task', 'update', row_to_json(NEW)); + END IF; + RETURN NEW; + END; + $$ LANGUAGE plpgsql; + """ + ) + ) + + # Create the triggers + op.execute( + sa.DDL( + """ + CREATE TRIGGER tasks_insert_update_trigger + AFTER INSERT OR UPDATE ON tasks + FOR EACH ROW + EXECUTE FUNCTION record_event(); + """ + ) + ) + + +def downgrade(): + # Drop the record_event function + op.execute(sa.DDL("DROP FUNCTION IF EXISTS record_event()")) + + # Drop the trigger + op.execute(sa.DDL("DROP TRIGGER IF EXISTS tasks_insert_update_trigger ON tasks")) diff --git a/mula/scheduler/context/context.py b/mula/scheduler/context/context.py index a59688551d0..4f4d547a43b 100644 --- a/mula/scheduler/context/context.py +++ b/mula/scheduler/context/context.py @@ -9,6 +9,7 @@ from scheduler import storage from scheduler.config import settings from scheduler.connectors import services +from scheduler.models import TaskDB from scheduler.utils import remove_trailing_slash @@ -83,9 +84,12 @@ def __init__(self) -> None: **{ storage.TaskStore.name: storage.TaskStore(dbconn), storage.PriorityQueueStore.name: storage.PriorityQueueStore(dbconn), + storage.EventStore.name: storage.EventStore(dbconn), } ) + TaskDB.set_event_store(self.datastores.event_store) + # Metrics collector registry self.metrics_registry: CollectorRegistry = CollectorRegistry() diff --git a/mula/scheduler/models/__init__.py b/mula/scheduler/models/__init__.py index c083e9c7af4..97c364da5ac 100644 --- a/mula/scheduler/models/__init__.py +++ b/mula/scheduler/models/__init__.py @@ -1,6 +1,6 @@ from .base import Base from .boefje import Boefje, BoefjeMeta -from .events import RawData, RawDataReceivedEvent +from .events import Event, EventDB, RawData, RawDataReceivedEvent from .health import ServiceHealth from .normalizer import Normalizer from .ooi import OOI, MutationOperationType, ScanProfile, ScanProfileMutation diff --git a/mula/scheduler/models/events.py b/mula/scheduler/models/events.py index 60278567a97..452ea4597b4 100644 --- a/mula/scheduler/models/events.py +++ b/mula/scheduler/models/events.py @@ -1,7 +1,15 @@ -from datetime import datetime +import uuid +from datetime import datetime, timezone -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy import Column, DateTime, Integer, String +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.schema import Index +from sqlalchemy.sql import func +from scheduler.utils import GUID + +from .base import Base from .raw_data import RawData @@ -9,3 +17,44 @@ class RawDataReceivedEvent(BaseModel): created_at: datetime organization: str raw_data: RawData + + +class Event(BaseModel): + model_config = ConfigDict(from_attributes=True) + + task_id: uuid.UUID + + type: str + + context: str + + event: str + + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + data: dict + + +class EventDB(Base): + __tablename__ = "events" + + id = Column(Integer, primary_key=True) + + task_id = Column(GUID) + + type = Column(String) + + context = Column(String) + + event = Column(String) + + timestamp = Column(DateTime(timezone=True), nullable=False, server_default=func.now()) + + data = Column(JSONB, nullable=False) + + __table_args__ = ( + Index( + "ix_events_task_id", + task_id, + ), + ) diff --git a/mula/scheduler/models/tasks.py b/mula/scheduler/models/tasks.py index 7c1a1b9ac13..d477cc7b7cf 100644 --- a/mula/scheduler/models/tasks.py +++ b/mula/scheduler/models/tasks.py @@ -5,8 +5,9 @@ import mmh3 from pydantic import BaseModel, ConfigDict, Field -from sqlalchemy import Column, DateTime, Enum, String +from sqlalchemy import DDL, Column, DateTime, Enum, String, event from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.schema import Index from sqlalchemy.sql import func from sqlalchemy.sql.expression import text @@ -44,27 +45,6 @@ class TaskStatus(str, enum.Enum): CANCELLED = "cancelled" -class Task(BaseModel): - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - - scheduler_id: str - - type: str - - p_item: PrioritizedItem - - status: TaskStatus - - created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) - - modified_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) - - def __repr__(self): - return f"Task(id={self.id}, scheduler_id={self.scheduler_id}, type={self.type}, status={self.status})" - - class TaskDB(Base): __tablename__ = "tasks" @@ -103,6 +83,99 @@ class TaskDB(Base): ), ) + _event_store = None + + @classmethod + def set_event_store(cls, event_store): + cls._event_store = event_store + + @hybrid_property + def duration(self) -> float: + if self._event_store is None: + raise ValueError("EventStore instance is not set. Use TaskDB.set_event_store to set it.") + + return self._event_store.get_task_duration(self.id) + + @hybrid_property + def queued(self) -> float: + if self._event_store is None: + raise ValueError("EventStore instance is not set. Use TaskDB.set_event_store to set it.") + + return self._event_store.get_task_queued(self.id) + + @hybrid_property + def runtime(self) -> float: + if self._event_store is None: + raise ValueError("EventStore instance is not set. Use TaskDB.set_event_store to set it.") + + return self._event_store.get_task_runtime(self.id) + + @hybrid_property + def cpu(self) -> float: + if self._event_store is None: + raise ValueError("EventStore instance is not set. Use TaskDB.set_event_store to set it.") + + return self._event_store.get_task_cpu(self.id) + + @hybrid_property + def memory(self) -> float: + if self._event_store is None: + raise ValueError("EventStore instance is not set. Use TaskDB.set_event_store to set it.") + + return self._event_store.get_task_memory(self.id) + + @hybrid_property + def disk(self) -> float: + if self._event_store is None: + raise ValueError("EventStore instance is not set. Use TaskDB.set_event_store to set it.") + + return self._event_store.get_task_disk(self.id) + + @hybrid_property + def network(self) -> float: + if self._event_store is None: + raise ValueError("EventStore instance is not set. Use TaskDB.set_event_store to set it.") + + return self._event_store.get_task_network(self.id) + + +class Task(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID + + scheduler_id: str + + type: str + + p_item: PrioritizedItem + + status: TaskStatus + + duration: Optional[float] = Field(None, alias="duration", readonly=True) + + queued: Optional[float] = Field(None, alieas="queued", readonly=True) + + runtime: Optional[float] = Field(None, alias="runtime", readonly=True) + + cpu: Optional[float] = Field(None, alias="cpu", readonly=True) + + memory: Optional[float] = Field(None, alias="memory", readonly=True) + + disk: Optional[float] = Field(None, alias="disk", readonly=True) + + network: Optional[float] = Field(None, alias="network", readonly=True) + + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + modified_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + def __repr__(self): + return f"Task(id={self.id}, scheduler_id={self.scheduler_id}, type={self.type}, status={self.status})" + + def model_dump_db(self): + return self.model_dump(exclude={"duration", "queued", "runtime", "cpu", "memory", "disk", "network"}) + class NormalizerTask(BaseModel): """NormalizerTask represent data needed for a Normalizer to run.""" @@ -144,3 +217,35 @@ def hash(self) -> str: return mmh3.hash_bytes(f"{self.input_ooi}-{self.boefje.id}-{self.organization}").hex() return mmh3.hash_bytes(f"{self.boefje.id}-{self.organization}").hex() + + +func_record_event = DDL( + """ + CREATE OR REPLACE FUNCTION record_event() + RETURNS TRIGGER AS + $$ + BEGIN + IF TG_OP = 'INSERT' THEN + INSERT INTO events (task_id, type, context, event, data) + VALUES (NEW.id, 'events.db', 'task', 'insert', row_to_json(NEW)); + ELSIF TG_OP = 'UPDATE' THEN + INSERT INTO events (task_id, type, context, event, data) + VALUES (NEW.id, 'events.db', 'task', 'update', row_to_json(NEW)); + END IF; + RETURN NEW; + END; + $$ LANGUAGE plpgsql; +""" +) + +trigger_tasks_insert_update = DDL( + """ + CREATE TRIGGER tasks_insert_update_trigger + AFTER INSERT OR UPDATE ON tasks + FOR EACH ROW + EXECUTE FUNCTION record_event(); +""" +) + +event.listen(TaskDB.__table__, "after_create", func_record_event.execute_if(dialect="postgresql")) +event.listen(TaskDB.__table__, "after_create", trigger_tasks_insert_update.execute_if(dialect="postgresql")) diff --git a/mula/scheduler/server/server.py b/mula/scheduler/server/server.py index b585b1060a5..1f80387f60d 100644 --- a/mula/scheduler/server/server.py +++ b/mula/scheduler/server/server.py @@ -1,6 +1,6 @@ import datetime import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import fastapi import prometheus_client @@ -207,6 +207,14 @@ def __init__( description="Push an item to a queue", ) + self.api.add_api_route( + path="/events", + endpoint=self.list_events, + methods=["GET", "POST"], + response_model=Union[PaginatedResponse, models.Event], + description="List all task events", + ) + def root(self) -> Any: return None @@ -294,7 +302,7 @@ def list_tasks( if (min_created_at is not None and max_created_at is not None) and min_created_at > max_created_at: raise fastapi.HTTPException( status_code=fastapi.status.HTTP_400_BAD_REQUEST, - detail="min_date must be less than max_date", + detail="min_created_at cannot be greater than max_created_at", ) # FIXME: deprecated; backwards compatibility for rocky that uses the @@ -567,6 +575,85 @@ def push_queue(self, queue_id: str, item: models.PrioritizedItem) -> Any: return models.PrioritizedItem(**p_item.model_dump()) + def list_events( + self, + request: fastapi.Request, + task_id: Optional[str] = None, + type: Optional[str] = None, # noqa: A002 + context: Optional[str] = None, # noqa: A002 + event: Optional[str] = None, + min_timestamp: Optional[datetime.datetime] = None, + max_timestamp: Optional[datetime.datetime] = None, + offset: int = 0, + limit: int = 10, + filters: Optional[storage.filters.FilterRequest] = None, + item: Optional[models.Event] = None, + ) -> Any: + if item is not None and request.method == "POST": + created_event = fastapi.encoders.jsonable_encoder(self.create_event(item=item)) + + return fastapi.responses.JSONResponse( + status_code=status.HTTP_201_CREATED, + content=created_event, + ) + + if (min_timestamp is not None and max_timestamp is not None) and min_timestamp > max_timestamp: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_400_BAD_REQUEST, + detail="min_timestamp cannot be greater than max_timestamp", + ) + + try: + results, count = self.ctx.datastores.event_store.get_events( + task_id=task_id, + type=type, + context=context, + event=event, + min_timestamp=min_timestamp, + max_timestamp=max_timestamp, + offset=offset, + limit=limit, + filters=filters, + ) + except storage.filters.errors.FilterError as exc: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_400_BAD_REQUEST, + detail=str(exc), + ) from exc + except ValueError as exc: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_400_BAD_REQUEST, + detail=str(exc), + ) from exc + except Exception as exc: + self.logger.exception(exc) + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="failed to get events", + ) from exc + + return paginate(request, results, count, offset, limit) + + def create_event(self, item: models.Event) -> Any: + try: + event = models.Event(**item.dict()) + except Exception as exc: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_400_BAD_REQUEST, + detail=str(exc), + ) from exc + + try: + self.ctx.datastores.event_store.create_event(event) + except Exception as exc: + self.logger.exception(exc) + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="failed to create event", + ) from exc + + return event + def run(self) -> None: uvicorn.run( self.api, diff --git a/mula/scheduler/storage/__init__.py b/mula/scheduler/storage/__init__.py index 4f6d7155872..7ff034074ad 100644 --- a/mula/scheduler/storage/__init__.py +++ b/mula/scheduler/storage/__init__.py @@ -1,3 +1,4 @@ +from .event_store import EventStore from .filters import apply_filter from .pq_store import PriorityQueueStore from .storage import DBConn, retry diff --git a/mula/scheduler/storage/event_store.py b/mula/scheduler/storage/event_store.py new file mode 100644 index 00000000000..a03f3df8471 --- /dev/null +++ b/mula/scheduler/storage/event_store.py @@ -0,0 +1,272 @@ +from datetime import datetime +from typing import List, Optional, Tuple + +from sqlalchemy import exc + +from scheduler.models import Event, EventDB, TaskStatus + +from .filters import FilterRequest, apply_filter +from .storage import DBConn, retry + + +class EventStore: + name: str = "event_store" + + def __init__(self, dbconn: DBConn) -> None: + self.dbconn = dbconn + + @retry() + def get_events( + self, + task_id: Optional[str] = None, + type: Optional[str] = None, # noqa: A002 + context: Optional[str] = None, + event: Optional[str] = None, + min_timestamp: Optional[datetime] = None, + max_timestamp: Optional[datetime] = None, + offset: int = 0, + limit: int = 100, + filters: Optional[FilterRequest] = None, + ) -> Tuple[List[Event], int]: + with self.dbconn.session.begin() as session: + query = session.query(EventDB) + + if task_id is not None: + query = query.filter(EventDB.task_id == task_id) + + if type is not None: + query = query.filter(EventDB.type == type) + + if context is not None: + query = query.filter(EventDB.context == context) + + if event is not None: + query = query.filter(EventDB.event == event) + + if min_timestamp is not None: + query = query.filter(EventDB.timestamp >= min_timestamp) + + if max_timestamp is not None: + query = query.filter(EventDB.timestamp <= max_timestamp) + + if filters is not None: + query = apply_filter(EventDB, query, filters) + + try: + count = query.count() + events_orm = query.order_by(EventDB.timestamp.desc()).offset(offset).limit(limit).all() + except exc.ProgrammingError as e: + raise ValueError(f"Invalid filter: {e}") from e + + events = [Event.model_validate(event_orm) for event_orm in events_orm] + + return events, count + + @retry() + def create_event(self, event: Event) -> None: + with self.dbconn.session.begin() as session: + event_orm = EventDB(**event.model_dump()) + session.add(event_orm) + + created_event = Event.model_validate(event_orm) + + return created_event + + @retry() + def get_task_queued(self, task_id: str) -> float: + """Get task queued (how long has task been on queue) time in seconds""" + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None + + with self.dbconn.session.begin() as session: + query = ( + session.query(EventDB) + .filter(EventDB.task_id == task_id) + .filter(EventDB.type == "events.db") + .filter(EventDB.context == "task") + .filter(EventDB.event == "insert") + .filter(EventDB.data["status"].as_string() == TaskStatus.QUEUED.upper()) + .order_by(EventDB.timestamp.asc()) + ) + + result_start = query.first() + if result_start is not None: + start_time = result_start.timestamp + + # Get task event end time when status is completed or failed + query = ( + session.query(EventDB) + .filter(EventDB.task_id == task_id) + .filter(EventDB.type == "events.db") + .filter(EventDB.context == "task") + .filter(EventDB.event == "update") + .filter(EventDB.data["status"].as_string() == TaskStatus.DISPATCHED.upper()) + .order_by(EventDB.timestamp.desc()) + ) + + result_end = query.first() + if result_end is not None: + end_time = result_end.timestamp + + if start_time is not None and end_time is not None: + return (end_time - start_time).total_seconds() + + return 0 + + @retry() + def get_task_runtime(self, task_id: str) -> float: + """Get task runtime in seconds""" + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None + + with self.dbconn.session.begin() as session: + query = ( + session.query(EventDB) + .filter(EventDB.task_id == task_id) + .filter(EventDB.type == "events.db") + .filter(EventDB.context == "task") + .filter(EventDB.event == "update") + .filter(EventDB.data["status"].as_string() == TaskStatus.DISPATCHED.upper()) + .order_by(EventDB.timestamp.asc()) + ) + + result_start = query.first() + if result_start is not None: + start_time = result_start.timestamp + + # Get task event end time when status is completed or failed + query = ( + session.query(EventDB) + .filter(EventDB.task_id == task_id) + .filter(EventDB.type == "events.db") + .filter(EventDB.context == "task") + .filter(EventDB.event == "update") + .filter( + EventDB.data["status"].as_string().in_([TaskStatus.COMPLETED.upper(), TaskStatus.FAILED.upper()]) + ) + .order_by(EventDB.timestamp.desc()) + ) + + result_end = query.first() + if result_end is not None: + end_time = result_end.timestamp + + if start_time is not None and end_time is not None: + return (end_time - start_time).total_seconds() + + return 0 + + @retry() + def get_task_duration(self, task_id: str) -> float: + """Total duration of a task from start to finish in seconds""" + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None + + with self.dbconn.session.begin() as session: + query = ( + session.query(EventDB) + .filter(EventDB.task_id == task_id) + .filter(EventDB.type == "events.db") + .filter(EventDB.context == "task") + .filter(EventDB.event == "insert") + .filter(EventDB.data["status"].as_string() == TaskStatus.QUEUED.upper()) + .order_by(EventDB.timestamp.asc()) + ) + + result_start = query.first() + if result_start is not None: + start_time = result_start.timestamp + + # Get task event end time when status is completed or failed + query = ( + session.query(EventDB) + .filter(EventDB.task_id == task_id) + .filter(EventDB.type == "events.db") + .filter(EventDB.context == "task") + .filter(EventDB.event == "update") + .filter( + EventDB.data["status"].as_string().in_([TaskStatus.COMPLETED.upper(), TaskStatus.FAILED.upper()]) + ) + .order_by(EventDB.timestamp.desc()) + ) + + result_end = query.first() + if result_end is not None: + end_time = result_end.timestamp + + if start_time is not None and end_time is not None: + return (end_time - start_time).total_seconds() + + return 0 + + @retry() + def get_task_cpu(self, task_id: str) -> float: + with self.dbconn.session.begin() as session: + query = ( + session.query(EventDB) + .filter(EventDB.task_id == task_id) + .filter(EventDB.type == "events.runner") + .filter(EventDB.context == "task") + .filter(EventDB.event == "cpu") + .order_by(EventDB.timestamp.desc()) + ) + + result = query.first() + if result is not None: + return result.data["cpu"].as_float() + + return 0 + + @retry() + def get_task_memory(self, task_id: str) -> float: + with self.dbconn.session.begin() as session: + query = ( + session.query(EventDB) + .filter(EventDB.task_id == task_id) + .filter(EventDB.type == "events.runner") + .filter(EventDB.context == "task") + .filter(EventDB.event == "memory") + .order_by(EventDB.timestamp.desc()) + ) + + result = query.first() + if result is not None: + return result.data["memory"].as_float() + + return 0 + + @retry() + def get_task_disk(self, task_id: str) -> float: + with self.dbconn.session.begin() as session: + query = ( + session.query(EventDB) + .filter(EventDB.task_id == task_id) + .filter(EventDB.type == "events.runner") + .filter(EventDB.context == "task") + .filter(EventDB.event == "disk") + .order_by(EventDB.timestamp.desc()) + ) + + result = query.first() + if result is not None: + return result.data["disk"].as_float() + + return 0 + + @retry() + def get_task_network(self, task_id: str) -> float: + with self.dbconn.session.begin() as session: + query = ( + session.query(EventDB) + .filter(EventDB.task_id == task_id) + .filter(EventDB.type == "events.runner") + .filter(EventDB.context == "task") + .filter(EventDB.event == "network") + .order_by(EventDB.timestamp.desc()) + ) + + result = query.first() + if result is not None: + return result.data["network"].as_float() + + return 0 diff --git a/mula/scheduler/storage/pq_store.py b/mula/scheduler/storage/pq_store.py index 5df8f1112e5..4d643f682f4 100644 --- a/mula/scheduler/storage/pq_store.py +++ b/mula/scheduler/storage/pq_store.py @@ -1,7 +1,7 @@ from typing import List, Optional, Tuple from uuid import UUID -from scheduler import models +from scheduler.models import PrioritizedItem, PrioritizedItemDB from .filters import FilterRequest, apply_filter from .storage import DBConn, retry @@ -14,38 +14,36 @@ def __init__(self, dbconn: DBConn) -> None: self.dbconn = dbconn @retry() - def pop(self, scheduler_id: str, filters: Optional[FilterRequest] = None) -> Optional[models.PrioritizedItem]: + def pop(self, scheduler_id: str, filters: Optional[FilterRequest] = None) -> Optional[PrioritizedItem]: with self.dbconn.session.begin() as session: - query = session.query(models.PrioritizedItemDB).filter( - models.PrioritizedItemDB.scheduler_id == scheduler_id - ) + query = session.query(PrioritizedItemDB).filter(PrioritizedItemDB.scheduler_id == scheduler_id) if filters is not None: - query = apply_filter(models.PrioritizedItemDB, query, filters) + query = apply_filter(PrioritizedItemDB, query, filters) item_orm = query.first() if item_orm is None: return None - return models.PrioritizedItem.model_validate(item_orm) + return PrioritizedItem.model_validate(item_orm) @retry() - def push(self, scheduler_id: str, item: models.PrioritizedItem) -> Optional[models.PrioritizedItem]: + def push(self, scheduler_id: str, item: PrioritizedItem) -> Optional[PrioritizedItem]: with self.dbconn.session.begin() as session: - item_orm = models.PrioritizedItemDB(**item.model_dump()) + item_orm = PrioritizedItemDB(**item.model_dump()) session.add(item_orm) - return models.PrioritizedItem.model_validate(item_orm) + return PrioritizedItem.model_validate(item_orm) @retry() - def peek(self, scheduler_id: str, index: int) -> Optional[models.PrioritizedItem]: + def peek(self, scheduler_id: str, index: int) -> Optional[PrioritizedItem]: with self.dbconn.session.begin() as session: item_orm = ( - session.query(models.PrioritizedItemDB) - .filter(models.PrioritizedItemDB.scheduler_id == scheduler_id) - .order_by(models.PrioritizedItemDB.priority.asc()) - .order_by(models.PrioritizedItemDB.created_at.asc()) + session.query(PrioritizedItemDB) + .filter(PrioritizedItemDB.scheduler_id == scheduler_id) + .order_by(PrioritizedItemDB.priority.asc()) + .order_by(PrioritizedItemDB.created_at.asc()) .offset(index) .first() ) @@ -53,15 +51,15 @@ def peek(self, scheduler_id: str, index: int) -> Optional[models.PrioritizedItem if item_orm is None: return None - return models.PrioritizedItem.model_validate(item_orm) + return PrioritizedItem.model_validate(item_orm) @retry() - def update(self, scheduler_id: str, item: models.PrioritizedItem) -> None: + def update(self, scheduler_id: str, item: PrioritizedItem) -> None: with self.dbconn.session.begin() as session: ( - session.query(models.PrioritizedItemDB) - .filter(models.PrioritizedItemDB.scheduler_id == scheduler_id) - .filter(models.PrioritizedItemDB.id == item.id) + session.query(PrioritizedItemDB) + .filter(PrioritizedItemDB.scheduler_id == scheduler_id) + .filter(PrioritizedItemDB.id == item.id) .update(item.model_dump()) ) @@ -69,45 +67,37 @@ def update(self, scheduler_id: str, item: models.PrioritizedItem) -> None: def remove(self, scheduler_id: str, item_id: UUID) -> None: with self.dbconn.session.begin() as session: ( - session.query(models.PrioritizedItemDB) - .filter(models.PrioritizedItemDB.scheduler_id == scheduler_id) - .filter(models.PrioritizedItemDB.id == str(item_id)) + session.query(PrioritizedItemDB) + .filter(PrioritizedItemDB.scheduler_id == scheduler_id) + .filter(PrioritizedItemDB.id == str(item_id)) .delete() ) @retry() - def get(self, scheduler_id, item_id: UUID) -> Optional[models.PrioritizedItem]: + def get(self, scheduler_id, item_id: UUID) -> Optional[PrioritizedItem]: with self.dbconn.session.begin() as session: item_orm = ( - session.query(models.PrioritizedItemDB) - .filter(models.PrioritizedItemDB.scheduler_id == scheduler_id) - .filter(models.PrioritizedItemDB.id == str(item_id)) + session.query(PrioritizedItemDB) + .filter(PrioritizedItemDB.scheduler_id == scheduler_id) + .filter(PrioritizedItemDB.id == str(item_id)) .first() ) if item_orm is None: return None - return models.PrioritizedItem.model_validate(item_orm) + return PrioritizedItem.model_validate(item_orm) @retry() def empty(self, scheduler_id: str) -> bool: with self.dbconn.session.begin() as session: - count = ( - session.query(models.PrioritizedItemDB) - .filter(models.PrioritizedItemDB.scheduler_id == scheduler_id) - .count() - ) + count = session.query(PrioritizedItemDB).filter(PrioritizedItemDB.scheduler_id == scheduler_id).count() return count == 0 @retry() def qsize(self, scheduler_id: str) -> int: with self.dbconn.session.begin() as session: - count = ( - session.query(models.PrioritizedItemDB) - .filter(models.PrioritizedItemDB.scheduler_id == scheduler_id) - .count() - ) + count = session.query(PrioritizedItemDB).filter(PrioritizedItemDB.scheduler_id == scheduler_id).count() return count @@ -116,52 +106,42 @@ def get_items( self, scheduler_id: str, filters: Optional[FilterRequest], - ) -> Tuple[List[models.PrioritizedItem], int]: + ) -> Tuple[List[PrioritizedItem], int]: with self.dbconn.session.begin() as session: - query = session.query(models.PrioritizedItemDB).filter( - models.PrioritizedItemDB.scheduler_id == scheduler_id - ) + query = session.query(PrioritizedItemDB).filter(PrioritizedItemDB.scheduler_id == scheduler_id) if filters is not None: - query = apply_filter(models.PrioritizedItemDB, query, filters) + query = apply_filter(PrioritizedItemDB, query, filters) count = query.count() items_orm = query.all() - return ([models.PrioritizedItem.model_validate(item_orm) for item_orm in items_orm], count) + return ([PrioritizedItem.model_validate(item_orm) for item_orm in items_orm], count) @retry() - def get_item_by_hash(self, scheduler_id: str, item_hash: str) -> Optional[models.PrioritizedItem]: + def get_item_by_hash(self, scheduler_id: str, item_hash: str) -> Optional[PrioritizedItem]: with self.dbconn.session.begin() as session: item_orm = ( - session.query(models.PrioritizedItemDB) - .order_by(models.PrioritizedItemDB.created_at.desc()) - .filter(models.PrioritizedItemDB.scheduler_id == scheduler_id) - .filter(models.PrioritizedItemDB.hash == item_hash) + session.query(PrioritizedItemDB) + .order_by(PrioritizedItemDB.created_at.desc()) + .filter(PrioritizedItemDB.scheduler_id == scheduler_id) + .filter(PrioritizedItemDB.hash == item_hash) .first() ) if item_orm is None: return None - return models.PrioritizedItem.model_validate(item_orm) + return PrioritizedItem.model_validate(item_orm) @retry() - def get_items_by_scheduler_id(self, scheduler_id: str) -> List[models.PrioritizedItem]: + def get_items_by_scheduler_id(self, scheduler_id: str) -> List[PrioritizedItem]: with self.dbconn.session.begin() as session: - items_orm = ( - session.query(models.PrioritizedItemDB) - .filter(models.PrioritizedItemDB.scheduler_id == scheduler_id) - .all() - ) + items_orm = session.query(PrioritizedItemDB).filter(PrioritizedItemDB.scheduler_id == scheduler_id).all() - return [models.PrioritizedItem.model_validate(item_orm) for item_orm in items_orm] + return [PrioritizedItem.model_validate(item_orm) for item_orm in items_orm] @retry() def clear(self, scheduler_id: str) -> None: with self.dbconn.session.begin() as session: - ( - session.query(models.PrioritizedItemDB) - .filter(models.PrioritizedItemDB.scheduler_id == scheduler_id) - .delete() - ) + (session.query(PrioritizedItemDB).filter(PrioritizedItemDB.scheduler_id == scheduler_id).delete()) diff --git a/mula/scheduler/storage/task_store.py b/mula/scheduler/storage/task_store.py index d7256e699cf..61a7863d6be 100644 --- a/mula/scheduler/storage/task_store.py +++ b/mula/scheduler/storage/task_store.py @@ -3,7 +3,7 @@ from sqlalchemy import exc, func -from scheduler import models +from scheduler.models import Task, TaskDB, TaskStatus from .filters import FilterRequest, apply_filter from .storage import DBConn, retry @@ -26,104 +26,104 @@ def get_tasks( filters: Optional[FilterRequest] = None, offset: int = 0, limit: int = 100, - ) -> Tuple[List[models.Task], int]: + ) -> Tuple[List[Task], int]: with self.dbconn.session.begin() as session: - query = session.query(models.TaskDB) + query = session.query(TaskDB) if scheduler_id is not None: - query = query.filter(models.TaskDB.scheduler_id == scheduler_id) + query = query.filter(TaskDB.scheduler_id == scheduler_id) if task_type is not None: - query = query.filter(models.TaskDB.type == task_type) + query = query.filter(TaskDB.type == task_type) if status is not None: - query = query.filter(models.TaskDB.status == models.TaskStatus(status).name) + query = query.filter(TaskDB.status == TaskStatus(status).name) if min_created_at is not None: - query = query.filter(models.TaskDB.created_at >= min_created_at) + query = query.filter(TaskDB.created_at >= min_created_at) if max_created_at is not None: - query = query.filter(models.TaskDB.created_at <= max_created_at) + query = query.filter(TaskDB.created_at <= max_created_at) if filters is not None: - query = apply_filter(models.TaskDB, query, filters) + query = apply_filter(TaskDB, query, filters) try: count = query.count() - tasks_orm = query.order_by(models.TaskDB.created_at.desc()).offset(offset).limit(limit).all() + tasks_orm = query.order_by(TaskDB.created_at.desc()).offset(offset).limit(limit).all() except exc.ProgrammingError as e: raise ValueError(f"Invalid filter: {e}") from e - tasks = [models.Task.model_validate(task_orm) for task_orm in tasks_orm] + tasks = [Task.model_validate(task_orm) for task_orm in tasks_orm] return tasks, count @retry() - def get_task_by_id(self, task_id: str) -> Optional[models.Task]: + def get_task_by_id(self, task_id: str) -> Optional[Task]: with self.dbconn.session.begin() as session: - task_orm = session.query(models.TaskDB).filter(models.TaskDB.id == task_id).first() + task_orm = session.query(TaskDB).filter(TaskDB.id == task_id).first() if task_orm is None: return None - task = models.Task.model_validate(task_orm) + task = Task.model_validate(task_orm) return task @retry() - def get_tasks_by_hash(self, task_hash: str) -> Optional[List[models.Task]]: + def get_tasks_by_hash(self, task_hash: str) -> Optional[List[Task]]: with self.dbconn.session.begin() as session: tasks_orm = ( - session.query(models.TaskDB) - .filter(models.TaskDB.p_item["hash"].as_string() == task_hash) - .order_by(models.TaskDB.created_at.desc()) + session.query(TaskDB) + .filter(TaskDB.p_item["hash"].as_string() == task_hash) + .order_by(TaskDB.created_at.desc()) .all() ) if tasks_orm is None: return None - tasks = [models.Task.model_validate(task_orm) for task_orm in tasks_orm] + tasks = [Task.model_validate(task_orm) for task_orm in tasks_orm] return tasks @retry() - def get_latest_task_by_hash(self, task_hash: str) -> Optional[models.Task]: + def get_latest_task_by_hash(self, task_hash: str) -> Optional[Task]: with self.dbconn.session.begin() as session: task_orm = ( - session.query(models.TaskDB) - .filter(models.TaskDB.p_item["hash"].as_string() == task_hash) - .order_by(models.TaskDB.created_at.desc()) + session.query(TaskDB) + .filter(TaskDB.p_item["hash"].as_string() == task_hash) + .order_by(TaskDB.created_at.desc()) .first() ) if task_orm is None: return None - task = models.Task.model_validate(task_orm) + task = Task.model_validate(task_orm) return task @retry() - def create_task(self, task: models.Task) -> Optional[models.Task]: + def create_task(self, task: Task) -> Optional[Task]: with self.dbconn.session.begin() as session: - task_orm = models.TaskDB(**task.model_dump()) + task_orm = TaskDB(**task.model_dump_db()) session.add(task_orm) - created_task = models.Task.model_validate(task_orm) + created_task = Task.model_validate(task_orm) return created_task @retry() - def update_task(self, task: models.Task) -> None: + def update_task(self, task: Task) -> None: with self.dbconn.session.begin() as session: - (session.query(models.TaskDB).filter(models.TaskDB.id == task.id).update(task.model_dump())) + (session.query(TaskDB).filter(TaskDB.id == task.id).update(task.model_dump_db())) @retry() def cancel_tasks(self, scheduler_id: str, task_ids: List[str]) -> None: with self.dbconn.session.begin() as session: - session.query(models.TaskDB).filter( - models.TaskDB.scheduler_id == scheduler_id, models.TaskDB.id.in_(task_ids) - ).update({"status": models.TaskStatus.CANCELLED.name}) + session.query(TaskDB).filter(TaskDB.scheduler_id == scheduler_id, TaskDB.id.in_(task_ids)).update( + {"status": TaskStatus.CANCELLED.name} + ) @retry() def get_status_count_per_hour( @@ -133,26 +133,26 @@ def get_status_count_per_hour( with self.dbconn.session.begin() as session: query = ( session.query( - func.DATE_TRUNC("hour", models.TaskDB.modified_at).label("hour"), - models.TaskDB.status, - func.count(models.TaskDB.id).label("count"), + func.DATE_TRUNC("hour", TaskDB.modified_at).label("hour"), + TaskDB.status, + func.count(TaskDB.id).label("count"), ) .filter( - models.TaskDB.modified_at >= datetime.now(timezone.utc) - timedelta(hours=24), + TaskDB.modified_at >= datetime.now(timezone.utc) - timedelta(hours=24), ) - .group_by("hour", models.TaskDB.status) - .order_by("hour", models.TaskDB.status) + .group_by("hour", TaskDB.status) + .order_by("hour", TaskDB.status) ) if scheduler_id is not None: - query = query.filter(models.TaskDB.scheduler_id == scheduler_id) + query = query.filter(TaskDB.scheduler_id == scheduler_id) results = query.all() response: Dict[str, Dict[str, int]] = {} for row in results: date, status, task_count = row - response.setdefault(date.isoformat(), {k.value: 0 for k in models.TaskStatus}).update( + response.setdefault(date.isoformat(), {k.value: 0 for k in TaskStatus}).update( {status.value: task_count} ) response[date.isoformat()].update({"total": response[date.isoformat()].get("total", 0) + task_count}) @@ -163,17 +163,17 @@ def get_status_count_per_hour( def get_status_counts(self, scheduler_id: Optional[str] = None) -> Optional[Dict[str, int]]: with self.dbconn.session.begin() as session: query = ( - session.query(models.TaskDB.status, func.count(models.TaskDB.id).label("count")) - .group_by(models.TaskDB.status) - .order_by(models.TaskDB.status) + session.query(TaskDB.status, func.count(TaskDB.id).label("count")) + .group_by(TaskDB.status) + .order_by(TaskDB.status) ) if scheduler_id is not None: - query = query.filter(models.TaskDB.scheduler_id == scheduler_id) + query = query.filter(TaskDB.scheduler_id == scheduler_id) results = query.all() - response = {k.value: 0 for k in models.TaskStatus} + response = {k.value: 0 for k in TaskStatus} for row in results: status, task_count = row response[status.value] = task_count diff --git a/mula/scheduler/utils/__init__.py b/mula/scheduler/utils/__init__.py index 47047ddb189..8ca568238ef 100644 --- a/mula/scheduler/utils/__init__.py +++ b/mula/scheduler/utils/__init__.py @@ -1,4 +1,5 @@ from .datastore import GUID from .dict_utils import ExpiredError, ExpiringDict, deep_get from .functions import remove_trailing_slash +from .json import UUIDEncoder from .thread import ThreadRunner diff --git a/mula/scheduler/utils/json.py b/mula/scheduler/utils/json.py new file mode 100644 index 00000000000..bec91f1a326 --- /dev/null +++ b/mula/scheduler/utils/json.py @@ -0,0 +1,9 @@ +import json +from uuid import UUID + + +class UUIDEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, UUID): + return obj.hex + return json.JSONEncoder.default(self, obj) diff --git a/mula/tests/integration/test_api.py b/mula/tests/integration/test_api.py index 4f766da0506..73592100f75 100644 --- a/mula/tests/integration/test_api.py +++ b/mula/tests/integration/test_api.py @@ -1,4 +1,5 @@ import copy +import json import unittest import uuid from datetime import datetime, timedelta, timezone @@ -6,7 +7,7 @@ from unittest import mock from fastapi.testclient import TestClient -from scheduler import config, models, server, storage +from scheduler import config, models, server, storage, utils from tests.factories import OrganisationFactory from tests.mocks import queue as mock_queue @@ -28,9 +29,12 @@ def setUp(self): **{ storage.TaskStore.name: storage.TaskStore(self.dbconn), storage.PriorityQueueStore.name: storage.PriorityQueueStore(self.dbconn), + storage.EventStore.name: storage.EventStore(self.dbconn), } ) + models.TaskDB.set_event_store(self.mock_ctx.datastores.event_store) + # Organisation self.organisation = OrganisationFactory() @@ -556,7 +560,7 @@ def test_get_tasks_min_greater_than_max_created_at(self): } response = self.client.get("/tasks", params=params) self.assertEqual(400, response.status_code) - self.assertEqual("min_date must be less than max_date", response.json().get("detail")) + self.assertEqual("min_created_at cannot be greater than max_created_at", response.json().get("detail")) def test_get_tasks_min_created_at_future(self): # Get tasks based on datetime for something in the future, should return 0 items @@ -618,3 +622,137 @@ def test_get_tasks_stats(self): response = self.client.get(f"/tasks/stats/{self.first_item_api.get('scheduler_id')}") self.assertEqual(200, response.status_code) + + +class APIEventsEndpointTestCase(APITemplateTestCase): + def setUp(self): + super().setUp() + + # Arrange + first_event = { + "item": models.Event( + task_id=uuid.uuid4(), + type="events.db", + context="task", + event="insert", + data={"test": "test"}, + ).model_dump() + } + + first_event_json = json.dumps(first_event, cls=utils.UUIDEncoder, default=str) + self.first_event_api = self.client.post("/events", data=first_event_json).json() + + second_event = { + "item": models.Event( + task_id=uuid.uuid4(), + type="events.app", + context="user", + event="login", + data={"foo": "bar"}, + ).model_dump() + } + + second_event_json = json.dumps(second_event, cls=utils.UUIDEncoder, default=str) + self.second_event_api = self.client.post("/events", data=second_event_json).json() + + def test_create_event(self): + # Arrange + event = { + "item": models.Event( + task_id=uuid.uuid4(), + type="events.db", + context="task", + event="insert", + data={"test": "test"}, + ).model_dump() + } + + event_json = json.dumps(event, cls=utils.UUIDEncoder, default=str) + + # Act + response = self.client.post("/events", data=event_json) + + # Assert + self.assertEqual(201, response.status_code) + self.assertEqual(str(event["item"]["task_id"]), response.json().get("task_id")) + + def test_list_events(self): + response = self.client.get("/events") + self.assertEqual(200, response.status_code) + self.assertEqual(2, response.json()["count"]) + self.assertEqual(2, len(response.json()["results"])) + + def test_list_events_task_id(self): + response = self.client.get(f"/events?task_id={self.first_event_api.get('task_id')}") + self.assertEqual(200, response.status_code) + self.assertEqual(1, response.json()["count"]) + self.assertEqual(1, len(response.json()["results"])) + self.assertEqual(self.first_event_api.get("task_id"), response.json()["results"][0]["task_id"]) + + def test_list_events_type(self): + response = self.client.get(f"/events?type={self.first_event_api.get('type')}") + self.assertEqual(200, response.status_code) + self.assertEqual(1, response.json()["count"]) + self.assertEqual(1, len(response.json()["results"])) + self.assertEqual(self.first_event_api.get("type"), response.json()["results"][0]["type"]) + + def test_list_events_context(self): + response = self.client.get(f"/events?context={self.first_event_api.get('context')}") + self.assertEqual(200, response.status_code) + self.assertEqual(1, response.json()["count"]) + self.assertEqual(1, len(response.json()["results"])) + self.assertEqual(self.first_event_api.get("context"), response.json()["results"][0]["context"]) + + def test_list_events_event(self): + response = self.client.get(f"/events?event={self.first_event_api.get('event')}") + self.assertEqual(200, response.status_code) + self.assertEqual(1, response.json()["count"]) + self.assertEqual(1, len(response.json()["results"])) + self.assertEqual(self.first_event_api.get("event"), response.json()["results"][0]["event"]) + + def test_list_events_min_timestamp(self): + response = self.client.get(f"/events?min_timestamp={self.first_event_api.get('timestamp')}") + self.assertEqual(200, response.status_code) + self.assertEqual(2, response.json()["count"]) + self.assertEqual(2, len(response.json()["results"])) + + response = self.client.get(f"/events?min_timestamp={self.second_event_api.get('timestamp')}") + self.assertEqual(200, response.status_code) + self.assertEqual(1, response.json()["count"]) + self.assertEqual(1, len(response.json()["results"])) + + def test_list_events_max_timestamp(self): + response = self.client.get(f"/events?max_timestamp={self.first_event_api.get('timestamp')}") + self.assertEqual(200, response.status_code) + self.assertEqual(1, response.json()["count"]) + self.assertEqual(1, len(response.json()["results"])) + + response = self.client.get(f"/events?max_timestamp={self.second_event_api.get('timestamp')}") + self.assertEqual(200, response.status_code) + self.assertEqual(2, response.json()["count"]) + self.assertEqual(2, len(response.json()["results"])) + + def test_list_events_min_and_max_timestamp(self): + response = self.client.get( + f"/events?min_timestamp={self.first_event_api.get('timestamp')}&max_timestamp={self.second_event_api.get('timestamp')}" + ) + self.assertEqual(200, response.status_code) + self.assertEqual(2, response.json()["count"]) + self.assertEqual(2, len(response.json()["results"])) + + def test_list_events_min_timestamp_greater_than_max_timestamp(self): + response = self.client.get( + f"/events?min_timestamp={self.second_event_api.get('timestamp')}&max_timestamp={self.first_event_api.get('timestamp')}" + ) + self.assertEqual(400, response.status_code) + self.assertEqual("min_timestamp cannot be greater than max_timestamp", response.json()["detail"]) + + def test_list_events_filter(self): + response = self.client.post( + "/events", + json={"filters": {"filters": [{"column": "data", "field": "test", "operator": "eq", "value": "test"}]}}, + ) + self.assertEqual(200, response.status_code) + self.assertEqual(1, response.json()["count"]) + self.assertEqual(1, len(response.json()["results"])) + self.assertEqual("test", response.json()["results"][0]["data"]["test"]) diff --git a/mula/tests/integration/test_event_store.py b/mula/tests/integration/test_event_store.py new file mode 100644 index 00000000000..7b18764b167 --- /dev/null +++ b/mula/tests/integration/test_event_store.py @@ -0,0 +1,346 @@ +import unittest +import uuid +from datetime import datetime, timezone +from types import SimpleNamespace +from unittest import mock + +from scheduler import config, models, storage +from scheduler.storage import filters + +from tests.factories import OrganisationFactory +from tests.utils import functions + + +class EventStoreTestCase(unittest.TestCase): + def setUp(self): + # Application Context + self.mock_ctx = mock.patch("scheduler.context.AppContext").start() + self.mock_ctx.config = config.settings.Settings() + + # Database + self.dbconn = storage.DBConn(str(self.mock_ctx.config.db_uri)) + models.Base.metadata.create_all(self.dbconn.engine) + self.mock_ctx.datastores = SimpleNamespace( + **{ + storage.TaskStore.name: storage.TaskStore(self.dbconn), + storage.PriorityQueueStore.name: storage.PriorityQueueStore(self.dbconn), + storage.EventStore.name: storage.EventStore(self.dbconn), + } + ) + + models.TaskDB.set_event_store(self.mock_ctx.datastores.event_store) + + # Organisation + self.organisation = OrganisationFactory() + + def tearDown(self): + models.Base.metadata.drop_all(self.dbconn.engine) + self.dbconn.engine.dispose() + + def test_record_event_trigger(self): + # Arrange + p_item = functions.create_p_item(self.organisation.id, 1) + task = functions.create_task(p_item) + task_db = self.mock_ctx.datastores.task_store.create_task(task) + + # Act + task_db.status = models.TaskStatus.DISPATCHED + self.mock_ctx.datastores.task_store.update_task(task_db) + + task_db.status = models.TaskStatus.COMPLETED + self.mock_ctx.datastores.task_store.update_task(task_db) + + # Assert + events = self.mock_ctx.datastores.event_store.get_events() + self.assertGreater(len(events), 0) + + def test_get_task_duration(self): + # Arrange + p_item = functions.create_p_item(self.organisation.id, 1) + task = functions.create_task(p_item) + task_db = self.mock_ctx.datastores.task_store.create_task(task) + + # Act + task_db.status = models.TaskStatus.DISPATCHED + self.mock_ctx.datastores.task_store.update_task(task_db) + + task_db.status = models.TaskStatus.COMPLETED + self.mock_ctx.datastores.task_store.update_task(task_db) + + # Assert + duration = self.mock_ctx.datastores.event_store.get_task_duration(task.id) + self.assertGreater(duration, 0) + + def test_get_task_queued(self): + # Arrange + p_item = functions.create_p_item(self.organisation.id, 1) + task = functions.create_task(p_item) + task_db = self.mock_ctx.datastores.task_store.create_task(task) + + # Act + task_db.status = models.TaskStatus.DISPATCHED + self.mock_ctx.datastores.task_store.update_task(task_db) + + task_db.status = models.TaskStatus.COMPLETED + self.mock_ctx.datastores.task_store.update_task(task_db) + + # Assert + queued = self.mock_ctx.datastores.event_store.get_task_queued(task.id) + self.assertGreater(queued, 0) + + def test_get_task_runtime(self): + # Arrange + p_item = functions.create_p_item(self.organisation.id, 1) + task = functions.create_task(p_item) + task_db = self.mock_ctx.datastores.task_store.create_task(task) + + # Act + task_db.status = models.TaskStatus.DISPATCHED + self.mock_ctx.datastores.task_store.update_task(task_db) + + task_db.status = models.TaskStatus.COMPLETED + self.mock_ctx.datastores.task_store.update_task(task_db) + + # Assert + runtime = self.mock_ctx.datastores.event_store.get_task_runtime(task.id) + self.assertGreater(runtime, 0) + + def test_create_event(self): + # Arrange + event = models.Event( + task_id=uuid.uuid4(), + type="events.db", + context="task", + event="insert", + timestamp=datetime.now(timezone.utc), + data={"test": "test"}, + ) + + # Act + self.mock_ctx.datastores.event_store.create_event(event) + + # Assert + events, count = self.mock_ctx.datastores.event_store.get_events() + self.assertEqual(count, 1) + self.assertEqual(events[0].task_id, event.task_id) + + def test_get_events(self): + # Arrange + first_event = models.Event( + task_id=uuid.uuid4(), + type="events.db", + context="task", + event="insert", + timestamp=datetime.now(timezone.utc), + data={"test": "test"}, + ) + + second_event = models.Event( + task_id=uuid.uuid4(), + type="events.db", + context="task", + event="insert", + timestamp=datetime.now(timezone.utc), + data={"test": "test"}, + ) + + # Act + self.mock_ctx.datastores.event_store.create_event(first_event) + self.mock_ctx.datastores.event_store.create_event(second_event) + + # Assert + events, count = self.mock_ctx.datastores.event_store.get_events() + self.assertEqual(count, 2) + self.assertEqual(events[0].task_id, second_event.task_id) + self.assertEqual(events[1].task_id, first_event.task_id) + + def test_get_events_task_id(self): + # Arange + first_event = models.Event( + task_id=uuid.uuid4(), + type="events.db", + context="task", + event="insert", + timestamp=datetime.now(timezone.utc), + data={"test": "test"}, + ) + + second_event = models.Event( + task_id=uuid.uuid4(), + type="events.db", + context="task", + event="insert", + timestamp=datetime.now(timezone.utc), + data={"test": "test"}, + ) + + self.mock_ctx.datastores.event_store.create_event(first_event) + self.mock_ctx.datastores.event_store.create_event(second_event) + + # Act + events, count = self.mock_ctx.datastores.event_store.get_events(task_id=first_event.task_id) + + # Assert + self.assertEqual(count, 1) + self.assertEqual(events[0].task_id, first_event.task_id) + + def test_get_events_type(self): + # Arange + first_event = models.Event( + task_id=uuid.uuid4(), + type="events.db", + context="task", + event="insert", + timestamp=datetime.now(timezone.utc), + data={"test": "test"}, + ) + + second_event = models.Event( + task_id=uuid.uuid4(), + type="events.app", + context="task", + event="insert", + timestamp=datetime.now(timezone.utc), + data={"test": "test"}, + ) + + self.mock_ctx.datastores.event_store.create_event(first_event) + self.mock_ctx.datastores.event_store.create_event(second_event) + + # Act + events, count = self.mock_ctx.datastores.event_store.get_events(type="events.db") + + # Assert + self.assertEqual(count, 1) + self.assertEqual(events[0].type, "events.db") + + def test_get_events_context(self): + # Arange + first_event = models.Event( + task_id=uuid.uuid4(), + type="events.db", + context="task", + event="insert", + timestamp=datetime.now(timezone.utc), + data={"test": "test"}, + ) + + second_event = models.Event( + task_id=uuid.uuid4(), + type="events.db", + context="task2", + event="insert", + timestamp=datetime.now(timezone.utc), + data={"test": "test"}, + ) + + self.mock_ctx.datastores.event_store.create_event(first_event) + self.mock_ctx.datastores.event_store.create_event(second_event) + + # Act + events, count = self.mock_ctx.datastores.event_store.get_events(context="task") + + # Assert + self.assertEqual(count, 1) + self.assertEqual(events[0].context, "task") + + def test_get_events_min_timestamp(self): + # Arange + first_event = models.Event( + task_id=uuid.uuid4(), + type="events.db", + context="task", + event="insert", + timestamp=datetime.now(timezone.utc), + data={"test": "test"}, + ) + + second_event = models.Event( + task_id=uuid.uuid4(), + type="events.db", + context="task", + event="insert", + timestamp=datetime.now(timezone.utc), + data={"test": "test"}, + ) + + self.mock_ctx.datastores.event_store.create_event(first_event) + self.mock_ctx.datastores.event_store.create_event(second_event) + + # Act + events, count = self.mock_ctx.datastores.event_store.get_events(min_timestamp=first_event.timestamp) + + # Assert + self.assertEqual(count, 2) + self.assertEqual(events[0].task_id, second_event.task_id) + self.assertEqual(events[1].task_id, first_event.task_id) + + def test_get_events_max_timestamp(self): + # Arange + first_event = models.Event( + task_id=uuid.uuid4(), + type="events.db", + context="task", + event="insert", + timestamp=datetime.now(timezone.utc), + data={"test": "test"}, + ) + + second_event = models.Event( + task_id=uuid.uuid4(), + type="events.db", + context="task", + event="insert", + timestamp=datetime.now(timezone.utc), + data={"test": "test"}, + ) + + self.mock_ctx.datastores.event_store.create_event(first_event) + self.mock_ctx.datastores.event_store.create_event(second_event) + + # Act + events, count = self.mock_ctx.datastores.event_store.get_events(max_timestamp=first_event.timestamp) + + # Assert + self.assertEqual(count, 1) + self.assertEqual(events[0].task_id, first_event.task_id) + + def test_get_events_filter(self): + # Arrange + first_event = models.Event( + task_id=uuid.uuid4(), + type="events.db", + context="task", + event="insert", + timestamp=datetime.now(timezone.utc), + data={"test": "test"}, + ) + + second_event = models.Event( + task_id=uuid.uuid4(), + type="events.db", + context="task", + event="insert", + timestamp=datetime.now(timezone.utc), + data={"test": "test"}, + ) + + # Act + first_event_db = self.mock_ctx.datastores.event_store.create_event(first_event) + self.mock_ctx.datastores.event_store.create_event(second_event) + + # Assert + f_req = filters.FilterRequest( + filters=[ + filters.Filter( + column="task_id", + field=None, + operator="eq", + value=first_event_db.task_id.hex, + ) + ], + ) + + events, count = self.mock_ctx.datastores.event_store.get_events(filters=f_req) + self.assertEqual(count, 1) + self.assertEqual(events[0].task_id, first_event.task_id) diff --git a/mula/tests/scripts/.gitignore b/mula/tests/scripts/.gitignore new file mode 100644 index 00000000000..203300891ee --- /dev/null +++ b/mula/tests/scripts/.gitignore @@ -0,0 +1 @@ +data.csv diff --git a/mula/tests/scripts/load.py b/mula/tests/scripts/load.py index ef08a91114c..dc03b2bc8de 100644 --- a/mula/tests/scripts/load.py +++ b/mula/tests/scripts/load.py @@ -14,7 +14,7 @@ def run(): # Create organisations orgs: List[Dict[str, Any]] = [] - for n in range(1, 10): + for n in range(0, 1): org = { "id": f"org-{n}", "name": f"Organisation {n}", @@ -62,10 +62,16 @@ def run(): print("Enabled boefje ", boefje_id) + count = 0 + limit = 10 + declarations: List[Dict[str, Any]] = [] with Path("data.csv").open(newline="") as csv_file: csv_reader = csv.DictReader(csv_file, delimiter=",", quotechar='"') for row in csv_reader: + if count >= limit: + break + name = row["name"] declaration = { "ooi": { @@ -85,6 +91,7 @@ def run(): "task_id": str(uuid.uuid4()), } declarations.append(declaration) + count += 1 for org in orgs: for declaration in declarations: