Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Part 1] Display logs for client and server #128

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
29 changes: 24 additions & 5 deletions florist/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,14 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red
:param redis_port: (str) the port for the Redis instance for metrics reporting.
:return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the client in the
format below, which can be used to pull metrics from Redis.
{"uuid": <client uuid>}
{
"uuid": (str) The client's uuid, which can be used to pull metrics from Redis,
"log_file_path": (str) The local path of the log file for this client,
}
If not successful, returns the appropriate error code with a JSON with the format below:
{"error": <error message>}
{
"error": (str) The error message,
}
"""
try:
if client not in Client.list():
Expand All @@ -65,10 +70,10 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red
reporters=[metrics_reporter],
)

log_file_name = str(get_client_log_file_path(client_uuid))
launch_client(client_obj, server_address, log_file_name)
log_file_path = str(get_client_log_file_path(client_uuid))
launch_client(client_obj, server_address, log_file_path)

return JSONResponse({"uuid": client_uuid})
return JSONResponse({"uuid": client_uuid, "log_file_path": log_file_path})

except Exception as ex:
return JSONResponse({"error": str(ex)}, status_code=500)
Expand Down Expand Up @@ -98,3 +103,17 @@ def check_status(client_uuid: str, redis_host: str, redis_port: str) -> JSONResp
except Exception as ex:
LOGGER.exception(ex)
return JSONResponse({"error": str(ex)}, status_code=500)


@app.get("/api/client/get_log")
def get_log(log_file_path: str) -> JSONResponse:
"""
Return the contents of the log file under the given path.

:param log_file_path: (str) the path of the logt file.

:return: (JSONResponse) Returns the contents of the file as a string.
"""
with open(log_file_path, "r") as f:
Dismissed Show dismissed Hide dismissed
content = f.read()
return JSONResponse(content)
44 changes: 43 additions & 1 deletion florist/api/db/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class ClientInfo(BaseModel):
redis_port: str = Field(...)
uuid: Optional[Annotated[str, Field(...)]]
metrics: Optional[Annotated[str, Field(...)]]
log_file_path: Optional[Annotated[str, Field(...)]]

class Config:
"""MongoDB config for the ClientInfo DB entity."""
Expand All @@ -62,6 +63,7 @@ class Config:
"redis_port": "6380",
"uuid": "0c316680-1375-4e07-84c3-a732a2e6d03f",
"metrics": '{"host_type": "client", "initialized": "2024-03-25 11:20:56.819569", "rounds": {"1": {"fit_start": "2024-03-25 11:20:56.827081"}}}',
"log_file_path": "/Users/foo/client/logfile.log",
},
}

Expand All @@ -77,6 +79,7 @@ class Job(BaseModel):
config_parser: Optional[Annotated[ConfigParser, Field(...)]]
server_uuid: Optional[Annotated[str, Field(...)]]
server_metrics: Optional[Annotated[str, Field(...)]]
server_log_file_path: Optional[Annotated[str, Field(...)]]
redis_host: Optional[Annotated[str, Field(...)]]
redis_port: Optional[Annotated[str, Field(...)]]
clients_info: Optional[Annotated[List[ClientInfo], Field(...)]]
Expand Down Expand Up @@ -129,7 +132,7 @@ async def set_uuids(self, server_uuid: str, client_uuids: List[str], database: A
"""
Save the server and clients' UUIDs in the database under the current job's id.

:param server_uuid: [str] the server_uuid to be saved in the database.
:param server_uuid: (str) the server_uuid to be saved in the database.
:param client_uuids: List[str] the list of client_uuids to be saved in the database.
:param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored.
"""
Expand Down Expand Up @@ -211,6 +214,44 @@ async def set_client_metrics(
)
assert_updated_successfully(update_result)

async def set_server_log_file_path(self, log_file_path: str, database: AsyncIOMotorDatabase[Any]) -> None:
"""
Save the server's log file path in the database under the current job's id.

:param log_file_path: (str) the file path to be saved in the database.
:param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored.
"""
job_collection = database[JOB_COLLECTION_NAME]
self.server_log_file_path = log_file_path
update_result = await job_collection.update_one(
{"_id": self.id}, {"$set": {"server_log_file_path": log_file_path}}
)
assert_updated_successfully(update_result)

async def set_client_log_file_path(
self,
client_index: int,
log_file_path: str,
database: AsyncIOMotorDatabase[Any],
) -> None:
"""
Save the clients' log file path in the database under the given client index and current job's id.

:param client_index: (str) the index of the client in the job.
:param log_file_path: (str) the path oof the client's log file.
:param database: (motor.motor_asyncio.AsyncIOMotorDatabase) The database where the job collection is stored.
"""
assert self.clients_info is not None, "Job has no clients."
assert (
0 <= client_index < len(self.clients_info)
), f"Client index {client_index} is invalid (total: {len(self.clients_info)})"

job_collection = database[JOB_COLLECTION_NAME]
update_result = await job_collection.update_one(
{"_id": self.id}, {"$set": {f"clients_info.{client_index}.log_file_path": log_file_path}}
)
assert_updated_successfully(update_result)

class Config:
"""MongoDB config for the Job DB entity."""

Expand All @@ -224,6 +265,7 @@ class Config:
"server_config": '{"n_server_rounds": 3, "batch_size": 8, "local_epochs": 1}',
"server_uuid": "d73243cf-8b89-473b-9607-8cd0253a101d",
"server_metrics": '{"host_type": "server", "fit_start": "2024-04-23 15:33:12.865604", "rounds": {"1": {"fit_start": "2024-04-23 15:33:12.869001"}}}',
"server_log_file_path": "/Users/foo/server/logfile.log",
"redis_host": "localhost",
"redis_port": "6379",
"clients_info": [
Expand Down
71 changes: 71 additions & 0 deletions florist/api/routes/server/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import List, Union

import requests
from fastapi import APIRouter, Body, Request, status
from fastapi.responses import JSONResponse

Expand Down Expand Up @@ -97,5 +98,75 @@
return JSONResponse(content={"status": "success"})
except AssertionError as assertion_e:
return JSONResponse(content={"error": str(assertion_e)}, status_code=400)
except Exception as general_e:

Check warning

Code scanning / CodeQL

Information exposure through an exception Medium

Stack trace information
flows to this location and may be exposed to an external user.
return JSONResponse(content={"error": str(general_e)}, status_code=500)


@router.get("/get_server_log/{job_id}")
async def get_server_log(job_id: str, request: Request) -> JSONResponse:
"""
Return the contents of the server's log file for the given job id.

:param job_id: (str) the ID of the job to get the server logs for.
:param request: (fastapi.Request) the FastAPI request object.

:return: (JSONResponse) if successful, returns the contents of the file as a string.
If not successful, returns the appropriate error code with a JSON with the format below:
{"error": <error message>}
"""
try:
job = await Job.find_by_id(job_id, request.app.database)

assert job is not None, f"Job {job_id} not found"
assert (
job.server_log_file_path is not None and job.server_log_file_path != ""
), "Log file path is None or empty"

with open(job.server_log_file_path, "r") as f:
content = f.read()
return JSONResponse(content)

except AssertionError as assertion_e:
return JSONResponse(content={"error": str(assertion_e)}, status_code=400)
Dismissed Show dismissed Hide dismissed
except Exception as general_e:
return JSONResponse(content={"error": str(general_e)}, status_code=500)
Dismissed Show dismissed Hide dismissed


@router.get("/get_client_log/{job_id}/{client_index}")
async def get_client_log(job_id: str, client_index: int, request: Request) -> JSONResponse:
"""
Return the contents of the log file for the client with given index under given job id.

:param job_id: (str) the ID of the job to get the client logs for.
:param client_index: (int) the index of the client within the job.
:param request: (fastapi.Request) the FastAPI request object.

