diff --git a/.github/workflows/smoke_tests.yaml b/.github/workflows/smoke_tests.yaml index 2d0c48828..1da21d605 100644 --- a/.github/workflows/smoke_tests.yaml +++ b/.github/workflows/smoke_tests.yaml @@ -10,7 +10,7 @@ on: jobs: test: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Checkout code uses: actions/checkout@v4 diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 92b278e96..5ead79688 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -126,6 +126,11 @@ def __init__( self.num_val_samples: int self.num_test_samples: Optional[int] = None self.learning_rate: Optional[float] = None + # Config can contain max_num_validation_steps key, which determines an upper bound + # for the validation steps taken. If not specified, no upper bound will be enforced. + # By specifying this in the config we cannot guarantee the validation set is the same + # accross rounds for clients. + self.max_num_validation_steps: int | None = None def _maybe_checkpoint(self, loss: float, metrics: dict[str, Scalar], checkpoint_mode: CheckpointMode) -> None: """ @@ -231,6 +236,7 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N """ current_server_round = narrow_dict_type(config, "current_server_round", int) + # Parse config to determine train by steps or train by epochs if ("local_epochs" in config) and ("local_steps" in config): raise ValueError("Config cannot contain both local_epochs and local_steps. Please specify only one.") elif "local_epochs" in config: @@ -726,7 +732,8 @@ def _validate_or_test( include_losses_in_metrics: bool = False, ) -> Tuple[float, Dict[str, Scalar]]: """ - Evaluate the model on the given validation or test dataset. + Evaluate the model on the given validation or test dataset. If max_num_validation_steps attribute + is not None and in validation phase, steps are limited to the value of max_num_validation_steps. Args: loader (DataLoader): The data loader for the dataset (validation or test). @@ -748,7 +755,10 @@ def _validate_or_test( metric_manager.clear() loss_meter.clear() with torch.no_grad(): - for input, target in maybe_progress_bar(loader, self.progress_bar): + for i, (input, target) in enumerate(maybe_progress_bar(loader, self.progress_bar)): + # Limit validation to self.max_num_validation_steps if it is defined + if logging_mode == LoggingMode.VALIDATION and self.max_num_validation_steps == i: + break input = move_data_to_device(input, self.device) target = move_data_to_device(target, self.device) losses, preds = self.val_step(input, target) @@ -830,11 +840,32 @@ def setup_client(self, config: Config) -> None: self.val_loader = val_loader self.test_loader = self.get_test_data_loader(config) + if "max_num_validation_steps" in config: + log( + INFO, + """ + max_num_validation_steps specified in config. Only a random subset of batches will \ + be sampled from the validation set if max_num_validation_steps is greater \ + than the number of batches in the validation dataloader. + """, + ) + self.max_num_validation_steps = narrow_dict_type(config, "max_num_validation_steps", int) + else: + self.max_num_validation_steps = None + # The following lines are type ignored because torch datasets are not "Sized" # IE __len__ is considered optionally defined. In practice, it is almost always defined # and as such, we will make that assumption. self.num_train_samples = len(self.train_loader.dataset) # type: ignore + + # if max_num_validation_steps is defined, limit validation set to minimum of + # batch_size * max_num_validation_steps and the length of validation set self.num_val_samples = len(self.val_loader.dataset) # type: ignore + if self.max_num_validation_steps is not None: + val_batch_size = self.val_loader.batch_size + max_val_size = self.max_num_validation_steps * val_batch_size # type: ignore + self.num_val_samples = min(self.num_val_samples, max_val_size) # type: ignore + if self.test_loader: self.num_test_samples = len(self.test_loader.dataset) # type: ignore diff --git a/tests/clients/test_basic_client.py b/tests/clients/test_basic_client.py index bcd7a7dae..337d72f01 100644 --- a/tests/clients/test_basic_client.py +++ b/tests/clients/test_basic_client.py @@ -7,6 +7,7 @@ import freezegun import torch from flwr.common import Scalar +from flwr.common.typing import Config from freezegun import freeze_time from fl4health.clients.basic_client import BasicClient @@ -83,7 +84,7 @@ def test_metrics_reporter_evaluate() -> None: "testing_metric": 1234, "val - checkpoint": 123.123, "test - checkpoint": 123.123, - "test - num_examples": 0, + "test - num_examples": 32, } reporter = JsonReporter() fl_client = MockBasicClient( @@ -137,6 +138,19 @@ def test_evaluate_after_fit_disabled() -> None: fl_client.validate.assert_not_called() # type: ignore +def test_num_val_samples_correct() -> None: + fl_client_no_max = MockBasicClient() + fl_client_no_max.setup_client({}) + assert fl_client_no_max.max_num_validation_steps is None + assert fl_client_no_max.num_val_samples == 32 + + fl_client_max = MockBasicClient() + config: Config = {"max_num_validation_steps": 2} + fl_client_max.setup_client(config) + assert fl_client_max.max_num_validation_steps == 2 + assert fl_client_max.num_val_samples == 8 + + class MockBasicClient(BasicClient): def __init__( self, @@ -165,6 +179,7 @@ def __init__( self.test_loader = MagicMock() self.num_train_samples = 0 self.num_val_samples = 0 + self.max_num_validation_steps = None # Mocking methods self.set_parameters = MagicMock() # type: ignore @@ -176,7 +191,8 @@ def __init__( self.get_model = MagicMock() # type: ignore self.get_data_loaders = MagicMock() # type: ignore mock_data_loader = MagicMock() # type: ignore - mock_data_loader.dataset = [] + mock_data_loader.batch_size = 4 + mock_data_loader.dataset = [None] * 32 self.get_data_loaders.return_value = mock_data_loader, mock_data_loader self.get_test_data_loader = MagicMock() # type: ignore self.get_test_data_loader.return_value = mock_data_loader