From 611fa1991a2a707f366c7a1347491696c7300ee5 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Fri, 22 Mar 2024 14:20:42 -0400 Subject: [PATCH] Make a function to start a server (#13) * Moving util functions for launching a server to florist/api/servers/utils.py and changing affected files * Adding launch_local_server function that launches a local server with Redis metrics monitoring --- florist/api/client.py | 9 +- florist/api/monitoring/logs.py | 32 ++++++ florist/api/servers/__init__.py | 1 + florist/api/servers/local.py | 49 +++++++++ florist/api/servers/utils.py | 66 +++++++++++ .../integration/api/launchers/test_launch.py | 8 +- florist/tests/integration/api/test_train.py | 35 ++++-- florist/tests/unit/api/servers/test_local.py | 50 +++++++++ florist/tests/unit/api/test_client.py | 3 +- florist/tests/utils/api/__init__.py | 0 florist/tests/utils/api/fl4health_utils.py | 85 -------------- florist/tests/utils/api/launch_utils.py | 26 ----- poetry.lock | 104 +++++++++++++----- pyproject.toml | 2 +- 14 files changed, 310 insertions(+), 160 deletions(-) create mode 100644 florist/api/monitoring/logs.py create mode 100644 florist/api/servers/__init__.py create mode 100644 florist/api/servers/local.py create mode 100644 florist/api/servers/utils.py create mode 100644 florist/tests/unit/api/servers/test_local.py delete mode 100644 florist/tests/utils/api/__init__.py delete mode 100644 florist/tests/utils/api/fl4health_utils.py delete mode 100644 florist/tests/utils/api/launch_utils.py diff --git a/florist/api/client.py b/florist/api/client.py index 8a145feb..720dee97 100644 --- a/florist/api/client.py +++ b/florist/api/client.py @@ -8,11 +8,10 @@ from florist.api.clients.common import Clients from florist.api.launchers.local import launch_client +from florist.api.monitoring.logs import get_client_log_file_path from florist.api.monitoring.metrics import RedisMetricsReporter -LOG_FOLDER = Path("logs/client/") - app = FastAPI() @@ -63,10 +62,8 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red metrics_reporter=metrics_reporter, ) - LOG_FOLDER.mkdir(parents=True, exist_ok=True) - log_file_name = LOG_FOLDER / f"{client_uuid}.out" - - launch_client(client_obj, server_address, str(log_file_name)) + log_file_name = str(get_client_log_file_path(client_uuid)) + launch_client(client_obj, server_address, log_file_name) return JSONResponse({"uuid": client_uuid}) diff --git a/florist/api/monitoring/logs.py b/florist/api/monitoring/logs.py new file mode 100644 index 00000000..ed8bdfb6 --- /dev/null +++ b/florist/api/monitoring/logs.py @@ -0,0 +1,32 @@ +"""General functions and definitions for monitoring.""" +from pathlib import Path + + +CLIENT_LOG_FOLDER = Path("logs/client/") +SERVER_LOG_FOLDER = Path("logs/server/") + + +def get_client_log_file_path(client_uuid: str) -> Path: + """ + Make the client log file path given its UUID. + + Will use the default client log folder defined in this class. + + :param client_uuid: (str) the uuid for the client to generate the log file. + :return: (pathlib.Path) The client log file path in the format f"{CLIENT_LOG_FOLDER}/{client_uuid}.out". + """ + CLIENT_LOG_FOLDER.mkdir(parents=True, exist_ok=True) + return CLIENT_LOG_FOLDER / f"{client_uuid}.out" + + +def get_server_log_file_path(server_uuid: str) -> Path: + """ + Make the default server log file path given its UUID. + + Will use the default server log folder defined in this class. + + :param server_uuid: (str) the uuid for the server to generate the log file. + :return: (Path) The server log file path in the format f"{SERVER_LOG_FOLDER}/{server_uuid}.out". + """ + SERVER_LOG_FOLDER.mkdir(parents=True, exist_ok=True) + return SERVER_LOG_FOLDER / f"{server_uuid}.out" diff --git a/florist/api/servers/__init__.py b/florist/api/servers/__init__.py new file mode 100644 index 00000000..bc65df80 --- /dev/null +++ b/florist/api/servers/__init__.py @@ -0,0 +1 @@ +"""Implementations for the servers.""" diff --git a/florist/api/servers/local.py b/florist/api/servers/local.py new file mode 100644 index 00000000..957558c2 --- /dev/null +++ b/florist/api/servers/local.py @@ -0,0 +1,49 @@ +"""Functions and definitions to launch local servers.""" +import uuid +from functools import partial +from multiprocessing import Process +from typing import Tuple + +from torch import nn + +from florist.api.launchers.local import launch_server +from florist.api.monitoring.logs import get_server_log_file_path +from florist.api.monitoring.metrics import RedisMetricsReporter +from florist.api.servers.utils import get_server + + +def launch_local_server( + model: nn.Module, + n_clients: int, + server_address: str, + n_server_rounds: int, + redis_host: str, + redis_port: str, +) -> Tuple[str, Process]: + """ + Launch a FL server locally. + + :param model: (torch.nn.Module) The model to be used by the server. Should match the clients' model. + :param n_clients: (int) The number of clients that will report to this server. + :param server_address: (str) The address the server should start at. + :param n_server_rounds: (int) The number of rounds the training should run for. + :param redis_host: (str) the host name for the Redis instance for metrics reporting. + :param redis_port: (str) the port for the Redis instance for metrics reporting. + :return: (Tuple[str, multiprocessing.Process]) the UUID of the server, which can be used to pull + metrics from Redis, along with its local process object. + """ + server_uuid = str(uuid.uuid4()) + + metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=server_uuid) + server_constructor = partial(get_server, model=model, n_clients=n_clients, metrics_reporter=metrics_reporter) + + log_file_name = str(get_server_log_file_path(server_uuid)) + server_process = launch_server( + server_constructor, + server_address, + n_server_rounds, + log_file_name, + seconds_to_sleep=0, + ) + + return server_uuid, server_process diff --git a/florist/api/servers/utils.py b/florist/api/servers/utils.py new file mode 100644 index 00000000..d64f1618 --- /dev/null +++ b/florist/api/servers/utils.py @@ -0,0 +1,66 @@ +"""Utilities functions and definitions for starting a server.""" +from functools import partial +from typing import Callable, Dict, Union + +from fl4health.client_managers.base_sampling_manager import SimpleClientManager +from fl4health.reporting.metrics import MetricsReporter +from fl4health.server.base_server import FlServer +from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn +from flwr.common.parameter import ndarrays_to_parameters +from flwr.server.strategy import FedAvg +from torch import nn + + +FitConfigFn = Callable[[int], Dict[str, Union[bool, bytes, float, int, str]]] + + +def fit_config(batch_size: int, local_epochs: int, current_server_round: int) -> Dict[str, int]: + """ + Return a dictionary used to configure the server's fit function. + + :param batch_size: (int) the size of the batch of samples. + :param local_epochs: (int) the number of local epochs the clients will run. + :param current_server_round: (int) the current server round + :return: (Dict[str, int]) A dictionary to the used at the config for the fit function. + """ + return { + "batch_size": batch_size, + "current_server_round": current_server_round, + "local_epochs": local_epochs, + } + + +def get_server( + model: nn.Module, + fit_config: Callable[[int, int, int], Dict[str, int]] = fit_config, + n_clients: int = 2, + batch_size: int = 8, + local_epochs: int = 1, + metrics_reporter: MetricsReporter = None, +) -> FlServer: + """ + Return a server instance with FedAvg aggregation strategy. + + :param model: (torch.nn.Model) the model the server and clients will be using. + :param fit_config: (Callable[[int, int, int], Dict[str, int]]) the function to configure the fit method. + :param n_clients: (int) the number of clients that will participate on training. Optional, default is 2. + :param batch_size: (int) the size of the batch of samples. Optional, default is 8. + :param local_epochs: (int) the number of local epochs the clients will run. Optional, default is 1. + :param metrics_reporter: (fl4health.reporting.metrics.MetricsReporter) An optional metrics reporter instance. + Default is None. + :return: (fl4health.server.base_server.FlServer) An instance of FlServer with FedAvg as strategy. + """ + fit_config_fn: FitConfigFn = partial(fit_config, batch_size, local_epochs) # type: ignore + initial_model_parameters = ndarrays_to_parameters([val.cpu().numpy() for _, val in model.state_dict().items()]) + strategy = FedAvg( + min_fit_clients=n_clients, + min_evaluate_clients=n_clients, + min_available_clients=n_clients, + on_fit_config_fn=fit_config_fn, + on_evaluate_config_fn=fit_config_fn, + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + initial_parameters=initial_model_parameters, + ) + client_manager = SimpleClientManager() + return FlServer(strategy=strategy, client_manager=client_manager, metrics_reporter=metrics_reporter) diff --git a/florist/tests/integration/api/launchers/test_launch.py b/florist/tests/integration/api/launchers/test_launch.py index f4b07495..a157793c 100644 --- a/florist/tests/integration/api/launchers/test_launch.py +++ b/florist/tests/integration/api/launchers/test_launch.py @@ -1,13 +1,14 @@ import os import re import tempfile +from functools import partial from pathlib import Path import torch from florist.api.launchers.local import launch -from florist.api.clients.mnist import MnistClient -from florist.tests.utils.api.launch_utils import get_server +from florist.api.clients.mnist import MnistClient, MnistNet +from florist.api.servers.utils import get_server def assert_string_in_file(file_path: str, search_string: str) -> bool: @@ -28,10 +29,11 @@ def test_launch() -> None: os.mkdir(client_data_path) clients = [MnistClient(client_data_path, [], torch.device("cpu")) for client_data_path in client_data_paths] + server_constructor = partial(get_server, model=MnistNet()) server_path = os.path.join(temp_dir, "server") client_base_path = f"{temp_dir}/client" launch( - get_server, + server_constructor, server_address, n_server_rounds, clients, diff --git a/florist/tests/integration/api/test_train.py b/florist/tests/integration/api/test_train.py index 4f84a29f..34c61b3b 100644 --- a/florist/tests/integration/api/test_train.py +++ b/florist/tests/integration/api/test_train.py @@ -1,32 +1,45 @@ import json import tempfile -from functools import partial +import time from unittest.mock import ANY +import redis + from florist.api import client -from florist.api.launchers.local import launch_server -from florist.tests.utils.api.launch_utils import get_server +from florist.api.clients.mnist import MnistNet +from florist.api.monitoring.logs import get_server_log_file_path +from florist.api.servers.local import launch_local_server def test_train(): - test_server_address = "0.0.0.0:8080" - with tempfile.TemporaryDirectory() as temp_dir: - server_constructor = partial(get_server, n_clients=1) - server_log_file = f"{temp_dir}/server.out" - server_process = launch_server(server_constructor, test_server_address, 2, server_log_file) - + test_server_address = "0.0.0.0:8080" test_client = "MNIST" test_data_path = f"{temp_dir}/data" test_redis_host = "localhost" test_redis_port = "6379" + server_uuid, server_process = launch_local_server( + MnistNet(), + 1, + test_server_address, + 2, + test_redis_host, + test_redis_port, + ) + time.sleep(10) # giving time to start the server + response = client.start(test_server_address, test_client, test_data_path, test_redis_host, test_redis_port) + json_body = json.loads(response.body.decode()) - assert json.loads(response.body.decode()) == {"uuid": ANY} + assert json_body == {"uuid": ANY} server_process.join() - with open(server_log_file, "r") as f: + redis_conn = redis.Redis(host=test_redis_host, port=test_redis_port) + assert redis_conn.get(json_body["uuid"]) is not None + assert redis_conn.get(server_uuid) is not None + + with open(get_server_log_file_path(server_uuid), "r") as f: file_contents = f.read() assert "FL finished in" in file_contents diff --git a/florist/tests/unit/api/servers/test_local.py b/florist/tests/unit/api/servers/test_local.py new file mode 100644 index 00000000..69e06d1a --- /dev/null +++ b/florist/tests/unit/api/servers/test_local.py @@ -0,0 +1,50 @@ +from unittest.mock import ANY, Mock, patch + +from florist.api.clients.mnist import MnistNet +from florist.api.monitoring.logs import get_server_log_file_path +from florist.api.monitoring.metrics import RedisMetricsReporter +from florist.api.servers.local import launch_local_server +from florist.api.servers.utils import get_server + + +@patch("florist.api.servers.local.launch_server") +def test_launch_local_server(mock_launch_server: Mock) -> None: + test_model = MnistNet() + test_n_clients = 2 + test_server_address = "test-server-address" + test_n_server_rounds = 5 + test_redis_host = "test-redis-host" + test_redis_port = "test-redis-port" + test_server_process = "test-server-process" + mock_launch_server.return_value = test_server_process + + server_uuid, server_process = launch_local_server( + test_model, + test_n_clients, + test_server_address, + test_n_server_rounds, + test_redis_host, + test_redis_port, + ) + + assert server_uuid is not None + assert server_process == test_server_process + + mock_launch_server.assert_called_once() + call_args = mock_launch_server.call_args_list[0][0] + call_kwargs = mock_launch_server.call_args_list[0][1] + assert call_args == ( + ANY, + test_server_address, + test_n_server_rounds, + str(get_server_log_file_path(server_uuid)), + ) + assert call_kwargs == {"seconds_to_sleep": 0} + assert call_args[0].func == get_server + assert call_args[0].keywords == {"model": test_model, "n_clients": test_n_clients, "metrics_reporter": ANY} + + metrics_reporter = call_args[0].keywords["metrics_reporter"] + assert isinstance(metrics_reporter, RedisMetricsReporter) + assert metrics_reporter.host == test_redis_host + assert metrics_reporter.port == test_redis_port + assert metrics_reporter.run_id == server_uuid diff --git a/florist/tests/unit/api/test_client.py b/florist/tests/unit/api/test_client.py index 91298e47..94f06e1d 100644 --- a/florist/tests/unit/api/test_client.py +++ b/florist/tests/unit/api/test_client.py @@ -4,6 +4,7 @@ from florist.api import client from florist.api.clients.mnist import MnistClient +from florist.api.monitoring.logs import get_client_log_file_path from florist.api.monitoring.metrics import RedisMetricsReporter @@ -30,7 +31,7 @@ def test_start_success(mock_launch_client: Mock) -> None: json_body = json.loads(response.body.decode()) assert json_body == {"uuid": ANY} - log_file_name = str(client.LOG_FOLDER / f"{json_body['uuid']}.out") + log_file_name = str(get_client_log_file_path(json_body["uuid"])) mock_launch_client.assert_called_once_with(ANY, test_server_address, log_file_name) client_obj = mock_launch_client.call_args_list[0][0][0] diff --git a/florist/tests/utils/api/__init__.py b/florist/tests/utils/api/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/florist/tests/utils/api/fl4health_utils.py b/florist/tests/utils/api/fl4health_utils.py deleted file mode 100644 index 9b0e6bd2..00000000 --- a/florist/tests/utils/api/fl4health_utils.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import Callable, List, Tuple - -from fl4health.client_managers.base_sampling_manager import SimpleClientManager -from fl4health.server.base_server import FlServer -from flwr.common.parameter import ndarrays_to_parameters -from flwr.common.typing import Metrics, Parameters -from flwr.server.strategy import FedAvg -from torch import nn - - -def metric_aggregation( - all_client_metrics: List[Tuple[int, Metrics]], -) -> Tuple[int, Metrics]: - aggregated_metrics: Metrics = {} - total_examples = 0 - # Run through all of the metrics - for num_examples_on_client, client_metrics in all_client_metrics: - total_examples += num_examples_on_client - for metric_name, metric_value in client_metrics.items(): - # Here we assume each metric is normalized by the number of examples on the client. So we scale up to - # get the "raw" value - if isinstance(metric_value, float): - current_metric_value = aggregated_metrics.get(metric_name, 0.0) - assert isinstance(current_metric_value, float) - aggregated_metrics[metric_name] = current_metric_value + num_examples_on_client * metric_value - elif isinstance(metric_value, int): - current_metric_value = aggregated_metrics.get(metric_name, 0) - assert isinstance(current_metric_value, int) - aggregated_metrics[metric_name] = current_metric_value + num_examples_on_client * metric_value - else: - raise ValueError("Metric type is not supported") - return total_examples, aggregated_metrics - - -def normalize_metrics(total_examples: int, aggregated_metrics: Metrics) -> Metrics: - # Normalize all metric values by the total count of examples seen. - normalized_metrics: Metrics = {} - for metric_name, metric_value in aggregated_metrics.items(): - if isinstance(metric_value, (float, int)): - normalized_metrics[metric_name] = metric_value / total_examples - return normalized_metrics - - -def fit_metrics_aggregation_fn( - all_client_metrics: List[Tuple[int, Metrics]], -) -> Metrics: - # This function is run by the server to aggregate metrics returned by each clients fit function - # NOTE: The first value of the tuple is number of examples for FedAvg - total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) - return normalize_metrics(total_examples, aggregated_metrics) - - -def evaluate_metrics_aggregation_fn( - all_client_metrics: List[Tuple[int, Metrics]], -) -> Metrics: - # This function is run by the server to aggregate metrics returned by each clients evaluate function - # NOTE: The first value of the tuple is number of examples for FedAvg - total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) - return normalize_metrics(total_examples, aggregated_metrics) - - -def get_initial_model_parameters(model: nn.Module) -> Parameters: - # Initializing the model parameters on the server side. - return ndarrays_to_parameters([val.cpu().numpy() for _, val in model.state_dict().items()]) - - -def get_fedavg_strategy(model: nn.Module, n_clients: int, fit_config_fn: Callable) -> FedAvg: - strategy = FedAvg( - min_fit_clients=n_clients, - min_evaluate_clients=n_clients, - min_available_clients=n_clients, - on_fit_config_fn=fit_config_fn, - on_evaluate_config_fn=fit_config_fn, - fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, - evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, - initial_parameters=get_initial_model_parameters(model), - ) - return strategy - - -def get_server_fedavg(model: nn.Module, n_clients: int, fit_config_fn: Callable) -> FlServer: - strategy = get_fedavg_strategy(model, n_clients, fit_config_fn) - client_manager = SimpleClientManager() - server = FlServer(strategy=strategy, client_manager=client_manager) - return server diff --git a/florist/tests/utils/api/launch_utils.py b/florist/tests/utils/api/launch_utils.py deleted file mode 100644 index dfc31128..00000000 --- a/florist/tests/utils/api/launch_utils.py +++ /dev/null @@ -1,26 +0,0 @@ -from functools import partial -from typing import Callable, Dict - -from fl4health.server.base_server import FlServer - -from florist.tests.utils.api.fl4health_utils import get_server_fedavg -from florist.api.clients.mnist import MnistNet - - -def fit_config(batch_size: int, local_epochs: int, current_server_round: int) -> Dict[str, int]: - return { - "batch_size": batch_size, - "current_server_round": current_server_round, - "local_epochs": local_epochs, - } - - -def get_server( - fit_config: Callable[..., Dict[str, int]] = fit_config, - n_clients: int = 2, - batch_size: int = 8, - local_epochs: int = 1, -) -> FlServer: - fit_config_fn = partial(fit_config, batch_size, local_epochs) - server = get_server_fedavg(model=MnistNet(), n_clients=n_clients, fit_config_fn=fit_config_fn) - return server diff --git a/poetry.lock b/poetry.lock index cb1a3609..cc725c82 100644 --- a/poetry.lock +++ b/poetry.lock @@ -344,33 +344,33 @@ lxml = ["lxml"] [[package]] name = "black" -version = "24.1.1" +version = "24.3.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" files = [ - {file = "black-24.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2588021038bd5ada078de606f2a804cadd0a3cc6a79cb3e9bb3a8bf581325a4c"}, - {file = "black-24.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1a95915c98d6e32ca43809d46d932e2abc5f1f7d582ffbe65a5b4d1588af7445"}, - {file = "black-24.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fa6a0e965779c8f2afb286f9ef798df770ba2b6cee063c650b96adec22c056a"}, - {file = "black-24.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:5242ecd9e990aeb995b6d03dc3b2d112d4a78f2083e5a8e86d566340ae80fec4"}, - {file = "black-24.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:fc1ec9aa6f4d98d022101e015261c056ddebe3da6a8ccfc2c792cbe0349d48b7"}, - {file = "black-24.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0269dfdea12442022e88043d2910429bed717b2d04523867a85dacce535916b8"}, - {file = "black-24.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3d64db762eae4a5ce04b6e3dd745dcca0fb9560eb931a5be97472e38652a161"}, - {file = "black-24.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:5d7b06ea8816cbd4becfe5f70accae953c53c0e53aa98730ceccb0395520ee5d"}, - {file = "black-24.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e2c8dfa14677f90d976f68e0c923947ae68fa3961d61ee30976c388adc0b02c8"}, - {file = "black-24.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a21725862d0e855ae05da1dd25e3825ed712eaaccef6b03017fe0853a01aa45e"}, - {file = "black-24.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07204d078e25327aad9ed2c64790d681238686bce254c910de640c7cc4fc3aa6"}, - {file = "black-24.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:a83fe522d9698d8f9a101b860b1ee154c1d25f8a82ceb807d319f085b2627c5b"}, - {file = "black-24.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08b34e85170d368c37ca7bf81cf67ac863c9d1963b2c1780c39102187ec8dd62"}, - {file = "black-24.1.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7258c27115c1e3b5de9ac6c4f9957e3ee2c02c0b39222a24dc7aa03ba0e986f5"}, - {file = "black-24.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40657e1b78212d582a0edecafef133cf1dd02e6677f539b669db4746150d38f6"}, - {file = "black-24.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e298d588744efda02379521a19639ebcd314fba7a49be22136204d7ed1782717"}, - {file = "black-24.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:34afe9da5056aa123b8bfda1664bfe6fb4e9c6f311d8e4a6eb089da9a9173bf9"}, - {file = "black-24.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:854c06fb86fd854140f37fb24dbf10621f5dab9e3b0c29a690ba595e3d543024"}, - {file = "black-24.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3897ae5a21ca132efa219c029cce5e6bfc9c3d34ed7e892113d199c0b1b444a2"}, - {file = "black-24.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:ecba2a15dfb2d97105be74bbfe5128bc5e9fa8477d8c46766505c1dda5883aac"}, - {file = "black-24.1.1-py3-none-any.whl", hash = "sha256:5cdc2e2195212208fbcae579b931407c1fa9997584f0a415421748aeafff1168"}, - {file = "black-24.1.1.tar.gz", hash = "sha256:48b5760dcbfe5cf97fd4fba23946681f3a81514c6ab8a45b50da67ac8fbc6c7b"}, + {file = "black-24.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7d5e026f8da0322b5662fa7a8e752b3fa2dac1c1cbc213c3d7ff9bdd0ab12395"}, + {file = "black-24.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9f50ea1132e2189d8dff0115ab75b65590a3e97de1e143795adb4ce317934995"}, + {file = "black-24.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2af80566f43c85f5797365077fb64a393861a3730bd110971ab7a0c94e873e7"}, + {file = "black-24.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:4be5bb28e090456adfc1255e03967fb67ca846a03be7aadf6249096100ee32d0"}, + {file = "black-24.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4f1373a7808a8f135b774039f61d59e4be7eb56b2513d3d2f02a8b9365b8a8a9"}, + {file = "black-24.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:aadf7a02d947936ee418777e0247ea114f78aff0d0959461057cae8a04f20597"}, + {file = "black-24.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c02e4ea2ae09d16314d30912a58ada9a5c4fdfedf9512d23326128ac08ac3d"}, + {file = "black-24.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:bf21b7b230718a5f08bd32d5e4f1db7fc8788345c8aea1d155fc17852b3410f5"}, + {file = "black-24.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:2818cf72dfd5d289e48f37ccfa08b460bf469e67fb7c4abb07edc2e9f16fb63f"}, + {file = "black-24.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4acf672def7eb1725f41f38bf6bf425c8237248bb0804faa3965c036f7672d11"}, + {file = "black-24.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7ed6668cbbfcd231fa0dc1b137d3e40c04c7f786e626b405c62bcd5db5857e4"}, + {file = "black-24.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:56f52cfbd3dabe2798d76dbdd299faa046a901041faf2cf33288bc4e6dae57b5"}, + {file = "black-24.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:79dcf34b33e38ed1b17434693763301d7ccbd1c5860674a8f871bd15139e7837"}, + {file = "black-24.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e19cb1c6365fd6dc38a6eae2dcb691d7d83935c10215aef8e6c38edee3f77abd"}, + {file = "black-24.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65b76c275e4c1c5ce6e9870911384bff5ca31ab63d19c76811cb1fb162678213"}, + {file = "black-24.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:b5991d523eee14756f3c8d5df5231550ae8993e2286b8014e2fdea7156ed0959"}, + {file = "black-24.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c45f8dff244b3c431b36e3224b6be4a127c6aca780853574c00faf99258041eb"}, + {file = "black-24.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6905238a754ceb7788a73f02b45637d820b2f5478b20fec82ea865e4f5d4d9f7"}, + {file = "black-24.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7de8d330763c66663661a1ffd432274a2f92f07feeddd89ffd085b5744f85e7"}, + {file = "black-24.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:7bb041dca0d784697af4646d3b62ba4a6b028276ae878e53f6b4f74ddd6db99f"}, + {file = "black-24.3.0-py3-none-any.whl", hash = "sha256:41622020d7120e01d377f74249e677039d20e6344ff5851de8a10f11f513bf93"}, + {file = "black-24.3.0.tar.gz", hash = "sha256:a0c9c4a0771afc6919578cec71ce82a3e31e054904e7197deacbc9382671c41f"}, ] [package.dependencies] @@ -1077,13 +1077,13 @@ typing = ["typing-extensions (>=4.8)"] [[package]] name = "fl4health" -version = "0.1.12" +version = "0.1.13" description = "Federated Learning for Health" optional = false python-versions = ">=3.9.0,<3.11" files = [ - {file = "fl4health-0.1.12-py3-none-any.whl", hash = "sha256:fa8ea2c73990b7d54b057e3fd90345cdec06df794777c86c629bc590dd551f33"}, - {file = "fl4health-0.1.12.tar.gz", hash = "sha256:6f777a48fd95b33a4e26409e9337d71294298a4a874872e98937aac5c88690a5"}, + {file = "fl4health-0.1.13-py3-none-any.whl", hash = "sha256:a6a947d82207e4734267b9c4839a8e7af50394d3db675d3f553e7a9768972efb"}, + {file = "fl4health-0.1.13.tar.gz", hash = "sha256:7de5bd3929229ffea7e8328b3acc7f9540ec86e90b1c1e94e832f2986e91ca1d"}, ] [package.dependencies] @@ -1095,6 +1095,7 @@ opacus = ">=1.3.0,<2.0.0" pandas = ">=2.0,<3.0" pycyclops = ">=0.2.2,<0.3.0" torch = ">=1.12.1,<2.0.0" +torchmetrics = ">=1.3.0,<2.0.0" [[package]] name = "flake8" @@ -1902,6 +1903,27 @@ files = [ docs = ["Sphinx (>=5.0.2)", "doc8 (>=0.11.2)", "sphinx-autobuild", "sphinx-copybutton", "sphinx-reredirects (>=0.1.2)", "sphinx-rtd-dark-mode (>=1.3.0)", "sphinx-rtd-theme (>=1.0.0)", "sphinxcontrib-apidoc (>=0.4.0)"] testing = ["black", "isort", "pytest (>=6,!=7.0.0)", "pytest-xdist (>=2)", "twine"] +[[package]] +name = "lightning-utilities" +version = "0.11.0" +description = "Lightning toolbox for across the our ecosystem." +optional = false +python-versions = ">=3.8" +files = [ + {file = "lightning-utilities-0.11.0.tar.gz", hash = "sha256:dd704795785ceba1e0cd60ba3a9b0553c7902ec9efc1578a74e893a291416e62"}, + {file = "lightning_utilities-0.11.0-py3-none-any.whl", hash = "sha256:bf576a421027fdbaf48e80cbc2fdf900a3316a469748a953c33a8ca2b2718a20"}, +] + +[package.dependencies] +packaging = ">=17.1" +setuptools = "*" +typing-extensions = "*" + +[package.extras] +cli = ["fire"] +docs = ["requests (>=2.0.0)"] +typing = ["mypy (>=1.0.0)", "types-setuptools"] + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -4736,6 +4758,34 @@ typing-extensions = "*" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] +[[package]] +name = "torchmetrics" +version = "1.3.2" +description = "PyTorch native Metrics" +optional = false +python-versions = ">=3.8" +files = [ + {file = "torchmetrics-1.3.2-py3-none-any.whl", hash = "sha256:44ca3a9f86dc050cb3f554836ef291698ea797778457195b4f685fce8e2e64a3"}, + {file = "torchmetrics-1.3.2.tar.gz", hash = "sha256:0a67694a4c4265eeb54cda741eaf5cb1f3a71da74b7e7e6215ad156c9f2379f6"}, +] + +[package.dependencies] +lightning-utilities = ">=0.8.0" +numpy = ">1.20.0" +packaging = ">17.1" +torch = ">=1.10.0" + +[package.extras] +all = ["SciencePlots (>=2.0.0)", "ipadic (>=1.0.0)", "matplotlib (>=3.3.0)", "mecab-ko (>=1.0.0)", "mecab-ko-dic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "mypy (==1.8.0)", "nltk (>=3.6)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.3.0)", "regex (>=2021.9.24)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "torch (==2.2.1)", "torch-fidelity (<=0.4.0)", "torchaudio (>=0.10.0)", "torchvision (>=0.8)", "tqdm (>=4.41.0)", "transformers (>4.4.0)", "transformers (>=4.10.0)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] +audio = ["pystoi (>=0.3.0)", "torchaudio (>=0.10.0)"] +detection = ["pycocotools (>2.0.0)", "torchvision (>=0.8)"] +dev = ["SciencePlots (>=2.0.0)", "bert-score (==0.3.13)", "dython (<=0.7.5)", "fairlearn", "fast-bss-eval (>=0.1.0)", "faster-coco-eval (>=1.3.3)", "huggingface-hub (<0.22)", "ipadic (>=1.0.0)", "jiwer (>=2.3.0)", "kornia (>=0.6.7)", "lpips (<=0.1.4)", "matplotlib (>=3.3.0)", "mecab-ko (>=1.0.0)", "mecab-ko-dic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "mir-eval (>=0.6)", "monai (==1.3.0)", "mypy (==1.8.0)", "netcal (>1.0.0)", "nltk (>=3.6)", "numpy (<1.25.0)", "pandas (>1.0.0)", "pandas (>=1.4.0)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.3.0)", "pytorch-msssim (==1.0.0)", "regex (>=2021.9.24)", "rouge-score (>0.1.0)", "sacrebleu (>=2.3.0)", "scikit-image (>=0.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "sewar (>=0.4.4)", "statsmodels (>0.13.5)", "torch (==2.2.1)", "torch-complex (<=0.4.3)", "torch-fidelity (<=0.4.0)", "torchaudio (>=0.10.0)", "torchvision (>=0.8)", "tqdm (>=4.41.0)", "transformers (>4.4.0)", "transformers (>=4.10.0)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] +image = ["scipy (>1.0.0)", "torch-fidelity (<=0.4.0)", "torchvision (>=0.8)"] +multimodal = ["piq (<=0.8.0)", "transformers (>=4.10.0)"] +text = ["ipadic (>=1.0.0)", "mecab-ko (>=1.0.0)", "mecab-ko-dic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "nltk (>=3.6)", "regex (>=2021.9.24)", "sentencepiece (>=0.2.0)", "tqdm (>=4.41.0)", "transformers (>4.4.0)"] +typing = ["mypy (==1.8.0)", "torch (==2.2.1)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] +visual = ["SciencePlots (>=2.0.0)", "matplotlib (>=3.3.0)"] + [[package]] name = "torchvision" version = "0.14.1" @@ -5463,4 +5513,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.11" -content-hash = "ba282349a592722e4def85a255fa8c96022ccba402438168f5c3c74c6d8b4ea5" +content-hash = "17c0b90826fdf0ba50663f1fb54861e1164919cbd4b30abbc5e22dc7bc77642f" diff --git a/pyproject.toml b/pyproject.toml index 3e95a57c..a45ec8e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ python = ">=3.9.0,<3.11" fastapi = "^0.109.1" uvicorn = {version = "^0.23.2", extras = ["standard"]} -fl4health = "^0.1.11" +fl4health = "^0.1.13" wandb = "^0.16.3" torchvision = "0.14.1" redis = "^5.0.1"