From 248ee7bdfda7773f2939a99925a3e2a3cdd38815 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Wed, 27 Nov 2024 14:52:52 -0500 Subject: [PATCH] Changing the nnunet server update_before_fit flow to minimize confusion and work --- .github/workflows/static_code_checks.yaml | 2 -- fl4health/servers/nnunet_server.py | 29 ++++++++++++++--------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/.github/workflows/static_code_checks.yaml b/.github/workflows/static_code_checks.yaml index 2a639dd5a..595a4cc19 100644 --- a/.github/workflows/static_code_checks.yaml +++ b/.github/workflows/static_code_checks.yaml @@ -43,9 +43,7 @@ jobs: virtual-environment: .venv/ # Ignoring vulnerability in cryptography # Fix is 43.0.1 but flwr 1.9 depends on < 43 - # PYSEC-2022-43145 seems like a bug in pip audit, we should probably try to remove the ignore at some point ignore-vulns: | GHSA-h4gh-qq45-vh27 GHSA-q34m-jh98-gwm2 GHSA-f9vj-2wh5-fj8j - PYSEC-2022-43145 diff --git a/fl4health/servers/nnunet_server.py b/fl4health/servers/nnunet_server.py index 9310d2da7..15852e409 100644 --- a/fl4health/servers/nnunet_server.py +++ b/fl4health/servers/nnunet_server.py @@ -177,16 +177,27 @@ def update_before_fit(self, num_rounds: int, timeout: Optional[float]) -> None: information from a client. Defaults to None, which indicates indefinite timeout. """ - # If no prior checkpoints exist, initialize server by sampling clients to get required properties to set + # If the per_round_checkpointer has been specified and a state checkpoint exists, we load state # NOTE: Inherent assumption that if checkpoint exists for server that it also will exist for client. - if self.per_round_checkpointer is None or not self.per_round_checkpointer.checkpoint_exists(): - # Sample properties from a random client to initialize plans + if self.per_round_checkpointer is not None and self.per_round_checkpointer.checkpoint_exists(): + self._load_server_state() + # Otherwise, we're starting training from "scratch" + elif self.per_round_checkpointer is not None: + # If the state checkpointer is not None, then we want to do state checkpointing. So we need information + # from the clients in the form of get_properties. log(INFO, "") log(INFO, "[PRE-INIT]") - log( - INFO, - "Requesting initialization of global nnunet plans from one random client via get_properties", - ) + log(INFO, "Requesting properties from one random client via get_properties") + + if self.fl_config.get("nnunet_plans") is None: + # If the nnUnet plans are not specified, we also need those plans from the client. + log(INFO, "Initialization of global nnunet plans will be sourced from this client") + else: + log( + INFO, + "Properties from NnUnetTrainer will be sourced from this client to facilitate state preservation", + ) + random_client = self._client_manager.sample(1)[0] ins = GetPropertiesIns(config=self.fl_config | {"current_server_round": 0}) properties_res = random_client.get_properties(ins=ins, timeout=timeout, group_id=0) @@ -195,7 +206,6 @@ def update_before_fit(self, num_rounds: int, timeout: Optional[float]) -> None: log(INFO, "Received global nnunet plans from one random client") else: raise Exception("Failed to receive properties from client to initialize nnunet plans") - properties = properties_res.properties # Set attributes of server that are dependent on client properties. @@ -213,9 +223,6 @@ def update_before_fit(self, num_rounds: int, timeout: Optional[float]) -> None: self.nnunet_config = NnunetConfig(self.fl_config["nnunet_config"]) self.initialize_server_model() - else: - # If a checkpoint exists, we load in previously checkpointed values for required properties - self._load_server_state() # Wrap config functions so that nnunet_plans is included new_fit_cfg_fn = add_items_to_config_fn(self.strategy.configure_fit, {"nnunet_plans": self.nnunet_plans_bytes})