From de94897f7c94acfaf3acc58429f04ce93ce77a02 Mon Sep 17 00:00:00 2001 From: Marcelo Lotif Date: Wed, 13 Mar 2024 13:38:55 -0400 Subject: [PATCH] CR by John --- florist/api/client.py | 34 +-------------------------------- florist/api/clients/common.py | 36 +++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 33 deletions(-) create mode 100644 florist/api/clients/common.py diff --git a/florist/api/client.py b/florist/api/client.py index 714b7dc..8a145fe 100644 --- a/florist/api/client.py +++ b/florist/api/client.py @@ -1,15 +1,12 @@ """FLorist client FastAPI endpoints.""" import uuid -from enum import Enum from pathlib import Path -from typing import List import torch from fastapi import FastAPI from fastapi.responses import JSONResponse -from fl4health.clients.basic_client import BasicClient -from florist.api.clients.mnist import MnistClient +from florist.api.clients.common import Clients from florist.api.launchers.local import launch_client from florist.api.monitoring.metrics import RedisMetricsReporter @@ -19,35 +16,6 @@ app = FastAPI() -class Clients(Enum): - """Enumeration of supported clients.""" - - MNIST = "MNIST" - - @classmethod - def class_for_client(cls, client: "Clients") -> type[BasicClient]: - """ - Return the class for a given client. - - :param client: The client enumeration object. - :return: A subclass of BasicClient corresponding to the given client. - :raises ValueError: if the client is not supported. - """ - if client == Clients.MNIST: - return MnistClient - - raise ValueError(f"Client {client.value} not supported.") - - @classmethod - def list(cls) -> List[str]: - """ - List all the supported clients. - - :return: a list of supported clients. - """ - return [client.value for client in Clients] - - @app.get("/api/client/connect") def connect() -> JSONResponse: """ diff --git a/florist/api/clients/common.py b/florist/api/clients/common.py new file mode 100644 index 0000000..121f06a --- /dev/null +++ b/florist/api/clients/common.py @@ -0,0 +1,36 @@ +"""Common functions and definitions for clients.""" +from enum import Enum +from typing import List + +from fl4health.clients.basic_client import BasicClient + +from florist.api.clients.mnist import MnistClient + + +class Clients(Enum): + """Enumeration of supported clients.""" + + MNIST = "MNIST" + + @classmethod + def class_for_client(cls, client: "Clients") -> type[BasicClient]: + """ + Return the class for a given client. + + :param client: The client enumeration object. + :return: A subclass of BasicClient corresponding to the given client. + :raises ValueError: if the client is not supported. + """ + if client == Clients.MNIST: + return MnistClient + + raise ValueError(f"Client {client.value} not supported.") + + @classmethod + def list(cls) -> List[str]: + """ + List all the supported clients. + + :return: a list of supported clients. + """ + return [client.value for client in Clients]