Skip to content

Commit

Permalink
weird bugs still...
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Jan 13, 2024
1 parent e59dff6 commit 9ed1e9d
Showing 1 changed file with 48 additions and 52 deletions.
100 changes: 48 additions & 52 deletions fedeca/strategies/bootstraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,48 +119,43 @@ def make_bootstrap_strategy(strategy: Strategy, n_bootstraps: Union[int, None] =
for agg, name in zip(orig_aggregations[key], aggregations_names[key])
]
# We have to overwrite the original methods at the class level
# obj_class = strategy.algo.__class__
obj = strategy.algo
for local_name in local_functions_names["algo"]:
# f = types.MethodType(_bootstrap_local_function(getattr(obj_class, local_name), local_name, bootstrap_seeds_list), obj_class)
# setattr(obj_class, local_name, f)
f = types.MethodType(_bootstrap_local_function(getattr(obj, local_name), local_name, bootstrap_seeds_list), obj)
setattr(obj, local_name, f)
for agg_name in aggregations_names["algo"]:
# f = types.MethodType(_aggregate_all_bootstraps(getattr(obj_class, agg_name), agg_name), obj_class)
# setattr(obj_class, agg_name, f)
f = types.MethodType(_aggregate_all_bootstraps(getattr(obj, agg_name), agg_name), obj)
setattr(obj, agg_name, f)

# f = types.MethodType(_save_all_bootstraps_states(getattr(obj_class, "save_local_state"), bootstrap_seeds_list), obj_class)
# setattr(obj_class, "save_local_state", f)

# f = types.MethodType(_load_all_bootstraps_states(getattr(obj_class, "load_local_state")), obj_class)
# setattr(obj_class, "load_local_state", f)

f = types.MethodType(_save_all_bootstraps_states(getattr(obj, "save_local_state"), bootstrap_seeds_list), obj)
setattr(obj, "save_local_state", f)

f = types.MethodType(_load_all_bootstraps_states(getattr(obj, "load_local_state")), obj)
setattr(obj, "load_local_state", f)



# obj_class = strategy.__class__
obj = strategy
for local_name in local_functions_names["strategy"]:
# f = types.MethodType(_bootstrap_local_function(getattr(obj_class, local_name), local_name, bootstrap_seeds_list), obj_class)
# setattr(obj_class, local_name, f)
f = types.MethodType(_bootstrap_local_function(getattr(obj, local_name), local_name, bootstrap_seeds_list), obj)
setattr(obj, local_name, f)
for agg_name in aggregations_names["strategy"]:
# f = types.MethodType(_aggregate_all_bootstraps(getattr(obj_class, agg_name), agg_name), obj_class)
# setattr(obj_class, agg_name, f)
f = types.MethodType(_aggregate_all_bootstraps(getattr(obj, agg_name), agg_name), obj)
setattr(obj, agg_name, f)


return strategy
class BtstAlgo(strategy.algo.__class__):
def __init__(self):
super().__init__(**strategy.algo.kwargs)
for local_name in local_functions_names["algo"]:
setattr(self, local_name + "_original", getattr(self, local_name))
f = types.MethodType(_bootstrap_local_function(getattr(self, local_name), local_name, bootstrap_seeds_list), self)
setattr(self, local_name, f)

for agg_name in aggregations_names["algo"]:
setattr(self, agg_name + "_original", getattr(self, agg_name))
f = types.MethodType(_aggregate_all_bootstraps(getattr(self, agg_name), agg_name), self)
setattr(self, agg_name, f)

setattr(self, "save_local_state_original", self.save_local_state)
f = types.MethodType(_save_all_bootstraps_states(getattr(self, "save_local_state"), bootstrap_seeds_list), self)
setattr(self, "save_local_state", f)

setattr(self, "load_local_state_original", self.load_local_state)
f = types.MethodType(_load_all_bootstraps_states(getattr(self, "load_local_state")), self)
setattr(self, "load_local_state", f)

