Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pinning a source of randomness in server-side aggregation. #278

Merged
merged 13 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fl4health/strategies/aggregate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def aggregate_losses(results: List[Tuple[int, float]], weighted: bool = True) ->
Returns:
float: the weighted or unweighted average of the loss values in the results list.
"""
# Sorting the results by the loss values for numerical fluctuation determinism of the sum
results = sorted(results, key=lambda x: x[1])
if weighted:
# uses flwr implementation of weighted loss averaging
return weighted_loss_avg(results)
Expand Down
13 changes: 8 additions & 5 deletions fl4health/strategies/basic_fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Parameters,
Scalar,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
Expand All @@ -23,6 +22,7 @@
from fl4health.client_managers.base_sampling_manager import BaseFractionSamplingManager
from fl4health.strategies.aggregate_utils import aggregate_losses, aggregate_results
from fl4health.strategies.strategy_with_poll import StrategyWithPolling
from fl4health.utils.functions import decode_and_pseudo_sort_results
from fl4health.utils.parameter_extraction import get_all_model_parameters


Expand Down Expand Up @@ -248,12 +248,15 @@ def aggregate_fit(
if not self.accept_failures and failures:
return None, {}

# Convert results
weights_results = [
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results
# Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in s
emersodb marked this conversation as resolved.
Show resolved Hide resolved
# summing the numpy arrays during aggregation. This ensures that addition will occur in the same order,
# reducing numerical fluctuation.
decoded_and_sorted_results = [
(weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results)
]

# Aggregate them in a weighted or unweighted fashion based on settings.
aggregated_arrays = aggregate_results(weights_results, self.weighted_aggregation)
aggregated_arrays = aggregate_results(decoded_and_sorted_results, self.weighted_aggregation)
# Convert back to parameters
parameters_aggregated = ndarrays_to_parameters(aggregated_arrays)

Expand Down
15 changes: 10 additions & 5 deletions fl4health/strategies/client_dp_fedavgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
gaussian_noisy_unweighted_aggregate,
gaussian_noisy_weighted_aggregate,
)
from fl4health.utils.functions import decode_and_pseudo_sort_results


class ClientLevelDPFedAvgM(BasicFedAvg):
Expand Down Expand Up @@ -195,13 +196,17 @@ def split_model_weights_and_clipping_bits(
Tuple[List[Tuple[NDArrays, int]], NDArrays]: The first tuple is the set of (weights, training counts) per
client. The second is a set of clipping bits, one for each client.
"""
# Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in s
# summing the numpy arrays during aggregation. This ensures that addition will occur in the same order,
# reducing numerical fluctuation.
decoded_and_sorted_results = [
(weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results)
]

weights_and_counts: List[Tuple[NDArrays, int]] = []
clipping_bits: NDArrays = []
for _, fit_res in results:
sample_count = fit_res.num_examples
updated_weights, clipping_bit = self.parameter_packer.unpack_parameters(
parameters_to_ndarrays(fit_res.parameters)
)
for weights, sample_count in decoded_and_sorted_results:
updated_weights, clipping_bit = self.parameter_packer.unpack_parameters(weights)
weights_and_counts.append((updated_weights, sample_count))
clipping_bits.append(np.array(clipping_bit))

Expand Down
22 changes: 9 additions & 13 deletions fl4health/strategies/fedavg_dynamic_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,14 @@
from typing import Callable, DefaultDict, Dict, List, Optional, Tuple, Union

