Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly listen to client updates #106

Merged
merged 9 commits into from
Nov 1, 2024
43 changes: 28 additions & 15 deletions florist/api/db/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,36 +176,49 @@ def set_status_sync(self, status: JobStatus, database: Database[Dict[str, Any]])
update_result = job_collection.update_one({"_id": self.id}, {"$set": {"status": status.value}})
assert_updated_successfully(update_result)

def set_metrics(
def set_server_metrics(
self,
server_metrics: Dict[str, Any],
client_metrics: List[Dict[str, Any]],
database: Database[Dict[str, Any]],
) -> None:
"""
Sync function to save the server and clients' metrics in the database under the current job's id.
Sync function to save the server's metrics in the database under the current job's id.

:param server_metrics: (Dict[str, Any]) the server metrics to be saved.
:param client_metrics: (List[Dict[str, Any]]) the clients metrics to be saved.
:param database: (pymongo.database.Database) The database where the job collection is stored.
"""
assert self.clients_info is not None and len(self.clients_info) == len(client_metrics), (
"self.clients_info and client_metrics must have the same length "
f"({'None' if self.clients_info is None else len(self.clients_info)}!={len(client_metrics)})."
)

job_collection = database[JOB_COLLECTION_NAME]

self.server_metrics = json.dumps(server_metrics)
update_result = job_collection.update_one({"_id": self.id}, {"$set": {"server_metrics": self.server_metrics}})
assert_updated_successfully(update_result)

for i in range(len(client_metrics)):
self.clients_info[i].metrics = json.dumps(client_metrics[i])
update_result = job_collection.update_one(
{"_id": self.id}, {"$set": {f"clients_info.{i}.metrics": self.clients_info[i].metrics}}
)
assert_updated_successfully(update_result)
def set_client_metrics(
self,
client_uuid: str,
client_metrics: Dict[str, Any],
database: Database[Dict[str, Any]],
) -> None:
"""
Sync function to save a clients' metrics in the database under the current job's id.

:param client_uuid: (str) the client's uuid whose produced the metrics.
:param client_metrics: (Dict[str, Any]) the client's metrics to be saved.
:param database: (pymongo.database.Database) The database where the job collection is stored.
"""
assert (
self.clients_info is not None and client_uuid in [c.uuid for c in self.clients_info]
), f"client uuid {client_uuid} is not in clients_info ({[c.uuid for c in self.clients_info] if self.clients_info is not None else None})"

job_collection = database[JOB_COLLECTION_NAME]

for i in range(len(self.clients_info)):
if client_uuid == self.clients_info[i].uuid:
self.clients_info[i].metrics = json.dumps(client_metrics)
update_result = job_collection.update_one(
{"_id": self.id}, {"$set": {f"clients_info.{i}.metrics": self.clients_info[i].metrics}}
)
assert_updated_successfully(update_result)

class Config:
"""MongoDB config for the Job DB entity."""
Expand Down
109 changes: 57 additions & 52 deletions florist/api/routes/server/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi.responses import JSONResponse
from pymongo.database import Database

from florist.api.db.entities import Job, JobStatus
from florist.api.db.entities import ClientInfo, Job, JobStatus
from florist.api.monitoring.metrics import get_from_redis, get_subscriber, wait_for_metric
from florist.api.servers.common import Model
from florist.api.servers.config_parsers import ConfigParser
Expand Down Expand Up @@ -106,8 +106,10 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks
await job.set_uuids(server_uuid, client_uuids, request.app.database)

# Start the server training listener as a background task to update
# the job's status once the training is done
# the job's metrics and status once the training is done
background_tasks.add_task(server_training_listener, job, request.app.synchronous_database)
for client_info in job.clients_info:
background_tasks.add_task(client_training_listener, job, client_info, request.app.synchronous_database)

# Return the UUIDs
return JSONResponse({"server_uuid": server_uuid, "client_uuids": client_uuids})
Expand All @@ -124,6 +126,47 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks
return JSONResponse({"error": str(ex)}, status_code=500)


def client_training_listener(job: Job, client_info: ClientInfo, database: Database[Dict[str, Any]]) -> None:
"""
Listen to the Redis' channel that reports updates on the training process of a FL client.

Keeps consuming updates to the channel until it finds `shutdown` in the client metrics.

:param job: (Job) The job that has this client's metrics.
:param client_info: (ClientInfo) The ClientInfo with the client_uuid to listen to.
:param database: (pymongo.database.Database) An instance of the database to save the information
into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async
because of limitations with FastAPI's BackgroundTasks.
"""
LOGGER.info(f"Starting listener for client messages from job {job.id} at channel {client_info.uuid}")

assert client_info.uuid is not None, "client_info.uuid is None."

# check if training has already finished before start listening
client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port)
LOGGER.debug(f"Listener: Current metrics for client {client_info.uuid}: {client_metrics}")
if client_metrics is not None:
LOGGER.info(f"Listener: Updating client metrics for client {client_info.uuid} on job {job.id}")
job.set_client_metrics(client_info.uuid, client_metrics, database)
LOGGER.info(f"Listener: Client metrics for client {client_info.uuid} on {job.id} has been updated.")
if "shutdown" in client_metrics:
return

