-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pick a database to store training results (#16)
* WIP Making the start server api, needs tests * Modifying integration test * Fixing test, adding code comments to integration test * Happy path test * Finished tests for server * Finished tests for client info * Finished tests for metrics * Adding inits * Adding batch_size and local_epochs to server params * Fixing additional test * WIP adding mongodb and create job route * WIP adding a test, need more setup * Small change * CR by John * Skipping one more security vulnerability with pillow * Moving test util classes to the right place, implementing fixture * Small code cleanup * Small code cleanup [2] * Better startup and shutdown * CR by John * Fixing pip-audit error
- Loading branch information
Showing
16 changed files
with
410 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,6 +46,10 @@ jobs: | |
uses: supercharge/[email protected] | ||
with: | ||
redis-version: 7.2.4 | ||
- name: Setup MongoDB | ||
uses: supercharge/[email protected] | ||
with: | ||
mongodb-version: 7.0.8 | ||
- name: Install dependencies and check code | ||
run: | | ||
poetry env use '3.9' | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[pytest] | ||
asyncio_mode = auto |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Classes and definitions for the database.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
"""Definitions for the MongoDB database entities.""" | ||
import uuid | ||
from typing import Annotated, Optional | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
from florist.api.servers.common import Model | ||
|
||
|
||
JOB_DATABASE_NAME = "job" | ||
|
||
|
||
class Job(BaseModel): | ||
"""Define the Job DB entity.""" | ||
|
||
id: str = Field(default_factory=uuid.uuid4, alias="_id") | ||
model: Optional[Annotated[Model, Field(...)]] | ||
redis_host: Optional[Annotated[str, Field(...)]] | ||
redis_port: Optional[Annotated[str, Field(...)]] | ||
|
||
class Config: | ||
"""MongoDB config for the Job DB entity.""" | ||
|
||
allow_population_by_field_name = True | ||
schema_extra = { | ||
"example": { | ||
"_id": "066de609-b04a-4b30-b46c-32537c7f1f6e", | ||
"model": "MNIST", | ||
"redis_host": "locahost", | ||
"redis_port": "6879", | ||
}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
"""FastAPI routes for the job.""" | ||
from typing import Any, Dict | ||
|
||
from fastapi import APIRouter, Body, Request, status | ||
from fastapi.encoders import jsonable_encoder | ||
|
||
from florist.api.db.entities import JOB_DATABASE_NAME, Job | ||
|
||
|
||
router = APIRouter() | ||
|
||
|
||
@router.post( | ||
path="/", | ||
response_description="Create a new job", | ||
status_code=status.HTTP_201_CREATED, | ||
response_model=Job, | ||
) | ||
async def new_job(request: Request, job: Job = Body(...)) -> Dict[str, Any]: # noqa: B008 | ||
""" | ||
Create a new training job. | ||
If calling from the REST API, it will receive the job attributes as the Request Body in raw/JSON format. | ||
See `florist.api.db.entities.Job` to check the list of attributes and their requirements. | ||
:param request: (fastapi.Request) the FastAPI request object. | ||
:param job: (Job) The Job instance to be saved in the database. | ||
:return: (Dict[str, Any]) A dictionary with the attributes of the new Job instance as saved in the database. | ||
""" | ||
json_job = jsonable_encoder(job) | ||
result = await request.app.database[JOB_DATABASE_NAME].insert_one(json_job) | ||
|
||
created_job = await request.app.database[JOB_DATABASE_NAME].find_one({"_id": result.inserted_id}) | ||
assert isinstance(created_job, dict) | ||
|
||
return created_job |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,31 @@ | ||
"""FLorist server FastAPI endpoints and routes.""" | ||
from contextlib import asynccontextmanager | ||
from typing import Any, AsyncGenerator | ||
|
||
from fastapi import FastAPI | ||
from motor.motor_asyncio import AsyncIOMotorClient | ||
|
||
from florist.api.routes.server.job import router as job_router | ||
from florist.api.routes.server.training import router as training_router | ||
|
||
|
||
app = FastAPI() | ||
MONGODB_URI = "mongodb://localhost:27017/" | ||
DATABASE_NAME = "florist-server" | ||
|
||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI) -> AsyncGenerator[Any, Any]: | ||
"""Set up function for app startup and shutdown.""" | ||
# Set up mongodb | ||
app.db_client = AsyncIOMotorClient(MONGODB_URI) # type: ignore[attr-defined] | ||
app.database = app.db_client[DATABASE_NAME] # type: ignore[attr-defined] | ||
|
||
yield | ||
|
||
# Shut down mongodb | ||
app.db_client.close() # type: ignore[attr-defined] | ||
|
||
|
||
app = FastAPI(lifespan=lifespan) | ||
app.include_router(training_router, tags=["training"], prefix="/api/server/training") | ||
app.include_router(job_router, tags=["job"], prefix="/api/server/job") |
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from unittest.mock import ANY | ||
|
||
from florist.api.db.entities import Job | ||
from florist.api.routes.server.job import new_job | ||
from florist.api.servers.common import Model | ||
from florist.tests.integration.api.utils import mock_request | ||
|
||
|
||
async def test_new_job(mock_request) -> None: | ||
test_empty_job = Job() | ||
result = await new_job(mock_request, test_empty_job) | ||
|
||
assert result == { | ||
"_id": ANY, | ||
"model": None, | ||
"redis_host": None, | ||
"redis_port": None, | ||
} | ||
assert isinstance(result["_id"], str) | ||
|
||
test_job = Job(id="test-id", model=Model.MNIST, redis_host="test-redis-host", redis_port="test-redis-port") | ||
result = await new_job(mock_request, test_job) | ||
|
||
assert result == { | ||
"_id": test_job.id, | ||
"model": test_job.model.value, | ||
"redis_host": test_job.redis_host, | ||
"redis_port": test_job.redis_port, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.