import numpy as np
from flwr.common import (
MetricsAggregationFn,
NDArray,
NDArrays,
Parameters,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.common import MetricsAggregationFn, NDArray, NDArrays, Parameters, ndarrays_to_parameters
from flwr.common.logger import log
from flwr.common.typing import FitRes, Scalar
from flwr.server.client_proxy import ClientProxy

from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithLayerNames
from fl4health.strategies.basic_fedavg import BasicFedAvg
from fl4health.utils.functions import decode_and_pseudo_sort_results


class FedAvgDynamicLayer(BasicFedAvg):
Expand Down Expand Up @@ -124,13 +118,15 @@ def aggregate_fit(
if not self.accept_failures and failures:
return None, {}

# Convert client layer weights and names into ndarrays
weights_results = [
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results
# Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in s
# summing the numpy arrays during aggregation. This ensures that addition will occur in the same order,
# reducing numerical fluctuation.
decoded_and_sorted_results = [
(weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results)
]

# For each layer of the model, perform weighted average of all received weights from clients
aggregated_params = self.aggregate(weights_results)
# Aggregate them in a weighted or unweighted fashion based on settings.
aggregated_params = self.aggregate(decoded_and_sorted_results)

weights_names = []
weights = []
Expand Down
13 changes: 8 additions & 5 deletions fl4health/strategies/fedavg_sparse_coo_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from typing import Callable, DefaultDict, Dict, List, Optional, Tuple, Union

import torch
from flwr.common import MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters, parameters_to_ndarrays
from flwr.common import MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters
from flwr.common.logger import log
from flwr.common.typing import FitRes, Scalar
from flwr.server.client_proxy import ClientProxy
from torch import Tensor

from fl4health.parameter_exchange.parameter_packer import SparseCooParameterPacker
from fl4health.strategies.basic_fedavg import BasicFedAvg
from fl4health.utils.functions import decode_and_pseudo_sort_results


class FedAvgSparseCooTensor(BasicFedAvg):
Expand Down Expand Up @@ -137,13 +138,15 @@ def aggregate_fit(
if not self.accept_failures and failures:
return None, {}

# Convert client tensor weights and names into ndarrays
weights_results = [
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results
# Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in s
# summing the numpy arrays during aggregation. This ensures that addition will occur in the same order,
# reducing numerical fluctuation.
decoded_and_sorted_results = [
(weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results)
]

# For each tensor of the model, perform weighted average of all received weights from clients
aggregated_tensors = self.aggregate(weights_results)
aggregated_tensors = self.aggregate(decoded_and_sorted_results)

tensor_names = []
selected_parameters_all_tensors = []
Expand Down
15 changes: 10 additions & 5 deletions fl4health/strategies/fedavg_with_adaptive_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint
from fl4health.strategies.aggregate_utils import aggregate_losses, aggregate_results
from fl4health.strategies.basic_fedavg import BasicFedAvg
from fl4health.utils.functions import decode_and_pseudo_sort_results


class FedAvgWithAdaptiveConstraint(BasicFedAvg):
Expand Down Expand Up @@ -157,14 +158,18 @@ def aggregate_fit(
if not self.accept_failures and failures:
return None, {}

# Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in s
# summing the numpy arrays during aggregation. This ensures that addition will occur in the same order,
# reducing numerical fluctuation.
decoded_and_sorted_results = [
(weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results)
]

# Convert results with packed params of model weights and training loss
weights_and_counts: List[Tuple[NDArrays, int]] = []
train_losses_and_counts: List[Tuple[int, float]] = []
for _, fit_res in results:
sample_count = fit_res.num_examples
updated_weights, train_loss = self.parameter_packer.unpack_parameters(
parameters_to_ndarrays(fit_res.parameters)
)
for weights, sample_count in decoded_and_sorted_results:
updated_weights, train_loss = self.parameter_packer.unpack_parameters(weights)
weights_and_counts.append((updated_weights, sample_count))
train_losses_and_counts.append((sample_count, train_loss))

Expand Down
20 changes: 10 additions & 10 deletions fl4health/strategies/feddg_ga.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,15 @@
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np
from flwr.common import (
EvaluateIns,
MetricsAggregationFn,
NDArrays,
Parameters,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.common import EvaluateIns, MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters
from flwr.common.logger import log
from flwr.common.typing import EvaluateRes, FitIns, FitRes, Scalar
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy import FedAvg

from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager
from fl4health.utils.functions import decode_and_pseudo_sort_results


class SignalForTypeException(Exception):
Expand Down Expand Up @@ -323,6 +317,7 @@ def aggregate_evaluate(
(Tuple[Optional[float], Dict[str, Scalar]]) A tuple containing the aggregated evaluation loss
and the aggregated evaluation metrics.
"""

loss_aggregated, metrics_aggregated = super().aggregate_evaluate(server_round, results, failures)

self.evaluation_metrics = {}
Expand All @@ -349,8 +344,13 @@ def weight_and_aggregate_results(self, results: List[Tuple[ClientProxy, FitRes]]
(NDArrays) the weighted and aggregated results.
"""

# Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in s
# summing the numpy arrays during aggregation. This ensures that addition will occur in the same order,
# reducing numerical fluctuation.
decoded_and_sorted_results = decode_and_pseudo_sort_results(results)

aggregated_results: Optional[NDArrays] = None
for client_proxy, fit_res in results:
for client_proxy, weights, _ in decoded_and_sorted_results:
cid = client_proxy.cid

# initializing adjustment weights for this client if they don't exist yet
Expand All @@ -359,7 +359,7 @@ def weight_and_aggregate_results(self, results: List[Tuple[ClientProxy, FitRes]]
self.adjustment_weights[cid] = self.initial_adjustment_weight

# apply adjustment weights
weighted_client_parameters = parameters_to_ndarrays(fit_res.parameters)
weighted_client_parameters = weights
for i in range(len(weighted_client_parameters)):
weighted_client_parameters[i] = weighted_client_parameters[i] * self.adjustment_weights[cid]

Expand Down
20 changes: 10 additions & 10 deletions fl4health/strategies/fedpca.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,13 @@
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np
from flwr.common import (
MetricsAggregationFn,
NDArray,
NDArrays,
Parameters,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.common import MetricsAggregationFn, NDArray, NDArrays, Parameters, ndarrays_to_parameters
from flwr.common.logger import log
from flwr.common.typing import FitRes, Scalar
from flwr.server.client_proxy import ClientProxy

from fl4health.strategies.basic_fedavg import BasicFedAvg
from fl4health.utils.functions import decode_and_pseudo_sort_results


class FedPCA(BasicFedAvg):
Expand Down Expand Up @@ -122,10 +116,16 @@ def aggregate_fit(
if not self.accept_failures and failures:
return None, {}

# Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in s
# summing the numpy arrays during aggregation. This ensures that addition will occur in the same order,
# reducing numerical fluctuation.
decoded_and_sorted_results = [
(weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results)
]

client_singular_values = []
client_singular_vectors = []
for _, fit_res in results:
A = parameters_to_ndarrays(fit_res.parameters)
for A, _ in decoded_and_sorted_results:
singular_vectors, singular_values = A[0], A[1]
client_singular_vectors.append(singular_vectors)
client_singular_values.append(singular_values)
Expand Down
2 changes: 2 additions & 0 deletions fl4health/strategies/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,11 @@ def aggregate_fit(
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using the Flash method."""

fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit(
server_round=server_round, results=results, failures=failures
)

if fedavg_parameters_aggregated is None:
return None, {}

Expand Down
14 changes: 9 additions & 5 deletions fl4health/strategies/model_merge_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from fl4health.client_managers.base_sampling_manager import BaseFractionSamplingManager
from fl4health.strategies.aggregate_utils import aggregate_results
from fl4health.utils.functions import decode_and_pseudo_sort_results


class ModelMergeStrategy(Strategy):
Expand Down Expand Up @@ -188,12 +189,15 @@ def aggregate_fit(
if not self.accept_failures and failures:
return None, {}

# Convert results
weights_results = [
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results
# Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in s
# summing the numpy arrays during aggregation. This ensures that addition will occur in the same order,
# reducing numerical fluctuation.
decoded_and_sorted_results = [
(weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results)
]

# Aggregate them in an weighted or unweighted fashion based on self.weighted_aggregation.
aggregated_arrays = aggregate_results(weights_results, self.weighted_aggregation)
aggregated_arrays = aggregate_results(decoded_and_sorted_results, self.weighted_aggregation)
# Convert back to parameters
parameters_aggregated = ndarrays_to_parameters(aggregated_arrays)

Expand Down Expand Up @@ -246,7 +250,7 @@ def aggregate_evaluate(

def evaluate(self, server_round: int, parameters: Parameters) -> Optional[Tuple[float, Dict[str, Scalar]]]:
"""
Evaluate the model parameters after the merging has occured. This function can be used to perform centralized
Evaluate the model parameters after the merging has occurred. This function can be used to perform centralized
(i.e., server-side) evaluation of model parameters.

Args:
Expand Down
9 changes: 6 additions & 3 deletions fl4health/strategies/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from fl4health.client_managers.base_sampling_manager import BaseFractionSamplingManager
from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithControlVariates
from fl4health.strategies.basic_fedavg import BasicFedAvg
from fl4health.utils.functions import decode_and_pseudo_sort_results
from fl4health.utils.parameter_extraction import get_all_model_parameters


Expand Down Expand Up @@ -179,12 +180,14 @@ def aggregate_fit(
if not self.accept_failures and failures:
return None, {}

# Convert results with packed params of model weights and client control variate updates
updated_params = [parameters_to_ndarrays(fit_res.parameters) for _, fit_res in results]
# Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in s
# summing the numpy arrays during aggregation. This ensures that addition will occur in the same order,
# reducing numerical fluctuation.
decoded_and_sorted_results = [weights for _, weights, _ in decode_and_pseudo_sort_results(results)]

# x = 1 / |S| * sum(x_i) and c = 1 / |S| * sum(delta_c_i)
# Aggregation operation over packed params (includes both weights and control variate updates)
aggregated_params = self.aggregate(updated_params)
aggregated_params = self.aggregate(decoded_and_sorted_results)

weights, control_variates_update = self.parameter_packer.unpack_parameters(aggregated_params)

Expand Down
Loading