Skip to content

Commit

Permalink
tests don't pass hmm...
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Jan 18, 2024
1 parent bd5989a commit 841108e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
2 changes: 1 addition & 1 deletion fedeca/strategies/bootstraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Union

import numpy as np
import torch
from substrafl.algorithms.pytorch.torch_base_algo import TorchAlgo
from substrafl.remote import remote, remote_data
from substrafl.strategies.strategy import Strategy
Expand Down Expand Up @@ -491,7 +492,6 @@ def bootstraped_metric(datasamples, predictions_path):

if __name__ == "__main__":
import pandas as pd
import torch
from substrafl.algorithms.pytorch import TorchFedAvgAlgo # , TorchNewtonRaphsonAlgo
from substrafl.dependency import Dependency
from substrafl.evaluation_strategy import EvaluationStrategy
Expand Down
21 changes: 17 additions & 4 deletions fedeca/tests/test_bootstraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ def test_bootstrapping(strategy_params: dict, num_rounds: int):
**strategy_params["strategy"]["strategy_kwargs"]
)

btst_strategy = make_bootstrap_strategy(
btst_strategy, _ = make_bootstrap_strategy(
strategy, bootstrap_seeds=bootstrap_seeds_list
)

# inefficient bootstrap
bootstrapped_models = []
bootstrapped_models_gt = []
for idx, seed in enumerate(bootstrap_seeds_list):
rng = np.random.default_rng(seed)
bootstrapped_df = df.sample(df.shape[0], replace=True, random_state=rng)
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_bootstrapping(strategy_params: dict, num_rounds: int):
compute_plan_key=compute_plan.key,
round_idx=strategy_params["get_true_nb_rounds"](num_rounds),
)
bootstrapped_models.append(algo.model)
bootstrapped_models_gt.append(algo.model)
# Clean up
shutil.rmtree("./data")

Expand Down Expand Up @@ -211,4 +211,17 @@ def test_bootstrapping(strategy_params: dict, num_rounds: int):
compute_plan_key=compute_plan.key,
round_idx=strategy_params["get_true_nb_rounds"](num_rounds),
)
bootstrapped_models = algo.model
breakpoint()
bootstrapped_models_efficient = [
UnifLogReg(ndim=50) for _ in range(len(bootstrap_seeds_list))
]
[
m.load_state_dict(chkpt["model_state_dict"])
for m, chkpt in zip(bootstrapped_models_efficient, algo.checkpoints_list)
]
for model1, model2 in zip(bootstrapped_models_gt, bootstrapped_models_efficient):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
if p1.data.ne(p2.data).sum() > 0:
assert False
# Clean up
shutil.rmtree("./data")

0 comments on commit 841108e

Please sign in to comment.