Skip to content

Commit

Permalink
CR by John
Browse files Browse the repository at this point in the history
  • Loading branch information
lotif committed Mar 13, 2024
1 parent 5dc9b2e commit de94897
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 33 deletions.
34 changes: 1 addition & 33 deletions florist/api/client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
"""FLorist client FastAPI endpoints."""
import uuid
from enum import Enum
from pathlib import Path
from typing import List

import torch
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fl4health.clients.basic_client import BasicClient

from florist.api.clients.mnist import MnistClient
from florist.api.clients.common import Clients
from florist.api.launchers.local import launch_client
from florist.api.monitoring.metrics import RedisMetricsReporter

Expand All @@ -19,35 +16,6 @@
app = FastAPI()


class Clients(Enum):
"""Enumeration of supported clients."""

MNIST = "MNIST"

@classmethod
def class_for_client(cls, client: "Clients") -> type[BasicClient]:
"""
Return the class for a given client.
:param client: The client enumeration object.
:return: A subclass of BasicClient corresponding to the given client.
:raises ValueError: if the client is not supported.
"""
if client == Clients.MNIST:
return MnistClient

raise ValueError(f"Client {client.value} not supported.")

@classmethod
def list(cls) -> List[str]:
"""
List all the supported clients.
:return: a list of supported clients.
"""
return [client.value for client in Clients]


@app.get("/api/client/connect")
def connect() -> JSONResponse:
"""
Expand Down
36 changes: 36 additions & 0 deletions florist/api/clients/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Common functions and definitions for clients."""
from enum import Enum
from typing import List

from fl4health.clients.basic_client import BasicClient

from florist.api.clients.mnist import MnistClient


class Clients(Enum):
"""Enumeration of supported clients."""

MNIST = "MNIST"

@classmethod
def class_for_client(cls, client: "Clients") -> type[BasicClient]:
"""
Return the class for a given client.
:param client: The client enumeration object.
:return: A subclass of BasicClient corresponding to the given client.
:raises ValueError: if the client is not supported.
"""
if client == Clients.MNIST:
return MnistClient

raise ValueError(f"Client {client.value} not supported.")

@classmethod
def list(cls) -> List[str]:
"""
List all the supported clients.
:return: a list of supported clients.
"""
return [client.value for client in Clients]

0 comments on commit de94897

Please sign in to comment.