From 968df3b8e32b9053843bdd1e526d8a77f11035f1 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Mon, 9 Dec 2024 16:27:11 -0500 Subject: [PATCH] Switching background tasks for threads on client and server updates (#122) While I was testing the app this week, I realized FastAPI's Background Tasks, which we use for client and server metric updates, are not really suitable for long running tasks as they are executed serially instead of in parallel. Here I am switching from that to plain python threads, and also making a few other changes: Adding a database connection in the listeners as FastAPI's db instance cannot be shared between threads Changing the listener functions to be async Converting the sync database functions to async in the Job Getting rid of the sync database connection, we don't need it anymore For metrics reporter, avoid saving the metrics if it's exactly the same as what's already stored. This will eliminate an issue with updates that are too frequent into Redis with no new information, which makes the app do a lot of unnecessary work. Fixing the metrics in the UI progress section for the updated metric names --- README.md | 2 +- florist/api/db/config.py | 4 + florist/api/db/entities.py | 39 ++-- florist/api/monitoring/metrics.py | 11 + florist/api/routes/server/training.py | 92 ++++---- florist/api/server.py | 10 +- florist/app/jobs/details/page.tsx | 62 +++-- .../tests/integration/api/db/test_entities.py | 37 +-- florist/tests/integration/api/utils.py | 2 - .../tests/unit/api/monitoring/test_metrics.py | 25 ++ .../unit/api/routes/server/test_training.py | 220 ++++++++++-------- .../tests/unit/app/jobs/details/page.test.tsx | 58 +++-- package.json | 4 +- 13 files changed, 320 insertions(+), 246 deletions(-) create mode 100644 florist/api/db/config.py diff --git a/README.md b/README.md index 9c588be4..ede6719a 100644 --- a/README.md +++ b/README.md @@ -114,7 +114,7 @@ docker start redis-florist-client To start the client back-end service: ```shell -uvicorn florist.api.client:app --reload --port 8001 +python -m uvicorn florist.api.client:app --reload --port 8001 ``` The service will be available at `http://localhost:8001`. diff --git a/florist/api/db/config.py b/florist/api/db/config.py new file mode 100644 index 00000000..76b674c9 --- /dev/null +++ b/florist/api/db/config.py @@ -0,0 +1,4 @@ +"""Database configuration parameters.""" + +MONGODB_URI = "mongodb://localhost:27017/" +DATABASE_NAME = "florist-server" diff --git a/florist/api/db/entities.py b/florist/api/db/entities.py index 6bcc9519..fba0bb94 100644 --- a/florist/api/db/entities.py +++ b/florist/api/db/entities.py @@ -8,7 +8,6 @@ from fastapi.encoders import jsonable_encoder from motor.motor_asyncio import AsyncIOMotorDatabase from pydantic import BaseModel, Field -from pymongo.database import Database from pymongo.results import UpdateResult from florist.api.clients.common import Client @@ -164,47 +163,38 @@ async def set_status(self, status: JobStatus, database: AsyncIOMotorDatabase[Any update_result = await job_collection.update_one({"_id": self.id}, {"$set": {"status": status.value}}) assert_updated_successfully(update_result) - def set_status_sync(self, status: JobStatus, database: Database[Dict[str, Any]]) -> None: - """ - Sync function to save the status in the database under the current job's id. - - :param status: (JobStatus) the status to be saved in the database. - :param database: (pymongo.database.Database) The database where the job collection is stored. - """ - job_collection = database[JOB_COLLECTION_NAME] - self.status = status - update_result = job_collection.update_one({"_id": self.id}, {"$set": {"status": status.value}}) - assert_updated_successfully(update_result) - - def set_server_metrics( + async def set_server_metrics( self, server_metrics: Dict[str, Any], - database: Database[Dict[str, Any]], + database: AsyncIOMotorDatabase[Any], ) -> None: """ - Sync function to save the server's metrics in the database under the current job's id. + Save the server's metrics in the database under the current job's id. :param server_metrics: (Dict[str, Any]) the server metrics to be saved. - :param database: (pymongo.database.Database) The database where the job collection is stored. + :param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored. """ job_collection = database[JOB_COLLECTION_NAME] self.server_metrics = json.dumps(server_metrics) - update_result = job_collection.update_one({"_id": self.id}, {"$set": {"server_metrics": self.server_metrics}}) + update_result = await job_collection.update_one( + {"_id": self.id}, + {"$set": {"server_metrics": self.server_metrics}}, + ) assert_updated_successfully(update_result) - def set_client_metrics( + async def set_client_metrics( self, client_uuid: str, client_metrics: Dict[str, Any], - database: Database[Dict[str, Any]], + database: AsyncIOMotorDatabase[Any], ) -> None: """ - Sync function to save a clients' metrics in the database under the current job's id. + Save a client's metrics in the database under the current job's id. :param client_uuid: (str) the client's uuid whose produced the metrics. :param client_metrics: (Dict[str, Any]) the client's metrics to be saved. - :param database: (pymongo.database.Database) The database where the job collection is stored. + :param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored. """ assert ( self.clients_info is not None and client_uuid in [c.uuid for c in self.clients_info] @@ -215,8 +205,9 @@ def set_client_metrics( for i in range(len(self.clients_info)): if client_uuid == self.clients_info[i].uuid: self.clients_info[i].metrics = json.dumps(client_metrics) - update_result = job_collection.update_one( - {"_id": self.id}, {"$set": {f"clients_info.{i}.metrics": self.clients_info[i].metrics}} + update_result = await job_collection.update_one( + {"_id": self.id}, + {"$set": {f"clients_info.{i}.metrics": self.clients_info[i].metrics}}, ) assert_updated_successfully(update_result) diff --git a/florist/api/monitoring/metrics.py b/florist/api/monitoring/metrics.py index f1d4502d..9c3f83bb 100644 --- a/florist/api/monitoring/metrics.py +++ b/florist/api/monitoring/metrics.py @@ -116,6 +116,17 @@ def dump(self) -> None: assert self.run_id is not None, "Run ID is None, ensure reporter is initialized prior to dumping metrics." encoded_metrics = json.dumps(self.metrics, cls=DateTimeEncoder) + + previous_metrics_blob = self.redis_connection.get(self.run_id) + if previous_metrics_blob is not None and isinstance(previous_metrics_blob, bytes): + previous_metrics = json.loads(previous_metrics_blob) + current_metrics = json.loads(encoded_metrics) + if current_metrics == previous_metrics: + log( + DEBUG, f"Skipping dumping: previous metrics are the same as current metrics at key '{self.run_id}'" + ) + return + log(DEBUG, f"Dumping metrics to redis at key '{self.run_id}': {encoded_metrics}") self.redis_connection.set(self.run_id, encoded_metrics) log(DEBUG, f"Notifying redis channel '{self.run_id}'") diff --git a/florist/api/routes/server/training.py b/florist/api/routes/server/training.py index f00778f3..bcc2f965 100644 --- a/florist/api/routes/server/training.py +++ b/florist/api/routes/server/training.py @@ -1,14 +1,17 @@ """FastAPI routes for training.""" +import asyncio import logging from json import JSONDecodeError -from typing import Any, Dict, List +from threading import Thread +from typing import Any, List import requests -from fastapi import APIRouter, BackgroundTasks, Request +from fastapi import APIRouter, Request from fastapi.responses import JSONResponse -from pymongo.database import Database +from motor.motor_asyncio import AsyncIOMotorClient +from florist.api.db.config import DATABASE_NAME, MONGODB_URI from florist.api.db.entities import ClientInfo, Job, JobStatus from florist.api.monitoring.metrics import get_from_redis, get_subscriber, wait_for_metric from florist.api.servers.common import Model @@ -25,15 +28,13 @@ @router.post("/start") -async def start(job_id: str, request: Request, background_tasks: BackgroundTasks) -> JSONResponse: +async def start(job_id: str, request: Request) -> JSONResponse: """ Start FL training for a job id by starting a FL server and its clients. :param job_id: (str) The id of the Job record in the DB which contains the information necessary to start training. :param request: (fastapi.Request) the FastAPI request object. - :param background_tasks: (BackgroundTasks) A BackgroundTasks instance to launch the training listener, - which will update the progress of the training job. :return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the server and the clients in the format below. The UUIDs can be used to pull metrics from Redis. { @@ -105,11 +106,13 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks await job.set_uuids(server_uuid, client_uuids, request.app.database) - # Start the server training listener as a background task to update + # Start the server training listener and client training listeners as threads to update # the job's metrics and status once the training is done - background_tasks.add_task(server_training_listener, job, request.app.synchronous_database) + server_listener_thread = Thread(target=asyncio.run, args=(server_training_listener(job),)) + server_listener_thread.start() for client_info in job.clients_info: - background_tasks.add_task(client_training_listener, job, client_info, request.app.synchronous_database) + client_listener_thread = Thread(target=asyncio.run, args=(client_training_listener(job, client_info),)) + client_listener_thread.start() # Return the UUIDs return JSONResponse({"server_uuid": server_uuid, "client_uuids": client_uuids}) @@ -126,7 +129,7 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks return JSONResponse({"error": str(ex)}, status_code=500) -def client_training_listener(job: Job, client_info: ClientInfo, database: Database[Dict[str, Any]]) -> None: +async def client_training_listener(job: Job, client_info: ClientInfo) -> None: """ Listen to the Redis' channel that reports updates on the training process of a FL client. @@ -134,22 +137,23 @@ def client_training_listener(job: Job, client_info: ClientInfo, database: Databa :param job: (Job) The job that has this client's metrics. :param client_info: (ClientInfo) The ClientInfo with the client_uuid to listen to. - :param database: (pymongo.database.Database) An instance of the database to save the information - into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async - because of limitations with FastAPI's BackgroundTasks. """ LOGGER.info(f"Starting listener for client messages from job {job.id} at channel {client_info.uuid}") assert client_info.uuid is not None, "client_info.uuid is None." + db_client: AsyncIOMotorClient[Any] = AsyncIOMotorClient(MONGODB_URI) + database = db_client[DATABASE_NAME] + # check if training has already finished before start listening client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port) - LOGGER.debug(f"Listener: Current metrics for client {client_info.uuid}: {client_metrics}") + LOGGER.debug(f"Client listener: Current metrics for client {client_info.uuid}: {client_metrics}") if client_metrics is not None: - LOGGER.info(f"Listener: Updating client metrics for client {client_info.uuid} on job {job.id}") - job.set_client_metrics(client_info.uuid, client_metrics, database) - LOGGER.info(f"Listener: Client metrics for client {client_info.uuid} on {job.id} has been updated.") + LOGGER.info(f"Client listener: Updating client metrics for client {client_info.uuid} on job {job.id}") + await job.set_client_metrics(client_info.uuid, client_metrics, database) + LOGGER.info(f"Client listener: Client metrics for client {client_info.uuid} on {job.id} have been updated.") if "shutdown" in client_metrics: + db_client.close() return subscriber = get_subscriber(client_info.uuid, client_info.redis_host, client_info.redis_port) @@ -158,16 +162,22 @@ def client_training_listener(job: Job, client_info: ClientInfo, database: Databa if message["type"] == "message": # The contents of the message do not matter, we just use it to get notified client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port) - LOGGER.debug(f"Listener: Current metrics for client {client_info.uuid}: {client_metrics}") + LOGGER.debug(f"Client listener: Current metrics for client {client_info.uuid}: {client_metrics}") + if client_metrics is not None: - LOGGER.info(f"Listener: Updating client metrics for client {client_info.uuid} on job {job.id}") - job.set_client_metrics(client_info.uuid, client_metrics, database) - LOGGER.info(f"Listener: Client metrics for client {client_info.uuid} on {job.id} has been updated.") + LOGGER.info(f"Client listener: Updating client metrics for client {client_info.uuid} on job {job.id}") + await job.set_client_metrics(client_info.uuid, client_metrics, database) + LOGGER.info( + f"Client listener: Client metrics for client {client_info.uuid} on {job.id} have been updated." + ) if "shutdown" in client_metrics: + db_client.close() return + db_client.close() + -def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> None: +async def server_training_listener(job: Job) -> None: """ Listen to the Redis' channel that reports updates on the training process of a FL server. @@ -176,9 +186,6 @@ def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> No to the job in the database. :param job: (Job) The job with the server_uuid to listen to. - :param database: (pymongo.database.Database) An instance of the database to save the information - into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async - because of limitations with FastAPI's BackgroundTasks. """ LOGGER.info(f"Starting listener for server messages from job {job.id} at channel {job.server_uuid}") @@ -186,17 +193,21 @@ def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> No assert job.redis_host is not None, "job.redis_host is None." assert job.redis_port is not None, "job.redis_port is None." + db_client: AsyncIOMotorClient[Any] = AsyncIOMotorClient(MONGODB_URI) + database = db_client[DATABASE_NAME] + # check if training has already finished before start listening server_metrics = get_from_redis(job.server_uuid, job.redis_host, job.redis_port) - LOGGER.debug(f"Listener: Current metrics for job {job.id}: {server_metrics}") + LOGGER.debug(f"Server listener: Current metrics for job {job.id}: {server_metrics}") if server_metrics is not None: - LOGGER.info(f"Listener: Updating server metrics for job {job.id}") - job.set_server_metrics(server_metrics, database) - LOGGER.info(f"Listener: Server metrics for {job.id} has been updated.") + LOGGER.info(f"Server listener: Updating server metrics for job {job.id}") + await job.set_server_metrics(server_metrics, database) + LOGGER.info(f"Server listener: Server metrics for {job.id} have been updated.") if "fit_end" in server_metrics: - LOGGER.info(f"Listener: Training finished for job {job.id}") - job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database) - LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.") + LOGGER.info(f"Server listener: Training finished for job {job.id}") + await job.set_status(JobStatus.FINISHED_SUCCESSFULLY, database) + LOGGER.info(f"Server listener: Job {job.id} status have been set to {job.status.value}.") + db_client.close() return subscriber = get_subscriber(job.server_uuid, job.redis_host, job.redis_port) @@ -205,14 +216,17 @@ def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> No if message["type"] == "message": # The contents of the message do not matter, we just use it to get notified server_metrics = get_from_redis(job.server_uuid, job.redis_host, job.redis_port) - LOGGER.debug(f"Listener: Message received for job {job.id}. Metrics: {server_metrics}") + LOGGER.debug(f"Server listener: Message received for job {job.id}. Metrics: {server_metrics}") if server_metrics is not None: - LOGGER.info(f"Listener: Updating server metrics for job {job.id}") - job.set_server_metrics(server_metrics, database) - LOGGER.info(f"Listener: Server metrics for {job.id} has been updated.") + LOGGER.info(f"Server listener: Updating server metrics for job {job.id}") + await job.set_server_metrics(server_metrics, database) + LOGGER.info(f"Server listener: Server metrics for {job.id} have been updated.") if "fit_end" in server_metrics: - LOGGER.info(f"Listener: Training finished for job {job.id}") - job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database) - LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.") + LOGGER.info(f"Server listener: Training finished for job {job.id}") + await job.set_status(JobStatus.FINISHED_SUCCESSFULLY, database) + LOGGER.info(f"Server listener: Job {job.id} status have been set to {job.status.value}.") + db_client.close() return + + db_client.close() diff --git a/florist/api/server.py b/florist/api/server.py index b38dd416..423d99ef 100644 --- a/florist/api/server.py +++ b/florist/api/server.py @@ -6,34 +6,26 @@ from fastapi import FastAPI from fastapi.responses import JSONResponse from motor.motor_asyncio import AsyncIOMotorClient -from pymongo import MongoClient from florist.api.clients.common import Client +from florist.api.db.config import DATABASE_NAME, MONGODB_URI from florist.api.routes.server.job import router as job_router from florist.api.routes.server.status import router as status_router from florist.api.routes.server.training import router as training_router from florist.api.servers.common import Model -MONGODB_URI = "mongodb://localhost:27017/" -DATABASE_NAME = "florist-server" - - @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[Any, Any]: """Set up function for app startup and shutdown.""" # Set up mongodb app.db_client = AsyncIOMotorClient(MONGODB_URI) # type: ignore[attr-defined] app.database = app.db_client[DATABASE_NAME] # type: ignore[attr-defined] - # Setting up a synchronous database connection for background tasks - app.synchronous_db_client = MongoClient(MONGODB_URI) # type: ignore[attr-defined] - app.synchronous_database = app.synchronous_db_client[DATABASE_NAME] # type: ignore[attr-defined] yield # Shut down mongodb app.db_client.close() # type: ignore[attr-defined] - app.synchronous_db_client.close() # type: ignore[attr-defined] app = FastAPI(lifespan=lifespan) diff --git a/florist/app/jobs/details/page.tsx b/florist/app/jobs/details/page.tsx index 63d36f81..3b82fa26 100644 --- a/florist/app/jobs/details/page.tsx +++ b/florist/app/jobs/details/page.tsx @@ -55,11 +55,9 @@ export function JobDetailsBody(): ReactElement { } let totalEpochs = null; - let localEpochs = null; if (job.server_config) { const serverConfigJson = JSON.parse(job.server_config); totalEpochs = serverConfigJson.n_server_rounds; - localEpochs = serverConfigJson.local_epochs; } return ( @@ -125,7 +123,7 @@ export function JobDetailsBody(): ReactElement { Component={JobDetailsClientsInfoTable} title="Clients Configuration" data={job.clients_info} - properties={{ localEpochs }} + properties={{ totalEpochs }} /> ); @@ -194,16 +192,16 @@ export function JobProgressBar({ let endRoundKey; if (metricsJson.host_type === "server") { - endRoundKey = "fit_end"; + endRoundKey = "eval_round_end"; } if (metricsJson.host_type === "client") { - endRoundKey = "shutdown"; + endRoundKey = "round_end"; } let progressPercent = 0; if ("rounds" in metricsJson && Object.keys(metricsJson.rounds).length > 0) { const lastRound = Math.max(...Object.keys(metricsJson.rounds)); - const lastCompletedRound = endRoundKey in metricsJson ? lastRound : lastRound - 1; + const lastCompletedRound = endRoundKey in metricsJson.rounds[lastRound] ? lastRound : lastRound - 1; progressPercent = (lastCompletedRound * 100) / totalEpochs; } const progressWidth = progressPercent === 0 ? "100%" : `${progressPercent}%`; @@ -403,17 +401,39 @@ export function JobProgressRoundDetails({ roundMetrics, index }: { roundMetrics: return null; } - let fitElapsedTime = ""; + let fitStart = null; + let fitEnd = null; if ("fit_start" in roundMetrics) { - const startDate = Date.parse(roundMetrics.fit_start); - const endDate = "fit_end" in roundMetrics ? Date.parse(roundMetrics.fit_end) : Date.now(); + fitStart = roundMetrics.fit_start; + fitEnd = roundMetrics.fit_end; + } + if ("fit_round_start" in roundMetrics) { + fitStart = roundMetrics.fit_round_start; + fitEnd = roundMetrics.fit_round_end; + } + + let fitElapsedTime = ""; + if (fitStart) { + const startDate = Date.parse(fitStart); + const endDate = fitEnd ? Date.parse(fitEnd) : Date.now(); fitElapsedTime = getTimeString(endDate - startDate); } + let evalStart = null; + let evalEnd = null; + if ("eval_start" in roundMetrics) { + evalStart = roundMetrics.eval_start; + evalEnd = roundMetrics.eval_end; + } + if ("eval_round_start" in roundMetrics) { + evalStart = roundMetrics.eval_round_start; + evalEnd = roundMetrics.eval_round_end; + } + let evaluateElapsedTime = ""; - if ("evaluate_start" in roundMetrics) { - const startDate = Date.parse(roundMetrics.evaluate_start); - const endDate = "evaluate_end" in roundMetrics ? Date.parse(roundMetrics.evaluate_end) : Date.now(); + if (evalStart !== null) { + const startDate = Date.parse(evalStart); + const endDate = evalEnd ? Date.parse(evalEnd) : Date.now(); evaluateElapsedTime = getTimeString(endDate - startDate); } @@ -429,13 +449,13 @@ export function JobProgressRoundDetails({ roundMetrics, index }: { roundMetrics:
Fit start time:
-
{"fit_start" in roundMetrics ? roundMetrics.fit_start : null}
+
{fitStart}
Fit end time:
-
{"fit_end" in roundMetrics ? roundMetrics.fit_end : null}
+
{fitEnd}
@@ -447,13 +467,13 @@ export function JobProgressRoundDetails({ roundMetrics, index }: { roundMetrics:
Evaluate start time:
-
{"evaluate_start" in roundMetrics ? roundMetrics.evaluate_start : null}
+
{evalStart}
Evaluate end time:
-
{"evaluate_end" in roundMetrics ? roundMetrics.evaluate_end : null}
+
{evalEnd}
{Object.keys(roundMetrics).map((name, i) => ( @@ -467,8 +487,12 @@ export function JobProgressProperty({ name, value }: { name: string; value: stri [ "fit_start", "fit_end", - "evaluate_start", - "evaluate_end", + "fit_round_start", + "fit_round_end", + "eval_start", + "eval_end", + "eval_round_start", + "eval_round_end", "rounds", "host_type", "initialized", @@ -643,7 +667,7 @@ export function JobDetailsClientsInfoTable({ diff --git a/florist/tests/integration/api/db/test_entities.py b/florist/tests/integration/api/db/test_entities.py index 6b8234e6..6d1ddcb5 100644 --- a/florist/tests/integration/api/db/test_entities.py +++ b/florist/tests/integration/api/db/test_entities.py @@ -144,31 +144,6 @@ async def test_set_status_fail_update_result(mock_request) -> None: await test_job.set_status(JobStatus.IN_PROGRESS, mock_request.app.database) -async def test_set_status_sync_success(mock_request) -> None: - test_job = get_test_job() - result_id = await test_job.create(mock_request.app.database) - test_job.id = result_id - test_job.clients_info[0].id = ANY - test_job.clients_info[1].id = ANY - - test_status = JobStatus.IN_PROGRESS - - test_job.set_status_sync(test_status, mock_request.app.synchronous_database) - - result_job = await Job.find_by_id(result_id, mock_request.app.database) - test_job.status = test_status - assert result_job == test_job - - -async def test_set_status_sync_fail_update_result(mock_request) -> None: - test_job = get_test_job() - test_job.id = str(test_job.id) - - error_msg = "UpdateResult's 'n' is not 1" - with raises(AssertionError, match=re.escape(error_msg)): - test_job.set_status_sync(JobStatus.IN_PROGRESS, mock_request.app.synchronous_database) - - async def test_set_server_metrics_success(mock_request) -> None: test_job = get_test_job() result_id = await test_job.create(mock_request.app.database) @@ -178,7 +153,7 @@ async def test_set_server_metrics_success(mock_request) -> None: test_server_metrics = {"test-server": 123} - test_job.set_server_metrics(test_server_metrics, mock_request.app.synchronous_database) + await test_job.set_server_metrics(test_server_metrics, mock_request.app.database) result_job = await Job.find_by_id(result_id, mock_request.app.database) test_job.server_metrics = json.dumps(test_server_metrics) @@ -193,7 +168,7 @@ async def test_set_server_metrics_fail_update_result(mock_request) -> None: error_msg = "UpdateResult's 'n' is not 1" with raises(AssertionError, match=re.escape(error_msg)): - test_job.set_server_metrics(test_server_metrics, mock_request.app.synchronous_database) + await test_job.set_server_metrics(test_server_metrics, mock_request.app.database) async def test_set_client_metrics_success(mock_request) -> None: @@ -205,7 +180,7 @@ async def test_set_client_metrics_success(mock_request) -> None: test_client_metrics = [{"test-metric-1": 456}, {"test-metric-2": 789}] - test_job.set_client_metrics(test_job.clients_info[1].uuid, test_client_metrics, mock_request.app.synchronous_database) + await test_job.set_client_metrics(test_job.clients_info[1].uuid, test_client_metrics, mock_request.app.database) result_job = await Job.find_by_id(result_id, mock_request.app.database) test_job.clients_info[1].metrics = json.dumps(test_client_metrics) @@ -222,7 +197,7 @@ async def test_set_metrics_fail_clients_info_is_none(mock_request) -> None: error_msg = f"client uuid {test_wrong_client_uuid} is not in clients_info (['{test_job.clients_info[0].uuid}', '{test_job.clients_info[1].uuid}'])" with raises(AssertionError, match=re.escape(error_msg)): - test_job.set_client_metrics(test_wrong_client_uuid, test_client_metrics, mock_request.app.synchronous_database) + await test_job.set_client_metrics(test_wrong_client_uuid, test_client_metrics, mock_request.app.database) async def test_set_client_metrics_fail_update_result(mock_request) -> None: @@ -233,10 +208,10 @@ async def test_set_client_metrics_fail_update_result(mock_request) -> None: error_msg = "UpdateResult's 'n' is not 1" with raises(AssertionError, match=re.escape(error_msg)): - test_job.set_client_metrics( + await test_job.set_client_metrics( test_job.clients_info[0].uuid, test_client_metrics, - mock_request.app.synchronous_database, + mock_request.app.database, ) diff --git a/florist/tests/integration/api/utils.py b/florist/tests/integration/api/utils.py index e0b6f652..2ac80b2e 100644 --- a/florist/tests/integration/api/utils.py +++ b/florist/tests/integration/api/utils.py @@ -32,8 +32,6 @@ class MockApp: def __init__(self, database_name: str): self.db_client = AsyncIOMotorClient(MONGODB_URI) self.database = self.db_client[database_name] - self.synchronous_db_client = MongoClient(MONGODB_URI) - self.synchronous_database = self.synchronous_db_client[database_name] class MockRequest(Request): diff --git a/florist/tests/unit/api/monitoring/test_metrics.py b/florist/tests/unit/api/monitoring/test_metrics.py index 7163438a..64a4636d 100644 --- a/florist/tests/unit/api/monitoring/test_metrics.py +++ b/florist/tests/unit/api/monitoring/test_metrics.py @@ -79,6 +79,31 @@ def test_dump_without_existing_connection(mock_redis: Mock) -> None: assert mock_redis_connection.set.call_args_list[2][0][0] == test_run_id assert mock_redis_connection.set.call_args_list[2][0][1] == json.dumps(expected_data, cls=DateTimeEncoder) +@freeze_time("2012-12-11 10:09:08") +@patch("florist.api.monitoring.metrics.redis.Redis") +def test_dump_does_not_save_duplicate(mock_redis: Mock) -> None: + mock_redis_connection = Mock() + mock_redis.return_value = mock_redis_connection + + test_host = "test host" + test_port = "test port" + test_run_id = "123" + test_data = {"test": "data", "date": datetime.datetime.now()} + test_round = 2 + + redis_metric_reporter = RedisMetricsReporter(test_host, test_port, test_run_id) + redis_metric_reporter.report(test_data, test_round) + + saved_data = json.dumps(redis_metric_reporter.metrics, cls=DateTimeEncoder) + mock_redis_connection.get.return_value = saved_data.encode("utf-8") + + redis_metric_reporter.dump() + + # assert this set has been called by the report only once and not called again by dump + assert mock_redis_connection.set.call_count == 1 + assert mock_redis_connection.set.call_args_list[0][0][0] == test_run_id + assert mock_redis_connection.set.call_args_list[0][0][1] == saved_data + @freeze_time("2012-12-11 10:09:08") @patch("florist.api.monitoring.metrics.redis.Redis") diff --git a/florist/tests/unit/api/routes/server/test_training.py b/florist/tests/unit/api/routes/server/test_training.py index 17e5af7b..2a900f29 100644 --- a/florist/tests/unit/api/routes/server/test_training.py +++ b/florist/tests/unit/api/routes/server/test_training.py @@ -2,18 +2,36 @@ import json from pytest import raises from typing import Dict, Any, Tuple -from unittest.mock import Mock, patch, ANY, call +from unittest.mock import Mock, AsyncMock, patch, ANY, call +from florist.api.db.config import DATABASE_NAME from florist.api.db.entities import Job, JobStatus, JOB_COLLECTION_NAME from florist.api.models.mnist import MnistNet from florist.api.routes.server.training import ( client_training_listener, start, - server_training_listener, - CHECK_CLIENT_STATUS_API, + server_training_listener ) +async def test_start_fail_unsupported_server_model() -> None: + # Arrange + test_job_id = "test-job-id" + _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() + test_job["model"] = "WRONG MODEL" + + # Act + response = await start(test_job_id, mock_fastapi_request) + + # Assert + assert response.status_code == 500 + json_body = json.loads(response.body.decode()) + assert json_body == {"error": ANY} + assert "value is not a valid enumeration member" in json_body["error"] + + +@patch("florist.api.routes.server.training.client_training_listener") +@patch("florist.api.routes.server.training.server_training_listener") @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.routes.server.training.requests") @@ -25,6 +43,8 @@ async def test_start_success( mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock, + mock_server_training_listener: Mock, + mock_client_training_listener: Mock, ) -> None: # Arrange test_job_id = "test-job-id" @@ -44,10 +64,11 @@ async def test_start_success( mock_response.json.side_effect = [{"uuid": test_client_1_uuid}, {"uuid": test_client_2_uuid}] mock_requests.get.return_value = mock_response - mock_background_tasks = Mock() + mock_client_training_listener.return_value = AsyncMock() + mock_server_training_listener.return_value = AsyncMock() # Act - response = await start(test_job_id, mock_fastapi_request, mock_background_tasks) + response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 200 @@ -101,41 +122,12 @@ async def test_start_success( expected_job.id = ANY expected_job.clients_info[0].id = ANY expected_job.clients_info[1].id = ANY - mock_background_tasks.add_task.assert_has_calls([ - call( - server_training_listener, - expected_job, - mock_fastapi_request.app.synchronous_database, - ), - call( - client_training_listener, - expected_job, - expected_job.clients_info[0], - mock_fastapi_request.app.synchronous_database, - ), - call( - client_training_listener, - expected_job, - expected_job.clients_info[1], - mock_fastapi_request.app.synchronous_database, - ), - ]) - - -async def test_start_fail_unsupported_server_model() -> None: - # Arrange - test_job_id = "test-job-id" - _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() - test_job["model"] = "WRONG MODEL" - - # Act - response = await start(test_job_id, mock_fastapi_request, Mock()) - # Assert - assert response.status_code == 500 - json_body = json.loads(response.body.decode()) - assert json_body == {"error": ANY} - assert "value is not a valid enumeration member" in json_body["error"] + mock_server_training_listener.assert_called_with(expected_job) + mock_client_training_listener.assert_has_calls([ + call(expected_job, expected_job.clients_info[0]), + call(expected_job, expected_job.clients_info[1]), + ]) async def test_start_fail_unsupported_client() -> None: @@ -145,7 +137,7 @@ async def test_start_fail_unsupported_client() -> None: test_job["clients_info"][1]["client"] = "WRONG CLIENT" # Act - response = await start(test_job_id, mock_fastapi_request, Mock()) + response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 500 @@ -155,7 +147,7 @@ async def test_start_fail_unsupported_client() -> None: @patch("florist.api.db.entities.Job.set_status") -async def test_start_fail_missing_info(mock_set_status: Mock) -> None: +async def test_start_fail_missing_info(_: Mock) -> None: fields_to_be_removed = ["model", "server_config", "clients_info", "server_address", "redis_host", "redis_port"] for field_to_be_removed in fields_to_be_removed: @@ -165,7 +157,7 @@ async def test_start_fail_missing_info(mock_set_status: Mock) -> None: del test_job[field_to_be_removed] # Act - response = await start(test_job_id, mock_fastapi_request, Mock()) + response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 400 @@ -175,14 +167,14 @@ async def test_start_fail_missing_info(mock_set_status: Mock) -> None: @patch("florist.api.db.entities.Job.set_status") -async def test_start_fail_invalid_server_config(mock_set_status: Mock) -> None: +async def test_start_fail_invalid_server_config(_: Mock) -> None: # Arrange test_job_id = "test-job-id" _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() test_job["server_config"] = "not json" # Act - response = await start(test_job_id, mock_fastapi_request, Mock()) + response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 400 @@ -199,7 +191,7 @@ async def test_start_fail_empty_clients_info(_: Mock) -> None: test_job["clients_info"] = [] # Act - response = await start(test_job_id, mock_fastapi_request, Mock()) + response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 400 @@ -219,7 +211,7 @@ async def test_start_launch_server_exception(mock_launch_local_server: Mock, _: mock_launch_local_server.side_effect = test_exception # Act - response = await start(test_job_id, mock_fastapi_request, Mock()) + response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 500 @@ -242,7 +234,7 @@ async def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_loc mock_redis.Redis.side_effect = test_exception # Act - response = await start(test_job_id, mock_fastapi_request, Mock()) + response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 500 @@ -254,7 +246,7 @@ async def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_loc @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.monitoring.metrics.time") # just so time.sleep does not actually sleep -async def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_launch_local_server: Mock, mock_set_status: Mock) -> None: +async def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_launch_local_server: Mock, __: Mock) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() @@ -267,7 +259,7 @@ async def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_lau mock_redis.Redis.return_value = mock_redis_connection # Act - response = await start(test_job_id, mock_fastapi_request, Mock()) + response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 500 @@ -297,7 +289,7 @@ async def test_start_fail_response(mock_requests: Mock, mock_redis: Mock, mock_l mock_requests.get.return_value = mock_response # Act - response = await start(test_job_id, mock_fastapi_request, Mock()) + response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 500 @@ -327,7 +319,7 @@ async def test_start_no_uuid_in_response(mock_requests: Mock, mock_redis: Mock, mock_requests.get.return_value = mock_response # Act - response = await start(test_job_id, mock_fastapi_request, Mock()) + response = await start(test_job_id, mock_fastapi_request) # Assert assert response.status_code == 500 @@ -335,9 +327,14 @@ async def test_start_no_uuid_in_response(mock_requests: Mock, mock_redis: Mock, assert json_body == {"error": "Client response did not return a UUID. Response: {'foo': 'bar'}"} +@patch("florist.api.routes.server.training.AsyncIOMotorClient") @patch("florist.api.routes.server.training.get_from_redis") @patch("florist.api.routes.server.training.get_subscriber") -def test_server_training_listener(mock_get_subscriber: Mock, mock_get_from_redis: Mock) -> None: +async def test_server_training_listener( + mock_get_subscriber: Mock, + mock_get_from_redis: Mock, + mock_motor_client: Mock, +) -> None: # Setup test_job = Job(**{ "server_uuid": "test-server-uuid", @@ -369,29 +366,32 @@ def test_server_training_listener(mock_get_subscriber: Mock, mock_get_from_redis {"type": "message"}, ] mock_get_subscriber.return_value = mock_subscriber - mock_database = Mock() + mock_db_client = make_mock_db_client() + mock_motor_client.return_value = mock_db_client - with patch.object(Job, "set_status_sync", Mock()) as mock_set_status_sync: - with patch.object(Job, "set_server_metrics", Mock()) as mock_set_server_metrics: + with patch.object(Job, "set_status", AsyncMock()) as mock_set_status: + with patch.object(Job, "set_server_metrics", AsyncMock()) as mock_set_server_metrics: # Act - server_training_listener(test_job, mock_database) + await server_training_listener(test_job) # Assert - mock_set_status_sync.assert_called_once_with(JobStatus.FINISHED_SUCCESSFULLY, mock_database) + mock_set_status.assert_called_once_with(JobStatus.FINISHED_SUCCESSFULLY, mock_db_client[DATABASE_NAME]) assert mock_set_server_metrics.call_count == 3 mock_set_server_metrics.assert_has_calls([ - call(test_server_metrics[0], mock_database), - call(test_server_metrics[1], mock_database), - call(test_server_metrics[2], mock_database), + call(test_server_metrics[0], mock_db_client[DATABASE_NAME]), + call(test_server_metrics[1], mock_db_client[DATABASE_NAME]), + call(test_server_metrics[2], mock_db_client[DATABASE_NAME]), ]) assert mock_get_from_redis.call_count == 3 mock_get_subscriber.assert_called_once_with(test_job.server_uuid, test_job.redis_host, test_job.redis_port) + mock_db_client.close.assert_called() +@patch("florist.api.routes.server.training.AsyncIOMotorClient") @patch("florist.api.routes.server.training.get_from_redis") -def test_server_training_listener_already_finished(mock_get_from_redis: Mock) -> None: +async def test_server_training_listener_already_finished(mock_get_from_redis: Mock, mock_motor_client: Mock) -> None: # Setup test_job = Job(**{ "server_uuid": "test-server-uuid", @@ -410,55 +410,62 @@ def test_server_training_listener_already_finished(mock_get_from_redis: Mock) -> }) test_server_final_metrics = {"fit_start": "2022-02-02 02:02:02", "rounds": [], "fit_end": "2022-02-02 03:03:03"} mock_get_from_redis.side_effect = [test_server_final_metrics] - mock_database = Mock() + mock_db_client = make_mock_db_client() + mock_motor_client.return_value = mock_db_client - with patch.object(Job, "set_status_sync", Mock()) as mock_set_status_sync: - with patch.object(Job, "set_server_metrics", Mock()) as mock_set_server_metrics: + with patch.object(Job, "set_status", AsyncMock()) as mock_set_status: + with patch.object(Job, "set_server_metrics", AsyncMock()) as mock_set_server_metrics: # Act - server_training_listener(test_job, mock_database) + await server_training_listener(test_job) # Assert - mock_set_status_sync.assert_called_once_with(JobStatus.FINISHED_SUCCESSFULLY, mock_database) - mock_set_server_metrics.assert_called_once_with(test_server_final_metrics, mock_database) + mock_set_status.assert_called_once_with(JobStatus.FINISHED_SUCCESSFULLY, mock_db_client[DATABASE_NAME]) + mock_set_server_metrics.assert_called_once_with(test_server_final_metrics, mock_db_client[DATABASE_NAME]) assert mock_get_from_redis.call_count == 1 + mock_db_client.close.assert_called() -def test_server_training_listener_fail_no_server_uuid() -> None: +async def test_server_training_listener_fail_no_server_uuid() -> None: test_job = Job(**{ "redis_host": "test-redis-host", "redis_port": "test-redis-port", }) with raises(AssertionError, match="job.server_uuid is None."): - server_training_listener(test_job, Mock()) + await server_training_listener(test_job) -def test_server_training_listener_fail_no_redis_host() -> None: +async def test_server_training_listener_fail_no_redis_host() -> None: test_job = Job(**{ "server_uuid": "test-server-uuid", "redis_port": "test-redis-port", }) with raises(AssertionError, match="job.redis_host is None."): - server_training_listener(test_job, Mock()) + await server_training_listener(test_job) -def test_server_training_listener_fail_no_redis_port() -> None: +async def test_server_training_listener_fail_no_redis_port() -> None: test_job = Job(**{ "server_uuid": "test-server-uuid", "redis_host": "test-redis-host", }) with raises(AssertionError, match="job.redis_port is None."): - server_training_listener(test_job, Mock()) + await server_training_listener(test_job) +@patch("florist.api.routes.server.training.AsyncIOMotorClient") @patch("florist.api.routes.server.training.get_from_redis") @patch("florist.api.routes.server.training.get_subscriber") -def test_client_training_listener(mock_get_subscriber: Mock, mock_get_from_redis: Mock) -> None: +async def test_client_training_listener( + mock_get_subscriber: Mock, + mock_get_from_redis: Mock, + mock_motor_client: Mock, +) -> None: # Setup - test_client_uuid = "test-client-uuid"; + test_client_uuid = "test-client-uuid" test_job = Job(**{ "clients_info": [ { @@ -486,20 +493,20 @@ def test_client_training_listener(mock_get_subscriber: Mock, mock_get_from_redis {"type": "message"}, ] mock_get_subscriber.return_value = mock_subscriber - mock_database = Mock() + mock_db_client = make_mock_db_client() + mock_motor_client.return_value = mock_db_client - with patch.object(Job, "set_status_sync", Mock()) as mock_set_status_sync: - with patch.object(Job, "set_client_metrics", Mock()) as mock_set_client_metrics: - # Act - client_training_listener(test_job, test_job.clients_info[0], mock_database) + with patch.object(Job, "set_client_metrics", AsyncMock()) as mock_set_client_metrics: + # Act + await client_training_listener(test_job, test_job.clients_info[0]) - # Assert - assert mock_set_client_metrics.call_count == 3 - mock_set_client_metrics.assert_has_calls([ - call(test_client_uuid, test_client_metrics[0], mock_database), - call(test_client_uuid, test_client_metrics[1], mock_database), - call(test_client_uuid, test_client_metrics[2], mock_database), - ]) + # Assert + assert mock_set_client_metrics.call_count == 3 + mock_set_client_metrics.assert_has_calls([ + call(test_client_uuid, test_client_metrics[0], mock_db_client[DATABASE_NAME]), + call(test_client_uuid, test_client_metrics[1], mock_db_client[DATABASE_NAME]), + call(test_client_uuid, test_client_metrics[2], mock_db_client[DATABASE_NAME]), + ]) assert mock_get_from_redis.call_count == 3 mock_get_subscriber.assert_called_once_with( @@ -507,12 +514,14 @@ def test_client_training_listener(mock_get_subscriber: Mock, mock_get_from_redis test_job.clients_info[0].redis_host, test_job.clients_info[0].redis_port, ) + mock_db_client.close.assert_called() +@patch("florist.api.routes.server.training.AsyncIOMotorClient") @patch("florist.api.routes.server.training.get_from_redis") -def test_client_training_listener_already_finished(mock_get_from_redis: Mock) -> None: +async def test_client_training_listener_already_finished(mock_get_from_redis: Mock, mock_motor_client: Mock) -> None: # Setup - test_client_uuid = "test-client-uuid"; + test_client_uuid = "test-client-uuid" test_job = Job(**{ "clients_info": [ { @@ -527,20 +536,25 @@ def test_client_training_listener_already_finished(mock_get_from_redis: Mock) -> }) test_client_final_metrics = {"initialized": "2022-02-02 02:02:02", "rounds": [], "shutdown": "2022-02-02 03:03:03"} mock_get_from_redis.side_effect = [test_client_final_metrics] - mock_database = Mock() + mock_db_client = make_mock_db_client() + mock_motor_client.return_value = mock_db_client - with patch.object(Job, "set_status_sync", Mock()) as mock_set_status_sync: - with patch.object(Job, "set_client_metrics", Mock()) as mock_set_client_metrics: - # Act - client_training_listener(test_job, test_job.clients_info[0], mock_database) + with patch.object(Job, "set_client_metrics", AsyncMock()) as mock_set_client_metrics: + # Act + await client_training_listener(test_job, test_job.clients_info[0]) - # Assert - mock_set_client_metrics.assert_called_once_with(test_client_uuid, test_client_final_metrics, mock_database) + # Assert + mock_set_client_metrics.assert_called_once_with( + test_client_uuid, + test_client_final_metrics, + mock_db_client[DATABASE_NAME], + ) assert mock_get_from_redis.call_count == 1 + mock_db_client.close.assert_called() -def test_client_training_listener_fail_no_uuid() -> None: +async def test_client_training_listener_fail_no_uuid() -> None: test_job = Job(**{ "clients_info": [ { @@ -554,7 +568,7 @@ def test_client_training_listener_fail_no_uuid() -> None: }) with raises(AssertionError, match="client_info.uuid is None."): - client_training_listener(test_job, test_job.clients_info[0], Mock()) + await client_training_listener(test_job, test_job.clients_info[0]) def _setup_test_job_and_mocks() -> Tuple[Dict[str, Any], Dict[str, Any], Mock, Mock]: @@ -601,6 +615,14 @@ def _setup_test_job_and_mocks() -> Tuple[Dict[str, Any], Dict[str, Any], Mock, M mock_job_collection.find_one.return_value = mock_find_one mock_fastapi_request = Mock() mock_fastapi_request.app.database = {JOB_COLLECTION_NAME: mock_job_collection} - mock_fastapi_request.app.synchronous_database = {JOB_COLLECTION_NAME: mock_job_collection} return test_server_config, test_job, mock_job_collection, mock_fastapi_request + + +def make_mock_db_client() -> Mock: + mock_database = Mock() + mock_db_client = Mock() + mock_db_client.__getitem__ = Mock( + side_effect=lambda database_name: mock_database if database_name == DATABASE_NAME else None + ) + return mock_db_client diff --git a/florist/tests/unit/app/jobs/details/page.test.tsx b/florist/tests/unit/app/jobs/details/page.test.tsx index 2d920fb4..c5ae739b 100644 --- a/florist/tests/unit/app/jobs/details/page.test.tsx +++ b/florist/tests/unit/app/jobs/details/page.test.tsx @@ -52,10 +52,10 @@ function makeTestJob(): JobData { fit_start: "2020-01-01 12:07:07.0707", rounds: { "1": { - fit_start: "2020-01-01 12:08:08.0808", - fit_end: "2020-01-01 12:09:09.0909", - evaluate_start: "2020-01-01 12:08:09.0808", - evaluate_end: "2020-01-01 12:08:10.0808", + fit_round_start: "2020-01-01 12:08:08.0808", + fit_round_end: "2020-01-01 12:09:09.0909", + eval_round_start: "2020-01-01 12:08:09.0808", + eval_round_end: "2020-01-01 12:08:10.0808", custom_property_value: "133.7", custom_property_array: [1337, 1338], custom_property_object: { @@ -63,12 +63,12 @@ function makeTestJob(): JobData { }, }, "2": { - fit_start: "2020-01-01 12:10:10.1010", - fit_end: "2020-01-01 12:11:11.1111", - evaluate_start: "2020-01-01 12:11:09.0808", + fit_round_start: "2020-01-01 12:10:10.1010", + fit_round_end: "2020-01-01 12:11:11.1111", + eval_round_start: "2020-01-01 12:11:09.0808", }, "3": { - fit_start: "2020-01-01 12:12:00.1212", + fit_round_start: "2020-01-01 12:12:00.1212", }, }, custom_property_value: "133.7", @@ -92,14 +92,30 @@ function makeTestJob(): JobData { "1": { fit_start: "2024-10-10 15:05:34.888213", fit_end: "2024-10-10 15:06:59.032618", - evaluate_start: "2024-10-10 15:07:59.032618", - evaluate_end: "2024-10-10 15:08:34.888213", + eval_start: "2024-10-10 15:07:59.032618", + eval_end: "2024-10-10 15:08:34.888213", + round_end: "2024-10-10 15:08:34.888213", }, "2": { fit_start: "2024-10-10 15:06:59.032618", fit_end: "2024-10-10 15:07:34.888213", - evaluate_start: "2024-10-10 15:08:34.888213", - evaluate_end: "2024-10-10 15:09:59.032618", + eval_start: "2024-10-10 15:08:34.888213", + eval_end: "2024-10-10 15:09:59.032618", + round_end: "2024-10-10 15:09:59.032618", + }, + "3": { + fit_start: "2024-10-10 15:10:59.032618", + fit_end: "2024-10-10 15:11:34.888213", + eval_start: "2024-10-10 15:12:34.888213", + eval_end: "2024-10-10 15:13:59.032618", + round_end: "2024-10-10 15:14:59.032618", + }, + "4": { + fit_start: "2024-10-10 15:15:59.032618", + fit_end: "2024-10-10 15:16:34.888213", + eval_start: "2024-10-10 15:17:34.888213", + eval_end: "2024-10-10 15:18:59.032618", + round_end: "2024-10-10 15:19:59.032618", }, }, }), @@ -117,8 +133,9 @@ function makeTestJob(): JobData { "1": { fit_start: "2024-10-10 15:05:34.888213", fit_end: "2024-10-10 15:05:34.888213", - evaluate_start: "2024-10-10 15:08:34.888213", - evaluate_end: "2024-10-10 15:08:34.888213", + eval_start: "2024-10-10 15:08:34.888213", + eval_end: "2024-10-10 15:08:34.888213", + round_end: "2024-10-10 15:08:34.888213", }, "2": { fit_start: "2024-10-10 15:06:59.032618", @@ -421,15 +438,16 @@ describe("Job Details Page", () => { const progressToggleButton = container.querySelector(".job-details-toggle a"); act(() => progressToggleButton.click()); + const serverRounds = serverMetrics.rounds; const expectedTimes = { fit: [ - ["01m 01s", serverMetrics.rounds["1"].fit_start, serverMetrics.rounds["1"].fit_end], - ["01m 01s", serverMetrics.rounds["2"].fit_start, serverMetrics.rounds["2"].fit_end], - ["12s", serverMetrics.rounds["3"].fit_start, ""], + ["01m 01s", serverRounds["1"].fit_round_start, serverRounds["1"].fit_round_end], + ["01m 01s", serverRounds["2"].fit_round_start, serverRounds["2"].fit_round_end], + ["12s", serverRounds["3"].fit_round_start, ""], ], evaluate: [ - ["01s", serverMetrics.rounds["1"].evaluate_start, serverMetrics.rounds["1"].evaluate_end], - ["01m 03s", serverMetrics.rounds["2"].evaluate_start, ""], + ["01s", serverRounds["1"].eval_round_start, serverRounds["1"].eval_round_end], + ["01m 03s", serverRounds["2"].eval_round_start, ""], ["", "", ""], ], }; @@ -545,7 +563,7 @@ describe("Job Details Page", () => { progressBar = clientsProgress[1].querySelector("div.progress-bar"); expect(progressBar).toHaveClass("bg-warning"); - expect(progressBar).toHaveTextContent("50%"); + expect(progressBar).toHaveTextContent("25%"); }); it("Renders the progress details correctly", () => { const testJob = makeTestJob(); diff --git a/package.json b/package.json index 828bf983..9883816a 100644 --- a/package.json +++ b/package.json @@ -3,8 +3,8 @@ "version": "0.0.1", "private": true, "scripts": { - "fastapi-dev": "poetry install --with test && python -m uvicorn florist.api.server:app --reload --log-level debug", - "fastapi-prod": "poetry install --with test && python -m uvicorn florist.api.server:app --reload", + "fastapi-dev": "python -m poetry install --with test && python -m uvicorn florist.api.server:app --reload --log-level debug", + "fastapi-prod": "python -m poetry install && python -m uvicorn florist.api.server:app --workers 4", "next-dev": "next dev florist", "dev": "concurrently \"npm run next-dev\" \"npm run fastapi-dev\"", "prod": "concurrently \"npm run next-dev\" \"npm run fastapi-prod\"",