Skip to content

Commit

Permalink
errors start to make more sense
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Jan 11, 2024
1 parent d8266e6 commit be02744
Showing 1 changed file with 36 additions and 21 deletions.
57 changes: 36 additions & 21 deletions fedeca/strategies/bootstraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import copy
from pathlib import Path
from substrafl.algorithms.pytorch.torch_base_algo import TorchAlgo
import tempfile
import shutil
import zipfile
from glob import glob

def make_bootstrap_strategy(strategy: Strategy, n_bootstraps: Union[int, None] = 10, bootstrap_seeds: Union[list[int], None] = None, inplace : bool = False):
"""Bootstrap a substrafl strategy wo impacting the number of compute tasks.
Expand Down Expand Up @@ -141,7 +145,7 @@ def bootstrapped_algo_init(self, *args, **kwargs):
methods_dict.update({"load_local_state_original": getattr(strategy.algo, "load_local_state")})
methods_dict.update({"save_local_state_original": getattr(strategy.algo, "save_local_state")})
methods_dict.update({"load_local_state": _load_all_bootstraps_states()})
methods_dict.update({"save_local_state": _save_all_bootstraps_states()})
methods_dict.update({"save_local_state": _save_all_bootstraps_states(bootstrap_seeds_list)})
methods_dict.update({"strategies": strategy.algo.strategies})
methods_dict.update({"__init__": bootstrapped_algo_init})

Expand Down Expand Up @@ -257,17 +261,16 @@ def local_computation(self, datasamples, shared_state=None) -> list:
# dependent and we need to load the correspponding state as the main
# state, so we need to have saved all states (aka i.e. n_bootstraps models)
# We use implicitly the new method load_bootstrap_states to load all states in-RAM
if not hasattr(self, "n_bootstraps_bootstraper"):
self.n_bootstraps_bootstraper = len(bootstrap_seeds_list)

if not hasattr(self, "checkpoints_list"):
self.checkpoints_list = [None] * self.n_bootstraps_bootstraper
self.checkpoints_list = [None] * len(bootstrap_seeds_list)

for idx, seed in enumerate(bootstrap_seeds_list):
rng = np.random.default_rng(seed)
bootstrapped_data = datasamples.sample(datasamples.shape[0], replace=True, random_state=rng)
# Loading the correct state into the current main algo
if self.checkpoints_list[idx] is not None:
self.load_local_state_original(self.checkpoints_list[idx])
self._update_from_checkpoint(self.checkpoints_list[idx])
if shared_state is None:
res = local_function(datasamples=bootstrapped_data, _skip=True)
else:
Expand Down Expand Up @@ -356,33 +359,45 @@ def load_local_state(self, path: Path) -> "TorchAlgo":

# Note that at the end of this loop the main state is the one of the last
# bootstrap
if (path.parent / "bootstrap_0") in path.parent.iterdir():
for idx in range(self.n_bootstraps_bootstraper):
self.load_local_state_original(path / f"bootstrap_{idx}")
self.checkpoints_list[idx] = self.get_state_to_save()
else:
# This first call is needed when no bootstrap has been done
# self.load_local_state_original(path)
self.load_local_state_original(path)
archive = zipfile.ZipFile(path, 'r')
with tempfile.TemporaryDirectory() as tmpdirname:
archive.extractall(tmpdirname)
checkpoints_found = [p for p in Path(tmpdirname).glob("**/bootstrap_*")]
self.checkpoints_list = [None] * len(checkpoints_found)
for idx, file in enumerate(checkpoints_found):
self.load_local_state_original(file)
self.checkpoints_list[idx] = self._get_state_to_save()
return self

return load_local_state


def _save_all_bootstraps_states():
def _save_all_bootstraps_states(bootstrap_seeds_list):
def save_local_state(self, path: Path) -> "TorchAlgo":
# We save all bootstrapped states in different subfolders
# It assumes at this point checkpoints_list has been populated
# if it exists
# We need to save the checkponts list attribute

# The reason for the if is because of initialize functions which don't
# populate checkpoints_list
if hasattr(self, "checkpoints_list"):
pass
else:
self.checkpoints_list = [copy.deepcopy(self._get_state_to_save()) for _ in range(len(bootstrap_seeds_list))]

with tempfile.TemporaryDirectory() as tmpdirname:
paths_to_checkpoints = []
for idx, checkpt in enumerate(self.checkpoints_list):
# Get the model in the proper state
self._update_from_checkpoint(checkpt)
# TODO methods implictly use the self attribute
self.load_local_state_original(checkpt)
self.save_local_state_original(path=path / f"bootstrap_{idx}")
else:
# First call of thiis function checkpoints_list doesn't exist
# we need to be able to load sthg
self.save_local_state_original(path=path)
path_to_checkpoint = Path(tmpdirname) / f"bootstrap_{idx}"
self.save_local_state_original(path_to_checkpoint)
paths_to_checkpoints.append(path_to_checkpoint)

with zipfile.ZipFile(path, 'w') as f:
for chkpt in paths_to_checkpoints:
f.write(chkpt, compress_type=zipfile.ZIP_DEFLATED)
return self
return save_local_state

Expand Down

0 comments on commit be02744

Please sign in to comment.