Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add integration tests with Redis #11

Merged
merged 13 commits into from
Mar 13, 2024
2 changes: 1 addition & 1 deletion .github/workflows/docs_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
source $(poetry env info --path)/bin/activate
poetry install --with docs,test
cd docs && rm -rf source/reference/api/_autosummary && make html
cd .. && coverage run -m pytest -m "not integration_test" && coverage xml && coverage report -m
cd .. && coverage run -m pytest florist/tests/unit && coverage xml && coverage report -m
# - name: Upload coverage to Codecov
# uses: Wandalen/[email protected]
# with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
source $(poetry env info --path)/bin/activate
poetry install --with docs,test
cd docs && rm -rf source/reference/api/_autosummary && make html
cd .. && coverage run -m pytest -m "not integration_test" && coverage xml && coverage report -m
cd .. && coverage run -m pytest florist/tests/unit && coverage xml && coverage report -m
# - name: Upload coverage to Codecov
# uses: Wandalen/[email protected]
# with:
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/integration_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ jobs:
- uses: actions/[email protected]
with:
python-version: '3.9'
- name: Setup redis
uses: supercharge/[email protected]
with:
redis-version: 7.2.4
- name: Install dependencies and check code
run: |
poetry env use '3.9'
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,4 @@ next-env.d.ts
/florist/tsconfig.json

/metrics/
/logs/
13 changes: 13 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ Then, run the server and client's Redis instance by following
[these instructions](README.md#start-servers-redis-instance) and
[these instructions](README.md#start-clients-redis-instance) respectively.

## Running the tests

To run the unit tests, simply execute:
```shell
pytest florist/tests/unit
```

To run the integration tests, first make sure you have a Redis server running on your
local machine on port 6379, then execute:
```shell
pytest florist/tests/integration
```

## Coding guidelines

For code style, we recommend the [PEP 8 style guide](https://peps.python.org/pep-0008/).
Expand Down
90 changes: 90 additions & 0 deletions florist/api/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,53 @@
"""FLorist client FastAPI endpoints."""
import uuid
from enum import Enum
from pathlib import Path
from typing import List

import torch
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fl4health.clients.basic_client import BasicClient

from florist.api.clients.mnist import MnistClient
from florist.api.launchers.local import launch_client
from florist.api.monitoring.metrics import RedisMetricsReporter


LOG_FOLDER = Path("logs/client/")

app = FastAPI()


class Clients(Enum):
"""Enumeration of supported clients."""

MNIST = "MNIST"

@classmethod
def class_for_client(cls, client: "Clients") -> type[BasicClient]:
"""
Return the class for a given client.

:param client: The client enumeration object.
:return: A subclass of BasicClient corresponding to the given client.
:raises ValueError: if the client is not supported.
"""
if client == Clients.MNIST:
return MnistClient

raise ValueError(f"Client {client.value} not supported.")

@classmethod
def list(cls) -> List[str]:
"""
List all the supported clients.

:return: a list of supported clients.
"""
return [client.value for client in Clients]


@app.get("/api/client/connect")
def connect() -> JSONResponse:
"""
Expand All @@ -14,3 +56,51 @@ def connect() -> JSONResponse:
:return: JSON `{"status": "ok"}`
"""
return JSONResponse({"status": "ok"})


@app.get("/api/client/start")
def start(server_address: str, client: str, data_path: str, redis_host: str, redis_port: str) -> JSONResponse:
"""
Start a client.

:param server_address: (str) the address of the server this client should report to.
It should be comprised of the host name and port separated by colon (e.g. "localhost:8080").
:param client: (str) the name of the client. Should be one of the enum values of florist.api.client.Clients.
:param data_path: (str) the path where the training data is located.
:param redis_host: (str) the host name for the Redis instance for metrics reporting.
:param redis_port: (str) the port for the Redis instance for metrics reporting.
:return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the client in the
format below, which can be used to pull metrics from Redis.
{"uuid": <client uuid>}
If not successful, returns the appropriate error code with a JSON with the format below:
{"error": <error message>}
"""
try:
if client not in Clients.list():
return JSONResponse(
content={"error": f"Client '{client}' not supported. Supported clients: {Clients.list()}"},
status_code=400,
)

client_uuid = str(uuid.uuid4())
metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=client_uuid)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

client_class = Clients.class_for_client(Clients[client])
client_obj = client_class(
data_path=Path(data_path),
metrics=[],
device=device,
metrics_reporter=metrics_reporter,
)

LOG_FOLDER.mkdir(parents=True, exist_ok=True)
log_file_name = LOG_FOLDER / f"{client_uuid}.out"

launch_client(client_obj, server_address, str(log_file_name))

return JSONResponse({"uuid": client_uuid})

except Exception as ex:
return JSONResponse({"error": str(ex)}, status_code=500)
1 change: 1 addition & 0 deletions florist/api/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Implementations for the clients."""
82 changes: 82 additions & 0 deletions florist/api/clients/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Implementation of the MNIST client and model."""
from typing import Tuple

import torch
import torch.nn.functional as f
from fl4health.clients.basic_client import BasicClient
from fl4health.utils.dataset import MnistDataset
from fl4health.utils.load_data import load_mnist_data
from flwr.common.typing import Config
from torch import nn
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader


class MnistClient(BasicClient): # type: ignore
"""Implementation of the MNIST client."""

def get_data_loaders(self, config: Config) -> Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]:
"""
Return the data loader for MNIST data.

:param config: (Config) the Config object for this client.
:return: (Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]) a tuple with the train data loader
and validation data loader respectively.
"""
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size=config["batch_size"])
return train_loader, val_loader

