Skip to content

Commit

Permalink
Changing the nnunet server update_before_fit flow to minimize confusi…
Browse files Browse the repository at this point in the history
…on and work
  • Loading branch information
emersodb committed Nov 27, 2024
1 parent 79f8a5a commit 248ee7b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/static_code_checks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ jobs:
virtual-environment: .venv/
# Ignoring vulnerability in cryptography
# Fix is 43.0.1 but flwr 1.9 depends on < 43
# PYSEC-2022-43145 seems like a bug in pip audit, we should probably try to remove the ignore at some point
ignore-vulns: |
GHSA-h4gh-qq45-vh27
GHSA-q34m-jh98-gwm2
GHSA-f9vj-2wh5-fj8j
PYSEC-2022-43145
29 changes: 18 additions & 11 deletions fl4health/servers/nnunet_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,27 @@ def update_before_fit(self, num_rounds: int, timeout: Optional[float]) -> None:
information from a client. Defaults to None, which indicates indefinite timeout.
"""

# If no prior checkpoints exist, initialize server by sampling clients to get required properties to set
# If the per_round_checkpointer has been specified and a state checkpoint exists, we load state
# NOTE: Inherent assumption that if checkpoint exists for server that it also will exist for client.
if self.per_round_checkpointer is None or not self.per_round_checkpointer.checkpoint_exists():
# Sample properties from a random client to initialize plans
if self.per_round_checkpointer is not None and self.per_round_checkpointer.checkpoint_exists():
self._load_server_state()
# Otherwise, we're starting training from "scratch"
elif self.per_round_checkpointer is not None:
# If the state checkpointer is not None, then we want to do state checkpointing. So we need information
# from the clients in the form of get_properties.
log(INFO, "")
log(INFO, "[PRE-INIT]")
log(
INFO,
"Requesting initialization of global nnunet plans from one random client via get_properties",
)
log(INFO, "Requesting properties from one random client via get_properties")

if self.fl_config.get("nnunet_plans") is None:
# If the nnUnet plans are not specified, we also need those plans from the client.
log(INFO, "Initialization of global nnunet plans will be sourced from this client")
else:
log(
INFO,
"Properties from NnUnetTrainer will be sourced from this client to facilitate state preservation",
)

random_client = self._client_manager.sample(1)[0]
ins = GetPropertiesIns(config=self.fl_config | {"current_server_round": 0})
properties_res = random_client.get_properties(ins=ins, timeout=timeout, group_id=0)
Expand All @@ -195,7 +206,6 @@ def update_before_fit(self, num_rounds: int, timeout: Optional[float]) -> None:
log(INFO, "Received global nnunet plans from one random client")
else:
raise Exception("Failed to receive properties from client to initialize nnunet plans")

properties = properties_res.properties

# Set attributes of server that are dependent on client properties.
Expand All @@ -213,9 +223,6 @@ def update_before_fit(self, num_rounds: int, timeout: Optional[float]) -> None:
self.nnunet_config = NnunetConfig(self.fl_config["nnunet_config"])

self.initialize_server_model()
else:
# If a checkpoint exists, we load in previously checkpointed values for required properties
self._load_server_state()

# Wrap config functions so that nnunet_plans is included
new_fit_cfg_fn = add_items_to_config_fn(self.strategy.configure_fit, {"nnunet_plans": self.nnunet_plans_bytes})
Expand Down

0 comments on commit 248ee7b

Please sign in to comment.