diff --git a/fl4health/clients/nnunet_client.py b/fl4health/clients/nnunet_client.py index 571925ec1..7f9bc4ee5 100644 --- a/fl4health/clients/nnunet_client.py +++ b/fl4health/clients/nnunet_client.py @@ -145,6 +145,7 @@ def __init__( reporters which the client should send data to. nnunet_trainer_class (Type[nnUNetTrainer]): A nnUNetTrainer constructor. Useful for passing custom nnUNetTrainer. Defaults to the standard nnUNetTrainer class. + Must match the nnunet_trainer_class passed to the NnunetServer. nnunet_trainer_class_kwargs (dict[str, Any]): Additonal kwargs to pass to nnunet_trainer_class. Defaults to empty dictionary. """ diff --git a/fl4health/servers/nnunet_server.py b/fl4health/servers/nnunet_server.py index 9338855ca..8195b4edc 100644 --- a/fl4health/servers/nnunet_server.py +++ b/fl4health/servers/nnunet_server.py @@ -97,6 +97,7 @@ def __init__( and throw an exception. Defaults to True. nnunet_trainer_class (Type[nnUNetTrainer]): nnUNetTrainer class. Useful for passing custom nnUNetTrainer. Defaults to the standard nnUNetTrainer class. + Must match the nnunet_trainer_class passed to the NnunetClient. """ FlServerWithCheckpointing.__init__( self,