Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server storing FL configs and Consolidating Base Server Functionality #294

Merged
merged 15 commits into from
Dec 13, 2024

Conversation

emersodb
Copy link
Collaborator

PR Type

Other: QoL improvements

Short Description

Clickup Ticket(s): Ticket 1, Ticket 2

There are two tickets folded into this PR. They were sort of interconnected, where fixing one would have left some code weirdness lying around that I didn't want.

The first component forces the server to store a configuration file. The goal here is for this configuration file to essentially be the representation of the config used to create the on_fit_config_fn and on_evaluation_config_fn variables that are used to drive overall FL training and evaluation. The server sort of "forgets" about these in the current setup because they go directly to the strategy and are become inaccessible to the server.

NOTE: There are a lot of file changes in this PR. Almost all of them are due to the fact that I am forcing the config to be provided to the server. So I had to migrate all of our examples etc.

The second component of this PR is merging the functionality of FlServerWithCheckpointing and FlServerWithInitializer into the FlServer class. For a while it made sense for these to be seperate functionalities, but it's actually not all that much more complex to just fold those functionalities into the base class and it makes inheritance much easier.

Tests Added

A test was added to cover a checkpointing edge case, but, overall, this is a refactor. So the tests we have should be sufficient.

@@ -61,43 +56,6 @@ def fit_config(
)


class CifarInstanceLevelDPServerWithCheckpointing(InstanceLevelDpServer):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All functionality implemented here is now neatly contained in the parent class of InstanceLevelDpServer

@@ -353,14 +394,11 @@ def evaluate_round(
"eval_round_end": str(end_time),
"eval_round_time_elapsed": round((end_time - start_time).total_seconds()),
}
dummy_params = Parameters([], "None")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stored config means that we no longer have to use the strategy to configure evaluation, which was a work-around for not having access to the config in the server.

return FlServerWithCheckpointing.fit(self, num_rounds, timeout)

def initialize(self, server_round: int, timeout: Optional[float] = None) -> None:
def update_before_fit(self, num_rounds: int, timeout: Optional[float]) -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function takes the place of the initialize functionality that was present in the FlServerWithIntiailizer class. It's called before any fitting starts (i.e. right at the start of FL).

@@ -187,11 +183,11 @@ def initialize(self, server_round: int, timeout: Optional[float] = None) -> None
"Requesting initialization of global nnunet plans from one random client via get_properties",
)
random_client = self._client_manager.sample(1)[0]
ins = GetPropertiesIns(config=config)
properties_res = random_client.get_properties(ins=ins, timeout=timeout, group_id=server_round)
ins = GetPropertiesIns(config=self.fl_config | {"current_server_round": 0})
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fl_config will have everything needed for this except for what server round we're currently on. So we just add it here. 0 is because no rounds have actually been done.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Didn't know about | operator for dicts!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just be careful because {} will evaluate to True in this case. It's fine if that's not a concern.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I maybe misreading your comment Marcelo, but in this case, the pipe isn't a binary operator but is rather a merge operator for the two dictionaries. So I think we're okay, but perhaps I'm missing something 🙂

@emersodb emersodb marked this pull request as ready for review November 21, 2024 23:17
Copy link
Collaborator

@scarere scarere left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Mostly had comments on stuff from previous PR's that we're just brought to light in this one and might be worth addressing

fl4health/servers/base_server.py Show resolved Hide resolved

def fit_with_per_epoch_checkpointing(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]:
def _save_server_state(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a future PR does it maybe make sense to conglomerate the state_saving and checkpointing? The whole server state stuff is basically a per round checkpointer that saves additional info. We could modify the checkpointer base class to accept an instance of self or an arbitrary dictionary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we have some tickets that speak to this theme. There are some design things that need to be sorted out, as the two classes are sort of trying to accomplish different things. I'm taking the first steps towards unifying these things into one module. Perhaps the next natural step is to merge them together in a meaningful way.

fl4health/servers/nnunet_server.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@jewelltaylor jewelltaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

fl4health/servers/base_server.py Show resolved Hide resolved
@@ -187,11 +183,11 @@ def initialize(self, server_round: int, timeout: Optional[float] = None) -> None
"Requesting initialization of global nnunet plans from one random client via get_properties",
)
random_client = self._client_manager.sample(1)[0]
ins = GetPropertiesIns(config=config)
properties_res = random_client.get_properties(ins=ins, timeout=timeout, group_id=server_round)
ins = GetPropertiesIns(config=self.fl_config | {"current_server_round": 0})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just be careful because {} will evaluate to True in this case. It's fine if that's not a concern.

@emersodb emersodb merged commit a2fd930 into main Dec 13, 2024
6 checks passed
@emersodb emersodb deleted the dbe/server_stores_config branch December 13, 2024 13:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants