-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Server storing FL configs and Consolidating Base Server Functionality #294
Conversation
…forcing the server to carry a config
…onality directly into the base server
@@ -61,43 +56,6 @@ def fit_config( | |||
) | |||
|
|||
|
|||
class CifarInstanceLevelDPServerWithCheckpointing(InstanceLevelDpServer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All functionality implemented here is now neatly contained in the parent class of InstanceLevelDpServer
@@ -353,14 +394,11 @@ def evaluate_round( | |||
"eval_round_end": str(end_time), | |||
"eval_round_time_elapsed": round((end_time - start_time).total_seconds()), | |||
} | |||
dummy_params = Parameters([], "None") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The stored config means that we no longer have to use the strategy to configure evaluation, which was a work-around for not having access to the config in the server.
…solicited BEFORE any parameter initialization. So we can't do it in fit round.
return FlServerWithCheckpointing.fit(self, num_rounds, timeout) | ||
|
||
def initialize(self, server_round: int, timeout: Optional[float] = None) -> None: | ||
def update_before_fit(self, num_rounds: int, timeout: Optional[float]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function takes the place of the initialize functionality that was present in the FlServerWithIntiailizer class. It's called before any fitting starts (i.e. right at the start of FL).
@@ -187,11 +183,11 @@ def initialize(self, server_round: int, timeout: Optional[float] = None) -> None | |||
"Requesting initialization of global nnunet plans from one random client via get_properties", | |||
) | |||
random_client = self._client_manager.sample(1)[0] | |||
ins = GetPropertiesIns(config=config) | |||
properties_res = random_client.get_properties(ins=ins, timeout=timeout, group_id=server_round) | |||
ins = GetPropertiesIns(config=self.fl_config | {"current_server_round": 0}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fl_config will have everything needed for this except for what server round we're currently on. So we just add it here. 0 is because no rounds have actually been done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Didn't know about | operator for dicts!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just be careful because {}
will evaluate to True
in this case. It's fine if that's not a concern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I maybe misreading your comment Marcelo, but in this case, the pipe isn't a binary operator but is rather a merge operator for the two dictionaries. So I think we're okay, but perhaps I'm missing something 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Mostly had comments on stuff from previous PR's that we're just brought to light in this one and might be worth addressing
|
||
def fit_with_per_epoch_checkpointing(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: | ||
def _save_server_state(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a future PR does it maybe make sense to conglomerate the state_saving and checkpointing? The whole server state stuff is basically a per round checkpointer that saves additional info. We could modify the checkpointer base class to accept an instance of self or an arbitrary dictionary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think we have some tickets that speak to this theme. There are some design things that need to be sorted out, as the two classes are sort of trying to accomplish different things. I'm taking the first steps towards unifying these things into one module. Perhaps the next natural step is to merge them together in a meaningful way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@@ -187,11 +183,11 @@ def initialize(self, server_round: int, timeout: Optional[float] = None) -> None | |||
"Requesting initialization of global nnunet plans from one random client via get_properties", | |||
) | |||
random_client = self._client_manager.sample(1)[0] | |||
ins = GetPropertiesIns(config=config) | |||
properties_res = random_client.get_properties(ins=ins, timeout=timeout, group_id=server_round) | |||
ins = GetPropertiesIns(config=self.fl_config | {"current_server_round": 0}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just be careful because {}
will evaluate to True
in this case. It's fine if that's not a concern.
PR Type
Other: QoL improvements
Short Description
Clickup Ticket(s): Ticket 1, Ticket 2
There are two tickets folded into this PR. They were sort of interconnected, where fixing one would have left some code weirdness lying around that I didn't want.
The first component forces the server to store a configuration file. The goal here is for this configuration file to essentially be the representation of the config used to create the
on_fit_config_fn
andon_evaluation_config_fn
variables that are used to drive overall FL training and evaluation. The server sort of "forgets" about these in the current setup because they go directly to the strategy and are become inaccessible to the server.NOTE: There are a lot of file changes in this PR. Almost all of them are due to the fact that I am forcing the config to be provided to the server. So I had to migrate all of our examples etc.
The second component of this PR is merging the functionality of
FlServerWithCheckpointing
andFlServerWithInitializer
into theFlServer
class. For a while it made sense for these to be seperate functionalities, but it's actually not all that much more complex to just fold those functionalities into the base class and it makes inheritance much easier.Tests Added
A test was added to cover a checkpointing edge case, but, overall, this is a refactor. So the tests we have should be sufficient.