Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade fl4heath and made appropriate changes #110

Merged
merged 14 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/docs_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ jobs:
virtualenvs-in-project: true
- uses: actions/[email protected]
with:
python-version: '3.9'
python-version: '3.10'
cache: 'poetry'
- run: |
python3 -m pip install --upgrade pip && python3 -m pip install poetry
poetry env use '3.9'
poetry env use '3.10'
source .venv/bin/activate
poetry install --with docs,test
cd docs && rm -rf source/reference/api/_autosummary && make html
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/docs_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ jobs:
virtualenvs-in-project: true
- uses: actions/[email protected]
with:
python-version: '3.9'
python-version: '3.10'
cache: 'poetry'
- run: |
python3 -m pip install --upgrade pip && python3 -m pip install poetry
poetry env use '3.9'
poetry env use '3.10'
source $(poetry env info --path)/bin/activate
poetry install --with docs,test
cd docs && rm -rf source/reference/api/_autosummary && make html
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/integration_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
virtualenvs-in-project: true
- uses: actions/[email protected]
with:
python-version: '3.9'
python-version: '3.10'
- name: Setup redis
uses: supercharge/[email protected]
with:
Expand All @@ -55,7 +55,7 @@ jobs:
mongodb-version: 7.0.8
- name: Install dependencies and check code
run: |
poetry env use '3.9'
poetry env use '3.10'
source .venv/bin/activate
poetry install --with docs,test
coverage run -m pytest florist/tests/integration && coverage xml && coverage report -m
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
virtualenvs-in-project: true
- uses: actions/[email protected]
with:
python-version: '3.9'
python-version: '3.10'
- name: Build package
run: poetry build
- name: Publish package
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/static_code_checks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ jobs:
virtualenvs-in-project: true
- uses: actions/[email protected]
with:
python-version: '3.9'
python-version: '3.10'
cache: 'poetry'
- name: Setup yarn
uses: mskelton/setup-yarn@v3
- name: Install dependencies and check code
run: |
yarn
poetry env use '3.9'
poetry env use '3.10'
source .venv/bin/activate
poetry install --with test --all-extras
pre-commit run --all-files
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ jobs:
virtualenvs-in-project: true
- uses: actions/[email protected]
with:
python-version: '3.9'
python-version: '3.10'
- name: Install python dependencies and check code
run: |
poetry env use '3.9'
poetry env use '3.10'
source .venv/bin/activate
poetry install --with docs,test
coverage run -m pytest florist/tests/unit && coverage xml && coverage report -m
Expand Down
2 changes: 1 addition & 1 deletion florist/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red
data_path=Path(data_path),
metrics=[],
device=device,
metrics_reporter=metrics_reporter,
reporters=[metrics_reporter],
)

log_file_name = str(get_client_log_file_path(client_uuid))
Expand Down
6 changes: 3 additions & 3 deletions florist/api/clients/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
from fl4health.clients.basic_client import BasicClient
from fl4health.utils.dataset import MnistDataset
from fl4health.utils.dataset import TensorDataset
from fl4health.utils.load_data import load_mnist_data
from flwr.common.typing import Config
from torch import nn
Expand All @@ -18,15 +18,15 @@
class MnistClient(BasicClient): # type: ignore[misc]
"""Implementation of the MNIST client."""

def get_data_loaders(self, config: Config) -> Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]:
def get_data_loaders(self, config: Config) -> Tuple[DataLoader[TensorDataset], DataLoader[TensorDataset]]:
"""
Return the data loader for MNIST data.

:param config: (Config) the Config object for this client.
:return: (Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]) a tuple with the train data loader
and validation data loader respectively.
"""
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size=config["batch_size"])
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size=int(config["batch_size"]))
return train_loader, val_loader

def get_model(self, config: Config) -> nn.Module:
Expand Down
10 changes: 5 additions & 5 deletions florist/api/db/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class Job(BaseModel):
clients_info: Optional[Annotated[List[ClientInfo], Field(...)]]

@classmethod
async def find_by_id(cls, job_id: str, database: AsyncIOMotorDatabase) -> Optional["Job"]:
async def find_by_id(cls, job_id: str, database: AsyncIOMotorDatabase[Any]) -> Optional["Job"]:
"""
Find a job in the database by its id.

Expand All @@ -98,7 +98,7 @@ async def find_by_id(cls, job_id: str, database: AsyncIOMotorDatabase) -> Option
return Job(**result)

@classmethod
async def find_by_status(cls, status: JobStatus, limit: int, database: AsyncIOMotorDatabase) -> List["Job"]:
async def find_by_status(cls, status: JobStatus, limit: int, database: AsyncIOMotorDatabase[Any]) -> List["Job"]:
"""
Return all jobs with the given status.

