From 950b95e7d22e708b7587ab003ee14cc006ed3400 Mon Sep 17 00:00:00 2001 From: jewelltaylor Date: Mon, 18 Nov 2024 19:19:00 -0500 Subject: [PATCH 1/3] Add resume argument for wandb reporter --- fl4health/reporting/wandb_reporter.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fl4health/reporting/wandb_reporter.py b/fl4health/reporting/wandb_reporter.py index edc6f231c..2b8b5f073 100644 --- a/fl4health/reporting/wandb_reporter.py +++ b/fl4health/reporting/wandb_reporter.py @@ -30,6 +30,7 @@ def __init__( tags: list[str] | None = None, name: str | None = None, id: str | None = None, + resume: str = "allow", **kwargs: Any, ) -> None: """ @@ -56,7 +57,11 @@ def __init__( name (str | None, optional): A short display name for this run. Default generates a random two-word name. id (str | None, optional): A unique ID for this run. It must be unique in the project, and if you delete a run you can't reuse the ID. - kwargs (Any): Keyword arguments to wandb.init excluding the ones explicitly described above. + resume (str): Indicates how to handle the case when a run has the same entity, project and run id as + a previous run. 'must' enforces the run must resume from the run with same id and throws an error + if it does not exist. 'never' enforces that a run will not resume and throws an error if run id exists. + 'allow' resumes if the run id already exists. Defaults to 'allow'. + kwargs (Any): Keyword arguments to wandb.init excluding the ones explicitly described above. Documentation here: https://docs.wandb.ai/ref/python/init/ """ @@ -77,6 +82,7 @@ def __init__( self.tags = tags self.name = name self.id = id + self.resume = resume # Keep track of epoch and step. Initialize as 0. self.current_epoch = 0 @@ -157,6 +163,7 @@ def start_run(self, wandb_init_kwargs: dict[str, Any]) -> None: tags=self.tags, name=self.name, id=self.id, + resume=self.resume, **wandb_init_kwargs, # Other less commonly used kwargs ) self.run_id = self.run._run_id # If run_id was None, we need to reset run id From f11d69428cf7f75bfdf13af87ad81fe9fa119e5b Mon Sep 17 00:00:00 2001 From: John Jewell Date: Wed, 20 Nov 2024 09:08:18 -0500 Subject: [PATCH 2/3] Remove fit_round_time_elapsed and eval_round_time_elapsed --- fl4health/reporting/wandb_reporter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fl4health/reporting/wandb_reporter.py b/fl4health/reporting/wandb_reporter.py index 2b8b5f073..5d717232d 100644 --- a/fl4health/reporting/wandb_reporter.py +++ b/fl4health/reporting/wandb_reporter.py @@ -116,12 +116,12 @@ def define_metrics(self) -> None: self.run.define_metric("round_end", summary="none", hidden=True) # A server round contains a fit_round and maybe also an evaluate round self.run.define_metric("fit_round_start", summary="none", hidden=True) - self.run.define_metric("fit_round_time_elapsed", summary="none", hidden=True) self.run.define_metric("fit_round_end", summary="none", hidden=True) self.run.define_metric("eval_round_start", summary="none", hidden=True) - self.run.define_metric("eval_round_time_elapsed", summary="none", hidden=True) self.run.define_metric("eval_round_end", summary="none", hidden=True) # The metrics computed on all the samples from the final epoch, or the entire round if training by steps + self.run.define_metric("fit_round_time_elapsed", summary="none") + self.run.define_metric("eval_round_time_elapsed", summary="none") self.run.define_metric("fit_round_metrics", step_metric="round", summary="best") self.run.define_metric("eval_round_metrics", step_metric="round", summary="best") # Average of the losses for each step in the final epoch, or the entire round if training by steps. From ae6c131ee94cc6a74d02bca7eef12ac4969edc0a Mon Sep 17 00:00:00 2001 From: John Jewell Date: Wed, 20 Nov 2024 12:32:09 -0500 Subject: [PATCH 3/3] Convert elapsed time to int (seconds) from string --- fl4health/clients/basic_client.py | 4 ++-- fl4health/servers/base_server.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 9908b4597..a7bff747c 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -322,7 +322,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, dict "round_start": str(round_start_time), "round_end": str(datetime.datetime.now()), "fit_round_start": str(fit_start_time), - "fit_round_time_elapsed": str(fit_end_time - fit_start_time), + "fit_round_time_elapsed": round((fit_end_time - fit_start_time).total_seconds()), "fit_round_end": str(fit_end_time), "fit_step": self.total_steps, "fit_epoch": self.total_epochs, @@ -377,7 +377,7 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di "eval_round_metrics": metrics, "eval_round_loss": loss, "eval_round_start": str(start_time), - "eval_round_time_elapsed": str(elapsed), + "eval_round_time_elapsed": round(elapsed.total_seconds()), "eval_round_end": str(end_time), "fit_step": self.total_steps, "fit_epoch": self.total_epochs, diff --git a/fl4health/servers/base_server.py b/fl4health/servers/base_server.py index e5d0c50ad..204e8371d 100644 --- a/fl4health/servers/base_server.py +++ b/fl4health/servers/base_server.py @@ -109,7 +109,7 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float end_time = datetime.datetime.now() self.reports_manager.report( { - "fit_elapsed_time": str(end_time - start_time), + "fit_elapsed_time": round((end_time - start_time).total_seconds()), "fit_end": str(end_time), "num_rounds": num_rounds, "host_type": "server", @@ -131,7 +131,11 @@ def fit_round( round_end = datetime.datetime.now() self.reports_manager.report( - {"fit_round_start": str(round_start), "fit_round_end": str(round_end)}, + { + "fit_round_start": str(round_start), + "fit_round_end": str(round_end), + "fit_round_time_elapsed": round((round_end - round_start).total_seconds()), + }, server_round, ) if fit_round_results is not None: @@ -347,6 +351,7 @@ def evaluate_round( "round": server_round, "eval_round_start": str(start_time), "eval_round_end": str(end_time), + "eval_round_time_elapsed": round((end_time - start_time).total_seconds()), } dummy_params = Parameters([], "None") config = self.strategy.configure_evaluate(server_round, dummy_params, self._client_manager)[0][