From ee512175428d4509417dcb1901e59e3fcfe4bc45 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 1 Nov 2024 12:10:09 -0400 Subject: [PATCH 1/2] Adding progress bars and progress details for clients (#104) * Changing current progress bar to make it more flexible to be used by different components * Adding progress bars to each one of the clients --- .github/workflows/unit_tests.yaml | 2 + florist/app/assets/css/florist.css | 66 +++++- florist/app/jobs/details/page.tsx | 221 ++++++++++++------ .../tests/unit/app/jobs/details/page.test.tsx | 131 +++++++++-- 4 files changed, 324 insertions(+), 96 deletions(-) diff --git a/.github/workflows/unit_tests.yaml b/.github/workflows/unit_tests.yaml index f8b317b..ed4af46 100644 --- a/.github/workflows/unit_tests.yaml +++ b/.github/workflows/unit_tests.yaml @@ -13,6 +13,7 @@ on: - .github/workflows/integration_tests.yaml - '**.py' - '**.ipynb' + - '**.tsx' - poetry.lock - pyproject.toml - '**.rst' @@ -29,6 +30,7 @@ on: - .github/workflows/integration_tests.yaml - '**.py' - '**.ipynb' + - '**.tsx' - poetry.lock - pyproject.toml - '**.rst' diff --git a/florist/app/assets/css/florist.css b/florist/app/assets/css/florist.css index 2e564f7..5a5e58d 100644 --- a/florist/app/assets/css/florist.css +++ b/florist/app/assets/css/florist.css @@ -62,17 +62,17 @@ color: white; } -#job-progress .progress { +.job-progress-bar .progress { padding: 0; margin: 0 12px; height: max-content; } -#job-progress .progress-bar { +.job-progress-bar .progress-bar { height: 25px; } -#job-progress .progress-bar.bg-disabled { +.job-progress-bar .progress-bar.bg-disabled { background-color: lightgray !important; } @@ -88,3 +88,63 @@ .job-round-details .col-sm-2 { width: 20%; } + +.job-client-details { + border-bottom-style: hidden; +} + +.job-client-progress-label div { + margin-top: -15px; +} + +.job-client-progress-label.empty-cell { + padding: 0px; +} + +.job-client-progress.empty-cell { + padding: 0px; +} + +.job-client-progress .card { + flex-direction: row; + box-shadow: none; + margin: 0 !important; +} + +.job-client-progress .card .row, +.job-client-progress .card .row .text-dark { + color: #7b809a !important; + font-weight: normal; + text-wrap: wrap; +} + +.job-client-progress .card .card-body { + padding: 0; +} + +.job-client-progress .card .card-header { + height: 0px; + width: 0px; + overflow: hidden; + padding: 0; +} + +.job-client-progress .card .card-body .row .col-sm { + align-content: center; +} + +.job-client-progress .job-progress-bar { + margin-left: -5px; + margin-bottom: 0px !important; + padding-left: 1rem !important; +} + +.job-client-progress .job-progress-bar .progress { + width: 50%; + height: max-content; + margin: auto; +} + +.job-client-progress .job-progress-bar .progress .progress-bar { + height: max-content; +} diff --git a/florist/app/jobs/details/page.tsx b/florist/app/jobs/details/page.tsx index af8d9da..0ad2dab 100644 --- a/florist/app/jobs/details/page.tsx +++ b/florist/app/jobs/details/page.tsx @@ -54,6 +54,14 @@ export function JobDetailsBody(): ReactElement { ); } + let totalEpochs = null; + let localEpochs = null; + if (job.server_config) { + const serverConfigJson = JSON.parse(job.server_config); + totalEpochs = serverConfigJson.n_server_rounds; + localEpochs = serverConfigJson.local_epochs; + } + return (
@@ -105,7 +113,7 @@ export function JobDetailsBody(): ReactElement {
- + ); @@ -164,33 +173,51 @@ export function JobDetailsStatus({ status }: { status: string }): ReactElement { ); } -export function JobProgress({ - serverMetrics, - serverConfig, +export function JobProgressBar({ + metrics, + totalEpochs, status, }: { - serverMetrics: string; - serverConfig: string; + metrics: string; + totalEpochs: number; status: status; }): ReactElement { const [collapsed, setCollapsed] = useState(true); - if (!serverMetrics || !serverConfig) { + if (!metrics || !totalEpochs) { return null; } - const serverMetricsJson = JSON.parse(serverMetrics); - const serverConfigJson = JSON.parse(serverConfig); + const metricsJson = JSON.parse(metrics); + + let endRoundKey; + if (metricsJson.type === "server") { + endRoundKey = "fit_end"; + } + if (metricsJson.type === "client") { + endRoundKey = "shutdown"; + } let progressPercent = 0; - if ("rounds" in serverMetricsJson && Object.keys(serverMetricsJson.rounds).length > 0) { - const totalServerRounds = serverConfigJson.n_server_rounds; - const lastRound = Math.max(...Object.keys(serverMetricsJson.rounds)); - const lastCompletedRound = "fit_end" in serverMetricsJson ? lastRound : lastRound - 1; - progressPercent = (lastCompletedRound * 100) / totalServerRounds; + if ("rounds" in metricsJson && Object.keys(metricsJson.rounds).length > 0) { + const lastRound = Math.max(...Object.keys(metricsJson.rounds)); + const lastCompletedRound = endRoundKey in metricsJson ? lastRound : lastRound - 1; + progressPercent = (lastCompletedRound * 100) / totalEpochs; } const progressWidth = progressPercent === 0 ? "100%" : `${progressPercent}%`; + // Clients will not have a status, so we need to set one based on the progress percent + if (!status) { + if (progressPercent === 0) { + status = "NOT_STARTED"; + } else if (progressPercent === 100) { + status = "FINISHED_SUCCESSFULLY"; + } else { + status = "IN_PROGRESS"; + } + // TODO: add error status + } + let progressBarClasses = "progress-bar progress-bar-striped"; switch (String(validStatuses[status])) { case validStatuses.IN_PROGRESS: @@ -211,7 +238,7 @@ export function JobProgress({ } return ( -
+
Progress: @@ -230,7 +257,7 @@ export function JobProgress({ {Math.floor(progressPercent)}%
- -
- {!collapsed ? : null} -
+
{!collapsed ? : null}
); } -export function JobProgressDetails({ serverMetrics }: { serverMetrics: Object }): ReactElement { - if (!serverMetrics) { +export function JobProgressDetails({ metrics }: { metrics: Object }): ReactElement { + if (!metrics) { return null; } + + let fitStartKey; + let fitEndKey; + if (metrics.type === "server") { + fitStartKey = "fit_start"; + fitEndKey = "fit_end"; + } + if (metrics.type === "client") { + fitStartKey = "initialized"; + fitEndKey = "shutdown"; + } + let elapsedTime = ""; - if ("fit_start" in serverMetrics) { - const startDate = Date.parse(serverMetrics.fit_start); - const endDate = "fit_end" in serverMetrics ? Date.parse(serverMetrics.fit_end) : Date.now(); + if (fitStartKey in metrics) { + const startDate = Date.parse(metrics[fitStartKey]); + const endDate = fitEndKey in metrics ? Date.parse(metrics[fitEndKey]) : Date.now(); elapsedTime = getTimeString(endDate - startDate); } - const roundMetricsArray = Array(serverMetrics.rounds.length); - for (const [round, roundMetrics] of Object.entries(serverMetrics.rounds)) { - roundMetricsArray[parseInt(round) - 1] = roundMetrics; + let roundMetricsArray = []; + if (metrics.rounds) { + roundMetricsArray = Array(metrics.rounds.length); + for (const [round, roundMetrics] of Object.entries(metrics.rounds)) { + roundMetricsArray[parseInt(round) - 1] = roundMetrics; + } } return ( -
+
Elapsed time: @@ -283,17 +323,17 @@ export function JobProgressDetails({ serverMetrics }: { serverMetrics: Object })
Start time:
-
{"fit_start" in serverMetrics ? serverMetrics.fit_start : null}
+
{fitStartKey in metrics ? metrics[fitStartKey] : null}
End time:
-
{"fit_end" in serverMetrics ? serverMetrics.fit_end : null}
+
{fitEndKey in metrics ? metrics[fitEndKey] : null}
- {Object.keys(serverMetrics).map((name, i) => ( - + {Object.keys(metrics).map((name, i) => ( + ))} {roundMetricsArray.map((roundMetrics, i) => ( @@ -316,7 +356,7 @@ export function JobProgressRound({ roundMetrics, index }: { roundMetrics: Object
Round {index + 1}
-
+
setCollapsed(!collapsed)}> {collapsed ? ( @@ -357,7 +397,7 @@ export function JobProgressRoundDetails({ roundMetrics, index }: { roundMetrics: } return ( -
+
Fit elapsed time: @@ -402,7 +442,18 @@ export function JobProgressRoundDetails({ roundMetrics, index }: { roundMetrics: } export function JobProgressProperty({ name, value }: { name: string; value: string }): ReactElement { - if (["fit_start", "fit_end", "evaluate_start", "evaluate_end", "rounds", "type"].includes(name)) { + if ( + [ + "fit_start", + "fit_end", + "evaluate_start", + "evaluate_end", + "rounds", + "type", + "initialized", + "shutdown", + ].includes(name) + ) { return null; } let renderedValue = value; @@ -424,7 +475,7 @@ export function JobProgressProperty({ name, value }: { name: string; value: stri ); } -export function JobDetailsTable({ Component, title, data }): ReactElement { +export function JobDetailsTable({ Component, title, data, properties }): ReactElement { return (
@@ -437,7 +488,7 @@ export function JobDetailsTable({ Component, title, data }): ReactElement {
- +
@@ -446,7 +497,7 @@ export function JobDetailsTable({ Component, title, data }): ReactElement { ); } -export function JobDetailsServerConfigTable({ data }: { data: string }): ReactElement { +export function JobDetailsServerConfigTable({ data, properties }: { data: string; properties: Object }): ReactElement { const emptyResponse = (
Empty. @@ -503,7 +554,15 @@ export function JobDetailsServerConfigTable({ data }: { data: string }): ReactEl ); } -export function JobDetailsClientsInfoTable({ data }: { data: Array }): ReactElement { +export function JobDetailsClientsInfoTable({ + data, + properties, +}: { + data: Array; + properties: Object; +}): ReactElement { + const [collapsed, setCollapsed] = useState(true); + return ( @@ -516,35 +575,61 @@ export function JobDetailsClientsInfoTable({ data }: { data: Array } - {data.map((clientInfo, i) => ( - - - - - - - - ))} + {data.map((clientInfo, i) => { + let additionalClasses = clientInfo.metrics ? "" : "empty-cell"; + return [ + + + + + + + , + + + + , + ]; + })}
-
- {clientInfo.client} -
-
-
- {clientInfo.service_address} -
-
-
- {clientInfo.data_path} -
-
-
- {clientInfo.redis_host} -
-
-
- {clientInfo.redis_port} -
-
+
+ {clientInfo.client} +
+
+
+ {clientInfo.service_address} +
+
+
+ {clientInfo.data_path} +
+
+
+ {clientInfo.redis_host} +
+
+
+ {clientInfo.redis_port} +
+
+ {clientInfo.metrics ? ( +
+ Progress: +
+ ) : null} +
+
+ + + +
+
); diff --git a/florist/tests/unit/app/jobs/details/page.test.tsx b/florist/tests/unit/app/jobs/details/page.test.tsx index 4a5d048..23226d2 100644 --- a/florist/tests/unit/app/jobs/details/page.test.tsx +++ b/florist/tests/unit/app/jobs/details/page.test.tsx @@ -39,8 +39,10 @@ function makeTestJob(): JobData { test_attribute_1: "test-value-1", test_attribute_2: "test-value-2", n_server_rounds: 4, + local_epochs: 2, }), server_metrics: JSON.stringify({ + type: "server", fit_start: "2020-01-01 12:07:07.0707", rounds: { "1": { @@ -76,6 +78,25 @@ function makeTestJob(): JobData { data_path: "test-data-path-1", redis_host: "test-redis-host-1", redis_port: "test-redis-port-1", + metrics: JSON.stringify({ + type: "client", + initialized: "2024-10-10 15:05:59.025693", + shutdown: "2024-10-10 15:12:34.888213", + rounds: { + "1": { + fit_start: "2024-10-10 15:05:34.888213", + fit_end: "2024-10-10 15:06:59.032618", + evaluate_start: "2024-10-10 15:07:59.032618", + evaluate_end: "2024-10-10 15:08:34.888213", + }, + "2": { + fit_start: "2024-10-10 15:06:59.032618", + fit_end: "2024-10-10 15:07:34.888213", + evaluate_start: "2024-10-10 15:08:34.888213", + evaluate_end: "2024-10-10 15:09:59.032618", + }, + }, + }), }, { client: "test-client-2", @@ -83,6 +104,21 @@ function makeTestJob(): JobData { data_path: "test-data-path-2", redis_host: "test-redis-host-2", redis_port: "test-redis-port-2", + metrics: JSON.stringify({ + type: "client", + initialized: "2024-10-10 15:05:59.025693", + rounds: { + "1": { + fit_start: "2024-10-10 15:05:34.888213", + fit_end: "2024-10-10 15:05:34.888213", + evaluate_start: "2024-10-10 15:08:34.888213", + evaluate_end: "2024-10-10 15:08:34.888213", + }, + "2": { + fit_start: "2024-10-10 15:06:59.032618", + }, + }, + }), }, ], }; @@ -201,7 +237,6 @@ describe("Job Details Page", () => { testJob.server_metrics = JSON.stringify({}); setupGetJobMock(testJob); const { container } = render(); - const jobProgressComponent = container.querySelector("#job-progress"); const progressBar = container.querySelector("div.progress-bar"); expect(progressBar.getAttribute("style")).toBe("width: 100%;"); expect(progressBar).toHaveTextContent("0%"); @@ -212,7 +247,6 @@ describe("Job Details Page", () => { testJob.server_metrics = JSON.stringify({ rounds: {} }); setupGetJobMock(testJob); const { container } = render(); - const jobProgressComponent = container.querySelector("#job-progress"); const progressBar = container.querySelector("div.progress-bar"); expect(progressBar.getAttribute("style")).toBe("width: 100%;"); expect(progressBar).toHaveTextContent("0%"); @@ -221,7 +255,6 @@ describe("Job Details Page", () => { it("Display progress bar at with correct progress percent", () => { setupGetJobMock(makeTestJob()); const { container } = render(); - const jobProgressComponent = container.querySelector("#job-progress"); const progressBar = container.querySelector("div.progress-bar"); expect(progressBar.getAttribute("style")).toBe("width: 50%;"); expect(progressBar).toHaveTextContent("50%"); @@ -231,7 +264,6 @@ describe("Job Details Page", () => { testJob.status = "NOT_STARTED"; setupGetJobMock(testJob); const { container } = render(); - const jobProgressComponent = container.querySelector("#job-progress"); const progressBar = container.querySelector("div.progress-bar"); expect(progressBar).toHaveClass("progress-bar-striped"); }); @@ -240,7 +272,6 @@ describe("Job Details Page", () => { testJob.status = "IN_PROGRESS"; setupGetJobMock(testJob); const { container } = render(); - const jobProgressComponent = container.querySelector("#job-progress"); const progressBar = container.querySelector("div.progress-bar"); expect(progressBar).toHaveClass("bg-warning"); }); @@ -249,7 +280,6 @@ describe("Job Details Page", () => { testJob.status = "FINISHED_SUCCESSFULLY"; setupGetJobMock(testJob); const { container } = render(); - const jobProgressComponent = container.querySelector("#job-progress"); const progressBar = container.querySelector("div.progress-bar"); expect(progressBar).toHaveClass("bg-success"); }); @@ -258,7 +288,6 @@ describe("Job Details Page", () => { testJob.status = "FINISHED_WITH_ERROR"; setupGetJobMock(testJob); const { container } = render(); - const jobProgressComponent = container.querySelector("#job-progress"); const progressBar = container.querySelector("div.progress-bar"); expect(progressBar).toHaveClass("bg-danger"); }); @@ -266,20 +295,20 @@ describe("Job Details Page", () => { it("Should be collapsed by default", () => { setupGetJobMock(makeTestJob()); const { container } = render(); - const jobProgressDetailsComponent = container.querySelector("#job-progress-detail"); + const jobProgressDetailsComponent = container.querySelector(".job-progress-detail"); expect(jobProgressDetailsComponent).toBeNull(); }); it("Should open when the toggle button is clicked", () => { setupGetJobMock(makeTestJob()); const { container } = render(); - const toggleButton = container.querySelector("#job-details-toggle a"); + const toggleButton = container.querySelector(".job-details-toggle a"); expect(toggleButton).toHaveTextContent("Expand"); act(() => toggleButton.click()); expect(toggleButton).toHaveTextContent("Collapse"); - const jobProgressDetailsComponent = container.querySelector("#job-progress-detail"); + const jobProgressDetailsComponent = container.querySelector(".job-progress-detail"); expect(jobProgressDetailsComponent).not.toBeNull(); }); it("Should render the contents correctly", () => { @@ -287,10 +316,10 @@ describe("Job Details Page", () => { const serverMetrics = JSON.parse(testJob.server_metrics); setupGetJobMock(testJob); const { container } = render(); - const toggleButton = container.querySelector("#job-details-toggle a"); + const toggleButton = container.querySelector(".job-details-toggle a"); act(() => toggleButton.click()); - const jobProgressDetailsComponent = container.querySelector("#job-progress-detail"); + const jobProgressDetailsComponent = container.querySelector(".job-progress-detail"); const elapsedTime = jobProgressDetailsComponent.children[0]; expect(elapsedTime.children[0]).toHaveTextContent("Elapsed time:"); expect(elapsedTime.children[1]).toHaveTextContent("05m 05s"); @@ -323,13 +352,13 @@ describe("Job Details Page", () => { // rounds const round1 = jobProgressDetailsComponent.children[6].children[0]; expect(round1.children[0]).toHaveTextContent("Round 1"); - expect(round1.children[1].getAttribute("id")).toBe("job-round-toggle-0"); + expect(round1.children[1]).toHaveClass("job-round-toggle-0"); const round2 = jobProgressDetailsComponent.children[7].children[0]; expect(round2.children[0]).toHaveTextContent("Round 2"); - expect(round2.children[1].getAttribute("id")).toBe("job-round-toggle-1"); + expect(round2.children[1]).toHaveClass("job-round-toggle-1"); const round3 = jobProgressDetailsComponent.children[8].children[0]; expect(round3.children[0]).toHaveTextContent("Round 3"); - expect(round3.children[1].getAttribute("id")).toBe("job-round-toggle-2"); + expect(round3.children[1]).toHaveClass("job-round-toggle-2"); }); describe("Rounds", () => { it("Should be collapsed by default", () => { @@ -337,7 +366,7 @@ describe("Job Details Page", () => { const serverMetrics = JSON.parse(testJob.server_metrics); setupGetJobMock(testJob); const { container } = render(); - const progressToggleButton = container.querySelector("#job-details-toggle a"); + const progressToggleButton = container.querySelector(".job-details-toggle a"); act(() => progressToggleButton.click()); for (let i = 0; i < Object.keys(serverMetrics.rounds).length; i++) { @@ -350,18 +379,18 @@ describe("Job Details Page", () => { const serverMetrics = JSON.parse(testJob.server_metrics); setupGetJobMock(testJob); const { container } = render(); - const progressToggleButton = container.querySelector("#job-details-toggle a"); + const progressToggleButton = container.querySelector(".job-details-toggle a"); act(() => progressToggleButton.click()); for (let i = 0; i < Object.keys(serverMetrics.rounds).length; i++) { - const toggleButton = container.querySelector(`#job-round-toggle-${i} a`); + const toggleButton = container.querySelector(`.job-round-toggle-${i} a`); expect(toggleButton).toHaveTextContent("Expand"); act(() => toggleButton.click()); expect(toggleButton).toHaveTextContent("Collapse"); - const jobRoundDetailsComponent = container.querySelector(`#job-round-details-${i}`); + const jobRoundDetailsComponent = container.querySelector(`.job-round-details-${i}`); expect(jobRoundDetailsComponent).not.toBeNull(); } }); @@ -370,7 +399,7 @@ describe("Job Details Page", () => { const serverMetrics = JSON.parse(testJob.server_metrics); setupGetJobMock(testJob); const { container } = render(); - const progressToggleButton = container.querySelector("#job-details-toggle a"); + const progressToggleButton = container.querySelector(".job-details-toggle a"); act(() => progressToggleButton.click()); const expectedTimes = { @@ -387,12 +416,10 @@ describe("Job Details Page", () => { }; for (let i = 0; i < Object.keys(expectedTimes.fit).length; i++) { - const toggleButton = container.querySelector(`#job-round-toggle-${i} a`); + const toggleButton = container.querySelector(`.job-round-toggle-${i} a`); act(() => toggleButton.click()); - console.log(i); - - const jobRoundDetailsComponent = container.querySelector(`#job-round-details-${i}`); + const jobRoundDetailsComponent = container.querySelector(`.job-round-details-${i}`); const fitElapsedTime = jobRoundDetailsComponent.children[0]; expect(fitElapsedTime.children[0]).toHaveTextContent("Fit elapsed time:"); expect(fitElapsedTime.children[1]).toHaveTextContent(expectedTimes.fit[i][0]); @@ -414,7 +441,7 @@ describe("Job Details Page", () => { } // custom properties - const jobRoundDetailsComponent = container.querySelector(`#job-round-details-${0}`); + const jobRoundDetailsComponent = container.querySelector(`.job-round-details-${0}`); const customPropertyValue = jobRoundDetailsComponent.children[6]; expect(customPropertyValue.children[0]).toHaveTextContent("custom_property_value"); expect(customPropertyValue.children[1]).toHaveTextContent(serverMetrics.custom_property_value); @@ -434,6 +461,60 @@ describe("Job Details Page", () => { ); }); }); + describe("Clients", () => { + it("Renders their progress bars correctly", () => { + const testJob = makeTestJob(); + setupGetJobMock(testJob); + const { container } = render(); + const clientsProgress = container.querySelectorAll(".job-client-progress"); + + let progressBar = clientsProgress[0].querySelector("div.progress-bar"); + expect(progressBar).toHaveClass("bg-success"); + expect(progressBar).toHaveTextContent("100%"); + + progressBar = clientsProgress[1].querySelector("div.progress-bar"); + expect(progressBar).toHaveClass("bg-warning"); + expect(progressBar).toHaveTextContent("50%"); + }); + it("Renders the progress details correctly", () => { + const testJob = makeTestJob(); + setupGetJobMock(testJob); + const { container } = render(); + + let toggleButton = container.querySelectorAll(".job-client-progress .job-details-toggle a")[0]; + act(() => toggleButton.click()); + + let clientMetrics = JSON.parse(testJob.clients_info[0].metrics); + let progressDetailsComponent = container.querySelectorAll( + ".job-client-progress .job-progress-detail", + )[0]; + let elapsedTime = progressDetailsComponent.children[0]; + expect(elapsedTime.children[0]).toHaveTextContent("Elapsed time:"); + expect(elapsedTime.children[1]).toHaveTextContent("06m 35s"); + let fitStart = progressDetailsComponent.children[1]; + expect(fitStart.children[0]).toHaveTextContent("Start time:"); + expect(fitStart.children[1]).toHaveTextContent(clientMetrics.initialized); + let fitEnd = progressDetailsComponent.children[2]; + expect(fitEnd.children[0]).toHaveTextContent("End time:"); + expect(fitEnd.children[1]).toHaveTextContent(clientMetrics.shutdown); + + toggleButton = container.querySelectorAll(".job-client-progress .job-details-toggle a")[1]; + act(() => toggleButton.click()); + clientMetrics = JSON.parse(testJob.clients_info[1].metrics); + progressDetailsComponent = container.querySelectorAll( + ".job-client-progress .job-progress-detail", + )[1]; + elapsedTime = progressDetailsComponent.children[0]; + expect(elapsedTime.children[0]).toHaveTextContent("Elapsed time:"); + expect(elapsedTime.children[1]).toHaveTextContent("06m 13s"); + fitStart = progressDetailsComponent.children[1]; + expect(fitStart.children[0]).toHaveTextContent("Start time:"); + expect(fitStart.children[1]).toHaveTextContent(clientMetrics.initialized); + fitEnd = progressDetailsComponent.children[2]; + expect(fitEnd.children[0]).toHaveTextContent("End time:"); + expect(fitEnd.children[1]).toHaveTextContent(""); + }); + }); }); }); describe("Server config", () => { From 2046e5c68a0a2facc73a4e2faeff2f04ddd6edf6 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 1 Nov 2024 12:10:37 -0400 Subject: [PATCH 2/2] Explicitly listen to client updates (#106) Replacing the old approach of updating client metrics together with server metrics update for an active listener to each one of the clients' redis, similar to the one for the server. This will ensure the metrics are always up to date for clients and also the client metrics updates are not dependent on server updates and can be ultimately visualized independently from server metrics updates in the UI. --- florist/api/db/entities.py | 43 ++-- florist/api/routes/server/training.py | 109 ++++++----- .../tests/integration/api/db/test_entities.py | 53 +++-- .../unit/api/routes/server/test_training.py | 184 +++++++++++++----- 4 files changed, 257 insertions(+), 132 deletions(-) diff --git a/florist/api/db/entities.py b/florist/api/db/entities.py index 55c1fac..aeb9d64 100644 --- a/florist/api/db/entities.py +++ b/florist/api/db/entities.py @@ -176,36 +176,49 @@ def set_status_sync(self, status: JobStatus, database: Database[Dict[str, Any]]) update_result = job_collection.update_one({"_id": self.id}, {"$set": {"status": status.value}}) assert_updated_successfully(update_result) - def set_metrics( + def set_server_metrics( self, server_metrics: Dict[str, Any], - client_metrics: List[Dict[str, Any]], database: Database[Dict[str, Any]], ) -> None: """ - Sync function to save the server and clients' metrics in the database under the current job's id. + Sync function to save the server's metrics in the database under the current job's id. :param server_metrics: (Dict[str, Any]) the server metrics to be saved. - :param client_metrics: (List[Dict[str, Any]]) the clients metrics to be saved. :param database: (pymongo.database.Database) The database where the job collection is stored. """ - assert self.clients_info is not None and len(self.clients_info) == len(client_metrics), ( - "self.clients_info and client_metrics must have the same length " - f"({'None' if self.clients_info is None else len(self.clients_info)}!={len(client_metrics)})." - ) - job_collection = database[JOB_COLLECTION_NAME] self.server_metrics = json.dumps(server_metrics) update_result = job_collection.update_one({"_id": self.id}, {"$set": {"server_metrics": self.server_metrics}}) assert_updated_successfully(update_result) - for i in range(len(client_metrics)): - self.clients_info[i].metrics = json.dumps(client_metrics[i]) - update_result = job_collection.update_one( - {"_id": self.id}, {"$set": {f"clients_info.{i}.metrics": self.clients_info[i].metrics}} - ) - assert_updated_successfully(update_result) + def set_client_metrics( + self, + client_uuid: str, + client_metrics: Dict[str, Any], + database: Database[Dict[str, Any]], + ) -> None: + """ + Sync function to save a clients' metrics in the database under the current job's id. + + :param client_uuid: (str) the client's uuid whose produced the metrics. + :param client_metrics: (Dict[str, Any]) the client's metrics to be saved. + :param database: (pymongo.database.Database) The database where the job collection is stored. + """ + assert ( + self.clients_info is not None and client_uuid in [c.uuid for c in self.clients_info] + ), f"client uuid {client_uuid} is not in clients_info ({[c.uuid for c in self.clients_info] if self.clients_info is not None else None})" + + job_collection = database[JOB_COLLECTION_NAME] + + for i in range(len(self.clients_info)): + if client_uuid == self.clients_info[i].uuid: + self.clients_info[i].metrics = json.dumps(client_metrics) + update_result = job_collection.update_one( + {"_id": self.id}, {"$set": {f"clients_info.{i}.metrics": self.clients_info[i].metrics}} + ) + assert_updated_successfully(update_result) class Config: """MongoDB config for the Job DB entity.""" diff --git a/florist/api/routes/server/training.py b/florist/api/routes/server/training.py index cdb0d8e..f00778f 100644 --- a/florist/api/routes/server/training.py +++ b/florist/api/routes/server/training.py @@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse from pymongo.database import Database -from florist.api.db.entities import Job, JobStatus +from florist.api.db.entities import ClientInfo, Job, JobStatus from florist.api.monitoring.metrics import get_from_redis, get_subscriber, wait_for_metric from florist.api.servers.common import Model from florist.api.servers.config_parsers import ConfigParser @@ -106,8 +106,10 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks await job.set_uuids(server_uuid, client_uuids, request.app.database) # Start the server training listener as a background task to update - # the job's status once the training is done + # the job's metrics and status once the training is done background_tasks.add_task(server_training_listener, job, request.app.synchronous_database) + for client_info in job.clients_info: + background_tasks.add_task(client_training_listener, job, client_info, request.app.synchronous_database) # Return the UUIDs return JSONResponse({"server_uuid": server_uuid, "client_uuids": client_uuids}) @@ -124,6 +126,47 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks return JSONResponse({"error": str(ex)}, status_code=500) +def client_training_listener(job: Job, client_info: ClientInfo, database: Database[Dict[str, Any]]) -> None: + """ + Listen to the Redis' channel that reports updates on the training process of a FL client. + + Keeps consuming updates to the channel until it finds `shutdown` in the client metrics. + + :param job: (Job) The job that has this client's metrics. + :param client_info: (ClientInfo) The ClientInfo with the client_uuid to listen to. + :param database: (pymongo.database.Database) An instance of the database to save the information + into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async + because of limitations with FastAPI's BackgroundTasks. + """ + LOGGER.info(f"Starting listener for client messages from job {job.id} at channel {client_info.uuid}") + + assert client_info.uuid is not None, "client_info.uuid is None." + + # check if training has already finished before start listening + client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port) + LOGGER.debug(f"Listener: Current metrics for client {client_info.uuid}: {client_metrics}") + if client_metrics is not None: + LOGGER.info(f"Listener: Updating client metrics for client {client_info.uuid} on job {job.id}") + job.set_client_metrics(client_info.uuid, client_metrics, database) + LOGGER.info(f"Listener: Client metrics for client {client_info.uuid} on {job.id} has been updated.") + if "shutdown" in client_metrics: + return + + subscriber = get_subscriber(client_info.uuid, client_info.redis_host, client_info.redis_port) + # TODO add a max retries mechanism, maybe? + for message in subscriber.listen(): # type: ignore[no-untyped-call] + if message["type"] == "message": + # The contents of the message do not matter, we just use it to get notified + client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port) + LOGGER.debug(f"Listener: Current metrics for client {client_info.uuid}: {client_metrics}") + if client_metrics is not None: + LOGGER.info(f"Listener: Updating client metrics for client {client_info.uuid} on job {job.id}") + job.set_client_metrics(client_info.uuid, client_metrics, database) + LOGGER.info(f"Listener: Client metrics for client {client_info.uuid} on {job.id} has been updated.") + if "shutdown" in client_metrics: + return + + def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> None: """ Listen to the Redis' channel that reports updates on the training process of a FL server. @@ -147,9 +190,13 @@ def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> No server_metrics = get_from_redis(job.server_uuid, job.redis_host, job.redis_port) LOGGER.debug(f"Listener: Current metrics for job {job.id}: {server_metrics}") if server_metrics is not None: - update_job_metrics(job, server_metrics, database) + LOGGER.info(f"Listener: Updating server metrics for job {job.id}") + job.set_server_metrics(server_metrics, database) + LOGGER.info(f"Listener: Server metrics for {job.id} has been updated.") if "fit_end" in server_metrics: - close_job(job, database) + LOGGER.info(f"Listener: Training finished for job {job.id}") + job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database) + LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.") return subscriber = get_subscriber(job.server_uuid, job.redis_host, job.redis_port) @@ -161,53 +208,11 @@ def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> No LOGGER.debug(f"Listener: Message received for job {job.id}. Metrics: {server_metrics}") if server_metrics is not None: - update_job_metrics(job, server_metrics, database) + LOGGER.info(f"Listener: Updating server metrics for job {job.id}") + job.set_server_metrics(server_metrics, database) + LOGGER.info(f"Listener: Server metrics for {job.id} has been updated.") if "fit_end" in server_metrics: - close_job(job, database) + LOGGER.info(f"Listener: Training finished for job {job.id}") + job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database) + LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.") return - - -def update_job_metrics(job: Job, server_metrics: Dict[str, Any], database: Database[Dict[str, Any]]) -> None: - """ - Update the job with server and client metrics. - - Collect the job's clients metrics, saving them and the server's metrics to the job. - - :param job: (Job) The job to be updated. - :param server_metrics: (Dict[str, Any]) The server's metrics to be saved into the job. - :param database: (pymongo.database.Database) An instance of the database to save the information - into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async - because of limitations with FastAPI's BackgroundTasks. - """ - LOGGER.info(f"Listener: Updating metrics for job {job.id}") - - clients_metrics: List[Dict[str, Any]] = [] - if job.clients_info is not None: - for client_info in job.clients_info: - response = requests.get( - url=f"http://{client_info.service_address}/{CHECK_CLIENT_STATUS_API}/{client_info.uuid}", - params={ - "redis_host": client_info.redis_host, - "redis_port": client_info.redis_port, - }, - ) - client_metrics = response.json() - clients_metrics.append(client_metrics) - - job.set_metrics(server_metrics, clients_metrics, database) - - LOGGER.info(f"Listener: Job {job.id} has been updated.") - - -def close_job(job: Job, database: Database[Dict[str, Any]]) -> None: - """ - Close the job by marking its status as FINISHED_SUCCESSFULLY. - - :param job: (Job) The job to be closed. - :param database: (pymongo.database.Database) An instance of the database to save the information - into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async - because of limitations with FastAPI's BackgroundTasks. - """ - LOGGER.info(f"Listener: Training finished for job {job.id}") - job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database) - LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.") diff --git a/florist/tests/integration/api/db/test_entities.py b/florist/tests/integration/api/db/test_entities.py index a79e8e9..6b8234e 100644 --- a/florist/tests/integration/api/db/test_entities.py +++ b/florist/tests/integration/api/db/test_entities.py @@ -169,7 +169,7 @@ async def test_set_status_sync_fail_update_result(mock_request) -> None: test_job.set_status_sync(JobStatus.IN_PROGRESS, mock_request.app.synchronous_database) -async def test_set_metrics_success(mock_request) -> None: +async def test_set_server_metrics_success(mock_request) -> None: test_job = get_test_job() result_id = await test_job.create(mock_request.app.database) test_job.id = result_id @@ -177,56 +177,67 @@ async def test_set_metrics_success(mock_request) -> None: test_job.clients_info[1].id = ANY test_server_metrics = {"test-server": 123} - test_client_metrics = [{"test-client-1": 456}, {"test-client-2": 789}] - test_job.set_metrics(test_server_metrics, test_client_metrics, mock_request.app.synchronous_database) + test_job.set_server_metrics(test_server_metrics, mock_request.app.synchronous_database) result_job = await Job.find_by_id(result_id, mock_request.app.database) test_job.server_metrics = json.dumps(test_server_metrics) - test_job.clients_info[0].metrics = json.dumps(test_client_metrics[0]) - test_job.clients_info[1].metrics = json.dumps(test_client_metrics[1]) assert result_job == test_job -async def test_set_metrics_fail_clients_info_is_none(mock_request) -> None: +async def test_set_server_metrics_fail_update_result(mock_request) -> None: test_job = get_test_job() - test_job.clients_info = None - result_id = await test_job.create(mock_request.app.database) - test_job.id = result_id + test_job.id = str(test_job.id) test_server_metrics = {"test-server": 123} - test_client_metrics = [{"test-client-1": 456}, {"test-client-2": 789}] - error_msg = "self.clients_info and client_metrics must have the same length (None!=2)." + error_msg = "UpdateResult's 'n' is not 1" with raises(AssertionError, match=re.escape(error_msg)): - test_job.set_metrics(test_server_metrics, test_client_metrics, mock_request.app.synchronous_database) + test_job.set_server_metrics(test_server_metrics, mock_request.app.synchronous_database) -async def test_set_metrics_fail_clients_info_is_not_same_length(mock_request) -> None: +async def test_set_client_metrics_success(mock_request) -> None: test_job = get_test_job() result_id = await test_job.create(mock_request.app.database) test_job.id = result_id test_job.clients_info[0].id = ANY test_job.clients_info[1].id = ANY - test_server_metrics = {"test-server": 123} - test_client_metrics = [{"test-client-1": 456}] + test_client_metrics = [{"test-metric-1": 456}, {"test-metric-2": 789}] + + test_job.set_client_metrics(test_job.clients_info[1].uuid, test_client_metrics, mock_request.app.synchronous_database) - error_msg = "self.clients_info and client_metrics must have the same length (2!=1)." + result_job = await Job.find_by_id(result_id, mock_request.app.database) + test_job.clients_info[1].metrics = json.dumps(test_client_metrics) + assert result_job == test_job + + +async def test_set_metrics_fail_clients_info_is_none(mock_request) -> None: + test_job = get_test_job() + result_id = await test_job.create(mock_request.app.database) + test_job.id = result_id + + test_wrong_client_uuid = "client-id-that-does-not-exist" + test_client_metrics = [{"test-metric-1": 456}, {"test-metric-2": 789}] + + error_msg = f"client uuid {test_wrong_client_uuid} is not in clients_info (['{test_job.clients_info[0].uuid}', '{test_job.clients_info[1].uuid}'])" with raises(AssertionError, match=re.escape(error_msg)): - test_job.set_metrics(test_server_metrics, test_client_metrics, mock_request.app.synchronous_database) + test_job.set_client_metrics(test_wrong_client_uuid, test_client_metrics, mock_request.app.synchronous_database) -async def test_set_metrics_fail_update_result(mock_request) -> None: +async def test_set_client_metrics_fail_update_result(mock_request) -> None: test_job = get_test_job() test_job.id = str(test_job.id) - test_server_metrics = {"test-server": 123} - test_client_metrics = [{"test-client-1": 456}, {"test-client-2": 789}] + test_client_metrics = [{"test-metric-1": 456}, {"test-metric-2": 789}] error_msg = "UpdateResult's 'n' is not 1" with raises(AssertionError, match=re.escape(error_msg)): - test_job.set_metrics(test_server_metrics, test_client_metrics, mock_request.app.synchronous_database) + test_job.set_client_metrics( + test_job.clients_info[0].uuid, + test_client_metrics, + mock_request.app.synchronous_database, + ) def get_test_job() -> Job: diff --git a/florist/tests/unit/api/routes/server/test_training.py b/florist/tests/unit/api/routes/server/test_training.py index c076eef..17e5af7 100644 --- a/florist/tests/unit/api/routes/server/test_training.py +++ b/florist/tests/unit/api/routes/server/test_training.py @@ -6,7 +6,12 @@ from florist.api.db.entities import Job, JobStatus, JOB_COLLECTION_NAME from florist.api.models.mnist import MnistNet -from florist.api.routes.server.training import start, server_training_listener, CHECK_CLIENT_STATUS_API +from florist.api.routes.server.training import ( + client_training_listener, + start, + server_training_listener, + CHECK_CLIENT_STATUS_API, +) @patch("florist.api.routes.server.training.launch_local_server") @@ -96,11 +101,25 @@ async def test_start_success( expected_job.id = ANY expected_job.clients_info[0].id = ANY expected_job.clients_info[1].id = ANY - mock_background_tasks.add_task.assert_called_once_with( - server_training_listener, - expected_job, - mock_fastapi_request.app.synchronous_database, - ) + mock_background_tasks.add_task.assert_has_calls([ + call( + server_training_listener, + expected_job, + mock_fastapi_request.app.synchronous_database, + ), + call( + client_training_listener, + expected_job, + expected_job.clients_info[0], + mock_fastapi_request.app.synchronous_database, + ), + call( + client_training_listener, + expected_job, + expected_job.clients_info[1], + mock_fastapi_request.app.synchronous_database, + ), + ]) async def test_start_fail_unsupported_server_model() -> None: @@ -318,8 +337,7 @@ async def test_start_no_uuid_in_response(mock_requests: Mock, mock_redis: Mock, @patch("florist.api.routes.server.training.get_from_redis") @patch("florist.api.routes.server.training.get_subscriber") -@patch("florist.api.routes.server.training.requests") -def test_server_training_listener(mock_requests: Mock(), mock_get_subscriber: Mock, mock_get_from_redis: Mock) -> None: +def test_server_training_listener(mock_get_subscriber: Mock, mock_get_from_redis: Mock) -> None: # Setup test_job = Job(**{ "server_uuid": "test-server-uuid", @@ -336,11 +354,10 @@ def test_server_training_listener(mock_requests: Mock(), mock_get_subscriber: Mo } ] }) - test_client_metrics = {"test": 123} test_server_metrics = [ {"fit_start": "2022-02-02 02:02:02"}, {"fit_start": "2022-02-02 02:02:02", "rounds": []}, - {"fit_start": "2022-02-02 02:02:02", "rounds": [], "fit_end": "2022-02-02 03:03:03"} + {"fit_start": "2022-02-02 02:02:02", "rounds": [], "fit_end": "2022-02-02 03:03:03"}, ] mock_get_from_redis.side_effect = test_server_metrics mock_subscriber = Mock() @@ -353,41 +370,28 @@ def test_server_training_listener(mock_requests: Mock(), mock_get_subscriber: Mo ] mock_get_subscriber.return_value = mock_subscriber mock_database = Mock() - mock_response = Mock() - mock_response.json.return_value = test_client_metrics - mock_requests.get.return_value = mock_response with patch.object(Job, "set_status_sync", Mock()) as mock_set_status_sync: - with patch.object(Job, "set_metrics", Mock()) as mock_set_metrics: + with patch.object(Job, "set_server_metrics", Mock()) as mock_set_server_metrics: # Act server_training_listener(test_job, mock_database) # Assert mock_set_status_sync.assert_called_once_with(JobStatus.FINISHED_SUCCESSFULLY, mock_database) - assert mock_set_metrics.call_count == 3 - mock_set_metrics.assert_has_calls([ - call(test_server_metrics[0], [test_client_metrics], mock_database), - call(test_server_metrics[1], [test_client_metrics], mock_database), - call(test_server_metrics[2], [test_client_metrics], mock_database), + assert mock_set_server_metrics.call_count == 3 + mock_set_server_metrics.assert_has_calls([ + call(test_server_metrics[0], mock_database), + call(test_server_metrics[1], mock_database), + call(test_server_metrics[2], mock_database), ]) assert mock_get_from_redis.call_count == 3 mock_get_subscriber.assert_called_once_with(test_job.server_uuid, test_job.redis_host, test_job.redis_port) - assert mock_requests.get.call_count == 3 - mock_requests_get_call = call( - url=f"http://{test_job.clients_info[0].service_address}/{CHECK_CLIENT_STATUS_API}/{test_job.clients_info[0].uuid}", - params={ - "redis_host": test_job.clients_info[0].redis_host, - "redis_port": test_job.clients_info[0].redis_port, - }, - ) - assert mock_requests.get.call_args_list == [mock_requests_get_call, mock_requests_get_call, mock_requests_get_call] @patch("florist.api.routes.server.training.get_from_redis") -@patch("florist.api.routes.server.training.requests") -def test_server_training_listener_already_finished(mock_requests: Mock, mock_get_from_redis: Mock) -> None: +def test_server_training_listener_already_finished(mock_get_from_redis: Mock) -> None: # Setup test_job = Job(**{ "server_uuid": "test-server-uuid", @@ -404,31 +408,20 @@ def test_server_training_listener_already_finished(mock_requests: Mock, mock_get } ] }) - test_client_metrics = {"test": 123} test_server_final_metrics = {"fit_start": "2022-02-02 02:02:02", "rounds": [], "fit_end": "2022-02-02 03:03:03"} mock_get_from_redis.side_effect = [test_server_final_metrics] mock_database = Mock() - mock_response = Mock() - mock_response.json.return_value = test_client_metrics - mock_requests.get.return_value = mock_response with patch.object(Job, "set_status_sync", Mock()) as mock_set_status_sync: - with patch.object(Job, "set_metrics", Mock()) as mock_set_metrics: + with patch.object(Job, "set_server_metrics", Mock()) as mock_set_server_metrics: # Act server_training_listener(test_job, mock_database) # Assert mock_set_status_sync.assert_called_once_with(JobStatus.FINISHED_SUCCESSFULLY, mock_database) - mock_set_metrics.assert_called_once_with(test_server_final_metrics, [test_client_metrics], - mock_database) + mock_set_server_metrics.assert_called_once_with(test_server_final_metrics, mock_database) + assert mock_get_from_redis.call_count == 1 - mock_requests.get.assert_called_once_with( - url=f"http://{test_job.clients_info[0].service_address}/{CHECK_CLIENT_STATUS_API}/{test_job.clients_info[0].uuid}", - params={ - "redis_host": test_job.clients_info[0].redis_host, - "redis_port": test_job.clients_info[0].redis_port, - }, - ) def test_server_training_listener_fail_no_server_uuid() -> None: @@ -461,6 +454,109 @@ def test_server_training_listener_fail_no_redis_port() -> None: server_training_listener(test_job, Mock()) +@patch("florist.api.routes.server.training.get_from_redis") +@patch("florist.api.routes.server.training.get_subscriber") +def test_client_training_listener(mock_get_subscriber: Mock, mock_get_from_redis: Mock) -> None: + # Setup + test_client_uuid = "test-client-uuid"; + test_job = Job(**{ + "clients_info": [ + { + "service_address": "test-service-address", + "uuid": test_client_uuid, + "redis_host": "test-client-redis-host", + "redis_port": "test-client-redis-port", + "client": "MNIST", + "data_path": "test-data-path", + } + ] + }) + test_client_metrics = [ + {"initialized": "2022-02-02 02:02:02"}, + {"initialized": "2022-02-02 02:02:02", "rounds": []}, + {"initialized": "2022-02-02 02:02:02", "rounds": [], "shutdown": "2022-02-02 03:03:03"}, + ] + mock_get_from_redis.side_effect = test_client_metrics + mock_subscriber = Mock() + mock_subscriber.listen.return_value = [ + {"type": "message"}, + {"type": "not message"}, + {"type": "message"}, + {"type": "message"}, + {"type": "message"}, + ] + mock_get_subscriber.return_value = mock_subscriber + mock_database = Mock() + + with patch.object(Job, "set_status_sync", Mock()) as mock_set_status_sync: + with patch.object(Job, "set_client_metrics", Mock()) as mock_set_client_metrics: + # Act + client_training_listener(test_job, test_job.clients_info[0], mock_database) + + # Assert + assert mock_set_client_metrics.call_count == 3 + mock_set_client_metrics.assert_has_calls([ + call(test_client_uuid, test_client_metrics[0], mock_database), + call(test_client_uuid, test_client_metrics[1], mock_database), + call(test_client_uuid, test_client_metrics[2], mock_database), + ]) + + assert mock_get_from_redis.call_count == 3 + mock_get_subscriber.assert_called_once_with( + test_job.clients_info[0].uuid, + test_job.clients_info[0].redis_host, + test_job.clients_info[0].redis_port, + ) + + +@patch("florist.api.routes.server.training.get_from_redis") +def test_client_training_listener_already_finished(mock_get_from_redis: Mock) -> None: + # Setup + test_client_uuid = "test-client-uuid"; + test_job = Job(**{ + "clients_info": [ + { + "service_address": "test-service-address", + "uuid": test_client_uuid, + "redis_host": "test-client-redis-host", + "redis_port": "test-client-redis-port", + "client": "MNIST", + "data_path": "test-data-path", + } + ] + }) + test_client_final_metrics = {"initialized": "2022-02-02 02:02:02", "rounds": [], "shutdown": "2022-02-02 03:03:03"} + mock_get_from_redis.side_effect = [test_client_final_metrics] + mock_database = Mock() + + with patch.object(Job, "set_status_sync", Mock()) as mock_set_status_sync: + with patch.object(Job, "set_client_metrics", Mock()) as mock_set_client_metrics: + # Act + client_training_listener(test_job, test_job.clients_info[0], mock_database) + + # Assert + mock_set_client_metrics.assert_called_once_with(test_client_uuid, test_client_final_metrics, mock_database) + + assert mock_get_from_redis.call_count == 1 + + +def test_client_training_listener_fail_no_uuid() -> None: + test_job = Job(**{ + "clients_info": [ + { + "redis_host": "test-redis-host", + "redis_port": "test-redis-port", + "service_address": "test-service-address", + "client": "MNIST", + "data_path": "test-data-path", + }, + ], + }) + + with raises(AssertionError, match="client_info.uuid is None."): + client_training_listener(test_job, test_job.clients_info[0], Mock()) + + def _setup_test_job_and_mocks() -> Tuple[Dict[str, Any], Dict[str, Any], Mock, Mock]: test_server_config = { "n_server_rounds": 2,