Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/github_actions/actions/checkout-4…
Browse files Browse the repository at this point in the history
….2.2
  • Loading branch information
lotif authored Nov 1, 2024
2 parents 0b05bce + 2046e5c commit 92a405e
Show file tree
Hide file tree
Showing 8 changed files with 581 additions and 228 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ on:
- .github/workflows/integration_tests.yaml
- '**.py'
- '**.ipynb'
- '**.tsx'
- poetry.lock
- pyproject.toml
- '**.rst'
Expand All @@ -29,6 +30,7 @@ on:
- .github/workflows/integration_tests.yaml
- '**.py'
- '**.ipynb'
- '**.tsx'
- poetry.lock
- pyproject.toml
- '**.rst'
Expand Down
43 changes: 28 additions & 15 deletions florist/api/db/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,36 +176,49 @@ def set_status_sync(self, status: JobStatus, database: Database[Dict[str, Any]])
update_result = job_collection.update_one({"_id": self.id}, {"$set": {"status": status.value}})
assert_updated_successfully(update_result)

def set_metrics(
def set_server_metrics(
self,
server_metrics: Dict[str, Any],
client_metrics: List[Dict[str, Any]],
database: Database[Dict[str, Any]],
) -> None:
"""
Sync function to save the server and clients' metrics in the database under the current job's id.
Sync function to save the server's metrics in the database under the current job's id.
:param server_metrics: (Dict[str, Any]) the server metrics to be saved.
:param client_metrics: (List[Dict[str, Any]]) the clients metrics to be saved.
:param database: (pymongo.database.Database) The database where the job collection is stored.
"""
assert self.clients_info is not None and len(self.clients_info) == len(client_metrics), (
"self.clients_info and client_metrics must have the same length "
f"({'None' if self.clients_info is None else len(self.clients_info)}!={len(client_metrics)})."
)

job_collection = database[JOB_COLLECTION_NAME]

self.server_metrics = json.dumps(server_metrics)
update_result = job_collection.update_one({"_id": self.id}, {"$set": {"server_metrics": self.server_metrics}})
assert_updated_successfully(update_result)

for i in range(len(client_metrics)):
self.clients_info[i].metrics = json.dumps(client_metrics[i])
update_result = job_collection.update_one(
{"_id": self.id}, {"$set": {f"clients_info.{i}.metrics": self.clients_info[i].metrics}}
)
assert_updated_successfully(update_result)
def set_client_metrics(
self,
client_uuid: str,
client_metrics: Dict[str, Any],
database: Database[Dict[str, Any]],
) -> None:
"""
Sync function to save a clients' metrics in the database under the current job's id.
:param client_uuid: (str) the client's uuid whose produced the metrics.
:param client_metrics: (Dict[str, Any]) the client's metrics to be saved.
:param database: (pymongo.database.Database) The database where the job collection is stored.
"""
assert (
self.clients_info is not None and client_uuid in [c.uuid for c in self.clients_info]
), f"client uuid {client_uuid} is not in clients_info ({[c.uuid for c in self.clients_info] if self.clients_info is not None else None})"

job_collection = database[JOB_COLLECTION_NAME]

for i in range(len(self.clients_info)):
if client_uuid == self.clients_info[i].uuid:
self.clients_info[i].metrics = json.dumps(client_metrics)
update_result = job_collection.update_one(
{"_id": self.id}, {"$set": {f"clients_info.{i}.metrics": self.clients_info[i].metrics}}
)
assert_updated_successfully(update_result)

class Config:
"""MongoDB config for the Job DB entity."""
Expand Down
109 changes: 57 additions & 52 deletions florist/api/routes/server/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi.responses import JSONResponse
from pymongo.database import Database

from florist.api.db.entities import Job, JobStatus
from florist.api.db.entities import ClientInfo, Job, JobStatus
from florist.api.monitoring.metrics import get_from_redis, get_subscriber, wait_for_metric
from florist.api.servers.common import Model
from florist.api.servers.config_parsers import ConfigParser
Expand Down Expand Up @@ -106,8 +106,10 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks
await job.set_uuids(server_uuid, client_uuids, request.app.database)

# Start the server training listener as a background task to update
# the job's status once the training is done
# the job's metrics and status once the training is done
background_tasks.add_task(server_training_listener, job, request.app.synchronous_database)
for client_info in job.clients_info:
background_tasks.add_task(client_training_listener, job, client_info, request.app.synchronous_database)

# Return the UUIDs
return JSONResponse({"server_uuid": server_uuid, "client_uuids": client_uuids})
Expand All @@ -124,6 +126,47 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks
return JSONResponse({"error": str(ex)}, status_code=500)


def client_training_listener(job: Job, client_info: ClientInfo, database: Database[Dict[str, Any]]) -> None:
"""
Listen to the Redis' channel that reports updates on the training process of a FL client.
Keeps consuming updates to the channel until it finds `shutdown` in the client metrics.
:param job: (Job) The job that has this client's metrics.
:param client_info: (ClientInfo) The ClientInfo with the client_uuid to listen to.
:param database: (pymongo.database.Database) An instance of the database to save the information
into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async
because of limitations with FastAPI's BackgroundTasks.
"""
LOGGER.info(f"Starting listener for client messages from job {job.id} at channel {client_info.uuid}")

assert client_info.uuid is not None, "client_info.uuid is None."

# check if training has already finished before start listening
client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port)
LOGGER.debug(f"Listener: Current metrics for client {client_info.uuid}: {client_metrics}")
if client_metrics is not None:
LOGGER.info(f"Listener: Updating client metrics for client {client_info.uuid} on job {job.id}")
job.set_client_metrics(client_info.uuid, client_metrics, database)
LOGGER.info(f"Listener: Client metrics for client {client_info.uuid} on {job.id} has been updated.")
if "shutdown" in client_metrics:
return