subscriber = get_subscriber(client_info.uuid, client_info.redis_host, client_info.redis_port)
# TODO add a max retries mechanism, maybe?
for message in subscriber.listen(): # type: ignore[no-untyped-call]
if message["type"] == "message":
# The contents of the message do not matter, we just use it to get notified
client_metrics = get_from_redis(client_info.uuid, client_info.redis_host, client_info.redis_port)
LOGGER.debug(f"Listener: Current metrics for client {client_info.uuid}: {client_metrics}")
if client_metrics is not None:
LOGGER.info(f"Listener: Updating client metrics for client {client_info.uuid} on job {job.id}")
job.set_client_metrics(client_info.uuid, client_metrics, database)
LOGGER.info(f"Listener: Client metrics for client {client_info.uuid} on {job.id} has been updated.")
if "shutdown" in client_metrics:
return


def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> None:
"""
Listen to the Redis' channel that reports updates on the training process of a FL server.
Expand All @@ -147,9 +190,13 @@ def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> No
server_metrics = get_from_redis(job.server_uuid, job.redis_host, job.redis_port)
LOGGER.debug(f"Listener: Current metrics for job {job.id}: {server_metrics}")
if server_metrics is not None:
update_job_metrics(job, server_metrics, database)
LOGGER.info(f"Listener: Updating server metrics for job {job.id}")
job.set_server_metrics(server_metrics, database)
LOGGER.info(f"Listener: Server metrics for {job.id} has been updated.")
if "fit_end" in server_metrics:
close_job(job, database)
LOGGER.info(f"Listener: Training finished for job {job.id}")
job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database)
LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.")
return

subscriber = get_subscriber(job.server_uuid, job.redis_host, job.redis_port)
Expand All @@ -161,53 +208,11 @@ def server_training_listener(job: Job, database: Database[Dict[str, Any]]) -> No
LOGGER.debug(f"Listener: Message received for job {job.id}. Metrics: {server_metrics}")

if server_metrics is not None:
update_job_metrics(job, server_metrics, database)
LOGGER.info(f"Listener: Updating server metrics for job {job.id}")
job.set_server_metrics(server_metrics, database)
LOGGER.info(f"Listener: Server metrics for {job.id} has been updated.")
if "fit_end" in server_metrics:
close_job(job, database)
LOGGER.info(f"Listener: Training finished for job {job.id}")
job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database)
LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.")
return


def update_job_metrics(job: Job, server_metrics: Dict[str, Any], database: Database[Dict[str, Any]]) -> None:
"""
Update the job with server and client metrics.

Collect the job's clients metrics, saving them and the server's metrics to the job.

:param job: (Job) The job to be updated.
:param server_metrics: (Dict[str, Any]) The server's metrics to be saved into the job.
:param database: (pymongo.database.Database) An instance of the database to save the information
into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async
because of limitations with FastAPI's BackgroundTasks.
"""
LOGGER.info(f"Listener: Updating metrics for job {job.id}")

clients_metrics: List[Dict[str, Any]] = []
if job.clients_info is not None:
for client_info in job.clients_info:
response = requests.get(
url=f"http://{client_info.service_address}/{CHECK_CLIENT_STATUS_API}/{client_info.uuid}",
params={
"redis_host": client_info.redis_host,
"redis_port": client_info.redis_port,
},
)
client_metrics = response.json()
clients_metrics.append(client_metrics)

job.set_metrics(server_metrics, clients_metrics, database)

LOGGER.info(f"Listener: Job {job.id} has been updated.")


def close_job(job: Job, database: Database[Dict[str, Any]]) -> None:
"""
Close the job by marking its status as FINISHED_SUCCESSFULLY.

