Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"start training" API (server side) #15

Merged
merged 12 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/static_code_checks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ jobs:
with:
virtual-environment: .venv/
# Ignoring security vulnerabilities in Pillow because pycyclops cannot update it to the
# version that fixes them (>10.0.1).
# Remove those when the issue below is fixed and pycyclops changes its requirements:
# https://github.com/SeldonIO/alibi/issues/991
# version that fixes them (>10.3.0).
# Remove those when FL4Health is released with the update to pillow > 10
ignore-vulns: |
PYSEC-2023-175
PYSEC-2023-227
GHSA-j7hp-h8jx-5ppr
GHSA-56pw-mpj4-fxww
GHSA-3f63-hfp8-52jq
GHSA-44wm-f244-xhp3
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,4 @@ next-env.d.ts

/metrics/
/logs/
/.ruff_cache/
12 changes: 12 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ Then, run the server and client's Redis instance by following
[these instructions](README.md#start-servers-redis-instance) and
[these instructions](README.md#start-clients-redis-instance) respectively.

To start the server in development mode, run:
```shell
yarn dev
```

## Running the tests

To run the unit tests, simply execute:
Expand All @@ -72,3 +77,10 @@ analysis. Ruff checks various rules including [flake8](https://docs.astral.sh/ru

Last but not the least, we use type hints in our code which is then checked using
[mypy](https://mypy.readthedocs.io/en/stable/).

## Documentation

Backend code API documentation can be found at https://vectorinstitute.github.io/FLorist/.

Backend REST API documentation can be found at https://localhost:8000/docs once the server
is running locally.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ docker start redis-florist-server

### Start back-end and front-end servers

Use Yarn to run both the back-end and front-end on server mode:
Use Yarn to run both the back-end and front-end on production server mode:

```shell
yarn dev
yarn prod
lotif marked this conversation as resolved.
Show resolved Hide resolved
```

The front-end will be available at `http://localhost:3000`. If you want to access
Expand Down Expand Up @@ -95,3 +95,8 @@ uvicorn florist.api.client:app --reload --port 8001
```

The service will be available at `http://localhost:8001`.

# Contributing
If you are interested in contributing to the library, please see [CONTRIBUTING.MD](CONTRIBUTING.md).
This file contains many details around contributing to the code base, including development
practices, code checks, tests, and more.
14 changes: 6 additions & 8 deletions florist/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi import FastAPI
from fastapi.responses import JSONResponse

from florist.api.clients.common import Clients
from florist.api.clients.common import Client
from florist.api.launchers.local import launch_client
from florist.api.monitoring.logs import get_client_log_file_path
from florist.api.monitoring.metrics import RedisMetricsReporter
Expand All @@ -30,7 +30,7 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red
"""
Start a client.

:param server_address: (str) the address of the server this client should report to.
:param server_address: (str) the address of the FL server the FL client should report to.
It should be comprised of the host name and port separated by colon (e.g. "localhost:8080").
:param client: (str) the name of the client. Should be one of the enum values of florist.api.client.Clients.
:param data_path: (str) the path where the training data is located.
Expand All @@ -43,18 +43,16 @@ def start(server_address: str, client: str, data_path: str, redis_host: str, red
{"error": <error message>}
"""
try:
if client not in Clients.list():
return JSONResponse(
content={"error": f"Client '{client}' not supported. Supported clients: {Clients.list()}"},
status_code=400,
)
if client not in Client.list():
error_msg = f"Client '{client}' not supported. Supported clients: {Client.list()}"
return JSONResponse(content={"error": error_msg}, status_code=400)

client_uuid = str(uuid.uuid4())
metrics_reporter = RedisMetricsReporter(host=redis_host, port=redis_port, run_id=client_uuid)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

client_class = Clients.class_for_client(Clients[client])
client_class = Client.class_for_client(Client[client])
client_obj = client_class(
data_path=Path(data_path),
metrics=[],
Expand Down
14 changes: 7 additions & 7 deletions florist/api/clients/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@
from florist.api.clients.mnist import MnistClient


class Clients(Enum):
class Client(Enum):
"""Enumeration of supported clients."""

MNIST = "MNIST"

@classmethod
def class_for_client(cls, client: "Clients") -> type[BasicClient]:
def class_for_client(cls, client: "Client") -> type[BasicClient]:
"""
Return the class for a given client.

:param client: The client enumeration object.
:return: A subclass of BasicClient corresponding to the given client.
:param client: (Client) The client enumeration object.
:return: (type[BasicClient]) A subclass of BasicClient corresponding to the given client.
:raises ValueError: if the client is not supported.
"""
if client == Clients.MNIST:
if client == Client.MNIST:
return MnistClient

raise ValueError(f"Client {client.value} not supported.")
Expand All @@ -31,6 +31,6 @@ def list(cls) -> List[str]:
"""
List all the supported clients.

:return: a list of supported clients.
:return: (List[str]) a list of supported clients.
"""
return [client.value for client in Clients]
return [client.value for client in Client]
29 changes: 2 additions & 27 deletions florist/api/clients/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Tuple

import torch
import torch.nn.functional as f
from fl4health.clients.basic_client import BasicClient
from fl4health.utils.dataset import MnistDataset
from fl4health.utils.load_data import load_mnist_data
Expand All @@ -12,6 +11,8 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from florist.api.models.mnist import MnistNet


class MnistClient(BasicClient): # type: ignore
"""Implementation of the MNIST client."""
Expand Down Expand Up @@ -54,29 +55,3 @@ def get_criterion(self, config: Config) -> _Loss:
:return: (torch.nn.modules.loss._Loss) an instance of torch.nn.CrossEntropyLoss.
"""
return torch.nn.CrossEntropyLoss()


class MnistNet(nn.Module):
"""Implementation of the Mnist model."""

def __init__(self) -> None:
"""Initialize an instance of MnistNet."""
super().__init__()
self.conv1 = nn.Conv2d(1, 8, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(8, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Perform a forward pass for the given tensor.

:param x: (torch.Tensor) the tensor to perform the forward pass on.
:return: (torch.Tensor) a result tensor after the forward pass.
"""
x = self.pool(f.relu(self.conv1(x)))
x = self.pool(f.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = f.relu(self.fc1(x))
return f.relu(self.fc2(x))
15 changes: 0 additions & 15 deletions florist/api/index.py

This file was deleted.

1 change: 1 addition & 0 deletions florist/api/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Contains the models definitions."""
30 changes: 30 additions & 0 deletions florist/api/models/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Definitions for the MNIST model."""
import torch
import torch.nn.functional as f
from torch import nn


class MnistNet(nn.Module):
"""Implementation of the Mnist model."""

def __init__(self) -> None:
"""Initialize an instance of MnistNet."""
super().__init__()
self.conv1 = nn.Conv2d(1, 8, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(8, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Perform a forward pass for the given tensor.

:param x: (torch.Tensor) the tensor to perform the forward pass on.
:return: (torch.Tensor) a result tensor after the forward pass.
"""
x = self.pool(f.relu(self.conv1(x)))
x = self.pool(f.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = f.relu(self.fc1(x))
return f.relu(self.fc2(x))
56 changes: 55 additions & 1 deletion florist/api/monitoring/metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Classes for the instrumentation of metrics reporting from clients and servers."""
import json
from logging import DEBUG
import time
from logging import DEBUG, Logger
from typing import Any, Dict, Optional

import redis
Expand Down Expand Up @@ -64,3 +65,56 @@ def dump(self) -> None:
encoded_metrics = json.dumps(self.metrics, cls=DateTimeEncoder)
log(DEBUG, f"Dumping metrics to redis at key '{self.run_id}': {encoded_metrics}")
self.redis_connection.set(self.run_id, encoded_metrics)


def wait_for_metric(
uuid: str,
metric: str,
redis_host: str,
redis_port: str,
logger: Logger,
max_retries: int = 20,
seconds_to_sleep_between_retries: int = 1,
) -> None:
"""
Check metrics on Redis under the given UUID and wait until it appears.

If the metrics are not there yet, it will retry up to max_retries times,
sleeping an amount of `seconds_to_sleep_between_retries` between them.

:param uuid: (str) The UUID to pull the metrics from Redis.
:param metric: (str) The metric to look for.
:param redis_host: (str) The hostname of the Redis instance the metrics are being reported to.
:param redis_port: (str) The port of the Redis instance the metrics are being reported to.
:param logger: (logging.Logger) A logger instance to write logs to.
:param max_retries: (int) The maximum number of retries. Optional, default is 20.
:param seconds_to_sleep_between_retries: (int) The amount of seconds to sleep between retries.
Optional, default is 1.
:raises Exception: If it retries `max_retries` times and the right metrics have not been found.
"""
redis_connection = redis.Redis(host=redis_host, port=redis_port)

retry = 0
while retry < max_retries:
result = redis_connection.get(uuid)

if result is not None:
assert isinstance(result, bytes)
json_result = json.loads(result.decode("utf8"))
if metric in json_result:
logger.debug(f"Metric '{metric}' has been found. Result: {json_result}")
return

logger.debug(
f"Metric '{metric}' has not been found yet, sleeping for {seconds_to_sleep_between_retries}s. "
f"Retry: {retry}. Result: {json_result}"
)
else:
logger.debug(
f"Metric '{metric}' has not been found yet, sleeping for {seconds_to_sleep_between_retries}s. "
f"Retry: {retry}. Result is None."
)
time.sleep(seconds_to_sleep_between_retries)
retry += 1

raise Exception(f"Metric '{metric}' not been found after {max_retries} retries.")
1 change: 1 addition & 0 deletions florist/api/routes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""FastAPI routes."""
1 change: 1 addition & 0 deletions florist/api/routes/server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""FastAPI server routes."""
Loading