Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start training button #67

Merged
merged 25 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9cc3d7f
Add Start button to Not Started Jobs
jewelltaylor Jun 18, 2024
61bbeb2
Add skeleton of handle click function
jewelltaylor Jun 19, 2024
75181fa
Add change status endpoint
jewelltaylor Jun 19, 2024
6185377
Add tests for change job status endpoint
jewelltaylor Jun 19, 2024
150cfb0
Add back accidentally deleted test file
jewelltaylor Jun 19, 2024
dd491e1
Merge branch 'change-status' into start-training-button
jewelltaylor Jun 19, 2024
cc7884b
Update TableRow component to include handle click function
jewelltaylor Jun 19, 2024
eba6b92
Small revisions trying to figure out reload after data is added
jewelltaylor Jun 19, 2024
15468ef
Move change status in train endpoint
jewelltaylor Jun 20, 2024
a3fc65e
Add general exception to catch things other than AssertionError
jewelltaylor Jun 20, 2024
edba53f
Merge branch 'change-status' into start-training-button
jewelltaylor Jun 20, 2024
9cf850b
Fix presets of Job and Client model to match ports in README
jewelltaylor Jun 20, 2024
3b68a0f
Add change of status to finished with error in case of start training…
jewelltaylor Jun 20, 2024
455256d
Update test to test non assertion error case and that it returns 500 …
jewelltaylor Jun 20, 2024
4363ca6
Merge branch 'change-status' into start-training-button
jewelltaylor Jun 20, 2024
5e85007
Start to add tests
jewelltaylor Jun 20, 2024
7480179
Avoid conditional calling of usePost
jewelltaylor Jun 20, 2024
152e258
Fix test_training tests that were broken with the changing of placeme…
jewelltaylor Jun 21, 2024
f8d1251
Add index to table rows for unique identification
jewelltaylor Jun 21, 2024
4d5b29c
Add more comprehensive tests for start button
jewelltaylor Jun 21, 2024
80127ce
Add some comments and fix default values in job model
jewelltaylor Jun 21, 2024
ba3274c
Address CRs by Marcelo
jewelltaylor Jun 26, 2024
7929594
Merge branch 'main' into start-training-button
jewelltaylor Jun 26, 2024
c3247b3
Encapsulate start training logic in start training button. Change to …
jewelltaylor Jun 26, 2024
be54640
Fix icon className so it properly displays
jewelltaylor Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions florist/api/db/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ class Config:
schema_extra = {
"example": {
"client": "MNIST",
"service_address": "locahost:8081",
"service_address": "localhost:8001",
jewelltaylor marked this conversation as resolved.
Show resolved Hide resolved
"data_path": "path/to/data",
"redis_host": "localhost",
"redis_port": "6880",
"redis_port": "6380",
"uuid": "0c316680-1375-4e07-84c3-a732a2e6d03f",
"metrics": '{"type": "client", "initialized": "2024-03-25 11:20:56.819569", "rounds": {"1": {"fit_start": "2024-03-25 11:20:56.827081"}}}',
},
Expand Down Expand Up @@ -216,19 +216,19 @@ class Config:
"_id": "066de609-b04a-4b30-b46c-32537c7f1f6e",
"status": "NOT_STARTED",
"model": "MNIST",
"server_address": "localhost:8080",
"server_config": '{"n_server_rounds": 3, "batch_size": 8}',
"server_address": "localhost:8000",
"server_config": '{"n_server_rounds": 3, "batch_size": 8, "local_epochs": 1}',
"server_uuid": "d73243cf-8b89-473b-9607-8cd0253a101d",
"server_metrics": '{"type": "server", "fit_start": "2024-04-23 15:33:12.865604", "rounds": {"1": {"fit_start": "2024-04-23 15:33:12.869001"}}}',
"redis_host": "localhost",
"redis_port": "6879",
"redis_port": "6379",
"clients_info": [
{
"client": "MNIST",
"service_address": "locahost:8081",
"service_address": "localhost:8001",
"data_path": "path/to/data",
"redis_host": "localhost",
"redis_port": "6880",
"redis_port": "6380",
"client_uuid": "0c316680-1375-4e07-84c3-a732a2e6d03f",
},
],
Expand Down
9 changes: 7 additions & 2 deletions florist/api/routes/server/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks
If not successful, returns the appropriate error code with a JSON with the format below:
{"error": <error message>}
"""
job = None

try:
job = await Job.find_by_id(job_id, request.app.database)

assert job is not None, f"Job with id {job_id} not found."
jewelltaylor marked this conversation as resolved.
Show resolved Hide resolved

assert job.status == JobStatus.NOT_STARTED, f"Job status ({job.status.value}) is not NOT_STARTED"
await job.set_status(JobStatus.IN_PROGRESS, request.app.database)

if job.config_parser is None:
job.config_parser = ConfigParser.BASIC
Expand Down Expand Up @@ -101,7 +103,6 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks

client_uuids.append(json_response["uuid"])

await job.set_status(JobStatus.IN_PROGRESS, request.app.database)
await job.set_uuids(server_uuid, client_uuids, request.app.database)

# Start the server training listener as a background task to update
Expand All @@ -112,10 +113,14 @@ async def start(job_id: str, request: Request, background_tasks: BackgroundTasks
return JSONResponse({"server_uuid": server_uuid, "client_uuids": client_uuids})

except AssertionError as err:
if job is not None:
await job.set_status(JobStatus.FINISHED_WITH_ERROR, request.app.database)
return JSONResponse(content={"error": str(err)}, status_code=400)

except Exception as ex:
LOGGER.exception(ex)
if job is not None:
await job.set_status(JobStatus.FINISHED_WITH_ERROR, request.app.database)
return JSONResponse({"error": str(ex)}, status_code=500)


Expand Down
6 changes: 5 additions & 1 deletion florist/app/jobs/hooks.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { useState } from "react";
import useSWR from "swr";
import useSWR, { mutate } from "swr";

import { fetcher } from "../client_imports";

Expand Down Expand Up @@ -47,3 +47,7 @@ export const usePost = () => {

return { post, response, isLoading, error };
};

export function refreshJobsByJobStatus(statuses: Array<string>) {
statuses.forEach((status: string) => mutate(`/api/server/job/${status}`));
}
60 changes: 56 additions & 4 deletions florist/app/jobs/page.tsx
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"use client";

import { useEffect } from "react";
import { ReactElement } from "react/React";

import { useGetJobsByJobStatus } from "./hooks";
import { refreshJobsByJobStatus, useGetJobsByJobStatus, usePost } from "./hooks";

import Link from "next/link";
import Image from "next/image";
Expand All @@ -18,6 +19,7 @@ export const validStatuses = {
};

interface JobData {
_id: string;
status: string;
model: string;
server_address: string;
Expand Down Expand Up @@ -90,6 +92,39 @@ export function NewJobButton(): ReactElement {
);
}

export function StartJobButton({ rowId, jobId }: { rowId: number; jobId: string }): ReactElement {
const { post, response, isLoading, error } = usePost();

const handleClickStartJobButton = async () => {
event.preventDefault();

if (isLoading) {
// Preventing double submit if already in progress
return;
}

const queryParams = new URLSearchParams({ job_id: jobId });
const url = `/api/server/training/start?${queryParams.toString()}`;
await post(url, JSON.stringify({}));
};

// Only refresh the job data if there is an error or response
useEffect(() => refreshJobsByJobStatus(Object.keys(validStatuses)), [error, response]);

return (
<div>
<button
data-testid={`start-training-button-${rowId}`}
onClick={handleClickStartJobButton}
className="btn btn-primary btn-sm mb-0"
title="Start"
>
<i className="material-icons text-sm">play_circle_outline</i>
</button>
</div>
);
}

export function Status({ status, data }: { status: StatusProp; data: Object }): ReactElement {
return (
<div className="row">
Expand Down Expand Up @@ -126,9 +161,10 @@ export function StatusTable({ data, status }: { data: Array<JobData>; status: St
<th className="text-uppercase text-secondary text-xxs font-weight-bolder opacity-7">
Client Service Addresses
</th>
<th></th>
</tr>
</thead>
<TableRows data={data} />
<TableRows data={data} status={status} />
</table>
</div>
</div>
Expand All @@ -144,26 +180,41 @@ export function StatusTable({ data, status }: { data: Array<JobData>; status: St
}
}

export function TableRows({ data }: { data: Array<JobData> }): ReactElement {
export function TableRows({ data, status }: { data: Array<JobData>; status: StatusProp }): ReactElement {
const tableRows = data.map((d, i) => (
<TableRow key={i} model={d.model} serverAddress={d.server_address} clientsInfo={d.clients_info} />
<TableRow
key={i}
rowId={i}
model={d.model}
serverAddress={d.server_address}
clientsInfo={d.clients_info}
status={status}
jobId={d._id}
/>
));

return <tbody>{tableRows}</tbody>;
}

export function TableRow({
rowId,
model,
serverAddress,
clientsInfo,
status,
jobId,
}: {
rowId: number;
model: string;
serverAddress: string;
clientsInfo: Array<ClientInfo>;
status: StatusProp;
jobId: string;
}): ReactElement {
if (clientsInfo === null) {
return <td />;
}

return (
<tr>
<td>
Expand All @@ -183,6 +234,7 @@ export function TableRow({
</span>
</div>
</td>
<td>{validStatuses[status] == "Not Started" ? <StartJobButton rowId={rowId} jobId={jobId} /> : null}</td>
</tr>
);
}
Expand Down
26 changes: 17 additions & 9 deletions florist/tests/unit/api/routes/server/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
from pytest import raises
from typing import Dict, Any, Tuple
from unittest.mock import Mock, patch, ANY
from unittest.mock import AsyncMock, Mock, patch, ANY

from florist.api.db.entities import Job, JobStatus, JOB_COLLECTION_NAME
from florist.api.models.mnist import MnistNet
Expand Down Expand Up @@ -135,7 +135,8 @@ async def test_start_fail_unsupported_client() -> None:
assert "value is not a valid enumeration member" in json_body["error"]


async def test_start_fail_missing_info() -> None:
@patch("florist.api.db.entities.Job.set_status")
async def test_start_fail_missing_info(mock_set_status: Mock) -> None:
fields_to_be_removed = ["model", "server_config", "clients_info", "server_address", "redis_host", "redis_port"]

for field_to_be_removed in fields_to_be_removed:
Expand All @@ -154,7 +155,8 @@ async def test_start_fail_missing_info() -> None:
assert f"Missing Job information: {field_to_be_removed}" in json_body["error"]


async def test_start_fail_invalid_server_config() -> None:
@patch("florist.api.db.entities.Job.set_status")
async def test_start_fail_invalid_server_config(mock_set_status: Mock) -> None:
# Arrange
test_job_id = "test-job-id"
_, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks()
Expand All @@ -170,7 +172,8 @@ async def test_start_fail_invalid_server_config() -> None:
assert f"server_config is not a valid json string." in json_body["error"]


async def test_start_fail_empty_clients_info() -> None:
@patch("florist.api.db.entities.Job.set_status")
async def test_start_fail_empty_clients_info(_: Mock) -> None:
# Arrange
test_job_id = "test-job-id"
_, test_job, _, mock_fastapi_request = _setup_test_job_and_mocks()
Expand All @@ -186,8 +189,9 @@ async def test_start_fail_empty_clients_info() -> None:
assert f"Missing Job information: clients_info" in json_body["error"]


@patch("florist.api.db.entities.Job.set_status")
@patch("florist.api.routes.server.training.launch_local_server")
async def test_start_launch_server_exception(mock_launch_local_server: Mock) -> None:
async def test_start_launch_server_exception(mock_launch_local_server: Mock, _: Mock) -> None:
# Arrange
test_job_id = "test-job-id"
_, _, _, mock_fastapi_request = _setup_test_job_and_mocks()
Expand All @@ -204,9 +208,10 @@ async def test_start_launch_server_exception(mock_launch_local_server: Mock) ->
assert json_body == {"error": str(test_exception)}


@patch("florist.api.db.entities.Job.set_status")
@patch("florist.api.routes.server.training.launch_local_server")
@patch("florist.api.monitoring.metrics.redis")
async def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_local_server: Mock) -> None:
async def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_local_server: Mock, _: Mock) -> None:
# Arrange
test_job_id = "test-job-id"
_, _, _, mock_fastapi_request = _setup_test_job_and_mocks()
Expand All @@ -226,10 +231,11 @@ async def test_start_wait_for_metric_exception(mock_redis: Mock, mock_launch_loc
assert json_body == {"error": str(test_exception)}


@patch("florist.api.db.entities.Job.set_status")
@patch("florist.api.routes.server.training.launch_local_server")
@patch("florist.api.monitoring.metrics.redis")
@patch("florist.api.monitoring.metrics.time") # just so time.sleep does not actually sleep
async def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None:
async def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_launch_local_server: Mock, mock_set_status: Mock) -> None:
# Arrange
test_job_id = "test-job-id"
_, _, _, mock_fastapi_request = _setup_test_job_and_mocks()
Expand All @@ -250,10 +256,11 @@ async def test_start_wait_for_metric_timeout(_: Mock, mock_redis: Mock, mock_lau
assert json_body == {"error": "Metric 'fit_start' not been found after 20 retries."}


@patch("florist.api.db.entities.Job.set_status")
@patch("florist.api.routes.server.training.launch_local_server")
@patch("florist.api.monitoring.metrics.redis")
@patch("florist.api.routes.server.training.requests")
async def test_start_fail_response(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None:
async def test_start_fail_response(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock, _: Mock) -> None:
# Arrange
test_job_id = "test-job-id"
_, _, _, mock_fastapi_request = _setup_test_job_and_mocks()
Expand All @@ -279,10 +286,11 @@ async def test_start_fail_response(mock_requests: Mock, mock_redis: Mock, mock_l
assert json_body == {"error": f"Client response returned 403. Response: error"}


@patch("florist.api.db.entities.Job.set_status")
@patch("florist.api.routes.server.training.launch_local_server")
@patch("florist.api.monitoring.metrics.redis")
@patch("florist.api.routes.server.training.requests")
async def test_start_no_uuid_in_response(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock) -> None:
async def test_start_no_uuid_in_response(mock_requests: Mock, mock_redis: Mock, mock_launch_local_server: Mock, _: Mock) -> None:
# Arrange
test_job_id = "test-job-id"
_, _, _, mock_fastapi_request = _setup_test_job_and_mocks()
Expand Down
Loading
Loading