Skip to content

Commit

Permalink
fixing predict
Browse files Browse the repository at this point in the history
  • Loading branch information
jeandut committed Jan 16, 2024
1 parent 7557813 commit 8a6b748
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions fedeca/strategies/bootstraper.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,17 @@ def load_local_state(self, path: Path) -> "TorchAlgo":
@remote_data
def predict(self, datasamples, shared_state = None, predictions_path: os.PathLike = None):
predictions = []
for _, chckpt in enumerate(self.checkpoints_list):
self._update_from_checkpoint(chckpt)
pred_btst = super(strategy.algo.__class__, self).predict(datasamples=datasamples, shared_state=shared_state, predictions_path=predictions_path, _skip=True)
predictions.append(pred_btst)
with tempfile.TemporaryDirectory() as tmpdirname:
paths_to_preds = []
for idx, chckpt in enumerate(self.checkpoints_list):
self._update_from_checkpoint(chckpt)
path_to_pred = Path(tmpdirname) / f"bootstrap_{idx}"
super(strategy.algo.__class__, self).predict(datasamples=datasamples, shared_state=shared_state, predictions_path=path_to_pred, _skip=True)
paths_to_preds.append(path_to_pred)
with zipfile.ZipFile(predictions_path, 'w') as f:
for pred in paths_to_preds:
f.write(pred, compress_type=zipfile.ZIP_DEFLATED)

return predictions

btst_algo = BtstAlgo(**strategy.algo.kwargs)
Expand Down Expand Up @@ -443,6 +450,10 @@ def __init__(self):
data_path="./data",
backend_type="subprocess",
)

for node in test_data_nodes:
node.metric_functions = {accuracy.__name__: accuracy}

first_key = list(clients.keys())[0]

aggregation_node = AggregationNode(clients[first_key].organization_info().organization_id)
Expand Down

0 comments on commit 8a6b748

Please sign in to comment.