Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add max_num_validation_steps member of config and client and related … #304

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/smoke_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:

jobs:
test:
runs-on: ubuntu-latest
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
Expand Down
35 changes: 33 additions & 2 deletions fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def __init__(
self.num_val_samples: int
self.num_test_samples: Optional[int] = None
self.learning_rate: Optional[float] = None
# Config can contain max_num_validation_steps key, which determines an upper bound
# for the validation steps taken. If not specified, no upper bound will be enforced.
# By specifying this in the config we cannot guarantee the validation set is the same
# accross rounds for clients.
self.max_num_validation_steps: int | None = None

def _maybe_checkpoint(self, loss: float, metrics: dict[str, Scalar], checkpoint_mode: CheckpointMode) -> None:
"""
Expand Down Expand Up @@ -231,6 +236,7 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N
"""
current_server_round = narrow_dict_type(config, "current_server_round", int)

# Parse config to determine train by steps or train by epochs
if ("local_epochs" in config) and ("local_steps" in config):
raise ValueError("Config cannot contain both local_epochs and local_steps. Please specify only one.")
elif "local_epochs" in config:
Expand Down Expand Up @@ -726,7 +732,8 @@ def _validate_or_test(
include_losses_in_metrics: bool = False,
) -> Tuple[float, Dict[str, Scalar]]:
"""
Evaluate the model on the given validation or test dataset.
Evaluate the model on the given validation or test dataset. If max_num_validation_steps attribute
is not None and in validation phase, steps are limited to the value of max_num_validation_steps.

Args:
loader (DataLoader): The data loader for the dataset (validation or test).
Expand All @@ -748,7 +755,10 @@ def _validate_or_test(
metric_manager.clear()
loss_meter.clear()
with torch.no_grad():
for input, target in maybe_progress_bar(loader, self.progress_bar):
for i, (input, target) in enumerate(maybe_progress_bar(loader, self.progress_bar)):
jewelltaylor marked this conversation as resolved.
Show resolved Hide resolved
# Limit validation to self.max_num_validation_steps if it is defined
if logging_mode == LoggingMode.VALIDATION and self.max_num_validation_steps == i:
break
input = move_data_to_device(input, self.device)
target = move_data_to_device(target, self.device)
losses, preds = self.val_step(input, target)
Expand Down Expand Up @@ -830,11 +840,32 @@ def setup_client(self, config: Config) -> None:
self.val_loader = val_loader
self.test_loader = self.get_test_data_loader(config)

if "max_num_validation_steps" in config:
log(
INFO,
"""
max_num_validation_steps specified in config. Only a random subset of batches will \
be sampled from the validation set if max_num_validation_steps is greater \
than the number of batches in the validation dataloader.
""",
)
self.max_num_validation_steps = narrow_dict_type(config, "max_num_validation_steps", int)
else:
self.max_num_validation_steps = None

# The following lines are type ignored because torch datasets are not "Sized"
# IE __len__ is considered optionally defined. In practice, it is almost always defined
# and as such, we will make that assumption.
self.num_train_samples = len(self.train_loader.dataset) # type: ignore

# if max_num_validation_steps is defined, limit validation set to minimum of
# batch_size * max_num_validation_steps and the length of validation set
self.num_val_samples = len(self.val_loader.dataset) # type: ignore
if self.max_num_validation_steps is not None:
val_batch_size = self.val_loader.batch_size
max_val_size = self.max_num_validation_steps * val_batch_size # type: ignore
self.num_val_samples = min(self.num_val_samples, max_val_size) # type: ignore

if self.test_loader:
self.num_test_samples = len(self.test_loader.dataset) # type: ignore

Expand Down
20 changes: 18 additions & 2 deletions tests/clients/test_basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import freezegun
import torch
from flwr.common import Scalar
from flwr.common.typing import Config
from freezegun import freeze_time

from fl4health.clients.basic_client import BasicClient
Expand Down Expand Up @@ -83,7 +84,7 @@ def test_metrics_reporter_evaluate() -> None:
"testing_metric": 1234,
"val - checkpoint": 123.123,
"test - checkpoint": 123.123,
"test - num_examples": 0,
"test - num_examples": 32,
}
reporter = JsonReporter()
fl_client = MockBasicClient(
Expand Down Expand Up @@ -137,6 +138,19 @@ def test_evaluate_after_fit_disabled() -> None:
fl_client.validate.assert_not_called() # type: ignore


def test_num_val_samples_correct() -> None:
fl_client_no_max = MockBasicClient()
fl_client_no_max.setup_client({})
assert fl_client_no_max.max_num_validation_steps is None
assert fl_client_no_max.num_val_samples == 32

fl_client_max = MockBasicClient()
config: Config = {"max_num_validation_steps": 2}
fl_client_max.setup_client(config)
assert fl_client_max.max_num_validation_steps == 2
assert fl_client_max.num_val_samples == 8


class MockBasicClient(BasicClient):
def __init__(
self,
Expand Down Expand Up @@ -165,6 +179,7 @@ def __init__(
self.test_loader = MagicMock()
self.num_train_samples = 0
self.num_val_samples = 0
self.max_num_validation_steps = None

# Mocking methods
self.set_parameters = MagicMock() # type: ignore
Expand All @@ -176,7 +191,8 @@ def __init__(
self.get_model = MagicMock() # type: ignore
self.get_data_loaders = MagicMock() # type: ignore
mock_data_loader = MagicMock() # type: ignore
mock_data_loader.dataset = []
mock_data_loader.batch_size = 4
mock_data_loader.dataset = [None] * 32
self.get_data_loaders.return_value = mock_data_loader, mock_data_loader
self.get_test_data_loader = MagicMock() # type: ignore
self.get_test_data_loader.return_value = mock_data_loader
Expand Down
Loading