Skip to content

Commit

Permalink
Merge branch 'main' into dbe/server_stores_config
Browse files Browse the repository at this point in the history
  • Loading branch information
emersodb committed Nov 21, 2024
2 parents 0a2e2da + c8e967f commit 7dd924e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
4 changes: 2 additions & 2 deletions fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, dict
"round_start": str(round_start_time),
"round_end": str(datetime.datetime.now()),
"fit_round_start": str(fit_start_time),
"fit_round_time_elapsed": str(fit_end_time - fit_start_time),
"fit_round_time_elapsed": round((fit_end_time - fit_start_time).total_seconds()),
"fit_round_end": str(fit_end_time),
"fit_step": self.total_steps,
"fit_epoch": self.total_epochs,
Expand Down Expand Up @@ -376,7 +376,7 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di
"eval_round_metrics": metrics,
"eval_round_loss": loss,
"eval_round_start": str(start_time),
"eval_round_time_elapsed": str(elapsed),
"eval_round_time_elapsed": round(elapsed.total_seconds()),
"eval_round_end": str(end_time),
"fit_step": self.total_steps,
"fit_epoch": self.total_epochs,
Expand Down
13 changes: 10 additions & 3 deletions fl4health/reporting/wandb_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
tags: list[str] | None = None,
name: str | None = None,
id: str | None = None,
resume: str = "allow",
**kwargs: Any,
) -> None:
"""
Expand All @@ -56,7 +57,11 @@ def __init__(
name (str | None, optional): A short display name for this run. Default generates a random two-word name.
id (str | None, optional): A unique ID for this run. It must be unique in the project, and if you delete a
run you can't reuse the ID.
kwargs (Any): Keyword arguments to wandb.init excluding the ones explicitly described above.
resume (str): Indicates how to handle the case when a run has the same entity, project and run id as
a previous run. 'must' enforces the run must resume from the run with same id and throws an error
if it does not exist. 'never' enforces that a run will not resume and throws an error if run id exists.
'allow' resumes if the run id already exists. Defaults to 'allow'.
kwargs (Any): Keyword arguments to wandb.init excluding the ones explicitly described above.
Documentation here: https://docs.wandb.ai/ref/python/init/
"""

Expand All @@ -77,6 +82,7 @@ def __init__(
self.tags = tags
self.name = name
self.id = id
self.resume = resume

# Keep track of epoch and step. Initialize as 0.
self.current_epoch = 0
Expand Down Expand Up @@ -110,12 +116,12 @@ def define_metrics(self) -> None:
self.run.define_metric("round_end", summary="none", hidden=True)
# A server round contains a fit_round and maybe also an evaluate round
self.run.define_metric("fit_round_start", summary="none", hidden=True)
self.run.define_metric("fit_round_time_elapsed", summary="none", hidden=True)
self.run.define_metric("fit_round_end", summary="none", hidden=True)
self.run.define_metric("eval_round_start", summary="none", hidden=True)
self.run.define_metric("eval_round_time_elapsed", summary="none", hidden=True)
self.run.define_metric("eval_round_end", summary="none", hidden=True)
# The metrics computed on all the samples from the final epoch, or the entire round if training by steps
self.run.define_metric("fit_round_time_elapsed", summary="none")
self.run.define_metric("eval_round_time_elapsed", summary="none")
self.run.define_metric("fit_round_metrics", step_metric="round", summary="best")
self.run.define_metric("eval_round_metrics", step_metric="round", summary="best")
# Average of the losses for each step in the final epoch, or the entire round if training by steps.
Expand Down Expand Up @@ -157,6 +163,7 @@ def start_run(self, wandb_init_kwargs: dict[str, Any]) -> None:
tags=self.tags,
name=self.name,
id=self.id,
resume=self.resume,
**wandb_init_kwargs, # Other less commonly used kwargs
)
self.run_id = self.run._run_id # If run_id was None, we need to reset run id
Expand Down
9 changes: 7 additions & 2 deletions fl4health/servers/base_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float
end_time = datetime.datetime.now()
self.reports_manager.report(
{
"fit_elapsed_time": str(end_time - start_time),
"fit_elapsed_time": round((end_time - start_time).total_seconds()),
"fit_end": str(end_time),
"num_rounds": num_rounds,
"host_type": "server",
Expand Down Expand Up @@ -311,7 +311,11 @@ def fit_round(
round_end = datetime.datetime.now()

self.reports_manager.report(
{"fit_round_start": str(round_start), "fit_round_end": str(round_end)},
{
"fit_round_start": str(round_start),
"fit_round_end": str(round_end),
"fit_round_time_elapsed": round((round_end - round_start).total_seconds()),
},
server_round,
)
if fit_round_results is not None:
Expand Down Expand Up @@ -388,6 +392,7 @@ def evaluate_round(
"round": server_round,
"eval_round_start": str(start_time),
"eval_round_end": str(end_time),
"eval_round_time_elapsed": round((end_time - start_time).total_seconds()),
}

if self.fl_config.get("local_epochs", None) is not None:
Expand Down

0 comments on commit 7dd924e

Please sign in to comment.