Skip to content

Commit

Permalink
CR by John
Browse files Browse the repository at this point in the history
  • Loading branch information
lotif committed Apr 9, 2024
1 parent 2368b08 commit e99ea7f
Show file tree
Hide file tree
Showing 15 changed files with 170 additions and 146 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,4 @@ next-env.d.ts

/metrics/
/logs/
/.ruff_cache/
7 changes: 7 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,10 @@ analysis. Ruff checks various rules including [flake8](https://docs.astral.sh/ru

Last but not the least, we use type hints in our code which is then checked using
[mypy](https://mypy.readthedocs.io/en/stable/).

## Documentation

Backend code API documentation can be found at https://vectorinstitute.github.io/FLorist/.

Backend REST API documentation can be found at https://localhost:8000/docs once the server
is running locally.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,8 @@ uvicorn florist.api.client:app --reload --port 8001
```

The service will be available at `http://localhost:8001`.

# Contributing
If you are interested in contributing to the library, please see [CONTRIBUTING.MD](CONTRIBUTING.md).
This file contains many details around contributing to the code base, including development
practices, code checks, tests, and more.
4 changes: 2 additions & 2 deletions florist/api/monitoring/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def wait_for_metric(
Check metrics on Redis under the given UUID and wait until it appears.
If the metrics are not there yet, it will retry up to max_retries times,
sleeping and amount of seconds_to_sleep_between_retries between them.
sleeping an amount of `seconds_to_sleep_between_retries` between them.
:param uuid: (str) The UUID to pull the metrics from Redis.
:param metric: (str) The metric to look for.
Expand All @@ -90,7 +90,7 @@ def wait_for_metric(
:param max_retries: (int) The maximum number of retries. Optional, default is 20.
:param seconds_to_sleep_between_retries: (int) The amount of seconds to sleep between retries.
Optional, default is 1.
:raises Exception: If it retries MAX_RETRIES times and the right metrics have not been found.
:raises Exception: If it retries `max_retries` times and the right metrics have not been found.
"""
redis_connection = redis.Redis(host=redis_host, port=redis_port)

Expand Down
1 change: 1 addition & 0 deletions florist/api/routes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""FastAPI routes."""
1 change: 1 addition & 0 deletions florist/api/routes/server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""FastAPI server routes."""
119 changes: 119 additions & 0 deletions florist/api/routes/server/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""FastAPI routes for training."""
import logging
from typing import List

import requests
from fastapi import APIRouter, Form
from fastapi.responses import JSONResponse
from typing_extensions import Annotated

from florist.api.monitoring.metrics import wait_for_metric
from florist.api.servers.common import ClientInfo, ClientInfoParseError, Model
from florist.api.servers.launch import launch_local_server


router = APIRouter()

LOGGER = logging.getLogger("uvicorn.error")

START_CLIENT_API = "api/client/start"


@router.post("/start")
def start(
model: Annotated[str, Form()],
server_address: Annotated[str, Form()],
n_server_rounds: Annotated[int, Form()],
batch_size: Annotated[int, Form()],
local_epochs: Annotated[int, Form()],
redis_host: Annotated[str, Form()],
redis_port: Annotated[str, Form()],
clients_info: Annotated[str, Form()],
) -> JSONResponse:
"""
Start FL training by starting a FL server and its clients.
Should be called with a POST request and the parameters should be contained in the request's form.
:param model: (str) The name of the model to train. Should be one of the values in the enum
florist.api.servers.common.Model
:param server_address: (str) The address of the FL server to be started. It should be comprised of
the host name and port separated by colon (e.g. "localhost:8080")
:param n_server_rounds: (int) The number of rounds the FL server should run.
:param batch_size: (int) The size of the batch for training.
:param local_epochs: (int) The number of epochs to run by the clients.
:param redis_host: (str) The host name for the Redis instance for metrics reporting.
:param redis_port: (str) The port for the Redis instance for metrics reporting.
:param clients_info: (str) A JSON string containing the client information. It will be parsed by
florist.api.servers.common.ClientInfo and should be in the following format:
[
{
"client": <client name as defined in florist.api.clients.common.Client>,
"client_address": <Florist's client hostname and port, e.g. localhost:8081>,
"data_path": <path where the data is located in the FL client's machine>,
"redis_host": <hostname of the Redis instance the FL client will be reporting to>,
"redis_port": <port of the Redis instance the FL client will be reporting to>,
}
]
:return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the server and
the clients in the format below. The UUIDs can be used to pull metrics from Redis.
{
"server_uuid": <client uuid>,
"client_uuids": [<client_uuid_1>, <client_uuid_2>, ..., <client_uuid_n>],
}
If not successful, returns the appropriate error code with a JSON with the format below:
{"error": <error message>}
"""
try:
# Parse input data
if model not in Model.list():
error_msg = f"Model '{model}' not supported. Supported models: {Model.list()}"
return JSONResponse(content={"error": error_msg}, status_code=400)

model_class = Model.class_for_model(Model[model])
clients_info_list = ClientInfo.parse(clients_info)

# Start the server
server_uuid, _ = launch_local_server(
model=model_class(),
n_clients=len(clients_info_list),
server_address=server_address,
n_server_rounds=n_server_rounds,
batch_size=batch_size,
local_epochs=local_epochs,
redis_host=redis_host,
redis_port=redis_port,
)
wait_for_metric(server_uuid, "fit_start", redis_host, redis_port, logger=LOGGER)

# Start the clients
client_uuids: List[str] = []
for client_info in clients_info_list:
parameters = {
"server_address": server_address,
"client": client_info.client.value,
"data_path": client_info.data_path,
"redis_host": client_info.redis_host,
"redis_port": client_info.redis_port,
}
response = requests.get(url=f"http://{client_info.client_address}/{START_CLIENT_API}", params=parameters)
json_response = response.json()
LOGGER.debug(f"Client response: {json_response}")

if response.status_code != 200:
raise Exception(f"Client response returned {response.status_code}. Response: {json_response}")

if "uuid" not in json_response:
raise Exception(f"Client response did not return a UUID. Response: {json_response}")

client_uuids.append(json_response["uuid"])

# Return the UUIDs
return JSONResponse({"server_uuid": server_uuid, "client_uuids": client_uuids})

except (ValueError, ClientInfoParseError) as ex:
return JSONResponse(content={"error": str(ex)}, status_code=400)

except Exception as ex:
LOGGER.exception(ex)
return JSONResponse({"error": str(ex)}, status_code=500)
118 changes: 4 additions & 114 deletions florist/api/server.py
Original file line number Diff line number Diff line change
@@ -1,118 +1,8 @@
"""FLorist server FastAPI endpoints."""
import logging
from typing import List
"""FLorist server FastAPI endpoints and routes."""
from fastapi import FastAPI

import requests
from fastapi import FastAPI, Form
from fastapi.responses import JSONResponse
from typing_extensions import Annotated

from florist.api.monitoring.metrics import wait_for_metric
from florist.api.servers.common import ClientInfo, ClientInfoParseError, Model
from florist.api.servers.launch import launch_local_server
from florist.api.routes.server.training import router as training_router


app = FastAPI()
LOGGER = logging.getLogger("uvicorn.error")

START_CLIENT_API = "api/client/start"


@app.post("/api/server/start_training")
def start_training(
model: Annotated[str, Form()],
server_address: Annotated[str, Form()],
n_server_rounds: Annotated[int, Form()],
batch_size: Annotated[int, Form()],
local_epochs: Annotated[int, Form()],
redis_host: Annotated[str, Form()],
redis_port: Annotated[str, Form()],
clients_info: Annotated[str, Form()],
) -> JSONResponse:
"""
Start FL training by starting a FL server and its clients.
Should be called with a POST request and the parameters should be contained in the request's form.
:param model: (str) The name of the model to train. Should be one of the values in the enum
florist.api.servers.common.Model
:param server_address: (str) The address of the FL server to be started. It should be comprised of
the host name and port separated by colon (e.g. "localhost:8080")
:param n_server_rounds: (int) The number of rounds the FL server should run.
:param batch_size: (int) The size of the batch for training
:param local_epochs: (int) The number of epochs to run by the clients
:param redis_host: (str) The host name for the Redis instance for metrics reporting.
:param redis_port: (str) The port for the Redis instance for metrics reporting.
:param clients_info: (str) A JSON string containing the client information. It will be parsed by
florist.api.servers.common.ClientInfo and should be in the following format:
[
{
"client": <client name as defined in florist.api.clients.common.Client>,
"client_address": <Florist's client hostname and port, e.g. localhost:8081>,
"data_path": <path where the data is located in the FL client's machine>,
"redis_host": <hostname of the Redis instance the FL client will be reporting to>,
"redis_port": <port of the Redis instance the FL client will be reporting to>,
}
]
:return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the server and
the clients in the format below. The UUIDs can be used to pull metrics from Redis.
{
"server_uuid": <client uuid>,
"client_uuids": [<client_uuid_1>, <client_uuid_2>, ..., <client_uuid_n>],
}
If not successful, returns the appropriate error code with a JSON with the format below:
{"error": <error message>}
"""
try:
# Parse input data
if model not in Model.list():
error_msg = f"Model '{model}' not supported. Supported models: {Model.list()}"
return JSONResponse(content={"error": error_msg}, status_code=400)

model_class = Model.class_for_model(Model[model])
clients_info_list = ClientInfo.parse(clients_info)

# Start the server
server_uuid, _ = launch_local_server(
model=model_class(),
n_clients=len(clients_info_list),
server_address=server_address,
n_server_rounds=n_server_rounds,
batch_size=batch_size,
local_epochs=local_epochs,
redis_host=redis_host,
redis_port=redis_port,
)
wait_for_metric(server_uuid, "fit_start", redis_host, redis_port, logger=LOGGER)

# Start the clients
client_uuids: List[str] = []
for client_info in clients_info_list:
parameters = {
"server_address": server_address,
"client": client_info.client.value,
"data_path": client_info.data_path,
"redis_host": client_info.redis_host,
"redis_port": client_info.redis_port,
}
response = requests.get(url=f"http://{client_info.client_address}/{START_CLIENT_API}", params=parameters)
json_response = response.json()
LOGGER.debug(f"Client response: {json_response}")

if response.status_code != 200:
raise Exception(f"Client response returned {response.status_code}. Response: {json_response}")

if "uuid" not in json_response:
raise Exception(f"Client response did not return a UUID. Response: {json_response}")

client_uuids.append(json_response["uuid"])

# Return the UUIDs
return JSONResponse({"server_uuid": server_uuid, "client_uuids": client_uuids})

except (ValueError, ClientInfoParseError) as ex:
return JSONResponse(content={"error": str(ex)}, status_code=400)

except Exception as ex:
LOGGER.exception(ex)
return JSONResponse({"error": str(ex)}, status_code=500)
app.include_router(training_router, tags=["training"], prefix="/api/server/training")
4 changes: 2 additions & 2 deletions florist/api/servers/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def launch_local_server(
:param n_clients: (int) The number of clients that will report to this server.
:param server_address: (str) The address the server should start at.
:param n_server_rounds: (int) The number of rounds the training should run for.
:param batch_size: (int) The size of the batch for training
:param local_epochs: (int) The number of epochs to run by the clients
:param batch_size: (int) The size of the batch for training.
:param local_epochs: (int) The number of epochs to run by the clients.
:param redis_host: (str) the host name for the Redis instance for metrics reporting.
:param redis_port: (str) the port for the Redis instance for metrics reporting.
:return: (Tuple[str, multiprocessing.Process]) the UUID of the server, which can be used to pull
Expand Down
Empty file added florist/tests/__init__.py
Empty file.
10 changes: 5 additions & 5 deletions florist/tests/integration/api/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
import uvicorn

from florist.api.monitoring.metrics import wait_for_metric
from florist.api.server import LOGGER
from florist.tests.integration.api.utils import Server
from florist.api.routes.server.training import LOGGER
from florist.tests.integration.api.utils import TestUvicornServer


def test_train():
# Define services
server_config = uvicorn.Config("florist.api.server:app", host="localhost", port=8000, log_level="debug")
server_service = Server(config=server_config)
server_service = TestUvicornServer(config=server_config)
client_config = uvicorn.Config("florist.api.client:app", host="localhost", port=8001, log_level="debug")
client_service = Server(config=client_config)
client_service = TestUvicornServer(config=client_config)

# Start services
with server_service.run_in_thread():
Expand Down Expand Up @@ -49,7 +49,7 @@ def test_train():
}
request = requests.Request(
method="POST",
url=f"http://localhost:8000/api/server/start_training",
url=f"http://localhost:8000/api/server/training/start",
files=data,
).prepare()
session = requests.Session()
Expand Down
2 changes: 1 addition & 1 deletion florist/tests/integration/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uvicorn


class Server(uvicorn.Server):
class TestUvicornServer(uvicorn.Server):
def install_signal_handlers(self):
pass

Expand Down
Empty file.
Empty file.
Loading

0 comments on commit e99ea7f

Please sign in to comment.