diff --git a/fedeca/strategies/bootstraper.py b/fedeca/strategies/bootstraper.py index 8c12a26f..dca6433a 100644 --- a/fedeca/strategies/bootstraper.py +++ b/fedeca/strategies/bootstraper.py @@ -118,73 +118,49 @@ def make_bootstrap_strategy(strategy: Strategy, n_bootstraps: Union[int, None] = _aggregate_all_bootstraps(agg, name) for agg, name in zip(orig_aggregations[key], aggregations_names[key]) ] - # define what will be the __init__ method of - # the to-be-dynamically-defined MergedClass - def bootstrapped_algo_init(self, *args, **kwargs): - """Initialize the merged strategy. - - Parameters - ---------- - self : BtstAlgo - The BtstAlgo instance. - algo : Strategy - List of the strategies used. - args: Any - extra arguments - kwargs: Any - extra keyword arguments - """ - super(self.__class__, self).__init__( - *args, **kwargs - ) - self.original_algo = strategy.algo - - # Dict holding the methods of the to-be-dynamically-defined MergedClass - methods_dict = dict(zip(local_functions_names["algo"], local_computations_fct["algo"])) - methods_dict.update(dict(zip(aggregations_names["algo"], aggregations_fct["algo"]))) - methods_dict.update({"load_local_state_original": getattr(strategy.algo, "load_local_state")}) - methods_dict.update({"save_local_state_original": getattr(strategy.algo, "save_local_state")}) - methods_dict.update({"load_local_state": _load_all_bootstraps_states()}) - methods_dict.update({"save_local_state": _save_all_bootstraps_states(bootstrap_seeds_list)}) - methods_dict.update({"strategies": strategy.algo.strategies}) - methods_dict.update({"__init__": bootstrapped_algo_init}) - - # dynamically define the BtstStrategy Class, which inherits from - # Strategy, and whose methods are defined by the method dict. - BtstAlgo = type("BtstAlgo", (strategy.algo.__class__,), methods_dict) - btst_algo = BtstAlgo() - - # define what will be the __init__ method of - # the to-be-dynamically-defined MergedClass - def bootstrapped_strategy_init(self, *args, **kwargs): - """Initialize the merged strategy. - - Parameters - ---------- - self : BtstStrategy - The BtstStrategy instance. - strategy : Strategy - List of the strategies used. - args: Any - extra arguments - kwargs: Any - extra keyword arguments - """ - super(self.__class__, self).__init__( - *args, **kwargs - ) - self.original_strategy = strategy - - # Dict holding the methods of the to-be-dynamically-defined MergedClass - methods_dict = dict(zip(local_functions_names["strategy"], local_computations_fct["strategy"])) - methods_dict.update(dict(zip(aggregations_names["strategy"], aggregations_fct["strategy"]))) - methods_dict.update({"__init__": bootstrapped_strategy_init}) - # dynamically define the BtstStrategy Class, which inherits from - # Strategy, and whose methods are defined by the method dict. - BtstStrategy = type("BtstStrategy", (strategy.__class__,), methods_dict) - # return an instance of this class. - strat = BtstStrategy(algo=btst_algo) - return strat + # We have to overwrite the original methods at the class level + # obj_class = strategy.algo.__class__ + obj = strategy.algo + for local_name in local_functions_names["algo"]: + # f = types.MethodType(_bootstrap_local_function(getattr(obj_class, local_name), local_name, bootstrap_seeds_list), obj_class) + # setattr(obj_class, local_name, f) + f = types.MethodType(_bootstrap_local_function(getattr(obj, local_name), local_name, bootstrap_seeds_list), obj) + setattr(obj, local_name, f) + for agg_name in aggregations_names["algo"]: + # f = types.MethodType(_aggregate_all_bootstraps(getattr(obj_class, agg_name), agg_name), obj_class) + # setattr(obj_class, agg_name, f) + f = types.MethodType(_aggregate_all_bootstraps(getattr(obj, agg_name), agg_name), obj) + setattr(obj, agg_name, f) + + # f = types.MethodType(_save_all_bootstraps_states(getattr(obj_class, "save_local_state"), bootstrap_seeds_list), obj_class) + # setattr(obj_class, "save_local_state", f) + + # f = types.MethodType(_load_all_bootstraps_states(getattr(obj_class, "load_local_state")), obj_class) + # setattr(obj_class, "load_local_state", f) + + f = types.MethodType(_save_all_bootstraps_states(getattr(obj, "save_local_state"), bootstrap_seeds_list), obj) + setattr(obj, "save_local_state", f) + + f = types.MethodType(_load_all_bootstraps_states(getattr(obj, "load_local_state")), obj) + setattr(obj, "load_local_state", f) + + + + # obj_class = strategy.__class__ + obj = strategy + for local_name in local_functions_names["strategy"]: + # f = types.MethodType(_bootstrap_local_function(getattr(obj_class, local_name), local_name, bootstrap_seeds_list), obj_class) + # setattr(obj_class, local_name, f) + f = types.MethodType(_bootstrap_local_function(getattr(obj, local_name), local_name, bootstrap_seeds_list), obj) + setattr(obj, local_name, f) + for agg_name in aggregations_names["strategy"]: + # f = types.MethodType(_aggregate_all_bootstraps(getattr(obj_class, agg_name), agg_name), obj_class) + # setattr(obj_class, agg_name, f) + f = types.MethodType(_aggregate_all_bootstraps(getattr(obj, agg_name), agg_name), obj) + setattr(obj, agg_name, f) + + + return strategy # # Very important we have to decorate AT THE CLASS LEVEL # # here we decorate both at the instance and at the class level # # but for actual deployments only class-level is important @@ -258,27 +234,26 @@ def local_computation(self, datasamples, shared_state=None) -> list: results = [] # loop over the provided local_computation steps using skip=True. # What is highly non-trivial is that algo has a state that is bootstrap - # dependent and we need to load the correspponding state as the main + # dependent and we need to load the corresponding state as the main # state, so we need to have saved all states (aka i.e. n_bootstraps models) # We use implicitly the new method load_bootstrap_states to load all states in-RAM - + name_decorated_function = local_computation.__name__ + "_original" if not hasattr(self, "checkpoints_list"): self.checkpoints_list = [None] * len(bootstrap_seeds_list) for idx, seed in enumerate(bootstrap_seeds_list): rng = np.random.default_rng(seed) bootstrapped_data = datasamples.sample(datasamples.shape[0], replace=True, random_state=rng) - # Loading the correct state into the current main algo if self.checkpoints_list[idx] is not None: self._update_from_checkpoint(self.checkpoints_list[idx]) - # We need this old state tto avoid side effects from the function + # We need this old state to avoid side effects from the function # on the instance old_state = copy.deepcopy(self) if shared_state is None: - res = local_function(datasamples=bootstrapped_data, _skip=True) + res = getattr(self, name_decorated_function)(datasamples=bootstrapped_data, _skip=True) else: - res = local_function( + res = getattr(self, name_decorated_function)( datasamples=bootstrapped_data, shared_state=shared_state[idx], _skip=True ) self.checkpoints_list[idx] = self._get_state_to_save() @@ -287,7 +262,6 @@ def local_computation(self, datasamples, shared_state=None) -> list: for att_name, att in vars(self).items(): if att != old_state.__getattribute__(att_name): self.__setattr__(att_name, old_state.__getattribute__(att_name)) - results.append(res) return results @@ -355,7 +329,7 @@ def aggregation(self, shared_states=None) -> list: return remote(aggregation) -def _load_all_bootstraps_states(): +def _load_all_bootstraps_states(load_local_state): def load_local_state(self, path: Path) -> "TorchAlgo": """Load the stateful arguments of this class. Child classes do not need to override that function. @@ -375,14 +349,14 @@ def load_local_state(self, path: Path) -> "TorchAlgo": checkpoints_found = [p for p in Path(tmpdirname).glob("**/bootstrap_*")] self.checkpoints_list = [None] * len(checkpoints_found) for idx, file in enumerate(checkpoints_found): - self.load_local_state_original(file) + load_local_state(file) self.checkpoints_list[idx] = self._get_state_to_save() return self return load_local_state -def _save_all_bootstraps_states(bootstrap_seeds_list): +def _save_all_bootstraps_states(save_local_state, bootstrap_seeds_list): def save_local_state(self, path: Path) -> "TorchAlgo": # We save all bootstrapped states in different subfolders # It assumes at this point checkpoints_list has been populated @@ -402,7 +376,7 @@ def save_local_state(self, path: Path) -> "TorchAlgo": self._update_from_checkpoint(checkpt) # TODO methods implictly use the self attribute path_to_checkpoint = Path(tmpdirname) / f"bootstrap_{idx}" - self.save_local_state_original(path_to_checkpoint) + save_local_state(path_to_checkpoint) paths_to_checkpoints.append(path_to_checkpoint) with zipfile.ZipFile(path, 'w') as f: @@ -480,13 +454,12 @@ def __init__(self): strategy = FedAvg(algo=TorchLogReg()) btst_strategy = make_bootstrap_strategy(strategy, n_bootstraps=10) - clients, train_data_nodes, test_data_nodes, _, _ = split_dataframe_across_clients( df, n_clients=2, - split_method= "split_control_over_centers", - split_method_kwargs={"treatment_info": "treatment"}, + split_method= "uniform", + split_method_kwargs=None, data_path="./data", backend_type="subprocess", )