Expand All @@ -114,7 +114,7 @@ async def find_by_status(cls, status: JobStatus, limit: int, database: AsyncIOMo
assert isinstance(result, list)
return [Job(**r) for r in result]

async def create(self, database: AsyncIOMotorDatabase) -> str:
async def create(self, database: AsyncIOMotorDatabase[Any]) -> str:
"""
Save this instance under a new record in the database.

Expand All @@ -126,7 +126,7 @@ async def create(self, database: AsyncIOMotorDatabase) -> str:
assert isinstance(result.inserted_id, str)
return result.inserted_id

async def set_uuids(self, server_uuid: str, client_uuids: List[str], database: AsyncIOMotorDatabase) -> None:
async def set_uuids(self, server_uuid: str, client_uuids: List[str], database: AsyncIOMotorDatabase[Any]) -> None:
"""
Save the server and clients' UUIDs in the database under the current job's id.

Expand All @@ -152,7 +152,7 @@ async def set_uuids(self, server_uuid: str, client_uuids: List[str], database: A
)
assert_updated_successfully(update_result)

async def set_status(self, status: JobStatus, database: AsyncIOMotorDatabase) -> None:
async def set_status(self, status: JobStatus, database: AsyncIOMotorDatabase[Any]) -> None:
"""
Save the status in the database under the current job's id.

Expand Down
6 changes: 3 additions & 3 deletions florist/api/launchers/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import flwr as fl
from fl4health.clients.basic_client import BasicClient
from fl4health.server.base_server import FlServer
from flwr.common.logger import DEFAULT_FORMATTER
from flwr.server import ServerConfig


DEFAULT_FORMATTER = logging.Formatter("%(levelname)s %(name)s %(asctime)s | %(filename)s:%(lineno)d | %(message)s")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess flwr removed the DEFAULT_FORMATTER they were exporting previously.

Here there reference it in the documentation so I just sourced it from there: https://flower.ai/docs/framework/how-to-configure-logging.html



def redirect_logging_from_console_to_file(log_file_path: str) -> None:
"""
Redirect loggers outputting to console to specified file.
Expand Down Expand Up @@ -60,7 +62,6 @@ def start_server(
config=ServerConfig(num_rounds=n_server_rounds),
)
server.shutdown()
server.metrics_reporter.dump()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are dumping on every call to report I figured we can remove these?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess, yes... I put it there just in case, it doesn't hurt, but I guess it's better to remove if it's not really necessary :)



def start_client(client: BasicClient, server_address: str, client_log_file_name: str) -> None:
Expand All @@ -79,7 +80,6 @@ def start_client(client: BasicClient, server_address: str, client_log_file_name:
sys.stderr = log_file
fl.client.start_numpy_client(server_address=server_address, client=client)
client.shutdown()
client.metrics_reporter.dump()


def launch_server(
Expand Down
88 changes: 70 additions & 18 deletions florist/api/monitoring/metrics.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
"""Classes for the instrumentation of metrics reporting from clients and servers."""

import datetime
import json
import time
import uuid
from logging import DEBUG, Logger
from typing import Any, Dict, Optional

import redis
from fl4health.reporting.metrics import DateTimeEncoder, MetricsReporter
from fl4health.reporting.base_reporter import BaseReporter
from flwr.common.logger import log
from redis.client import PubSub


class RedisMetricsReporter(MetricsReporter): # type: ignore
class DateTimeEncoder(json.JSONEncoder):
"""Converts a datetime object to string in order to make json encoding easier."""

def default(self, o: Any) -> Any:
"""
Return string of datetime if datetime object is passed else return result of the default encoder method.

:param o: Object to encode.
"""
if isinstance(o, datetime.datetime):
return str(o)
return json.JSONEncoder.default(self, o)


class RedisMetricsReporter(BaseReporter): # type: ignore
"""
Save the metrics to a Redis instance while it records them.

Expand All @@ -27,32 +43,66 @@ def __init__(self, host: str, port: str, run_id: Optional[str] = None):
:param run_id: (Optional[str]) the identifier for the run which these metrics are from.
It will be used as the name of the object in Redis. Optional, default is a random UUID.
"""
super().__init__(run_id)
self.host = host
self.port = port
self.run_id = run_id
self.initialized = False

self.redis_connection: Optional[redis.Redis] = None
self.metrics: Dict[str, Any] = {}

def add_to_metrics(self, data: Dict[str, Any]) -> None:
def initialize(self, **kwargs: Any) -> None:
"""
Add a dictionary of data into the main metrics dictionary.
Initialize RedisMetricReporter with run_id and set initialized to True.

At the end, dumps the current state of the metrics to Redis.

:param data: (Dict[str, Any]) Data to be added to the metrics dictionary via .update().
:param kwargs: (Any) The keyword arguments required to initialize the Reporter.
"""
super().add_to_metrics(data)
self.dump()

def add_to_metrics_at_round(self, fl_round: int, data: Dict[str, Any]) -> None:
# If run_id was not specified on init try first to initialize with client name
if self.run_id is None:
self.run_id = kwargs.get("id")
# If client name was not provided, init run id manually
if self.run_id is None:
self.run_id = str(uuid.uuid4())

self.initialized = True

def report(
self,
data: dict[str, Any],
round: int | None = None, # noqa: A002
epoch: int | None = None,
step: int | None = None,
) -> None:
"""Send data to the reporter.

The report method is called by the client/server at frequent intervals (ie step, epoch, round) and sometimes
outside of a FL round (for high level summary data). The json reporter is hardcoded to report at the 'round'
level and therefore ignores calls to the report method made every epoch or every step.

Args:
data (dict): The data to maybe report from the server or client.
round (int | None, optional): The current FL round. If None, this indicates that the method was called
outside of a round (e.g. for summary information). Defaults to None.
epoch (int | None, optional): The current epoch. If None then this method was not called within the scope
of an epoch. Defaults to None.
step (int | None, optional): The current step (total). If None then this method was called outside the
scope of a training or evaluation step (eg. at the end of an epoch or round) Defaults to None.
"""
Add a dictionary of data into the metrics dictionary for a specific FL round.
if not self.initialized:
kwargs = {"run_id": self.run_id} if self.run_id is not None else {}
self.initialize(**kwargs)

At the end, dumps the current state of the metrics to Redis.
if round is None: # Reports outside of a fit round
self.metrics.update(data)
# Ensure we don't report for each epoch or step
elif epoch is None and step is None:
if "rounds" not in self.metrics:
self.metrics["rounds"] = {}
if round not in self.metrics["rounds"]:
self.metrics["rounds"][round] = {}

self.metrics["rounds"][round].update(data)

:param fl_round: (int) the FL round these metrics are from.
:param data: (Dict[str, Any]) Data to be added to the round's metrics dictionary via .update().
"""
super().add_to_metrics_at_round(fl_round, data)
self.dump()

def dump(self) -> None:
Expand All @@ -64,6 +114,8 @@ def dump(self) -> None:
if self.redis_connection is None:
self.redis_connection = redis.Redis(host=self.host, port=self.port)

assert self.run_id is not None, "Run ID is None, ensure reporter is initialized prior to dumping metrics."

encoded_metrics = json.dumps(self.metrics, cls=DateTimeEncoder)
log(DEBUG, f"Dumping metrics to redis at key '{self.run_id}': {encoded_metrics}")
self.redis_connection.set(self.run_id, encoded_metrics)
Expand Down
2 changes: 1 addition & 1 deletion florist/api/servers/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def launch_local_server(
metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=server_uuid)
server_constructor = partial(
get_server,
reporters=[metrics_reporter],
model=model,
n_clients=n_clients,
batch_size=batch_size,
local_epochs=local_epochs,
metrics_reporter=metrics_reporter,
)

log_file_name = str(get_server_log_file_path(server_uuid))
Expand Down
8 changes: 4 additions & 4 deletions florist/api/servers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Dict, Union

from fl4health.client_managers.base_sampling_manager import SimpleClientManager
from fl4health.reporting.metrics import MetricsReporter
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.server.base_server import FlServer
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from flwr.common.parameter import ndarrays_to_parameters
Expand Down Expand Up @@ -33,11 +33,11 @@ def fit_config(batch_size: int, local_epochs: int, current_server_round: int) ->

def get_server(
model: nn.Module,
reporters: list[BaseReporter],
fit_config: Callable[[int, int, int], Dict[str, int]] = fit_config,
n_clients: int = 2,
batch_size: int = 8,
local_epochs: int = 1,
metrics_reporter: MetricsReporter = None,
) -> FlServer:
"""
Return a server instance with FedAvg aggregation strategy.
Expand All @@ -47,7 +47,7 @@ def get_server(
:param n_clients: (int) the number of clients that will participate on training. Optional, default is 2.
:param batch_size: (int) the size of the batch of samples. Optional, default is 8.
:param local_epochs: (int) the number of local epochs the clients will run. Optional, default is 1.
:param metrics_reporter: (fl4health.reporting.metrics.MetricsReporter) An optional metrics reporter instance.
:param reporters: (list[fl4health.reporting.base_reporter.BaseReporter]) An optional metrics reporter instance.
Default is None.
:return: (fl4health.server.base_server.FlServer) An instance of FlServer with FedAvg as strategy.
"""
Expand All @@ -64,4 +64,4 @@ def get_server(
initial_parameters=initial_model_parameters,
)
client_manager = SimpleClientManager()
return FlServer(strategy=strategy, client_manager=client_manager, metrics_reporter=metrics_reporter)
return FlServer(strategy=strategy, client_manager=client_manager, reporters=reporters)
Loading
Loading