-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make the client's "start" endpoint #10
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -168,3 +168,4 @@ next-env.d.ts | |
/florist/tsconfig.json | ||
|
||
/metrics/ | ||
/logs/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Implementations for the clients.""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
"""Common functions and definitions for clients.""" | ||
from enum import Enum | ||
from typing import List | ||
|
||
from fl4health.clients.basic_client import BasicClient | ||
|
||
from florist.api.clients.mnist import MnistClient | ||
|
||
|
||
class Clients(Enum): | ||
"""Enumeration of supported clients.""" | ||
|
||
MNIST = "MNIST" | ||
|
||
@classmethod | ||
def class_for_client(cls, client: "Clients") -> type[BasicClient]: | ||
""" | ||
Return the class for a given client. | ||
|
||
:param client: The client enumeration object. | ||
:return: A subclass of BasicClient corresponding to the given client. | ||
:raises ValueError: if the client is not supported. | ||
""" | ||
if client == Clients.MNIST: | ||
return MnistClient | ||
|
||
raise ValueError(f"Client {client.value} not supported.") | ||
|
||
@classmethod | ||
def list(cls) -> List[str]: | ||
""" | ||
List all the supported clients. | ||
|
||
:return: a list of supported clients. | ||
""" | ||
return [client.value for client in Clients] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
"""Implementation of the MNIST client and model.""" | ||
from typing import Tuple | ||
|
||
import torch | ||
import torch.nn.functional as f | ||
from fl4health.clients.basic_client import BasicClient | ||
from fl4health.utils.dataset import MnistDataset | ||
from fl4health.utils.load_data import load_mnist_data | ||
from flwr.common.typing import Config | ||
from torch import nn | ||
from torch.nn.modules.loss import _Loss | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
class MnistClient(BasicClient): # type: ignore | ||
"""Implementation of the MNIST client.""" | ||
|
||
def get_data_loaders(self, config: Config) -> Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]: | ||
""" | ||
Return the data loader for MNIST data. | ||
|
||
:param config: (Config) the Config object for this client. | ||
:return: (Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]) a tuple with the train data loader | ||
and validation data loader respectively. | ||
""" | ||
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size=config["batch_size"]) | ||
return train_loader, val_loader | ||
|
||
def get_model(self, config: Config) -> nn.Module: | ||
""" | ||
Return the model for MNIST data. | ||
|
||
:param config: (Config) the Config object for this client. | ||
:return: (torch.nn.Module) An instance of florist.api.clients.mnist.MnistNet. | ||
""" | ||
return MnistNet() | ||
|
||
def get_optimizer(self, config: Config) -> Optimizer: | ||
""" | ||
Return the optimizer for MNIST data. | ||
|
||
:param config: (Config) the Config object for this client. | ||
:return: (torch.optim.Optimizer) An instance of torch.optim.SGD with learning | ||
rate of 0.001 and momentum of 0.9. | ||
""" | ||
return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) | ||
|
||
def get_criterion(self, config: Config) -> _Loss: | ||
""" | ||
Return the loss for MNIST data. | ||
|
||
:param config: (Config) the Config object for this client. | ||
:return: (torch.nn.modules.loss._Loss) an instance of torch.nn.CrossEntropyLoss. | ||
""" | ||
return torch.nn.CrossEntropyLoss() | ||
|
||
|
||
class MnistNet(nn.Module): | ||
"""Implementation of the Mnist model.""" | ||
|
||
def __init__(self) -> None: | ||
"""Initialize an instance of MnistNet.""" | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(1, 8, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(8, 16, 5) | ||
self.fc1 = nn.Linear(16 * 4 * 4, 120) | ||
self.fc2 = nn.Linear(120, 10) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Perform a forward pass for the given tensor. | ||
|
||
:param x: (torch.Tensor) the tensor to perform the forward pass on. | ||
:return: (torch.Tensor) a result tensor after the forward pass. | ||
""" | ||
x = self.pool(f.relu(self.conv1(x))) | ||
x = self.pool(f.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 4 * 4) | ||
x = f.relu(self.fc1(x)) | ||
return f.relu(self.fc2(x)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Launchers for servers and clients.""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Classes and functions for monitoring of clients and servers' execution.""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,23 +9,25 @@ | |
|
||
|
||
class RedisMetricsReporter(MetricsReporter): # type: ignore | ||
"""Save the metrics to a Redis instance while it records them.""" | ||
""" | ||
Save the metrics to a Redis instance while it records them. | ||
|
||
def __init__( | ||
self, | ||
redis_connection: redis.client.Redis, | ||
run_id: Optional[str] = None, | ||
): | ||
Lazily instantiates a Redis connection when the first metrics are recorded. | ||
""" | ||
|
||
def __init__(self, host: str, port: str, run_id: Optional[str] = None): | ||
""" | ||
Init an instance of RedisMetricsReporter. | ||
|
||
:param redis_connection: (redis.client.Redis) the connection object to a Redis. Should be the output | ||
of redis.Redis(host=host, port=port) | ||
:param host: (str) The host address where the Redis instance is running. | ||
:param port: (str) The port where the Redis instance is running on the host. | ||
:param run_id: (Optional[str]) the identifier for the run which these metrics are from. | ||
It will be used as the name of the object in Redis. Optional, default is a random UUID. | ||
""" | ||
super().__init__(run_id) | ||
self.redis_connection = redis_connection | ||
self.host = host | ||
self.port = port | ||
self.redis_connection: Optional[redis.Redis] = None | ||
|
||
def add_to_metrics(self, data: Dict[str, Any]) -> None: | ||
""" | ||
|
@@ -51,7 +53,14 @@ def add_to_metrics_at_round(self, fl_round: int, data: Dict[str, Any]) -> None: | |
self.dump() | ||
|
||
def dump(self) -> None: | ||
"""Dump the current metrics to Redis under the run_id name.""" | ||
""" | ||
Comment on lines
55
to
+56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is beyond the scope of this PR, but I am curious if you think we should explore how to dump metrics to redis at more frequent intervals. If we only do so at the end, then in the case of a crash we lose the metrics. Do you think this is something worthwhile to explore or not really important enough as of now to be thinking about? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did it in this class, I added a call to |
||
Dump the current metrics to Redis under the run_id name. | ||
|
||
Will instantiate a Redis connection if it's the first time it runs for this instance. | ||
""" | ||
if self.redis_connection is None: | ||
self.redis_connection = redis.Redis(host=self.host, port=self.port) | ||
|
||
encoded_metrics = json.dumps(self.metrics, cls=DateTimeEncoder) | ||
log(DEBUG, f"Dumping metrics to redis at key '{self.run_id}': {encoded_metrics}") | ||
self.redis_connection.set(self.run_id, encoded_metrics) |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,9 +8,9 @@ | |
import torch | ||
from fl4health.server.base_server import FlServer | ||
|
||
from florist.api.launchers.launch import launch | ||
from florist.tests.utils.api.fl4health_utils import MnistClient, get_server_fedavg | ||
from florist.tests.utils.api.models import MnistNet | ||
from florist.api.launchers.local import launch | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good call, I like the distinction between local and "distributed" launchers that we will build out down the line |
||
from florist.api.clients.mnist import MnistClient, MnistNet | ||
from florist.tests.utils.api.fl4health_utils import get_server_fedavg | ||
|
||
|
||
def fit_config(batch_size: int, local_epochs: int, current_server_round: int) -> Dict[str, int]: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why we have to add the type: ignore in when we previously did not have to have one. Is it because you made the return type of get_dataloaders more specific?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because before it was in the testing folder. The static code checking for tests is mostly disabled because of mocking and other code practices that are OK to do in testing but not in code that runs in prod.