Skip to content

Commit

Permalink
Upgrade fl4heath and made appropriate changes (#110)
Browse files Browse the repository at this point in the history
* Upgrade fl4heath and made appropriate changes

* Update documentation

* Upgrade python version to 3.10 in workflow files

* Upgrade python version to 3.10 in integration workflow file

* Upgrade static code check workflow python version

* Fix issues in test_launch integration test

* Changes to make compatible with new reporter structure

Updated key name from type to host_type in tests and entities

* Install most recent version of fl4health from main branch

* Try adding fl4health via http instead of ssh to avoid auth issues on github server

* Change type to host_type in front end

* Revert change from type to host_type

* Add back change from type to host_type and fix UI test with old names
  • Loading branch information
jewelltaylor authored Nov 7, 2024
1 parent d0604c4 commit 59c8cef
Show file tree
Hide file tree
Showing 22 changed files with 4,301 additions and 2,718 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/docs_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ jobs:
virtualenvs-in-project: true
- uses: actions/[email protected]
with:
python-version: '3.9'
python-version: '3.10'
cache: 'poetry'
- run: |
python3 -m pip install --upgrade pip && python3 -m pip install poetry
poetry env use '3.9'
poetry env use '3.10'
source .venv/bin/activate
poetry install --with docs,test
cd docs && rm -rf source/reference/api/_autosummary && make html
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/docs_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ jobs:
virtualenvs-in-project: true
- uses: actions/[email protected]
with:
python-version: '3.9'
python-version: '3.10'
cache: 'poetry'
- run: |
python3 -m pip install --upgrade pip && python3 -m pip install poetry
poetry env use '3.9'
poetry env use '3.10'
source $(poetry env info --path)/bin/activate
poetry install --with docs,test
cd docs && rm -rf source/reference/api/_autosummary && make html
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/integration_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
virtualenvs-in-project: true
- uses: actions/[email protected]
with:
python-version: '3.9'
python-version: '3.10'
- name: Setup redis
uses: supercharge/[email protected]
with:
Expand All @@ -57,7 +57,7 @@ jobs:
mongodb-version: 7.0.8
- name: Install dependencies and check code
run: |
poetry env use '3.9'
poetry env use '3.10'
source .venv/bin/activate
poetry install --with docs,test
coverage run -m pytest florist/tests/integration && coverage xml && coverage report -m
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
virtualenvs-in-project: true
- uses: actions/[email protected]
with:
python-version: '3.9'
python-version: '3.10'
- name: Build package
run: poetry build
- name: Publish package
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/static_code_checks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ jobs:
virtualenvs-in-project: true
- uses: actions/[email protected]
with:
python-version: '3.9'
python-version: '3.10'
cache: 'poetry'
- name: Setup yarn
uses: mskelton/setup-yarn@v3
- name: Install dependencies and check code
run: |
yarn
poetry env use '3.9'
poetry env use '3.10'
source .venv/bin/activate
poetry install --with test --all-extras
pre-commit run --all-files --verbose
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ jobs:
virtualenvs-in-project: true
- uses: actions/[email protected]
with:
python-version: '3.9'
python-version: '3.10'
- name: Install python dependencies and check code
run: |
poetry env use '3.9'
poetry env use '3.10'
source .venv/bin/activate
poetry install --with docs,test
coverage run -m pytest florist/tests/unit && coverage xml && coverage report -m
Expand Down
2 changes: 1 addition & 1 deletion florist/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red
data_path=Path(data_path),
metrics=[],
device=device,
metrics_reporter=metrics_reporter,
reporters=[metrics_reporter],
)

log_file_name = str(get_client_log_file_path(client_uuid))
Expand Down
6 changes: 3 additions & 3 deletions florist/api/clients/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
from fl4health.clients.basic_client import BasicClient
from fl4health.utils.dataset import MnistDataset
from fl4health.utils.dataset import TensorDataset
from fl4health.utils.load_data import load_mnist_data
from flwr.common.typing import Config
from torch import nn
Expand All @@ -18,15 +18,15 @@
class MnistClient(BasicClient): # type: ignore[misc]
"""Implementation of the MNIST client."""

