Skip to content

Commit

Permalink
fixing metric
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Jan 16, 2024
1 parent 8a6b748 commit 17aa1ff
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions fedeca/strategies/bootstraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def __init__(self, **kwargs):
setattr(self, agg_name, f)

strategy.kwargs.pop("algo")
return BtstStrategy(algo=btst_algo, **strategy.kwargs)
return BtstStrategy(algo=btst_algo, **strategy.kwargs), bootstrap_seeds_list

def _bootstrap_predict(predict):
def new_predict(self, predictions_path):
Expand Down Expand Up @@ -365,6 +365,24 @@ def aggregation(self, shared_states=None) -> list:
aggregation.__name__ = new_op_name
return remote(aggregation)

def make_bootstrap_metric_function(metric_function):
def bootstraped_metric(datasamples, predictions_path):
list_of_metrics = []
if isinstance(predictions_path, str) or isinstance(predictions_path, Path):
archive = zipfile.ZipFile(predictions_path, 'r')
with tempfile.TemporaryDirectory() as tmpdirname:
archive.extractall(tmpdirname)
preds_found = sorted([p for p in Path(tmpdirname).glob("**/bootstrap_*")])
for pred_found in preds_found:
list_of_metrics.append(metric_function(datasamples, pred_found))
else:
y_preds = predictions_path
for y_pred in y_preds:
list_of_metrics.append(metric_function(datasamples, y_pred))
return np.array(list_of_metrics).mean()

return bootstraped_metric


if __name__ == "__main__":
from substrafl.algorithms.pytorch import TorchFedAvgAlgo
Expand All @@ -386,6 +404,7 @@ def aggregation(self, shared_states=None) -> list:
from fedeca import LogisticRegressionTorch
from fedeca.utils.survival_utils import CoxData
import os
import pandas as pd

seed = 42
torch.manual_seed(seed)
Expand Down Expand Up @@ -424,6 +443,7 @@ def __init__(self, *args, **kwargs):
return_torch_tensors=True,
)
accuracy = make_accuracy_function("treatment")
accuracy_btst = make_bootstrap_metric_function(accuracy)


class TorchLogReg(TorchFedAvgAlgo):
Expand All @@ -440,7 +460,7 @@ def __init__(self):

strategy = FedAvg(algo=TorchLogReg())

btst_strategy = make_bootstrap_strategy(strategy, n_bootstraps=10)
btst_strategy, _ = make_bootstrap_strategy(strategy, n_bootstraps=10)

clients, train_data_nodes, test_data_nodes, _, _ = split_dataframe_across_clients(
df,
Expand All @@ -452,7 +472,7 @@ def __init__(self):
)

for node in test_data_nodes:
node.metric_functions = {accuracy.__name__: accuracy}
node.metric_functions = {accuracy_btst.__name__: accuracy_btst}

first_key = list(clients.keys())[0]

Expand All @@ -475,4 +495,5 @@ def __init__(self):
dependencies=dependencies,
clean_models=False,
name="FedECA",
)
)
print(pd.DataFrame(clients[first_key].get_performances(compute_plan.key).dict()))

0 comments on commit 17aa1ff

Please sign in to comment.