:return: (JSONResponse) if successful, returns the contents of the file as a string.
If not successful, returns the appropriate error code with a JSON with the format below:
{"error": <error message>}
"""
try:
job = await Job.find_by_id(job_id, request.app.database)

assert job is not None, f"Job {job_id} not found"
assert job.clients_info is not None, "Job has no clients."
assert (
0 <= client_index < len(job.clients_info)
), f"Client index {client_index} is invalid (total: {len(job.clients_info)})"

client_info = job.clients_info[client_index]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is almost certainly a naive question, but is there a reason we are relying on the server to tell the clients where their log files are? In my head, rather than passing the location of the log to the client, the client will have stored that path already and you would just hit it with a get call to tell it that you want its log file. There's likely something more complicated there that I'm not seeing though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I sort of see where the complexity is coming from. Basically you want the server to set the log location for each client. Perhaps that's something we can communicate on startup and then don't have to worry about the server managing that from there on out?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing here is we need to store that log path somewhere. The server already has a connection to a database, along with all the other information from the clients, so I thought it would be easier to store it there.

On startup, the client will pass to the server the location of its logs, who will in turn save it on the database. The log location only lives in memory for the client, which is lost after the flower client is started up, so the server is responsible for permanently storing it for later retrieval.

Let me know if that makes sense.

assert (
client_info.log_file_path is not None and client_info.log_file_path != ""
), "Log file path is None or empty"

response = requests.get(
url=f"http://{client_info.service_address}/api/client/get_log",
params={"log_file_path": client_info.log_file_path},
)
json_response = response.json()
return JSONResponse(json_response)

except AssertionError as assertion_e:
return JSONResponse(content={"error": str(assertion_e)}, status_code=400)
Dismissed Show dismissed Hide dismissed
except Exception as general_e:
return JSONResponse(content={"error": str(general_e)}, status_code=500)
8 changes: 6 additions & 2 deletions florist/api/routes/server/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,21 @@ async def start(job_id: str, request: Request) -> JSONResponse:
model_class = Model.class_for_model(job.model)

# Start the server
server_uuid, _ = launch_local_server(
server_uuid, _, server_log_file_path = launch_local_server(
model=model_class(),
n_clients=len(job.clients_info),
server_address=job.server_address,
redis_host=job.redis_host,
redis_port=job.redis_port,
**server_config,
)
await job.set_server_log_file_path(server_log_file_path, request.app.database)
wait_for_metric(server_uuid, "fit_start", job.redis_host, job.redis_port, logger=LOGGER)

# Start the clients
client_uuids: List[str] = []
for client_info in job.clients_info:
for i in range(len(job.clients_info)):
client_info = job.clients_info[i]
parameters = {
"server_address": job.server_address,
"client": client_info.client.value,
Expand All @@ -104,6 +106,8 @@ async def start(job_id: str, request: Request) -> JSONResponse:

client_uuids.append(json_response["uuid"])

await job.set_client_log_file_path(i, json_response["log_file_path"], request.app.database)

await job.set_uuids(server_uuid, client_uuids, request.app.database)

# Start the server training listener and client training listeners as threads to update
Expand Down
14 changes: 8 additions & 6 deletions florist/api/servers/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def launch_local_server(
local_epochs: int,
redis_host: str,
redis_port: str,
) -> Tuple[str, Process]:
) -> Tuple[str, Process, str]:
"""
Launch a FL server locally.

Expand All @@ -34,8 +34,10 @@ def launch_local_server(
:param local_epochs: (int) The number of epochs to run by the clients.
:param redis_host: (str) the host name for the Redis instance for metrics reporting.
:param redis_port: (str) the port for the Redis instance for metrics reporting.
:return: (Tuple[str, multiprocessing.Process]) the UUID of the server, which can be used to pull
metrics from Redis, along with its local process object.
:return: (Tuple[str, multiprocessing.Process, str]) a tuple with
- The UUID of the server, which can be used to pull metrics from Redis
- The server's local process object
- The local path for the log file
"""
server_uuid = str(uuid.uuid4())

Expand All @@ -49,13 +51,13 @@ def launch_local_server(
local_epochs=local_epochs,
)

log_file_name = str(get_server_log_file_path(server_uuid))
log_file_path = str(get_server_log_file_path(server_uuid))
server_process = launch_server(
server_constructor,
server_address,
n_server_rounds,
log_file_name,
log_file_path,
seconds_to_sleep=0,
)

return server_uuid, server_process
return server_uuid, server_process, log_file_path
32 changes: 32 additions & 0 deletions florist/tests/integration/api/db/test_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,38 @@ async def test_set_client_metrics_fail_update_result(mock_request) -> None:
)


async def test_set_server_log_file_path_success(mock_request) -> None:
test_job = get_test_job()
result_id = await test_job.create(mock_request.app.database)
test_job.id = result_id
test_job.clients_info[0].id = ANY
test_job.clients_info[1].id = ANY

test_log_file_path = "test/log/file/path.log"

await test_job.set_server_log_file_path(test_log_file_path, mock_request.app.database)

result_job = await Job.find_by_id(result_id, mock_request.app.database)
test_job.server_log_file_path = test_log_file_path
assert result_job == test_job


async def test_set_client_log_file_path_success(mock_request) -> None:
test_job = get_test_job()
result_id = await test_job.create(mock_request.app.database)
test_job.id = result_id
test_job.clients_info[0].id = ANY
test_job.clients_info[1].id = ANY

test_log_file_path = "test/log/file/path.log"

await test_job.set_client_log_file_path(1, test_log_file_path, mock_request.app.database)

result_job = await Job.find_by_id(result_id, mock_request.app.database)
test_job.clients_info[1].log_file_path = test_log_file_path
assert result_job == test_job


def get_test_job() -> Job:
test_server_config = {
"n_server_rounds": 2,
Expand Down
2 changes: 1 addition & 1 deletion florist/tests/integration/api/launchers/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from florist.api.servers.utils import get_server


def assert_string_in_file(file_path: str, search_string: str) -> bool:
def assert_string_in_file(file_path: str, search_string: str) -> None:
with open(file_path, "r") as f:
file_contents = f.read()
match = re.search(search_string, file_contents)
Expand Down
Loading
Loading