def get_model(self, config: Config) -> nn.Module:
"""
Return the model for MNIST data.

:param config: (Config) the Config object for this client.
:return: (torch.nn.Module) An instance of florist.api.clients.mnist.MnistNet.
"""
return MnistNet()

def get_optimizer(self, config: Config) -> Optimizer:
"""
Return the optimizer for MNIST data.

:param config: (Config) the Config object for this client.
:return: (torch.optim.Optimizer) An instance of torch.optim.SGD with learning
rate of 0.001 and momentum of 0.9.
"""
return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)

def get_criterion(self, config: Config) -> _Loss:
"""
Return the loss for MNIST data.

:param config: (Config) the Config object for this client.
:return: (torch.nn.modules.loss._Loss) an instance of torch.nn.CrossEntropyLoss.
"""
return torch.nn.CrossEntropyLoss()


class MnistNet(nn.Module):
"""Implementation of the Mnist model."""

def __init__(self) -> None:
"""Initialize an instance of MnistNet."""
super().__init__()
self.conv1 = nn.Conv2d(1, 8, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(8, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Perform a forward pass for the given tensor.

:param x: (torch.Tensor) the tensor to perform the forward pass on.
:return: (torch.Tensor) a result tensor after the forward pass.
"""
x = self.pool(f.relu(self.conv1(x)))
x = self.pool(f.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = f.relu(self.fc1(x))
return f.relu(self.fc2(x))
1 change: 1 addition & 0 deletions florist/api/launchers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Launchers for servers and clients."""
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Launcher functions for clients and servers."""
"""Launcher functions for local clients and servers."""
import logging
import sys
import time
Expand Down
1 change: 1 addition & 0 deletions florist/api/monitoring/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Classes and functions for monitoring of clients and servers' execution."""
29 changes: 19 additions & 10 deletions florist/api/monitoring/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,25 @@


class RedisMetricsReporter(MetricsReporter): # type: ignore
"""Save the metrics to a Redis instance while it records them."""
"""
Save the metrics to a Redis instance while it records them.

def __init__(
self,
redis_connection: redis.client.Redis,
run_id: Optional[str] = None,
):
Lazily instantiates a Redis connection when the first metrics are recorded.
"""

def __init__(self, host: str, port: str, run_id: Optional[str] = None):
"""
Init an instance of RedisMetricsReporter.

:param redis_connection: (redis.client.Redis) the connection object to a Redis. Should be the output
of redis.Redis(host=host, port=port)
:param host: (str) The host address where the Redis instance is running.
:param port: (str) The port where the Redis instance is running on the host.
:param run_id: (Optional[str]) the identifier for the run which these metrics are from.
It will be used as the name of the object in Redis. Optional, default is a random UUID.
"""
super().__init__(run_id)
self.redis_connection = redis_connection
self.host = host
self.port = port
self.redis_connection: Optional[redis.Redis] = None

def add_to_metrics(self, data: Dict[str, Any]) -> None:
"""
Expand All @@ -51,7 +53,14 @@ def add_to_metrics_at_round(self, fl_round: int, data: Dict[str, Any]) -> None:
self.dump()

def dump(self) -> None:
"""Dump the current metrics to Redis under the run_id name."""
"""
Dump the current metrics to Redis under the run_id name.

Will instantiate a Redis connection if it's the first time it runs for this instance.
"""
if self.redis_connection is None:
self.redis_connection = redis.Redis(host=self.host, port=self.port)

encoded_metrics = json.dumps(self.metrics, cls=DateTimeEncoder)
log(DEBUG, f"Dumping metrics to redis at key '{self.run_id}': {encoded_metrics}")
self.redis_connection.set(self.run_id, encoded_metrics)
60 changes: 0 additions & 60 deletions florist/tests/api/monitoring/test_metrics.py

This file was deleted.

Loading