Skip to content

Commit

Permalink
trying to make linting better
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Jan 18, 2024
1 parent 82bef30 commit 480709d
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions fedeca/strategies/bootstraper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Bootstrap substra strategy in an efficient fashion."""
import copy
import inspect
import os
Expand Down Expand Up @@ -46,6 +47,7 @@ def make_bootstrap_strategy(
of both allways use bootstrap_seeds.
inplace : bool, optional
Whether to modify the strategy inplace or not, by default False.
Returns
-------
Strategy
Expand Down Expand Up @@ -206,16 +208,21 @@ def save_local_state(self, path: Path) -> "TorchAlgo":
return self

def load_local_state(self, path: Path) -> "TorchAlgo":
"""Load the stateful arguments of this class. Child classes do not need to
"""Load the stateful arguments of this class.
Child classes do not need to
override that function.
Args:
path (pathlib.Path): The path where the class has been saved.
Parameters
----------
path : pathlib.Path
The path where the class has been saved.
Returns:
TorchAlgo: The class with the loaded elements.
Returns
-------
TorchAlgo
The class with the loaded elements.
"""

# Note that at the end of this loop the main state is the one of the last
# bootstrap
archive = zipfile.ZipFile(path, "r")
Expand Down Expand Up @@ -280,13 +287,6 @@ def __init__(self, **kwargs):
return BtstStrategy(algo=btst_algo, **strategy.kwargs), bootstrap_seeds_list


def _bootstrap_predict(predict):
def new_predict(self, predictions_path):
return self

return new_predict


def _bootstrap_local_function(local_function, new_op_name, bootstrap_seeds_list):
"""Bootstrap the local functiion given.
Expand Down Expand Up @@ -433,6 +433,7 @@ def aggregation(self, shared_states=None) -> list:
----------
self : MergedStrategy
The mergedStrategy instance.
shared_states : List
List of lists of results returned by local_computation ran at
previous step.
Expand Down Expand Up @@ -470,6 +471,14 @@ def aggregation(self, shared_states=None) -> list:


def make_bootstrap_metric_function(metric_function):
"""Averages metric on each bootstrapped versions of the models.
Parameters
----------
metric_function : list
The metric function to hook.
"""

def bootstraped_metric(datasamples, predictions_path):
list_of_metrics = []
if isinstance(predictions_path, str) or isinstance(predictions_path, Path):
Expand Down

0 comments on commit 480709d

Please sign in to comment.