From b198b950836e107eb0be4b943d5278262da81e7d Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Mon, 18 Nov 2024 17:40:42 -0500 Subject: [PATCH 1/7] Adding endpoints to get log contents for servers and clients, needs to fix unit tests --- florist/api/client.py | 29 +++++++++-- florist/api/db/entities.py | 44 ++++++++++++++++- florist/api/routes/server/job.py | 71 +++++++++++++++++++++++++++ florist/api/routes/server/training.py | 8 ++- florist/api/servers/launch.py | 14 +++--- 5 files changed, 152 insertions(+), 14 deletions(-) diff --git a/florist/api/client.py b/florist/api/client.py index 30fcf8f..1a051fb 100644 --- a/florist/api/client.py +++ b/florist/api/client.py @@ -43,9 +43,14 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red :param redis_port: (str) the port for the Redis instance for metrics reporting. :return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the client in the format below, which can be used to pull metrics from Redis. - {"uuid": } + { + "uuid": (str) The client's uuid, which can be used to pull metrics from Redis, + "log_file_path": (str) The local path of the log file for this client, + } If not successful, returns the appropriate error code with a JSON with the format below: - {"error": } + { + "error": (str) The error message, + } """ try: if client not in Client.list(): @@ -65,10 +70,10 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red reporters=[metrics_reporter], ) - log_file_name = str(get_client_log_file_path(client_uuid)) - launch_client(client_obj, server_address, log_file_name) + log_file_path = str(get_client_log_file_path(client_uuid)) + launch_client(client_obj, server_address, log_file_path) - return JSONResponse({"uuid": client_uuid}) + return JSONResponse({"uuid": client_uuid, "log_file_path": log_file_path}) except Exception as ex: return JSONResponse({"error": str(ex)}, status_code=500) @@ -98,3 +103,17 @@ def check_status(client_uuid: str, redis_host: str, redis_port: str) -> JSONResp except Exception as ex: LOGGER.exception(ex) return JSONResponse({"error": str(ex)}, status_code=500) + + +@app.get("/api/client/get_log") +def get_log(log_file_path: str) -> JSONResponse: + """ + Return the contents of the log file under the given path. + + :param log_file_path: (str) the path of the logt file. + + :return: (JSONResponse) Returns the contents of the file as a string. + """ + with open(log_file_path, "r") as f: + content = f.read() + return JSONResponse(content) diff --git a/florist/api/db/entities.py b/florist/api/db/entities.py index 6bcc951..bbf594d 100644 --- a/florist/api/db/entities.py +++ b/florist/api/db/entities.py @@ -49,6 +49,7 @@ class ClientInfo(BaseModel): redis_port: str = Field(...) uuid: Optional[Annotated[str, Field(...)]] metrics: Optional[Annotated[str, Field(...)]] + log_file_path: Optional[Annotated[str, Field(...)]] class Config: """MongoDB config for the ClientInfo DB entity.""" @@ -63,6 +64,7 @@ class Config: "redis_port": "6380", "uuid": "0c316680-1375-4e07-84c3-a732a2e6d03f", "metrics": '{"host_type": "client", "initialized": "2024-03-25 11:20:56.819569", "rounds": {"1": {"fit_start": "2024-03-25 11:20:56.827081"}}}', + "log_file_path": "/Users/foo/client/logfile.log", }, } @@ -78,6 +80,7 @@ class Job(BaseModel): config_parser: Optional[Annotated[ConfigParser, Field(...)]] server_uuid: Optional[Annotated[str, Field(...)]] server_metrics: Optional[Annotated[str, Field(...)]] + server_log_file_path: Optional[Annotated[str, Field(...)]] redis_host: Optional[Annotated[str, Field(...)]] redis_port: Optional[Annotated[str, Field(...)]] clients_info: Optional[Annotated[List[ClientInfo], Field(...)]] @@ -130,7 +133,7 @@ async def set_uuids(self, server_uuid: str, client_uuids: List[str], database: A """ Save the server and clients' UUIDs in the database under the current job's id. - :param server_uuid: [str] the server_uuid to be saved in the database. + :param server_uuid: (str) the server_uuid to be saved in the database. :param client_uuids: List[str] the list of client_uuids to be saved in the database. :param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored. """ @@ -220,6 +223,44 @@ def set_client_metrics( ) assert_updated_successfully(update_result) + async def set_server_log_file_path(self, log_file_path: str, database: AsyncIOMotorDatabase[Any]) -> None: + """ + Save the server's log file path in the database under the current job's id. + + :param log_file_path: (str) the file path to be saved in the database. + :param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored. + """ + job_collection = database[JOB_COLLECTION_NAME] + self.server_log_file_path = log_file_path + update_result = await job_collection.update_one( + {"_id": self.id}, {"$set": {"server_log_file_path": log_file_path}} + ) + assert_updated_successfully(update_result) + + async def set_client_log_file_path( + self, + client_index: int, + log_file_path: str, + database: AsyncIOMotorDatabase[Any], + ) -> None: + """ + Save the clients' log file path in the database under the given client index and current job's id. + + :param client_index: (str) the index of the client in the job. + :param log_file_path: (str) the path oof the client's log file. + :param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored. + """ + assert self.clients_info is not None, "Job has no clients." + assert ( + 0 <= client_index < len(self.clients_info) + ), f"Client index {client_index} is invalid (total: {len(self.clients_info)})" + + job_collection = database[JOB_COLLECTION_NAME] + update_result = await job_collection.update_one( + {"_id": self.id}, {"$set": {f"clients_info.{client_index}.log_file_path": log_file_path}} + ) + assert_updated_successfully(update_result) + class Config: """MongoDB config for the Job DB entity.""" @@ -233,6 +274,7 @@ class Config: "server_config": '{"n_server_rounds": 3, "batch_size": 8, "local_epochs": 1}', "server_uuid": "d73243cf-8b89-473b-9607-8cd0253a101d", "server_metrics": '{"host_type": "server", "fit_start": "2024-04-23 15:33:12.865604", "rounds": {"1": {"fit_start": "2024-04-23 15:33:12.869001"}}}', + "server_log_file_path": "/Users/foo/server/logfile.log", "redis_host": "localhost", "redis_port": "6379", "clients_info": [ diff --git a/florist/api/routes/server/job.py b/florist/api/routes/server/job.py index be9dcda..f450ee4 100644 --- a/florist/api/routes/server/job.py +++ b/florist/api/routes/server/job.py @@ -2,6 +2,7 @@ from typing import List, Union +import requests from fastapi import APIRouter, Body, Request, status from fastapi.responses import JSONResponse @@ -99,3 +100,73 @@ async def change_job_status(job_id: str, status: JobStatus, request: Request) -> return JSONResponse(content={"error": str(assertion_e)}, status_code=400) except Exception as general_e: return JSONResponse(content={"error": str(general_e)}, status_code=500) + + +@router.get("/get_server_log/{job_id}") +async def get_server_log(job_id: str, request: Request) -> JSONResponse: + """ + Return the contents of the server's log file for the given job id. + + :param job_id: (str) the ID of the job to get the server logs for. + :param request: (fastapi.Request) the FastAPI request object. + + :return: (JSONResponse) if successful, returns the contents of the file as a string. + If not successful, returns the appropriate error code with a JSON with the format below: + {"error": } + """ + try: + job = await Job.find_by_id(job_id, request.app.database) + + assert job is not None, f"Job {job_id} not found" + assert ( + job.server_log_file_path is not None and job.server_log_file_path != "" + ), "log file path is None or empty" + + with open(job.server_log_file_path, "r") as f: + content = f.read() + return JSONResponse(content) + + except AssertionError as assertion_e: + return JSONResponse(content={"error": str(assertion_e)}, status_code=400) + except Exception as general_e: + return JSONResponse(content={"error": str(general_e)}, status_code=500) + + +@router.get("/get_client_log/{job_id}/{client_index}") +async def get_client_log(job_id: str, client_index: int, request: Request) -> JSONResponse: + """ + Return the contents of the log file for the client with given index under given job id. + + :param job_id: (str) the ID of the job to get the client logs for. + :param client_index: (int) the index of the client within the job. + :param request: (fastapi.Request) the FastAPI request object. + + :return: (JSONResponse) if successful, returns the contents of the file as a string. + If not successful, returns the appropriate error code with a JSON with the format below: + {"error": } + """ + try: + job = await Job.find_by_id(job_id, request.app.database) + + assert job is not None, f"Job {job_id} not found" + assert job.clients_info is not None, "Job has no clients." + assert ( + 0 <= client_index < len(job.clients_info) + ), f"Client index {client_index} is invalid (total: {len(job.clients_info)})" + + client_info = job.clients_info[client_index] + assert ( + client_info.log_file_path is not None and client_info.log_file_path != "" + ), "log file path is None or empty" + + response = requests.get( + url=f"http://{client_info.service_address}/api/client/get_log", + params={"log_file_path": client_info.log_file_path}, + ) + json_response = response.json() + return JSONResponse(json_response) + + except AssertionError as assertion_e: + return JSONResponse(content={"error": str(assertion_e)}, status_code=400) + except Exception as general_e: + return JSONResponse(content={"error": str(general_e)}, status_code=500) diff --git a/florist/api/routes/server/training.py b/florist/api/routes/server/training.py index f00778f..f6e03d7 100644 --- a/florist/api/routes/server/training.py +++ b/florist/api/routes/server/training.py @@ -71,7 +71,7 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks model_class = Model.class_for_model(job.model) # Start the server - server_uuid, _ = launch_local_server( + server_uuid, _, server_log_file_path = launch_local_server( model=model_class(), n_clients=len(job.clients_info), server_address=job.server_address, @@ -79,11 +79,13 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks redis_port=job.redis_port, **server_config, ) + await job.set_server_log_file_path(server_log_file_path, request.app.database) wait_for_metric(server_uuid, "fit_start", job.redis_host, job.redis_port, logger=LOGGER) # Start the clients client_uuids: List[str] = [] - for client_info in job.clients_info: + for i in range(len(job.clients_info)): + client_info = job.clients_info[i] parameters = { "server_address": job.server_address, "client": client_info.client.value, @@ -103,6 +105,8 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks client_uuids.append(json_response["uuid"]) + await job.set_client_log_file_path(i, json_response["log_file_path"], request.app.database) + await job.set_uuids(server_uuid, client_uuids, request.app.database) # Start the server training listener as a background task to update diff --git a/florist/api/servers/launch.py b/florist/api/servers/launch.py index 298993c..b8ddda7 100644 --- a/florist/api/servers/launch.py +++ b/florist/api/servers/launch.py @@ -22,7 +22,7 @@ def launch_local_server( local_epochs: int, redis_host: str, redis_port: str, -) -> Tuple[str, Process]: +) -> Tuple[str, Process, str]: """ Launch a FL server locally. @@ -34,8 +34,10 @@ def launch_local_server( :param local_epochs: (int) The number of epochs to run by the clients. :param redis_host: (str) the host name for the Redis instance for metrics reporting. :param redis_port: (str) the port for the Redis instance for metrics reporting. - :return: (Tuple[str, multiprocessing.Process]) the UUID of the server, which can be used to pull - metrics from Redis, along with its local process object. + :return: (Tuple[str, multiprocessing.Process, str]) a tuple with + - The UUID of the server, which can be used to pull metrics from Redis + - The server's local process object + - The local path for the log file """ server_uuid = str(uuid.uuid4()) @@ -49,13 +51,13 @@ def launch_local_server( local_epochs=local_epochs, ) - log_file_name = str(get_server_log_file_path(server_uuid)) + log_file_path = str(get_server_log_file_path(server_uuid)) server_process = launch_server( server_constructor, server_address, n_server_rounds, - log_file_name, + log_file_path, seconds_to_sleep=0, ) - return server_uuid, server_process + return server_uuid, server_process, log_file_path From e4e5536f5fb4dc3d38348305d1d5bbd017b7c114 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 19 Nov 2024 16:44:23 -0500 Subject: [PATCH 2/7] Fixing unit tests, need test coverage --- .../integration/api/launchers/test_launch.py | 2 +- .../unit/api/routes/server/test_training.py | 62 +++++++++++++++---- florist/tests/unit/api/servers/test_launch.py | 5 +- florist/tests/unit/api/test_client.py | 2 +- 4 files changed, 53 insertions(+), 18 deletions(-) diff --git a/florist/tests/integration/api/launchers/test_launch.py b/florist/tests/integration/api/launchers/test_launch.py index f8f14b7..ef5bf18 100644 --- a/florist/tests/integration/api/launchers/test_launch.py +++ b/florist/tests/integration/api/launchers/test_launch.py @@ -11,7 +11,7 @@ from florist.api.servers.utils import get_server -def assert_string_in_file(file_path: str, search_string: str) -> bool: +def assert_string_in_file(file_path: str, search_string: str) -> None: with open(file_path, "r") as f: file_contents = f.read() match = re.search(search_string, file_contents) diff --git a/florist/tests/unit/api/routes/server/test_training.py b/florist/tests/unit/api/routes/server/test_training.py index 17e5af7..94614f5 100644 --- a/florist/tests/unit/api/routes/server/test_training.py +++ b/florist/tests/unit/api/routes/server/test_training.py @@ -10,7 +10,6 @@ client_training_listener, start, server_training_listener, - CHECK_CLIENT_STATUS_API, ) @@ -19,7 +18,11 @@ @patch("florist.api.routes.server.training.requests") @patch("florist.api.db.entities.Job.set_status") @patch("florist.api.db.entities.Job.set_uuids") +@patch("florist.api.db.entities.Job.set_server_log_file_path") +@patch("florist.api.db.entities.Job.set_client_log_file_path") async def test_start_success( + mock_set_client_log_file_path: Mock, + mock_server_log_file_path: Mock, mock_set_uuids: Mock, mock_set_status: Mock, mock_requests: Mock, @@ -31,7 +34,8 @@ async def test_start_success( test_server_config, test_job, mock_job_collection, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" - mock_launch_local_server.return_value = (test_server_uuid, None) + test_server_log_file_path = "test-log-file-path" + mock_launch_local_server.return_value = (test_server_uuid, None, test_server_log_file_path) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"fit_start\": null}" @@ -40,8 +44,13 @@ async def test_start_success( mock_response = Mock() mock_response.status_code = 200 test_client_1_uuid = "test-client-1-uuid" + test_client_1_log_file_path = "test-client-1-log-file-path" test_client_2_uuid = "test-client-2-uuid" - mock_response.json.side_effect = [{"uuid": test_client_1_uuid}, {"uuid": test_client_2_uuid}] + test_client_2_log_file_path = "test-client-2-log-file-path" + mock_response.json.side_effect = [ + {"uuid": test_client_1_uuid, "log_file_path": test_client_1_log_file_path}, + {"uuid": test_client_2_uuid, "log_file_path": test_client_2_log_file_path}, + ] mock_requests.get.return_value = mock_response mock_background_tasks = Mock() @@ -230,13 +239,19 @@ async def test_start_launch_server_exception(mock_launch_local_server: Mock, _: @patch("florist.api.db.entities.Job.set_status") @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") -async def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_local_server: Mock, _: Mock) -> None: +@patch("florist.api.db.entities.Job.set_server_log_file_path") +async def test_start_wait_for_metric_exception( + mock_set_server_log_file_path: Mock, + mock_redis: Mock, + mock_launch_local_server: Mock, + _: Mock, +) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" - mock_launch_local_server.return_value = (test_server_uuid, None) + mock_launch_local_server.return_value = (test_server_uuid, None, "test-log-file-path") test_exception = Exception("test exception") mock_redis.Redis.side_effect = test_exception @@ -254,13 +269,20 @@ async def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_loc @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.monitoring.metrics.time") # just so time.sleep does not actually sleep -async def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_launch_local_server: Mock, mock_set_status: Mock) -> None: +@patch("florist.api.db.entities.Job.set_server_log_file_path") +async def test_start_wait_for_metric_timeout( + mock_set_server_log_file_path: Mock, + _: Mock, + mock_redis: Mock, + mock_launch_local_server: Mock, + __: Mock, +) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" - mock_launch_local_server.return_value = (test_server_uuid, None) + mock_launch_local_server.return_value = (test_server_uuid, None, "test-log-file-path") mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"foo\": null}" @@ -279,13 +301,20 @@ async def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_lau @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.routes.server.training.requests") -async def test_start_fail_response(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock, _: Mock) -> None: +@patch("florist.api.db.entities.Job.set_server_log_file_path") +async def test_start_fail_response( + mock_set_server_log_file_path: Mock, + mock_requests: Mock, + mock_redis: Mock, + mock_launch_local_server: Mock, + _: Mock, +) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" - mock_launch_local_server.return_value = (test_server_uuid, None) + mock_launch_local_server.return_value = (test_server_uuid, None, "test-log-file-path") mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"fit_start\": null}" @@ -309,13 +338,20 @@ async def test_start_fail_response(mock_requests: Mock, mock_redis: Mock, mock_l @patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.routes.server.training.requests") -async def test_start_no_uuid_in_response(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock, _: Mock) -> None: +@patch("florist.api.db.entities.Job.set_server_log_file_path") +async def test_start_no_uuid_in_response( + mock_set_server_log_file_path: Mock, + mock_requests: Mock, + mock_redis: Mock, + mock_launch_local_server: Mock, + _: Mock, +) -> None: # Arrange test_job_id = "test-job-id" _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" - mock_launch_local_server.return_value = (test_server_uuid, None) + mock_launch_local_server.return_value = (test_server_uuid, None, "test-log-file-path") mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"fit_start\": null}" @@ -458,7 +494,7 @@ def test_server_training_listener_fail_no_redis_port() -> None: @patch("florist.api.routes.server.training.get_subscriber") def test_client_training_listener(mock_get_subscriber: Mock, mock_get_from_redis: Mock) -> None: # Setup - test_client_uuid = "test-client-uuid"; + test_client_uuid = "test-client-uuid" test_job = Job(**{ "clients_info": [ { @@ -512,7 +548,7 @@ def test_client_training_listener(mock_get_subscriber: Mock, mock_get_from_redis @patch("florist.api.routes.server.training.get_from_redis") def test_client_training_listener_already_finished(mock_get_from_redis: Mock) -> None: # Setup - test_client_uuid = "test-client-uuid"; + test_client_uuid = "test-client-uuid" test_job = Job(**{ "clients_info": [ { diff --git a/florist/tests/unit/api/servers/test_launch.py b/florist/tests/unit/api/servers/test_launch.py index 06e818a..49c1c2e 100644 --- a/florist/tests/unit/api/servers/test_launch.py +++ b/florist/tests/unit/api/servers/test_launch.py @@ -1,7 +1,6 @@ from unittest.mock import ANY, Mock, patch from florist.api.clients.mnist import MnistNet -from florist.api.monitoring.logs import get_server_log_file_path from florist.api.monitoring.metrics import RedisMetricsReporter from florist.api.servers.launch import launch_local_server from florist.api.servers.utils import get_server @@ -20,7 +19,7 @@ def test_launch_local_server(mock_launch_server: Mock) -> None: test_server_process = "test-server-process" mock_launch_server.return_value = test_server_process - server_uuid, server_process = launch_local_server( + server_uuid, server_process, log_file_path = launch_local_server( test_model, test_n_clients, test_server_address, @@ -41,7 +40,7 @@ def test_launch_local_server(mock_launch_server: Mock) -> None: ANY, test_server_address, test_n_server_rounds, - str(get_server_log_file_path(server_uuid)), + log_file_path, ) assert call_kwargs == {"seconds_to_sleep": 0} assert call_args[0].func == get_server diff --git a/florist/tests/unit/api/test_client.py b/florist/tests/unit/api/test_client.py index 6b497bf..c2e825e 100644 --- a/florist/tests/unit/api/test_client.py +++ b/florist/tests/unit/api/test_client.py @@ -29,7 +29,7 @@ def test_start_success(mock_launch_client: Mock) -> None: assert response.status_code == 200 json_body = json.loads(response.body.decode()) - assert json_body == {"uuid": ANY} + assert json_body == {"uuid": ANY, "log_file_path": str(get_client_log_file_path(json_body["uuid"]))} log_file_name = str(get_client_log_file_path(json_body["uuid"])) mock_launch_client.assert_called_once_with(ANY, test_server_address, log_file_name) From 437fec6ada7baf1643812fd781f45ed9a1362743 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 19 Nov 2024 17:02:48 -0500 Subject: [PATCH 3/7] Adding extra assertions on existing tests --- .../unit/api/routes/server/test_training.py | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/florist/tests/unit/api/routes/server/test_training.py b/florist/tests/unit/api/routes/server/test_training.py index 94614f5..fcaa9a2 100644 --- a/florist/tests/unit/api/routes/server/test_training.py +++ b/florist/tests/unit/api/routes/server/test_training.py @@ -64,6 +64,7 @@ async def test_start_success( assert json_body == {"server_uuid": test_server_uuid, "client_uuids": [test_client_1_uuid, test_client_2_uuid]} mock_job_collection.find_one.assert_called_with({"_id": test_job_id}) + mock_set_status.assert_called_once_with(JobStatus.IN_PROGRESS, mock_fastapi_request.app.database) assert isinstance(mock_launch_local_server.call_args_list[0][1]["model"], MnistNet) mock_launch_local_server.assert_called_once_with( @@ -78,6 +79,8 @@ async def test_start_success( ) mock_redis.Redis.assert_called_once_with(host=test_job["redis_host"], port=test_job["redis_port"]) mock_redis_connection.get.assert_called_once_with(test_server_uuid) + mock_server_log_file_path.assert_called_once_with(test_server_log_file_path, mock_fastapi_request.app.database) + mock_requests.get.assert_any_call( url=f"http://{test_job['clients_info'][0]['service_address']}/api/client/start", params={ @@ -99,7 +102,11 @@ async def test_start_success( }, ) - mock_set_status.assert_called_once_with(JobStatus.IN_PROGRESS, mock_fastapi_request.app.database) + mock_set_client_log_file_path.assert_has_calls([ + call(0, test_client_1_log_file_path, mock_fastapi_request.app.database), + call(1, test_client_2_log_file_path, mock_fastapi_request.app.database), + ]) + mock_set_uuids.assert_called_once_with( test_server_uuid, [test_client_1_uuid, test_client_2_uuid], @@ -164,7 +171,7 @@ async def test_start_fail_unsupported_client() -> None: @patch("florist.api.db.entities.Job.set_status") -async def test_start_fail_missing_info(mock_set_status: Mock) -> None: +async def test_start_fail_missing_info(_: Mock) -> None: fields_to_be_removed = ["model", "server_config", "clients_info", "server_address", "redis_host", "redis_port"] for field_to_be_removed in fields_to_be_removed: @@ -184,7 +191,7 @@ async def test_start_fail_missing_info(mock_set_status: Mock) -> None: @patch("florist.api.db.entities.Job.set_status") -async def test_start_fail_invalid_server_config(mock_set_status: Mock) -> None: +async def test_start_fail_invalid_server_config(_: Mock) -> None: # Arrange test_job_id = "test-job-id" _, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks() @@ -251,7 +258,8 @@ async def test_start_wait_for_metric_exception( _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" - mock_launch_local_server.return_value = (test_server_uuid, None, "test-log-file-path") + test_log_file_path = "test-log-file-path" + mock_launch_local_server.return_value = (test_server_uuid, None, test_log_file_path) test_exception = Exception("test exception") mock_redis.Redis.side_effect = test_exception @@ -264,6 +272,8 @@ async def test_start_wait_for_metric_exception( json_body = json.loads(response.body.decode()) assert json_body == {"error": str(test_exception)} + mock_set_server_log_file_path.assert_called_once_with(test_log_file_path, mock_fastapi_request.app.database) + @patch("florist.api.db.entities.Job.set_status") @patch("florist.api.routes.server.training.launch_local_server") @@ -282,7 +292,8 @@ async def test_start_wait_for_metric_timeout( _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" - mock_launch_local_server.return_value = (test_server_uuid, None, "test-log-file-path") + test_log_file_path = "test-log-file-path" + mock_launch_local_server.return_value = (test_server_uuid, None, test_log_file_path) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"foo\": null}" @@ -296,6 +307,8 @@ async def test_start_wait_for_metric_timeout( json_body = json.loads(response.body.decode()) assert json_body == {"error": "Metric 'fit_start' not been found after 20 retries."} + mock_set_server_log_file_path.assert_called_once_with(test_log_file_path, mock_fastapi_request.app.database) + @patch("florist.api.db.entities.Job.set_status") @patch("florist.api.routes.server.training.launch_local_server") @@ -314,7 +327,8 @@ async def test_start_fail_response( _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" - mock_launch_local_server.return_value = (test_server_uuid, None, "test-log-file-path") + test_log_file_path = "test-log-file-path" + mock_launch_local_server.return_value = (test_server_uuid, None, test_log_file_path) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"fit_start\": null}" @@ -333,6 +347,8 @@ async def test_start_fail_response( json_body = json.loads(response.body.decode()) assert json_body == {"error": f"Client response returned 403. Response: error"} + mock_set_server_log_file_path.assert_called_once_with(test_log_file_path, mock_fastapi_request.app.database) + @patch("florist.api.db.entities.Job.set_status") @patch("florist.api.routes.server.training.launch_local_server") @@ -351,7 +367,8 @@ async def test_start_no_uuid_in_response( _, _, _, mock_fastapi_request = _setup_test_job_and_mocks() test_server_uuid = "test-server-uuid" - mock_launch_local_server.return_value = (test_server_uuid, None, "test-log-file-path") + test_log_file_path = "test-log-file-path" + mock_launch_local_server.return_value = (test_server_uuid, None, test_log_file_path) mock_redis_connection = Mock() mock_redis_connection.get.return_value = b"{\"fit_start\": null}" @@ -370,6 +387,8 @@ async def test_start_no_uuid_in_response( json_body = json.loads(response.body.decode()) assert json_body == {"error": "Client response did not return a UUID. Response: {'foo': 'bar'}"} + mock_set_server_log_file_path.assert_called_once_with(test_log_file_path, mock_fastapi_request.app.database) + @patch("florist.api.routes.server.training.get_from_redis") @patch("florist.api.routes.server.training.get_subscriber") From 3c8a2dc370f690581ff377e771741a0fe77147b9 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 19 Nov 2024 18:03:21 -0500 Subject: [PATCH 4/7] Adding tests for log file path --- .../tests/integration/api/db/test_entities.py | 32 +++++++++++++++++++ florist/tests/unit/api/test_client.py | 30 ++++++++++++++--- 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/florist/tests/integration/api/db/test_entities.py b/florist/tests/integration/api/db/test_entities.py index 6b8234e..f5939fc 100644 --- a/florist/tests/integration/api/db/test_entities.py +++ b/florist/tests/integration/api/db/test_entities.py @@ -240,6 +240,38 @@ async def test_set_client_metrics_fail_update_result(mock_request) -> None: ) +async def test_set_server_log_file_path_success(mock_request) -> None: + test_job = get_test_job() + result_id = await test_job.create(mock_request.app.database) + test_job.id = result_id + test_job.clients_info[0].id = ANY + test_job.clients_info[1].id = ANY + + test_log_file_path = "test/log/file/path.log" + + await test_job.set_server_log_file_path(test_log_file_path, mock_request.app.database) + + result_job = await Job.find_by_id(result_id, mock_request.app.database) + test_job.server_log_file_path = test_log_file_path + assert result_job == test_job + + +async def test_set_client_log_file_path_success(mock_request) -> None: + test_job = get_test_job() + result_id = await test_job.create(mock_request.app.database) + test_job.id = result_id + test_job.clients_info[0].id = ANY + test_job.clients_info[1].id = ANY + + test_log_file_path = "test/log/file/path.log" + + await test_job.set_client_log_file_path(1, test_log_file_path, mock_request.app.database) + + result_job = await Job.find_by_id(result_id, mock_request.app.database) + test_job.clients_info[1].log_file_path = test_log_file_path + assert result_job == test_job + + def get_test_job() -> Job: test_server_config = { "n_server_rounds": 2, diff --git a/florist/tests/unit/api/test_client.py b/florist/tests/unit/api/test_client.py index c2e825e..1a5790a 100644 --- a/florist/tests/unit/api/test_client.py +++ b/florist/tests/unit/api/test_client.py @@ -1,5 +1,6 @@ """Tests for FLorist's client FastAPI endpoints.""" import json +import os from unittest.mock import ANY, Mock, patch from florist.api import client @@ -29,10 +30,11 @@ def test_start_success(mock_launch_client: Mock) -> None: assert response.status_code == 200 json_body = json.loads(response.body.decode()) - assert json_body == {"uuid": ANY, "log_file_path": str(get_client_log_file_path(json_body["uuid"]))} + log_file_path = str(get_client_log_file_path(json_body["uuid"])) - log_file_name = str(get_client_log_file_path(json_body["uuid"])) - mock_launch_client.assert_called_once_with(ANY, test_server_address, log_file_name) + assert json_body == {"uuid": ANY, "log_file_path": log_file_path} + + mock_launch_client.assert_called_once_with(ANY, test_server_address, log_file_path) client_obj = mock_launch_client.call_args_list[0][0][0] assert isinstance(client_obj, MnistClient) @@ -61,7 +63,7 @@ def test_start_fail_unsupported_client() -> None: @patch("florist.api.client.launch_client", side_effect=Exception("test exception")) -def test_start_fail_exception(mock_launch_client: Mock) -> None: +def test_start_fail_exception(_: Mock) -> None: test_server_address = "test-server-address" test_client = "MNIST" test_data_path = "test/data/path" @@ -91,6 +93,7 @@ def test_check_status(mock_redis: Mock) -> None: mock_redis.Redis.assert_called_with(host=test_redis_host, port=test_redis_port) assert json.loads(response.body.decode()) == {"info": "test"} + @patch("florist.api.monitoring.metrics.redis") def test_check_status_not_found(mock_redis: Mock) -> None: mock_redis_connection = Mock() @@ -108,8 +111,9 @@ def test_check_status_not_found(mock_redis: Mock) -> None: assert response.status_code == 404 assert json.loads(response.body.decode()) == {"error": f"Client {test_uuid} Not Found"} + @patch("florist.api.monitoring.metrics.redis.Redis", side_effect=Exception("test exception")) -def test_check_status_fail_exception(mock_redis: Mock) -> None: +def test_check_status_fail_exception(_: Mock) -> None: test_uuid = "test_uuid" test_redis_host = "localhost" @@ -119,3 +123,19 @@ def test_check_status_fail_exception(mock_redis: Mock) -> None: assert response.status_code == 500 assert json.loads(response.body.decode()) == {"error": "test exception"} + + +def test_get_log() -> None: + test_client_uuid = "test-client-uuid" + test_log_file_content = "this is a test log file content" + test_log_file_path = str(get_client_log_file_path(test_client_uuid)) + + with open(test_log_file_path, "w") as f: + f.write(test_log_file_content) + + response = client.get_log(test_log_file_path) + + assert response.status_code == 200 + assert response.body.decode() == f"\"{test_log_file_content}\"" + + os.remove(test_log_file_path) From 0e136f3fc9db0356755d3fbde16244e71702d251 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 19 Nov 2024 18:29:31 -0500 Subject: [PATCH 5/7] Fixing job tests, adding first get server log test --- .../integration/api/routes/server/test_job.py | 59 ++++++++++++++++++- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/florist/tests/integration/api/routes/server/test_job.py b/florist/tests/integration/api/routes/server/test_job.py index 297d046..56574d6 100644 --- a/florist/tests/integration/api/routes/server/test_job.py +++ b/florist/tests/integration/api/routes/server/test_job.py @@ -1,10 +1,12 @@ +import os from unittest.mock import ANY from fastapi.encoders import jsonable_encoder from florist.api.clients.common import Client from florist.api.db.entities import ClientInfo, Job, JobStatus -from florist.api.routes.server.job import list_jobs_with_status, new_job +from florist.api.monitoring.logs import get_server_log_file_path, get_client_log_file_path +from florist.api.routes.server.job import list_jobs_with_status, new_job, get_server_log from florist.api.servers.common import Model from florist.tests.integration.api.utils import mock_request from florist.api.servers.config_parsers import ConfigParser @@ -26,6 +28,7 @@ async def test_new_job(mock_request) -> None: "clients_info": None, "server_metrics": None, "server_uuid": None, + "server_log_file_path": None, } test_job = Job( @@ -39,6 +42,7 @@ async def test_new_job(mock_request) -> None: redis_port="test-redis-port", server_metrics="test-server-metrics", server_uuid="test-server-uuid", + server_log_file_path="test-server-log-file-path", clients_info=[ ClientInfo( client=Client.MNIST, @@ -48,6 +52,7 @@ async def test_new_job(mock_request) -> None: redis_port="test-redis-port-1", metrics="test-client-metrics-1", uuid="test-client-uuid-1", + log_file_path="test-log-file-path-1", ), ClientInfo( client=Client.MNIST, @@ -55,8 +60,9 @@ async def test_new_job(mock_request) -> None: data_path="test/data/path-2", redis_host="test-redis-host-2", redis_port="test-redis-port-2", - metrics="test-client-metrics-1", - uuid="test-client-uuid-1", + metrics="test-client-metrics-2", + uuid="test-client-uuid-2", + log_file_path="test-log-file-path-2", ), ] ) @@ -73,6 +79,7 @@ async def test_new_job(mock_request) -> None: "redis_port": test_job.redis_port, "server_uuid": test_job.server_uuid, "server_metrics": test_job.server_metrics, + "server_log_file_path": test_job.server_log_file_path, "clients_info": [ { "_id": ANY, @@ -83,6 +90,7 @@ async def test_new_job(mock_request) -> None: "redis_port": test_job.clients_info[0].redis_port, "uuid": test_job.clients_info[0].uuid, "metrics": test_job.clients_info[0].metrics, + "log_file_path": test_job.clients_info[0].log_file_path, }, { "_id": ANY, "client": test_job.clients_info[1].client.value, @@ -92,6 +100,7 @@ async def test_new_job(mock_request) -> None: "redis_port": test_job.clients_info[1].redis_port, "uuid": test_job.clients_info[1].uuid, "metrics": test_job.clients_info[1].metrics, + "log_file_path": test_job.clients_info[1].log_file_path, }, ], } @@ -109,6 +118,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port1", server_metrics="test-server-metrics1", server_uuid="test-server-uuid1", + server_log_file_path="test-server-log-file-path1", clients_info=[ ClientInfo( client=Client.MNIST, @@ -118,6 +128,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port-1-1", metrics="test-client-metrics-1-1", uuid="test-client-uuid-1-1", + log_file_path="test-log-file-path-1-1", ), ClientInfo( client=Client.MNIST, @@ -127,6 +138,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port-2-1", metrics="test-client-metrics-2-1", uuid="test-client-uuid-2-1", + log_file_path="test-log-file-path-2-1", ), ] ) @@ -142,6 +154,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port2", server_metrics="test-server-metrics2", server_uuid="test-server-uuid2", + server_log_file_path="test-server-log-file-path2", clients_info=[ ClientInfo( client=Client.MNIST, @@ -151,6 +164,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port-1-2", metrics="test-client-metrics-1-2", uuid="test-client-uuid-1-2", + log_file_path="test-log-file-path-1-2", ), ClientInfo( client=Client.MNIST, @@ -160,6 +174,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port-2-2", metrics="test-client-metrics-2-2", uuid="test-client-uuid-2-2", + log_file_path="test-log-file-path-2-2", ), ] ) @@ -175,6 +190,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port3", server_metrics="test-server-metrics3", server_uuid="test-server-uuid3", + server_log_file_path="test-server-log-file-path3", clients_info=[ ClientInfo( client=Client.MNIST, @@ -184,6 +200,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port-1-3", metrics="test-client-metrics-1-3", uuid="test-client-uuid-1-3", + log_file_path="test-log-file-path-1-3", ), ClientInfo( client=Client.MNIST, @@ -193,6 +210,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port-2-3", metrics="test-client-metrics-2-3", uuid="test-client-uuid-2-3", + log_file_path="test-log-file-path-2-3", ), ] ) @@ -208,6 +226,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port4", server_metrics="test-server-metrics4", server_uuid="test-server-uuid4", + server_log_file_path="test-server-log-file-path4", clients_info=[ ClientInfo( client=Client.MNIST, @@ -217,6 +236,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port-1-4", metrics="test-client-metrics-1-4", uuid="test-client-uuid-1-4", + log_file_path="test-log-file-path-1-4", ), ClientInfo( client=Client.MNIST, @@ -226,6 +246,7 @@ async def test_list_jobs_with_status(mock_request) -> None: redis_port="test-redis-port-2-4", metrics="test-client-metrics-2-4", uuid="test-client-uuid-2-4", + log_file_path="test-log-file-path-2-4", ), ] ) @@ -255,6 +276,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job1.redis_port, "server_metrics": test_job1.server_metrics, "server_uuid": test_job1.server_uuid, + "server_log_file_path": test_job1.server_log_file_path, "clients_info": [ { "_id": ANY, @@ -265,6 +287,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job1.clients_info[0].redis_port, "metrics": test_job1.clients_info[0].metrics, "uuid": test_job1.clients_info[0].uuid, + "log_file_path": test_job1.clients_info[0].log_file_path, }, { "_id": ANY, "client": test_job1.clients_info[1].client.value, @@ -274,6 +297,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job1.clients_info[1].redis_port, "metrics": test_job1.clients_info[1].metrics, "uuid": test_job1.clients_info[1].uuid, + "log_file_path": test_job1.clients_info[1].log_file_path, }, ], } @@ -289,6 +313,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job2.redis_port, "server_metrics": test_job2.server_metrics, "server_uuid": test_job2.server_uuid, + "server_log_file_path": test_job2.server_log_file_path, "clients_info": [ { "_id": ANY, @@ -299,6 +324,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job2.clients_info[0].redis_port, "metrics": test_job2.clients_info[0].metrics, "uuid": test_job2.clients_info[0].uuid, + "log_file_path": test_job2.clients_info[0].log_file_path, }, { "_id": ANY, "client": test_job2.clients_info[1].client.value, @@ -308,6 +334,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job2.clients_info[1].redis_port, "metrics": test_job2.clients_info[1].metrics, "uuid": test_job2.clients_info[1].uuid, + "log_file_path": test_job2.clients_info[1].log_file_path, }, ], } @@ -323,6 +350,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job3.redis_port, "server_metrics": test_job3.server_metrics, "server_uuid": test_job3.server_uuid, + "server_log_file_path": test_job3.server_log_file_path, "clients_info": [ { "_id": ANY, @@ -333,6 +361,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job3.clients_info[0].redis_port, "metrics": test_job3.clients_info[0].metrics, "uuid": test_job3.clients_info[0].uuid, + "log_file_path": test_job3.clients_info[0].log_file_path, }, { "_id": ANY, "client": test_job3.clients_info[1].client.value, @@ -342,6 +371,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job3.clients_info[1].redis_port, "metrics": test_job3.clients_info[1].metrics, "uuid": test_job3.clients_info[1].uuid, + "log_file_path": test_job3.clients_info[1].log_file_path, }, ], } @@ -357,6 +387,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job4.redis_port, "server_metrics": test_job4.server_metrics, "server_uuid": test_job4.server_uuid, + "server_log_file_path": test_job4.server_log_file_path, "clients_info": [ { "_id": ANY, @@ -367,6 +398,7 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job4.clients_info[0].redis_port, "metrics": test_job4.clients_info[0].metrics, "uuid": test_job4.clients_info[0].uuid, + "log_file_path": test_job4.clients_info[0].log_file_path, }, { "_id": ANY, "client": test_job4.clients_info[1].client.value, @@ -376,6 +408,27 @@ async def test_list_jobs_with_status(mock_request) -> None: "redis_port": test_job4.clients_info[1].redis_port, "metrics": test_job4.clients_info[1].metrics, "uuid": test_job4.clients_info[1].uuid, + "log_file_path": test_job4.clients_info[1].log_file_path, }, ], } + + +async def test_get_server_log_success(mock_request): + test_log_file_name = "test-log-file-name" + test_log_file_content = "this is a test log file content" + test_log_file_path = str(get_server_log_file_path(test_log_file_name)) + + with open(test_log_file_path, "w") as f: + f.write(test_log_file_content) + + result_job = await new_job(mock_request, Job(server_log_file_path=test_log_file_path)) + + result = await get_server_log(result_job.id, mock_request) + + assert result.status_code == 200 + assert result.body.decode() == f"\"{test_log_file_content}\"" + + os.remove(test_log_file_path) + +# TODO test assertion errors for get server log From 29214d72995b2663c8adae7c428918507b134284 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 20 Nov 2024 16:59:17 -0500 Subject: [PATCH 6/7] Adding more tests --- florist/api/routes/server/job.py | 4 ++-- .../integration/api/routes/server/test_job.py | 22 ++++++++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/florist/api/routes/server/job.py b/florist/api/routes/server/job.py index f450ee4..8efe0a0 100644 --- a/florist/api/routes/server/job.py +++ b/florist/api/routes/server/job.py @@ -120,7 +120,7 @@ async def get_server_log(job_id: str, request: Request) -> JSONResponse: assert job is not None, f"Job {job_id} not found" assert ( job.server_log_file_path is not None and job.server_log_file_path != "" - ), "log file path is None or empty" + ), "Log file path is None or empty" with open(job.server_log_file_path, "r") as f: content = f.read() @@ -157,7 +157,7 @@ async def get_client_log(job_id: str, client_index: int, request: Request) -> JS client_info = job.clients_info[client_index] assert ( client_info.log_file_path is not None and client_info.log_file_path != "" - ), "log file path is None or empty" + ), "Log file path is None or empty" response = requests.get( url=f"http://{client_info.service_address}/api/client/get_log", diff --git a/florist/tests/integration/api/routes/server/test_job.py b/florist/tests/integration/api/routes/server/test_job.py index 56574d6..5f15cd9 100644 --- a/florist/tests/integration/api/routes/server/test_job.py +++ b/florist/tests/integration/api/routes/server/test_job.py @@ -1,3 +1,4 @@ +import json import os from unittest.mock import ANY @@ -431,4 +432,23 @@ async def test_get_server_log_success(mock_request): os.remove(test_log_file_path) -# TODO test assertion errors for get server log + +async def test_get_server_log_error_no_job(mock_request): + test_job_id = "inexistent-job-id" + + result = await get_server_log(test_job_id, mock_request) + + assert result.status_code == 400 + assert json.loads(result.body.decode()) == {"error": f"Job {test_job_id} not found"} + + +async def test_get_server_log_error_no_log_path(mock_request): + result_job = await new_job(mock_request, Job()) + + result = await get_server_log(result_job.id, mock_request) + + assert result.status_code == 400 + assert json.loads(result.body.decode()) == {"error": f"Log file path is None or empty"} + + +# TODO add tests for get client log From 32af3bca965eb7fea9fe96616c36c31195926c53 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Thu, 21 Nov 2024 12:06:09 -0500 Subject: [PATCH 7/7] Add tests for get client logs --- .../integration/api/routes/server/test_job.py | 102 +++++++++++++++++- 1 file changed, 99 insertions(+), 3 deletions(-) diff --git a/florist/tests/integration/api/routes/server/test_job.py b/florist/tests/integration/api/routes/server/test_job.py index 5f15cd9..a1e5537 100644 --- a/florist/tests/integration/api/routes/server/test_job.py +++ b/florist/tests/integration/api/routes/server/test_job.py @@ -2,14 +2,15 @@ import os from unittest.mock import ANY +import uvicorn from fastapi.encoders import jsonable_encoder from florist.api.clients.common import Client from florist.api.db.entities import ClientInfo, Job, JobStatus from florist.api.monitoring.logs import get_server_log_file_path, get_client_log_file_path -from florist.api.routes.server.job import list_jobs_with_status, new_job, get_server_log +from florist.api.routes.server.job import list_jobs_with_status, new_job, get_server_log, get_client_log from florist.api.servers.common import Model -from florist.tests.integration.api.utils import mock_request +from florist.tests.integration.api.utils import mock_request, TestUvicornServer from florist.api.servers.config_parsers import ConfigParser @@ -451,4 +452,99 @@ async def test_get_server_log_error_no_log_path(mock_request): assert json.loads(result.body.decode()) == {"error": f"Log file path is None or empty"} -# TODO add tests for get client log +async def test_get_client_log_success(mock_request): + test_log_file_name = "test-log-file-name" + test_log_file_content = "this is a test log file content" + test_log_file_path = str(get_client_log_file_path(test_log_file_name)) + + with open(test_log_file_path, "w") as f: + f.write(test_log_file_content) + + test_client_host = "localhost" + test_client_port = 8001 + + result_job = await new_job(mock_request, Job( + clients_info=[ + ClientInfo( + client=Client.MNIST, + service_address=f"{test_client_host}:{test_client_port}", + data_path="test/data/path-1", + redis_host="test-redis-host-1", + redis_port="test-redis-port-1", + ), + ClientInfo( + client=Client.MNIST, + service_address=f"{test_client_host}:{test_client_port}", + data_path="test/data/path-2", + redis_host="test-redis-host-2", + redis_port="test-redis-port-2", + log_file_path=test_log_file_path, + ), + ], + )) + + client_config = uvicorn.Config("florist.api.client:app", host=test_client_host, port=test_client_port, log_level="debug") + client_service = TestUvicornServer(config=client_config) + with client_service.run_in_thread(): + result = await get_client_log(result_job.id, 1, mock_request) + + assert result.status_code == 200 + assert result.body.decode() == f"\"{test_log_file_content}\"" + + os.remove(test_log_file_path) + + +async def test_get_client_log_error_no_job(mock_request): + test_job_id = "inexistent-job-id" + + result = await get_client_log(test_job_id, 0, mock_request) + + assert result.status_code == 400 + assert json.loads(result.body.decode()) == {"error": f"Job {test_job_id} not found"} + + +async def test_get_client_log_error_no_clients(mock_request): + result_job = await new_job(mock_request, Job()) + + result = await get_client_log(result_job.id, 0, mock_request) + + assert result.status_code == 400 + assert json.loads(result.body.decode()) == {"error": f"Job has no clients."} + + +async def test_get_client_log_error_invalid_client_index(mock_request): + result_job = await new_job(mock_request, Job( + clients_info=[ + ClientInfo( + client=Client.MNIST, + service_address=f"test-address", + data_path="test/data/path-1", + redis_host="test-redis-host-1", + redis_port="test-redis-port-1", + ), + ], + )) + + result = await get_client_log(result_job.id, 1, mock_request) + + assert result.status_code == 400 + assert json.loads(result.body.decode()) == {"error": f"Client index 1 is invalid (total: 1)"} + + +async def test_get_client_log_error_log_file_path_is_none(mock_request): + result_job = await new_job(mock_request, Job( + clients_info=[ + ClientInfo( + client=Client.MNIST, + service_address=f"test-address", + data_path="test/data/path-1", + redis_host="test-redis-host-1", + redis_port="test-redis-port-1", + ), + ], + )) + + result = await get_client_log(result_job.id, 0, mock_request) + + assert result.status_code == 400 + assert json.loads(result.body.decode()) == {"error": f"Log file path is None or empty"}