Skip to content

Commit

Permalink
Refactored server info into config parser
Browse files Browse the repository at this point in the history
  • Loading branch information
lotif committed May 10, 2024
1 parent ee47b17 commit f064952
Show file tree
Hide file tree
Showing 13 changed files with 153 additions and 163 deletions.
20 changes: 4 additions & 16 deletions florist/api/db/entities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Definitions for the MongoDB database entities."""

import json
import uuid
from enum import Enum
from typing import Annotated, List, Optional
Expand All @@ -9,6 +8,7 @@

from florist.api.clients.common import Client
from florist.api.servers.common import Model
from florist.api.servers.config_parsers import ConfigParser


JOB_COLLECTION_NAME = "job"
Expand Down Expand Up @@ -65,24 +65,12 @@ class Job(BaseModel):
status: JobStatus = Field(default=JobStatus.NOT_STARTED)
model: Optional[Annotated[Model, Field(...)]]
server_address: Optional[Annotated[str, Field(...)]]
server_info: Optional[Annotated[str, Field(...)]]
server_config: Optional[Annotated[str, Field(...)]]
config_parser: Optional[Annotated[ConfigParser, Field(...)]]
redis_host: Optional[Annotated[str, Field(...)]]
redis_port: Optional[Annotated[str, Field(...)]]
clients_info: Optional[Annotated[List[ClientInfo], Field(...)]]

@classmethod
def is_valid_server_info(cls, server_info: Optional[str]) -> bool:
"""
Validate if server info is a json string.
:param server_info: (str) the json string with the server info.
:return: True if server_info is None or a valid JSON string, False otherwise.
:raises: (json.JSONDecodeError) if there is an error decoding the server info into json
"""
if server_info is not None:
json.loads(server_info)
return True

class Config:
"""MongoDB config for the Job DB entity."""

Expand All @@ -93,7 +81,7 @@ class Config:
"status": "NOT_STARTED",
"model": "MNIST",
"server_address": "localhost:8080",
"server_info": '{"n_server_rounds": 3, "batch_size": 8}',
"server_config": '{"n_server_rounds": 3, "batch_size": 8}',
"redis_host": "localhost",
"redis_port": "6879",
"clients_info": [
Expand Down
49 changes: 0 additions & 49 deletions florist/api/models/common.py

This file was deleted.

15 changes: 1 addition & 14 deletions florist/api/models/mnist.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
"""Definitions for the MNIST model."""

from typing import List

import torch
import torch.nn.functional as f
from torch import nn

from florist.api.models.common import AbstractModel


class MnistNet(AbstractModel):
class MnistNet(nn.Module):
"""Implementation of the Mnist model."""

def __init__(self) -> None:
Expand All @@ -33,12 +29,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.view(-1, 16 * 4 * 4)
x = f.relu(self.fc1(x))
return f.relu(self.fc2(x))

@classmethod
def mandatory_server_info_fields(cls) -> List[str]:
"""
Define the list of mandatory server info fields for training the MNIST model.
:return: a list of mandatory fields for this model, namely `["n_server_rounds", "batch_size", "local_epochs"]`.
"""
return ["n_server_rounds", "batch_size", "local_epochs"]
12 changes: 1 addition & 11 deletions florist/api/routes/server/job.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""FastAPI routes for the job."""

from json import JSONDecodeError
from typing import Any, Dict, List

from fastapi import APIRouter, Body, HTTPException, Request, status
from fastapi import APIRouter, Body, Request, status
from fastapi.encoders import jsonable_encoder