subscriber = get_subscriber(client_info.uuid, client_info.redis_host, client_info.redis_port)
# TODO add a max retries mechanism, maybe?
for message in subscriber.listen(): # type: ignore[no-untyped-call]
if message["type"] == "message":
# The contents of the message do not matter, we just use it to get notified
client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port)
LOGGER.debug(f"Listener: Current metrics for client {client_info.uuid}: {client_metrics}")
if client_metrics is not None:
LOGGER.info(f"Listener: Updating client metrics for client {client_info.uuid} on job {job.id}")
job.set_client_metrics(client_info.uuid, client_metrics, database)
LOGGER.info(f"Listener: Client metrics for client {client_info.uuid} on {job.id} has been updated.")
if "shutdown" in client_metrics:
return


def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> None:
"""
Listen to the Redis' channel that reports updates on the training process of a FL server.
Expand All @@ -147,9 +190,13 @@ def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> No
server_metrics = get_from_redis(job.server_uuid, job.redis_host, job.redis_port)
LOGGER.debug(f"Listener: Current metrics for job {job.id}: {server_metrics}")
if server_metrics is not None:
update_job_metrics(job, server_metrics, database)
LOGGER.info(f"Listener: Updating server metrics for job {job.id}")
job.set_server_metrics(server_metrics, database)
LOGGER.info(f"Listener: Server metrics for {job.id} has been updated.")
if "fit_end" in server_metrics:
close_job(job, database)
LOGGER.info(f"Listener: Training finished for job {job.id}")
job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database)
LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.")
return

subscriber = get_subscriber(job.server_uuid, job.redis_host, job.redis_port)
Expand All @@ -161,53 +208,11 @@ def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> No
LOGGER.debug(f"Listener: Message received for job {job.id}. Metrics: {server_metrics}")

if server_metrics is not None:
update_job_metrics(job, server_metrics, database)
LOGGER.info(f"Listener: Updating server metrics for job {job.id}")
job.set_server_metrics(server_metrics, database)
LOGGER.info(f"Listener: Server metrics for {job.id} has been updated.")
if "fit_end" in server_metrics:
close_job(job, database)
LOGGER.info(f"Listener: Training finished for job {job.id}")
job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database)
LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.")
return


def update_job_metrics(job: Job, server_metrics: Dict[str, Any], database: Database[Dict[str, Any]]) -> None:
"""
Update the job with server and client metrics.
Collect the job's clients metrics, saving them and the server's metrics to the job.
:param job: (Job) The job to be updated.
:param server_metrics: (Dict[str, Any]) The server's metrics to be saved into the job.
:param database: (pymongo.database.Database) An instance of the database to save the information
into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async
because of limitations with FastAPI's BackgroundTasks.
"""
LOGGER.info(f"Listener: Updating metrics for job {job.id}")

clients_metrics: List[Dict[str, Any]] = []
if job.clients_info is not None:
for client_info in job.clients_info:
response = requests.get(
url=f"http://{client_info.service_address}/{CHECK_CLIENT_STATUS_API}/{client_info.uuid}",
params={
"redis_host": client_info.redis_host,
"redis_port": client_info.redis_port,
},
)
client_metrics = response.json()
clients_metrics.append(client_metrics)

job.set_metrics(server_metrics, clients_metrics, database)

LOGGER.info(f"Listener: Job {job.id} has been updated.")


def close_job(job: Job, database: Database[Dict[str, Any]]) -> None:
"""
Close the job by marking its status as FINISHED_SUCCESSFULLY.
:param job: (Job) The job to be closed.
:param database: (pymongo.database.Database) An instance of the database to save the information
into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async
because of limitations with FastAPI's BackgroundTasks.
"""
LOGGER.info(f"Listener: Training finished for job {job.id}")
job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database)
LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.")
66 changes: 63 additions & 3 deletions florist/app/assets/css/florist.css
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,17 @@
color: white;
}

#job-progress .progress {
.job-progress-bar .progress {
padding: 0;
margin: 0 12px;
height: max-content;
}

#job-progress .progress-bar {
.job-progress-bar .progress-bar {
height: 25px;
}

#job-progress .progress-bar.bg-disabled {
.job-progress-bar .progress-bar.bg-disabled {
background-color: lightgray !important;
}

Expand All @@ -88,3 +88,63 @@
.job-round-details .col-sm-2 {
width: 20%;
}

.job-client-details {
border-bottom-style: hidden;
}

.job-client-progress-label div {
margin-top: -15px;
}

.job-client-progress-label.empty-cell {
padding: 0px;
}

.job-client-progress.empty-cell {
padding: 0px;
}

.job-client-progress .card {
flex-direction: row;
box-shadow: none;
margin: 0 !important;
}

.job-client-progress .card .row,
.job-client-progress .card .row .text-dark {
color: #7b809a !important;
font-weight: normal;
text-wrap: wrap;
}

.job-client-progress .card .card-body {
padding: 0;
}

.job-client-progress .card .card-header {
height: 0px;
width: 0px;
overflow: hidden;
padding: 0;
}

.job-client-progress .card .card-body .row .col-sm {
align-content: center;
}

.job-client-progress .job-progress-bar {
margin-left: -5px;
margin-bottom: 0px !important;
padding-left: 1rem !important;
}

.job-client-progress .job-progress-bar .progress {
width: 50%;
height: max-content;
margin: auto;
}

.job-client-progress .job-progress-bar .progress .progress-bar {
height: max-content;
}
Loading

0 comments on commit 92a405e

Please sign in to comment.