:param job: (Job) The job to be closed.
:param database: (pymongo.database.Database) An instance of the database to save the information
into the Job. MUST BE A SYNCHRONOUS DATABASE since this function cannot be marked as async
because of limitations with FastAPI's BackgroundTasks.
"""
LOGGER.info(f"Listener: Training finished for job {job.id}")
job.set_status_sync(JobStatus.FINISHED_SUCCESSFULLY, database)
LOGGER.info(f"Listener: Job {job.id} status has been set to {job.status.value}.")
53 changes: 32 additions & 21 deletions florist/tests/integration/api/db/test_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,64 +169,75 @@ async def test_set_status_sync_fail_update_result(mock_request) -> None:
test_job.set_status_sync(JobStatus.IN_PROGRESS, mock_request.app.synchronous_database)


async def test_set_metrics_success(mock_request) -> None:
async def test_set_server_metrics_success(mock_request) -> None:
test_job = get_test_job()
result_id = await test_job.create(mock_request.app.database)
test_job.id = result_id
test_job.clients_info[0].id = ANY
test_job.clients_info[1].id = ANY

test_server_metrics = {"test-server": 123}
test_client_metrics = [{"test-client-1": 456}, {"test-client-2": 789}]

test_job.set_metrics(test_server_metrics, test_client_metrics, mock_request.app.synchronous_database)
test_job.set_server_metrics(test_server_metrics, mock_request.app.synchronous_database)

result_job = await Job.find_by_id(result_id, mock_request.app.database)
test_job.server_metrics = json.dumps(test_server_metrics)
test_job.clients_info[0].metrics = json.dumps(test_client_metrics[0])
test_job.clients_info[1].metrics = json.dumps(test_client_metrics[1])
assert result_job == test_job


async def test_set_metrics_fail_clients_info_is_none(mock_request) -> None:
async def test_set_server_metrics_fail_update_result(mock_request) -> None:
test_job = get_test_job()
test_job.clients_info = None
result_id = await test_job.create(mock_request.app.database)
test_job.id = result_id
test_job.id = str(test_job.id)

test_server_metrics = {"test-server": 123}
test_client_metrics = [{"test-client-1": 456}, {"test-client-2": 789}]

error_msg = "self.clients_info and client_metrics must have the same length (None!=2)."
error_msg = "UpdateResult's 'n' is not 1"
with raises(AssertionError, match=re.escape(error_msg)):
test_job.set_metrics(test_server_metrics, test_client_metrics, mock_request.app.synchronous_database)
test_job.set_server_metrics(test_server_metrics, mock_request.app.synchronous_database)


async def test_set_metrics_fail_clients_info_is_not_same_length(mock_request) -> None:
async def test_set_client_metrics_success(mock_request) -> None:
test_job = get_test_job()
result_id = await test_job.create(mock_request.app.database)
test_job.id = result_id
test_job.clients_info[0].id = ANY
test_job.clients_info[1].id = ANY

test_server_metrics = {"test-server": 123}
test_client_metrics = [{"test-client-1": 456}]
test_client_metrics = [{"test-metric-1": 456}, {"test-metric-2": 789}]

test_job.set_client_metrics(test_job.clients_info[1].uuid, test_client_metrics, mock_request.app.synchronous_database)

error_msg = "self.clients_info and client_metrics must have the same length (2!=1)."
result_job = await Job.find_by_id(result_id, mock_request.app.database)
test_job.clients_info[1].metrics = json.dumps(test_client_metrics)
assert result_job == test_job


async def test_set_metrics_fail_clients_info_is_none(mock_request) -> None:
test_job = get_test_job()
result_id = await test_job.create(mock_request.app.database)
test_job.id = result_id

test_wrong_client_uuid = "client-id-that-does-not-exist"
test_client_metrics = [{"test-metric-1": 456}, {"test-metric-2": 789}]

error_msg = f"client uuid {test_wrong_client_uuid} is not in clients_info (['{test_job.clients_info[0].uuid}', '{test_job.clients_info[1].uuid}'])"
with raises(AssertionError, match=re.escape(error_msg)):
test_job.set_metrics(test_server_metrics, test_client_metrics, mock_request.app.synchronous_database)
test_job.set_client_metrics(test_wrong_client_uuid, test_client_metrics, mock_request.app.synchronous_database)


async def test_set_metrics_fail_update_result(mock_request) -> None:
async def test_set_client_metrics_fail_update_result(mock_request) -> None:
test_job = get_test_job()
test_job.id = str(test_job.id)

test_server_metrics = {"test-server": 123}
test_client_metrics = [{"test-client-1": 456}, {"test-client-2": 789}]
test_client_metrics = [{"test-metric-1": 456}, {"test-metric-2": 789}]

error_msg = "UpdateResult's 'n' is not 1"
with raises(AssertionError, match=re.escape(error_msg)):
test_job.set_metrics(test_server_metrics, test_client_metrics, mock_request.app.synchronous_database)
test_job.set_client_metrics(
test_job.clients_info[0].uuid,
test_client_metrics,
mock_request.app.synchronous_database,
)


def get_test_job() -> Job:
Expand Down
Loading