btst_algo = BtstAlgo()
class BtstStrategy(strategy.__class__):
def __init__(self):
kwargs = strategy.kwargs
kwargs["algo"] = btst_algo
super().__init__(**kwargs)
for local_name in local_functions_names["strategy"]:
setattr(self, local_name + "_original", getattr(self, local_name))
f = types.MethodType(_bootstrap_local_function(getattr(self, local_name), local_name, bootstrap_seeds_list), self)
setattr(self, local_name, f)
for agg_name in aggregations_names["strategy"]:
setattr(self, agg_name + "_original", getattr(self, agg_name))
f = types.MethodType(_aggregate_all_bootstraps(getattr(self, agg_name), agg_name), self)
setattr(self, agg_name, f)

return BtstStrategy()
# # Very important we have to decorate AT THE CLASS LEVEL
# # here we decorate both at the instance and at the class level
# # but for actual deployments only class-level is important
Expand Down Expand Up @@ -256,14 +251,15 @@ def local_computation(self, datasamples, shared_state=None) -> list:
res = getattr(self, name_decorated_function)(
datasamples=bootstrapped_data, shared_state=shared_state[idx], _skip=True
)
self.checkpoints_list[idx] = self._get_state_to_save()
self.checkpoints_list[idx] = copy.deepcopy(self._get_state_to_save())

# We restore the algo to its old state
for att_name, att in vars(self).items():
if att_name == "checkpoints_list":
continue
if att != old_state.__getattribute__(att_name):
self.__setattr__(att_name, old_state.__getattribute__(att_name))
results.append(res)

return results

# We need to change its name before decorating it,
Expand Down Expand Up @@ -315,14 +311,15 @@ def aggregation(self, shared_states=None) -> list:
Global results to be shared with train nodes via shared_state.
"""
results = []
name_decorated_function = aggregation_function.__name__ + "_original"
if shared_states is not None:
# loop over the aggregation steps provided using _skip=True
for shared_state in shared_states:
res = aggregation_function(shared_states=shared_state, _skip=True)
res = getattr(self, name_decorated_function)(shared_states=shared_state, _skip=True)
results.append(res)
else:
# This is the case in initialize
results = aggregation_function(shared_states=None, _skip=True)
results = getattr(self, name_decorated_function)(shared_states=None, _skip=True)

return results
aggregation.__name__ = new_op_name
Expand All @@ -346,11 +343,11 @@ def load_local_state(self, path: Path) -> "TorchAlgo":
archive = zipfile.ZipFile(path, 'r')
with tempfile.TemporaryDirectory() as tmpdirname:
archive.extractall(tmpdirname)
checkpoints_found = [p for p in Path(tmpdirname).glob("**/bootstrap_*")]
checkpoints_found = sorted([p for p in Path(tmpdirname).glob("**/bootstrap_*")])
self.checkpoints_list = [None] * len(checkpoints_found)
for idx, file in enumerate(checkpoints_found):
load_local_state(file)
self.checkpoints_list[idx] = self._get_state_to_save()
self.load_local_state_original(file)
self.checkpoints_list[idx] = copy.deepcopy(self._get_state_to_save())
return self

return load_local_state
Expand All @@ -368,15 +365,14 @@ def save_local_state(self, path: Path) -> "TorchAlgo":
pass
else:
self.checkpoints_list = [copy.deepcopy(self._get_state_to_save()) for _ in range(len(bootstrap_seeds_list))]

with tempfile.TemporaryDirectory() as tmpdirname:
paths_to_checkpoints = []
for idx, checkpt in enumerate(self.checkpoints_list):
# Get the model in the proper state
self._update_from_checkpoint(checkpt)
# TODO methods implictly use the self attribute
assert not checkpt
path_to_checkpoint = Path(tmpdirname) / f"bootstrap_{idx}"
save_local_state(path_to_checkpoint)
self.save_local_state_original(path_to_checkpoint)
paths_to_checkpoints.append(path_to_checkpoint)

with zipfile.ZipFile(path, 'w') as f:
Expand Down

0 comments on commit 9ed1e9d

Please sign in to comment.