Skip to content

Commit

Permalink
Merge pull request #296 from VectorInstitute/custom-nnunet-trainer
Browse files Browse the repository at this point in the history
Custom nnunet trainer
  • Loading branch information
jewelltaylor authored Nov 25, 2024
2 parents 2509952 + 426821e commit d73bae0
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 25 deletions.
14 changes: 12 additions & 2 deletions fl4health/clients/nnunet_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from logging import DEBUG, ERROR, INFO, WARNING
from os.path import exists, join
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union

import numpy as np
import torch
Expand Down Expand Up @@ -78,6 +78,8 @@ def __init__(
checkpointer: Optional[ClientCheckpointModule] = None,
reporters: Sequence[BaseReporter] | None = None,
client_name: Optional[str] = None,
nnunet_trainer_class: Type[nnUNetTrainer] = nnUNetTrainer,
nnunet_trainer_class_kwargs: Optional[dict[str, Any]] = {},
) -> None:
"""
A client for training nnunet models. Requires the nnunet environment variables
Expand Down Expand Up @@ -141,6 +143,11 @@ def __init__(
provided. Defaults to None.
reporters (Sequence[BaseReporter], optional): A sequence of FL4Health
reporters which the client should send data to.
nnunet_trainer_class (Type[nnUNetTrainer]): A nnUNetTrainer constructor.
Useful for passing custom nnUNetTrainer. Defaults to the standard nnUNetTrainer class.
Must match the nnunet_trainer_class passed to the NnunetServer.
nnunet_trainer_class_kwargs (dict[str, Any]): Additonal kwargs to pass to nnunet_trainer_class.
Defaults to empty dictionary.
"""
metrics = metrics if metrics else []
# Parent method sets up several class attributes
Expand Down Expand Up @@ -182,6 +189,8 @@ def __init__(
self.stream2debug = StreamToLogger(FLOWER_LOGGER, DEBUG)

# nnunet specific attributes to be initialized in setup_client
self.nnunet_trainer_class = nnunet_trainer_class
self.nnunet_trainer_class_kwargs = nnunet_trainer_class_kwargs
self.nnunet_trainer: nnUNetTrainer
self.nnunet_config: NnunetConfig
self.plans: dict[str, Any]
Expand Down Expand Up @@ -465,12 +474,13 @@ def setup_client(self, config: Config) -> None:
# Unless log level is DEBUG or lower hide nnunet output
with redirect_stdout(self.stream2debug):
# Create the nnunet trainer
self.nnunet_trainer = nnUNetTrainer(
self.nnunet_trainer = self.nnunet_trainer_class(
plans=self.plans,
configuration=self.nnunet_config.value,
fold=self.fold,
dataset_json=self.dataset_json,
device=self.device,
**self.nnunet_trainer_class_kwargs,
)
# nnunet_trainer initialization
self.nnunet_trainer.initialize()
Expand Down
9 changes: 7 additions & 2 deletions fl4health/servers/nnunet_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Callable, Sequence
from logging import INFO
from pathlib import Path
from typing import Any, Optional, Tuple, Union
from typing import Any, Optional, Tuple, Type, Union

import torch.nn as nn
from flwr.common import Parameters
Expand Down Expand Up @@ -69,6 +69,7 @@ def __init__(
intermediate_server_state_dir: Path | None = None,
server_name: str | None = None,
accept_failures: bool = True,
nnunet_trainer_class: Type[nnUNetTrainer] = nnUNetTrainer,
) -> None:
"""
A Basic FlServer with added functionality to ask a client to initialize the global nnunet plans if one was not
Expand All @@ -94,6 +95,9 @@ def __init__(
accept_failures (bool, optional): Determines whether the server should accept failures during training or
evaluation from clients or not. If set to False, this will cause the server to shutdown all clients
and throw an exception. Defaults to True.
nnunet_trainer_class (Type[nnUNetTrainer]): nnUNetTrainer class.
Useful for passing custom nnUNetTrainer. Defaults to the standard nnUNetTrainer class.
Must match the nnunet_trainer_class passed to the NnunetClient.
"""
FlServerWithCheckpointing.__init__(
self,
Expand All @@ -107,6 +111,7 @@ def __init__(
server_name=server_name,
)
self.initialized = False
self.nnunet_trainer_class = nnunet_trainer_class

self.nnunet_plans_bytes: bytes
self.num_input_channels: int
Expand All @@ -127,7 +132,7 @@ def initialize_server_model(self) -> None:
plans = pickle.loads(self.nnunet_plans_bytes)
plans_manager = PlansManager(plans)
configuration_manager = plans_manager.get_configuration(self.nnunet_config.value)
model = nnUNetTrainer.build_network_architecture(
model = self.nnunet_trainer_class.build_network_architecture(
configuration_manager.network_arch_class_name,
configuration_manager.network_arch_init_kwargs,
configuration_manager.network_arch_init_kwargs_req_import,
Expand Down
34 changes: 13 additions & 21 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ scikit-learn = "1.5.0" # Pin as it was causing issues with nnunet
# Problematic grpcio versions cause issues, should be fixed in next flwr update
# See https://github.com/adap/flower/pull/3853
# https://github.com/grpc/grpc/issues/37162
tornado = ">=6.4.2"

[tool.poetry.group.dev.dependencies]
# locked the 2.13 version because of restrictions with tensorflow-io
Expand Down

0 comments on commit d73bae0

Please sign in to comment.