Skip to content

Commit

Permalink
Pick a database to store training results (#16)
Browse files Browse the repository at this point in the history
* WIP Making the start server api, needs tests

* Modifying integration test

* Fixing test, adding code comments to integration test

* Happy path test

* Finished tests for server

* Finished tests for client info

* Finished tests for metrics

* Adding inits

* Adding batch_size and local_epochs to server params

* Fixing additional test

* WIP adding mongodb and create job route

* WIP adding a test, need more setup

* Small change

* CR by John

* Skipping one more security vulnerability with pillow

* Moving test util classes to the right place, implementing fixture

* Small code cleanup

* Small code cleanup [2]

* Better startup and shutdown

* CR by John

* Fixing pip-audit error
  • Loading branch information
lotif authored Apr 12, 2024
1 parent e694fe7 commit 6281fd7
Show file tree
Hide file tree
Showing 16 changed files with 410 additions and 47 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/integration_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ jobs:
uses: supercharge/[email protected]
with:
redis-version: 7.2.4
- name: Setup MongoDB
uses: supercharge/[email protected]
with:
mongodb-version: 7.0.8
- name: Install dependencies and check code
run: |
poetry env use '3.9'
Expand Down
2 changes: 2 additions & 0 deletions .pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
asyncio_mode = auto
10 changes: 7 additions & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,11 @@ To run the unit tests, simply execute:
pytest florist/tests/unit
```

To run the integration tests, first make sure you have a Redis server running on your
local machine on port 6379, then execute:
To run the integration tests, first make sure you:
- Have a Redis server running on your local machine on port 6379 by following [these instructions](README.md#start-servers-redis-instance).
- Have a MongoDB server running on your local machine on port 27017 by following [these instructions](README.md#start-mongodbs-instance).

Then execute:
```shell
pytest florist/tests/integration
```
Expand All @@ -73,7 +76,8 @@ For code style, we recommend the [PEP 8 style guide](https://peps.python.org/pep
For docstrings we use [numpy format](https://numpydoc.readthedocs.io/en/latest/format.html).

We use [ruff](https://docs.astral.sh/ruff/) for code formatting and static code
analysis. Ruff checks various rules including [flake8](https://docs.astral.sh/ruff/faq/#how-does-ruff-compare-to-flake8). The pre-commit hooks show errors which you need to fix before submitting a PR.
analysis. Ruff checks various rules including [flake8](https://docs.astral.sh/ruff/faq/#how-does-ruff-compare-to-flake8). The pre-commit hooks
show errors which you need to fix before submitting a PR.

Last but not the least, we use type hints in our code which is then checked using
[mypy](https://mypy.readthedocs.io/en/stable/).
Expand Down
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ yarn

### Pulling Redis' Docker

Redis is used to fetch the metrics reported by servers and clients during their runs.
[Redis](https://redis.io/) is used to fetch the metrics reported by servers and clients during their runs.


If you don't have Docker installed, follow [these instructions](https://docs.docker.com/desktop/)
Expand All @@ -47,8 +47,31 @@ to install it. Then, pull [Redis' official docker image](https://hub.docker.com/
docker pull redis:7.2.4
```

### Pulling MongoDB's Docker

[MongoDB](https://www.mongodb.com) is used to store information about the training jobs.

If you don't have Docker installed, follow [these instructions](https://docs.docker.com/desktop/)
to install it. Then, pull [MongoDB' official docker image](https://hub.docker.com/_/mongo)
(we currently use version 7.0.8):
```shell
docker pull mongo:7.0.8
```

## Running the server

### Start MongoDB's instance

If it's your first time running it, create a container and run it with the command below:
```shell
docker run --name mongodb-florist -d -p 27017:27017 mongo:7.0.8
```

From the second time on, you can just start it:
```shell
docker start mongodb-florist
```

### Start server's Redis instance

If it's your first time running it, create a container and run it with the command below:
Expand Down
2 changes: 1 addition & 1 deletion florist/api/clients/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from florist.api.models.mnist import MnistNet


class MnistClient(BasicClient): # type: ignore
class MnistClient(BasicClient): # type: ignore[misc]
"""Implementation of the MNIST client."""

def get_data_loaders(self, config: Config) -> Tuple[DataLoader[MnistDataset], DataLoader[MnistDataset]]:
Expand Down
1 change: 1 addition & 0 deletions florist/api/db/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Classes and definitions for the database."""
32 changes: 32 additions & 0 deletions florist/api/db/entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Definitions for the MongoDB database entities."""
import uuid
from typing import Annotated, Optional

from pydantic import BaseModel, Field

from florist.api.servers.common import Model


JOB_DATABASE_NAME = "job"


class Job(BaseModel):
"""Define the Job DB entity."""

id: str = Field(default_factory=uuid.uuid4, alias="_id")
model: Optional[Annotated[Model, Field(...)]]
redis_host: Optional[Annotated[str, Field(...)]]
redis_port: Optional[Annotated[str, Field(...)]]

class Config:
"""MongoDB config for the Job DB entity."""

allow_population_by_field_name = True
schema_extra = {
"example": {
"_id": "066de609-b04a-4b30-b46c-32537c7f1f6e",
"model": "MNIST",
"redis_host": "locahost",
"redis_port": "6879",
},
}
36 changes: 36 additions & 0 deletions florist/api/routes/server/job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""FastAPI routes for the job."""
from typing import Any, Dict

from fastapi import APIRouter, Body, Request, status
from fastapi.encoders import jsonable_encoder

from florist.api.db.entities import JOB_DATABASE_NAME, Job


router = APIRouter()


@router.post(
path="/",
response_description="Create a new job",
status_code=status.HTTP_201_CREATED,
response_model=Job,
)
async def new_job(request: Request, job: Job = Body(...)) -> Dict[str, Any]: # noqa: B008
"""
Create a new training job.
If calling from the REST API, it will receive the job attributes as the Request Body in raw/JSON format.
See `florist.api.db.entities.Job` to check the list of attributes and their requirements.
:param request: (fastapi.Request) the FastAPI request object.
:param job: (Job) The Job instance to be saved in the database.
:return: (Dict[str, Any]) A dictionary with the attributes of the new Job instance as saved in the database.
"""
json_job = jsonable_encoder(job)
result = await request.app.database[JOB_DATABASE_NAME].insert_one(json_job)

created_job = await request.app.database[JOB_DATABASE_NAME].find_one({"_id": result.inserted_id})
assert isinstance(created_job, dict)

return created_job
25 changes: 24 additions & 1 deletion florist/api/server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,31 @@
"""FLorist server FastAPI endpoints and routes."""
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator

from fastapi import FastAPI
from motor.motor_asyncio import AsyncIOMotorClient

from florist.api.routes.server.job import router as job_router
from florist.api.routes.server.training import router as training_router


app = FastAPI()
MONGODB_URI = "mongodb://localhost:27017/"
DATABASE_NAME = "florist-server"


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[Any, Any]:
"""Set up function for app startup and shutdown."""
# Set up mongodb
app.db_client = AsyncIOMotorClient(MONGODB_URI) # type: ignore[attr-defined]
app.database = app.db_client[DATABASE_NAME] # type: ignore[attr-defined]

yield

# Shut down mongodb
app.db_client.close() # type: ignore[attr-defined]


app = FastAPI(lifespan=lifespan)
app.include_router(training_router, tags=["training"], prefix="/api/server/training")
app.include_router(job_router, tags=["job"], prefix="/api/server/job")
Empty file.
Empty file.
Empty file.
29 changes: 29 additions & 0 deletions florist/tests/integration/api/routes/server/test_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from unittest.mock import ANY

from florist.api.db.entities import Job
from florist.api.routes.server.job import new_job
from florist.api.servers.common import Model
from florist.tests.integration.api.utils import mock_request


async def test_new_job(mock_request) -> None:
test_empty_job = Job()
result = await new_job(mock_request, test_empty_job)

assert result == {
"_id": ANY,
"model": None,
"redis_host": None,
"redis_port": None,
}
assert isinstance(result["_id"], str)

test_job = Job(id="test-id", model=Model.MNIST, redis_host="test-redis-host", redis_port="test-redis-port")
result = await new_job(mock_request, test_job)

assert result == {
"_id": test_job.id,
"model": test_job.model.value,
"redis_host": test_job.redis_host,
"redis_port": test_job.redis_port,
}
41 changes: 41 additions & 0 deletions florist/tests/integration/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import contextlib
import pytest
import time
import threading
import uvicorn

from motor.motor_asyncio import AsyncIOMotorClient
from starlette.requests import Request

from florist.api.server import MONGODB_URI


class TestUvicornServer(uvicorn.Server):
def install_signal_handlers(self):
Expand All @@ -19,3 +25,38 @@ def run_in_thread(self):
finally:
self.should_exit = True
thread.join()


class MockApp:
def __init__(self, database_name: str):
self.db_client = AsyncIOMotorClient(MONGODB_URI)
self.database = self.db_client[database_name]


class MockRequest(Request):
def __init__(self, app: MockApp):
super().__init__({"type": "http"})
self._app = app

@property
def app(self):
return self._app

@app.setter
def app(self, value):
self._app = value


TEST_DATABASE_NAME = "test-database"


@pytest.fixture
async def mock_request() -> MockRequest:
print(f"Creating test detabase '{TEST_DATABASE_NAME}'")
app = MockApp(TEST_DATABASE_NAME)
request = MockRequest(app)

yield request

print(f"Deleting test detabase '{TEST_DATABASE_NAME}'")
await app.db_client.drop_database(TEST_DATABASE_NAME)
Loading

0 comments on commit 6281fd7

Please sign in to comment.