Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the client's "start" endpoint #10

Merged
merged 5 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,4 @@ next-env.d.ts
/florist/tsconfig.json

/metrics/
/logs/
58 changes: 58 additions & 0 deletions florist/api/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
"""FLorist client FastAPI endpoints."""
import uuid
from pathlib import Path

import torch
from fastapi import FastAPI
from fastapi.responses import JSONResponse

from florist.api.clients.common import Clients
from florist.api.launchers.local import launch_client
from florist.api.monitoring.metrics import RedisMetricsReporter


LOG_FOLDER = Path("logs/client/")

app = FastAPI()

Expand All @@ -14,3 +24,51 @@ def connect() -> JSONResponse:
:return: JSON `{"status": "ok"}`
"""
return JSONResponse({"status": "ok"})


@app.get("/api/client/start")
def start(server_address: str, client: str, data_path: str, redis_host: str, redis_port: str) -> JSONResponse:
"""
Start a client.

:param server_address: (str) the address of the server this client should report to.
It should be comprised of the host name and port separated by colon (e.g. "localhost:8080").
:param client: (str) the name of the client. Should be one of the enum values of florist.api.client.Clients.
:param data_path: (str) the path where the training data is located.
:param redis_host: (str) the host name for the Redis instance for metrics reporting.
:param redis_port: (str) the port for the Redis instance for metrics reporting.
:return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the client in the
format below, which can be used to pull metrics from Redis.
{"uuid": <client uuid>}
If not successful, returns the appropriate error code with a JSON with the format below:
{"error": <error message>}
"""
try:
if client not in Clients.list():
return JSONResponse(
content={"error": f"Client '{client}' not supported. Supported clients: {Clients.list()}"},
status_code=400,
)

client_uuid = str(uuid.uuid4())
metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=client_uuid)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

client_class = Clients.class_for_client(Clients[client])
client_obj = client_class(
data_path=Path(data_path),
metrics=[],
device=device,
metrics_reporter=metrics_reporter,
)

LOG_FOLDER.mkdir(parents=True, exist_ok=True)
log_file_name = LOG_FOLDER / f"{client_uuid}.out"

launch_client(client_obj, server_address, str(log_file_name))

return JSONResponse({"uuid": client_uuid})

except Exception as ex:
return JSONResponse({"error": str(ex)}, status_code=500)
1 change: 1 addition & 0 deletions florist/api/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Implementations for the clients."""
36 changes: 36 additions & 0 deletions florist/api/clients/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Common functions and definitions for clients."""
from enum import Enum
from typing import List

from fl4health.clients.basic_client import BasicClient

from florist.api.clients.mnist import MnistClient


class Clients(Enum):
"""Enumeration of supported clients."""

MNIST = "MNIST"

@classmethod
def class_for_client(cls, client: "Clients") -> type[BasicClient]:
"""
Return the class for a given client.

:param client: The client enumeration object.
:return: A subclass of BasicClient corresponding to the given client.
:raises ValueError: if the client is not supported.
"""
if client == Clients.MNIST:
return MnistClient

raise ValueError(f"Client {client.value} not supported.")

@classmethod
def list(cls) -> List[str]:
"""
List all the supported clients.

:return: a list of supported clients.
"""
return [client.value for client in Clients]
82 changes: 82 additions & 0 deletions florist/api/clients/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Implementation of the MNIST client and model."""
from typing import Tuple

import torch
import torch.nn.functional as f
from fl4health.clients.basic_client import BasicClient
from fl4health.utils.dataset import MnistDataset
from fl4health.utils.load_data import load_mnist_data
from flwr.common.typing import Config
from torch import nn
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader


class MnistClient(BasicClient): # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why we have to add the type: ignore in when we previously did not have to have one. Is it because you made the return type of get_dataloaders more specific?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because before it was in the testing folder. The static code checking for tests is mostly disabled because of mocking and other code practices that are OK to do in testing but not in code that runs in prod.

"""Implementation of the MNIST client."""

def get_data_loaders(self, config: Config) -> Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]:
"""
Return the data loader for MNIST data.

:param config: (Config) the Config object for this client.
:return: (Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]) a tuple with the train data loader
and validation data loader respectively.
"""
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size=config["batch_size"])
return train_loader, val_loader

def get_model(self, config: Config) -> nn.Module:
"""
Return the model for MNIST data.

:param config: (Config) the Config object for this client.
:return: (torch.nn.Module) An instance of florist.api.clients.mnist.MnistNet.
"""
return MnistNet()

