From 67f0c9a141e9e9cf0d930029080aa49029529eab Mon Sep 17 00:00:00 2001 From: jewelltaylor Date: Sat, 23 Nov 2024 19:29:43 -0500 Subject: [PATCH 1/4] Add ability to specify custom nnunet trainer --- fl4health/clients/nnunet_client.py | 13 +++++++++++-- fl4health/servers/nnunet_server.py | 8 ++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/fl4health/clients/nnunet_client.py b/fl4health/clients/nnunet_client.py index 734f6c913..571925ec1 100644 --- a/fl4health/clients/nnunet_client.py +++ b/fl4health/clients/nnunet_client.py @@ -8,7 +8,7 @@ from logging import DEBUG, ERROR, INFO, WARNING from os.path import exists, join from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union import numpy as np import torch @@ -78,6 +78,8 @@ def __init__( checkpointer: Optional[ClientCheckpointModule] = None, reporters: Sequence[BaseReporter] | None = None, client_name: Optional[str] = None, + nnunet_trainer_class: Type[nnUNetTrainer] = nnUNetTrainer, + nnunet_trainer_class_kwargs: Optional[dict[str, Any]] = {}, ) -> None: """ A client for training nnunet models. Requires the nnunet environment variables @@ -141,6 +143,10 @@ def __init__( provided. Defaults to None. reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the client should send data to. + nnunet_trainer_class (Type[nnUNetTrainer]): A nnUNetTrainer constructor. + Useful for passing custom nnUNetTrainer. Defaults to the standard nnUNetTrainer class. + nnunet_trainer_class_kwargs (dict[str, Any]): Additonal kwargs to pass to nnunet_trainer_class. + Defaults to empty dictionary. """ metrics = metrics if metrics else [] # Parent method sets up several class attributes @@ -182,6 +188,8 @@ def __init__( self.stream2debug = StreamToLogger(FLOWER_LOGGER, DEBUG) # nnunet specific attributes to be initialized in setup_client + self.nnunet_trainer_class = nnunet_trainer_class + self.nnunet_trainer_class_kwargs = nnunet_trainer_class_kwargs self.nnunet_trainer: nnUNetTrainer self.nnunet_config: NnunetConfig self.plans: dict[str, Any] @@ -465,12 +473,13 @@ def setup_client(self, config: Config) -> None: # Unless log level is DEBUG or lower hide nnunet output with redirect_stdout(self.stream2debug): # Create the nnunet trainer - self.nnunet_trainer = nnUNetTrainer( + self.nnunet_trainer = self.nnunet_trainer_class( plans=self.plans, configuration=self.nnunet_config.value, fold=self.fold, dataset_json=self.dataset_json, device=self.device, + **self.nnunet_trainer_class_kwargs, ) # nnunet_trainer initialization self.nnunet_trainer.initialize() diff --git a/fl4health/servers/nnunet_server.py b/fl4health/servers/nnunet_server.py index 6a73a0a45..54f60b596 100644 --- a/fl4health/servers/nnunet_server.py +++ b/fl4health/servers/nnunet_server.py @@ -3,7 +3,7 @@ from collections.abc import Callable, Sequence from logging import INFO from pathlib import Path -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Type, Union import torch.nn as nn from flwr.common import Parameters @@ -69,6 +69,7 @@ def __init__( intermediate_server_state_dir: Path | None = None, server_name: str | None = None, accept_failures: bool = True, + nnunet_trainer_class: Type[nnUNetTrainer] = nnUNetTrainer, ) -> None: """ A Basic FlServer with added functionality to ask a client to initialize the global nnunet plans if one was not @@ -94,6 +95,8 @@ def __init__( accept_failures (bool, optional): Determines whether the server should accept failures during training or evaluation from clients or not. If set to False, this will cause the server to shutdown all clients and throw an exception. Defaults to True. + nnunet_trainer_class (Type[nnUNetTrainer], optional): nnUNetTrainer class. + Useful for passing custom nnUNetTrainer. Defaults to the standard nnUNetTrainer class. """ FlServerWithCheckpointing.__init__( self, @@ -107,6 +110,7 @@ def __init__( server_name=server_name, ) self.initialized = False + self.nnunet_trainer_class = nnunet_trainer_class self.nnunet_plans_bytes: bytes self.num_input_channels: int @@ -127,7 +131,7 @@ def initialize_server_model(self) -> None: plans = pickle.loads(self.nnunet_plans_bytes) plans_manager = PlansManager(plans) configuration_manager = plans_manager.get_configuration(self.nnunet_config.value) - model = nnUNetTrainer.build_network_architecture( + model = self.nnunet_trainer_class.build_network_architecture( configuration_manager.network_arch_class_name, configuration_manager.network_arch_init_kwargs, configuration_manager.network_arch_init_kwargs_req_import, From ea4cd71dec91f874f9670cad0f520ed50d5dc277 Mon Sep 17 00:00:00 2001 From: jewelltaylor Date: Sat, 23 Nov 2024 19:38:18 -0500 Subject: [PATCH 2/4] Fix small issue in documentation --- fl4health/servers/nnunet_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fl4health/servers/nnunet_server.py b/fl4health/servers/nnunet_server.py index 54f60b596..9338855ca 100644 --- a/fl4health/servers/nnunet_server.py +++ b/fl4health/servers/nnunet_server.py @@ -95,7 +95,7 @@ def __init__( accept_failures (bool, optional): Determines whether the server should accept failures during training or evaluation from clients or not. If set to False, this will cause the server to shutdown all clients and throw an exception. Defaults to True. - nnunet_trainer_class (Type[nnUNetTrainer], optional): nnUNetTrainer class. + nnunet_trainer_class (Type[nnUNetTrainer]): nnUNetTrainer class. Useful for passing custom nnUNetTrainer. Defaults to the standard nnUNetTrainer class. """ FlServerWithCheckpointing.__init__( From 660ad208592522d69bc30941ca56879c5815bec8 Mon Sep 17 00:00:00 2001 From: jewelltaylor Date: Mon, 25 Nov 2024 00:41:07 -0500 Subject: [PATCH 3/4] Upgrade tornado to avoid pip audit issue --- poetry.lock | 34 +++++++++++++--------------------- pyproject.toml | 1 + 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/poetry.lock b/poetry.lock index 75505cf36..21ba153a8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3040,7 +3040,6 @@ description = "Clang Python Bindings, mirrored from the official LLVM repo: http optional = false python-versions = "*" files = [ - {file = "libclang-18.1.1-1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a"}, {file = "libclang-18.1.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5"}, {file = "libclang-18.1.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8"}, {file = "libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl", hash = "sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b"}, @@ -5557,13 +5556,6 @@ files = [ {file = "python_gdcm-3.0.24.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e537b9c3c582e0a19cd89791634da5ff48f1d61eeee633bf6e806c7bed10aba"}, {file = "python_gdcm-3.0.24.1-cp312-cp312-win32.whl", hash = "sha256:0fe3684df3be2abcf4ec6931e45f4caa8bd2aa60a84e65ddd612428f0fa39bcc"}, {file = "python_gdcm-3.0.24.1-cp312-cp312-win_amd64.whl", hash = "sha256:530e6b3f3904fd87c7e69ad0aee383f7a87213a8bf339314741ca64e3b6a3e94"}, - {file = "python_gdcm-3.0.24.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:45c5927af717f06f7ff8e0d6124746ef15e314954ae105d3a98410b6e327fb15"}, - {file = "python_gdcm-3.0.24.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ee88f9cdcd4f5e98da0e608d9692e96173ec8832d5aba1f0234db8af0835d9bd"}, - {file = "python_gdcm-3.0.24.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a89a8b1b666f2c3ebe6afbac3fcc3e256566ac8e55080ef03dd1ef7c98cd6b1"}, - {file = "python_gdcm-3.0.24.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5bc32309aeba3d3675ae0e5641aae03a8b9dc66ace058979debd2fea849dda7"}, - {file = "python_gdcm-3.0.24.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8ba54c9908ce117734b32340b2e5bbf96d5544109b1af468e2b99f2c6341a15"}, - {file = "python_gdcm-3.0.24.1-cp313-cp313-win32.whl", hash = "sha256:5920e63ac12b9a430108cd804ee2709fcefb9781bda6b6cb2f7d311a8dc61a04"}, - {file = "python_gdcm-3.0.24.1-cp313-cp313-win_amd64.whl", hash = "sha256:c586099268f0baf3cfda5851fa6115dc93394930da148fe3c350081b59e7551a"}, {file = "python_gdcm-3.0.24.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:de850cedc1dc58b8b5ee16c72cf67c5ee1021963a0bcbc0de58e162824afd6ff"}, {file = "python_gdcm-3.0.24.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8eddf3dc5d7793f3af407972f5185ec5d7edc4989ccaeafbf0d3e5e74f5ba88e"}, {file = "python_gdcm-3.0.24.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8b9f5075cbd39fd4448ef83a298e5dbf0886dea8805d1de90df7de56d7930839"}, @@ -7586,22 +7578,22 @@ scipy = ["scipy"] [[package]] name = "tornado" -version = "6.4.1" +version = "6.4.2" description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed." optional = false python-versions = ">=3.8" files = [ - {file = "tornado-6.4.1-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:163b0aafc8e23d8cdc3c9dfb24c5368af84a81e3364745ccb4427669bf84aec8"}, - {file = "tornado-6.4.1-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:6d5ce3437e18a2b66fbadb183c1d3364fb03f2be71299e7d10dbeeb69f4b2a14"}, - {file = "tornado-6.4.1-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e20b9113cd7293f164dc46fffb13535266e713cdb87bd2d15ddb336e96cfc4"}, - {file = "tornado-6.4.1-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ae50a504a740365267b2a8d1a90c9fbc86b780a39170feca9bcc1787ff80842"}, - {file = "tornado-6.4.1-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:613bf4ddf5c7a95509218b149b555621497a6cc0d46ac341b30bd9ec19eac7f3"}, - {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:25486eb223babe3eed4b8aecbac33b37e3dd6d776bc730ca14e1bf93888b979f"}, - {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:454db8a7ecfcf2ff6042dde58404164d969b6f5d58b926da15e6b23817950fc4"}, - {file = "tornado-6.4.1-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a02a08cc7a9314b006f653ce40483b9b3c12cda222d6a46d4ac63bb6c9057698"}, - {file = "tornado-6.4.1-cp38-abi3-win32.whl", hash = "sha256:d9a566c40b89757c9aa8e6f032bcdb8ca8795d7c1a9762910c722b1635c9de4d"}, - {file = "tornado-6.4.1-cp38-abi3-win_amd64.whl", hash = "sha256:b24b8982ed444378d7f21d563f4180a2de31ced9d8d84443907a0a64da2072e7"}, - {file = "tornado-6.4.1.tar.gz", hash = "sha256:92d3ab53183d8c50f8204a51e6f91d18a15d5ef261e84d452800d4ff6fc504e9"}, + {file = "tornado-6.4.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e828cce1123e9e44ae2a50a9de3055497ab1d0aeb440c5ac23064d9e44880da1"}, + {file = "tornado-6.4.2-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:072ce12ada169c5b00b7d92a99ba089447ccc993ea2143c9ede887e0937aa803"}, + {file = "tornado-6.4.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a017d239bd1bb0919f72af256a970624241f070496635784d9bf0db640d3fec"}, + {file = "tornado-6.4.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c36e62ce8f63409301537222faffcef7dfc5284f27eec227389f2ad11b09d946"}, + {file = "tornado-6.4.2-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bca9eb02196e789c9cb5c3c7c0f04fb447dc2adffd95265b2c7223a8a615ccbf"}, + {file = "tornado-6.4.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:304463bd0772442ff4d0f5149c6f1c2135a1fae045adf070821c6cdc76980634"}, + {file = "tornado-6.4.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:c82c46813ba483a385ab2a99caeaedf92585a1f90defb5693351fa7e4ea0bf73"}, + {file = "tornado-6.4.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:932d195ca9015956fa502c6b56af9eb06106140d844a335590c1ec7f5277d10c"}, + {file = "tornado-6.4.2-cp38-abi3-win32.whl", hash = "sha256:2876cef82e6c5978fde1e0d5b1f919d756968d5b4282418f3146b79b58556482"}, + {file = "tornado-6.4.2-cp38-abi3-win_amd64.whl", hash = "sha256:908b71bf3ff37d81073356a5fadcc660eb10c1476ee6e2725588626ce7e5ca38"}, + {file = "tornado-6.4.2.tar.gz", hash = "sha256:92bad5b4746e9879fd7bf1eb21dce4e3fc5128d71601f80005afa39237ad620b"}, ] [[package]] @@ -8490,4 +8482,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10.0,<3.11" -content-hash = "db9e81b0f4468b60f81c1cfe6dd412fc41b3be0bcf4e3ca3828811cacc4b2ab6" +content-hash = "a3e15cd8f7b3bc543cc625919b109933adf6056c50990c07e3754a337f10a475" diff --git a/pyproject.toml b/pyproject.toml index e2b2b1c4b..be154d642 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ scikit-learn = "1.5.0" # Pin as it was causing issues with nnunet # Problematic grpcio versions cause issues, should be fixed in next flwr update # See https://github.com/adap/flower/pull/3853 # https://github.com/grpc/grpc/issues/37162 +tornado = ">=6.4.2" [tool.poetry.group.dev.dependencies] # locked the 2.13 version because of restrictions with tensorflow-io From 426821e3e04b7785ef1a3f24610dd5b28b17c238 Mon Sep 17 00:00:00 2001 From: John Jewell Date: Mon, 25 Nov 2024 10:24:39 -0500 Subject: [PATCH 4/4] Add documentation request by David --- fl4health/clients/nnunet_client.py | 1 + fl4health/servers/nnunet_server.py | 1 + 2 files changed, 2 insertions(+) diff --git a/fl4health/clients/nnunet_client.py b/fl4health/clients/nnunet_client.py index 571925ec1..7f9bc4ee5 100644 --- a/fl4health/clients/nnunet_client.py +++ b/fl4health/clients/nnunet_client.py @@ -145,6 +145,7 @@ def __init__( reporters which the client should send data to. nnunet_trainer_class (Type[nnUNetTrainer]): A nnUNetTrainer constructor. Useful for passing custom nnUNetTrainer. Defaults to the standard nnUNetTrainer class. + Must match the nnunet_trainer_class passed to the NnunetServer. nnunet_trainer_class_kwargs (dict[str, Any]): Additonal kwargs to pass to nnunet_trainer_class. Defaults to empty dictionary. """ diff --git a/fl4health/servers/nnunet_server.py b/fl4health/servers/nnunet_server.py index 9338855ca..8195b4edc 100644 --- a/fl4health/servers/nnunet_server.py +++ b/fl4health/servers/nnunet_server.py @@ -97,6 +97,7 @@ def __init__( and throw an exception. Defaults to True. nnunet_trainer_class (Type[nnUNetTrainer]): nnUNetTrainer class. Useful for passing custom nnUNetTrainer. Defaults to the standard nnUNetTrainer class. + Must match the nnunet_trainer_class passed to the NnunetClient. """ FlServerWithCheckpointing.__init__( self,