from florist.api.db.entities import JOB_COLLECTION_NAME, MAX_RECORDS_TO_FETCH, Job, JobStatus
Expand All @@ -30,15 +29,6 @@ async def new_job(request: Request, job: Job = Body(...)) -> Dict[str, Any]: #
:return: (Dict[str, Any]) A dictionary with the attributes of the new Job instance as saved in the database.
:raises: (HTTPException) status 400 if job.server_info is not None and cannot be parsed into JSON.
"""
try:
is_valid = Job.is_valid_server_info(job.server_info)
if not is_valid:
msg = f"job.server_info is not valid. job.server_info: {job.server_info}."
raise HTTPException(status_code=400, detail=msg)
except JSONDecodeError as e:
msg = f"job.server_info could not be parsed into JSON. job.server_info: {job.server_info}. Error: {e}"
raise HTTPException(status_code=400, detail=msg) from e

json_job = jsonable_encoder(job)
result = await request.app.database[JOB_COLLECTION_NAME].insert_one(json_job)

Expand Down
15 changes: 10 additions & 5 deletions florist/api/routes/server/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from florist.api.db.entities import JOB_COLLECTION_NAME, Job
from florist.api.monitoring.metrics import wait_for_metric
from florist.api.servers.common import Model
from florist.api.servers.config_parsers import ConfigParser
from florist.api.servers.launch import launch_local_server


Expand Down Expand Up @@ -43,19 +44,23 @@ async def start(job_id: str, request: Request) -> JSONResponse:
result = await job_collection.find_one({"_id": job_id})
job = Job(**result)

if job.config_parser is None:
job.config_parser = ConfigParser.BASIC

assert job.model is not None, "Missing Job information: model"
assert job.server_info is not None, "Missing Job information: server_info"
assert job.server_config is not None, "Missing Job information: server_config"
assert job.clients_info is not None and len(job.clients_info) > 0, "Missing Job information: clients_info"
assert job.server_address is not None, "Missing Job information: server_address"
assert job.redis_host is not None, "Missing Job information: redis_host"
assert job.redis_port is not None, "Missing Job information: redis_port"

try:
assert Job.is_valid_server_info(job.server_info), "server_info is not valid"
config_parser = ConfigParser.class_for_parser(job.config_parser)
server_config = config_parser.parse(job.server_config)
except JSONDecodeError as err:
raise AssertionError("server_info is not valid") from err
raise AssertionError("server_config is not a valid json string.") from err

model_class = Model.class_for_model(job.model)
server_info = model_class.parse_server_info(job.server_info)

# Start the server
server_uuid, _ = launch_local_server(
Expand All @@ -64,7 +69,7 @@ async def start(job_id: str, request: Request) -> JSONResponse:
server_address=job.server_address,
redis_host=job.redis_host,
redis_port=job.redis_port,
**server_info,
**server_config,
)
wait_for_metric(server_uuid, "fit_start", job.redis_host, job.redis_port, logger=LOGGER)

Expand Down
5 changes: 3 additions & 2 deletions florist/api/servers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from enum import Enum
from typing import List

from florist.api.models.common import AbstractModel
from torch import nn

from florist.api.models.mnist import MnistNet


Expand All @@ -13,7 +14,7 @@ class Model(Enum):
MNIST = "MNIST"

@classmethod
def class_for_model(cls, model: "Model") -> type[AbstractModel]:
def class_for_model(cls, model: "Model") -> type[nn.Module]:
"""
Return the class for a given model.
Expand Down
72 changes: 72 additions & 0 deletions florist/api/servers/config_parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Parsers for FL server configurations."""

import json
from enum import Enum
from typing import Any, Dict, List


class BasicConfigParser:
"""Parser for basic server configurations."""

@classmethod
def mandatory_fields(cls) -> List[str]:
"""
Define the mandatory fields for basic configuration, namely `n_server_rounds`, `batch_size` and `local_epochs`.
:return: (List[str]) the list of fields for basic server configuration.
"""
return ["n_server_rounds", "batch_size", "local_epochs"]

@classmethod
def parse(cls, config_json_str: str) -> Dict[str, Any]:
"""
Parse a configuration JSON string into a dictionary.
:param config_json_str: (str) the configuration JSON string
:return: (Dict[str, Any]) The configuration JSON string parsed as a dictionary.
"""
config = json.loads(config_json_str)
assert isinstance(config, dict)

mandatory_fields = cls.mandatory_fields()

for mandatory_field in mandatory_fields:
if mandatory_field not in config:
raise IncompleteConfigError(f"Server config does not contain '{mandatory_field}'")

return config


class ConfigParser(Enum):
"""Enum to define the types of server configuration parsers."""

BASIC = "BASIC"

@classmethod
def class_for_parser(cls, config_parser: "ConfigParser") -> type[BasicConfigParser]:
"""
Return the class for a given config parser.
:param config_parser: (ConfigParser) The config parser enumeration instance.
:return: (type[BasicConfigParser]) A subclass of BasicConfigParser corresponding to the given config parser.
:raises ValueError: if the config_parser is not supported.
"""
if config_parser == ConfigParser.BASIC:
return BasicConfigParser

raise ValueError(f"Config parser {config_parser.value} not supported.")

@classmethod
def list(cls) -> List[str]:
"""
List all the supported config parsers.
:return: (List[str]) a list of supported config parsers.
"""
return [config_parser.value for config_parser in ConfigParser]


class IncompleteConfigError(Exception):
"""Defines errors in server info strings that have incomplete information."""

pass
Loading

0 comments on commit f064952

Please sign in to comment.