diff --git a/frontend/.gitignore b/frontend/.gitignore
index a547bf36..e440e68b 100644
--- a/frontend/.gitignore
+++ b/frontend/.gitignore
@@ -22,3 +22,6 @@ dist-ssr
*.njsproj
*.sln
*.sw?
+
+tsconfig.app.tsbuildinfo
+tsconfig.node.tsbuildinfo
\ No newline at end of file
diff --git a/frontend/src/app/providers/model-creation-provider.tsx b/frontend/src/app/providers/model-creation-provider.tsx
index 470de263..2ae7b75b 100644
--- a/frontend/src/app/providers/model-creation-provider.tsx
+++ b/frontend/src/app/providers/model-creation-provider.tsx
@@ -11,8 +11,10 @@ import { UseMutationResult } from "@tanstack/react-query";
import React, { createContext, useContext, useEffect, useState } from "react";
import { useToast } from "./toast-provider";
import { useNavigate } from "react-router-dom";
-import { TTrainingDataset } from "@/types";
+import { TModel, TTrainingDataset } from "@/types";
import { TCreateTrainingDatasetArgs } from "@/features/model-creation/api/create-trainings";
+import { useCreateModel, useCreateModelTrainingRequest } from "@/features/model-creation/hooks/use-models";
+import { TCreateModelArgs } from "@/features/model-creation/api/create-models";
// The names here is the same with the initialFormState object keys as well as the form validation config
export enum MODEL_CREATION_FORM_NAME {
@@ -30,6 +32,7 @@ export enum MODEL_CREATION_FORM_NAME {
TMS_URL = "tmsURL",
TMS_URL_VALIDITY = "tmsURLValidation",
SELECTED_TRAINING_DATASET_ID = "selectedTrainingDatasetId",
+ TRAINING_AREAS = "trainingAreas",
}
export const FORM_VALIDATION_CONFIG = {
@@ -81,6 +84,7 @@ const initialFormState = {
// training dataset selection
selectedTrainingDatasetId: "",
zoomLevels: [20, 21],
+ trainingAreas: [],
// Defaults to basic configurations
trainingType: TrainingType.BASIC,
epoch: 2,
@@ -107,6 +111,12 @@ const ModelCreationFormContext = createContext<{
TCreateTrainingDatasetArgs,
unknown
>;
+ createNewModelMutation: UseMutationResult<
+ TModel,
+ Error,
+ TCreateModelArgs,
+ unknown
+ >;
}>({
formData: initialFormState,
setFormData: () => {},
@@ -117,6 +127,12 @@ const ModelCreationFormContext = createContext<{
TCreateTrainingDatasetArgs,
unknown
>,
+ createNewModelMutation: {} as UseMutationResult<
+ TModel,
+ Error,
+ TCreateModelArgs,
+ unknown
+ >,
});
export const ModelCreationFormProvider: React.FC<{
@@ -142,6 +158,15 @@ export const ModelCreationFormProvider: React.FC<{
) => {
setFormData((prev) => ({ ...prev, [field]: value }));
};
+ const trainingRequestMutation = useCreateModelTrainingRequest({mutationConfig:{
+ onSuccess:()=>{
+ notify("Training request submitted successfully", "success");
+ },
+ onError: (error) => {
+ const errorText = error?.response?.data[0] ?? "An error ocurred while submitting training request"
+ notify(errorText, "danger");
+ },
+ }});
const createNewTrainingDatasetMutation = useCreateTrainingDataset({
mutationConfig: {
@@ -157,7 +182,36 @@ export const ModelCreationFormProvider: React.FC<{
navigate(APPLICATION_ROUTES.CREATE_NEW_MODEL_TRAINING_AREA);
},
onError: () => {
- notify("Error creating dataset", "danger");
+
+ notify("An error occurred while creating dataset", "danger");
+ },
+ },
+ });
+
+ const createNewModelMutation = useCreateModel({
+ mutationConfig: {
+ onSuccess: (data) => {
+ notify("Model created successfully", "success");
+ // Submit the model for training request
+ trainingRequestMutation.mutate({
+ model:data.id,
+ input_boundary_width:formData.boundaryWidth,
+ input_contact_spacing:formData.contactSpacing,
+ epochs:formData.epoch,
+ batch_size:formData.batchSize,
+ zoom_level:formData.zoomLevels
+
+ })
+
+ setFormData(initialFormState);
+
+ navigate(
+ `${APPLICATION_ROUTES.CREATE_NEW_MODEL_CONFIRMATION}?id=${data.id}`,
+
+ );
+ },
+ onError: () => {
+ notify("An error ocurred while creating model", "danger");
},
},
});
@@ -176,6 +230,7 @@ export const ModelCreationFormProvider: React.FC<{
setFormData,
handleChange,
createNewTrainingDatasetMutation,
+ createNewModelMutation,
}}
>
{children}
diff --git a/frontend/src/app/routes/models/model-details.tsx b/frontend/src/app/routes/models/model-details.tsx
index bf78ead1..f9f028be 100644
--- a/frontend/src/app/routes/models/model-details.tsx
+++ b/frontend/src/app/routes/models/model-details.tsx
@@ -17,7 +17,8 @@ import { useDialog } from "@/hooks/use-dialog";
import { APP_CONTENT, APPLICATION_ROUTES } from "@/utils";
import { useEffect } from "react";
import { useNavigate, useParams } from "react-router-dom";
-
+import TrainingInProgressImage from "@/assets/images/training_in_prorgress.png";
+import { Image } from "@/components/ui/image";
export const ModelDetailsPage = () => {
const { id } = useParams<{ id: string }>();
const { isOpened, closeDialog, openDialog } = useDialog();
@@ -64,10 +65,23 @@ export const ModelDetailsPage = () => {
-
+ {!data?.published_training ? (
+
+
+
+ Model training is not activated yet. Properties will be
+ available after a successful and activated training.
+
+
+ ) : (
+
+ )}
=> {
+ return await (
+ await apiClient.post(`${API_ENDPOINTS.CREATE_MODELS}`, {
+ dataset,
+ name,
+ description,
+ base_model,
+ })
+ ).data;
+};
diff --git a/frontend/src/features/model-creation/api/create-trainings.ts b/frontend/src/features/model-creation/api/create-trainings.ts
index 44ef57d0..e2be1d18 100644
--- a/frontend/src/features/model-creation/api/create-trainings.ts
+++ b/frontend/src/features/model-creation/api/create-trainings.ts
@@ -1,5 +1,9 @@
import { API_ENDPOINTS, apiClient } from "@/services";
-import { TTrainingAreaFeature, TTrainingDataset } from "@/types";
+import {
+ TTrainingAreaFeature,
+ TTrainingDataset,
+ TTrainingDetails,
+} from "@/types";
export type TCreateTrainingDatasetArgs = {
name: string;
@@ -13,7 +17,7 @@ export const createTrainingDataset = async ({
status = 0,
}: TCreateTrainingDatasetArgs): Promise => {
return await (
- await apiClient.post(`${API_ENDPOINTS.CREATE_TRAINING_DATASETS}`, {
+ await apiClient.post(API_ENDPOINTS.CREATE_TRAINING_DATASETS, {
name,
source_imagery,
status,
@@ -31,9 +35,41 @@ export const createTrainingArea = async ({
geom,
}: TCreateTrainingAreaArgs): Promise => {
return await (
- await apiClient.post(`${API_ENDPOINTS.CREATE_TRAINING_AREA}`, {
+ await apiClient.post(API_ENDPOINTS.CREATE_TRAINING_AREA, {
dataset,
geom,
})
).data;
};
+
+export type TCreateTrainingRequestArgs = {
+ batch_size: number;
+ epochs: number;
+ input_boundary_width: number;
+ input_contact_spacing: number;
+ model: string;
+ zoom_level: number[];
+};
+
+export const createTrainingRequest = async ({
+ batch_size,
+ epochs,
+ input_boundary_width,
+ input_contact_spacing,
+ model,
+ zoom_level,
+}: TCreateTrainingRequestArgs): Promise => {
+ return await (
+ await apiClient.post(API_ENDPOINTS.CREATE_TRAINING_REQUEST, {
+ batch_size,
+ epochs,
+ description: "",
+ freeze_layer: false,
+ input_contact_spacing,
+ input_boundary_width,
+ model,
+ multimask: false,
+ zoom_level,
+ })
+ ).data;
+};
diff --git a/frontend/src/features/model-creation/components/creation-success-confirmation.tsx b/frontend/src/features/model-creation/components/creation-success-confirmation.tsx
index 37622f62..14845d50 100644
--- a/frontend/src/features/model-creation/components/creation-success-confirmation.tsx
+++ b/frontend/src/features/model-creation/components/creation-success-confirmation.tsx
@@ -1,9 +1,25 @@
+
import ModelCreationSuccess from "@/assets/images/model_creation_success.png";
import { Button } from "@/components/ui/button";
import { Image } from "@/components/ui/image";
+import { Link } from "@/components/ui/link";
+import { APPLICATION_ROUTES } from "@/utils";
+import { useEffect } from "react";
import ConfettiExplosion from "react-confetti-explosion";
+import { useNavigate, useSearchParams } from "react-router-dom";
const ModelCreationSuccessConfirmation = () => {
+ const navigate = useNavigate();
+ const [searchParams] = useSearchParams();
+ // Model ID should be in the url params upon successful creation
+ const modelId = searchParams.get("id");
+
+ useEffect(() => {
+ if (!modelId) {
+ navigate(APPLICATION_ROUTES.MODELS);
+ }
+ }, [modelId]);
+
return (
{
height={10000}
/>
- Model 15 is Created!
+ Model {modelId} is Created!
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod
tempor incididunt ut labore.
-
-
+
+
+
+
+
+
);
diff --git a/frontend/src/features/model-creation/components/model-summary.tsx b/frontend/src/features/model-creation/components/model-summary.tsx
index 340c1fbb..b957685e 100644
--- a/frontend/src/features/model-creation/components/model-summary.tsx
+++ b/frontend/src/features/model-creation/components/model-summary.tsx
@@ -1,6 +1,7 @@
import { useModelFormContext } from "@/app/providers/model-creation-provider";
import {
DatabaseIcon,
+ MapIcon,
// MapIcon,
// RAMIcon,
SaveIcon,
@@ -11,6 +12,9 @@ import {
} from "@/components/ui/icons";
import { StepHeading } from "@/features/model-creation/components/";
import { IconProps } from "@/types";
+import { APPLICATION_ROUTES } from "@/utils";
+import { useEffect } from "react";
+import { useNavigate } from "react-router-dom";
const SummaryItem = ({
icon: Icon,
@@ -60,13 +64,13 @@ const ModelSummaryStep = () => {
label: "Dataset ID",
content: formData.selectedTrainingDatasetId,
},
- // These will be retrieved from API.
+
// { icon: RAMIcon, label: "Dataset Size", content: "250 Images" },
- // {
- // icon: MapIcon,
- // label: "Open Aerial Imagery",
- // content: "San Jose Mission 10C Flight 1",
- // },
+ {
+ icon: MapIcon,
+ label: "Open Aerial Imagery",
+ content: "",
+ },
{
icon: ZoomInIcon,
label: "Zoom Levels",
@@ -83,6 +87,12 @@ const ModelSummaryStep = () => {
],
},
];
+ const navigate = useNavigate()
+ useEffect(() => {
+ if (!formData.modelName && !formData.modelDescription) {
+ navigate(APPLICATION_ROUTES.MODELS);
+ }
+ }, [formData]);
return (
diff --git a/frontend/src/features/model-creation/components/progress-buttons.tsx b/frontend/src/features/model-creation/components/progress-buttons.tsx
index 0c9ba69d..ea0b8a11 100644
--- a/frontend/src/features/model-creation/components/progress-buttons.tsx
+++ b/frontend/src/features/model-creation/components/progress-buttons.tsx
@@ -9,6 +9,7 @@ import { APPLICATION_ROUTES } from "@/utils";
import { useMemo } from "react";
import { useNavigate } from "react-router-dom";
import { TrainingDatasetOption } from "./training-dataset";
+import { TTrainingArea } from "@/types";
type ProgressButtonsProps = {
currentPath: string;
@@ -23,24 +24,42 @@ const ProgressButtons: React.FC
= ({
}) => {
const navigate = useNavigate();
- const { formData, handleChange, createNewTrainingDatasetMutation } =
- useModelFormContext();
+ const {
+ formData,
+ handleChange,
+ createNewTrainingDatasetMutation,
+ createNewModelMutation,
+ } = useModelFormContext();
const nextPage = () => {
- // If on training dataset creation page, then submit the form before proceeding
- if (
- currentPath == APPLICATION_ROUTES.CREATE_NEW_MODEL_TRAINING_DATASET &&
- formData.trainingDatasetOption === TrainingDatasetOption.CREATE_NEW
- ) {
- createNewTrainingDatasetMutation.mutate({
- source_imagery: formData.tmsURL,
- name: formData.datasetName,
- });
- // Navigation will happen in the context if successful.
- } else {
- if (currentPageIndex < pages.length - 1) {
- navigate(pages[currentPageIndex + 1].path);
- }
+ switch (currentPath) {
+ case APPLICATION_ROUTES.CREATE_NEW_MODEL_TRAINING_DATASET:
+ if (
+ formData.trainingDatasetOption === TrainingDatasetOption.CREATE_NEW
+ ) {
+ createNewTrainingDatasetMutation.mutate({
+ source_imagery: formData.tmsURL,
+ name: formData.datasetName,
+ });
+ } else {
+ if (currentPageIndex < pages.length - 1) {
+ navigate(pages[currentPageIndex + 1].path);
+ }
+ }
+ break;
+ case APPLICATION_ROUTES.CREATE_NEW_MODEL_SUMMARY:
+ createNewModelMutation.mutate({
+ dataset: formData.selectedTrainingDatasetId,
+ name: formData.modelName,
+ description: formData.modelDescription,
+ base_model: formData.baseModel.toUpperCase(),
+ });
+ break;
+ default:
+ if (currentPageIndex < pages.length - 1) {
+ navigate(pages[currentPageIndex + 1].path);
+ }
+ break;
}
};
@@ -86,6 +105,7 @@ const ProgressButtons: React.FC = ({
FORM_VALIDATION_CONFIG[MODEL_CREATION_FORM_NAME.MODEL_DESCRIPTION]
.minLength
);
+
case APPLICATION_ROUTES.CREATE_NEW_MODEL_TRAINING_DATASET:
// if the user hasn't selected any of the options, then they can not proceed to next page.
if (formData.trainingDatasetOption === TrainingDatasetOption.NONE) {
@@ -117,11 +137,21 @@ const ProgressButtons: React.FC = ({
case APPLICATION_ROUTES.CREATE_NEW_MODEL_TRAINING_SETTINGS:
// confirm that the user has selected at least an option
return formData.zoomLevels.length > 0;
+ case APPLICATION_ROUTES.CREATE_NEW_MODEL_TRAINING_AREA:
+ //@ts-expect-error bad type definition
+ const trainingAreas: TTrainingArea = formData.trainingAreas;
+ // Ensure that no geometry is null before they can proceed
+ return (
+ trainingAreas?.features?.filter((area) => area.geometry !== null)
+ .length > 0
+ );
default:
return true;
}
}, [formData, currentPath]);
+ // Handle model creation when on the last page
+
return (
{
+ updateTMSLayer();
+ updateGeoJSONLayer();
+ };
+ onStyleData();
+ // Attach the listener for style changes
+ map.on("styledata", onStyleData);
return () => {
if (map.getLayer(trainingAreasLayerId)) {
@@ -94,8 +99,10 @@ const TrainingAreaMap = ({
if (map.getSource(TMSSourceId)) {
map.removeSource(TMSSourceId);
}
+
+ map.off("styledata", onStyleData);
};
- }, [map, mapData]);
+ }, [map, mapData, tileJSONURL]);
return (
0
+ ? [
+ {
+ value: "Training Areas",
+ mapLayerId: trainingAreasLayerId,
+ },
+ ]
+ : []),
]}
/>
);
diff --git a/frontend/src/features/model-creation/components/training-area/training-area.tsx b/frontend/src/features/model-creation/components/training-area/training-area.tsx
index c77c2189..4ff52645 100644
--- a/frontend/src/features/model-creation/components/training-area/training-area.tsx
+++ b/frontend/src/features/model-creation/components/training-area/training-area.tsx
@@ -6,11 +6,14 @@ import {
import { StepHeading } from "@/features/model-creation/components/";
import TrainingAreaMap from "./training-area-map";
import { Button, ButtonWithIcon } from "@/components/ui/button";
-import { useModelFormContext } from "@/app/providers/model-creation-provider";
+import {
+ MODEL_CREATION_FORM_NAME,
+ useModelFormContext,
+} from "@/app/providers/model-creation-provider";
import { useGetTMSTileJSON } from "../../hooks/use-tms-tilejson";
import { useDialog } from "@/hooks/use-dialog";
import FileUploadDialog from "../dialogs/file-upload-dialog";
-import { useState } from "react";
+import { useEffect, useState } from "react";
import TrainingAreaList from "./training-area-list";
import { useGetTrainingAreas } from "../../hooks/use-training-areas";
import { useMap } from "@/app/providers/map-provider";
@@ -23,6 +26,7 @@ const TrainingAreaForm = () => {
const { isPending, data, isError } = useGetTMSTileJSON(tileJSONURL);
const { closeDialog, isOpened, toggle } = useDialog();
+ const { handleChange } = useModelFormContext();
const { map } = useMap();
@@ -38,6 +42,16 @@ const TrainingAreaForm = () => {
isPlaceholderData,
} = useGetTrainingAreas(Number(formData.selectedTrainingDatasetId), offset);
+ useEffect(() => {
+ if (!trainingAreasData) return;
+ // update the form data when the data changes
+ // @ts-expect-error bad type definition
+ handleChange(
+ MODEL_CREATION_FORM_NAME.TRAINING_AREAS,
+ trainingAreasData?.results,
+ );
+ }, [trainingAreasData]);
+
return (
<>
;
+};
+
+export const useCreateModel = ({ mutationConfig }: useCreateModelOptions) => {
+ const { onSuccess, ...restConfig } = mutationConfig || {};
+
+ return useMutation({
+ mutationFn: (args: TCreateModelArgs) => createModel(args),
+ onSuccess: (...args) => {
+ onSuccess?.(...args);
+ },
+ ...restConfig,
+ });
+};
+
+type useCreateModelTrainingOptions = {
+ mutationConfig?: MutationConfig;
+};
+
+export const useCreateModelTrainingRequest = ({
+ mutationConfig,
+}: useCreateModelTrainingOptions) => {
+ const { onSuccess, ...restConfig } = mutationConfig || {};
+
+ return useMutation({
+ mutationFn: (args: TCreateTrainingRequestArgs) =>
+ createTrainingRequest(args),
+ onSuccess: (...args) => {
+ onSuccess?.(...args);
+ },
+ ...restConfig,
+ });
+};
diff --git a/frontend/src/features/models/api/factory.ts b/frontend/src/features/models/api/factory.ts
index 8d941317..5ebf6549 100644
--- a/frontend/src/features/models/api/factory.ts
+++ b/frontend/src/features/models/api/factory.ts
@@ -89,6 +89,7 @@ export const getTrainingWorkspaceQueryOptions = (
return queryOptions({
queryKey: ["training-workspace", datasetId, trainingId, directory_name],
queryFn: () => getTrainingWorkspace(datasetId, trainingId, directory_name),
+ enabled: trainingId !== null,
});
};
diff --git a/frontend/src/features/models/components/directory-tree.tsx b/frontend/src/features/models/components/directory-tree.tsx
index 06d35867..65535afc 100644
--- a/frontend/src/features/models/components/directory-tree.tsx
+++ b/frontend/src/features/models/components/directory-tree.tsx
@@ -101,9 +101,11 @@ const DirectoryTree: React.FC = ({
const fetchDirectoryData = async (path: string = "") => {
try {
- return await queryClient.fetchQuery({
- ...getTrainingWorkspaceQueryOptions(datasetId, trainingId, path),
- });
+ if (trainingId !== null) {
+ return await queryClient.fetchQuery({
+ ...getTrainingWorkspaceQueryOptions(datasetId, trainingId, path),
+ });
+ }
} catch {
setHasError(true);
return null;
diff --git a/frontend/src/features/models/components/model-details-info.tsx b/frontend/src/features/models/components/model-details-info.tsx
index bc406e34..a829129e 100644
--- a/frontend/src/features/models/components/model-details-info.tsx
+++ b/frontend/src/features/models/components/model-details-info.tsx
@@ -39,6 +39,7 @@ const ModelDetailsInfo = ({
-
{APP_CONTENT.models.modelsDetailsCard.viewTrainingArea}
-
+
@@ -88,6 +90,7 @@ const ModelDetailsInfo = ({
capitalizeText={false}
onClick={openModelFilesDialog}
prefixIcon={DirectoryIcon}
+ disabled={data?.published_training === null}
/>
diff --git a/frontend/src/features/models/components/model-feedbacks.tsx b/frontend/src/features/models/components/model-feedbacks.tsx
index ad3655fa..91c77b65 100644
--- a/frontend/src/features/models/components/model-feedbacks.tsx
+++ b/frontend/src/features/models/components/model-feedbacks.tsx
@@ -13,7 +13,7 @@ const ModelFeedbacks = ({ trainingId }: { trainingId: number }) => {
return (
<>
- {isError ? "N/A" : data?.count}
+ {isError ? "N/A" : (data?.count ?? 0)}
{APP_CONTENT.models.modelsDetailsCard.feedbacks}
@@ -24,6 +24,7 @@ const ModelFeedbacks = ({ trainingId }: { trainingId: number }) => {
variant="dark"
size="medium"
prefixIcon={ChatbubbleIcon}
+ disabled={trainingId === null}
/>
>
diff --git a/frontend/src/features/models/hooks/use-training.ts b/frontend/src/features/models/hooks/use-training.ts
index d130d1d9..8658c34e 100644
--- a/frontend/src/features/models/hooks/use-training.ts
+++ b/frontend/src/features/models/hooks/use-training.ts
@@ -14,6 +14,7 @@ export const useTrainingDetails = (id: number) => {
throwOnError: (error) => error?.response?.status >= 500,
refetchInterval: 10000, // 10 seconds
refetchIntervalInBackground: true,
+ enabled: id !== null,
});
};
@@ -32,6 +33,7 @@ export const useTrainingFeedbacks = (id: number) => {
...getTrainingFeedbacksQueryOptions(id),
//@ts-expect-error bad type definition
throwOnError: (error) => error?.response?.status >= 500,
+ enabled: id !== null,
});
};
export const useTrainingWorkspace = (
@@ -43,6 +45,7 @@ export const useTrainingWorkspace = (
...getTrainingWorkspaceQueryOptions(datasetId, trainingId, directory_name),
//@ts-expect-error bad type definition
throwOnError: (error) => error?.response?.status >= 500,
+ enabled: trainingId !== null,
});
};
diff --git a/frontend/src/services/api-routes.ts b/frontend/src/services/api-routes.ts
index 0db09aa2..b150c249 100644
--- a/frontend/src/services/api-routes.ts
+++ b/frontend/src/services/api-routes.ts
@@ -13,6 +13,7 @@ export const API_ENDPOINTS = {
//Models
GET_MODELS: "model/",
+ CREATE_MODELS: "model/",
GET_MODEL_DETAILS: (id: string) => `model/${id}`,
GET_MODELS_CENTROIDS: "models/centroid",
@@ -24,6 +25,8 @@ export const API_ENDPOINTS = {
CREATE_TRAINING_AREA: "aoi/",
+ CREATE_TRAINING_REQUEST: "training/",
+
DELETE_TRAINING_AREA: (id: number) => `aoi/${id}/`,
GET_TRAINING_AREAS: (datasetId: number, offset: number, limit: number) =>