diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index f9d386d7b..92b278e96 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -321,7 +321,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, dict "round_start": str(round_start_time), "round_end": str(datetime.datetime.now()), "fit_round_start": str(fit_start_time), - "fit_round_time_elapsed": str(fit_end_time - fit_start_time), + "fit_round_time_elapsed": round((fit_end_time - fit_start_time).total_seconds()), "fit_round_end": str(fit_end_time), "fit_step": self.total_steps, "fit_epoch": self.total_epochs, @@ -376,7 +376,7 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di "eval_round_metrics": metrics, "eval_round_loss": loss, "eval_round_start": str(start_time), - "eval_round_time_elapsed": str(elapsed), + "eval_round_time_elapsed": round(elapsed.total_seconds()), "eval_round_end": str(end_time), "fit_step": self.total_steps, "fit_epoch": self.total_epochs, diff --git a/fl4health/reporting/wandb_reporter.py b/fl4health/reporting/wandb_reporter.py index edc6f231c..5d717232d 100644 --- a/fl4health/reporting/wandb_reporter.py +++ b/fl4health/reporting/wandb_reporter.py @@ -30,6 +30,7 @@ def __init__( tags: list[str] | None = None, name: str | None = None, id: str | None = None, + resume: str = "allow", **kwargs: Any, ) -> None: """ @@ -56,7 +57,11 @@ def __init__( name (str | None, optional): A short display name for this run. Default generates a random two-word name. id (str | None, optional): A unique ID for this run. It must be unique in the project, and if you delete a run you can't reuse the ID. - kwargs (Any): Keyword arguments to wandb.init excluding the ones explicitly described above. + resume (str): Indicates how to handle the case when a run has the same entity, project and run id as + a previous run. 'must' enforces the run must resume from the run with same id and throws an error + if it does not exist. 'never' enforces that a run will not resume and throws an error if run id exists. + 'allow' resumes if the run id already exists. Defaults to 'allow'. + kwargs (Any): Keyword arguments to wandb.init excluding the ones explicitly described above. Documentation here: https://docs.wandb.ai/ref/python/init/ """ @@ -77,6 +82,7 @@ def __init__( self.tags = tags self.name = name self.id = id + self.resume = resume # Keep track of epoch and step. Initialize as 0. self.current_epoch = 0 @@ -110,12 +116,12 @@ def define_metrics(self) -> None: self.run.define_metric("round_end", summary="none", hidden=True) # A server round contains a fit_round and maybe also an evaluate round self.run.define_metric("fit_round_start", summary="none", hidden=True) - self.run.define_metric("fit_round_time_elapsed", summary="none", hidden=True) self.run.define_metric("fit_round_end", summary="none", hidden=True) self.run.define_metric("eval_round_start", summary="none", hidden=True) - self.run.define_metric("eval_round_time_elapsed", summary="none", hidden=True) self.run.define_metric("eval_round_end", summary="none", hidden=True) # The metrics computed on all the samples from the final epoch, or the entire round if training by steps + self.run.define_metric("fit_round_time_elapsed", summary="none") + self.run.define_metric("eval_round_time_elapsed", summary="none") self.run.define_metric("fit_round_metrics", step_metric="round", summary="best") self.run.define_metric("eval_round_metrics", step_metric="round", summary="best") # Average of the losses for each step in the final epoch, or the entire round if training by steps. @@ -157,6 +163,7 @@ def start_run(self, wandb_init_kwargs: dict[str, Any]) -> None: tags=self.tags, name=self.name, id=self.id, + resume=self.resume, **wandb_init_kwargs, # Other less commonly used kwargs ) self.run_id = self.run._run_id # If run_id was None, we need to reset run id diff --git a/fl4health/servers/base_server.py b/fl4health/servers/base_server.py index 4bd71d476..96df72de2 100644 --- a/fl4health/servers/base_server.py +++ b/fl4health/servers/base_server.py @@ -269,7 +269,7 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float end_time = datetime.datetime.now() self.reports_manager.report( { - "fit_elapsed_time": str(end_time - start_time), + "fit_elapsed_time": round((end_time - start_time).total_seconds()), "fit_end": str(end_time), "num_rounds": num_rounds, "host_type": "server", @@ -311,7 +311,11 @@ def fit_round( round_end = datetime.datetime.now() self.reports_manager.report( - {"fit_round_start": str(round_start), "fit_round_end": str(round_end)}, + { + "fit_round_start": str(round_start), + "fit_round_end": str(round_end), + "fit_round_time_elapsed": round((round_end - round_start).total_seconds()), + }, server_round, ) if fit_round_results is not None: @@ -388,6 +392,7 @@ def evaluate_round( "round": server_round, "eval_round_start": str(start_time), "eval_round_end": str(end_time), + "eval_round_time_elapsed": round((end_time - start_time).total_seconds()), } if self.fl_config.get("local_epochs", None) is not None: