diff --git a/examples/dry_run/customized_client.py b/examples/dry_run/customized_client.py index 4ead6fa6..f2f240c6 100644 --- a/examples/dry_run/customized_client.py +++ b/examples/dry_run/customized_client.py @@ -12,14 +12,14 @@ class Customized_Client(TorchClient): """Basic client component in Federated Learning""" - def train(self, client_data, model, conf): """We flip the label of the malicious client""" + device = conf.cuda_device if conf.use_cuda else torch.device( + 'cpu') + client_id = conf.client_id logging.info(f"Start to train (CLIENT: {client_id}) ...") - device = conf.device - model = model.to(device=device) model.train() diff --git a/fedscale/cloud/internal/torch_model_adapter.py b/fedscale/cloud/internal/torch_model_adapter.py index 813d0869..0d258ec9 100644 --- a/fedscale/cloud/internal/torch_model_adapter.py +++ b/fedscale/cloud/internal/torch_model_adapter.py @@ -25,7 +25,7 @@ def set_weights(self, weights: List[np.ndarray]): Set the model's weights to the numpy weights array. :param weights: numpy weights array """ - current_grad_weights = [param.data.clone() for param in self.model.state_dict().values()] + last_grad_weights = [param.data.clone() for param in self.model.state_dict().values()] new_state_dict = { name: torch.from_numpy(np.asarray(weights[i], dtype=np.float32)) for i, name in enumerate(self.model.state_dict().keys()) @@ -34,7 +34,7 @@ def set_weights(self, weights: List[np.ndarray]): if self.optimizer: weights_origin = copy.deepcopy(weights) weights = [torch.tensor(x) for x in weights_origin] - self.optimizer.update_round_gradient(weights, current_grad_weights, self.model) + self.optimizer.update_round_gradient(last_grad_weights, weights, self.model) def get_weights(self) -> List[np.ndarray]: """