diff --git a/fedeca/strategies/bootstraper.py b/fedeca/strategies/bootstraper.py index d93affa8..6eb1b746 100644 --- a/fedeca/strategies/bootstraper.py +++ b/fedeca/strategies/bootstraper.py @@ -186,10 +186,17 @@ def load_local_state(self, path: Path) -> "TorchAlgo": @remote_data def predict(self, datasamples, shared_state = None, predictions_path: os.PathLike = None): predictions = [] - for _, chckpt in enumerate(self.checkpoints_list): - self._update_from_checkpoint(chckpt) - pred_btst = super(strategy.algo.__class__, self).predict(datasamples=datasamples, shared_state=shared_state, predictions_path=predictions_path, _skip=True) - predictions.append(pred_btst) + with tempfile.TemporaryDirectory() as tmpdirname: + paths_to_preds = [] + for idx, chckpt in enumerate(self.checkpoints_list): + self._update_from_checkpoint(chckpt) + path_to_pred = Path(tmpdirname) / f"bootstrap_{idx}" + super(strategy.algo.__class__, self).predict(datasamples=datasamples, shared_state=shared_state, predictions_path=path_to_pred, _skip=True) + paths_to_preds.append(path_to_pred) + with zipfile.ZipFile(predictions_path, 'w') as f: + for pred in paths_to_preds: + f.write(pred, compress_type=zipfile.ZIP_DEFLATED) + return predictions btst_algo = BtstAlgo(**strategy.algo.kwargs) @@ -443,6 +450,10 @@ def __init__(self): data_path="./data", backend_type="subprocess", ) + + for node in test_data_nodes: + node.metric_functions = {accuracy.__name__: accuracy} + first_key = list(clients.keys())[0] aggregation_node = AggregationNode(clients[first_key].organization_info().organization_id)