From 1d251351ec5ecda838003acbf03513acb168d9ef Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 11 Oct 2024 16:12:39 -0400 Subject: [PATCH 1/7] adding listeners for clients as well --- florist/api/db/entities.py | 43 ++++++---- florist/api/routes/server/training.py | 111 ++++++++++++++------------ 2 files changed, 87 insertions(+), 67 deletions(-) diff --git a/florist/api/db/entities.py b/florist/api/db/entities.py index 55c1fac..165cd60 100644 --- a/florist/api/db/entities.py +++ b/florist/api/db/entities.py @@ -176,36 +176,49 @@ def set_status_sync(self, status: JobStatus, database: Database[Dict[str, Any]]) update_result = job_collection.update_one({"_id": self.id}, {"$set": {"status": status.value}}) assert_updated_successfully(update_result) - def set_metrics( + def set_server_metrics( self, server_metrics: Dict[str, Any], - client_metrics: List[Dict[str, Any]], database: Database[Dict[str, Any]], ) -> None: """ - Sync function to save the server and clients' metrics in the database under the current job's id. + Sync function to save the server's metrics in the database under the current job's id. :param server_metrics: (Dict[str, Any]) the server metrics to be saved. - :param client_metrics: (List[Dict[str, Any]]) the clients metrics to be saved. :param database: (pymongo.database.Database) The database where the job collection is stored. """ - assert self.clients_info is not None and len(self.clients_info) == len(client_metrics), ( - "self.clients_info and client_metrics must have the same length " - f"({'None' if self.clients_info is None else len(self.clients_info)}!={len(client_metrics)})." - ) - job_collection = database[JOB_COLLECTION_NAME] self.server_metrics = json.dumps(server_metrics) update_result = job_collection.update_one({"_id": self.id}, {"$set": {"server_metrics": self.server_metrics}}) assert_updated_successfully(update_result) - for i in range(len(client_metrics)): - self.clients_info[i].metrics = json.dumps(client_metrics[i]) - update_result = job_collection.update_one( - {"_id": self.id}, {"$set": {f"clients_info.{i}.metrics": self.clients_info[i].metrics}} - ) - assert_updated_successfully(update_result) + def set_client_metrics( + self, + client_uuid: str, + client_metrics: Dict[str, Any], + database: Database[Dict[str, Any]], + ) -> None: + """ + Sync function to save a clients' metrics in the database under the current job's id. + + :param client_uuid: (str) the client's uuid whose produced the metrics. + :param client_metrics: (Dict[str, Any]) the client's metrics to be saved. + :param database: (pymongo.database.Database) The database where the job collection is stored. + """ + assert self.clients_info is not None and client_uuid in [c.uuid for c in self.clients_info], ( + f"client uuid {client_uuid} is not in clients_info ({[c.uuid for c in self.clients_info]})" + ) + + job_collection = database[JOB_COLLECTION_NAME] + + for i in range(len(self.clients_info)): + if client_uuid in self.clients_info[i].uuid: + self.clients_info[i].metrics = json.dumps(client_metrics) + update_result = job_collection.update_one( + {"_id": self.id}, {"$set": {f"clients_info.{i}.metrics": self.clients_info[i].metrics}} + ) + assert_updated_successfully(update_result) class Config: """MongoDB config for the Job DB entity.""" diff --git a/florist/api/routes/server/training.py b/florist/api/routes/server/training.py index cdb0d8e..79b78f4 100644 --- a/florist/api/routes/server/training.py +++ b/florist/api/routes/server/training.py @@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse from pymongo.database import Database -from florist.api.db.entities import Job, JobStatus +from florist.api.db.entities import ClientInfo, Job, JobStatus from florist.api.monitoring.metrics import get_from_redis, get_subscriber, wait_for_metric from florist.api.servers.common import Model from florist.api.servers.config_parsers import ConfigParser @@ -106,8 +106,10 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks await job.set_uuids(server_uuid, client_uuids, request.app.database) # Start the server training listener as a background task to update - # the job's status once the training is done + # the job's metrics and status once the training is done background_tasks.add_task(server_training_listener, job, request.app.synchronous_database) + for client_info in job.clients_info: + background_tasks.add_task(client_training_listener, job, client_info, request.app.synchronous_database) # Return the UUIDs return JSONResponse({"server_uuid": server_uuid, "client_uuids": client_uuids}) @@ -124,6 +126,49 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks return JSONResponse({"error": str(ex)}, status_code=500) +def client_training_listener(job: Job, client_info: ClientInfo, database: Database[Dict[str, Any]]) -> None: + """ + Listen to the Redis' channel that reports updates on the training process of a FL client. + + Keeps consuming updates to the channel until it finds `shutdown` in the client metrics. + + :param job: (Job) The job that has this client's metrics. + :param client_info: (ClientInfo) The ClientInfo with the client_uuid to listen to. + :param database: (pymongo.database.Database) An instance of the database to save the information + into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async + because of limitations with FastAPI's BackgroundTasks. + """ + LOGGER.info(f"Starting listener for client messages from job {job.id} at channel {client_info.uuid}") + + assert client_info.uuid is not None, "clientInfo.uuid is None." + assert client_info.redis_host is not None, "clientInfo.redis_host is None." + assert client_info.redis_port is not None, "clientInfo.redis_port is None." + + # check if training has already finished before start listening + client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port) + LOGGER.debug(f"Listener: Current metrics for client {client_info.uuid}: {client_metrics}") + if client_metrics is not None: + LOGGER.info(f"Listener: Updating client metrics for client {client_info.uuid} on job {job.id}") + job.set_client_metrics(client_info.uuid, client_metrics, database) + LOGGER.info(f"Listener: Client metrics for client {client_info.uuid} on {job.id} has been updated.") + if "shutdown" in client_metrics: + return + + subscriber = get_subscriber(client_info.uuid, client_info.redis_host, client_info.redis_port) + # TODO add a max retries mechanism, maybe? + for message in subscriber.listen(): # type: ignore[no-untyped-call] + if message["type"] == "message": + # The contents of the message do not matter, we just use it to get notified + client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port) + LOGGER.debug(f"Listener: Current metrics for client {client_info.uuid}: {client_metrics}") + if client_metrics is not None: + LOGGER.info(f"Listener: Updating client metrics for client {client_info.uuid} on job {job.id}") + job.set_client_metrics(client_info.uuid, client_metrics, database) + LOGGER.info(f"Listener: Client metrics for client {client_info.uuid} on {job.id} has been updated.") + if "shutdown" in client_metrics: + return + + def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> None: """ Listen to the Redis' channel that reports updates on the training process of a FL server. @@ -147,9 +192,13 @@ def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> No server_metrics = get_from_redis(job.server_uuid, job.redis_host, job.redis_port) LOGGER.debug(f"Listener: Current metrics for job {job.id}: {server_metrics}") if server_metrics is not None: - update_job_metrics(job, server_metrics, database) + LOGGER.info(f"Listener: Updating server metrics for job {job.id}") + job.set_server_metrics(server_metrics, database) + LOGGER.info(f"Listener: Server metrics for {job.id} has been updated.") if "fit_end" in server_metrics: - close_job(job, database) + LOGGER.info(f"Listener: Training finished for job {job.id}") + job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database) + LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.") return subscriber = get_subscriber(job.server_uuid, job.redis_host, job.redis_port) @@ -161,53 +210,11 @@ def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> No LOGGER.debug(f"Listener: Message received for job {job.id}. Metrics: {server_metrics}") if server_metrics is not None: - update_job_metrics(job, server_metrics, database) + LOGGER.info(f"Listener: Updating server metrics for job {job.id}") + job.set_server_metrics(server_metrics, database) + LOGGER.info(f"Listener: Server metrics for {job.id} has been updated.") if "fit_end" in server_metrics: - close_job(job, database) + LOGGER.info(f"Listener: Training finished for job {job.id}") + job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database) + LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.") return - - -def update_job_metrics(job: Job, server_metrics: Dict[str, Any], database: Database[Dict[str, Any]]) -> None: - """ - Update the job with server and client metrics. - - Collect the job's clients metrics, saving them and the server's metrics to the job. - - :param job: (Job) The job to be updated. - :param server_metrics: (Dict[str, Any]) The server's metrics to be saved into the job. - :param database: (pymongo.database.Database) An instance of the database to save the information - into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async - because of limitations with FastAPI's BackgroundTasks. - """ - LOGGER.info(f"Listener: Updating metrics for job {job.id}") - - clients_metrics: List[Dict[str, Any]] = [] - if job.clients_info is not None: - for client_info in job.clients_info: - response = requests.get( - url=f"http://{client_info.service_address}/{CHECK_CLIENT_STATUS_API}/{client_info.uuid}", - params={ - "redis_host": client_info.redis_host, - "redis_port": client_info.redis_port, - }, - ) - client_metrics = response.json() - clients_metrics.append(client_metrics) - - job.set_metrics(server_metrics, clients_metrics, database) - - LOGGER.info(f"Listener: Job {job.id} has been updated.") - - -def close_job(job: Job, database: Database[Dict[str, Any]]) -> None: - """ - Close the job by marking its status as FINISHED_SUCCESSFULLY. - - :param job: (Job) The job to be closed. - :param database: (pymongo.database.Database) An instance of the database to save the information - into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async - because of limitations with FastAPI's BackgroundTasks. - """ - LOGGER.info(f"Listener: Training finished for job {job.id}") - job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database) - LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.") From c502d0eff66aeba48083c97bbfd1dfd43a191449 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 11 Oct 2024 16:19:25 -0400 Subject: [PATCH 2/7] ruff and mypy changes, need tests --- florist/api/db/entities.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/florist/api/db/entities.py b/florist/api/db/entities.py index 165cd60..56bd5ce 100644 --- a/florist/api/db/entities.py +++ b/florist/api/db/entities.py @@ -207,13 +207,14 @@ def set_client_metrics( :param database: (pymongo.database.Database) The database where the job collection is stored. """ assert self.clients_info is not None and client_uuid in [c.uuid for c in self.clients_info], ( - f"client uuid {client_uuid} is not in clients_info ({[c.uuid for c in self.clients_info]})" + f"client uuid {client_uuid} is not in clients_info", + f"({[c.uuid for c in self.clients_info] if self.clients_info is not None else None})", ) job_collection = database[JOB_COLLECTION_NAME] for i in range(len(self.clients_info)): - if client_uuid in self.clients_info[i].uuid: + if client_uuid == self.clients_info[i].uuid: self.clients_info[i].metrics = json.dumps(client_metrics) update_result = job_collection.update_one( {"_id": self.id}, {"$set": {f"clients_info.{i}.metrics": self.clients_info[i].metrics}} From bd88d5c7d819d0c38a48a082cba0e5e59fa75c92 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 15 Oct 2024 17:16:04 -0400 Subject: [PATCH 3/7] Fixing training tests, adding more tests --- florist/api/routes/server/training.py | 4 +- .../unit/api/routes/server/test_training.py | 184 +++++++++++++----- 2 files changed, 141 insertions(+), 47 deletions(-) diff --git a/florist/api/routes/server/training.py b/florist/api/routes/server/training.py index 79b78f4..f00778f 100644 --- a/florist/api/routes/server/training.py +++ b/florist/api/routes/server/training.py @@ -140,9 +140,7 @@ def client_training_listener(job: Job, client_info: ClientInfo, database: Databa """ LOGGER.info(f"Starting listener for client messages from job {job.id} at channel {client_info.uuid}") - assert client_info.uuid is not None, "clientInfo.uuid is None." - assert client_info.redis_host is not None, "clientInfo.redis_host is None." - assert client_info.redis_port is not None, "clientInfo.redis_port is None." + assert client_info.uuid is not None, "client_info.uuid is None." # check if training has already finished before start listening client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port) diff --git a/florist/tests/unit/api/routes/server/test_training.py b/florist/tests/unit/api/routes/server/test_training.py index c076eef..17e5af7 100644 --- a/florist/tests/unit/api/routes/server/test_training.py +++ b/florist/tests/unit/api/routes/server/test_training.py @@ -6,7 +6,12 @@ from florist.api.db.entities import Job, JobStatus, JOB_COLLECTION_NAME from florist.api.models.mnist import MnistNet -from florist.api.routes.server.training import start, server_training_listener, CHECK_CLIENT_STATUS_API +from florist.api.routes.server.training import ( + client_training_listener, + start, + server_training_listener, + CHECK_CLIENT_STATUS_API, +) @patch("florist.api.routes.server.training.launch_local_server") @@ -96,11 +101,25 @@ async def test_start_success( expected_job.id = ANY expected_job.clients_info[0].id = ANY expected_job.clients_info[1].id = ANY - mock_background_tasks.add_task.assert_called_once_with( - server_training_listener, - expected_job, - mock_fastapi_request.app.synchronous_database, - ) + mock_background_tasks.add_task.assert_has_calls([ + call( + server_training_listener, + expected_job, + mock_fastapi_request.app.synchronous_database, + ), + call( + client_training_listener, + expected_job, + expected_job.clients_info[0], + mock_fastapi_request.app.synchronous_database, + ), + call( + client_training_listener, + expected_job, + expected_job.clients_info[1], + mock_fastapi_request.app.synchronous_database, + ), + ]) async def test_start_fail_unsupported_server_model() -> None: @@ -318,8 +337,7 @@ async def test_start_no_uuid_in_response(mock_requests: Mock, mock_redis: Mock, @patch("florist.api.routes.server.training.get_from_redis") @patch("florist.api.routes.server.training.get_subscriber") -@patch("florist.api.routes.server.training.requests") -def test_server_training_listener(mock_requests: Mock(), mock_get_subscriber: Mock, mock_get_from_redis: Mock) -> None: +def test_server_training_listener(mock_get_subscriber: Mock, mock_get_from_redis: Mock) -> None: # Setup test_job = Job(**{ "server_uuid": "test-server-uuid", @@ -336,11 +354,10 @@ def test_server_training_listener(mock_requests: Mock(), mock_get_subscriber: Mo } ] }) - test_client_metrics = {"test": 123} test_server_metrics = [ {"fit_start": "2022-02-02 02:02:02"}, {"fit_start": "2022-02-02 02:02:02", "rounds": []}, - {"fit_start": "2022-02-02 02:02:02", "rounds": [], "fit_end": "2022-02-02 03:03:03"} + {"fit_start": "2022-02-02 02:02:02", "rounds": [], "fit_end": "2022-02-02 03:03:03"}, ] mock_get_from_redis.side_effect = test_server_metrics mock_subscriber = Mock() @@ -353,41 +370,28 @@ def test_server_training_listener(mock_requests: Mock(), mock_get_subscriber: Mo ] mock_get_subscriber.return_value = mock_subscriber mock_database = Mock() - mock_response = Mock() - mock_response.json.return_value = test_client_metrics - mock_requests.get.return_value = mock_response with patch.object(Job, "set_status_sync", Mock()) as mock_set_status_sync: - with patch.object(Job, "set_metrics", Mock()) as mock_set_metrics: + with patch.object(Job, "set_server_metrics", Mock()) as mock_set_server_metrics: # Act server_training_listener(test_job, mock_database) # Assert mock_set_status_sync.assert_called_once_with(JobStatus.FINISHED_SUCCESSFULLY, mock_database) - assert mock_set_metrics.call_count == 3 - mock_set_metrics.assert_has_calls([ - call(test_server_metrics[0], [test_client_metrics], mock_database), - call(test_server_metrics[1], [test_client_metrics], mock_database), - call(test_server_metrics[2], [test_client_metrics], mock_database), + assert mock_set_server_metrics.call_count == 3 + mock_set_server_metrics.assert_has_calls([ + call(test_server_metrics[0], mock_database), + call(test_server_metrics[1], mock_database), + call(test_server_metrics[2], mock_database), ]) assert mock_get_from_redis.call_count == 3 mock_get_subscriber.assert_called_once_with(test_job.server_uuid, test_job.redis_host, test_job.redis_port) - assert mock_requests.get.call_count == 3 - mock_requests_get_call = call( - url=f"http://{test_job.clients_info[0].service_address}/{CHECK_CLIENT_STATUS_API}/{test_job.clients_info[0].uuid}", - params={ - "redis_host": test_job.clients_info[0].redis_host, - "redis_port": test_job.clients_info[0].redis_port, - }, - ) - assert mock_requests.get.call_args_list == [mock_requests_get_call, mock_requests_get_call, mock_requests_get_call] @patch("florist.api.routes.server.training.get_from_redis") -@patch("florist.api.routes.server.training.requests") -def test_server_training_listener_already_finished(mock_requests: Mock, mock_get_from_redis: Mock) -> None: +def test_server_training_listener_already_finished(mock_get_from_redis: Mock) -> None: # Setup test_job = Job(**{ "server_uuid": "test-server-uuid", @@ -404,31 +408,20 @@ def test_server_training_listener_already_finished(mock_requests: Mock, mock_get } ] }) - test_client_metrics = {"test": 123} test_server_final_metrics = {"fit_start": "2022-02-02 02:02:02", "rounds": [], "fit_end": "2022-02-02 03:03:03"} mock_get_from_redis.side_effect = [test_server_final_metrics] mock_database = Mock() - mock_response = Mock() - mock_response.json.return_value = test_client_metrics - mock_requests.get.return_value = mock_response with patch.object(Job, "set_status_sync", Mock()) as mock_set_status_sync: - with patch.object(Job, "set_metrics", Mock()) as mock_set_metrics: + with patch.object(Job, "set_server_metrics", Mock()) as mock_set_server_metrics: # Act server_training_listener(test_job, mock_database) # Assert mock_set_status_sync.assert_called_once_with(JobStatus.FINISHED_SUCCESSFULLY, mock_database) - mock_set_metrics.assert_called_once_with(test_server_final_metrics, [test_client_metrics], - mock_database) + mock_set_server_metrics.assert_called_once_with(test_server_final_metrics, mock_database) + assert mock_get_from_redis.call_count == 1 - mock_requests.get.assert_called_once_with( - url=f"http://{test_job.clients_info[0].service_address}/{CHECK_CLIENT_STATUS_API}/{test_job.clients_info[0].uuid}", - params={ - "redis_host": test_job.clients_info[0].redis_host, - "redis_port": test_job.clients_info[0].redis_port, - }, - ) def test_server_training_listener_fail_no_server_uuid() -> None: @@ -461,6 +454,109 @@ def test_server_training_listener_fail_no_redis_port() -> None: server_training_listener(test_job, Mock()) +@patch("florist.api.routes.server.training.get_from_redis") +@patch("florist.api.routes.server.training.get_subscriber") +def test_client_training_listener(mock_get_subscriber: Mock, mock_get_from_redis: Mock) -> None: + # Setup + test_client_uuid = "test-client-uuid"; + test_job = Job(**{ + "clients_info": [ + { + "service_address": "test-service-address", + "uuid": test_client_uuid, + "redis_host": "test-client-redis-host", + "redis_port": "test-client-redis-port", + "client": "MNIST", + "data_path": "test-data-path", + } + ] + }) + test_client_metrics = [ + {"initialized": "2022-02-02 02:02:02"}, + {"initialized": "2022-02-02 02:02:02", "rounds": []}, + {"initialized": "2022-02-02 02:02:02", "rounds": [], "shutdown": "2022-02-02 03:03:03"}, + ] + mock_get_from_redis.side_effect = test_client_metrics + mock_subscriber = Mock() + mock_subscriber.listen.return_value = [ + {"type": "message"}, + {"type": "not message"}, + {"type": "message"}, + {"type": "message"}, + {"type": "message"}, + ] + mock_get_subscriber.return_value = mock_subscriber + mock_database = Mock() + + with patch.object(Job, "set_status_sync", Mock()) as mock_set_status_sync: + with patch.object(Job, "set_client_metrics", Mock()) as mock_set_client_metrics: + # Act + client_training_listener(test_job, test_job.clients_info[0], mock_database) + + # Assert + assert mock_set_client_metrics.call_count == 3 + mock_set_client_metrics.assert_has_calls([ + call(test_client_uuid, test_client_metrics[0], mock_database), + call(test_client_uuid, test_client_metrics[1], mock_database), + call(test_client_uuid, test_client_metrics[2], mock_database), + ]) + + assert mock_get_from_redis.call_count == 3 + mock_get_subscriber.assert_called_once_with( + test_job.clients_info[0].uuid, + test_job.clients_info[0].redis_host, + test_job.clients_info[0].redis_port, + ) + + +@patch("florist.api.routes.server.training.get_from_redis") +def test_client_training_listener_already_finished(mock_get_from_redis: Mock) -> None: + # Setup + test_client_uuid = "test-client-uuid"; + test_job = Job(**{ + "clients_info": [ + { + "service_address": "test-service-address", + "uuid": test_client_uuid, + "redis_host": "test-client-redis-host", + "redis_port": "test-client-redis-port", + "client": "MNIST", + "data_path": "test-data-path", + } + ] + }) + test_client_final_metrics = {"initialized": "2022-02-02 02:02:02", "rounds": [], "shutdown": "2022-02-02 03:03:03"} + mock_get_from_redis.side_effect = [test_client_final_metrics] + mock_database = Mock() + + with patch.object(Job, "set_status_sync", Mock()) as mock_set_status_sync: + with patch.object(Job, "set_client_metrics", Mock()) as mock_set_client_metrics: + # Act + client_training_listener(test_job, test_job.clients_info[0], mock_database) + + # Assert + mock_set_client_metrics.assert_called_once_with(test_client_uuid, test_client_final_metrics, mock_database) + + assert mock_get_from_redis.call_count == 1 + + +def test_client_training_listener_fail_no_uuid() -> None: + test_job = Job(**{ + "clients_info": [ + { + "redis_host": "test-redis-host", + "redis_port": "test-redis-port", + "service_address": "test-service-address", + "client": "MNIST", + "data_path": "test-data-path", + }, + ], + }) + + with raises(AssertionError, match="client_info.uuid is None."): + client_training_listener(test_job, test_job.clients_info[0], Mock()) + + def _setup_test_job_and_mocks() -> Tuple[Dict[str, Any], Dict[str, Any], Mock, Mock]: test_server_config = { "n_server_rounds": 2, From 5d202fe52f73d914327d4d96d8753cee55fdcbce Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 16 Oct 2024 12:59:06 -0400 Subject: [PATCH 4/7] Adding integration tests for set client metrics --- florist/api/db/entities.py | 7 ++- .../tests/integration/api/db/test_entities.py | 53 +++++++++++-------- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/florist/api/db/entities.py b/florist/api/db/entities.py index 56bd5ce..aeb9d64 100644 --- a/florist/api/db/entities.py +++ b/florist/api/db/entities.py @@ -206,10 +206,9 @@ def set_client_metrics( :param client_metrics: (Dict[str, Any]) the client's metrics to be saved. :param database: (pymongo.database.Database) The database where the job collection is stored. """ - assert self.clients_info is not None and client_uuid in [c.uuid for c in self.clients_info], ( - f"client uuid {client_uuid} is not in clients_info", - f"({[c.uuid for c in self.clients_info] if self.clients_info is not None else None})", - ) + assert ( + self.clients_info is not None and client_uuid in [c.uuid for c in self.clients_info] + ), f"client uuid {client_uuid} is not in clients_info ({[c.uuid for c in self.clients_info] if self.clients_info is not None else None})" job_collection = database[JOB_COLLECTION_NAME] diff --git a/florist/tests/integration/api/db/test_entities.py b/florist/tests/integration/api/db/test_entities.py index a79e8e9..6b8234e 100644 --- a/florist/tests/integration/api/db/test_entities.py +++ b/florist/tests/integration/api/db/test_entities.py @@ -169,7 +169,7 @@ async def test_set_status_sync_fail_update_result(mock_request) -> None: test_job.set_status_sync(JobStatus.IN_PROGRESS, mock_request.app.synchronous_database) -async def test_set_metrics_success(mock_request) -> None: +async def test_set_server_metrics_success(mock_request) -> None: test_job = get_test_job() result_id = await test_job.create(mock_request.app.database) test_job.id = result_id @@ -177,56 +177,67 @@ async def test_set_metrics_success(mock_request) -> None: test_job.clients_info[1].id = ANY test_server_metrics = {"test-server": 123} - test_client_metrics = [{"test-client-1": 456}, {"test-client-2": 789}] - test_job.set_metrics(test_server_metrics, test_client_metrics, mock_request.app.synchronous_database) + test_job.set_server_metrics(test_server_metrics, mock_request.app.synchronous_database) result_job = await Job.find_by_id(result_id, mock_request.app.database) test_job.server_metrics = json.dumps(test_server_metrics) - test_job.clients_info[0].metrics = json.dumps(test_client_metrics[0]) - test_job.clients_info[1].metrics = json.dumps(test_client_metrics[1]) assert result_job == test_job -async def test_set_metrics_fail_clients_info_is_none(mock_request) -> None: +async def test_set_server_metrics_fail_update_result(mock_request) -> None: test_job = get_test_job() - test_job.clients_info = None - result_id = await test_job.create(mock_request.app.database) - test_job.id = result_id + test_job.id = str(test_job.id) test_server_metrics = {"test-server": 123} - test_client_metrics = [{"test-client-1": 456}, {"test-client-2": 789}] - error_msg = "self.clients_info and client_metrics must have the same length (None!=2)." + error_msg = "UpdateResult's 'n' is not 1" with raises(AssertionError, match=re.escape(error_msg)): - test_job.set_metrics(test_server_metrics, test_client_metrics, mock_request.app.synchronous_database) + test_job.set_server_metrics(test_server_metrics, mock_request.app.synchronous_database) -async def test_set_metrics_fail_clients_info_is_not_same_length(mock_request) -> None: +async def test_set_client_metrics_success(mock_request) -> None: test_job = get_test_job() result_id = await test_job.create(mock_request.app.database) test_job.id = result_id test_job.clients_info[0].id = ANY test_job.clients_info[1].id = ANY - test_server_metrics = {"test-server": 123} - test_client_metrics = [{"test-client-1": 456}] + test_client_metrics = [{"test-metric-1": 456}, {"test-metric-2": 789}] + + test_job.set_client_metrics(test_job.clients_info[1].uuid, test_client_metrics, mock_request.app.synchronous_database) - error_msg = "self.clients_info and client_metrics must have the same length (2!=1)." + result_job = await Job.find_by_id(result_id, mock_request.app.database) + test_job.clients_info[1].metrics = json.dumps(test_client_metrics) + assert result_job == test_job + + +async def test_set_metrics_fail_clients_info_is_none(mock_request) -> None: + test_job = get_test_job() + result_id = await test_job.create(mock_request.app.database) + test_job.id = result_id + + test_wrong_client_uuid = "client-id-that-does-not-exist" + test_client_metrics = [{"test-metric-1": 456}, {"test-metric-2": 789}] + + error_msg = f"client uuid {test_wrong_client_uuid} is not in clients_info (['{test_job.clients_info[0].uuid}', '{test_job.clients_info[1].uuid}'])" with raises(AssertionError, match=re.escape(error_msg)): - test_job.set_metrics(test_server_metrics, test_client_metrics, mock_request.app.synchronous_database) + test_job.set_client_metrics(test_wrong_client_uuid, test_client_metrics, mock_request.app.synchronous_database) -async def test_set_metrics_fail_update_result(mock_request) -> None: +async def test_set_client_metrics_fail_update_result(mock_request) -> None: test_job = get_test_job() test_job.id = str(test_job.id) - test_server_metrics = {"test-server": 123} - test_client_metrics = [{"test-client-1": 456}, {"test-client-2": 789}] + test_client_metrics = [{"test-metric-1": 456}, {"test-metric-2": 789}] error_msg = "UpdateResult's 'n' is not 1" with raises(AssertionError, match=re.escape(error_msg)): - test_job.set_metrics(test_server_metrics, test_client_metrics, mock_request.app.synchronous_database) + test_job.set_client_metrics( + test_job.clients_info[0].uuid, + test_client_metrics, + mock_request.app.synchronous_database, + ) def get_test_job() -> Job: From 1f8adc8e4880d96ae08cfae57ff1a4503a9e3cda Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 16 Oct 2024 14:47:57 -0400 Subject: [PATCH 5/7] Upgrading fastapi --- poetry.lock | 28 ++++++++++------------------ pyproject.toml | 2 +- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/poetry.lock b/poetry.lock index 62d3545..19a0fd2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1068,13 +1068,6 @@ files = [ {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d"}, {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393"}, {file = "dm_tree-0.1.8-cp311-cp311-win_amd64.whl", hash = "sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80"}, - {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8"}, - {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22"}, - {file = "dm_tree-0.1.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b"}, - {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760"}, - {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75c5d528bb992981c20793b6b453e91560784215dffb8a5440ba999753c14ceb"}, - {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e"}, - {file = "dm_tree-0.1.8-cp312-cp312-win_amd64.whl", hash = "sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715"}, {file = "dm_tree-0.1.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571"}, {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d"}, {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb"}, @@ -1193,22 +1186,23 @@ tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipyth [[package]] name = "fastapi" -version = "0.109.2" +version = "0.115.2" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.8" files = [ - {file = "fastapi-0.109.2-py3-none-any.whl", hash = "sha256:2c9bab24667293b501cad8dd388c05240c850b58ec5876ee3283c47d6e1e3a4d"}, - {file = "fastapi-0.109.2.tar.gz", hash = "sha256:f3817eac96fe4f65a2ebb4baa000f394e55f5fccdaf7f75250804bc58f354f73"}, + {file = "fastapi-0.115.2-py3-none-any.whl", hash = "sha256:61704c71286579cc5a598763905928f24ee98bfcc07aabe84cfefb98812bbc86"}, + {file = "fastapi-0.115.2.tar.gz", hash = "sha256:3995739e0b09fa12f984bce8fa9ae197b35d433750d3d312422d846e283697ee"}, ] [package.dependencies] pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" -starlette = ">=0.36.3,<0.37.0" +starlette = ">=0.37.2,<0.41.0" typing-extensions = ">=4.8.0" [package.extras] -all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"] [[package]] name = "fastjsonschema" @@ -3173,7 +3167,6 @@ description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ - {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_aarch64.whl", hash = "sha256:004186d5ea6a57758fd6d57052a123c73a4815adf365eb8dd6a85c9eaa7535ff"}, {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"}, ] @@ -4310,7 +4303,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -5316,13 +5308,13 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] [[package]] name = "starlette" -version = "0.36.3" +version = "0.40.0" description = "The little ASGI library that shines." optional = false python-versions = ">=3.8" files = [ - {file = "starlette-0.36.3-py3-none-any.whl", hash = "sha256:13d429aa93a61dc40bf503e8c801db1f1bca3dc706b10ef2434a36123568f044"}, - {file = "starlette-0.36.3.tar.gz", hash = "sha256:90a671733cfb35771d8cc605e0b679d23b992f8dcfad48cc60b38cb29aeb7080"}, + {file = "starlette-0.40.0-py3-none-any.whl", hash = "sha256:c494a22fae73805376ea6bf88439783ecfba9aac88a43911b48c653437e784c4"}, + {file = "starlette-0.40.0.tar.gz", hash = "sha256:1a3139688fb298ce5e2d661d37046a66ad996ce94be4d4983be019a23a04ea35"}, ] [package.dependencies] @@ -6267,4 +6259,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.11" -content-hash = "b3e891c3498b1bfb7583d18758f1bb50b9c60598b07873e5444a715233bbebc5" +content-hash = "09a4191122fc887d9469ed793525e6148d74898357f0f7444d3d6474381ceb3f" diff --git a/pyproject.toml b/pyproject.toml index acef207..af21c4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.9.0,<3.11" -fastapi = "^0.109.1" +fastapi = "^0.115.2" uvicorn = {version = "^0.23.2", extras = ["standard"]} fl4health = "^0.1.15" wandb = "^0.16.3" From 5eff725a47b88d9734b159e60473f857387c753f Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 16 Oct 2024 15:01:08 -0400 Subject: [PATCH 6/7] Testing some other form of poetry install --- .github/workflows/docs_build.yml | 9 ++++++--- .github/workflows/docs_deploy.yml | 7 +++++-- .github/workflows/integration_tests.yaml | 9 ++++++--- .github/workflows/publish.yml | 7 +++++-- .github/workflows/unit_tests.yaml | 9 ++++++--- 5 files changed, 28 insertions(+), 13 deletions(-) diff --git a/.github/workflows/docs_build.yml b/.github/workflows/docs_build.yml index 5c3d453..d4e61a8 100644 --- a/.github/workflows/docs_build.yml +++ b/.github/workflows/docs_build.yml @@ -19,8 +19,11 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4.2.1 - - name: Install dependencies, build docs and coverage report - run: sudo apt install python3-poetry + - name: Install and configure Poetry + uses: snok/install-poetry@v1 + with: + virtualenvs-create: true + virtualenvs-in-project: true - uses: actions/setup-python@v5.2.0 with: python-version: '3.9' @@ -28,7 +31,7 @@ jobs: - run: | python3 -m pip install --upgrade pip && python3 -m pip install poetry poetry env use '3.9' - source $(poetry env info --path)/bin/activate + source .venv/bin/activate poetry install --with docs,test cd docs && rm -rf source/reference/api/_autosummary && make html cd .. && coverage run -m pytest florist/tests/unit && coverage xml && coverage report -m diff --git a/.github/workflows/docs_deploy.yml b/.github/workflows/docs_deploy.yml index 1f91b46..2848793 100644 --- a/.github/workflows/docs_deploy.yml +++ b/.github/workflows/docs_deploy.yml @@ -24,8 +24,11 @@ jobs: - uses: actions/checkout@v4.2.1 with: submodules: 'true' - - name: Install dependencies, build docs and coverage report - run: sudo apt install python3-poetry + - name: Install and configure Poetry + uses: snok/install-poetry@v1 + with: + virtualenvs-create: true + virtualenvs-in-project: true - uses: actions/setup-python@v5.2.0 with: python-version: '3.9' diff --git a/.github/workflows/integration_tests.yaml b/.github/workflows/integration_tests.yaml index 3a54671..76c254f 100644 --- a/.github/workflows/integration_tests.yaml +++ b/.github/workflows/integration_tests.yaml @@ -37,8 +37,11 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4.2.1 - - name: Install poetry - run: sudo apt install python3-poetry + - name: Install and configure Poetry + uses: snok/install-poetry@v1 + with: + virtualenvs-create: true + virtualenvs-in-project: true - uses: actions/setup-python@v5.2.0 with: python-version: '3.9' @@ -53,7 +56,7 @@ jobs: - name: Install dependencies and check code run: | poetry env use '3.9' - source $(poetry env info --path)/bin/activate + source .venv/bin/activate poetry install --with docs,test coverage run -m pytest florist/tests/integration && coverage xml && coverage report -m - name: Upload python coverage to Codecov diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index ea26216..d5e5822 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -13,8 +13,11 @@ jobs: sudo apt-get update sudo apt-get install libcurl4-openssl-dev libssl-dev - uses: actions/checkout@v4.2.1 - - name: Install poetry - run: sudo apt install python3-poetry + - name: Install and configure Poetry + uses: snok/install-poetry@v1 + with: + virtualenvs-create: true + virtualenvs-in-project: true - uses: actions/setup-python@v5.2.0 with: python-version: '3.9' diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index 73edde4..0b5b037 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -37,8 +37,11 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4.2.1 - - name: Install poetry - run: sudo apt install python3-poetry + - name: Install and configure Poetry + uses: snok/install-poetry@v1 + with: + virtualenvs-create: true + virtualenvs-in-project: true - uses: actions/setup-python@v5.2.0 with: python-version: '3.9' @@ -46,7 +49,7 @@ jobs: - name: Install python dependencies and check code run: | poetry env use '3.9' - source $(poetry env info --path)/bin/activate + source .venv/bin/activate poetry install --with docs,test coverage run -m pytest florist/tests/unit && coverage xml && coverage report -m - name: Upload python coverage to Codecov From d9b2504a0a6e2d3b85ced68bcfe68f5b6abb5474 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 16 Oct 2024 15:20:20 -0400 Subject: [PATCH 7/7] Removing poetry cache so it stops it from cleaning up --- .github/workflows/unit_tests.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index 0b5b037..29d01d6 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -45,7 +45,6 @@ jobs: - uses: actions/setup-python@v5.2.0 with: python-version: '3.9' - cache: 'poetry' - name: Install python dependencies and check code run: | poetry env use '3.9'