diff --git a/fedeca/algorithms/torch_dp_fed_avg_algo.py b/fedeca/algorithms/torch_dp_fed_avg_algo.py index 9cce2b6c..5a232ef4 100644 --- a/fedeca/algorithms/torch_dp_fed_avg_algo.py +++ b/fedeca/algorithms/torch_dp_fed_avg_algo.py @@ -1,5 +1,6 @@ """Differentially private algorithm to be used with FedAvg strategy.""" import logging +import random from typing import Any, Optional import numpy as np @@ -427,6 +428,9 @@ def _update_from_checkpoint(self, checkpoint: dict) -> None: self._index_generator = checkpoint.pop("index_generator") + random.setstate(checkpoint.pop("random_rng_state")) + np.random.set_state(checkpoint.pop("numpy_rng_state")) + if self._device == torch.device("cpu"): torch.set_rng_state(checkpoint.pop("torch_rng_state").to(self._device)) else: