diff --git a/fl4health/strategies/aggregate_utils.py b/fl4health/strategies/aggregate_utils.py index bbe866cbc..ecef5eaeb 100644 --- a/fl4health/strategies/aggregate_utils.py +++ b/fl4health/strategies/aggregate_utils.py @@ -47,6 +47,8 @@ def aggregate_losses(results: List[Tuple[int, float]], weighted: bool = True) -> Returns: float: the weighted or unweighted average of the loss values in the results list. """ + # Sorting the results by the loss values for numerical fluctuation determinism of the sum + results = sorted(results, key=lambda x: x[1]) if weighted: # uses flwr implementation of weighted loss averaging return weighted_loss_avg(results) diff --git a/fl4health/strategies/basic_fedavg.py b/fl4health/strategies/basic_fedavg.py index c45a03320..461de952c 100644 --- a/fl4health/strategies/basic_fedavg.py +++ b/fl4health/strategies/basic_fedavg.py @@ -12,7 +12,6 @@ Parameters, Scalar, ndarrays_to_parameters, - parameters_to_ndarrays, ) from flwr.common.logger import log from flwr.server.client_manager import ClientManager @@ -23,6 +22,7 @@ from fl4health.client_managers.base_sampling_manager import BaseFractionSamplingManager from fl4health.strategies.aggregate_utils import aggregate_losses, aggregate_results from fl4health.strategies.strategy_with_poll import StrategyWithPolling +from fl4health.utils.functions import decode_and_pseudo_sort_results from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -248,12 +248,15 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} - # Convert results - weights_results = [ - (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + decoded_and_sorted_results = [ + (weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results) ] + # Aggregate them in a weighted or unweighted fashion based on settings. - aggregated_arrays = aggregate_results(weights_results, self.weighted_aggregation) + aggregated_arrays = aggregate_results(decoded_and_sorted_results, self.weighted_aggregation) # Convert back to parameters parameters_aggregated = ndarrays_to_parameters(aggregated_arrays) diff --git a/fl4health/strategies/client_dp_fedavgm.py b/fl4health/strategies/client_dp_fedavgm.py index 48276c160..ae1066a5f 100644 --- a/fl4health/strategies/client_dp_fedavgm.py +++ b/fl4health/strategies/client_dp_fedavgm.py @@ -27,6 +27,7 @@ gaussian_noisy_unweighted_aggregate, gaussian_noisy_weighted_aggregate, ) +from fl4health.utils.functions import decode_and_pseudo_sort_results class ClientLevelDPFedAvgM(BasicFedAvg): @@ -195,13 +196,17 @@ def split_model_weights_and_clipping_bits( Tuple[List[Tuple[NDArrays, int]], NDArrays]: The first tuple is the set of (weights, training counts) per client. The second is a set of clipping bits, one for each client. """ + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + decoded_and_sorted_results = [ + (weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results) + ] + weights_and_counts: List[Tuple[NDArrays, int]] = [] clipping_bits: NDArrays = [] - for _, fit_res in results: - sample_count = fit_res.num_examples - updated_weights, clipping_bit = self.parameter_packer.unpack_parameters( - parameters_to_ndarrays(fit_res.parameters) - ) + for weights, sample_count in decoded_and_sorted_results: + updated_weights, clipping_bit = self.parameter_packer.unpack_parameters(weights) weights_and_counts.append((updated_weights, sample_count)) clipping_bits.append(np.array(clipping_bit)) diff --git a/fl4health/strategies/fedavg_dynamic_layer.py b/fl4health/strategies/fedavg_dynamic_layer.py index 08f0999bf..c6ab977b7 100644 --- a/fl4health/strategies/fedavg_dynamic_layer.py +++ b/fl4health/strategies/fedavg_dynamic_layer.py @@ -4,20 +4,14 @@ from typing import Callable, DefaultDict, Dict, List, Optional, Tuple, Union import numpy as np -from flwr.common import ( - MetricsAggregationFn, - NDArray, - NDArrays, - Parameters, - ndarrays_to_parameters, - parameters_to_ndarrays, -) +from flwr.common import MetricsAggregationFn, NDArray, NDArrays, Parameters, ndarrays_to_parameters from flwr.common.logger import log from flwr.common.typing import FitRes, Scalar from flwr.server.client_proxy import ClientProxy from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithLayerNames from fl4health.strategies.basic_fedavg import BasicFedAvg +from fl4health.utils.functions import decode_and_pseudo_sort_results class FedAvgDynamicLayer(BasicFedAvg): @@ -124,13 +118,17 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + # Convert client layer weights and names into ndarrays - weights_results = [ - (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results + decoded_and_sorted_results = [ + (weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results) ] # For each layer of the model, perform weighted average of all received weights from clients - aggregated_params = self.aggregate(weights_results) + aggregated_params = self.aggregate(decoded_and_sorted_results) weights_names = [] weights = [] diff --git a/fl4health/strategies/fedavg_sparse_coo_tensor.py b/fl4health/strategies/fedavg_sparse_coo_tensor.py index d11d9d295..be8c5d922 100644 --- a/fl4health/strategies/fedavg_sparse_coo_tensor.py +++ b/fl4health/strategies/fedavg_sparse_coo_tensor.py @@ -4,7 +4,7 @@ from typing import Callable, DefaultDict, Dict, List, Optional, Tuple, Union import torch -from flwr.common import MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common import MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters from flwr.common.logger import log from flwr.common.typing import FitRes, Scalar from flwr.server.client_proxy import ClientProxy @@ -12,6 +12,7 @@ from fl4health.parameter_exchange.parameter_packer import SparseCooParameterPacker from fl4health.strategies.basic_fedavg import BasicFedAvg +from fl4health.utils.functions import decode_and_pseudo_sort_results class FedAvgSparseCooTensor(BasicFedAvg): @@ -137,13 +138,17 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + # Convert client tensor weights and names into ndarrays - weights_results = [ - (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results + decoded_and_sorted_results = [ + (weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results) ] # For each tensor of the model, perform weighted average of all received weights from clients - aggregated_tensors = self.aggregate(weights_results) + aggregated_tensors = self.aggregate(decoded_and_sorted_results) tensor_names = [] selected_parameters_all_tensors = [] diff --git a/fl4health/strategies/fedavg_with_adaptive_constraint.py b/fl4health/strategies/fedavg_with_adaptive_constraint.py index 735fcdd62..b438f97e4 100644 --- a/fl4health/strategies/fedavg_with_adaptive_constraint.py +++ b/fl4health/strategies/fedavg_with_adaptive_constraint.py @@ -10,6 +10,7 @@ from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint from fl4health.strategies.aggregate_utils import aggregate_losses, aggregate_results from fl4health.strategies.basic_fedavg import BasicFedAvg +from fl4health.utils.functions import decode_and_pseudo_sort_results class FedAvgWithAdaptiveConstraint(BasicFedAvg): @@ -157,14 +158,18 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + decoded_and_sorted_results = [ + (weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results) + ] + # Convert results with packed params of model weights and training loss weights_and_counts: List[Tuple[NDArrays, int]] = [] train_losses_and_counts: List[Tuple[int, float]] = [] - for _, fit_res in results: - sample_count = fit_res.num_examples - updated_weights, train_loss = self.parameter_packer.unpack_parameters( - parameters_to_ndarrays(fit_res.parameters) - ) + for weights, sample_count in decoded_and_sorted_results: + updated_weights, train_loss = self.parameter_packer.unpack_parameters(weights) weights_and_counts.append((updated_weights, sample_count)) train_losses_and_counts.append((sample_count, train_loss)) diff --git a/fl4health/strategies/feddg_ga.py b/fl4health/strategies/feddg_ga.py index db4f1ef24..108758707 100644 --- a/fl4health/strategies/feddg_ga.py +++ b/fl4health/strategies/feddg_ga.py @@ -3,14 +3,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np -from flwr.common import ( - EvaluateIns, - MetricsAggregationFn, - NDArrays, - Parameters, - ndarrays_to_parameters, - parameters_to_ndarrays, -) +from flwr.common import EvaluateIns, MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters from flwr.common.logger import log from flwr.common.typing import EvaluateRes, FitIns, FitRes, Scalar from flwr.server.client_manager import ClientManager @@ -18,6 +11,7 @@ from flwr.server.strategy import FedAvg from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager +from fl4health.utils.functions import decode_and_pseudo_sort_results class SignalForTypeException(Exception): @@ -329,6 +323,7 @@ def aggregate_evaluate( (Tuple[Optional[float], Dict[str, Scalar]]) A tuple containing the aggregated evaluation loss and the aggregated evaluation metrics. """ + loss_aggregated, metrics_aggregated = super().aggregate_evaluate(server_round, results, failures) self.evaluation_metrics = {} @@ -363,8 +358,13 @@ def weight_and_aggregate_results(self, results: List[Tuple[ClientProxy, FitRes]] # and will be below. log(INFO, f"Current adjustment weights are all initialized to {self.initial_adjustment_weight}") + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + decoded_and_sorted_results = decode_and_pseudo_sort_results(results) + aggregated_results: Optional[NDArrays] = None - for client_proxy, fit_res in results: + for client_proxy, weights, _ in decoded_and_sorted_results: cid = client_proxy.cid # initializing adjustment weights for this client if they don't exist yet @@ -373,7 +373,7 @@ def weight_and_aggregate_results(self, results: List[Tuple[ClientProxy, FitRes]] self.adjustment_weights[cid] = self.initial_adjustment_weight # apply adjustment weights - weighted_client_parameters = parameters_to_ndarrays(fit_res.parameters) + weighted_client_parameters = weights for i in range(len(weighted_client_parameters)): weighted_client_parameters[i] = weighted_client_parameters[i] * self.adjustment_weights[cid] diff --git a/fl4health/strategies/fedpca.py b/fl4health/strategies/fedpca.py index a0e4c5979..97ba6f8b5 100644 --- a/fl4health/strategies/fedpca.py +++ b/fl4health/strategies/fedpca.py @@ -2,19 +2,13 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np -from flwr.common import ( - MetricsAggregationFn, - NDArray, - NDArrays, - Parameters, - ndarrays_to_parameters, - parameters_to_ndarrays, -) +from flwr.common import MetricsAggregationFn, NDArray, NDArrays, Parameters, ndarrays_to_parameters from flwr.common.logger import log from flwr.common.typing import FitRes, Scalar from flwr.server.client_proxy import ClientProxy from fl4health.strategies.basic_fedavg import BasicFedAvg +from fl4health.utils.functions import decode_and_pseudo_sort_results class FedPCA(BasicFedAvg): @@ -122,10 +116,14 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + decoded_and_sorted_results = [weights for _, weights, _ in decode_and_pseudo_sort_results(results)] + client_singular_values = [] client_singular_vectors = [] - for _, fit_res in results: - A = parameters_to_ndarrays(fit_res.parameters) + for A in decoded_and_sorted_results: singular_vectors, singular_values = A[0], A[1] client_singular_vectors.append(singular_vectors) client_singular_values.append(singular_values) diff --git a/fl4health/strategies/flash.py b/fl4health/strategies/flash.py index 789226c47..97fbe8dae 100644 --- a/fl4health/strategies/flash.py +++ b/fl4health/strategies/flash.py @@ -149,9 +149,11 @@ def aggregate_fit( failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: """Aggregate fit results using the Flash method.""" + fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit( server_round=server_round, results=results, failures=failures ) + if fedavg_parameters_aggregated is None: return None, {} diff --git a/fl4health/strategies/model_merge_strategy.py b/fl4health/strategies/model_merge_strategy.py index c30e723e9..cd5e06cf3 100644 --- a/fl4health/strategies/model_merge_strategy.py +++ b/fl4health/strategies/model_merge_strategy.py @@ -20,6 +20,7 @@ from fl4health.client_managers.base_sampling_manager import BaseFractionSamplingManager from fl4health.strategies.aggregate_utils import aggregate_results +from fl4health.utils.functions import decode_and_pseudo_sort_results class ModelMergeStrategy(Strategy): @@ -188,12 +189,15 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} - # Convert results - weights_results = [ - (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + decoded_and_sorted_results = [ + (weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results) ] + # Aggregate them in an weighted or unweighted fashion based on self.weighted_aggregation. - aggregated_arrays = aggregate_results(weights_results, self.weighted_aggregation) + aggregated_arrays = aggregate_results(decoded_and_sorted_results, self.weighted_aggregation) # Convert back to parameters parameters_aggregated = ndarrays_to_parameters(aggregated_arrays) @@ -246,7 +250,7 @@ def aggregate_evaluate( def evaluate(self, server_round: int, parameters: Parameters) -> Optional[Tuple[float, Dict[str, Scalar]]]: """ - Evaluate the model parameters after the merging has occured. This function can be used to perform centralized + Evaluate the model parameters after the merging has occurred. This function can be used to perform centralized (i.e., server-side) evaluation of model parameters. Args: diff --git a/fl4health/strategies/scaffold.py b/fl4health/strategies/scaffold.py index 842f56630..df2f8edf2 100644 --- a/fl4health/strategies/scaffold.py +++ b/fl4health/strategies/scaffold.py @@ -21,6 +21,7 @@ from fl4health.client_managers.base_sampling_manager import BaseFractionSamplingManager from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithControlVariates from fl4health.strategies.basic_fedavg import BasicFedAvg +from fl4health.utils.functions import decode_and_pseudo_sort_results from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -179,12 +180,14 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} - # Convert results with packed params of model weights and client control variate updates - updated_params = [parameters_to_ndarrays(fit_res.parameters) for _, fit_res in results] + # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in + # summing the numpy arrays during aggregation. This ensures that addition will occur in the same order, + # reducing numerical fluctuation. + decoded_and_sorted_results = [weights for _, weights, _ in decode_and_pseudo_sort_results(results)] # x = 1 / |S| * sum(x_i) and c = 1 / |S| * sum(delta_c_i) # Aggregation operation over packed params (includes both weights and control variate updates) - aggregated_params = self.aggregate(updated_params) + aggregated_params = self.aggregate(decoded_and_sorted_results) weights, control_variates_update = self.parameter_packer.unpack_parameters(aggregated_params) diff --git a/fl4health/utils/functions.py b/fl4health/utils/functions.py index a546b42f9..460aa949e 100644 --- a/fl4health/utils/functions.py +++ b/fl4health/utils/functions.py @@ -1,6 +1,10 @@ -from typing import Any, Tuple +from typing import Any, List, Tuple +import numpy as np import torch +from flwr.common import parameters_to_ndarrays +from flwr.common.typing import FitRes, NDArrays +from flwr.server.client_proxy import ClientProxy class BernoulliSample(torch.autograd.Function): @@ -41,3 +45,65 @@ def backward(ctx: torch.Any, grad_output: torch.Tensor) -> torch.Tensor: # type def sigmoid_inverse(x: torch.Tensor) -> torch.Tensor: return -torch.log(1 / x - 1) + + +def select_zeroeth_element(array: np.ndarray) -> float: + """ + Helper function that simply selects the first element of an array (index 0 across all dimensions). + + Args: + array (np.ndarray): Array from which the very first element is selected + + Returns: + float: zeroeth element value. + """ + indices = tuple(0 for _ in array.shape) + return array[indices] + + +def pseudo_sort_scoring_function(client_result: Tuple[ClientProxy, NDArrays, int]) -> float: + """ + This function provides the "score" that is used to sort a list of Tuple[ClientProxy, NDArrays, int]. We select + the zeroeth (index 0 across all dimensions) element from each of the arrays in the NDArrays list, sum them, and + add the integer (client sample counts) to the sum to come up with a score for sorting. Note that + the underlying numpy arrays in NDArrays may not all be of numerical type. So we limit to selecting elements from + arrays of floats. + + Args: + client_result (Tuple[ClientProxy, NDArrays, int]]): Elements to use to determine the score. + + Returns: + float: Sum of a the zeroeth elements of each array in the NDArrays and the int of the tuple + """ + _, client_arrays, sample_count = client_result + zeroeth_params = [ + select_zeroeth_element(array) for array in client_arrays if np.issubdtype(array.dtype, np.floating) + ] + return np.sum(zeroeth_params) + sample_count + + +def decode_and_pseudo_sort_results( + results: List[Tuple[ClientProxy, FitRes]] +) -> List[Tuple[ClientProxy, NDArrays, int]]: + """ + This function is used to convert the results of client training into NDArrays and to apply a pseudo sort + based on the zeroeth elements in the weights and the sample counts. As long as the numpy seed has been set on the + server this process should be deterministic when repeatedly running the same server code leading to deterministic + sorting (assuming the clients are deterministically training their weights as well). This allows, for example, + for weights from the clients to be summed in a deterministic order during aggregation. + + NOTE: Client proxies would be nice to use for this task, but the CIDs are set by uuid deep in the flower library + and are, therefore, not pinnable without a ton of work. + + Args: + results (List[Tuple[ClientProxy, FitRes]]): Results from a federated training round. + + Returns: + List[Tuple[ClientProxy, NDArrays, int]]: The ordered set of weights as NDarrays and the corresponding + number of examples + """ + ndarrays_results = [ + (client_proxy, parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) + for client_proxy, fit_res in results + ] + return sorted(ndarrays_results, key=lambda x: pseudo_sort_scoring_function(x)) diff --git a/research/picai/reporting/__init__.py b/research/picai/reporting/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/utils/functions_test.py b/tests/utils/functions_test.py index 8da907ebd..1f8b024b9 100644 --- a/tests/utils/functions_test.py +++ b/tests/utils/functions_test.py @@ -1,6 +1,20 @@ +from typing import List, Tuple + +import numpy as np +import pytest import torch +from flwr.common import Code, Status, ndarrays_to_parameters +from flwr.common.typing import FitRes, NDArrays +from flwr.server.client_proxy import ClientProxy -from fl4health.utils.functions import bernoulli_sample, sigmoid_inverse +from fl4health.utils.functions import ( + bernoulli_sample, + decode_and_pseudo_sort_results, + pseudo_sort_scoring_function, + select_zeroeth_element, + sigmoid_inverse, +) +from tests.test_utils.custom_client_proxy import CustomClientProxy def test_bernoulli_gradient() -> None: @@ -21,3 +35,59 @@ def test_sigmoid_inverse() -> None: z = torch.sigmoid(x) assert torch.allclose(sigmoid_inverse(z), x) torch.seed() + + +def test_select_zeroeth_element() -> None: + np.random.seed(42) + array = np.random.rand(10, 10) + random_element = select_zeroeth_element(array) + assert pytest.approx(random_element, abs=1e-5) == 0.3745401188473625 + np.random.seed(None) + + +def test_pseudo_sort_scoring_function() -> None: + np.random.seed(42) + array_list = [np.random.rand(10, 10) for _ in range(2)] + [np.random.rand(5, 5) for _ in range(2)] + sort_value = pseudo_sort_scoring_function((CustomClientProxy("c0"), array_list, 13)) + assert pytest.approx(sort_value, abs=1e-5) == 14.291990594067467 + np.random.seed(None) + + +def test_pseudo_sort_scoring_function_with_mixed_types() -> None: + np.random.seed(42) + array_list = ( + [np.random.rand(10, 10) for _ in range(2)] + + [np.array(["Cat", "Dog"]), np.array([True, False])] + + [np.random.rand(5, 5) for _ in range(2)] + ) + sort_value = pseudo_sort_scoring_function((CustomClientProxy("c0"), array_list, 13)) + assert pytest.approx(sort_value, abs=1e-5) == 14.291990594067467 + np.random.seed(None) + + +def construct_fit_res(parameters: NDArrays, metric: float, num_examples: int) -> FitRes: + return FitRes( + status=Status(Code.OK, ""), + parameters=ndarrays_to_parameters(parameters), + num_examples=num_examples, + metrics={"metric": metric}, + ) + + +def test_decode_and_pseudo_sort_results() -> None: + np.random.seed(42) + client0_res = construct_fit_res([np.ones((3, 3)), np.ones((4, 4))], 0.1, 100) + client1_res = construct_fit_res([np.ones((3, 3)), np.full((4, 4), 2.0)], 0.2, 75) + client2_res = construct_fit_res([np.full((3, 3), 3.0), np.full((4, 4), 3.0)], 0.3, 50) + clients_res: List[Tuple[ClientProxy, FitRes]] = [ + (CustomClientProxy("c0"), client0_res), + (CustomClientProxy("c1"), client1_res), + (CustomClientProxy("c2"), client2_res), + ] + + sorted_results = decode_and_pseudo_sort_results(clients_res) + assert sorted_results[0][2] == 50 + assert sorted_results[1][2] == 75 + assert sorted_results[2][2] == 100 + + np.random.seed(None)