diff --git a/florist/api/client.py b/florist/api/client.py index 30fcf8fe..1a051fbb 100644 --- a/florist/api/client.py +++ b/florist/api/client.py @@ -43,9 +43,14 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red :param redis_port: (str) the port for the Redis instance for metrics reporting. :return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the client in the format below, which can be used to pull metrics from Redis. - {"uuid": } + { + "uuid": (str) The client's uuid, which can be used to pull metrics from Redis, + "log_file_path": (str) The local path of the log file for this client, + } If not successful, returns the appropriate error code with a JSON with the format below: - {"error": } + { + "error": (str) The error message, + } """ try: if client not in Client.list(): @@ -65,10 +70,10 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red reporters=[metrics_reporter], ) - log_file_name = str(get_client_log_file_path(client_uuid)) - launch_client(client_obj, server_address, log_file_name) + log_file_path = str(get_client_log_file_path(client_uuid)) + launch_client(client_obj, server_address, log_file_path) - return JSONResponse({"uuid": client_uuid}) + return JSONResponse({"uuid": client_uuid, "log_file_path": log_file_path}) except Exception as ex: return JSONResponse({"error": str(ex)}, status_code=500) @@ -98,3 +103,17 @@ def check_status(client_uuid: str, redis_host: str, redis_port: str) -> JSONResp except Exception as ex: LOGGER.exception(ex) return JSONResponse({"error": str(ex)}, status_code=500) + + +@app.get("/api/client/get_log") +def get_log(log_file_path: str) -> JSONResponse: + """ + Return the contents of the log file under the given path. + + :param log_file_path: (str) the path of the logt file. + + :return: (JSONResponse) Returns the contents of the file as a string. + """ + with open(log_file_path, "r") as f: + content = f.read() + return JSONResponse(content) diff --git a/florist/api/db/entities.py b/florist/api/db/entities.py index fba0bb94..2492b0af 100644 --- a/florist/api/db/entities.py +++ b/florist/api/db/entities.py @@ -48,6 +48,7 @@ class ClientInfo(BaseModel): redis_port: str = Field(...) uuid: Optional[Annotated[str, Field(...)]] metrics: Optional[Annotated[str, Field(...)]] + log_file_path: Optional[Annotated[str, Field(...)]] class Config: """MongoDB config for the ClientInfo DB entity.""" @@ -62,6 +63,7 @@ class Config: "redis_port": "6380", "uuid": "0c316680-1375-4e07-84c3-a732a2e6d03f", "metrics": '{"host_type": "client", "initialized": "2024-03-25 11:20:56.819569", "rounds": {"1": {"fit_start": "2024-03-25 11:20:56.827081"}}}', + "log_file_path": "/Users/foo/client/logfile.log", }, } @@ -77,6 +79,7 @@ class Job(BaseModel): config_parser: Optional[Annotated[ConfigParser, Field(...)]] server_uuid: Optional[Annotated[str, Field(...)]] server_metrics: Optional[Annotated[str, Field(...)]] + server_log_file_path: Optional[Annotated[str, Field(...)]] redis_host: Optional[Annotated[str, Field(...)]] redis_port: Optional[Annotated[str, Field(...)]] clients_info: Optional[Annotated[List[ClientInfo], Field(...)]] @@ -129,7 +132,7 @@ async def set_uuids(self, server_uuid: str, client_uuids: List[str], database: A """ Save the server and clients' UUIDs in the database under the current job's id. - :param server_uuid: [str] the server_uuid to be saved in the database. + :param server_uuid: (str) the server_uuid to be saved in the database. :param client_uuids: List[str] the list of client_uuids to be saved in the database. :param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored. """ @@ -211,6 +214,44 @@ async def set_client_metrics( ) assert_updated_successfully(update_result) + async def set_server_log_file_path(self, log_file_path: str, database: AsyncIOMotorDatabase[Any]) -> None: + """ + Save the server's log file path in the database under the current job's id. + + :param log_file_path: (str) the file path to be saved in the database. + :param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored. + """ + job_collection = database[JOB_COLLECTION_NAME] + self.server_log_file_path = log_file_path + update_result = await job_collection.update_one( + {"_id": self.id}, {"$set": {"server_log_file_path": log_file_path}} + ) + assert_updated_successfully(update_result) + + async def set_client_log_file_path( + self, + client_index: int, + log_file_path: str, + database: AsyncIOMotorDatabase[Any], + ) -> None: + """ + Save the clients' log file path in the database under the given client index and current job's id. + + :param client_index: (str) the index of the client in the job. + :param log_file_path: (str) the path oof the client's log file. + :param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored. + """ + assert self.clients_info is not None, "Job has no clients." + assert ( + 0 <= client_index < len(self.clients_info) + ), f"Client index {client_index} is invalid (total: {len(self.clients_info)})" + + job_collection = database[JOB_COLLECTION_NAME] + update_result = await job_collection.update_one( + {"_id": self.id}, {"$set": {f"clients_info.{client_index}.log_file_path": log_file_path}} + ) + assert_updated_successfully(update_result) + class Config: """MongoDB config for the Job DB entity.""" @@ -224,6 +265,7 @@ class Config: "server_config": '{"n_server_rounds": 3, "batch_size": 8, "local_epochs": 1}', "server_uuid": "d73243cf-8b89-473b-9607-8cd0253a101d", "server_metrics": '{"host_type": "server", "fit_start": "2024-04-23 15:33:12.865604", "rounds": {"1": {"fit_start": "2024-04-23 15:33:12.869001"}}}', + "server_log_file_path": "/Users/foo/server/logfile.log", "redis_host": "localhost", "redis_port": "6379", "clients_info": [ diff --git a/florist/api/routes/server/job.py b/florist/api/routes/server/job.py index be9dcdad..8efe0a06 100644 --- a/florist/api/routes/server/job.py +++ b/florist/api/routes/server/job.py @@ -2,6 +2,7 @@ from typing import List, Union +import requests from fastapi import APIRouter, Body, Request, status from fastapi.responses import JSONResponse @@ -99,3 +100,73 @@ async def change_job_status(job_id: str, status: JobStatus, request: Request) -> return JSONResponse(content={"error": str(assertion_e)}, status_code=400) except Exception as general_e: return JSONResponse(content={"error": str(general_e)}, status_code=500) + + +@router.get("/get_server_log/{job_id}") +async def get_server_log(job_id: str, request: Request) -> JSONResponse: + """ + Return the contents of the server's log file for the given job id. + + :param job_id: (str) the ID of the job to get the server logs for. + :param request: (fastapi.Request) the FastAPI request object. + + :return: (JSONResponse) if successful, returns the contents of the file as a string. + If not successful, returns the appropriate error code with a JSON with the format below: + {"error": } + """ + try: + job = await Job.find_by_id(job_id, request.app.database) + + assert job is not None, f"Job {job_id} not found" + assert ( + job.server_log_file_path is not None and job.server_log_file_path != "" + ), "Log file path is None or empty" + + with open(job.server_log_file_path, "r") as f: + content = f.read() + return JSONResponse(content) + + except AssertionError as assertion_e: + return JSONResponse(content={"error": str(assertion_e)}, status_code=400) + except Exception as general_e: + return JSONResponse(content={"error": str(general_e)}, status_code=500) + + +@router.get("/get_client_log/{job_id}/{client_index}") +async def get_client_log(job_id: str, client_index: int, request: Request) -> JSONResponse: + """ + Return the contents of the log file for the client with given index under given job id. + + :param job_id: (str) the ID of the job to get the client logs for. + :param client_index: (int) the index of the client within the job. + :param request: (fastapi.Request) the FastAPI request object. + + :return: (JSONResponse) if successful, returns the contents of the file as a string. + If not successful, returns the appropriate error code with a JSON with the format below: + {"error": } + """ + try: + job = await Job.find_by_id(job_id, request.app.database) + + assert job is not None, f"Job {job_id} not found" + assert job.clients_info is not None, "Job has no clients." + assert ( + 0 <= client_index < len(job.clients_info) + ), f"Client index {client_index} is invalid (total: {len(job.clients_info)})" + + client_info = job.clients_info[client_index] + assert ( + client_info.log_file_path is not None and client_info.log_file_path != "" + ), "Log file path is None or empty" + + response = requests.get( + url=f"http://{client_info.service_address}/api/client/get_log", + params={"log_file_path": client_info.log_file_path}, + ) + json_response = response.json() + return JSONResponse(json_response) + + except AssertionError as assertion_e: + return JSONResponse(content={"error": str(assertion_e)}, status_code=400) + except Exception as general_e: + return JSONResponse(content={"error": str(general_e)}, status_code=500) diff --git a/florist/api/routes/server/training.py b/florist/api/routes/server/training.py index bcc2f965..5cf31988 100644 --- a/florist/api/routes/server/training.py +++ b/florist/api/routes/server/training.py @@ -72,7 +72,7 @@ async def start(job_id: str, request: Request) -> JSONResponse: model_class = Model.class_for_model(job.model) # Start the server - server_uuid, _ = launch_local_server( + server_uuid, _, server_log_file_path = launch_local_server( model=model_class(), n_clients=len(job.clients_info), server_address=job.server_address, @@ -80,11 +80,13 @@ async def start(job_id: str, request: Request) -> JSONResponse: redis_port=job.redis_port, **server_config, ) + await job.set_server_log_file_path(server_log_file_path, request.app.database) wait_for_metric(server_uuid, "fit_start", job.redis_host, job.redis_port, logger=LOGGER) # Start the clients client_uuids: List[str] = [] - for client_info in job.clients_info: + for i in range(len(job.clients_info)): + client_info = job.clients_info[i] parameters = { "server_address": job.server_address, "client": client_info.client.value, @@ -104,6 +106,8 @@ async def start(job_id: str, request: Request) -> JSONResponse: client_uuids.append(json_response["uuid"]) + await job.set_client_log_file_path(i, json_response["log_file_path"], request.app.database) + await job.set_uuids(server_uuid, client_uuids, request.app.database) # Start the server training listener and client training listeners as threads to update diff --git a/florist/api/servers/launch.py b/florist/api/servers/launch.py index 298993cc..b8ddda7b 100644 --- a/florist/api/servers/launch.py +++ b/florist/api/servers/launch.py @@ -22,7 +22,7 @@ def launch_local_server( local_epochs: int, redis_host: str, redis_port: str, -) -> Tuple[str, Process]: +) -> Tuple[str, Process, str]: """ Launch a FL server locally. @@ -34,8 +34,10 @@ def launch_local_server( :param local_epochs: (int) The number of epochs to run by the clients. :param redis_host: (str) the host name for the Redis instance for metrics reporting. :param redis_port: (str) the port for the Redis instance for metrics reporting. - :return: (Tuple[str, multiprocessing.Process]) the UUID of the server, which can be used to pull - metrics from Redis, along with its local process object. + :return: (Tuple[str, multiprocessing.Process, str]) a tuple with + - The UUID of the server, which can be used to pull metrics from Redis + - The server's local process object + - The local path for the log file """ server_uuid = str(uuid.uuid4()) @@ -49,13 +51,13 @@ def launch_local_server( local_epochs=local_epochs, ) - log_file_name = str(get_server_log_file_path(server_uuid)) + log_file_path = str(get_server_log_file_path(server_uuid)) server_process = launch_server( server_constructor, server_address, n_server_rounds, - log_file_name, + log_file_path, seconds_to_sleep=0, ) - return server_uuid, server_process + return server_uuid, server_process, log_file_path diff --git a/florist/tests/integration/api/db/test_entities.py b/florist/tests/integration/api/db/test_entities.py index 6d1ddcb5..89f4dcd3 100644 --- a/florist/tests/integration/api/db/test_entities.py +++ b/florist/tests/integration/api/db/test_entities.py @@ -215,6 +215,38 @@ async def test_set_client_metrics_fail_update_result(mock_request) -> None: ) +async def test_set_server_log_file_path_success(mock_request) -> None: + test_job = get_test_job() + result_id = await test_job.create(mock_request.app.database) + test_job.id = result_id + test_job.clients_info[0].id = ANY + test_job.clients_info[1].id = ANY + + test_log_file_path = "test/log/file/path.log" + + await test_job.set_server_log_file_path(test_log_file_path, mock_request.app.database) + + result_job = await Job.find_by_id(result_id, mock_request.app.database) + test_job.server_log_file_path = test_log_file_path + assert result_job == test_job + + +async def test_set_client_log_file_path_success(mock_request) -> None: + test_job = get_test_job() + result_id = await test_job.create(mock_request.app.database) + test_job.id = result_id + test_job.clients_info[0].id = ANY + test_job.clients_info[1].id = ANY + + test_log_file_path = "test/log/file/path.log" + + await test_job.set_client_log_file_path(1, test_log_file_path, mock_request.app.database) + + result_job = await Job.find_by_id(result_id, mock_request.app.database) + test_job.clients_info[1].log_file_path = test_log_file_path + assert result_job == test_job + + def get_test_job() -> Job: test_server_config = { "n_server_rounds": 2, diff --git a/florist/tests/integration/api/launchers/test_launch.py b/florist/tests/integration/api/launchers/test_launch.py index f8f14b77..ef5bf18f 100644 --- a/florist/tests/integration/api/launchers/test_launch.py +++ b/florist/tests/integration/api/launchers/test_launch.py @@ -11,7 +11,7 @@ from florist.api.servers.utils import get_server -def assert_string_in_file(file_path: str, search_string: str) -> bool: +def assert_string_in_file(file_path: str, search_string: str) -> None: with open(file_path, "r") as f: file_contents = f.read() match = re.search(search_string, file_contents) diff --git a/florist/tests/integration/api/routes/server/test_job.py b/florist/tests/integration/api/routes/server/test_job.py index 297d046a..a1e55370 100644 --- a/florist/tests/integration/api/routes/server/test_job.py +++ b/florist/tests/integration/api/routes/server/test_job.py @@ -1,12 +1,16 @@ +import json +import os from unittest.mock import ANY +import uvicorn from fastapi.encoders import jsonable_encoder from florist.api.clients.common import Client from florist.api.db.entities import ClientInfo, Job, JobStatus -from florist.api.routes.server.job import list_jobs_with_status, new_job +from florist.api.monitoring.logs import get_server_log_file_path, get_client_log_file_path +from florist.api.routes.server.job import list_jobs_with_status, new_job, get_server_log, get_client_log from florist.api.servers.common import Model -from florist.tests.integration.api.utils import mock_request +from florist.tests.integration.api.utils import mock_request, TestUvicornServer from florist.api.servers.config_parsers import ConfigParser @@ -26,6 +30,7 @@ async def test_new_job(mock_request) -> None: "clients_info": None, "server_metrics": None, "server_uuid": None, + "server_log_file_path": None, } test_job = Job( @@ -39,6 +44,7 @@ async def test_new_job(mock_request) -> None: redis_port="test-redis-port", server_metrics="test-server-metrics", server_uuid="test-server-uuid", + server_log_file_path="test-server-log-file-path", clients_info=[ ClientInfo( client=Client.MNIST, @@ -48,6 +54,7 @@ async def test_new_job(mock_request) -> None: redis_port="test-redis-port-1", metrics="test-client-metrics-1", uuid="test-client-uuid-1", + log_file_path="test-log-file-path-1", ), ClientInfo( client=Client.MNIST, @@ -55,8 +62,9 @@ async def test_new_job(mock_request) -> None: data_path="test/data/path-2", redis_host="test-redis-host-2", redis_port="test-redis-port-2", - metrics="test-client-metrics-1", - uuid="test-client-uuid-1", + metrics="test-client-metrics-2", + uuid="test-client-uuid-2", + log_file_path="test-log-file-path-2", ), ] ) @@ -73,6 +81,7 @@ async def test_new_job(mock_request) -> None: "redis_port": test_job.redis_port, "server_uuid": test_job.server_uuid, "server_metrics": test_job.server_metrics, + "server_log_file_path": test_job.server_log_file_path, "clients_info": [ { "_id": ANY, @@ -83,6 +92,7 @@ async def test_new_job(mock_request) -> None: "redis_port": test_job.clients_info[0].redis_port, "uuid": test_job.clients_info[0].uuid, "metrics": test_job.clients_info[0].metrics, + "log_file_path": test_job.clients_info[0].log_file_path, }, { "_id": ANY, "client": test_job.clients_info[1].client.value, @@ -92,6 +102,7 @@ async def test_new_job(mock_request) -> None: "redis_port": test_job.clients_info[1].redis_port, "uuid": test_job.clients_info[1].uuid, "metrics": test_job.clients_info[1].metrics, + "log_file_path": test_job.clients_info[1].log_file_path, }, ], } @@ -109,6 +120,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port1", server_metrics="test-server-metrics1", server_uuid="test-server-uuid1", + server_log_file_path="test-server-log-file-path1", clients_info=[ ClientInfo( client=Client.MNIST, @@ -118,6 +130,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port-1-1", metrics="test-client-metrics-1-1", uuid="test-client-uuid-1-1", + log_file_path="test-log-file-path-1-1", ), ClientInfo( client=Client.MNIST, @@ -127,6 +140,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port-2-1", metrics="test-client-metrics-2-1", uuid="test-client-uuid-2-1", + log_file_path="test-log-file-path-2-1", ), ] ) @@ -142,6 +156,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port2", server_metrics="test-server-metrics2", server_uuid="test-server-uuid2", + server_log_file_path="test-server-log-file-path2", clients_info=[ ClientInfo( client=Client.MNIST, @@ -151,6 +166,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port-1-2", metrics="test-client-metrics-1-2", uuid="test-client-uuid-1-2", + log_file_path="test-log-file-path-1-2", ), ClientInfo( client=Client.MNIST, @@ -160,6 +176,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port-2-2", metrics="test-client-metrics-2-2", uuid="test-client-uuid-2-2", + log_file_path="test-log-file-path-2-2", ), ] ) @@ -175,6 +192,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port3", server_metrics="test-server-metrics3", server_uuid="test-server-uuid3", + server_log_file_path="test-server-log-file-path3", clients_info=[ ClientInfo( client=Client.MNIST, @@ -184,6 +202,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port-1-3", metrics="test-client-metrics-1-3", uuid="test-client-uuid-1-3", + log_file_path="test-log-file-path-1-3", ), ClientInfo( client=Client.MNIST, @@ -193,6 +212,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port-2-3", metrics="test-client-metrics-2-3", uuid="test-client-uuid-2-3", + log_file_path="test-log-file-path-2-3", ), ] ) @@ -208,6 +228,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port4", server_metrics="test-server-metrics4", server_uuid="test-server-uuid4", + server_log_file_path="test-server-log-file-path4", clients_info=[ ClientInfo( client=Client.MNIST, @@ -217,6 +238,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port-1-4", metrics="test-client-metrics-1-4", uuid="test-client-uuid-1-4", + log_file_path="test-log-file-path-1-4", ), ClientInfo( client=Client.MNIST, @@ -226,6 +248,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port-2-4", metrics="test-client-metrics-2-4", uuid="test-client-uuid-2-4", + log_file_path="test-log-file-path-2-4", ), ] ) @@ -255,6 +278,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job1.redis_port, "server_metrics": test_job1.server_metrics, "server_uuid": test_job1.server_uuid, + "server_log_file_path": test_job1.server_log_file_path, "clients_info": [ { "_id": ANY, @@ -265,6 +289,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job1.clients_info[0].redis_port, "metrics": test_job1.clients_info[0].metrics, "uuid": test_job1.clients_info[0].uuid, + "log_file_path": test_job1.clients_info[0].log_file_path, }, { "_id": ANY, "client": test_job1.clients_info[1].client.value, @@ -274,6 +299,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job1.clients_info[1].redis_port, "metrics": test_job1.clients_info[1].metrics, "uuid": test_job1.clients_info[1].uuid, + "log_file_path": test_job1.clients_info[1].log_file_path, }, ], } @@ -289,6 +315,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job2.redis_port, "server_metrics": test_job2.server_metrics, "server_uuid": test_job2.server_uuid, + "server_log_file_path": test_job2.server_log_file_path, "clients_info": [ { "_id": ANY, @@ -299,6 +326,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job2.clients_info[0].redis_port, "metrics": test_job2.clients_info[0].metrics, "uuid": test_job2.clients_info[0].uuid, + "log_file_path": test_job2.clients_info[0].log_file_path, }, { "_id": ANY, "client": test_job2.clients_info[1].client.value, @@ -308,6 +336,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job2.clients_info[1].redis_port, "metrics": test_job2.clients_info[1].metrics, "uuid": test_job2.clients_info[1].uuid, + "log_file_path": test_job2.clients_info[1].log_file_path, }, ], } @@ -323,6 +352,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job3.redis_port, "server_metrics": test_job3.server_metrics, "server_uuid": test_job3.server_uuid, + "server_log_file_path": test_job3.server_log_file_path, "clients_info": [ { "_id": ANY, @@ -333,6 +363,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job3.clients_info[0].redis_port, "metrics": test_job3.clients_info[0].metrics, "uuid": test_job3.clients_info[0].uuid, + "log_file_path": test_job3.clients_info[0].log_file_path, }, { "_id": ANY, "client": test_job3.clients_info[1].client.value, @@ -342,6 +373,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job3.clients_info[1].redis_port, "metrics": test_job3.clients_info[1].metrics, "uuid": test_job3.clients_info[1].uuid, + "log_file_path": test_job3.clients_info[1].log_file_path, }, ], } @@ -357,6 +389,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job4.redis_port, "server_metrics": test_job4.server_metrics, "server_uuid": test_job4.server_uuid, + "server_log_file_path": test_job4.server_log_file_path, "clients_info": [ { "_id": ANY, @@ -367,6 +400,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job4.clients_info[0].redis_port, "metrics": test_job4.clients_info[0].metrics, "uuid": test_job4.clients_info[0].uuid, + "log_file_path": test_job4.clients_info[0].log_file_path, }, { "_id": ANY, "client": test_job4.clients_info[1].client.value, @@ -376,6 +410,141 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job4.clients_info[1].redis_port, "metrics": test_job4.clients_info[1].metrics, "uuid": test_job4.clients_info[1].uuid, + "log_file_path": test_job4.clients_info[1].log_file_path, }, ], } + + +async def test_get_server_log_success(mock_request): + test_log_file_name = "test-log-file-name" + test_log_file_content = "this is a test log file content" + test_log_file_path = str(get_server_log_file_path(test_log_file_name)) + + with open(test_log_file_path, "w") as f: + f.write(test_log_file_content) + + result_job = await new_job(mock_request, Job(server_log_file_path=test_log_file_path)) + + result = await get_server_log(result_job.id, mock_request) + + assert result.status_code == 200 + assert result.body.decode() == f"\"{test_log_file_content}\"" + + os.remove(test_log_file_path) + + +async def test_get_server_log_error_no_job(mock_request): + test_job_id = "inexistent-job-id" + + result = await get_server_log(test_job_id, mock_request) + + assert result.status_code == 400 + assert json.loads(result.body.decode()) == {"error": f"Job {test_job_id} not found"} + + +async def test_get_server_log_error_no_log_path(mock_request): + result_job = await new_job(mock_request, Job()) + + result = await get_server_log(result_job.id, mock_request) + + assert result.status_code == 400 + assert json.loads(result.body.decode()) == {"error": f"Log file path is None or empty"} + + +async def test_get_client_log_success(mock_request): + test_log_file_name = "test-log-file-name" + test_log_file_content = "this is a test log file content" + test_log_file_path = str(get_client_log_file_path(test_log_file_name)) + + with open(test_log_file_path, "w") as f: + f.write(test_log_file_content) + + test_client_host = "localhost" + test_client_port = 8001 + + result_job = await new_job(mock_request, Job( + clients_info=[ + ClientInfo( + client=Client.MNIST, + service_address=f"{test_client_host}:{test_client_port}", + data_path="test/data/path-1", + redis_host="test-redis-host-1", + redis_port="test-redis-port-1", + ), + ClientInfo( + client=Client.MNIST, + service_address=f"{test_client_host}:{test_client_port}", + data_path="test/data/path-2", + redis_host="test-redis-host-2", + redis_port="test-redis-port-2", + log_file_path=test_log_file_path, + ), + ], + )) + + client_config = uvicorn.Config("florist.api.client:app", host=test_client_host, port=test_client_port, log_level="debug") + client_service = TestUvicornServer(config=client_config) + with client_service.run_in_thread(): + result = await get_client_log(result_job.id, 1, mock_request) + + assert result.status_code == 200 + assert result.body.decode() == f"\"{test_log_file_content}\"" + + os.remove(test_log_file_path) + + +async def test_get_client_log_error_no_job(mock_request): + test_job_id = "inexistent-job-id" + + result = await get_client_log(test_job_id, 0, mock_request) + + assert result.status_code == 400 + assert json.loads(result.body.decode()) == {"error": f"Job {test_job_id} not found"} + + +async def test_get_client_log_error_no_clients(mock_request): + result_job = await new_job(mock_request, Job()) + + result = await get_client_log(result_job.id, 0, mock_request) + + assert result.status_code == 400 + assert json.loads(result.body.decode()) == {"error": f"Job has no clients."} + + +async def test_get_client_log_error_invalid_client_index(mock_request): + result_job = await new_job(mock_request, Job( + clients_info=[ + ClientInfo( + client=Client.MNIST, + service_address=f"test-address", + data_path="test/data/path-1", + redis_host="test-redis-host-1", + redis_port="test-redis-port-1", + ), + ], + )) + + result = await get_client_log(result_job.id, 1, mock_request) + + assert result.status_code == 400 + assert json.loads(result.body.decode()) == {"error": f"Client index 1 is invalid (total: 1)"} + + +async def test_get_client_log_error_log_file_path_is_none(mock_request): + result_job = await new_job(mock_request, Job( + clients_info=[ + ClientInfo( + client=Client.MNIST, + service_address=f"test-address", + data_path="test/data/path-1", + redis_host="test-redis-host-1", + redis_port="test-redis-port-1", + ), + ], + )) + + result = await get_client_log(result_job.id, 0, mock_request) + + assert result.status_code == 400 + assert json.loads(result.body.decode()) == {"error": f"Log file path is None or empty"} diff --git a/florist/tests/unit/api/routes/server/test_training.py b/florist/tests/unit/api/routes/server/test_training.py index 2a900f29..9f7d3bc9 100644 --- a/florist/tests/unit/api/routes/server/test_training.py +++ b/florist/tests/unit/api/routes/server/test_training.py @@ -13,23 +13,6 @@ server_training_listener ) - -async def test_start_fail_unsupported_server_model() -> None: - # Arrange - test_job_id = "test-job-id" - _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() - test_job["model"] = "WRONG MODEL" - - # Act - response = await start(test_job_id, mock_fastapi_request) - - # Assert - assert response.status_code == 500 - json_body = json.loads(response.body.decode()) - assert json_body == {"error": ANY} - assert "value is not a valid enumeration member" in json_body["error"] - - @patch("florist.api.routes.server.training.client_training_listener") @patch("florist.api.routes.server.training.server_training_listener") @patch("florist.api.routes.server.training.launch_local_server") @@ -37,7 +20,11 @@ async def test_start_fail_unsupported_server_model() -> None: @patch("florist.api.routes.server.training.requests") @patch("florist.api.db.entities.Job.set_status") @patch("florist.api.db.entities.Job.set_uuids") +@patch("florist.api.db.entities.Job.set_server_log_file_path") +@patch("florist.api.db.entities.Job.set_client_log_file_path") async def test_start_success( + mock_set_client_log_file_path: Mock, + mock_server_log_file_path: Mock, mock_set_uuids: Mock, mock_set_status: Mock, mock_requests: Mock, @@ -51,7 +38,8 @@ async def test_start_success( test_server_config, test_job, mock_job_collection, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" - mock_launch_local_server.return_value = (test_server_uuid, None) + test_server_log_file_path = "test-log-file-path" + mock_launch_local_server.return_value = (test_server_uuid, None, test_server_log_file_path) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"fit_start\": null}" @@ -60,8 +48,13 @@ async def test_start_success( mock_response = Mock() mock_response.status_code = 200 test_client_1_uuid = "test-client-1-uuid" + test_client_1_log_file_path = "test-client-1-log-file-path" test_client_2_uuid = "test-client-2-uuid" - mock_response.json.side_effect = [{"uuid": test_client_1_uuid}, {"uuid": test_client_2_uuid}] + test_client_2_log_file_path = "test-client-2-log-file-path" + mock_response.json.side_effect = [ + {"uuid": test_client_1_uuid, "log_file_path": test_client_1_log_file_path}, + {"uuid": test_client_2_uuid, "log_file_path": test_client_2_log_file_path}, + ] mock_requests.get.return_value = mock_response mock_client_training_listener.return_value = AsyncMock() @@ -76,6 +69,7 @@ async def test_start_success( assert json_body == {"server_uuid": test_server_uuid, "client_uuids": [test_client_1_uuid, test_client_2_uuid]} mock_job_collection.find_one.assert_called_with({"_id": test_job_id}) + mock_set_status.assert_called_once_with(JobStatus.IN_PROGRESS, mock_fastapi_request.app.database) assert isinstance(mock_launch_local_server.call_args_list[0][1]["model"], MnistNet) mock_launch_local_server.assert_called_once_with( @@ -90,6 +84,8 @@ async def test_start_success( ) mock_redis.Redis.assert_called_once_with(host=test_job["redis_host"], port=test_job["redis_port"]) mock_redis_connection.get.assert_called_once_with(test_server_uuid) + mock_server_log_file_path.assert_called_once_with(test_server_log_file_path, mock_fastapi_request.app.database) + mock_requests.get.assert_any_call( url=f"http://{test_job['clients_info'][0]['service_address']}/api/client/start", params={ @@ -111,7 +107,11 @@ async def test_start_success( }, ) - mock_set_status.assert_called_once_with(JobStatus.IN_PROGRESS, mock_fastapi_request.app.database) + mock_set_client_log_file_path.assert_has_calls([ + call(0, test_client_1_log_file_path, mock_fastapi_request.app.database), + call(1, test_client_2_log_file_path, mock_fastapi_request.app.database), + ]) + mock_set_uuids.assert_called_once_with( test_server_uuid, [test_client_1_uuid, test_client_2_uuid], @@ -130,6 +130,22 @@ async def test_start_success( ]) +async def test_start_fail_unsupported_server_model() -> None: + # Arrange + test_job_id = "test-job-id" + _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() + test_job["model"] = "WRONG MODEL" + + # Act + response = await start(test_job_id, mock_fastapi_request) + + # Assert + assert response.status_code == 500 + json_body = json.loads(response.body.decode()) + assert json_body == {"error": ANY} + assert "value is not a valid enumeration member" in json_body["error"] + + async def test_start_fail_unsupported_client() -> None: # Arrange test_job_id = "test-job-id" @@ -222,13 +238,20 @@ async def test_start_launch_server_exception(mock_launch_local_server: Mock, _: @patch("florist.api.db.entities.Job.set_status") @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") -async def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_local_server: Mock, _: Mock) -> None: +@patch("florist.api.db.entities.Job.set_server_log_file_path") +async def test_start_wait_for_metric_exception( + mock_set_server_log_file_path: Mock, + mock_redis: Mock, + mock_launch_local_server: Mock, + _: Mock, +) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" - mock_launch_local_server.return_value = (test_server_uuid, None) + test_log_file_path = "test-log-file-path" + mock_launch_local_server.return_value = (test_server_uuid, None, test_log_file_path) test_exception = Exception("test exception") mock_redis.Redis.side_effect = test_exception @@ -241,18 +264,28 @@ async def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_loc json_body = json.loads(response.body.decode()) assert json_body == {"error": str(test_exception)} + mock_set_server_log_file_path.assert_called_once_with(test_log_file_path, mock_fastapi_request.app.database) + @patch("florist.api.db.entities.Job.set_status") @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.monitoring.metrics.time") # just so time.sleep does not actually sleep -async def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_launch_local_server: Mock, __: Mock) -> None: +@patch("florist.api.db.entities.Job.set_server_log_file_path") +async def test_start_wait_for_metric_timeout( + mock_set_server_log_file_path: Mock, + _: Mock, + mock_redis: Mock, + mock_launch_local_server: Mock, + __: Mock, +) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" - mock_launch_local_server.return_value = (test_server_uuid, None) + test_log_file_path = "test-log-file-path" + mock_launch_local_server.return_value = (test_server_uuid, None, test_log_file_path) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"foo\": null}" @@ -266,18 +299,28 @@ async def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_lau json_body = json.loads(response.body.decode()) assert json_body == {"error": "Metric 'fit_start' not been found after 20 retries."} + mock_set_server_log_file_path.assert_called_once_with(test_log_file_path, mock_fastapi_request.app.database) + @patch("florist.api.db.entities.Job.set_status") @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.routes.server.training.requests") -async def test_start_fail_response(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock, _: Mock) -> None: +@patch("florist.api.db.entities.Job.set_server_log_file_path") +async def test_start_fail_response( + mock_set_server_log_file_path: Mock, + mock_requests: Mock, + mock_redis: Mock, + mock_launch_local_server: Mock, + _: Mock, +) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" - mock_launch_local_server.return_value = (test_server_uuid, None) + test_log_file_path = "test-log-file-path" + mock_launch_local_server.return_value = (test_server_uuid, None, test_log_file_path) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"fit_start\": null}" @@ -296,18 +339,28 @@ async def test_start_fail_response(mock_requests: Mock, mock_redis: Mock, mock_l json_body = json.loads(response.body.decode()) assert json_body == {"error": f"Client response returned 403. Response: error"} + mock_set_server_log_file_path.assert_called_once_with(test_log_file_path, mock_fastapi_request.app.database) + @patch("florist.api.db.entities.Job.set_status") @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.routes.server.training.requests") -async def test_start_no_uuid_in_response(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock, _: Mock) -> None: +@patch("florist.api.db.entities.Job.set_server_log_file_path") +async def test_start_no_uuid_in_response( + mock_set_server_log_file_path: Mock, + mock_requests: Mock, + mock_redis: Mock, + mock_launch_local_server: Mock, + _: Mock, +) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" - mock_launch_local_server.return_value = (test_server_uuid, None) + test_log_file_path = "test-log-file-path" + mock_launch_local_server.return_value = (test_server_uuid, None, test_log_file_path) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"fit_start\": null}" @@ -326,6 +379,8 @@ async def test_start_no_uuid_in_response(mock_requests: Mock, mock_redis: Mock, json_body = json.loads(response.body.decode()) assert json_body == {"error": "Client response did not return a UUID. Response: {'foo': 'bar'}"} + mock_set_server_log_file_path.assert_called_once_with(test_log_file_path, mock_fastapi_request.app.database) + @patch("florist.api.routes.server.training.AsyncIOMotorClient") @patch("florist.api.routes.server.training.get_from_redis") @@ -615,6 +670,7 @@ def _setup_test_job_and_mocks() -> Tuple[Dict[str, Any], Dict[str, Any], Mock, M mock_job_collection.find_one.return_value = mock_find_one mock_fastapi_request = Mock() mock_fastapi_request.app.database = {JOB_COLLECTION_NAME: mock_job_collection} + mock_fastapi_request.app.synchronous_database = {JOB_COLLECTION_NAME: mock_job_collection} return test_server_config, test_job, mock_job_collection, mock_fastapi_request diff --git a/florist/tests/unit/api/servers/test_launch.py b/florist/tests/unit/api/servers/test_launch.py index 06e818a8..49c1c2e2 100644 --- a/florist/tests/unit/api/servers/test_launch.py +++ b/florist/tests/unit/api/servers/test_launch.py @@ -1,7 +1,6 @@ from unittest.mock import ANY, Mock, patch from florist.api.clients.mnist import MnistNet -from florist.api.monitoring.logs import get_server_log_file_path from florist.api.monitoring.metrics import RedisMetricsReporter from florist.api.servers.launch import launch_local_server from florist.api.servers.utils import get_server @@ -20,7 +19,7 @@ def test_launch_local_server(mock_launch_server: Mock) -> None: test_server_process = "test-server-process" mock_launch_server.return_value = test_server_process - server_uuid, server_process = launch_local_server( + server_uuid, server_process, log_file_path = launch_local_server( test_model, test_n_clients, test_server_address, @@ -41,7 +40,7 @@ def test_launch_local_server(mock_launch_server: Mock) -> None: ANY, test_server_address, test_n_server_rounds, - str(get_server_log_file_path(server_uuid)), + log_file_path, ) assert call_kwargs == {"seconds_to_sleep": 0} assert call_args[0].func == get_server diff --git a/florist/tests/unit/api/test_client.py b/florist/tests/unit/api/test_client.py index 6b497bf0..1a5790ae 100644 --- a/florist/tests/unit/api/test_client.py +++ b/florist/tests/unit/api/test_client.py @@ -1,5 +1,6 @@ """Tests for FLorist's client FastAPI endpoints.""" import json +import os from unittest.mock import ANY, Mock, patch from florist.api import client @@ -29,10 +30,11 @@ def test_start_success(mock_launch_client: Mock) -> None: assert response.status_code == 200 json_body = json.loads(response.body.decode()) - assert json_body == {"uuid": ANY} + log_file_path = str(get_client_log_file_path(json_body["uuid"])) - log_file_name = str(get_client_log_file_path(json_body["uuid"])) - mock_launch_client.assert_called_once_with(ANY, test_server_address, log_file_name) + assert json_body == {"uuid": ANY, "log_file_path": log_file_path} + + mock_launch_client.assert_called_once_with(ANY, test_server_address, log_file_path) client_obj = mock_launch_client.call_args_list[0][0][0] assert isinstance(client_obj, MnistClient) @@ -61,7 +63,7 @@ def test_start_fail_unsupported_client() -> None: @patch("florist.api.client.launch_client", side_effect=Exception("test exception")) -def test_start_fail_exception(mock_launch_client: Mock) -> None: +def test_start_fail_exception(_: Mock) -> None: test_server_address = "test-server-address" test_client = "MNIST" test_data_path = "test/data/path" @@ -91,6 +93,7 @@ def test_check_status(mock_redis: Mock) -> None: mock_redis.Redis.assert_called_with(host=test_redis_host, port=test_redis_port) assert json.loads(response.body.decode()) == {"info": "test"} + @patch("florist.api.monitoring.metrics.redis") def test_check_status_not_found(mock_redis: Mock) -> None: mock_redis_connection = Mock() @@ -108,8 +111,9 @@ def test_check_status_not_found(mock_redis: Mock) -> None: assert response.status_code == 404 assert json.loads(response.body.decode()) == {"error": f"Client {test_uuid} Not Found"} + @patch("florist.api.monitoring.metrics.redis.Redis", side_effect=Exception("test exception")) -def test_check_status_fail_exception(mock_redis: Mock) -> None: +def test_check_status_fail_exception(_: Mock) -> None: test_uuid = "test_uuid" test_redis_host = "localhost" @@ -119,3 +123,19 @@ def test_check_status_fail_exception(mock_redis: Mock) -> None: assert response.status_code == 500 assert json.loads(response.body.decode()) == {"error": "test exception"} + + +def test_get_log() -> None: + test_client_uuid = "test-client-uuid" + test_log_file_content = "this is a test log file content" + test_log_file_path = str(get_client_log_file_path(test_client_uuid)) + + with open(test_log_file_path, "w") as f: + f.write(test_log_file_content) + + response = client.get_log(test_log_file_path) + + assert response.status_code == 200 + assert response.body.decode() == f"\"{test_log_file_content}\"" + + os.remove(test_log_file_path)