From be027446f1c671ca9aa33d6107217d1b710fde6c Mon Sep 17 00:00:00 2001 From: jeandut Date: Thu, 11 Jan 2024 16:01:41 +0100 Subject: [PATCH] errors start to make more sense --- fedeca/strategies/bootstraper.py | 57 ++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/fedeca/strategies/bootstraper.py b/fedeca/strategies/bootstraper.py index e2e118d9..e6d05399 100644 --- a/fedeca/strategies/bootstraper.py +++ b/fedeca/strategies/bootstraper.py @@ -8,6 +8,10 @@ import copy from pathlib import Path from substrafl.algorithms.pytorch.torch_base_algo import TorchAlgo +import tempfile +import shutil +import zipfile +from glob import glob def make_bootstrap_strategy(strategy: Strategy, n_bootstraps: Union[int, None] = 10, bootstrap_seeds: Union[list[int], None] = None, inplace : bool = False): """Bootstrap a substrafl strategy wo impacting the number of compute tasks. @@ -141,7 +145,7 @@ def bootstrapped_algo_init(self, *args, **kwargs): methods_dict.update({"load_local_state_original": getattr(strategy.algo, "load_local_state")}) methods_dict.update({"save_local_state_original": getattr(strategy.algo, "save_local_state")}) methods_dict.update({"load_local_state": _load_all_bootstraps_states()}) - methods_dict.update({"save_local_state": _save_all_bootstraps_states()}) + methods_dict.update({"save_local_state": _save_all_bootstraps_states(bootstrap_seeds_list)}) methods_dict.update({"strategies": strategy.algo.strategies}) methods_dict.update({"__init__": bootstrapped_algo_init}) @@ -257,17 +261,16 @@ def local_computation(self, datasamples, shared_state=None) -> list: # dependent and we need to load the correspponding state as the main # state, so we need to have saved all states (aka i.e. n_bootstraps models) # We use implicitly the new method load_bootstrap_states to load all states in-RAM - if not hasattr(self, "n_bootstraps_bootstraper"): - self.n_bootstraps_bootstraper = len(bootstrap_seeds_list) + if not hasattr(self, "checkpoints_list"): - self.checkpoints_list = [None] * self.n_bootstraps_bootstraper + self.checkpoints_list = [None] * len(bootstrap_seeds_list) for idx, seed in enumerate(bootstrap_seeds_list): rng = np.random.default_rng(seed) bootstrapped_data = datasamples.sample(datasamples.shape[0], replace=True, random_state=rng) # Loading the correct state into the current main algo if self.checkpoints_list[idx] is not None: - self.load_local_state_original(self.checkpoints_list[idx]) + self._update_from_checkpoint(self.checkpoints_list[idx]) if shared_state is None: res = local_function(datasamples=bootstrapped_data, _skip=True) else: @@ -356,33 +359,45 @@ def load_local_state(self, path: Path) -> "TorchAlgo": # Note that at the end of this loop the main state is the one of the last # bootstrap - if (path.parent / "bootstrap_0") in path.parent.iterdir(): - for idx in range(self.n_bootstraps_bootstraper): - self.load_local_state_original(path / f"bootstrap_{idx}") - self.checkpoints_list[idx] = self.get_state_to_save() - else: - # This first call is needed when no bootstrap has been done - # self.load_local_state_original(path) - self.load_local_state_original(path) + archive = zipfile.ZipFile(path, 'r') + with tempfile.TemporaryDirectory() as tmpdirname: + archive.extractall(tmpdirname) + checkpoints_found = [p for p in Path(tmpdirname).glob("**/bootstrap_*")] + self.checkpoints_list = [None] * len(checkpoints_found) + for idx, file in enumerate(checkpoints_found): + self.load_local_state_original(file) + self.checkpoints_list[idx] = self._get_state_to_save() + return self + return load_local_state -def _save_all_bootstraps_states(): +def _save_all_bootstraps_states(bootstrap_seeds_list): def save_local_state(self, path: Path) -> "TorchAlgo": # We save all bootstrapped states in different subfolders # It assumes at this point checkpoints_list has been populated # if it exists - # We need to save the checkponts list attribute + + # The reason for the if is because of initialize functions which don't + # populate checkpoints_list if hasattr(self, "checkpoints_list"): + pass + else: + self.checkpoints_list = [copy.deepcopy(self._get_state_to_save()) for _ in range(len(bootstrap_seeds_list))] + + with tempfile.TemporaryDirectory() as tmpdirname: + paths_to_checkpoints = [] for idx, checkpt in enumerate(self.checkpoints_list): # Get the model in the proper state + self._update_from_checkpoint(checkpt) # TODO methods implictly use the self attribute - self.load_local_state_original(checkpt) - self.save_local_state_original(path=path / f"bootstrap_{idx}") - else: - # First call of thiis function checkpoints_list doesn't exist - # we need to be able to load sthg - self.save_local_state_original(path=path) + path_to_checkpoint = Path(tmpdirname) / f"bootstrap_{idx}" + self.save_local_state_original(path_to_checkpoint) + paths_to_checkpoints.append(path_to_checkpoint) + + with zipfile.ZipFile(path, 'w') as f: + for chkpt in paths_to_checkpoints: + f.write(chkpt, compress_type=zipfile.ZIP_DEFLATED) return self return save_local_state