def get_data_loaders(self, config: Config) -> Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]:
def get_data_loaders(self, config: Config) -> Tuple[DataLoader[TensorDataset], DataLoader[TensorDataset]]:
"""
Return the data loader for MNIST data.
:param config: (Config) the Config object for this client.
:return: (Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]) a tuple with the train data loader
and validation data loader respectively.
"""
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size=config["batch_size"])
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size=int(config["batch_size"]))
return train_loader, val_loader

def get_model(self, config: Config) -> nn.Module:
Expand Down
14 changes: 7 additions & 7 deletions florist/api/db/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Config:
"redis_host": "localhost",
"redis_port": "6380",
"uuid": "0c316680-1375-4e07-84c3-a732a2e6d03f",
"metrics": '{"type": "client", "initialized": "2024-03-25 11:20:56.819569", "rounds": {"1": {"fit_start": "2024-03-25 11:20:56.827081"}}}',
"metrics": '{"host_type": "client", "initialized": "2024-03-25 11:20:56.819569", "rounds": {"1": {"fit_start": "2024-03-25 11:20:56.827081"}}}',
},
}

Expand All @@ -83,7 +83,7 @@ class Job(BaseModel):
clients_info: Optional[Annotated[List[ClientInfo], Field(...)]]

@classmethod
async def find_by_id(cls, job_id: str, database: AsyncIOMotorDatabase) -> Optional["Job"]:
async def find_by_id(cls, job_id: str, database: AsyncIOMotorDatabase[Any]) -> Optional["Job"]:
"""
Find a job in the database by its id.
Expand All @@ -98,7 +98,7 @@ async def find_by_id(cls, job_id: str, database: AsyncIOMotorDatabase) -> Option
return Job(**result)

@classmethod
async def find_by_status(cls, status: JobStatus, limit: int, database: AsyncIOMotorDatabase) -> List["Job"]:
async def find_by_status(cls, status: JobStatus, limit: int, database: AsyncIOMotorDatabase[Any]) -> List["Job"]:
"""
Return all jobs with the given status.
Expand All @@ -114,7 +114,7 @@ async def find_by_status(cls, status: JobStatus, limit: int, database: AsyncIOMo
assert isinstance(result, list)
return [Job(**r) for r in result]

async def create(self, database: AsyncIOMotorDatabase) -> str:
async def create(self, database: AsyncIOMotorDatabase[Any]) -> str:
"""
Save this instance under a new record in the database.
Expand All @@ -126,7 +126,7 @@ async def create(self, database: AsyncIOMotorDatabase) -> str:
assert isinstance(result.inserted_id, str)
return result.inserted_id

async def set_uuids(self, server_uuid: str, client_uuids: List[str], database: AsyncIOMotorDatabase) -> None:
async def set_uuids(self, server_uuid: str, client_uuids: List[str], database: AsyncIOMotorDatabase[Any]) -> None:
"""
Save the server and clients' UUIDs in the database under the current job's id.
Expand All @@ -152,7 +152,7 @@ async def set_uuids(self, server_uuid: str, client_uuids: List[str], database: A
)
assert_updated_successfully(update_result)

async def set_status(self, status: JobStatus, database: AsyncIOMotorDatabase) -> None:
async def set_status(self, status: JobStatus, database: AsyncIOMotorDatabase[Any]) -> None:
"""
Save the status in the database under the current job's id.
Expand Down Expand Up @@ -232,7 +232,7 @@ class Config:
"server_address": "localhost:8000",
"server_config": '{"n_server_rounds": 3, "batch_size": 8, "local_epochs": 1}',
"server_uuid": "d73243cf-8b89-473b-9607-8cd0253a101d",
"server_metrics": '{"type": "server", "fit_start": "2024-04-23 15:33:12.865604", "rounds": {"1": {"fit_start": "2024-04-23 15:33:12.869001"}}}',
"server_metrics": '{"host_type": "server", "fit_start": "2024-04-23 15:33:12.865604", "rounds": {"1": {"fit_start": "2024-04-23 15:33:12.869001"}}}',
"redis_host": "localhost",
"redis_port": "6379",
"clients_info": [
Expand Down
6 changes: 3 additions & 3 deletions florist/api/launchers/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import flwr as fl
from fl4health.clients.basic_client import BasicClient
from fl4health.server.base_server import FlServer
from flwr.common.logger import DEFAULT_FORMATTER
from flwr.server import ServerConfig


DEFAULT_FORMATTER = logging.Formatter("%(levelname)s %(name)s %(asctime)s | %(filename)s:%(lineno)d | %(message)s")


def redirect_logging_from_console_to_file(log_file_path: str) -> None:
"""
Redirect loggers outputting to console to specified file.
Expand Down Expand Up @@ -60,7 +62,6 @@ def start_server(
config=ServerConfig(num_rounds=n_server_rounds),
)
server.shutdown()
server.metrics_reporter.dump()


def start_client(client: BasicClient, server_address: str, client_log_file_name: str) -> None:
Expand All @@ -79,7 +80,6 @@ def start_client(client: BasicClient, server_address: str, client_log_file_name:
sys.stderr = log_file
fl.client.start_numpy_client(server_address=server_address, client=client)
client.shutdown()
client.metrics_reporter.dump()


def launch_server(
Expand Down
87 changes: 69 additions & 18 deletions florist/api/monitoring/metrics.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
"""Classes for the instrumentation of metrics reporting from clients and servers."""

import datetime
import json
import time
import uuid
from logging import DEBUG, Logger
from typing import Any, Dict, Optional

import redis
from fl4health.reporting.metrics import DateTimeEncoder, MetricsReporter
from fl4health.reporting.base_reporter import BaseReporter
from flwr.common.logger import log
from redis.client import PubSub


class RedisMetricsReporter(MetricsReporter): # type: ignore
class DateTimeEncoder(json.JSONEncoder):
"""Converts a datetime object to string in order to make json encoding easier."""

def default(self, o: Any) -> Any:
"""
Return string of datetime if datetime object is passed else return result of the default encoder method.
:param o: Object to encode.
"""
if isinstance(o, datetime.datetime):
return str(o)
return json.JSONEncoder.default(self, o)


class RedisMetricsReporter(BaseReporter): # type: ignore
"""
Save the metrics to a Redis instance while it records them.
Expand All @@ -27,32 +43,65 @@ def __init__(self, host: str, port: str, run_id: Optional[str] = None):
:param run_id: (Optional[str]) the identifier for the run which these metrics are from.
It will be used as the name of the object in Redis. Optional, default is a random UUID.
"""
super().__init__(run_id)
self.host = host
self.port = port
self.run_id = run_id
self.initialized = False

self.redis_connection: Optional[redis.Redis] = None
self.metrics: Dict[str, Any] = {}

def add_to_metrics(self, data: Dict[str, Any]) -> None:
def initialize(self, **kwargs: Any) -> None:
"""
Add a dictionary of data into the main metrics dictionary.
Initialize RedisMetricReporter with run_id and set initialized to True.
At the end, dumps the current state of the metrics to Redis.
:param data: (Dict[str, Any]) Data to be added to the metrics dictionary via .update().
:param kwargs: (Any) The keyword arguments required to initialize the Reporter.
"""
super().add_to_metrics(data)
self.dump()

def add_to_metrics_at_round(self, fl_round: int, data: Dict[str, Any]) -> None:
# If run_id was not specified on init try first to initialize with client name
if self.run_id is None:
self.run_id = kwargs.get("id")
# If client name was not provided, init run id manually
if self.run_id is None:
self.run_id = str(uuid.uuid4())

self.initialized = True

def report(
self,
data: dict[str, Any],
round: int | None = None, # noqa: A002
epoch: int | None = None,
step: int | None = None,
) -> None:
"""Send data to the reporter.
The report method is called by the client/server at frequent intervals (ie step, epoch, round) and sometimes
outside of a FL round (for high level summary data). The json reporter is hardcoded to report at the 'round'
level and therefore ignores calls to the report method made every epoch or every step.
Args:
data (dict): The data to maybe report from the server or client.
round (int | None, optional): The current FL round. If None, this indicates that the method was called
outside of a round (e.g. for summary information). Defaults to None.
epoch (int | None, optional): The current epoch. If None then this method was not called within the scope
of an epoch. Defaults to None.
step (int | None, optional): The current step (total). If None then this method was called outside the
scope of a training or evaluation step (eg. at the end of an epoch or round) Defaults to None.
"""
Add a dictionary of data into the metrics dictionary for a specific FL round.
if not self.initialized:
self.initialize()

At the end, dumps the current state of the metrics to Redis.
if round is None: # Reports outside of a fit round
self.metrics.update(data)
# Ensure we don't report for each epoch or step
elif epoch is None and step is None:
if "rounds" not in self.metrics:
self.metrics["rounds"] = {}
if round not in self.metrics["rounds"]:
self.metrics["rounds"][round] = {}

self.metrics["rounds"][round].update(data)

:param fl_round: (int) the FL round these metrics are from.
:param data: (Dict[str, Any]) Data to be added to the round's metrics dictionary via .update().
"""
super().add_to_metrics_at_round(fl_round, data)
self.dump()

def dump(self) -> None:
Expand All @@ -64,6 +113,8 @@ def dump(self) -> None:
if self.redis_connection is None:
self.redis_connection = redis.Redis(host=self.host, port=self.port)

assert self.run_id is not None, "Run ID is None, ensure reporter is initialized prior to dumping metrics."

encoded_metrics = json.dumps(self.metrics, cls=DateTimeEncoder)
log(DEBUG, f"Dumping metrics to redis at key '{self.run_id}': {encoded_metrics}")
self.redis_connection.set(self.run_id, encoded_metrics)
Expand Down
2 changes: 1 addition & 1 deletion florist/api/servers/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def launch_local_server(
metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=server_uuid)
server_constructor = partial(
get_server,
reporters=[metrics_reporter],
model=model,
n_clients=n_clients,
batch_size=batch_size,
local_epochs=local_epochs,
metrics_reporter=metrics_reporter,
)

log_file_name = str(get_server_log_file_path(server_uuid))
Expand Down
8 changes: 4 additions & 4 deletions florist/api/servers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Dict, Union

from fl4health.client_managers.base_sampling_manager import SimpleClientManager
from fl4health.reporting.metrics import MetricsReporter
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.server.base_server import FlServer
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from flwr.common.parameter import ndarrays_to_parameters
Expand Down Expand Up @@ -33,11 +33,11 @@ def fit_config(batch_size: int, local_epochs: int, current_server_round: int) ->

def get_server(
model: nn.Module,
reporters: list[BaseReporter],
fit_config: Callable[[int, int, int], Dict[str, int]] = fit_config,
n_clients: int = 2,
batch_size: int = 8,
local_epochs: int = 1,
metrics_reporter: MetricsReporter = None,
) -> FlServer:
"""
Return a server instance with FedAvg aggregation strategy.
Expand All @@ -47,7 +47,7 @@ def get_server(
:param n_clients: (int) the number of clients that will participate on training. Optional, default is 2.
:param batch_size: (int) the size of the batch of samples. Optional, default is 8.
:param local_epochs: (int) the number of local epochs the clients will run. Optional, default is 1.
:param metrics_reporter: (fl4health.reporting.metrics.MetricsReporter) An optional metrics reporter instance.
:param reporters: (list[fl4health.reporting.base_reporter.BaseReporter]) An optional metrics reporter instance.
Default is None.
:return: (fl4health.server.base_server.FlServer) An instance of FlServer with FedAvg as strategy.
"""
Expand All @@ -64,4 +64,4 @@ def get_server(
initial_parameters=initial_model_parameters,
)
client_manager = SimpleClientManager()
return FlServer(strategy=strategy, client_manager=client_manager, metrics_reporter=metrics_reporter)
return FlServer(strategy=strategy, client_manager=client_manager, reporters=reporters)
Loading

0 comments on commit 59c8cef

Please sign in to comment.