def get_optimizer(self, config: Config) -> Optimizer:
"""
Return the optimizer for MNIST data.

:param config: (Config) the Config object for this client.
:return: (torch.optim.Optimizer) An instance of torch.optim.SGD with learning
rate of 0.001 and momentum of 0.9.
"""
return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)

def get_criterion(self, config: Config) -> _Loss:
"""
Return the loss for MNIST data.

:param config: (Config) the Config object for this client.
:return: (torch.nn.modules.loss._Loss) an instance of torch.nn.CrossEntropyLoss.
"""
return torch.nn.CrossEntropyLoss()


class MnistNet(nn.Module):
"""Implementation of the Mnist model."""

def __init__(self) -> None:
"""Initialize an instance of MnistNet."""
super().__init__()
self.conv1 = nn.Conv2d(1, 8, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(8, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Perform a forward pass for the given tensor.

:param x: (torch.Tensor) the tensor to perform the forward pass on.
:return: (torch.Tensor) a result tensor after the forward pass.
"""
x = self.pool(f.relu(self.conv1(x)))
x = self.pool(f.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = f.relu(self.fc1(x))
return f.relu(self.fc2(x))
1 change: 1 addition & 0 deletions florist/api/launchers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Launchers for servers and clients."""
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Launcher functions for clients and servers."""
"""Launcher functions for local clients and servers."""
import logging
import sys
import time
Expand Down
1 change: 1 addition & 0 deletions florist/api/monitoring/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Classes and functions for monitoring of clients and servers' execution."""
29 changes: 19 additions & 10 deletions florist/api/monitoring/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,25 @@


class RedisMetricsReporter(MetricsReporter): # type: ignore
"""Save the metrics to a Redis instance while it records them."""
"""
Save the metrics to a Redis instance while it records them.

def __init__(
self,
redis_connection: redis.client.Redis,
run_id: Optional[str] = None,
):
Lazily instantiates a Redis connection when the first metrics are recorded.
"""

def __init__(self, host: str, port: str, run_id: Optional[str] = None):
"""
Init an instance of RedisMetricsReporter.

:param redis_connection: (redis.client.Redis) the connection object to a Redis. Should be the output
of redis.Redis(host=host, port=port)
:param host: (str) The host address where the Redis instance is running.
:param port: (str) The port where the Redis instance is running on the host.
:param run_id: (Optional[str]) the identifier for the run which these metrics are from.
It will be used as the name of the object in Redis. Optional, default is a random UUID.
"""
super().__init__(run_id)
self.redis_connection = redis_connection
self.host = host
self.port = port
self.redis_connection: Optional[redis.Redis] = None

def add_to_metrics(self, data: Dict[str, Any]) -> None:
"""
Expand All @@ -51,7 +53,14 @@ def add_to_metrics_at_round(self, fl_round: int, data: Dict[str, Any]) -> None:
self.dump()

def dump(self) -> None:
"""Dump the current metrics to Redis under the run_id name."""
"""
Comment on lines 55 to +56
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is beyond the scope of this PR, but I am curious if you think we should explore how to dump metrics to redis at more frequent intervals. If we only do so at the end, then in the case of a crash we lose the metrics. Do you think this is something worthwhile to explore or not really important enough as of now to be thinking about?

Copy link
Collaborator Author

@lotif lotif Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did it in this class, I added a call to dump at the end of each method so it will update redis every time a new metric is recorded. I think this kind of behaviour overkill for the main class in FL4Health but it's necessary here and easy enough to instrument.

Dump the current metrics to Redis under the run_id name.

Will instantiate a Redis connection if it's the first time it runs for this instance.
"""
if self.redis_connection is None:
self.redis_connection = redis.Redis(host=self.host, port=self.port)

encoded_metrics = json.dumps(self.metrics, cls=DateTimeEncoder)
log(DEBUG, f"Dumping metrics to redis at key '{self.run_id}': {encoded_metrics}")
self.redis_connection.set(self.run_id, encoded_metrics)
60 changes: 0 additions & 60 deletions florist/tests/api/monitoring/test_metrics.py

This file was deleted.

6 changes: 3 additions & 3 deletions florist/tests/integration/api/launchers/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import torch
from fl4health.server.base_server import FlServer

from florist.api.launchers.launch import launch
from florist.tests.utils.api.fl4health_utils import MnistClient, get_server_fedavg
from florist.tests.utils.api.models import MnistNet
from florist.api.launchers.local import launch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call, I like the distinction between local and "distributed" launchers that we will build out down the line

from florist.api.clients.mnist import MnistClient, MnistNet
from florist.tests.utils.api.fl4health_utils import get_server_fedavg


def fit_config(batch_size: int, local_epochs: int, current_server_round: int) -> Dict[str, int]:
Expand Down
Loading