From 841f83cf7dca59bc7e718382f0193c9870a5a95a Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 22 Mar 2024 17:13:48 -0400 Subject: [PATCH 01/24] WIP Making the start server api, needs tests --- CONTRIBUTING.md | 5 + README.md | 4 +- florist/api/client.py | 14 ++- florist/api/clients/common.py | 14 +-- florist/api/clients/mnist.py | 29 +---- florist/api/index.py | 15 --- florist/api/models/__init__.py | 1 + florist/api/models/mnist.py | 30 ++++++ florist/api/server.py | 108 +++++++++++++++++++ florist/api/servers/common.py | 106 ++++++++++++++++++ florist/api/servers/launch.py | 98 +++++++++++++++++ florist/api/servers/local.py | 49 --------- florist/tests/integration/api/test_train.py | 2 +- florist/tests/unit/api/servers/test_local.py | 2 +- package.json | 4 +- poetry.lock | 16 ++- pyproject.toml | 1 + 17 files changed, 386 insertions(+), 112 deletions(-) delete mode 100644 florist/api/index.py create mode 100644 florist/api/models/__init__.py create mode 100644 florist/api/models/mnist.py create mode 100644 florist/api/server.py create mode 100644 florist/api/servers/common.py create mode 100644 florist/api/servers/launch.py delete mode 100644 florist/api/servers/local.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a4ddbbef..09cc5a54 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -48,6 +48,11 @@ Then, run the server and client's Redis instance by following [these instructions](README.md#start-servers-redis-instance) and [these instructions](README.md#start-clients-redis-instance) respectively. +To start the server in development mode, run: +```shell +yarn dev +``` + ## Running the tests To run the unit tests, simply execute: diff --git a/README.md b/README.md index cddf1cbd..99a320fc 100644 --- a/README.md +++ b/README.md @@ -63,10 +63,10 @@ docker start redis-florist-server ### Start back-end and front-end servers -Use Yarn to run both the back-end and front-end on server mode: +Use Yarn to run both the back-end and front-end on production server mode: ```shell -yarn dev +yarn prod ``` The front-end will be available at `http://localhost:3000`. If you want to access diff --git a/florist/api/client.py b/florist/api/client.py index 720dee97..4d21b58b 100644 --- a/florist/api/client.py +++ b/florist/api/client.py @@ -6,7 +6,7 @@ from fastapi import FastAPI from fastapi.responses import JSONResponse -from florist.api.clients.common import Clients +from florist.api.clients.common import Client from florist.api.launchers.local import launch_client from florist.api.monitoring.logs import get_client_log_file_path from florist.api.monitoring.metrics import RedisMetricsReporter @@ -30,7 +30,7 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red """ Start a client. - :param server_address: (str) the address of the server this client should report to. + :param server_address: (str) the address of the FL server the FL client should report to. It should be comprised of the host name and port separated by colon (e.g. "localhost:8080"). :param client: (str) the name of the client. Should be one of the enum values of florist.api.client.Clients. :param data_path: (str) the path where the training data is located. @@ -43,18 +43,16 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red {"error": } """ try: - if client not in Clients.list(): - return JSONResponse( - content={"error": f"Client '{client}' not supported. Supported clients: {Clients.list()}"}, - status_code=400, - ) + if client not in Client.list(): + error_msg = f"Client '{client}' not supported. Supported clients: {Client.list()}" + return JSONResponse(content={"error": error_msg}, status_code=400) client_uuid = str(uuid.uuid4()) metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=client_uuid) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - client_class = Clients.class_for_client(Clients[client]) + client_class = Client.class_for_client(Client[client]) client_obj = client_class( data_path=Path(data_path), metrics=[], diff --git a/florist/api/clients/common.py b/florist/api/clients/common.py index 121f06ad..bdc77f16 100644 --- a/florist/api/clients/common.py +++ b/florist/api/clients/common.py @@ -7,21 +7,21 @@ from florist.api.clients.mnist import MnistClient -class Clients(Enum): +class Client(Enum): """Enumeration of supported clients.""" MNIST = "MNIST" @classmethod - def class_for_client(cls, client: "Clients") -> type[BasicClient]: + def class_for_client(cls, client: "Client") -> type[BasicClient]: """ Return the class for a given client. - :param client: The client enumeration object. - :return: A subclass of BasicClient corresponding to the given client. + :param client: (Client) The client enumeration object. + :return: (type[BasicClient]) A subclass of BasicClient corresponding to the given client. :raises ValueError: if the client is not supported. """ - if client == Clients.MNIST: + if client == Client.MNIST: return MnistClient raise ValueError(f"Client {client.value} not supported.") @@ -31,6 +31,6 @@ def list(cls) -> List[str]: """ List all the supported clients. - :return: a list of supported clients. + :return: (List[str]) a list of supported clients. """ - return [client.value for client in Clients] + return [client.value for client in Client] diff --git a/florist/api/clients/mnist.py b/florist/api/clients/mnist.py index 57464694..f05733c1 100644 --- a/florist/api/clients/mnist.py +++ b/florist/api/clients/mnist.py @@ -2,7 +2,6 @@ from typing import Tuple import torch -import torch.nn.functional as f from fl4health.clients.basic_client import BasicClient from fl4health.utils.dataset import MnistDataset from fl4health.utils.load_data import load_mnist_data @@ -12,6 +11,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader +from florist.api.models.mnist import MnistNet + class MnistClient(BasicClient): # type: ignore """Implementation of the MNIST client.""" @@ -54,29 +55,3 @@ def get_criterion(self, config: Config) -> _Loss: :return: (torch.nn.modules.loss._Loss) an instance of torch.nn.CrossEntropyLoss. """ return torch.nn.CrossEntropyLoss() - - -class MnistNet(nn.Module): - """Implementation of the Mnist model.""" - - def __init__(self) -> None: - """Initialize an instance of MnistNet.""" - super().__init__() - self.conv1 = nn.Conv2d(1, 8, 5) - self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(8, 16, 5) - self.fc1 = nn.Linear(16 * 4 * 4, 120) - self.fc2 = nn.Linear(120, 10) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Perform a forward pass for the given tensor. - - :param x: (torch.Tensor) the tensor to perform the forward pass on. - :return: (torch.Tensor) a result tensor after the forward pass. - """ - x = self.pool(f.relu(self.conv1(x))) - x = self.pool(f.relu(self.conv2(x))) - x = x.view(-1, 16 * 4 * 4) - x = f.relu(self.fc1(x)) - return f.relu(self.fc2(x)) diff --git a/florist/api/index.py b/florist/api/index.py deleted file mode 100644 index 656ad2fa..00000000 --- a/florist/api/index.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Initial sample FastAPI endpoints file.""" -from fastapi import FastAPI - - -app = FastAPI() - - -@app.get("/api/python") -def hello_world() -> str: - """ - Provide a simple hello world endpoint. - - :return: the string `hello world` - """ - return "hello world" diff --git a/florist/api/models/__init__.py b/florist/api/models/__init__.py new file mode 100644 index 00000000..b55b85ff --- /dev/null +++ b/florist/api/models/__init__.py @@ -0,0 +1 @@ +"""Contains the models definitions.""" diff --git a/florist/api/models/mnist.py b/florist/api/models/mnist.py new file mode 100644 index 00000000..dd4a71be --- /dev/null +++ b/florist/api/models/mnist.py @@ -0,0 +1,30 @@ +"""Definitions for the MNIST model.""" +import torch +import torch.nn.functional as f +from torch import nn + + +class MnistNet(nn.Module): + """Implementation of the Mnist model.""" + + def __init__(self) -> None: + """Initialize an instance of MnistNet.""" + super().__init__() + self.conv1 = nn.Conv2d(1, 8, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(8, 16, 5) + self.fc1 = nn.Linear(16 * 4 * 4, 120) + self.fc2 = nn.Linear(120, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Perform a forward pass for the given tensor. + + :param x: (torch.Tensor) the tensor to perform the forward pass on. + :return: (torch.Tensor) a result tensor after the forward pass. + """ + x = self.pool(f.relu(self.conv1(x))) + x = self.pool(f.relu(self.conv2(x))) + x = x.view(-1, 16 * 4 * 4) + x = f.relu(self.fc1(x)) + return f.relu(self.fc2(x)) diff --git a/florist/api/server.py b/florist/api/server.py new file mode 100644 index 00000000..c6c8d512 --- /dev/null +++ b/florist/api/server.py @@ -0,0 +1,108 @@ +"""FLorist server FastAPI endpoints.""" +import logging +from typing import List + +import requests +from fastapi import FastAPI, Form +from fastapi.responses import JSONResponse +from typing_extensions import Annotated + +from florist.api.servers.common import ClientInfo, ClientInfoParseError, Model +from florist.api.servers.launch import launch_local_server, wait_until_server_is_started + + +app = FastAPI() +LOGGER = logging.getLogger("uvicorn.error") + +START_CLIENT_API = "api/client/start" + + +@app.post("/api/server/start_training") +def start_training( + model: Annotated[str, Form()], + server_address: Annotated[str, Form()], + n_server_rounds: Annotated[int, Form()], + redis_host: Annotated[str, Form()], + redis_port: Annotated[str, Form()], + clients_info: Annotated[str, Form()], +) -> JSONResponse: + """ + Start FL training by starting a FL server and its clients. + + Should be called with a POST request and the parameters should be contained in the request's form. + + :param model: (str) The name of the model to train. Should be one of the values in the enum + florist.api.servers.common.Model + :param server_address: (str) The address of the FL server to be started. It should be comprised of + the host name and port separated by colon (e.g. "localhost:8080") + :param n_server_rounds: (int) The number of rounds the FL server should run. + :param redis_host: (str) The host name for the Redis instance for metrics reporting. + :param redis_port: (str) The port for the Redis instance for metrics reporting. + :param clients_info: (str) A JSON string containing the client information. It will be parsed by + florist.api.servers.common.ClientInfo and should be in the following format: + [ + { + "client": , + "client_address": , + "data_path": , + "redis_host": , + "redis_port": , + } + ] + :return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the server and + the clients in the format below. The UUIDs can be used to pull metrics from Redis. + { + "server_uuid": , + "client_uuids": [, , ..., ], + } + If not successful, returns the appropriate error code with a JSON with the format below: + {"error": } + """ + try: + # Parse input data + if model not in Model.list(): + error_msg = f"Model '{model}' not supported. Supported models: {Model.list()}" + return JSONResponse(content={"error": error_msg}, status_code=400) + + model_class = Model.class_for_model(Model[model]) + clients_info_list = ClientInfo.parse(clients_info) + + # Start the server + server_uuid, _ = launch_local_server( + model=model_class(), + n_clients=len(clients_info_list), + server_address=server_address, + n_server_rounds=n_server_rounds, + redis_host=redis_host, + redis_port=redis_port, + ) + wait_until_server_is_started(server_uuid, redis_host, redis_port, logger=LOGGER) + + # Start the clients + client_uuids: List[str] = [] + for client_info in clients_info_list: + parameters = { + "server_address": server_address, + "client": client_info.client.value, + "data_path": client_info.data_path, + "redis_host": client_info.redis_host, + "redis_port": client_info.redis_port, + } + response = requests.get(url=f"http://{client_info.client_address}/{START_CLIENT_API}", params=parameters) + json_response = response.json() + LOGGER.debug(f"Client response: {json_response}") + + if "uuid" not in json_response: + raise Exception(f"Client response did not return a UUID. Response: {json_response}") + + client_uuids.append(json_response["uuid"]) + + # Return the UUIDs + return JSONResponse({"server_uuid": server_uuid, "client_uuids": client_uuids}) + + except (ValueError, ClientInfoParseError) as ex: + return JSONResponse(content={"error": str(ex)}, status_code=400) + + except Exception as ex: + LOGGER.exception(ex) + return JSONResponse({"error": str(ex)}, status_code=500) diff --git a/florist/api/servers/common.py b/florist/api/servers/common.py new file mode 100644 index 00000000..a8c784f6 --- /dev/null +++ b/florist/api/servers/common.py @@ -0,0 +1,106 @@ +"""Common functions and definitions for servers.""" +import json +from enum import Enum +from typing import List + +from torch import nn + +from florist.api.clients.common import Client +from florist.api.models.mnist import MnistNet + + +class ClientInfo: + """Define the input information necessary to start a client.""" + + def __init__(self, client: Client, client_address: str, data_path: str, redis_host: str, redis_port: str): + self.client = client + self.client_address = client_address + self.data_path = data_path + self.redis_host = redis_host + self.redis_port = redis_port + + @classmethod + def parse(cls, clients_info: str) -> List["ClientInfo"]: + """ + Parse the client information JSON string into a ClientInfo instance. + + :param clients_info: (str) A JSON string containing the client information. + Should be in the following format: + [ + { + "client": , + "client_address": , + "data_path": , + "redis_host": , + "redis_port": , + } + ] + :return: (ClientInfo) an instance of ClientInfo containing the information given. + :raises ClientInfoParseError: If any of the required information is missing or has the + wrong type. + """ + client_info_list: List[ClientInfo] = [] + + json_clients_info = json.loads(clients_info) + for client_info in json_clients_info: + if "client" not in client_info or not isinstance(client_info["client"], str): + raise ClientInfoParseError("clients_info does not contain key 'client'") + if client_info["client"] not in Client.list(): + error_msg = f"Client '{client_info['client']}' not supported. Supported clients: {Client.list()}" + raise ClientInfoParseError(error_msg) + client = Client[client_info["client"]] + + if "client_address" not in client_info or not isinstance(client_info["client_address"], str): + raise ClientInfoParseError("clients_info does not contain key 'client_address'") + client_address = client_info["client_address"] + + if "data_path" not in client_info or not isinstance(client_info["data_path"], str): + raise ClientInfoParseError("clients_info does not contain key 'data_path'") + data_path = client_info["data_path"] + + if "redis_host" not in client_info or not isinstance(client_info["redis_host"], str): + raise ClientInfoParseError("clients_info does not contain key 'redis_host'") + redis_host = client_info["redis_host"] + + if "redis_port" not in client_info or not isinstance(client_info["redis_port"], str): + raise ClientInfoParseError("clients_info does not contain key 'redis_port'") + redis_port = client_info["redis_port"] + + client_info_list.append(ClientInfo(client, client_address, data_path, redis_host, redis_port)) + + return client_info_list + + +class ClientInfoParseError(Exception): + """Defines errors in parsing client info.""" + + pass + + +class Model(Enum): + """Enumeration of supported models.""" + + MNIST = "MNIST" + + @classmethod + def class_for_model(cls, model: "Model") -> type[nn.Module]: + """ + Return the class for a given model. + + :param model: (Model) The model enumeration object. + :return: (type[torch.nn.Module]) A torch.nn.Module class corresponding to the given model. + :raises ValueError: if the client is not supported. + """ + if model == Model.MNIST: + return MnistNet + + raise ValueError(f"Model {model.value} not supported.") + + @classmethod + def list(cls) -> List[str]: + """ + List all the supported models. + + :return: (List[str]) a list of supported models. + """ + return [model.value for model in Model] diff --git a/florist/api/servers/launch.py b/florist/api/servers/launch.py new file mode 100644 index 00000000..94267738 --- /dev/null +++ b/florist/api/servers/launch.py @@ -0,0 +1,98 @@ +"""Functions and definitions to launch local servers.""" +import json +import time +import uuid +from functools import partial +from logging import Logger +from multiprocessing import Process +from typing import Tuple + +from redis import Redis +from torch import nn + +from florist.api.launchers.local import launch_server +from florist.api.monitoring.logs import get_server_log_file_path +from florist.api.monitoring.metrics import RedisMetricsReporter +from florist.api.servers.utils import get_server + + +def launch_local_server( + model: nn.Module, + n_clients: int, + server_address: str, + n_server_rounds: int, + redis_host: str, + redis_port: str, +) -> Tuple[str, Process]: + """ + Launch a FL server locally. + + :param model: (torch.nn.Module) The model to be used by the server. Should match the clients' model. + :param n_clients: (int) The number of clients that will report to this server. + :param server_address: (str) The address the server should start at. + :param n_server_rounds: (int) The number of rounds the training should run for. + :param redis_host: (str) the host name for the Redis instance for metrics reporting. + :param redis_port: (str) the port for the Redis instance for metrics reporting. + :return: (Tuple[str, multiprocessing.Process]) the UUID of the server, which can be used to pull + metrics from Redis, along with its local process object. + """ + server_uuid = str(uuid.uuid4()) + + metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=server_uuid) + server_constructor = partial(get_server, model=model, n_clients=n_clients, metrics_reporter=metrics_reporter) + + log_file_name = str(get_server_log_file_path(server_uuid)) + server_process = launch_server( + server_constructor, + server_address, + n_server_rounds, + log_file_name, + seconds_to_sleep=0, + ) + + return server_uuid, server_process + + +MAX_RETRIES = 20 +SECONDS_TO_SLEEP_BETWEEN_RETRIES = 1 + + +def wait_until_server_is_started(server_uuid: str, redis_host: str, redis_port: str, logger: Logger) -> None: + """ + Check server's metrics on Redis and wait until it has been started. + + If the right metrics are not there yet, it will retry up to MAX_RETRIES times, + sleeping and amount of SECONDS_TO_SLEEP_BETWEEN_RETRIES between them. + + :param server_uuid: (str) The UUID of the server in order to pull its metrics from Redis. + :param redis_host: (str) The hostname of the Redis instance this server is reporting to. + :param redis_port: (str) The port of the Redis instance this server is reporting to. + :param logger: (logging.Logger) A logger instance to write logs to. + :raises Exception: If it retries MAX_RETRIES times and the right metrics have not been found. + """ + redis_connection = Redis(host=redis_host, port=redis_port) + + retry = 0 + while retry < MAX_RETRIES: + result = redis_connection.get(server_uuid) + + if result is not None: + assert isinstance(result, bytes) + json_result = json.loads(result.decode("utf8")) + if "fit_start" in json_result: + logger.debug(f"Server has started. Result: {json_result}") + return + + logger.debug( + f"Server is not started yet, sleeping for {SECONDS_TO_SLEEP_BETWEEN_RETRIES}. " + f"Retry: {retry}. Result: {json_result}" + ) + else: + logger.debug( + f"Server is not started yet, sleeping for {SECONDS_TO_SLEEP_BETWEEN_RETRIES}. " + f"Retry: {retry}. Result is None." + ) + time.sleep(SECONDS_TO_SLEEP_BETWEEN_RETRIES) + retry += 1 + + raise Exception(f"Server failed to start after {MAX_RETRIES} retries.") diff --git a/florist/api/servers/local.py b/florist/api/servers/local.py deleted file mode 100644 index 957558c2..00000000 --- a/florist/api/servers/local.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Functions and definitions to launch local servers.""" -import uuid -from functools import partial -from multiprocessing import Process -from typing import Tuple - -from torch import nn - -from florist.api.launchers.local import launch_server -from florist.api.monitoring.logs import get_server_log_file_path -from florist.api.monitoring.metrics import RedisMetricsReporter -from florist.api.servers.utils import get_server - - -def launch_local_server( - model: nn.Module, - n_clients: int, - server_address: str, - n_server_rounds: int, - redis_host: str, - redis_port: str, -) -> Tuple[str, Process]: - """ - Launch a FL server locally. - - :param model: (torch.nn.Module) The model to be used by the server. Should match the clients' model. - :param n_clients: (int) The number of clients that will report to this server. - :param server_address: (str) The address the server should start at. - :param n_server_rounds: (int) The number of rounds the training should run for. - :param redis_host: (str) the host name for the Redis instance for metrics reporting. - :param redis_port: (str) the port for the Redis instance for metrics reporting. - :return: (Tuple[str, multiprocessing.Process]) the UUID of the server, which can be used to pull - metrics from Redis, along with its local process object. - """ - server_uuid = str(uuid.uuid4()) - - metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=server_uuid) - server_constructor = partial(get_server, model=model, n_clients=n_clients, metrics_reporter=metrics_reporter) - - log_file_name = str(get_server_log_file_path(server_uuid)) - server_process = launch_server( - server_constructor, - server_address, - n_server_rounds, - log_file_name, - seconds_to_sleep=0, - ) - - return server_uuid, server_process diff --git a/florist/tests/integration/api/test_train.py b/florist/tests/integration/api/test_train.py index 34c61b3b..ed159f79 100644 --- a/florist/tests/integration/api/test_train.py +++ b/florist/tests/integration/api/test_train.py @@ -8,7 +8,7 @@ from florist.api import client from florist.api.clients.mnist import MnistNet from florist.api.monitoring.logs import get_server_log_file_path -from florist.api.servers.local import launch_local_server +from florist.api.servers.launch import launch_local_server def test_train(): diff --git a/florist/tests/unit/api/servers/test_local.py b/florist/tests/unit/api/servers/test_local.py index 69e06d1a..be0aa836 100644 --- a/florist/tests/unit/api/servers/test_local.py +++ b/florist/tests/unit/api/servers/test_local.py @@ -3,7 +3,7 @@ from florist.api.clients.mnist import MnistNet from florist.api.monitoring.logs import get_server_log_file_path from florist.api.monitoring.metrics import RedisMetricsReporter -from florist.api.servers.local import launch_local_server +from florist.api.servers.launch import launch_local_server from florist.api.servers.utils import get_server diff --git a/package.json b/package.json index 9f68d869..26da5082 100644 --- a/package.json +++ b/package.json @@ -3,9 +3,11 @@ "version": "0.1.0", "private": true, "scripts": { - "fastapi-dev": "poetry install --with dev && python -m uvicorn florist.api.index:app --reload", + "fastapi-dev": "poetry install --with test && python -m uvicorn florist.api.server:app --reload --log-level debug", + "fastapi-prod": "poetry install --with test && python -m uvicorn florist.api.server:app --reload", "next-dev": "next dev florist", "dev": "concurrently \"npm run next-dev\" \"npm run fastapi-dev\"", + "prod": "concurrently \"npm run next-dev\" \"npm run fastapi-prod\"", "build": "next build", "start": "next start", "lint": "next lint" diff --git a/poetry.lock b/poetry.lock index cc725c82..0c5d8ad3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3530,6 +3530,20 @@ files = [ [package.extras] cli = ["click (>=5.0)"] +[[package]] +name = "python-multipart" +version = "0.0.9" +description = "A streaming multipart parser for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"}, + {file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"}, +] + +[package.extras] +dev = ["atomicwrites (==1.4.1)", "attrs (==23.2.0)", "coverage (==7.4.1)", "hatch", "invoke (==2.2.0)", "more-itertools (==10.2.0)", "pbr (==6.0.0)", "pluggy (==1.4.0)", "py (==1.11.0)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.2.0)", "pyyaml (==6.0.1)", "ruff (==0.2.1)"] + [[package]] name = "pytz" version = "2024.1" @@ -5513,4 +5527,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.11" -content-hash = "17c0b90826fdf0ba50663f1fb54861e1164919cbd4b30abbc5e22dc7bc77642f" +content-hash = "7cbd2fbe52c7deb241805f1dc3fa1044512bd069ba8482d930a01c28680b454f" diff --git a/pyproject.toml b/pyproject.toml index a45ec8e4..bd7f7d35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ fl4health = "^0.1.13" wandb = "^0.16.3" torchvision = "0.14.1" redis = "^5.0.1" +python-multipart = "^0.0.9" [tool.poetry.group.test] optional = true From c31766a5307d71c0dd28c64d1a5db05c181bdd86 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Mon, 25 Mar 2024 11:37:14 -0400 Subject: [PATCH 02/24] Modifying integration test --- florist/api/monitoring/metrics.py | 56 ++++++++++- florist/api/server.py | 5 +- florist/api/servers/launch.py | 49 --------- florist/tests/integration/api/test_train.py | 104 +++++++++++++------- florist/tests/integration/api/utils.py | 21 ++++ 5 files changed, 147 insertions(+), 88 deletions(-) create mode 100644 florist/tests/integration/api/utils.py diff --git a/florist/api/monitoring/metrics.py b/florist/api/monitoring/metrics.py index ed8461fc..a04e7b01 100644 --- a/florist/api/monitoring/metrics.py +++ b/florist/api/monitoring/metrics.py @@ -1,6 +1,7 @@ """Classes for the instrumentation of metrics reporting from clients and servers.""" import json -from logging import DEBUG +import time +from logging import DEBUG, Logger from typing import Any, Dict, Optional import redis @@ -64,3 +65,56 @@ def dump(self) -> None: encoded_metrics = json.dumps(self.metrics, cls=DateTimeEncoder) log(DEBUG, f"Dumping metrics to redis at key '{self.run_id}': {encoded_metrics}") self.redis_connection.set(self.run_id, encoded_metrics) + + +def wait_for_metric( + uuid: str, + metric: str, + redis_host: str, + redis_port: str, + logger: Logger, + max_retries: int = 20, + seconds_to_sleep_between_retries: int = 1, +) -> None: + """ + Check metrics on Redis under the given UUID and wait until it appears. + + If the metrics are not there yet, it will retry up to max_retries times, + sleeping and amount of seconds_to_sleep_between_retries between them. + + :param uuid: (str) The UUID to pull the metrics from Redis. + :param metric: (str) The metric to look for. + :param redis_host: (str) The hostname of the Redis instance the metrics are being reported to. + :param redis_port: (str) The port of the Redis instance the metrics are being reported to. + :param logger: (logging.Logger) A logger instance to write logs to. + :param max_retries: (int) The maximum number of retries. Optional, default is 20. + :param seconds_to_sleep_between_retries: (int) The amount of seconds to sleep between retries. + Optional, default is 1. + :raises Exception: If it retries MAX_RETRIES times and the right metrics have not been found. + """ + redis_connection = redis.Redis(host=redis_host, port=redis_port) + + retry = 0 + while retry < max_retries: + result = redis_connection.get(uuid) + + if result is not None: + assert isinstance(result, bytes) + json_result = json.loads(result.decode("utf8")) + if metric in json_result: + logger.debug(f"Metric '{metric}' has been found. Result: {json_result}") + return + + logger.debug( + f"Metric '{metric}' has not been found yet, sleeping for {seconds_to_sleep_between_retries}s. " + f"Retry: {retry}. Result: {json_result}" + ) + else: + logger.debug( + f"Metric '{metric}' has not been found yet, sleeping for {seconds_to_sleep_between_retries}s. " + f"Retry: {retry}. Result is None." + ) + time.sleep(seconds_to_sleep_between_retries) + retry += 1 + + raise Exception(f"Metric '{metric}' not been found after {max_retries} retries.") diff --git a/florist/api/server.py b/florist/api/server.py index c6c8d512..3f84252d 100644 --- a/florist/api/server.py +++ b/florist/api/server.py @@ -7,8 +7,9 @@ from fastapi.responses import JSONResponse from typing_extensions import Annotated +from florist.api.monitoring.metrics import wait_for_metric from florist.api.servers.common import ClientInfo, ClientInfoParseError, Model -from florist.api.servers.launch import launch_local_server, wait_until_server_is_started +from florist.api.servers.launch import launch_local_server app = FastAPI() @@ -76,7 +77,7 @@ def start_training( redis_host=redis_host, redis_port=redis_port, ) - wait_until_server_is_started(server_uuid, redis_host, redis_port, logger=LOGGER) + wait_for_metric(server_uuid, "fit_start", redis_host, redis_port, logger=LOGGER) # Start the clients client_uuids: List[str] = [] diff --git a/florist/api/servers/launch.py b/florist/api/servers/launch.py index 94267738..957558c2 100644 --- a/florist/api/servers/launch.py +++ b/florist/api/servers/launch.py @@ -1,13 +1,9 @@ """Functions and definitions to launch local servers.""" -import json -import time import uuid from functools import partial -from logging import Logger from multiprocessing import Process from typing import Tuple -from redis import Redis from torch import nn from florist.api.launchers.local import launch_server @@ -51,48 +47,3 @@ def launch_local_server( ) return server_uuid, server_process - - -MAX_RETRIES = 20 -SECONDS_TO_SLEEP_BETWEEN_RETRIES = 1 - - -def wait_until_server_is_started(server_uuid: str, redis_host: str, redis_port: str, logger: Logger) -> None: - """ - Check server's metrics on Redis and wait until it has been started. - - If the right metrics are not there yet, it will retry up to MAX_RETRIES times, - sleeping and amount of SECONDS_TO_SLEEP_BETWEEN_RETRIES between them. - - :param server_uuid: (str) The UUID of the server in order to pull its metrics from Redis. - :param redis_host: (str) The hostname of the Redis instance this server is reporting to. - :param redis_port: (str) The port of the Redis instance this server is reporting to. - :param logger: (logging.Logger) A logger instance to write logs to. - :raises Exception: If it retries MAX_RETRIES times and the right metrics have not been found. - """ - redis_connection = Redis(host=redis_host, port=redis_port) - - retry = 0 - while retry < MAX_RETRIES: - result = redis_connection.get(server_uuid) - - if result is not None: - assert isinstance(result, bytes) - json_result = json.loads(result.decode("utf8")) - if "fit_start" in json_result: - logger.debug(f"Server has started. Result: {json_result}") - return - - logger.debug( - f"Server is not started yet, sleeping for {SECONDS_TO_SLEEP_BETWEEN_RETRIES}. " - f"Retry: {retry}. Result: {json_result}" - ) - else: - logger.debug( - f"Server is not started yet, sleeping for {SECONDS_TO_SLEEP_BETWEEN_RETRIES}. " - f"Retry: {retry}. Result is None." - ) - time.sleep(SECONDS_TO_SLEEP_BETWEEN_RETRIES) - retry += 1 - - raise Exception(f"Server failed to start after {MAX_RETRIES} retries.") diff --git a/florist/tests/integration/api/test_train.py b/florist/tests/integration/api/test_train.py index ed159f79..7404be73 100644 --- a/florist/tests/integration/api/test_train.py +++ b/florist/tests/integration/api/test_train.py @@ -1,45 +1,77 @@ import json +import logging +import requests import tempfile -import time from unittest.mock import ANY import redis +import uvicorn -from florist.api import client -from florist.api.clients.mnist import MnistNet -from florist.api.monitoring.logs import get_server_log_file_path -from florist.api.servers.launch import launch_local_server +from florist.api.monitoring.metrics import wait_for_metric +from florist.api.server import LOGGER +from florist.tests.integration.api.utils import Server def test_train(): - with tempfile.TemporaryDirectory() as temp_dir: - test_server_address = "0.0.0.0:8080" - test_client = "MNIST" - test_data_path = f"{temp_dir}/data" - test_redis_host = "localhost" - test_redis_port = "6379" - - server_uuid, server_process = launch_local_server( - MnistNet(), - 1, - test_server_address, - 2, - test_redis_host, - test_redis_port, - ) - time.sleep(10) # giving time to start the server - - response = client.start(test_server_address, test_client, test_data_path, test_redis_host, test_redis_port) - json_body = json.loads(response.body.decode()) - - assert json_body == {"uuid": ANY} - - server_process.join() - - redis_conn = redis.Redis(host=test_redis_host, port=test_redis_port) - assert redis_conn.get(json_body["uuid"]) is not None - assert redis_conn.get(server_uuid) is not None - - with open(get_server_log_file_path(server_uuid), "r") as f: - file_contents = f.read() - assert "FL finished in" in file_contents + # start services + server_config = uvicorn.Config("florist.api.server:app", host="localhost", port=8000, log_level="debug") + server_service = Server(config=server_config) + client_config = uvicorn.Config("florist.api.client:app", host="localhost", port=8001, log_level="debug") + client_service = Server(config=client_config) + + with server_service.run_in_thread(): + with client_service.run_in_thread(): + with tempfile.TemporaryDirectory() as temp_dir: + test_redis_host = "localhost" + test_redis_port = "6379" + + data = { + "model": (None, "MNIST"), + "server_address": (None, "localhost:8080"), + "n_server_rounds": (None, 2), + "redis_host": (None, test_redis_host), + "redis_port": (None, test_redis_port), + "clients_info": (None, json.dumps( + [ + { + "client": "MNIST", + "client_address": "localhost:8001", + "data_path": f"{temp_dir}/data", + "redis_host": test_redis_host, + "redis_port": test_redis_port, + }, + ], + )), + } + request = requests.Request( + method="POST", + url=f"http://localhost:8000/api/server/start_training", + files=data, + ).prepare() + session = requests.Session() + response = session.send(request) + + assert response.status_code == 200 + assert response.json() == {"server_uuid": ANY, "client_uuids": [ANY]} + + redis_conn = redis.Redis(host=test_redis_host, port=test_redis_port) + server_uuid = response.json()["server_uuid"] + client_uuid = response.json()["client_uuids"][0] + + wait_for_metric(server_uuid, "fit_end", test_redis_host, test_redis_port, LOGGER, max_retries=60) + + server_metrics_result = redis_conn.get(server_uuid) + assert server_metrics_result is not None and isinstance(server_metrics_result, bytes) + server_metrics = json.loads(server_metrics_result.decode("utf8")) + assert server_metrics["type"] == "server" + assert "fit_start" in server_metrics + assert "fit_end" in server_metrics + assert len(server_metrics["rounds"]) == 2 + + client_metrics_result = redis_conn.get(client_uuid) + assert client_metrics_result is not None and isinstance(client_metrics_result, bytes) + client_metrics = json.loads(client_metrics_result.decode("utf8")) + assert client_metrics["type"] == "client" + assert "initialized" in client_metrics + assert "shutdown" in client_metrics + assert len(client_metrics["rounds"]) == 2 diff --git a/florist/tests/integration/api/utils.py b/florist/tests/integration/api/utils.py new file mode 100644 index 00000000..3b2c7a77 --- /dev/null +++ b/florist/tests/integration/api/utils.py @@ -0,0 +1,21 @@ +import contextlib +import time +import threading +import uvicorn + + +class Server(uvicorn.Server): + def install_signal_handlers(self): + pass + + @contextlib.contextmanager + def run_in_thread(self): + thread = threading.Thread(target=self.run) + thread.start() + try: + while not self.started: + time.sleep(1e-3) + yield + finally: + self.should_exit = True + thread.join() From 6eff335e40111971636bd5d81c0d7faa03fe9e41 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Mon, 25 Mar 2024 11:48:03 -0400 Subject: [PATCH 03/24] Fixing test, adding code comments to integration test --- florist/tests/integration/api/test_train.py | 16 +++++++++++----- .../servers/{test_local.py => test_launch.py} | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) rename florist/tests/unit/api/servers/{test_local.py => test_launch.py} (97%) diff --git a/florist/tests/integration/api/test_train.py b/florist/tests/integration/api/test_train.py index 7404be73..6d747ad4 100644 --- a/florist/tests/integration/api/test_train.py +++ b/florist/tests/integration/api/test_train.py @@ -1,5 +1,4 @@ import json -import logging import requests import tempfile from unittest.mock import ANY @@ -13,22 +12,25 @@ def test_train(): - # start services + # Define services server_config = uvicorn.Config("florist.api.server:app", host="localhost", port=8000, log_level="debug") server_service = Server(config=server_config) client_config = uvicorn.Config("florist.api.client:app", host="localhost", port=8001, log_level="debug") client_service = Server(config=client_config) + # Start services with server_service.run_in_thread(): with client_service.run_in_thread(): with tempfile.TemporaryDirectory() as temp_dir: test_redis_host = "localhost" test_redis_port = "6379" + test_n_server_rounds = 2 + # Send the POST request to start training data = { "model": (None, "MNIST"), "server_address": (None, "localhost:8080"), - "n_server_rounds": (None, 2), + "n_server_rounds": (None, test_n_server_rounds), "redis_host": (None, test_redis_host), "redis_port": (None, test_redis_port), "clients_info": (None, json.dumps( @@ -51,6 +53,7 @@ def test_train(): session = requests.Session() response = session.send(request) + # Check response assert response.status_code == 200 assert response.json() == {"server_uuid": ANY, "client_uuids": [ANY]} @@ -58,20 +61,23 @@ def test_train(): server_uuid = response.json()["server_uuid"] client_uuid = response.json()["client_uuids"][0] + # Wait for training to finish wait_for_metric(server_uuid, "fit_end", test_redis_host, test_redis_port, LOGGER, max_retries=60) + # Check server metrics server_metrics_result = redis_conn.get(server_uuid) assert server_metrics_result is not None and isinstance(server_metrics_result, bytes) server_metrics = json.loads(server_metrics_result.decode("utf8")) assert server_metrics["type"] == "server" assert "fit_start" in server_metrics assert "fit_end" in server_metrics - assert len(server_metrics["rounds"]) == 2 + assert len(server_metrics["rounds"]) == test_n_server_rounds + # Check client metrics client_metrics_result = redis_conn.get(client_uuid) assert client_metrics_result is not None and isinstance(client_metrics_result, bytes) client_metrics = json.loads(client_metrics_result.decode("utf8")) assert client_metrics["type"] == "client" assert "initialized" in client_metrics assert "shutdown" in client_metrics - assert len(client_metrics["rounds"]) == 2 + assert len(client_metrics["rounds"]) == test_n_server_rounds diff --git a/florist/tests/unit/api/servers/test_local.py b/florist/tests/unit/api/servers/test_launch.py similarity index 97% rename from florist/tests/unit/api/servers/test_local.py rename to florist/tests/unit/api/servers/test_launch.py index be0aa836..73fed658 100644 --- a/florist/tests/unit/api/servers/test_local.py +++ b/florist/tests/unit/api/servers/test_launch.py @@ -7,7 +7,7 @@ from florist.api.servers.utils import get_server -@patch("florist.api.servers.local.launch_server") +@patch("florist.api.servers.launch.launch_server") def test_launch_local_server(mock_launch_server: Mock) -> None: test_model = MnistNet() test_n_clients = 2 From 6124d847be494e09ee4670ff5628b47372a4ddcf Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Mon, 25 Mar 2024 13:54:10 -0400 Subject: [PATCH 04/24] Happy path test --- florist/tests/unit/api/test_server.py | 88 +++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 florist/tests/unit/api/test_server.py diff --git a/florist/tests/unit/api/test_server.py b/florist/tests/unit/api/test_server.py new file mode 100644 index 00000000..a9b64535 --- /dev/null +++ b/florist/tests/unit/api/test_server.py @@ -0,0 +1,88 @@ +import json +from unittest.mock import Mock, patch, ANY + +from florist.api.models.mnist import MnistNet +from florist.api.server import start_training + + +@patch("florist.api.server.launch_local_server") +@patch("florist.api.monitoring.metrics.redis") +@patch("florist.api.server.requests") +def test_start_training_success(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None: + test_model = "MNIST" + test_server_address = "test-server-address" + test_n_server_rounds = 2 + test_redis_host = "test-redis-host" + test_redis_port = "test-redis-port" + test_clients_info = [ + { + "client": "MNIST", + "client_address": "test-client-address-1", + "data_path": "test-data-path-1", + "redis_host": "test-redis-host-1", + "redis_port": "test-redis-port-1", + }, { + "client": "MNIST", + "client_address": "test-client-address-2", + "data_path": "test-data-path-2", + "redis_host": "test-redis-host-2", + "redis_port": "test-redis-port-2", + } + ] + test_server_uuid = "test-server-uuid" + mock_launch_local_server.return_value = (test_server_uuid, None) + + mock_redis_connection = Mock() + mock_redis_connection.get.return_value = b"{\"fit_start\": null}" + mock_redis.Redis.return_value = mock_redis_connection + + mock_response = Mock() + test_client_1_uuid = "test-client-1-uuid" + test_client_2_uuid = "test-client-2-uuid" + mock_response.json.side_effect = [{"uuid": test_client_1_uuid},{"uuid": test_client_2_uuid}] + mock_requests.get.return_value = mock_response + + response = start_training( + test_model, + test_server_address, + test_n_server_rounds, + test_redis_host, + test_redis_port, + json.dumps(test_clients_info), + ) + + assert response.status_code == 200 + json_body = json.loads(response.body.decode()) + assert json_body == {"server_uuid": test_server_uuid, "client_uuids": [test_client_1_uuid, test_client_2_uuid]} + + assert isinstance(mock_launch_local_server.call_args_list[0][1]["model"], MnistNet) + mock_launch_local_server.assert_called_once_with( + model=ANY, + n_clients=len(test_clients_info), + server_address=test_server_address, + n_server_rounds=test_n_server_rounds, + redis_host=test_redis_host, + redis_port=test_redis_port, + ) + mock_redis.Redis.assert_called_once_with(host=test_redis_host, port=test_redis_port) + mock_redis_connection.get.assert_called_once_with(test_server_uuid) + mock_requests.get.assert_any_call( + url=f"http://{test_clients_info[0]['client_address']}/api/client/start", + params={ + "server_address": test_server_address, + "client": test_clients_info[0]["client"], + "data_path": test_clients_info[0]["data_path"], + "redis_host": test_clients_info[0]["redis_host"], + "redis_port": test_clients_info[0]["redis_port"], + }, + ) + mock_requests.get.assert_any_call( + url=f"http://{test_clients_info[1]['client_address']}/api/client/start", + params={ + "server_address": test_server_address, + "client": test_clients_info[1]["client"], + "data_path": test_clients_info[1]["data_path"], + "redis_host": test_clients_info[1]["redis_host"], + "redis_port": test_clients_info[1]["redis_port"], + }, + ) From 79d8a180952a6f1e7df407e3a7171979c1c5bf31 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Mon, 25 Mar 2024 15:26:40 -0400 Subject: [PATCH 05/24] Finished tests for server --- florist/api/server.py | 3 + florist/tests/unit/api/test_server.py | 326 ++++++++++++++++++++++++++ 2 files changed, 329 insertions(+) diff --git a/florist/api/server.py b/florist/api/server.py index 3f84252d..6f8c186f 100644 --- a/florist/api/server.py +++ b/florist/api/server.py @@ -93,6 +93,9 @@ def start_training( json_response = response.json() LOGGER.debug(f"Client response: {json_response}") + if response.status_code != 200: + raise Exception(f"Client response returned {response.status_code}. Response: {json_response}") + if "uuid" not in json_response: raise Exception(f"Client response did not return a UUID. Response: {json_response}") diff --git a/florist/tests/unit/api/test_server.py b/florist/tests/unit/api/test_server.py index a9b64535..bf2ab7df 100644 --- a/florist/tests/unit/api/test_server.py +++ b/florist/tests/unit/api/test_server.py @@ -9,6 +9,7 @@ @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.server.requests") def test_start_training_success(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None: + # Arrange test_model = "MNIST" test_server_address = "test-server-address" test_n_server_rounds = 2 @@ -37,11 +38,13 @@ def test_start_training_success(mock_requests: Mock, mock_redis: Mock, mock_laun mock_redis.Redis.return_value = mock_redis_connection mock_response = Mock() + mock_response.status_code = 200 test_client_1_uuid = "test-client-1-uuid" test_client_2_uuid = "test-client-2-uuid" mock_response.json.side_effect = [{"uuid": test_client_1_uuid},{"uuid": test_client_2_uuid}] mock_requests.get.return_value = mock_response + # Act response = start_training( test_model, test_server_address, @@ -51,6 +54,7 @@ def test_start_training_success(mock_requests: Mock, mock_redis: Mock, mock_laun json.dumps(test_clients_info), ) + # Assert assert response.status_code == 200 json_body = json.loads(response.body.decode()) assert json_body == {"server_uuid": test_server_uuid, "client_uuids": [test_client_1_uuid, test_client_2_uuid]} @@ -86,3 +90,325 @@ def test_start_training_success(mock_requests: Mock, mock_redis: Mock, mock_laun "redis_port": test_clients_info[1]["redis_port"], }, ) + + +def test_start_fail_unsupported_server_model() -> None: + # Arrange + test_model = "WRONG MODEL" + test_server_address = "test-server-address" + test_n_server_rounds = 2 + test_redis_host = "test-redis-host" + test_redis_port = "test-redis-port" + test_clients_info = [ + { + "client": "MNIST", + "client_address": "test-client-address-1", + "data_path": "test-data-path-1", + "redis_host": "test-redis-host-1", + "redis_port": "test-redis-port-1", + }, { + "client": "MNIST", + "client_address": "test-client-address-2", + "data_path": "test-data-path-2", + "redis_host": "test-redis-host-2", + "redis_port": "test-redis-port-2", + } + ] + + # Act + response = start_training( + test_model, + test_server_address, + test_n_server_rounds, + test_redis_host, + test_redis_port, + json.dumps(test_clients_info), + ) + + # Assert + assert response.status_code == 400 + json_body = json.loads(response.body.decode()) + assert json_body == {"error": ANY} + assert "Model 'WRONG MODEL' not supported." in json_body["error"] + + +def test_start_fail_unsupported_client() -> None: + # Arrange + test_model = "MNIST" + test_server_address = "test-server-address" + test_n_server_rounds = 2 + test_redis_host = "test-redis-host" + test_redis_port = "test-redis-port" + test_clients_info = [ + { + "client": "MNIST", + "client_address": "test-client-address-1", + "data_path": "test-data-path-1", + "redis_host": "test-redis-host-1", + "redis_port": "test-redis-port-1", + }, { + "client": "WRONG CLIENT", + "client_address": "test-client-address-2", + "data_path": "test-data-path-2", + "redis_host": "test-redis-host-2", + "redis_port": "test-redis-port-2", + } + ] + + # Act + response = start_training( + test_model, + test_server_address, + test_n_server_rounds, + test_redis_host, + test_redis_port, + json.dumps(test_clients_info), + ) + + # Assert + assert response.status_code == 400 + json_body = json.loads(response.body.decode()) + assert json_body == {"error": ANY} + assert "Client 'WRONG CLIENT' not supported." in json_body["error"] + + +@patch("florist.api.server.launch_local_server") +def test_start_training_launch_server_exception(mock_launch_local_server: Mock) -> None: + # Arrange + test_model = "MNIST" + test_server_address = "test-server-address" + test_n_server_rounds = 2 + test_redis_host = "test-redis-host" + test_redis_port = "test-redis-port" + test_clients_info = [ + { + "client": "MNIST", + "client_address": "test-client-address-1", + "data_path": "test-data-path-1", + "redis_host": "test-redis-host-1", + "redis_port": "test-redis-port-1", + }, { + "client": "MNIST", + "client_address": "test-client-address-2", + "data_path": "test-data-path-2", + "redis_host": "test-redis-host-2", + "redis_port": "test-redis-port-2", + }, + ] + test_exception = Exception("test exception") + mock_launch_local_server.side_effect = test_exception + + # Act + response = start_training( + test_model, + test_server_address, + test_n_server_rounds, + test_redis_host, + test_redis_port, + json.dumps(test_clients_info), + ) + + # Assert + assert response.status_code == 500 + json_body = json.loads(response.body.decode()) + assert json_body == {"error": str(test_exception)} + + +@patch("florist.api.server.launch_local_server") +@patch("florist.api.monitoring.metrics.redis") +def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_local_server: Mock) -> None: + # Arrange + test_model = "MNIST" + test_server_address = "test-server-address" + test_n_server_rounds = 2 + test_redis_host = "test-redis-host" + test_redis_port = "test-redis-port" + test_clients_info = [ + { + "client": "MNIST", + "client_address": "test-client-address-1", + "data_path": "test-data-path-1", + "redis_host": "test-redis-host-1", + "redis_port": "test-redis-port-1", + }, { + "client": "MNIST", + "client_address": "test-client-address-2", + "data_path": "test-data-path-2", + "redis_host": "test-redis-host-2", + "redis_port": "test-redis-port-2", + } + ] + test_server_uuid = "test-server-uuid" + mock_launch_local_server.return_value = (test_server_uuid, None) + + test_exception = Exception("test exception") + mock_redis.Redis.side_effect = test_exception + + # Act + response = start_training( + test_model, + test_server_address, + test_n_server_rounds, + test_redis_host, + test_redis_port, + json.dumps(test_clients_info), + ) + + # Assert + assert response.status_code == 500 + json_body = json.loads(response.body.decode()) + assert json_body == {"error": str(test_exception)} + + +@patch("florist.api.server.launch_local_server") +@patch("florist.api.monitoring.metrics.redis") +@patch("florist.api.monitoring.metrics.time") +def test_start_wait_for_metric_timeout(mock_time: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None: + # Arrange + test_model = "MNIST" + test_server_address = "test-server-address" + test_n_server_rounds = 2 + test_redis_host = "test-redis-host" + test_redis_port = "test-redis-port" + test_clients_info = [ + { + "client": "MNIST", + "client_address": "test-client-address-1", + "data_path": "test-data-path-1", + "redis_host": "test-redis-host-1", + "redis_port": "test-redis-port-1", + }, { + "client": "MNIST", + "client_address": "test-client-address-2", + "data_path": "test-data-path-2", + "redis_host": "test-redis-host-2", + "redis_port": "test-redis-port-2", + } + ] + test_server_uuid = "test-server-uuid" + mock_launch_local_server.return_value = (test_server_uuid, None) + + mock_redis_connection = Mock() + mock_redis_connection.get.return_value = b"{\"foo\": null}" + mock_redis.Redis.return_value = mock_redis_connection + + # Act + response = start_training( + test_model, + test_server_address, + test_n_server_rounds, + test_redis_host, + test_redis_port, + json.dumps(test_clients_info), + ) + + # Assert + assert response.status_code == 500 + json_body = json.loads(response.body.decode()) + assert json_body == {"error": "Metric 'fit_start' not been found after 20 retries."} + + +@patch("florist.api.server.launch_local_server") +@patch("florist.api.monitoring.metrics.redis") +@patch("florist.api.server.requests") +def test_start_training_fail_response(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None: + # Arrange + test_model = "MNIST" + test_server_address = "test-server-address" + test_n_server_rounds = 2 + test_redis_host = "test-redis-host" + test_redis_port = "test-redis-port" + test_clients_info = [ + { + "client": "MNIST", + "client_address": "test-client-address-1", + "data_path": "test-data-path-1", + "redis_host": "test-redis-host-1", + "redis_port": "test-redis-port-1", + }, { + "client": "MNIST", + "client_address": "test-client-address-2", + "data_path": "test-data-path-2", + "redis_host": "test-redis-host-2", + "redis_port": "test-redis-port-2", + } + ] + test_server_uuid = "test-server-uuid" + mock_launch_local_server.return_value = (test_server_uuid, None) + + mock_redis_connection = Mock() + mock_redis_connection.get.return_value = b"{\"fit_start\": null}" + mock_redis.Redis.return_value = mock_redis_connection + + mock_response = Mock() + mock_response.status_code = 403 + mock_response.json.return_value = "error" + mock_requests.get.return_value = mock_response + + # Act + response = start_training( + test_model, + test_server_address, + test_n_server_rounds, + test_redis_host, + test_redis_port, + json.dumps(test_clients_info), + ) + + # Assert + assert response.status_code == 500 + json_body = json.loads(response.body.decode()) + assert json_body == {"error": f"Client response returned 403. Response: error"} + + +@patch("florist.api.server.launch_local_server") +@patch("florist.api.monitoring.metrics.redis") +@patch("florist.api.server.requests") +def test_start_training_no_uuid_in_response(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None: + # Arrange + test_model = "MNIST" + test_server_address = "test-server-address" + test_n_server_rounds = 2 + test_redis_host = "test-redis-host" + test_redis_port = "test-redis-port" + test_clients_info = [ + { + "client": "MNIST", + "client_address": "test-client-address-1", + "data_path": "test-data-path-1", + "redis_host": "test-redis-host-1", + "redis_port": "test-redis-port-1", + }, { + "client": "MNIST", + "client_address": "test-client-address-2", + "data_path": "test-data-path-2", + "redis_host": "test-redis-host-2", + "redis_port": "test-redis-port-2", + } + ] + test_server_uuid = "test-server-uuid" + mock_launch_local_server.return_value = (test_server_uuid, None) + + mock_redis_connection = Mock() + mock_redis_connection.get.return_value = b"{\"fit_start\": null}" + mock_redis.Redis.return_value = mock_redis_connection + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"foo": "bar"} + mock_requests.get.return_value = mock_response + + # Act + response = start_training( + test_model, + test_server_address, + test_n_server_rounds, + test_redis_host, + test_redis_port, + json.dumps(test_clients_info), + ) + + # Assert + assert response.status_code == 500 + json_body = json.loads(response.body.decode()) + assert json_body == {"error": "Client response did not return a UUID. Response: {'foo': 'bar'}"} From a65aa12dee620bed487744e9c636a00ef5230924 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Mon, 25 Mar 2024 15:44:25 -0400 Subject: [PATCH 06/24] Finished tests for client info --- florist/tests/unit/api/servers/test_common.py | 178 ++++++++++++++++++ florist/tests/unit/api/test_server.py | 4 +- 2 files changed, 180 insertions(+), 2 deletions(-) create mode 100644 florist/tests/unit/api/servers/test_common.py diff --git a/florist/tests/unit/api/servers/test_common.py b/florist/tests/unit/api/servers/test_common.py new file mode 100644 index 00000000..3910dde2 --- /dev/null +++ b/florist/tests/unit/api/servers/test_common.py @@ -0,0 +1,178 @@ +import json +from copy import deepcopy +from pytest import raises + +from florist.api.clients.common import Client +from florist.api.servers.common import ClientInfo, ClientInfoParseError + + +def test_client_info_parse_success() -> None: + test_clients_info = [ + { + "client": "MNIST", + "client_address": "test-client-address-1", + "data_path": "test-data-path-1", + "redis_host": "test-redis-host-1", + "redis_port": "test-redis-port-1", + }, { + "client": "MNIST", + "client_address": "test-client-address-2", + "data_path": "test-data-path-2", + "redis_host": "test-redis-host-2", + "redis_port": "test-redis-port-2", + } + ] + + result = ClientInfo.parse(json.dumps(test_clients_info)) + + for i in range(len(test_clients_info)): + assert result[i].client == Client.MNIST + assert result[i].client_address == test_clients_info[i]["client_address"] + assert result[i].data_path == test_clients_info[i]["data_path"] + assert result[i].redis_host == test_clients_info[i]["redis_host"] + assert result[i].redis_port == test_clients_info[i]["redis_port"] + + +def test_client_info_parse_fail_client() -> None: + test_clients_info = [ + { + "client": "MNIST", + "client_address": "test-client-address-1", + "data_path": "test-data-path-1", + "redis_host": "test-redis-host-1", + "redis_port": "test-redis-port-1", + }, { + "client": "MNIST", + "client_address": "test-client-address-2", + "data_path": "test-data-path-2", + "redis_host": "test-redis-host-2", + "redis_port": "test-redis-port-2", + } + ] + + with raises(ClientInfoParseError): + test_data = deepcopy(test_clients_info) + test_data[1]["client"] = "WRONG CLIENT" + ClientInfo.parse(json.dumps(test_data)) + + with raises(ClientInfoParseError): + test_data = deepcopy(test_clients_info) + test_data[1]["client"] = 2 + ClientInfo.parse(json.dumps(test_data)) + + with raises(ClientInfoParseError): + test_data = deepcopy(test_clients_info) + del test_data[1]["client"] + ClientInfo.parse(json.dumps(test_data)) + + +def test_client_info_parse_fail_client_address() -> None: + test_clients_info = [ + { + "client": "MNIST", + "client_address": "test-client-address-1", + "data_path": "test-data-path-1", + "redis_host": "test-redis-host-1", + "redis_port": "test-redis-port-1", + }, { + "client": "MNIST", + "client_address": "test-client-address-2", + "data_path": "test-data-path-2", + "redis_host": "test-redis-host-2", + "redis_port": "test-redis-port-2", + } + ] + + with raises(ClientInfoParseError): + test_data = deepcopy(test_clients_info) + test_data[1]["client_address"] = 2 + ClientInfo.parse(json.dumps(test_data)) + + with raises(ClientInfoParseError): + test_data = deepcopy(test_clients_info) + del test_data[1]["client_address"] + ClientInfo.parse(json.dumps(test_data)) + + +def test_client_info_parse_fail_data_path() -> None: + test_clients_info = [ + { + "client": "MNIST", + "client_address": "test-client-address-1", + "data_path": "test-data-path-1", + "redis_host": "test-redis-host-1", + "redis_port": "test-redis-port-1", + }, { + "client": "MNIST", + "client_address": "test-client-address-2", + "data_path": "test-data-path-2", + "redis_host": "test-redis-host-2", + "redis_port": "test-redis-port-2", + } + ] + + with raises(ClientInfoParseError): + test_data = deepcopy(test_clients_info) + test_data[1]["data_path"] = 2 + ClientInfo.parse(json.dumps(test_data)) + + with raises(ClientInfoParseError): + test_data = deepcopy(test_clients_info) + del test_data[1]["data_path"] + ClientInfo.parse(json.dumps(test_data)) + + +def test_client_info_parse_fail_redis_host() -> None: + test_clients_info = [ + { + "client": "MNIST", + "client_address": "test-client-address-1", + "data_path": "test-data-path-1", + "redis_host": "test-redis-host-1", + "redis_port": "test-redis-port-1", + }, { + "client": "MNIST", + "client_address": "test-client-address-2", + "data_path": "test-data-path-2", + "redis_host": "test-redis-host-2", + "redis_port": "test-redis-port-2", + } + ] + + with raises(ClientInfoParseError): + test_data = deepcopy(test_clients_info) + test_data[1]["redis_host"] = 2 + ClientInfo.parse(json.dumps(test_data)) + + with raises(ClientInfoParseError): + test_data = deepcopy(test_clients_info) + del test_data[1]["redis_host"] + ClientInfo.parse(json.dumps(test_data)) + + +def test_client_info_parse_fail_redis_port() -> None: + test_clients_info = [ + { + "client": "MNIST", + "client_address": "test-client-address-1", + "data_path": "test-data-path-1", + "redis_host": "test-redis-host-1", + "redis_port": "test-redis-port-1", + }, { + "client": "MNIST", + "client_address": "test-client-address-2", + "data_path": "test-data-path-2", + "redis_host": "test-redis-host-2", + "redis_port": "test-redis-port-2", + } + ] + + with raises(ClientInfoParseError): + test_data = deepcopy(test_clients_info) + test_data[1]["redis_port"] = 2 + ClientInfo.parse(json.dumps(test_data)) + + with raises(ClientInfoParseError): + test_data = deepcopy(test_clients_info) + del test_data[1]["redis_port"] + ClientInfo.parse(json.dumps(test_data)) diff --git a/florist/tests/unit/api/test_server.py b/florist/tests/unit/api/test_server.py index bf2ab7df..9c0f8d2b 100644 --- a/florist/tests/unit/api/test_server.py +++ b/florist/tests/unit/api/test_server.py @@ -262,8 +262,8 @@ def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_local_ser @patch("florist.api.server.launch_local_server") @patch("florist.api.monitoring.metrics.redis") -@patch("florist.api.monitoring.metrics.time") -def test_start_wait_for_metric_timeout(mock_time: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None: +@patch("florist.api.monitoring.metrics.time") # just so time.sleep does not actually sleep +def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None: # Arrange test_model = "MNIST" test_server_address = "test-server-address" From 5dc97082141c468ade16a5cfca3e1e227a7a9eea Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Mon, 25 Mar 2024 16:00:48 -0400 Subject: [PATCH 07/24] Finished tests for metrics --- .../tests/unit/api/monitoring/test_metrics.py | 65 ++++++++++++++++++- 1 file changed, 63 insertions(+), 2 deletions(-) diff --git a/florist/tests/unit/api/monitoring/test_metrics.py b/florist/tests/unit/api/monitoring/test_metrics.py index 518d20c5..85ba1ecb 100644 --- a/florist/tests/unit/api/monitoring/test_metrics.py +++ b/florist/tests/unit/api/monitoring/test_metrics.py @@ -1,11 +1,13 @@ import datetime import json -from unittest.mock import Mock, patch +import logging +from pytest import raises +from unittest.mock import Mock, call, patch from fl4health.reporting.metrics import DateTimeEncoder from freezegun import freeze_time -from florist.api.monitoring.metrics import RedisMetricsReporter +from florist.api.monitoring.metrics import RedisMetricsReporter, wait_for_metric @freeze_time("2012-12-11 10:09:08") @@ -94,3 +96,62 @@ def test_dump_with_existing_connection(mock_redis: Mock) -> None: mock_redis.assert_not_called() assert mock_redis_connection.set.call_args_list[0][0][0] == test_run_id assert mock_redis_connection.set.call_args_list[0][0][1] == json.dumps(test_data, cls=DateTimeEncoder) + + +@patch("florist.api.monitoring.metrics.redis") +@patch("florist.api.monitoring.metrics.time") # just so time.sleep does not actually sleep +def test_wait_for_metric_success(_: Mock, mock_redis: Mock) -> None: + test_uuid = "uuid" + test_metric = "test-metric" + test_redis_host = "test-redis-host" + test_redis_port = "test-redis-port" + + mock_redis_connection = Mock() + mock_redis_connection.get.return_value = b"{\"test-metric\": null}" + mock_redis.Redis.return_value = mock_redis_connection + + wait_for_metric(test_uuid, test_metric, test_redis_host, test_redis_port, logging.getLogger(__name__)) + + mock_redis.Redis.assert_called_once_with(host=test_redis_host, port=test_redis_port) + mock_redis_connection.get.assert_called_once_with(test_uuid) + + +@patch("florist.api.monitoring.metrics.redis") +@patch("florist.api.monitoring.metrics.time") # just so time.sleep does not actually sleep +def test_wait_for_metric_success_with_retry(_: Mock, mock_redis: Mock) -> None: + test_uuid = "uuid" + test_metric = "test-metric" + test_redis_host = "test-redis-host" + test_redis_port = "test-redis-port" + + mock_redis_connection = Mock() + mock_redis_connection.get.side_effect = [ + None, + None, + b"{\"foo\": \"bar\"}", + b"{\"test-metric\": null}", + b"{\"foo\": \"bar\"}", + ] + mock_redis.Redis.return_value = mock_redis_connection + + wait_for_metric(test_uuid, test_metric, test_redis_host, test_redis_port, logging.getLogger(__name__)) + + mock_redis.Redis.assert_called_once_with(host=test_redis_host, port=test_redis_port) + assert mock_redis_connection.get.call_count == 4 + mock_redis_connection.get.assert_has_calls([call(test_uuid)] * 4) + + +@patch("florist.api.monitoring.metrics.redis") +@patch("florist.api.monitoring.metrics.time") # just so time.sleep does not actually sleep +def test_wait_for_metric_fail_max_retries(_: Mock, mock_redis: Mock) -> None: + test_uuid = "uuid" + test_metric = "test-metric" + test_redis_host = "test-redis-host" + test_redis_port = "test-redis-port" + + mock_redis_connection = Mock() + mock_redis_connection.get.return_value = b"{\"foo\": \"bar\"}" + mock_redis.Redis.return_value = mock_redis_connection + + with raises(Exception): + wait_for_metric(test_uuid, test_metric, test_redis_host, test_redis_port, logging.getLogger(__name__)) From a6aa46ceb32234592027514a5ad36a470684fc17 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Mon, 25 Mar 2024 16:05:16 -0400 Subject: [PATCH 08/24] Adding inits --- florist/tests/integration/api/launchers/__init__.py | 0 florist/tests/unit/api/__init__.py | 0 florist/tests/unit/api/monitoring/__init__.py | 0 florist/tests/unit/api/servers/__init__.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 florist/tests/integration/api/launchers/__init__.py create mode 100644 florist/tests/unit/api/__init__.py create mode 100644 florist/tests/unit/api/monitoring/__init__.py create mode 100644 florist/tests/unit/api/servers/__init__.py diff --git a/florist/tests/integration/api/launchers/__init__.py b/florist/tests/integration/api/launchers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/florist/tests/unit/api/__init__.py b/florist/tests/unit/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/florist/tests/unit/api/monitoring/__init__.py b/florist/tests/unit/api/monitoring/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/florist/tests/unit/api/servers/__init__.py b/florist/tests/unit/api/servers/__init__.py new file mode 100644 index 00000000..e69de29b From 349bfe94663f38938b8b1c96821fe5d1dce9d685 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 27 Mar 2024 11:56:25 -0400 Subject: [PATCH 09/24] Adding batch_size and local_epochs to server params --- florist/api/server.py | 6 ++++ florist/api/servers/launch.py | 13 +++++++- florist/tests/integration/api/test_train.py | 2 ++ florist/tests/unit/api/test_server.py | 34 +++++++++++++++++++++ 4 files changed, 54 insertions(+), 1 deletion(-) diff --git a/florist/api/server.py b/florist/api/server.py index 6f8c186f..5982e0d2 100644 --- a/florist/api/server.py +++ b/florist/api/server.py @@ -23,6 +23,8 @@ def start_training( model: Annotated[str, Form()], server_address: Annotated[str, Form()], n_server_rounds: Annotated[int, Form()], + batch_size: Annotated[int, Form()], + local_epochs: Annotated[int, Form()], redis_host: Annotated[str, Form()], redis_port: Annotated[str, Form()], clients_info: Annotated[str, Form()], @@ -37,6 +39,8 @@ def start_training( :param server_address: (str) The address of the FL server to be started. It should be comprised of the host name and port separated by colon (e.g. "localhost:8080") :param n_server_rounds: (int) The number of rounds the FL server should run. + :param batch_size: (int) The size of the batch for training + :param local_epochs: (int) The number of epochs to run by the clients :param redis_host: (str) The host name for the Redis instance for metrics reporting. :param redis_port: (str) The port for the Redis instance for metrics reporting. :param clients_info: (str) A JSON string containing the client information. It will be parsed by @@ -74,6 +78,8 @@ def start_training( n_clients=len(clients_info_list), server_address=server_address, n_server_rounds=n_server_rounds, + batch_size=batch_size, + local_epochs=local_epochs, redis_host=redis_host, redis_port=redis_port, ) diff --git a/florist/api/servers/launch.py b/florist/api/servers/launch.py index 957558c2..19f74752 100644 --- a/florist/api/servers/launch.py +++ b/florist/api/servers/launch.py @@ -17,6 +17,8 @@ def launch_local_server( n_clients: int, server_address: str, n_server_rounds: int, + batch_size: int, + local_epochs: int, redis_host: str, redis_port: str, ) -> Tuple[str, Process]: @@ -27,6 +29,8 @@ def launch_local_server( :param n_clients: (int) The number of clients that will report to this server. :param server_address: (str) The address the server should start at. :param n_server_rounds: (int) The number of rounds the training should run for. + :param batch_size: (int) The size of the batch for training + :param local_epochs: (int) The number of epochs to run by the clients :param redis_host: (str) the host name for the Redis instance for metrics reporting. :param redis_port: (str) the port for the Redis instance for metrics reporting. :return: (Tuple[str, multiprocessing.Process]) the UUID of the server, which can be used to pull @@ -35,7 +39,14 @@ def launch_local_server( server_uuid = str(uuid.uuid4()) metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=server_uuid) - server_constructor = partial(get_server, model=model, n_clients=n_clients, metrics_reporter=metrics_reporter) + server_constructor = partial( + get_server, + model=model, + n_clients=n_clients, + batch_size=batch_size, + local_epochs=local_epochs, + metrics_reporter=metrics_reporter, + ) log_file_name = str(get_server_log_file_path(server_uuid)) server_process = launch_server( diff --git a/florist/tests/integration/api/test_train.py b/florist/tests/integration/api/test_train.py index 6d747ad4..4223060f 100644 --- a/florist/tests/integration/api/test_train.py +++ b/florist/tests/integration/api/test_train.py @@ -31,6 +31,8 @@ def test_train(): "model": (None, "MNIST"), "server_address": (None, "localhost:8080"), "n_server_rounds": (None, test_n_server_rounds), + "batch_size": (None, 8), + "local_epochs": (None, 1), "redis_host": (None, test_redis_host), "redis_port": (None, test_redis_port), "clients_info": (None, json.dumps( diff --git a/florist/tests/unit/api/test_server.py b/florist/tests/unit/api/test_server.py index 9c0f8d2b..52537e45 100644 --- a/florist/tests/unit/api/test_server.py +++ b/florist/tests/unit/api/test_server.py @@ -13,6 +13,8 @@ def test_start_training_success(mock_requests: Mock, mock_redis: Mock, mock_laun test_model = "MNIST" test_server_address = "test-server-address" test_n_server_rounds = 2 + test_batch_size = 8 + test_local_epochs = 1 test_redis_host = "test-redis-host" test_redis_port = "test-redis-port" test_clients_info = [ @@ -49,6 +51,8 @@ def test_start_training_success(mock_requests: Mock, mock_redis: Mock, mock_laun test_model, test_server_address, test_n_server_rounds, + test_batch_size, + test_local_epochs, test_redis_host, test_redis_port, json.dumps(test_clients_info), @@ -65,6 +69,8 @@ def test_start_training_success(mock_requests: Mock, mock_redis: Mock, mock_laun n_clients=len(test_clients_info), server_address=test_server_address, n_server_rounds=test_n_server_rounds, + batch_size=test_batch_size, + local_epochs=test_local_epochs, redis_host=test_redis_host, redis_port=test_redis_port, ) @@ -97,6 +103,8 @@ def test_start_fail_unsupported_server_model() -> None: test_model = "WRONG MODEL" test_server_address = "test-server-address" test_n_server_rounds = 2 + test_batch_size = 8 + test_local_epochs = 1 test_redis_host = "test-redis-host" test_redis_port = "test-redis-port" test_clients_info = [ @@ -120,6 +128,8 @@ def test_start_fail_unsupported_server_model() -> None: test_model, test_server_address, test_n_server_rounds, + test_batch_size, + test_local_epochs, test_redis_host, test_redis_port, json.dumps(test_clients_info), @@ -137,6 +147,8 @@ def test_start_fail_unsupported_client() -> None: test_model = "MNIST" test_server_address = "test-server-address" test_n_server_rounds = 2 + test_batch_size = 8 + test_local_epochs = 1 test_redis_host = "test-redis-host" test_redis_port = "test-redis-port" test_clients_info = [ @@ -160,6 +172,8 @@ def test_start_fail_unsupported_client() -> None: test_model, test_server_address, test_n_server_rounds, + test_batch_size, + test_local_epochs, test_redis_host, test_redis_port, json.dumps(test_clients_info), @@ -178,6 +192,8 @@ def test_start_training_launch_server_exception(mock_launch_local_server: Mock) test_model = "MNIST" test_server_address = "test-server-address" test_n_server_rounds = 2 + test_batch_size = 8 + test_local_epochs = 1 test_redis_host = "test-redis-host" test_redis_port = "test-redis-port" test_clients_info = [ @@ -203,6 +219,8 @@ def test_start_training_launch_server_exception(mock_launch_local_server: Mock) test_model, test_server_address, test_n_server_rounds, + test_batch_size, + test_local_epochs, test_redis_host, test_redis_port, json.dumps(test_clients_info), @@ -221,6 +239,8 @@ def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_local_ser test_model = "MNIST" test_server_address = "test-server-address" test_n_server_rounds = 2 + test_batch_size = 8 + test_local_epochs = 1 test_redis_host = "test-redis-host" test_redis_port = "test-redis-port" test_clients_info = [ @@ -249,6 +269,8 @@ def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_local_ser test_model, test_server_address, test_n_server_rounds, + test_batch_size, + test_local_epochs, test_redis_host, test_redis_port, json.dumps(test_clients_info), @@ -268,6 +290,8 @@ def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_launch_lo test_model = "MNIST" test_server_address = "test-server-address" test_n_server_rounds = 2 + test_batch_size = 8 + test_local_epochs = 1 test_redis_host = "test-redis-host" test_redis_port = "test-redis-port" test_clients_info = [ @@ -297,6 +321,8 @@ def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_launch_lo test_model, test_server_address, test_n_server_rounds, + test_batch_size, + test_local_epochs, test_redis_host, test_redis_port, json.dumps(test_clients_info), @@ -316,6 +342,8 @@ def test_start_training_fail_response(mock_requests: Mock, mock_redis: Mock, moc test_model = "MNIST" test_server_address = "test-server-address" test_n_server_rounds = 2 + test_batch_size = 8 + test_local_epochs = 1 test_redis_host = "test-redis-host" test_redis_port = "test-redis-port" test_clients_info = [ @@ -350,6 +378,8 @@ def test_start_training_fail_response(mock_requests: Mock, mock_redis: Mock, moc test_model, test_server_address, test_n_server_rounds, + test_batch_size, + test_local_epochs, test_redis_host, test_redis_port, json.dumps(test_clients_info), @@ -369,6 +399,8 @@ def test_start_training_no_uuid_in_response(mock_requests: Mock, mock_redis: Moc test_model = "MNIST" test_server_address = "test-server-address" test_n_server_rounds = 2 + test_batch_size = 8 + test_local_epochs = 1 test_redis_host = "test-redis-host" test_redis_port = "test-redis-port" test_clients_info = [ @@ -403,6 +435,8 @@ def test_start_training_no_uuid_in_response(mock_requests: Mock, mock_redis: Moc test_model, test_server_address, test_n_server_rounds, + test_batch_size, + test_local_epochs, test_redis_host, test_redis_port, json.dumps(test_clients_info), From 2368b088822da6505aec0bbda12078770c69c985 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 27 Mar 2024 12:11:52 -0400 Subject: [PATCH 10/24] Fixing additional test --- florist/tests/unit/api/servers/test_launch.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/florist/tests/unit/api/servers/test_launch.py b/florist/tests/unit/api/servers/test_launch.py index 73fed658..17b5a3f5 100644 --- a/florist/tests/unit/api/servers/test_launch.py +++ b/florist/tests/unit/api/servers/test_launch.py @@ -13,6 +13,8 @@ def test_launch_local_server(mock_launch_server: Mock) -> None: test_n_clients = 2 test_server_address = "test-server-address" test_n_server_rounds = 5 + test_batch_size = 8 + test_local_epochs = 1 test_redis_host = "test-redis-host" test_redis_port = "test-redis-port" test_server_process = "test-server-process" @@ -23,6 +25,8 @@ def test_launch_local_server(mock_launch_server: Mock) -> None: test_n_clients, test_server_address, test_n_server_rounds, + test_batch_size, + test_local_epochs, test_redis_host, test_redis_port, ) @@ -41,7 +45,13 @@ def test_launch_local_server(mock_launch_server: Mock) -> None: ) assert call_kwargs == {"seconds_to_sleep": 0} assert call_args[0].func == get_server - assert call_args[0].keywords == {"model": test_model, "n_clients": test_n_clients, "metrics_reporter": ANY} + assert call_args[0].keywords == { + "model": test_model, + "n_clients": test_n_clients, + "batch_size": test_batch_size, + "local_epochs": test_local_epochs, + "metrics_reporter": ANY, + } metrics_reporter = call_args[0].keywords["metrics_reporter"] assert isinstance(metrics_reporter, RedisMetricsReporter) From c7df7af0cc8f8edfb0045809dcea4d301b99b649 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 5 Apr 2024 16:59:17 -0400 Subject: [PATCH 11/24] WIP adding mongodb and create job route --- .gitignore | 1 + CONTRIBUTING.md | 10 +- README.md | 25 ++++- florist/api/clients/mnist.py | 2 +- florist/api/db/entities.py | 33 ++++++ florist/api/routes/job.py | 35 ++++++ florist/api/server.py | 33 ++++-- poetry.lock | 199 ++++++++++++++++++++++++++++------- pyproject.toml | 2 + 9 files changed, 292 insertions(+), 48 deletions(-) create mode 100644 florist/api/db/entities.py create mode 100644 florist/api/routes/job.py diff --git a/.gitignore b/.gitignore index 7d2c7f53..74604df4 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,4 @@ next-env.d.ts /metrics/ /logs/ +/.ruff_cache/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 09cc5a54..95d4d3b6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -73,7 +73,15 @@ For code style, we recommend the [PEP 8 style guide](https://peps.python.org/pep For docstrings we use [numpy format](https://numpydoc.readthedocs.io/en/latest/format.html). We use [ruff](https://docs.astral.sh/ruff/) for code formatting and static code -analysis. Ruff checks various rules including [flake8](https://docs.astral.sh/ruff/faq/#how-does-ruff-compare-to-flake8). The pre-commit hooks show errors which you need to fix before submitting a PR. +analysis. Ruff checks various rules including [flake8](https://docs.astral.sh/ruff/faq/#how-does-ruff-compare-to-flake8). The pre-commit hooks show +errors which you need to fix before submitting a PR. Last but not the least, we use type hints in our code which is then checked using [mypy](https://mypy.readthedocs.io/en/stable/). + +## Documentation + +Backend code API documentation can be found at https://vectorinstitute.github.io/FLorist/. + +Backend REST API documentation can be found at https://localhost:8000/docs once the server +is running. diff --git a/README.md b/README.md index 99a320fc..50b75253 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ yarn ### Pulling Redis' Docker -Redis is used to fetch the metrics reported by servers and clients during their runs. +[Redis](https://redis.io/) is used to fetch the metrics reported by servers and clients during their runs. If you don't have Docker installed, follow [these instructions](https://docs.docker.com/desktop/) @@ -47,8 +47,31 @@ to install it. Then, pull [Redis' official docker image](https://hub.docker.com/ docker pull redis:7.2.4 ``` +### Pulling MongoDB's Docker + +[MongoDB](https://www.mongodb.com) is used to store information about the training jobs. + +If you don't have Docker installed, follow [these instructions](https://docs.docker.com/desktop/) +to install it. Then, pull [MongoDB' official docker image](https://hub.docker.com/_/mongo) +(we currently use version 7.0.8): +```shell +docker pull mongo:7.0.8 +``` + ## Running the server +### Start MongoDB's instance + +If it's your first time running it, create a container and run it with the command below: +```shell +docker run --name mongodb-florist -d -p 27017:27017 mongo:7.0.8 +``` + +From the second time on, you can just start it: +```shell +docker start mongodb-florist +``` + ### Start server's Redis instance If it's your first time running it, create a container and run it with the command below: diff --git a/florist/api/clients/mnist.py b/florist/api/clients/mnist.py index f05733c1..575fc27a 100644 --- a/florist/api/clients/mnist.py +++ b/florist/api/clients/mnist.py @@ -14,7 +14,7 @@ from florist.api.models.mnist import MnistNet -class MnistClient(BasicClient): # type: ignore +class MnistClient(BasicClient): # type: ignore[misc] """Implementation of the MNIST client.""" def get_data_loaders(self, config: Config) -> Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]: diff --git a/florist/api/db/entities.py b/florist/api/db/entities.py new file mode 100644 index 00000000..9d0ade5a --- /dev/null +++ b/florist/api/db/entities.py @@ -0,0 +1,33 @@ +"""Definitions for the MongoDB database entities.""" +import uuid +from typing import Optional + +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +from florist.api.servers.common import Model + + +JOB_DATABASE_NAME = "job" + + +class Job(BaseModel): + """Define the Job DB entity.""" + + id: str = Field(default_factory=uuid.uuid4, alias="_id") + model: Optional[Annotated[Model, Field(...)]] + redis_host: Optional[Annotated[str, Field(...)]] + redis_port: Optional[Annotated[str, Field(...)]] + + class Config: + """MongoDB config for the Job DB entity.""" + + allow_population_by_field_name = True + schema_extra = { + "example": { + "_id": "066de609-b04a-4b30-b46c-32537c7f1f6e", + "model": "MNIST", + "redis_host": "locahost", + "redis_port": "6879", + }, + } diff --git a/florist/api/routes/job.py b/florist/api/routes/job.py new file mode 100644 index 00000000..5d2be865 --- /dev/null +++ b/florist/api/routes/job.py @@ -0,0 +1,35 @@ +"""The /job FastAPI routes.""" +from typing import Any, Dict + +from fastapi import APIRouter, Body, Request, status +from fastapi.encoders import jsonable_encoder + +from florist.api.db.entities import JOB_DATABASE_NAME, Job + + +router = APIRouter() + + +@router.post( + path="/", + response_description="Create a new job", + status_code=status.HTTP_201_CREATED, + response_model=Job, +) +def new_job(request: Request, job: Job = Body(...)) -> Dict[str, Any]: # noqa: B008 + """ + Create a new training job. + + If calling from the REST API, it will receive the job attributes as the Request Body in raw/JSON format. + See `florist.api.db.entities.Job` to check the list of attributes and their requirements. + + :param request: (fastapi.Request) the FastAPI request object. + :param job: (Job) The Job instance to be saved in the database. + :return: (Dict[str, Any]) A dictionary with the attributes of the new Job instance as saved in the database. + """ + job = jsonable_encoder(job) + new_job = request.app.database[JOB_DATABASE_NAME].insert_one(job) + created_job = request.app.database[JOB_DATABASE_NAME].find_one({"_id": new_job.inserted_id}) + + assert isinstance(created_job, dict) + return created_job diff --git a/florist/api/server.py b/florist/api/server.py index 5982e0d2..440b2590 100644 --- a/florist/api/server.py +++ b/florist/api/server.py @@ -5,19 +5,38 @@ import requests from fastapi import FastAPI, Form from fastapi.responses import JSONResponse +from pymongo import MongoClient from typing_extensions import Annotated from florist.api.monitoring.metrics import wait_for_metric +from florist.api.routes.job import router as job_router from florist.api.servers.common import ClientInfo, ClientInfoParseError, Model from florist.api.servers.launch import launch_local_server app = FastAPI() +app.include_router(job_router, tags=["job"], prefix="/job") + LOGGER = logging.getLogger("uvicorn.error") +MONGODB_URI = "mongodb://localhost:27017/" +DATABASE_NAME = "florist-server" START_CLIENT_API = "api/client/start" +@app.on_event("startup") +def startup_db_client() -> None: + """Start up the MongoDB client.""" + app.mongodb_client = MongoClient(MONGODB_URI) # type: ignore[attr-defined] + app.database = app.mongodb_client[DATABASE_NAME] # type: ignore[attr-defined] + + +@app.on_event("shutdown") +def shutdown_db_client() -> None: + """Shut down the MongoDB client.""" + app.mongodb_client.close() # type: ignore[attr-defined] + + @app.post("/api/server/start_training") def start_training( model: Annotated[str, Form()], @@ -35,20 +54,20 @@ def start_training( Should be called with a POST request and the parameters should be contained in the request's form. :param model: (str) The name of the model to train. Should be one of the values in the enum - florist.api.servers.common.Model + `florist.api.servers.common.Model` :param server_address: (str) The address of the FL server to be started. It should be comprised of - the host name and port separated by colon (e.g. "localhost:8080") + the host name and port separated by colon (e.g. `localhost:8080`) :param n_server_rounds: (int) The number of rounds the FL server should run. :param batch_size: (int) The size of the batch for training :param local_epochs: (int) The number of epochs to run by the clients :param redis_host: (str) The host name for the Redis instance for metrics reporting. :param redis_port: (str) The port for the Redis instance for metrics reporting. :param clients_info: (str) A JSON string containing the client information. It will be parsed by - florist.api.servers.common.ClientInfo and should be in the following format: + `florist.api.servers.common.ClientInfo` and should be in the following format: [ { - "client": , - "client_address": , + "client": , + "client_address": , "data_path": , "redis_host": , "redis_port": , @@ -60,8 +79,8 @@ def start_training( "server_uuid": , "client_uuids": [, , ..., ], } - If not successful, returns the appropriate error code with a JSON with the format below: - {"error": } + If not successful, returns the appropriate error code with a JSON with the format: + `{"error": }` """ try: # Parse input data diff --git a/poetry.lock b/poetry.lock index 0c5d8ad3..bab05b8d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -954,6 +954,26 @@ files = [ {file = "dm_tree-0.1.8-cp39-cp39-win_amd64.whl", hash = "sha256:8ed3564abed97c806db122c2d3e1a2b64c74a63debe9903aad795167cc301368"}, ] +[[package]] +name = "dnspython" +version = "2.6.1" +description = "DNS toolkit" +optional = false +python-versions = ">=3.8" +files = [ + {file = "dnspython-2.6.1-py3-none-any.whl", hash = "sha256:5ef3b9680161f6fa89daf8ad451b5f1a33b18ae8a1c6778cdf4b43f08c0a6e50"}, + {file = "dnspython-2.6.1.tar.gz", hash = "sha256:e8f0f9c23a7b7cb99ded64e6c3a6f3e701d78f50c55e002b839dea7225cff7cc"}, +] + +[package.extras] +dev = ["black (>=23.1.0)", "coverage (>=7.0)", "flake8 (>=7)", "mypy (>=1.8)", "pylint (>=3)", "pytest (>=7.4)", "pytest-cov (>=4.1.0)", "sphinx (>=7.2.0)", "twine (>=4.0.0)", "wheel (>=0.42.0)"] +dnssec = ["cryptography (>=41)"] +doh = ["h2 (>=4.1.0)", "httpcore (>=1.0.0)", "httpx (>=0.26.0)"] +doq = ["aioquic (>=0.9.25)"] +idna = ["idna (>=3.6)"] +trio = ["trio (>=0.23)"] +wmi = ["wmi (>=1.5.1)"] + [[package]] name = "docker-pycreds" version = "0.4.0" @@ -3346,47 +3366,47 @@ xgboost = ["xgboost (>=1.5.2,<2.0.0)"] [[package]] name = "pydantic" -version = "1.10.14" +version = "1.10.15" description = "Data validation and settings management using python type hints" optional = false python-versions = ">=3.7" files = [ - {file = "pydantic-1.10.14-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7f4fcec873f90537c382840f330b90f4715eebc2bc9925f04cb92de593eae054"}, - {file = "pydantic-1.10.14-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8e3a76f571970fcd3c43ad982daf936ae39b3e90b8a2e96c04113a369869dc87"}, - {file = "pydantic-1.10.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82d886bd3c3fbeaa963692ef6b643159ccb4b4cefaf7ff1617720cbead04fd1d"}, - {file = "pydantic-1.10.14-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:798a3d05ee3b71967844a1164fd5bdb8c22c6d674f26274e78b9f29d81770c4e"}, - {file = "pydantic-1.10.14-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:23d47a4b57a38e8652bcab15a658fdb13c785b9ce217cc3a729504ab4e1d6bc9"}, - {file = "pydantic-1.10.14-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f9f674b5c3bebc2eba401de64f29948ae1e646ba2735f884d1594c5f675d6f2a"}, - {file = "pydantic-1.10.14-cp310-cp310-win_amd64.whl", hash = "sha256:24a7679fab2e0eeedb5a8924fc4a694b3bcaac7d305aeeac72dd7d4e05ecbebf"}, - {file = "pydantic-1.10.14-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9d578ac4bf7fdf10ce14caba6f734c178379bd35c486c6deb6f49006e1ba78a7"}, - {file = "pydantic-1.10.14-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fa7790e94c60f809c95602a26d906eba01a0abee9cc24150e4ce2189352deb1b"}, - {file = "pydantic-1.10.14-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aad4e10efa5474ed1a611b6d7f0d130f4aafadceb73c11d9e72823e8f508e663"}, - {file = "pydantic-1.10.14-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1245f4f61f467cb3dfeced2b119afef3db386aec3d24a22a1de08c65038b255f"}, - {file = "pydantic-1.10.14-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:21efacc678a11114c765eb52ec0db62edffa89e9a562a94cbf8fa10b5db5c046"}, - {file = "pydantic-1.10.14-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:412ab4a3f6dbd2bf18aefa9f79c7cca23744846b31f1d6555c2ee2b05a2e14ca"}, - {file = "pydantic-1.10.14-cp311-cp311-win_amd64.whl", hash = "sha256:e897c9f35281f7889873a3e6d6b69aa1447ceb024e8495a5f0d02ecd17742a7f"}, - {file = "pydantic-1.10.14-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:d604be0f0b44d473e54fdcb12302495fe0467c56509a2f80483476f3ba92b33c"}, - {file = "pydantic-1.10.14-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a42c7d17706911199798d4c464b352e640cab4351efe69c2267823d619a937e5"}, - {file = "pydantic-1.10.14-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:596f12a1085e38dbda5cbb874d0973303e34227b400b6414782bf205cc14940c"}, - {file = "pydantic-1.10.14-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:bfb113860e9288d0886e3b9e49d9cf4a9d48b441f52ded7d96db7819028514cc"}, - {file = "pydantic-1.10.14-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:bc3ed06ab13660b565eed80887fcfbc0070f0aa0691fbb351657041d3e874efe"}, - {file = "pydantic-1.10.14-cp37-cp37m-win_amd64.whl", hash = "sha256:ad8c2bc677ae5f6dbd3cf92f2c7dc613507eafe8f71719727cbc0a7dec9a8c01"}, - {file = "pydantic-1.10.14-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c37c28449752bb1f47975d22ef2882d70513c546f8f37201e0fec3a97b816eee"}, - {file = "pydantic-1.10.14-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49a46a0994dd551ec051986806122767cf144b9702e31d47f6d493c336462597"}, - {file = "pydantic-1.10.14-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53e3819bd20a42470d6dd0fe7fc1c121c92247bca104ce608e609b59bc7a77ee"}, - {file = "pydantic-1.10.14-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0fbb503bbbbab0c588ed3cd21975a1d0d4163b87e360fec17a792f7d8c4ff29f"}, - {file = "pydantic-1.10.14-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:336709883c15c050b9c55a63d6c7ff09be883dbc17805d2b063395dd9d9d0022"}, - {file = "pydantic-1.10.14-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4ae57b4d8e3312d486e2498d42aed3ece7b51848336964e43abbf9671584e67f"}, - {file = "pydantic-1.10.14-cp38-cp38-win_amd64.whl", hash = "sha256:dba49d52500c35cfec0b28aa8b3ea5c37c9df183ffc7210b10ff2a415c125c4a"}, - {file = "pydantic-1.10.14-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c66609e138c31cba607d8e2a7b6a5dc38979a06c900815495b2d90ce6ded35b4"}, - {file = "pydantic-1.10.14-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d986e115e0b39604b9eee3507987368ff8148222da213cd38c359f6f57b3b347"}, - {file = "pydantic-1.10.14-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:646b2b12df4295b4c3148850c85bff29ef6d0d9621a8d091e98094871a62e5c7"}, - {file = "pydantic-1.10.14-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:282613a5969c47c83a8710cc8bfd1e70c9223feb76566f74683af889faadc0ea"}, - {file = "pydantic-1.10.14-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:466669501d08ad8eb3c4fecd991c5e793c4e0bbd62299d05111d4f827cded64f"}, - {file = "pydantic-1.10.14-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:13e86a19dca96373dcf3190fcb8797d40a6f12f154a244a8d1e8e03b8f280593"}, - {file = "pydantic-1.10.14-cp39-cp39-win_amd64.whl", hash = "sha256:08b6ec0917c30861e3fe71a93be1648a2aa4f62f866142ba21670b24444d7fd8"}, - {file = "pydantic-1.10.14-py3-none-any.whl", hash = "sha256:8ee853cd12ac2ddbf0ecbac1c289f95882b2d4482258048079d13be700aa114c"}, - {file = "pydantic-1.10.14.tar.gz", hash = "sha256:46f17b832fe27de7850896f3afee50ea682220dd218f7e9c88d436788419dca6"}, + {file = "pydantic-1.10.15-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:22ed12ee588b1df028a2aa5d66f07bf8f8b4c8579c2e96d5a9c1f96b77f3bb55"}, + {file = "pydantic-1.10.15-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:75279d3cac98186b6ebc2597b06bcbc7244744f6b0b44a23e4ef01e5683cc0d2"}, + {file = "pydantic-1.10.15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50f1666a9940d3d68683c9d96e39640f709d7a72ff8702987dab1761036206bb"}, + {file = "pydantic-1.10.15-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82790d4753ee5d00739d6cb5cf56bceb186d9d6ce134aca3ba7befb1eedbc2c8"}, + {file = "pydantic-1.10.15-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:d207d5b87f6cbefbdb1198154292faee8017d7495a54ae58db06762004500d00"}, + {file = "pydantic-1.10.15-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e49db944fad339b2ccb80128ffd3f8af076f9f287197a480bf1e4ca053a866f0"}, + {file = "pydantic-1.10.15-cp310-cp310-win_amd64.whl", hash = "sha256:d3b5c4cbd0c9cb61bbbb19ce335e1f8ab87a811f6d589ed52b0254cf585d709c"}, + {file = "pydantic-1.10.15-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c3d5731a120752248844676bf92f25a12f6e45425e63ce22e0849297a093b5b0"}, + {file = "pydantic-1.10.15-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c365ad9c394f9eeffcb30a82f4246c0006417f03a7c0f8315d6211f25f7cb654"}, + {file = "pydantic-1.10.15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3287e1614393119c67bd4404f46e33ae3be3ed4cd10360b48d0a4459f420c6a3"}, + {file = "pydantic-1.10.15-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:be51dd2c8596b25fe43c0a4a59c2bee4f18d88efb8031188f9e7ddc6b469cf44"}, + {file = "pydantic-1.10.15-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:6a51a1dd4aa7b3f1317f65493a182d3cff708385327c1c82c81e4a9d6d65b2e4"}, + {file = "pydantic-1.10.15-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4e316e54b5775d1eb59187f9290aeb38acf620e10f7fd2f776d97bb788199e53"}, + {file = "pydantic-1.10.15-cp311-cp311-win_amd64.whl", hash = "sha256:0d142fa1b8f2f0ae11ddd5e3e317dcac060b951d605fda26ca9b234b92214986"}, + {file = "pydantic-1.10.15-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7ea210336b891f5ea334f8fc9f8f862b87acd5d4a0cbc9e3e208e7aa1775dabf"}, + {file = "pydantic-1.10.15-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3453685ccd7140715e05f2193d64030101eaad26076fad4e246c1cc97e1bb30d"}, + {file = "pydantic-1.10.15-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bea1f03b8d4e8e86702c918ccfd5d947ac268f0f0cc6ed71782e4b09353b26f"}, + {file = "pydantic-1.10.15-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:005655cabc29081de8243126e036f2065bd7ea5b9dff95fde6d2c642d39755de"}, + {file = "pydantic-1.10.15-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:af9850d98fc21e5bc24ea9e35dd80a29faf6462c608728a110c0a30b595e58b7"}, + {file = "pydantic-1.10.15-cp37-cp37m-win_amd64.whl", hash = "sha256:d31ee5b14a82c9afe2bd26aaa405293d4237d0591527d9129ce36e58f19f95c1"}, + {file = "pydantic-1.10.15-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5e09c19df304b8123938dc3c53d3d3be6ec74b9d7d0d80f4f4b5432ae16c2022"}, + {file = "pydantic-1.10.15-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7ac9237cd62947db00a0d16acf2f3e00d1ae9d3bd602b9c415f93e7a9fc10528"}, + {file = "pydantic-1.10.15-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:584f2d4c98ffec420e02305cf675857bae03c9d617fcfdc34946b1160213a948"}, + {file = "pydantic-1.10.15-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bbc6989fad0c030bd70a0b6f626f98a862224bc2b1e36bfc531ea2facc0a340c"}, + {file = "pydantic-1.10.15-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d573082c6ef99336f2cb5b667b781d2f776d4af311574fb53d908517ba523c22"}, + {file = "pydantic-1.10.15-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6bd7030c9abc80134087d8b6e7aa957e43d35714daa116aced57269a445b8f7b"}, + {file = "pydantic-1.10.15-cp38-cp38-win_amd64.whl", hash = "sha256:3350f527bb04138f8aff932dc828f154847fbdc7a1a44c240fbfff1b57f49a12"}, + {file = "pydantic-1.10.15-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:51d405b42f1b86703555797270e4970a9f9bd7953f3990142e69d1037f9d9e51"}, + {file = "pydantic-1.10.15-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a980a77c52723b0dc56640ced396b73a024d4b74f02bcb2d21dbbac1debbe9d0"}, + {file = "pydantic-1.10.15-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67f1a1fb467d3f49e1708a3f632b11c69fccb4e748a325d5a491ddc7b5d22383"}, + {file = "pydantic-1.10.15-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:676ed48f2c5bbad835f1a8ed8a6d44c1cd5a21121116d2ac40bd1cd3619746ed"}, + {file = "pydantic-1.10.15-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:92229f73400b80c13afcd050687f4d7e88de9234d74b27e6728aa689abcf58cc"}, + {file = "pydantic-1.10.15-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2746189100c646682eff0bce95efa7d2e203420d8e1c613dc0c6b4c1d9c1fde4"}, + {file = "pydantic-1.10.15-cp39-cp39-win_amd64.whl", hash = "sha256:394f08750bd8eaad714718812e7fab615f873b3cdd0b9d84e76e51ef3b50b6b7"}, + {file = "pydantic-1.10.15-py3-none-any.whl", hash = "sha256:28e552a060ba2740d0d2aabe35162652c1459a0b9069fe0db7f4ee0e18e74d58"}, + {file = "pydantic-1.10.15.tar.gz", hash = "sha256:ca832e124eda231a60a041da4f013e3ff24949d94a01154b137fc2f2a43c3ffb"}, ] [package.dependencies] @@ -3448,6 +3468,109 @@ typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\"" spelling = ["pyenchant (>=3.2,<4.0)"] testutils = ["gitpython (>3)"] +[[package]] +name = "pymongo" +version = "4.6.3" +description = "Python driver for MongoDB " +optional = false +python-versions = ">=3.7" +files = [ + {file = "pymongo-4.6.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e344d0afdd7c06c1f1e66a4736593293f432defc2191e6b411fc9c82fa8c5adc"}, + {file = "pymongo-4.6.3-cp310-cp310-manylinux1_i686.whl", hash = "sha256:731a92dfc4022db763bfa835c6bd160f2d2cba6ada75749c2ed500e13983414b"}, + {file = "pymongo-4.6.3-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:c4726e36a2f7e92f09f5b8e92ba4db7525daffe31a0dcbcf0533edc0ade8c7d8"}, + {file = "pymongo-4.6.3-cp310-cp310-manylinux2014_i686.whl", hash = "sha256:00e6cfce111883ca63a3c12878286e0b89871f4b840290e61fb6f88ee0e687be"}, + {file = "pymongo-4.6.3-cp310-cp310-manylinux2014_ppc64le.whl", hash = "sha256:cc7a26edf79015c58eea46feb5b262cece55bc1d4929a8a9e0cbe7e6d6a9b0eb"}, + {file = "pymongo-4.6.3-cp310-cp310-manylinux2014_s390x.whl", hash = "sha256:4955be64d943b30f2a7ff98d818ca530f7cb37450bc6b32c37e0e74821907ef8"}, + {file = "pymongo-4.6.3-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:af039afc6d787502c02089759778b550cb2f25dbe2780f5b050a2e37031c3fbf"}, + {file = "pymongo-4.6.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ccc15a7c7a99aed7d0831eaf78a607f1db0c7a255f96e3d18984231acd72f70c"}, + {file = "pymongo-4.6.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8e97c138d811e9367723fcd07c4402a9211caae20479fdd6301d57762778a69f"}, + {file = "pymongo-4.6.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ebcc145c74d06296ce0cad35992185064e5cb2aadef719586778c144f0cd4d37"}, + {file = "pymongo-4.6.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:664c64b6bdb31aceb80f0556951e5e2bf50d359270732268b4e7af00a1cf5d6c"}, + {file = "pymongo-4.6.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4056bc421d4df2c61db4e584415f2b0f1eebb92cbf9222f7f38303467c37117"}, + {file = "pymongo-4.6.3-cp310-cp310-win32.whl", hash = "sha256:cdbea2aac1a4caa66ee912af3601557d2bda2f9f69feec83601c78c7e53ece64"}, + {file = "pymongo-4.6.3-cp310-cp310-win_amd64.whl", hash = "sha256:6cec7279e5a1b74b257d0270a8c97943d745811066630a6bc6beb413c68c6a33"}, + {file = "pymongo-4.6.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:138b9fa18d40401c217bc038a48bcde4160b02d36d8632015b1804971a2eaa2f"}, + {file = "pymongo-4.6.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:60931b0e07448afe8866ffff764cd5bf4b1a855dc84c7dcb3974c6aa6a377a59"}, + {file = "pymongo-4.6.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9b35f8bded43ff91475305445fedf0613f880ff7e25c75ae1028e1260a9b7a86"}, + {file = "pymongo-4.6.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:872bad5c83f7eec9da11e1fef5f858c6a4c79fe4a83c7780e7b0fe95d560ae3f"}, + {file = "pymongo-4.6.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2ad3e5bfcd345c0bfe9af69a82d720860b5b043c1657ffb513c18a0dee19c19"}, + {file = "pymongo-4.6.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0e208f2ab7b495eff8fd175022abfb0abce6307ac5aee3f4de51fc1a459b71c9"}, + {file = "pymongo-4.6.3-cp311-cp311-win32.whl", hash = "sha256:4670edbb5ddd71a4d555668ef99b032a5f81b59e4145d66123aa0d831eac7883"}, + {file = "pymongo-4.6.3-cp311-cp311-win_amd64.whl", hash = "sha256:1c2761302b6cbfd12e239ce1b8061d4cf424a361d199dcb32da534985cae9350"}, + {file = "pymongo-4.6.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:722f2b709b63311c0efda4fa4c603661faa4bec6bad24a6cc41a3bc6d841bf09"}, + {file = "pymongo-4.6.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:994386a4d6ad39e18bcede6dc8d1d693ec3ed897b88f86b1841fbc37227406da"}, + {file = "pymongo-4.6.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:391aea047bba928006114282f175bc8d09c53fe1b7d8920bf888325e229302fe"}, + {file = "pymongo-4.6.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f4330c022024e7994b630199cdae909123e4b0e9cf15335de71b146c0f6a2435"}, + {file = "pymongo-4.6.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01277a7e183c59081368e4efbde2b8f577014431b257959ca98d3a4e8682dd51"}, + {file = "pymongo-4.6.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d30d5d7963453b478016bf7b0d87d7089ca24d93dbdecfbc9aa32f1b4772160a"}, + {file = "pymongo-4.6.3-cp312-cp312-win32.whl", hash = "sha256:a023804a3ac0f85d4510265b60978522368b5815772262e61e3a2222a8b315c9"}, + {file = "pymongo-4.6.3-cp312-cp312-win_amd64.whl", hash = "sha256:2a6ae9a600bbc2dbff719c98bf5da584fb8a4f2bb23729a09be2e9c3dbc61c8a"}, + {file = "pymongo-4.6.3-cp37-cp37m-macosx_10_6_intel.whl", hash = "sha256:3b909e5b1864de01510079b39bbdc480720c37747be5552b354bc73f02c24a3c"}, + {file = "pymongo-4.6.3-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:48c60bd32ec141c0d45d8471179430003d9fb4490da181b8165fb1dce9cc255c"}, + {file = "pymongo-4.6.3-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:36d7049fc183fe4edda3eae7f66ea14c660921429e082fe90b4b7f4dc6664a70"}, + {file = "pymongo-4.6.3-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:18e5c161b18660f1c9d1f78236de45520a436be65e42b7bb51f25f74ad22bdde"}, + {file = "pymongo-4.6.3-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:e458e6fc2b7dd40d15cda04898bd2d8c9ff7ae086c516bc261628d54eb4e3158"}, + {file = "pymongo-4.6.3-cp37-cp37m-manylinux2014_ppc64le.whl", hash = "sha256:e420e74c6db4594a6d09f39b58c0772679006cb0b4fc40901ba608794d87dad2"}, + {file = "pymongo-4.6.3-cp37-cp37m-manylinux2014_s390x.whl", hash = "sha256:9c9340c7161e112e36ebb97fbba1cdbe7db3dfacb694d2918b1f155a01f3d859"}, + {file = "pymongo-4.6.3-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:26d036e0f5de09d0b21d0fc30314fcf2ae6359e4d43ae109aa6cf27b4ce02d30"}, + {file = "pymongo-4.6.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7cf28d9c90e40d4e385b858e4095739829f466f23e08674085161d86bb4bb10"}, + {file = "pymongo-4.6.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9066dff9dc0a182478ca5885d0b8a2b820b462e19459ada109df7a3ced31b272"}, + {file = "pymongo-4.6.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e1e1586ebdebe0447a24842480defac17c496430a218486c96e2da3f164c0f05"}, + {file = "pymongo-4.6.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b3853fb66bf34ce1b6e573e1bbb3cb28763be9d1f57758535757faf1ab2f24a"}, + {file = "pymongo-4.6.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:462684a6f5ce6f2661c30eab4d1d459231e0eed280f338e716e31a24fc09ccb3"}, + {file = "pymongo-4.6.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0a4ea44e5a913bdb7c9abd34c69e9fcfac10dfaf49765463e0dc1ea922dd2a9d"}, + {file = "pymongo-4.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:098d420a8214ad25f872de7e8b309441995d12ece0376218a04d9ed5d2222cf3"}, + {file = "pymongo-4.6.3-cp37-cp37m-win32.whl", hash = "sha256:7330245253fbe2e09845069d2f4d35dd27f63e377034c94cb0ddac18bc8b0d82"}, + {file = "pymongo-4.6.3-cp37-cp37m-win_amd64.whl", hash = "sha256:151361c101600a85cb1c1e0db4e4b28318b521fcafa9b62d389f7342faaaee80"}, + {file = "pymongo-4.6.3-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:4d167d546352869125dc86f6fda6dffc627d8a9c8963eaee665825f2520d542b"}, + {file = "pymongo-4.6.3-cp38-cp38-manylinux1_i686.whl", hash = "sha256:eaf3d594ebfd5e1f3503d81e06a5d78e33cda27418b36c2491c3d4ad4fca5972"}, + {file = "pymongo-4.6.3-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:7ee79e02a7c5ed34706ecb5dad19e6c7d267cf86d28c075ef3127c58f3081279"}, + {file = "pymongo-4.6.3-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:af5c5112db04cf62a5d9d224a24f289aaecb47d152c08a457cca81cee061d5bd"}, + {file = "pymongo-4.6.3-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:6b5aec78aa4840e8d6c3881900259892ab5733a366696ca10d99d68c3d73eaaf"}, + {file = "pymongo-4.6.3-cp38-cp38-manylinux2014_ppc64le.whl", hash = "sha256:9757602fb45c8ecc1883fe6db7c59c19d87eb3c645ec9342d28a6026837da931"}, + {file = "pymongo-4.6.3-cp38-cp38-manylinux2014_s390x.whl", hash = "sha256:dde9fb6e105ce054339256a8b7a9775212ebb29596ef4e402d7bbc63b354d202"}, + {file = "pymongo-4.6.3-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:7df8b166d3db6cfead4cf55b481408d8f0935d8bd8d6dbf64507c49ef82c7200"}, + {file = "pymongo-4.6.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:53451190b8628e1ce7d1fe105dc376c3f10705127bd3b51fe3e107b9ff1851e6"}, + {file = "pymongo-4.6.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:75107a386d4ccf5291e75cce8ca3898430e7907f4cc1208a17c9efad33a1ea84"}, + {file = "pymongo-4.6.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4a0660ce32d8459b7f12dc3ca0141528fead62d3cce31b548f96f30902074cc0"}, + {file = "pymongo-4.6.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa310096450e9c461b7dfd66cbc1c41771fe36c06200440bb3e062b1d4a06b6e"}, + {file = "pymongo-4.6.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5f465cca9b178e7bb782f952dd58e9e92f8ba056e585959465f2bb50feddef5f"}, + {file = "pymongo-4.6.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c67c19f653053ef2ebd7f1837c2978400058d6d7f66ec5760373a21eaf660158"}, + {file = "pymongo-4.6.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:c701de8e483fb5e53874aab642235361aac6de698146b02c644389eaa8c137b6"}, + {file = "pymongo-4.6.3-cp38-cp38-win32.whl", hash = "sha256:90525454546536544307e6da9c81f331a71a1b144e2d038fec587cc9f9250285"}, + {file = "pymongo-4.6.3-cp38-cp38-win_amd64.whl", hash = "sha256:3e1ba5a037c526a3f4060c28f8d45d71ed9626e2bf954b0cd9a8dcc3b45172ee"}, + {file = "pymongo-4.6.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:14a82593528cddc93cfea5ee78fac95ae763a3a4e124ca79ee0b24fbbc6da1c9"}, + {file = "pymongo-4.6.3-cp39-cp39-manylinux1_i686.whl", hash = "sha256:cd6c15242d9306ff1748681c3235284cbe9f807aeaa86cd17d85e72af626e9a7"}, + {file = "pymongo-4.6.3-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:6de33f1b2eed91b802ec7abeb92ffb981d052f3604b45588309aae9e0f6e3c02"}, + {file = "pymongo-4.6.3-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:0182899aafe830f25cf96c5976d724efeaaf7b6646c15424ad8dd25422b2efe1"}, + {file = "pymongo-4.6.3-cp39-cp39-manylinux2014_i686.whl", hash = "sha256:8d0ea740a2faa56f930dc82c5976d96c017ece26b29a1cddafb58721c7aab960"}, + {file = "pymongo-4.6.3-cp39-cp39-manylinux2014_ppc64le.whl", hash = "sha256:5c8a4982f5eb767c6fbfb8fb378683d09bcab7c3251ba64357eef600d43f6c23"}, + {file = "pymongo-4.6.3-cp39-cp39-manylinux2014_s390x.whl", hash = "sha256:becfa816545a48c8e740ac2fd624c1c121e1362072d68ffcf37a6b1be8ea187e"}, + {file = "pymongo-4.6.3-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:ff7d1f449fcad23d9bc8e8dc2b9972be38bcd76d99ea5f7d29b2efa929c2a7ff"}, + {file = "pymongo-4.6.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e097f877de4d6af13a33ef938bf2a2350f424be5deabf8b857da95f5b080487a"}, + {file = "pymongo-4.6.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:705a9bfd619301ee7e985d6f91f68b15dfcb2f6f36b8cc225cc82d4260d2bce5"}, + {file = "pymongo-4.6.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2ef1b4992ee1cb8bb16745e70afa0c02c5360220a7a8bb4775888721f052d0a6"}, + {file = "pymongo-4.6.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3d10bdd46cbc35a2109737d36ffbef32e7420569a87904738ad444ccb7ac2c5"}, + {file = "pymongo-4.6.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:17c1c143ba77d6e21fc8b48e93f0a5ed982a23447434e9ee4fbb6d633402506b"}, + {file = "pymongo-4.6.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9e51e30d67b468a2a634ade928b30cb3e420127f148a9aec60de33f39087bdc4"}, + {file = "pymongo-4.6.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:bec8e4e88984be157408f1923d25869e1b575c07711cdbdde596f66931800934"}, + {file = "pymongo-4.6.3-cp39-cp39-win32.whl", hash = "sha256:98877a9c4ad42df8253a12d8d17a3265781d1feb5c91c767bd153f88feb0b670"}, + {file = "pymongo-4.6.3-cp39-cp39-win_amd64.whl", hash = "sha256:6d5b35da9e16cda630baed790ffc3d0d01029d269523a7cec34d2ec7e6823e75"}, + {file = "pymongo-4.6.3.tar.gz", hash = "sha256:400074090b9a631f120b42c61b222fd743490c133a5d2f99c0208cefcccc964e"}, +] + +[package.dependencies] +dnspython = ">=1.16.0,<3.0.0" + +[package.extras] +aws = ["pymongo-auth-aws (<2.0.0)"] +encryption = ["certifi", "pymongo[aws]", "pymongocrypt (>=1.6.0,<2.0.0)"] +gssapi = ["pykerberos", "winkerberos (>=0.5.0)"] +ocsp = ["certifi", "cryptography (>=2.5)", "pyopenssl (>=17.2.0)", "requests (<3.0.0)", "service-identity (>=18.1.0)"] +snappy = ["python-snappy"] +test = ["pytest (>=7)"] +zstd = ["zstandard"] + [[package]] name = "pyparsing" version = "3.1.1" @@ -5527,4 +5650,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.11" -content-hash = "7cbd2fbe52c7deb241805f1dc3fa1044512bd069ba8482d930a01c28680b454f" +content-hash = "101ed669ff9ee827aee599343ce6fd6470c42d5d5af2f2e6dbc4ed7893a5b5e3" diff --git a/pyproject.toml b/pyproject.toml index bd7f7d35..b039e97c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,8 @@ wandb = "^0.16.3" torchvision = "0.14.1" redis = "^5.0.1" python-multipart = "^0.0.9" +pymongo = {extras = ["srv"], version = "^4.6.3"} +pydantic = "^1.10.15" [tool.poetry.group.test] optional = true From eae71a67c9100cb264c5b1b4b8b11f6679523676 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Mon, 8 Apr 2024 13:14:06 -0400 Subject: [PATCH 12/24] WIP adding a test, need more setup --- .github/workflows/integration_tests.yaml | 4 ++ CONTRIBUTING.md | 13 +++-- florist/api/server.py | 4 +- florist/tests/integration/api/test_routes.py | 58 ++++++++++++++++++++ 4 files changed, 72 insertions(+), 7 deletions(-) create mode 100644 florist/tests/integration/api/test_routes.py diff --git a/.github/workflows/integration_tests.yaml b/.github/workflows/integration_tests.yaml index 354f1df3..181274a3 100644 --- a/.github/workflows/integration_tests.yaml +++ b/.github/workflows/integration_tests.yaml @@ -46,6 +46,10 @@ jobs: uses: supercharge/redis-github-action@1.2.0 with: redis-version: 7.2.4 + - name: Setup MongoDB + uses: supercharge/mongodb-github-action@1.10.0 + with: + mongodb-version: 7.0.8 - name: Install dependencies and check code run: | poetry env use '3.9' diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 95d4d3b6..5e629821 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -60,8 +60,11 @@ To run the unit tests, simply execute: pytest florist/tests/unit ``` -To run the integration tests, first make sure you have a Redis server running on your -local machine on port 6379, then execute: +To run the integration tests, first make sure you: +- Have a Redis server running on your local machine on port 6379 by following [these instructions](README.md#start-servers-redis-instance). +- Have a MongoDB server running on your local machine on port 27017 by following [these instructions](README.md#start-mongodbs-instance). + +Then execute: ```shell pytest florist/tests/integration ``` @@ -73,8 +76,8 @@ For code style, we recommend the [PEP 8 style guide](https://peps.python.org/pep For docstrings we use [numpy format](https://numpydoc.readthedocs.io/en/latest/format.html). We use [ruff](https://docs.astral.sh/ruff/) for code formatting and static code -analysis. Ruff checks various rules including [flake8](https://docs.astral.sh/ruff/faq/#how-does-ruff-compare-to-flake8). The pre-commit hooks show -errors which you need to fix before submitting a PR. +analysis. Ruff checks various rules including [flake8](https://docs.astral.sh/ruff/faq/#how-does-ruff-compare-to-flake8). The pre-commit hooks +show errors which you need to fix before submitting a PR. Last but not the least, we use type hints in our code which is then checked using [mypy](https://mypy.readthedocs.io/en/stable/). @@ -84,4 +87,4 @@ Last but not the least, we use type hints in our code which is then checked usin Backend code API documentation can be found at https://vectorinstitute.github.io/FLorist/. Backend REST API documentation can be found at https://localhost:8000/docs once the server -is running. +is running locally. diff --git a/florist/api/server.py b/florist/api/server.py index 440b2590..8669caaa 100644 --- a/florist/api/server.py +++ b/florist/api/server.py @@ -26,14 +26,14 @@ @app.on_event("startup") def startup_db_client() -> None: - """Start up the MongoDB client.""" + """Start up the MongoDB client on app startup.""" app.mongodb_client = MongoClient(MONGODB_URI) # type: ignore[attr-defined] app.database = app.mongodb_client[DATABASE_NAME] # type: ignore[attr-defined] @app.on_event("shutdown") def shutdown_db_client() -> None: - """Shut down the MongoDB client.""" + """Shut down the MongoDB client on app shutdown.""" app.mongodb_client.close() # type: ignore[attr-defined] diff --git a/florist/tests/integration/api/test_routes.py b/florist/tests/integration/api/test_routes.py new file mode 100644 index 00000000..6c505e4b --- /dev/null +++ b/florist/tests/integration/api/test_routes.py @@ -0,0 +1,58 @@ +from unittest.mock import ANY + +from pymongo import MongoClient +from starlette.requests import Request + +from florist.api.db.entities import Job +from florist.api.routes.job import new_job +from florist.api.server import MONGODB_URI +from florist.api.servers.common import Model + + +DATABASE_NAME = "test-database" + + +def test_new_job() -> None: + test_empty_job = Job() + result = new_job(MockRequest(), test_empty_job) + + assert result == { + "_id": ANY, + "model": None, + "redis_host": None, + "redis_port": None, + } + assert isinstance(result["_id"], str) + + test_job = Job(id="test-id", model=Model.MNIST, redis_host="test-redis-host", redis_port="test-redis-port") + result = new_job(MockRequest(), test_job) + + assert result == { + "_id": test_job.id, + "model": test_job.model.value, + "redis_host": test_job.redis_host, + "redis_port": test_job.redis_port, + } + assert isinstance(result["_id"], str) + + +# TODO delete database at the end + +class MockApp(): + def __init__(self): + mongo_client = MongoClient(MONGODB_URI) + self.database = mongo_client[DATABASE_NAME] + + +class MockRequest(Request): + def __init__(self): + super().__init__({"type": "http"}) + self._app = MockApp() + + @property + def app(self): + return self._app + + @app.setter + def app(self, value): + self._app = value From 0bfc585070da71cb78385fd735d959c23c05b1c1 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 9 Apr 2024 10:55:19 -0400 Subject: [PATCH 13/24] Small change --- florist/tests/integration/api/test_routes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/florist/tests/integration/api/test_routes.py b/florist/tests/integration/api/test_routes.py index 6c505e4b..ada2adbf 100644 --- a/florist/tests/integration/api/test_routes.py +++ b/florist/tests/integration/api/test_routes.py @@ -38,7 +38,7 @@ def test_new_job() -> None: # TODO delete database at the end -class MockApp(): +class MockApp: def __init__(self): mongo_client = MongoClient(MONGODB_URI) self.database = mongo_client[DATABASE_NAME] From e99ea7f8db477556a6112233552ba91884fff43c Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 9 Apr 2024 11:32:34 -0400 Subject: [PATCH 14/24] CR by John --- .gitignore | 1 + CONTRIBUTING.md | 7 ++ README.md | 5 + florist/api/monitoring/metrics.py | 4 +- florist/api/routes/__init__.py | 1 + florist/api/routes/server/__init__.py | 1 + florist/api/routes/server/training.py | 119 ++++++++++++++++++ florist/api/server.py | 118 +---------------- florist/api/servers/launch.py | 4 +- florist/tests/__init__.py | 0 florist/tests/integration/api/test_train.py | 10 +- florist/tests/integration/api/utils.py | 2 +- florist/tests/unit/api/routes/__init__.py | 0 .../tests/unit/api/routes/server/__init__.py | 0 .../server/test_training.py} | 44 +++---- 15 files changed, 170 insertions(+), 146 deletions(-) create mode 100644 florist/api/routes/__init__.py create mode 100644 florist/api/routes/server/__init__.py create mode 100644 florist/api/routes/server/training.py create mode 100644 florist/tests/__init__.py create mode 100644 florist/tests/unit/api/routes/__init__.py create mode 100644 florist/tests/unit/api/routes/server/__init__.py rename florist/tests/unit/api/{test_server.py => routes/server/test_training.py} (92%) diff --git a/.gitignore b/.gitignore index 7d2c7f53..74604df4 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,4 @@ next-env.d.ts /metrics/ /logs/ +/.ruff_cache/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 09cc5a54..f4d876b2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -77,3 +77,10 @@ analysis. Ruff checks various rules including [flake8](https://docs.astral.sh/ru Last but not the least, we use type hints in our code which is then checked using [mypy](https://mypy.readthedocs.io/en/stable/). + +## Documentation + +Backend code API documentation can be found at https://vectorinstitute.github.io/FLorist/. + +Backend REST API documentation can be found at https://localhost:8000/docs once the server +is running locally. diff --git a/README.md b/README.md index 99a320fc..76201033 100644 --- a/README.md +++ b/README.md @@ -95,3 +95,8 @@ uvicorn florist.api.client:app --reload --port 8001 ``` The service will be available at `http://localhost:8001`. + +# Contributing +If you are interested in contributing to the library, please see [CONTRIBUTING.MD](CONTRIBUTING.md). +This file contains many details around contributing to the code base, including development +practices, code checks, tests, and more. diff --git a/florist/api/monitoring/metrics.py b/florist/api/monitoring/metrics.py index a04e7b01..f03afbf8 100644 --- a/florist/api/monitoring/metrics.py +++ b/florist/api/monitoring/metrics.py @@ -80,7 +80,7 @@ def wait_for_metric( Check metrics on Redis under the given UUID and wait until it appears. If the metrics are not there yet, it will retry up to max_retries times, - sleeping and amount of seconds_to_sleep_between_retries between them. + sleeping an amount of `seconds_to_sleep_between_retries` between them. :param uuid: (str) The UUID to pull the metrics from Redis. :param metric: (str) The metric to look for. @@ -90,7 +90,7 @@ def wait_for_metric( :param max_retries: (int) The maximum number of retries. Optional, default is 20. :param seconds_to_sleep_between_retries: (int) The amount of seconds to sleep between retries. Optional, default is 1. - :raises Exception: If it retries MAX_RETRIES times and the right metrics have not been found. + :raises Exception: If it retries `max_retries` times and the right metrics have not been found. """ redis_connection = redis.Redis(host=redis_host, port=redis_port) diff --git a/florist/api/routes/__init__.py b/florist/api/routes/__init__.py new file mode 100644 index 00000000..95b12ef8 --- /dev/null +++ b/florist/api/routes/__init__.py @@ -0,0 +1 @@ +"""FastAPI routes.""" diff --git a/florist/api/routes/server/__init__.py b/florist/api/routes/server/__init__.py new file mode 100644 index 00000000..08b764d2 --- /dev/null +++ b/florist/api/routes/server/__init__.py @@ -0,0 +1 @@ +"""FastAPI server routes.""" diff --git a/florist/api/routes/server/training.py b/florist/api/routes/server/training.py new file mode 100644 index 00000000..75101f8e --- /dev/null +++ b/florist/api/routes/server/training.py @@ -0,0 +1,119 @@ +"""FastAPI routes for training.""" +import logging +from typing import List + +import requests +from fastapi import APIRouter, Form +from fastapi.responses import JSONResponse +from typing_extensions import Annotated + +from florist.api.monitoring.metrics import wait_for_metric +from florist.api.servers.common import ClientInfo, ClientInfoParseError, Model +from florist.api.servers.launch import launch_local_server + + +router = APIRouter() + +LOGGER = logging.getLogger("uvicorn.error") + +START_CLIENT_API = "api/client/start" + + +@router.post("/start") +def start( + model: Annotated[str, Form()], + server_address: Annotated[str, Form()], + n_server_rounds: Annotated[int, Form()], + batch_size: Annotated[int, Form()], + local_epochs: Annotated[int, Form()], + redis_host: Annotated[str, Form()], + redis_port: Annotated[str, Form()], + clients_info: Annotated[str, Form()], +) -> JSONResponse: + """ + Start FL training by starting a FL server and its clients. + + Should be called with a POST request and the parameters should be contained in the request's form. + + :param model: (str) The name of the model to train. Should be one of the values in the enum + florist.api.servers.common.Model + :param server_address: (str) The address of the FL server to be started. It should be comprised of + the host name and port separated by colon (e.g. "localhost:8080") + :param n_server_rounds: (int) The number of rounds the FL server should run. + :param batch_size: (int) The size of the batch for training. + :param local_epochs: (int) The number of epochs to run by the clients. + :param redis_host: (str) The host name for the Redis instance for metrics reporting. + :param redis_port: (str) The port for the Redis instance for metrics reporting. + :param clients_info: (str) A JSON string containing the client information. It will be parsed by + florist.api.servers.common.ClientInfo and should be in the following format: + [ + { + "client": , + "client_address": , + "data_path": , + "redis_host": , + "redis_port": , + } + ] + :return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the server and + the clients in the format below. The UUIDs can be used to pull metrics from Redis. + { + "server_uuid": , + "client_uuids": [, , ..., ], + } + If not successful, returns the appropriate error code with a JSON with the format below: + {"error": } + """ + try: + # Parse input data + if model not in Model.list(): + error_msg = f"Model '{model}' not supported. Supported models: {Model.list()}" + return JSONResponse(content={"error": error_msg}, status_code=400) + + model_class = Model.class_for_model(Model[model]) + clients_info_list = ClientInfo.parse(clients_info) + + # Start the server + server_uuid, _ = launch_local_server( + model=model_class(), + n_clients=len(clients_info_list), + server_address=server_address, + n_server_rounds=n_server_rounds, + batch_size=batch_size, + local_epochs=local_epochs, + redis_host=redis_host, + redis_port=redis_port, + ) + wait_for_metric(server_uuid, "fit_start", redis_host, redis_port, logger=LOGGER) + + # Start the clients + client_uuids: List[str] = [] + for client_info in clients_info_list: + parameters = { + "server_address": server_address, + "client": client_info.client.value, + "data_path": client_info.data_path, + "redis_host": client_info.redis_host, + "redis_port": client_info.redis_port, + } + response = requests.get(url=f"http://{client_info.client_address}/{START_CLIENT_API}", params=parameters) + json_response = response.json() + LOGGER.debug(f"Client response: {json_response}") + + if response.status_code != 200: + raise Exception(f"Client response returned {response.status_code}. Response: {json_response}") + + if "uuid" not in json_response: + raise Exception(f"Client response did not return a UUID. Response: {json_response}") + + client_uuids.append(json_response["uuid"]) + + # Return the UUIDs + return JSONResponse({"server_uuid": server_uuid, "client_uuids": client_uuids}) + + except (ValueError, ClientInfoParseError) as ex: + return JSONResponse(content={"error": str(ex)}, status_code=400) + + except Exception as ex: + LOGGER.exception(ex) + return JSONResponse({"error": str(ex)}, status_code=500) diff --git a/florist/api/server.py b/florist/api/server.py index 5982e0d2..6b616cd3 100644 --- a/florist/api/server.py +++ b/florist/api/server.py @@ -1,118 +1,8 @@ -"""FLorist server FastAPI endpoints.""" -import logging -from typing import List +"""FLorist server FastAPI endpoints and routes.""" +from fastapi import FastAPI -import requests -from fastapi import FastAPI, Form -from fastapi.responses import JSONResponse -from typing_extensions import Annotated - -from florist.api.monitoring.metrics import wait_for_metric -from florist.api.servers.common import ClientInfo, ClientInfoParseError, Model -from florist.api.servers.launch import launch_local_server +from florist.api.routes.server.training import router as training_router app = FastAPI() -LOGGER = logging.getLogger("uvicorn.error") - -START_CLIENT_API = "api/client/start" - - -@app.post("/api/server/start_training") -def start_training( - model: Annotated[str, Form()], - server_address: Annotated[str, Form()], - n_server_rounds: Annotated[int, Form()], - batch_size: Annotated[int, Form()], - local_epochs: Annotated[int, Form()], - redis_host: Annotated[str, Form()], - redis_port: Annotated[str, Form()], - clients_info: Annotated[str, Form()], -) -> JSONResponse: - """ - Start FL training by starting a FL server and its clients. - - Should be called with a POST request and the parameters should be contained in the request's form. - - :param model: (str) The name of the model to train. Should be one of the values in the enum - florist.api.servers.common.Model - :param server_address: (str) The address of the FL server to be started. It should be comprised of - the host name and port separated by colon (e.g. "localhost:8080") - :param n_server_rounds: (int) The number of rounds the FL server should run. - :param batch_size: (int) The size of the batch for training - :param local_epochs: (int) The number of epochs to run by the clients - :param redis_host: (str) The host name for the Redis instance for metrics reporting. - :param redis_port: (str) The port for the Redis instance for metrics reporting. - :param clients_info: (str) A JSON string containing the client information. It will be parsed by - florist.api.servers.common.ClientInfo and should be in the following format: - [ - { - "client": , - "client_address": , - "data_path": , - "redis_host": , - "redis_port": , - } - ] - :return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the server and - the clients in the format below. The UUIDs can be used to pull metrics from Redis. - { - "server_uuid": , - "client_uuids": [, , ..., ], - } - If not successful, returns the appropriate error code with a JSON with the format below: - {"error": } - """ - try: - # Parse input data - if model not in Model.list(): - error_msg = f"Model '{model}' not supported. Supported models: {Model.list()}" - return JSONResponse(content={"error": error_msg}, status_code=400) - - model_class = Model.class_for_model(Model[model]) - clients_info_list = ClientInfo.parse(clients_info) - - # Start the server - server_uuid, _ = launch_local_server( - model=model_class(), - n_clients=len(clients_info_list), - server_address=server_address, - n_server_rounds=n_server_rounds, - batch_size=batch_size, - local_epochs=local_epochs, - redis_host=redis_host, - redis_port=redis_port, - ) - wait_for_metric(server_uuid, "fit_start", redis_host, redis_port, logger=LOGGER) - - # Start the clients - client_uuids: List[str] = [] - for client_info in clients_info_list: - parameters = { - "server_address": server_address, - "client": client_info.client.value, - "data_path": client_info.data_path, - "redis_host": client_info.redis_host, - "redis_port": client_info.redis_port, - } - response = requests.get(url=f"http://{client_info.client_address}/{START_CLIENT_API}", params=parameters) - json_response = response.json() - LOGGER.debug(f"Client response: {json_response}") - - if response.status_code != 200: - raise Exception(f"Client response returned {response.status_code}. Response: {json_response}") - - if "uuid" not in json_response: - raise Exception(f"Client response did not return a UUID. Response: {json_response}") - - client_uuids.append(json_response["uuid"]) - - # Return the UUIDs - return JSONResponse({"server_uuid": server_uuid, "client_uuids": client_uuids}) - - except (ValueError, ClientInfoParseError) as ex: - return JSONResponse(content={"error": str(ex)}, status_code=400) - - except Exception as ex: - LOGGER.exception(ex) - return JSONResponse({"error": str(ex)}, status_code=500) +app.include_router(training_router, tags=["training"], prefix="/api/server/training") diff --git a/florist/api/servers/launch.py b/florist/api/servers/launch.py index 19f74752..2971f916 100644 --- a/florist/api/servers/launch.py +++ b/florist/api/servers/launch.py @@ -29,8 +29,8 @@ def launch_local_server( :param n_clients: (int) The number of clients that will report to this server. :param server_address: (str) The address the server should start at. :param n_server_rounds: (int) The number of rounds the training should run for. - :param batch_size: (int) The size of the batch for training - :param local_epochs: (int) The number of epochs to run by the clients + :param batch_size: (int) The size of the batch for training. + :param local_epochs: (int) The number of epochs to run by the clients. :param redis_host: (str) the host name for the Redis instance for metrics reporting. :param redis_port: (str) the port for the Redis instance for metrics reporting. :return: (Tuple[str, multiprocessing.Process]) the UUID of the server, which can be used to pull diff --git a/florist/tests/__init__.py b/florist/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/florist/tests/integration/api/test_train.py b/florist/tests/integration/api/test_train.py index 4223060f..2c8f5df9 100644 --- a/florist/tests/integration/api/test_train.py +++ b/florist/tests/integration/api/test_train.py @@ -7,16 +7,16 @@ import uvicorn from florist.api.monitoring.metrics import wait_for_metric -from florist.api.server import LOGGER -from florist.tests.integration.api.utils import Server +from florist.api.routes.server.training import LOGGER +from florist.tests.integration.api.utils import TestUvicornServer def test_train(): # Define services server_config = uvicorn.Config("florist.api.server:app", host="localhost", port=8000, log_level="debug") - server_service = Server(config=server_config) + server_service = TestUvicornServer(config=server_config) client_config = uvicorn.Config("florist.api.client:app", host="localhost", port=8001, log_level="debug") - client_service = Server(config=client_config) + client_service = TestUvicornServer(config=client_config) # Start services with server_service.run_in_thread(): @@ -49,7 +49,7 @@ def test_train(): } request = requests.Request( method="POST", - url=f"http://localhost:8000/api/server/start_training", + url=f"http://localhost:8000/api/server/training/start", files=data, ).prepare() session = requests.Session() diff --git a/florist/tests/integration/api/utils.py b/florist/tests/integration/api/utils.py index 3b2c7a77..488e8ce5 100644 --- a/florist/tests/integration/api/utils.py +++ b/florist/tests/integration/api/utils.py @@ -4,7 +4,7 @@ import uvicorn -class Server(uvicorn.Server): +class TestUvicornServer(uvicorn.Server): def install_signal_handlers(self): pass diff --git a/florist/tests/unit/api/routes/__init__.py b/florist/tests/unit/api/routes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/florist/tests/unit/api/routes/server/__init__.py b/florist/tests/unit/api/routes/server/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/florist/tests/unit/api/test_server.py b/florist/tests/unit/api/routes/server/test_training.py similarity index 92% rename from florist/tests/unit/api/test_server.py rename to florist/tests/unit/api/routes/server/test_training.py index 52537e45..d296e393 100644 --- a/florist/tests/unit/api/test_server.py +++ b/florist/tests/unit/api/routes/server/test_training.py @@ -2,13 +2,13 @@ from unittest.mock import Mock, patch, ANY from florist.api.models.mnist import MnistNet -from florist.api.server import start_training +from florist.api.routes.server.training import start -@patch("florist.api.server.launch_local_server") +@patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") -@patch("florist.api.server.requests") -def test_start_training_success(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None: +@patch("florist.api.routes.server.training.requests") +def test_start_success(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None: # Arrange test_model = "MNIST" test_server_address = "test-server-address" @@ -47,7 +47,7 @@ def test_start_training_success(mock_requests: Mock, mock_redis: Mock, mock_laun mock_requests.get.return_value = mock_response # Act - response = start_training( + response = start( test_model, test_server_address, test_n_server_rounds, @@ -124,7 +124,7 @@ def test_start_fail_unsupported_server_model() -> None: ] # Act - response = start_training( + response = start( test_model, test_server_address, test_n_server_rounds, @@ -168,7 +168,7 @@ def test_start_fail_unsupported_client() -> None: ] # Act - response = start_training( + response = start( test_model, test_server_address, test_n_server_rounds, @@ -186,8 +186,8 @@ def test_start_fail_unsupported_client() -> None: assert "Client 'WRONG CLIENT' not supported." in json_body["error"] -@patch("florist.api.server.launch_local_server") -def test_start_training_launch_server_exception(mock_launch_local_server: Mock) -> None: +@patch("florist.api.routes.server.training.launch_local_server") +def test_start_launch_server_exception(mock_launch_local_server: Mock) -> None: # Arrange test_model = "MNIST" test_server_address = "test-server-address" @@ -215,7 +215,7 @@ def test_start_training_launch_server_exception(mock_launch_local_server: Mock) mock_launch_local_server.side_effect = test_exception # Act - response = start_training( + response = start( test_model, test_server_address, test_n_server_rounds, @@ -232,7 +232,7 @@ def test_start_training_launch_server_exception(mock_launch_local_server: Mock) assert json_body == {"error": str(test_exception)} -@patch("florist.api.server.launch_local_server") +@patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_local_server: Mock) -> None: # Arrange @@ -265,7 +265,7 @@ def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_local_ser mock_redis.Redis.side_effect = test_exception # Act - response = start_training( + response = start( test_model, test_server_address, test_n_server_rounds, @@ -282,7 +282,7 @@ def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_local_ser assert json_body == {"error": str(test_exception)} -@patch("florist.api.server.launch_local_server") +@patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") @patch("florist.api.monitoring.metrics.time") # just so time.sleep does not actually sleep def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None: @@ -317,7 +317,7 @@ def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_launch_lo mock_redis.Redis.return_value = mock_redis_connection # Act - response = start_training( + response = start( test_model, test_server_address, test_n_server_rounds, @@ -334,10 +334,10 @@ def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_launch_lo assert json_body == {"error": "Metric 'fit_start' not been found after 20 retries."} -@patch("florist.api.server.launch_local_server") +@patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") -@patch("florist.api.server.requests") -def test_start_training_fail_response(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None: +@patch("florist.api.routes.server.training.requests") +def test_start_fail_response(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None: # Arrange test_model = "MNIST" test_server_address = "test-server-address" @@ -374,7 +374,7 @@ def test_start_training_fail_response(mock_requests: Mock, mock_redis: Mock, moc mock_requests.get.return_value = mock_response # Act - response = start_training( + response = start( test_model, test_server_address, test_n_server_rounds, @@ -391,10 +391,10 @@ def test_start_training_fail_response(mock_requests: Mock, mock_redis: Mock, moc assert json_body == {"error": f"Client response returned 403. Response: error"} -@patch("florist.api.server.launch_local_server") +@patch("florist.api.routes.server.training.launch_local_server") @patch("florist.api.monitoring.metrics.redis") -@patch("florist.api.server.requests") -def test_start_training_no_uuid_in_response(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None: +@patch("florist.api.routes.server.training.requests") +def test_start_no_uuid_in_response(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None: # Arrange test_model = "MNIST" test_server_address = "test-server-address" @@ -431,7 +431,7 @@ def test_start_training_no_uuid_in_response(mock_requests: Mock, mock_redis: Moc mock_requests.get.return_value = mock_response # Act - response = start_training( + response = start( test_model, test_server_address, test_n_server_rounds, From 5134eb773646718632ad2fd47c82a4f832f76382 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 9 Apr 2024 11:40:31 -0400 Subject: [PATCH 15/24] Skipping one more security vulnerability with pillow --- .github/workflows/static_code_checks.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/static_code_checks.yaml b/.github/workflows/static_code_checks.yaml index abcbdaa7..a556550a 100644 --- a/.github/workflows/static_code_checks.yaml +++ b/.github/workflows/static_code_checks.yaml @@ -47,12 +47,12 @@ jobs: with: virtual-environment: .venv/ # Ignoring security vulnerabilities in Pillow because pycyclops cannot update it to the - # version that fixes them (>10.0.1). - # Remove those when the issue below is fixed and pycyclops changes its requirements: - # https://github.com/SeldonIO/alibi/issues/991 + # version that fixes them (>10.3.0). + # Remove those when FL4Health is released with the update to pillow > 10 ignore-vulns: | PYSEC-2023-175 PYSEC-2023-227 GHSA-j7hp-h8jx-5ppr GHSA-56pw-mpj4-fxww GHSA-3f63-hfp8-52jq + GHSA-44wm-f244-xhp3 From 67dc1c82b0691e57dfd0671ed04dfc8eacaf38f8 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 9 Apr 2024 13:23:47 -0400 Subject: [PATCH 16/24] Moving test util classes to the right place, implementing fixture --- florist/api/routes/{ => server}/job.py | 4 +- florist/api/server.py | 21 ++++++- florist/tests/integration/__init__.py | 0 .../tests/integration/api/routes/__init__.py | 0 .../integration/api/routes/server/__init__.py | 0 .../integration/api/routes/server/test_job.py | 29 ++++++++++ florist/tests/integration/api/test_routes.py | 58 ------------------- florist/tests/integration/api/utils.py | 41 +++++++++++++ 8 files changed, 91 insertions(+), 62 deletions(-) rename florist/api/routes/{ => server}/job.py (91%) create mode 100644 florist/tests/integration/__init__.py create mode 100644 florist/tests/integration/api/routes/__init__.py create mode 100644 florist/tests/integration/api/routes/server/__init__.py create mode 100644 florist/tests/integration/api/routes/server/test_job.py delete mode 100644 florist/tests/integration/api/test_routes.py diff --git a/florist/api/routes/job.py b/florist/api/routes/server/job.py similarity index 91% rename from florist/api/routes/job.py rename to florist/api/routes/server/job.py index 5d2be865..053e5ff4 100644 --- a/florist/api/routes/job.py +++ b/florist/api/routes/server/job.py @@ -28,8 +28,8 @@ def new_job(request: Request, job: Job = Body(...)) -> Dict[str, Any]: # noqa: :return: (Dict[str, Any]) A dictionary with the attributes of the new Job instance as saved in the database. """ job = jsonable_encoder(job) - new_job = request.app.database[JOB_DATABASE_NAME].insert_one(job) - created_job = request.app.database[JOB_DATABASE_NAME].find_one({"_id": new_job.inserted_id}) + result = request.app.database[JOB_DATABASE_NAME].insert_one(job) + created_job = request.app.database[JOB_DATABASE_NAME].find_one({"_id": result.inserted_id}) assert isinstance(created_job, dict) return created_job diff --git a/florist/api/server.py b/florist/api/server.py index eae059e0..07593c08 100644 --- a/florist/api/server.py +++ b/florist/api/server.py @@ -1,10 +1,27 @@ """FLorist server FastAPI endpoints and routes.""" from fastapi import FastAPI +from pymongo import MongoClient -from florist.api.routes.job import router as job_router +from florist.api.routes.server.job import router as job_router from florist.api.routes.server.training import router as training_router app = FastAPI() app.include_router(training_router, tags=["training"], prefix="/api/server/training") -app.include_router(job_router, tags=["job"], prefix="/job") +app.include_router(job_router, tags=["job"], prefix="/api/server/job") + +MONGODB_URI = "mongodb://localhost:27017/" +DATABASE_NAME = "florist-server" + + +@app.on_event("startup") +def startup_db_client() -> None: + """Start up the MongoDB client on app startup.""" + app.mongodb_client = MongoClient(MONGODB_URI) # type: ignore[attr-defined] + app.database = app.mongodb_client[DATABASE_NAME] # type: ignore[attr-defined] + + +@app.on_event("shutdown") +def shutdown_db_client() -> None: + """Shut down the MongoDB client on app shutdown.""" + app.mongodb_client.close() # type: ignore[attr-defined] diff --git a/florist/tests/integration/__init__.py b/florist/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/florist/tests/integration/api/routes/__init__.py b/florist/tests/integration/api/routes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/florist/tests/integration/api/routes/server/__init__.py b/florist/tests/integration/api/routes/server/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/florist/tests/integration/api/routes/server/test_job.py b/florist/tests/integration/api/routes/server/test_job.py new file mode 100644 index 00000000..f4b13939 --- /dev/null +++ b/florist/tests/integration/api/routes/server/test_job.py @@ -0,0 +1,29 @@ +from unittest.mock import ANY + +from florist.api.db.entities import Job +from florist.api.routes.server.job import new_job +from florist.api.servers.common import Model +from florist.tests.integration.api.utils import mock_request + + +def test_new_job(mock_request) -> None: + test_empty_job = Job() + result = new_job(mock_request, test_empty_job) + + assert result == { + "_id": ANY, + "model": None, + "redis_host": None, + "redis_port": None, + } + assert isinstance(result["_id"], str) + + test_job = Job(id="test-id", model=Model.MNIST, redis_host="test-redis-host", redis_port="test-redis-port") + result = new_job(mock_request, test_job) + + assert result == { + "_id": test_job.id, + "model": test_job.model.value, + "redis_host": test_job.redis_host, + "redis_port": test_job.redis_port, + } diff --git a/florist/tests/integration/api/test_routes.py b/florist/tests/integration/api/test_routes.py deleted file mode 100644 index ada2adbf..00000000 --- a/florist/tests/integration/api/test_routes.py +++ /dev/null @@ -1,58 +0,0 @@ -from unittest.mock import ANY - -from pymongo import MongoClient -from starlette.requests import Request - -from florist.api.db.entities import Job -from florist.api.routes.job import new_job -from florist.api.server import MONGODB_URI -from florist.api.servers.common import Model - - -DATABASE_NAME = "test-database" - - -def test_new_job() -> None: - test_empty_job = Job() - result = new_job(MockRequest(), test_empty_job) - - assert result == { - "_id": ANY, - "model": None, - "redis_host": None, - "redis_port": None, - } - assert isinstance(result["_id"], str) - - test_job = Job(id="test-id", model=Model.MNIST, redis_host="test-redis-host", redis_port="test-redis-port") - result = new_job(MockRequest(), test_job) - - assert result == { - "_id": test_job.id, - "model": test_job.model.value, - "redis_host": test_job.redis_host, - "redis_port": test_job.redis_port, - } - assert isinstance(result["_id"], str) - - -# TODO delete database at the end - -class MockApp: - def __init__(self): - mongo_client = MongoClient(MONGODB_URI) - self.database = mongo_client[DATABASE_NAME] - - -class MockRequest(Request): - def __init__(self): - super().__init__({"type": "http"}) - self._app = MockApp() - - @property - def app(self): - return self._app - - @app.setter - def app(self, value): - self._app = value diff --git a/florist/tests/integration/api/utils.py b/florist/tests/integration/api/utils.py index 488e8ce5..bddbc051 100644 --- a/florist/tests/integration/api/utils.py +++ b/florist/tests/integration/api/utils.py @@ -1,8 +1,14 @@ import contextlib +import pytest import time import threading import uvicorn +from pymongo import MongoClient +from starlette.requests import Request + +from florist.api.server import MONGODB_URI + class TestUvicornServer(uvicorn.Server): def install_signal_handlers(self): @@ -19,3 +25,38 @@ def run_in_thread(self): finally: self.should_exit = True thread.join() + + +class MockApp: + def __init__(self, database_name: str): + self.mongo_client = MongoClient(MONGODB_URI) + self.database = self.mongo_client[database_name] + + +class MockRequest(Request): + def __init__(self, app: MockApp): + super().__init__({"type": "http"}) + self._app = app + + @property + def app(self): + return self._app + + @app.setter + def app(self, value): + self._app = value + + +TEST_DATABASE_NAME = "test-database" + + +@pytest.fixture +def mock_request() -> MockRequest: + print(f"Creating test detabase '{TEST_DATABASE_NAME}'") + app = MockApp(TEST_DATABASE_NAME) + request = MockRequest(app) + + yield request + + print(f"Deleting test detabase '{TEST_DATABASE_NAME}'") + app.mongo_client.drop_database(TEST_DATABASE_NAME) From d7b430826448872f9fc878a54285407681fd0d58 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 9 Apr 2024 17:08:31 -0400 Subject: [PATCH 17/24] Adding more information to the Job --- florist/api/db/entities.py | 56 ++++++++++++++- florist/api/routes/server/job.py | 20 ++++-- .../integration/api/routes/server/test_job.py | 69 ++++++++++++++++++- 3 files changed, 136 insertions(+), 9 deletions(-) diff --git a/florist/api/db/entities.py b/florist/api/db/entities.py index 9d0ade5a..30bfa8a1 100644 --- a/florist/api/db/entities.py +++ b/florist/api/db/entities.py @@ -1,23 +1,63 @@ """Definitions for the MongoDB database entities.""" import uuid -from typing import Optional +from enum import Enum +from typing import List, Optional from pydantic import BaseModel, Field from typing_extensions import Annotated +from florist.api.clients.common import Client from florist.api.servers.common import Model JOB_DATABASE_NAME = "job" +class JobStatus(Enum): + """Enumeration of all possible statuses of a Job.""" + + NOT_STARTED = "NOT_STARTED" + IN_PROGRESS = "IN_PROGRESS" + FINISHED_WITH_ERROR = "FINISHED_WITH_ERROR" + FINISHED_SUCCESSFULLY = "FINISHED_SUCCESSFULLY" + + +class ClientInfo(BaseModel): + """Define the information of an FL client.""" + + id: str = Field(default_factory=uuid.uuid4, alias="_id") + client: Client = Field(...) + service_address: str = Field(...) + data_path: str = Field(...) + redis_host: str = Field(...) + redis_port: str = Field(...) + + class Config: + """MongoDB config for the ClientInfo DB entity.""" + + allow_population_by_field_name = True + schema_extra = { + "example": { + "client": "MNIST", + "service_address": "locahost:8081", + "data_path": "path/to/data", + "redis_host": "localhost", + "redis_port": "6880", + }, + } + + class Job(BaseModel): """Define the Job DB entity.""" id: str = Field(default_factory=uuid.uuid4, alias="_id") + status: JobStatus = Field(default=JobStatus.NOT_STARTED) model: Optional[Annotated[Model, Field(...)]] + server_address: Optional[Annotated[str, Field(...)]] + server_info: Optional[Annotated[str, Field(...)]] redis_host: Optional[Annotated[str, Field(...)]] redis_port: Optional[Annotated[str, Field(...)]] + clients_info: Optional[Annotated[List[ClientInfo], Field(...)]] class Config: """MongoDB config for the Job DB entity.""" @@ -26,8 +66,20 @@ class Config: schema_extra = { "example": { "_id": "066de609-b04a-4b30-b46c-32537c7f1f6e", + "status": "NOT_STARTED", "model": "MNIST", - "redis_host": "locahost", + "server_address": "localhost:8080", + "server_info": '{"n_server_rounds": 3, "batch_size": 8}', + "redis_host": "localhost", "redis_port": "6879", + "client_info": [ + { + "client": "MNIST", + "service_address": "locahost:8081", + "data_path": "path/to/data", + "redis_host": "localhost", + "redis_port": "6880", + }, + ], }, } diff --git a/florist/api/routes/server/job.py b/florist/api/routes/server/job.py index 053e5ff4..18ac53b1 100644 --- a/florist/api/routes/server/job.py +++ b/florist/api/routes/server/job.py @@ -1,7 +1,8 @@ -"""The /job FastAPI routes.""" +"""FastAPI routes for the job.""" +import json from typing import Any, Dict -from fastapi import APIRouter, Body, Request, status +from fastapi import APIRouter, Body, HTTPException, Request, status from fastapi.encoders import jsonable_encoder from florist.api.db.entities import JOB_DATABASE_NAME, Job @@ -26,9 +27,20 @@ def new_job(request: Request, job: Job = Body(...)) -> Dict[str, Any]: # noqa: :param request: (fastapi.Request) the FastAPI request object. :param job: (Job) The Job instance to be saved in the database. :return: (Dict[str, Any]) A dictionary with the attributes of the new Job instance as saved in the database. + :raises: (HTTPException) a 400 if job.server_info is not None and cannot be parsed into JSON. """ - job = jsonable_encoder(job) - result = request.app.database[JOB_DATABASE_NAME].insert_one(job) + if job.server_info is not None: + try: + json.loads(job.server_info) + except json.JSONDecodeError as e: + error_message = ( + "job.server_info could not be parsed into JSON. " f"job.server_info: {job.server_info}. Error: {e}" + ) + raise HTTPException(status_code=400, detail=error_message) from e + + json_job = jsonable_encoder(job) + + result = request.app.database[JOB_DATABASE_NAME].insert_one(json_job) created_job = request.app.database[JOB_DATABASE_NAME].find_one({"_id": result.inserted_id}) assert isinstance(created_job, dict) diff --git a/florist/tests/integration/api/routes/server/test_job.py b/florist/tests/integration/api/routes/server/test_job.py index f4b13939..d3ecdd38 100644 --- a/florist/tests/integration/api/routes/server/test_job.py +++ b/florist/tests/integration/api/routes/server/test_job.py @@ -1,29 +1,92 @@ from unittest.mock import ANY +from pytest import raises -from florist.api.db.entities import Job +from fastapi import HTTPException + +from florist.api.clients.common import Client +from florist.api.db.entities import ClientInfo, Job, JobStatus from florist.api.routes.server.job import new_job from florist.api.servers.common import Model from florist.tests.integration.api.utils import mock_request -def test_new_job(mock_request) -> None: +def test_new_job_success(mock_request) -> None: test_empty_job = Job() result = new_job(mock_request, test_empty_job) assert result == { "_id": ANY, + "status": JobStatus.NOT_STARTED.value, "model": None, + "server_address": None, + "server_info": None, "redis_host": None, "redis_port": None, + "clients_info": None, } assert isinstance(result["_id"], str) - test_job = Job(id="test-id", model=Model.MNIST, redis_host="test-redis-host", redis_port="test-redis-port") + test_job = Job( + id="test-id", + status=JobStatus.IN_PROGRESS, + model=Model.MNIST, + server_address="test-server-address", + server_info="{\"test-server-info\": 123}", + redis_host="test-redis-host", + redis_port="test-redis-port", + clients_info=[ + ClientInfo( + client=Client.MNIST, + service_address="test-addr-1", + data_path="test/data/path-1", + redis_host="test-redis-host-1", + redis_port="test-redis-port-1", + ), + ClientInfo( + client=Client.MNIST, + service_address="test-addr-2", + data_path="test/data/path-2", + redis_host="test-redis-host-2", + redis_port="test-redis-port-2", + ), + ] + ) result = new_job(mock_request, test_job) assert result == { "_id": test_job.id, + "status": test_job.status.value, "model": test_job.model.value, + "server_address": "test-server-address", + "server_info": "{\"test-server-info\": 123}", "redis_host": test_job.redis_host, "redis_port": test_job.redis_port, + "clients_info": [ + { + "_id": ANY, + "client": test_job.clients_info[0].client.value, + "service_address": test_job.clients_info[0].service_address, + "data_path": test_job.clients_info[0].data_path, + "redis_host": test_job.clients_info[0].redis_host, + "redis_port": test_job.clients_info[0].redis_port, + }, { + "_id": ANY, + "client": test_job.clients_info[1].client.value, + "service_address": test_job.clients_info[1].service_address, + "data_path": test_job.clients_info[1].data_path, + "redis_host": test_job.clients_info[1].redis_host, + "redis_port": test_job.clients_info[1].redis_port, + }, + ], } + assert isinstance(result["clients_info"][0]["_id"], str) + assert isinstance(result["clients_info"][1]["_id"], str) + + +def test_new_job_fail_bad_server_info(mock_request) -> None: + test_job = Job(server_info="not json") + with raises(HTTPException) as exception_info: + new_job(mock_request, test_job) + + assert exception_info.value.status_code == 400 + assert "job.server_info could not be parsed into JSON" in exception_info.value.detail From b247a9f6da5e26b7382edbbe7173b1a01f13aea8 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 9 Apr 2024 17:10:02 -0400 Subject: [PATCH 18/24] Small code cleanup --- florist/api/routes/server/job.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/florist/api/routes/server/job.py b/florist/api/routes/server/job.py index 053e5ff4..ec55a820 100644 --- a/florist/api/routes/server/job.py +++ b/florist/api/routes/server/job.py @@ -27,9 +27,10 @@ def new_job(request: Request, job: Job = Body(...)) -> Dict[str, Any]: # noqa: :param job: (Job) The Job instance to be saved in the database. :return: (Dict[str, Any]) A dictionary with the attributes of the new Job instance as saved in the database. """ - job = jsonable_encoder(job) - result = request.app.database[JOB_DATABASE_NAME].insert_one(job) - created_job = request.app.database[JOB_DATABASE_NAME].find_one({"_id": result.inserted_id}) + json_job = jsonable_encoder(job) + result = request.app.database[JOB_DATABASE_NAME].insert_one(json_job) + created_job = request.app.database[JOB_DATABASE_NAME].find_one({"_id": result.inserted_id}) assert isinstance(created_job, dict) + return created_job From f49042be5c387031a55cf1f7097609cbcc1e407c Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Tue, 9 Apr 2024 17:11:41 -0400 Subject: [PATCH 19/24] Small code cleanup [2] --- florist/api/routes/server/job.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/florist/api/routes/server/job.py b/florist/api/routes/server/job.py index ec55a820..b1f001a8 100644 --- a/florist/api/routes/server/job.py +++ b/florist/api/routes/server/job.py @@ -1,4 +1,4 @@ -"""The /job FastAPI routes.""" +"""FastAPI routes for the job.""" from typing import Any, Dict from fastapi import APIRouter, Body, Request, status From d6fc246e7cdf52274acaf7ca04a0a6bb407f2fb4 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 10 Apr 2024 12:01:58 -0400 Subject: [PATCH 20/24] Better startup and shutdown --- florist/api/db/__init__.py | 1 + florist/api/server.py | 24 ++++++++++++++---------- 2 files changed, 15 insertions(+), 10 deletions(-) create mode 100644 florist/api/db/__init__.py diff --git a/florist/api/db/__init__.py b/florist/api/db/__init__.py new file mode 100644 index 00000000..dc5ff910 --- /dev/null +++ b/florist/api/db/__init__.py @@ -0,0 +1 @@ +"""Classes and definitions for the database.""" diff --git a/florist/api/server.py b/florist/api/server.py index 07593c08..b4de58c3 100644 --- a/florist/api/server.py +++ b/florist/api/server.py @@ -1,4 +1,7 @@ """FLorist server FastAPI endpoints and routes.""" +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator + from fastapi import FastAPI from pymongo import MongoClient @@ -6,22 +9,23 @@ from florist.api.routes.server.training import router as training_router -app = FastAPI() -app.include_router(training_router, tags=["training"], prefix="/api/server/training") -app.include_router(job_router, tags=["job"], prefix="/api/server/job") - MONGODB_URI = "mongodb://localhost:27017/" DATABASE_NAME = "florist-server" -@app.on_event("startup") -def startup_db_client() -> None: - """Start up the MongoDB client on app startup.""" +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[Any, Any]: + """Set up function for app startup and shutdown.""" + # Set up mongodb app.mongodb_client = MongoClient(MONGODB_URI) # type: ignore[attr-defined] app.database = app.mongodb_client[DATABASE_NAME] # type: ignore[attr-defined] + yield -@app.on_event("shutdown") -def shutdown_db_client() -> None: - """Shut down the MongoDB client on app shutdown.""" + # Shut down mongodb app.mongodb_client.close() # type: ignore[attr-defined] + + +app = FastAPI(lifespan=lifespan) +app.include_router(training_router, tags=["training"], prefix="/api/server/training") +app.include_router(job_router, tags=["job"], prefix="/api/server/job") From c0d01eb23c383fc540d9b5e5f9ee290f64161429 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 10 Apr 2024 14:31:58 -0400 Subject: [PATCH 21/24] Better validation --- florist/api/db/entities.py | 14 ++++++++++++++ florist/api/routes/server/job.py | 18 +++++++++--------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/florist/api/db/entities.py b/florist/api/db/entities.py index 30bfa8a1..a87459fe 100644 --- a/florist/api/db/entities.py +++ b/florist/api/db/entities.py @@ -1,4 +1,5 @@ """Definitions for the MongoDB database entities.""" +import json import uuid from enum import Enum from typing import List, Optional @@ -59,6 +60,19 @@ class Job(BaseModel): redis_port: Optional[Annotated[str, Field(...)]] clients_info: Optional[Annotated[List[ClientInfo], Field(...)]] + @classmethod + def is_valid_server_info(cls, server_info: Optional[str]) -> bool: + """ + Validate if server info is a json string. + + :param server_info: (str) the json string with the server info. + :return: True if server_info is None or a valid JSON string, False otherwise. + :raises: (json.JSONDecodeError) if there is an error decoding the server info into json + """ + if server_info is not None: + json.loads(server_info) + return True + class Config: """MongoDB config for the Job DB entity.""" diff --git a/florist/api/routes/server/job.py b/florist/api/routes/server/job.py index 013de869..eab54de3 100644 --- a/florist/api/routes/server/job.py +++ b/florist/api/routes/server/job.py @@ -1,5 +1,5 @@ """FastAPI routes for the job.""" -import json +from json import JSONDecodeError from typing import Any, Dict from fastapi import APIRouter, Body, HTTPException, Request, status @@ -29,14 +29,14 @@ def new_job(request: Request, job: Job = Body(...)) -> Dict[str, Any]: # noqa: :return: (Dict[str, Any]) A dictionary with the attributes of the new Job instance as saved in the database. :raises: (HTTPException) a 400 if job.server_info is not None and cannot be parsed into JSON. """ - if job.server_info is not None: - try: - json.loads(job.server_info) - except json.JSONDecodeError as e: - error_message = ( - "job.server_info could not be parsed into JSON. " f"job.server_info: {job.server_info}. Error: {e}" - ) - raise HTTPException(status_code=400, detail=error_message) from e + try: + is_valid = Job.is_valid_server_info(job.server_info) + if not is_valid: + msg = f"job.server_info is not valid. job.server_info: {job.server_info}." + raise HTTPException(status_code=400, detail=msg) + except JSONDecodeError as e: + msg = f"job.server_info could not be parsed into JSON. job.server_info: {job.server_info}. Error: {e}" + raise HTTPException(status_code=400, detail=msg) from e json_job = jsonable_encoder(job) result = request.app.database[JOB_DATABASE_NAME].insert_one(json_job) From 336c92b8102bce5b16f62f724fc7abdd468476b4 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 10 Apr 2024 14:34:59 -0400 Subject: [PATCH 22/24] Small change in docstring. --- florist/api/routes/server/job.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/florist/api/routes/server/job.py b/florist/api/routes/server/job.py index eab54de3..868c3a72 100644 --- a/florist/api/routes/server/job.py +++ b/florist/api/routes/server/job.py @@ -27,7 +27,7 @@ def new_job(request: Request, job: Job = Body(...)) -> Dict[str, Any]: # noqa: :param request: (fastapi.Request) the FastAPI request object. :param job: (Job) The Job instance to be saved in the database. :return: (Dict[str, Any]) A dictionary with the attributes of the new Job instance as saved in the database. - :raises: (HTTPException) a 400 if job.server_info is not None and cannot be parsed into JSON. + :raises: (HTTPException) status 400 if job.server_info is not None and cannot be parsed into JSON. """ try: is_valid = Job.is_valid_server_info(job.server_info) From 84d2dee9e10d79eb37706dd8b30cb76649a31d0a Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 12 Apr 2024 14:31:40 -0400 Subject: [PATCH 23/24] CR by John --- .pytest.ini | 2 + florist/api/db/entities.py | 3 +- florist/api/routes/server/job.py | 6 +-- florist/api/server.py | 8 ++-- .../integration/api/routes/server/test_job.py | 6 +-- florist/tests/integration/api/utils.py | 10 ++--- poetry.lock | 44 ++++++++++++++++++- pyproject.toml | 3 +- 8 files changed, 63 insertions(+), 19 deletions(-) create mode 100644 .pytest.ini diff --git a/.pytest.ini b/.pytest.ini new file mode 100644 index 00000000..2f4c80e3 --- /dev/null +++ b/.pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto diff --git a/florist/api/db/entities.py b/florist/api/db/entities.py index 9d0ade5a..e3615373 100644 --- a/florist/api/db/entities.py +++ b/florist/api/db/entities.py @@ -1,9 +1,8 @@ """Definitions for the MongoDB database entities.""" import uuid -from typing import Optional +from typing import Annotated, Optional from pydantic import BaseModel, Field -from typing_extensions import Annotated from florist.api.servers.common import Model diff --git a/florist/api/routes/server/job.py b/florist/api/routes/server/job.py index b1f001a8..6f43e1a7 100644 --- a/florist/api/routes/server/job.py +++ b/florist/api/routes/server/job.py @@ -16,7 +16,7 @@ status_code=status.HTTP_201_CREATED, response_model=Job, ) -def new_job(request: Request, job: Job = Body(...)) -> Dict[str, Any]: # noqa: B008 +async def new_job(request: Request, job: Job = Body(...)) -> Dict[str, Any]: # noqa: B008 """ Create a new training job. @@ -28,9 +28,9 @@ def new_job(request: Request, job: Job = Body(...)) -> Dict[str, Any]: # noqa: :return: (Dict[str, Any]) A dictionary with the attributes of the new Job instance as saved in the database. """ json_job = jsonable_encoder(job) - result = request.app.database[JOB_DATABASE_NAME].insert_one(json_job) + result = await request.app.database[JOB_DATABASE_NAME].insert_one(json_job) - created_job = request.app.database[JOB_DATABASE_NAME].find_one({"_id": result.inserted_id}) + created_job = await request.app.database[JOB_DATABASE_NAME].find_one({"_id": result.inserted_id}) assert isinstance(created_job, dict) return created_job diff --git a/florist/api/server.py b/florist/api/server.py index b4de58c3..f962c805 100644 --- a/florist/api/server.py +++ b/florist/api/server.py @@ -3,7 +3,7 @@ from typing import Any, AsyncGenerator from fastapi import FastAPI -from pymongo import MongoClient +from motor.motor_asyncio import AsyncIOMotorClient from florist.api.routes.server.job import router as job_router from florist.api.routes.server.training import router as training_router @@ -17,13 +17,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[Any, Any]: """Set up function for app startup and shutdown.""" # Set up mongodb - app.mongodb_client = MongoClient(MONGODB_URI) # type: ignore[attr-defined] - app.database = app.mongodb_client[DATABASE_NAME] # type: ignore[attr-defined] + app.db_client = AsyncIOMotorClient(MONGODB_URI) # type: ignore[attr-defined] + app.database = app.db_client[DATABASE_NAME] # type: ignore[attr-defined] yield # Shut down mongodb - app.mongodb_client.close() # type: ignore[attr-defined] + app.db_client.close() # type: ignore[attr-defined] app = FastAPI(lifespan=lifespan) diff --git a/florist/tests/integration/api/routes/server/test_job.py b/florist/tests/integration/api/routes/server/test_job.py index f4b13939..8bea9d1a 100644 --- a/florist/tests/integration/api/routes/server/test_job.py +++ b/florist/tests/integration/api/routes/server/test_job.py @@ -6,9 +6,9 @@ from florist.tests.integration.api.utils import mock_request -def test_new_job(mock_request) -> None: +async def test_new_job(mock_request) -> None: test_empty_job = Job() - result = new_job(mock_request, test_empty_job) + result = await new_job(mock_request, test_empty_job) assert result == { "_id": ANY, @@ -19,7 +19,7 @@ def test_new_job(mock_request) -> None: assert isinstance(result["_id"], str) test_job = Job(id="test-id", model=Model.MNIST, redis_host="test-redis-host", redis_port="test-redis-port") - result = new_job(mock_request, test_job) + result = await new_job(mock_request, test_job) assert result == { "_id": test_job.id, diff --git a/florist/tests/integration/api/utils.py b/florist/tests/integration/api/utils.py index bddbc051..29279d94 100644 --- a/florist/tests/integration/api/utils.py +++ b/florist/tests/integration/api/utils.py @@ -4,7 +4,7 @@ import threading import uvicorn -from pymongo import MongoClient +from motor.motor_asyncio import AsyncIOMotorClient from starlette.requests import Request from florist.api.server import MONGODB_URI @@ -29,8 +29,8 @@ def run_in_thread(self): class MockApp: def __init__(self, database_name: str): - self.mongo_client = MongoClient(MONGODB_URI) - self.database = self.mongo_client[database_name] + self.db_client = AsyncIOMotorClient(MONGODB_URI) + self.database = self.db_client[database_name] class MockRequest(Request): @@ -51,7 +51,7 @@ def app(self, value): @pytest.fixture -def mock_request() -> MockRequest: +async def mock_request() -> MockRequest: print(f"Creating test detabase '{TEST_DATABASE_NAME}'") app = MockApp(TEST_DATABASE_NAME) request = MockRequest(app) @@ -59,4 +59,4 @@ def mock_request() -> MockRequest: yield request print(f"Deleting test detabase '{TEST_DATABASE_NAME}'") - app.mongo_client.drop_database(TEST_DATABASE_NAME) + await app.db_client.drop_database(TEST_DATABASE_NAME) diff --git a/poetry.lock b/poetry.lock index bab05b8d..5f4f2481 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2103,6 +2103,30 @@ files = [ {file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"}, ] +[[package]] +name = "motor" +version = "3.4.0" +description = "Non-blocking MongoDB driver for Tornado or asyncio" +optional = false +python-versions = ">=3.7" +files = [ + {file = "motor-3.4.0-py3-none-any.whl", hash = "sha256:4b1e1a0cc5116ff73be2c080a72da078f2bb719b53bc7a6bb9e9a2f7dcd421ed"}, + {file = "motor-3.4.0.tar.gz", hash = "sha256:c89b4e4eb2e711345e91c7c9b122cb68cce0e5e869ed0387dd0acb10775e3131"}, +] + +[package.dependencies] +pymongo = ">=4.5,<5" + +[package.extras] +aws = ["pymongo[aws] (>=4.5,<5)"] +encryption = ["pymongo[encryption] (>=4.5,<5)"] +gssapi = ["pymongo[gssapi] (>=4.5,<5)"] +ocsp = ["pymongo[ocsp] (>=4.5,<5)"] +snappy = ["pymongo[snappy] (>=4.5,<5)"] +srv = ["pymongo[srv] (>=4.5,<5)"] +test = ["aiohttp (!=3.8.6)", "mockupdb", "motor[encryption]", "pytest (>=7)", "tornado (>=5)"] +zstd = ["pymongo[zstd] (>=4.5,<5)"] + [[package]] name = "mpmath" version = "1.3.0" @@ -3607,6 +3631,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.23.6" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-asyncio-0.23.6.tar.gz", hash = "sha256:ffe523a89c1c222598c76856e76852b787504ddb72dd5d9b6617ffa8aa2cde5f"}, + {file = "pytest_asyncio-0.23.6-py3-none-any.whl", hash = "sha256:68516fdd1018ac57b846c9846b954f0393b26f094764a28c955eabb0536a4e8a"}, +] + +[package.dependencies] +pytest = ">=7.0.0,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-cov" version = "3.0.0" @@ -5650,4 +5692,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.11" -content-hash = "101ed669ff9ee827aee599343ce6fd6470c42d5d5af2f2e6dbc4ed7893a5b5e3" +content-hash = "5ae177c779b9aaff044248a4e58c14c03bf6b3b9748508c754907d8243680347" diff --git a/pyproject.toml b/pyproject.toml index b039e97c..dbf971ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,8 @@ wandb = "^0.16.3" torchvision = "0.14.1" redis = "^5.0.1" python-multipart = "^0.0.9" -pymongo = {extras = ["srv"], version = "^4.6.3"} pydantic = "^1.10.15" +motor = "^3.4.0" [tool.poetry.group.test] optional = true @@ -38,6 +38,7 @@ ruff = "^0.2.0" pip-audit = "^2.7.1" nbqa = {extras = ["toolchain"], version = "^1.7.1"} freezegun = "^1.4.0" +pytest-asyncio = "^0.23.6" [tool.poetry.group.docs] optional = true From 168557b1c4559161e24a78adc91e3a6142a9e94a Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 12 Apr 2024 14:43:12 -0400 Subject: [PATCH 24/24] Fixing pip-audit error --- poetry.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index 5f4f2481..e6c34d34 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1552,13 +1552,13 @@ license = ["ukkonen"] [[package]] name = "idna" -version = "3.6" +version = "3.7" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.5" files = [ - {file = "idna-3.6-py3-none-any.whl", hash = "sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f"}, - {file = "idna-3.6.tar.gz", hash = "sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca"}, + {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"}, + {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, ] [[package]]