diff --git a/frontend/.gitignore b/frontend/.gitignore index a547bf36..e440e68b 100644 --- a/frontend/.gitignore +++ b/frontend/.gitignore @@ -22,3 +22,6 @@ dist-ssr *.njsproj *.sln *.sw? + +tsconfig.app.tsbuildinfo +tsconfig.node.tsbuildinfo \ No newline at end of file diff --git a/frontend/src/app/providers/model-creation-provider.tsx b/frontend/src/app/providers/model-creation-provider.tsx index 470de263..2ae7b75b 100644 --- a/frontend/src/app/providers/model-creation-provider.tsx +++ b/frontend/src/app/providers/model-creation-provider.tsx @@ -11,8 +11,10 @@ import { UseMutationResult } from "@tanstack/react-query"; import React, { createContext, useContext, useEffect, useState } from "react"; import { useToast } from "./toast-provider"; import { useNavigate } from "react-router-dom"; -import { TTrainingDataset } from "@/types"; +import { TModel, TTrainingDataset } from "@/types"; import { TCreateTrainingDatasetArgs } from "@/features/model-creation/api/create-trainings"; +import { useCreateModel, useCreateModelTrainingRequest } from "@/features/model-creation/hooks/use-models"; +import { TCreateModelArgs } from "@/features/model-creation/api/create-models"; // The names here is the same with the initialFormState object keys as well as the form validation config export enum MODEL_CREATION_FORM_NAME { @@ -30,6 +32,7 @@ export enum MODEL_CREATION_FORM_NAME { TMS_URL = "tmsURL", TMS_URL_VALIDITY = "tmsURLValidation", SELECTED_TRAINING_DATASET_ID = "selectedTrainingDatasetId", + TRAINING_AREAS = "trainingAreas", } export const FORM_VALIDATION_CONFIG = { @@ -81,6 +84,7 @@ const initialFormState = { // training dataset selection selectedTrainingDatasetId: "", zoomLevels: [20, 21], + trainingAreas: [], // Defaults to basic configurations trainingType: TrainingType.BASIC, epoch: 2, @@ -107,6 +111,12 @@ const ModelCreationFormContext = createContext<{ TCreateTrainingDatasetArgs, unknown >; + createNewModelMutation: UseMutationResult< + TModel, + Error, + TCreateModelArgs, + unknown + >; }>({ formData: initialFormState, setFormData: () => {}, @@ -117,6 +127,12 @@ const ModelCreationFormContext = createContext<{ TCreateTrainingDatasetArgs, unknown >, + createNewModelMutation: {} as UseMutationResult< + TModel, + Error, + TCreateModelArgs, + unknown + >, }); export const ModelCreationFormProvider: React.FC<{ @@ -142,6 +158,15 @@ export const ModelCreationFormProvider: React.FC<{ ) => { setFormData((prev) => ({ ...prev, [field]: value })); }; + const trainingRequestMutation = useCreateModelTrainingRequest({mutationConfig:{ + onSuccess:()=>{ + notify("Training request submitted successfully", "success"); + }, + onError: (error) => { + const errorText = error?.response?.data[0] ?? "An error ocurred while submitting training request" + notify(errorText, "danger"); + }, + }}); const createNewTrainingDatasetMutation = useCreateTrainingDataset({ mutationConfig: { @@ -157,7 +182,36 @@ export const ModelCreationFormProvider: React.FC<{ navigate(APPLICATION_ROUTES.CREATE_NEW_MODEL_TRAINING_AREA); }, onError: () => { - notify("Error creating dataset", "danger"); + + notify("An error occurred while creating dataset", "danger"); + }, + }, + }); + + const createNewModelMutation = useCreateModel({ + mutationConfig: { + onSuccess: (data) => { + notify("Model created successfully", "success"); + // Submit the model for training request + trainingRequestMutation.mutate({ + model:data.id, + input_boundary_width:formData.boundaryWidth, + input_contact_spacing:formData.contactSpacing, + epochs:formData.epoch, + batch_size:formData.batchSize, + zoom_level:formData.zoomLevels + + }) + + setFormData(initialFormState); + + navigate( + `${APPLICATION_ROUTES.CREATE_NEW_MODEL_CONFIRMATION}?id=${data.id}`, + + ); + }, + onError: () => { + notify("An error ocurred while creating model", "danger"); }, }, }); @@ -176,6 +230,7 @@ export const ModelCreationFormProvider: React.FC<{ setFormData, handleChange, createNewTrainingDatasetMutation, + createNewModelMutation, }} > {children} diff --git a/frontend/src/app/routes/models/model-details.tsx b/frontend/src/app/routes/models/model-details.tsx index bf78ead1..f9f028be 100644 --- a/frontend/src/app/routes/models/model-details.tsx +++ b/frontend/src/app/routes/models/model-details.tsx @@ -17,7 +17,8 @@ import { useDialog } from "@/hooks/use-dialog"; import { APP_CONTENT, APPLICATION_ROUTES } from "@/utils"; import { useEffect } from "react"; import { useNavigate, useParams } from "react-router-dom"; - +import TrainingInProgressImage from "@/assets/images/training_in_prorgress.png"; +import { Image } from "@/components/ui/image"; export const ModelDetailsPage = () => { const { id } = useParams<{ id: string }>(); const { isOpened, closeDialog, openDialog } = useDialog(); @@ -64,10 +65,23 @@ export const ModelDetailsPage = () => { - + {!data?.published_training ? ( +
+ Model training in progress +

+ Model training is not activated yet. Properties will be + available after a successful and activated training. +

+
+ ) : ( + + )}
=> { + return await ( + await apiClient.post(`${API_ENDPOINTS.CREATE_MODELS}`, { + dataset, + name, + description, + base_model, + }) + ).data; +}; diff --git a/frontend/src/features/model-creation/api/create-trainings.ts b/frontend/src/features/model-creation/api/create-trainings.ts index 44ef57d0..e2be1d18 100644 --- a/frontend/src/features/model-creation/api/create-trainings.ts +++ b/frontend/src/features/model-creation/api/create-trainings.ts @@ -1,5 +1,9 @@ import { API_ENDPOINTS, apiClient } from "@/services"; -import { TTrainingAreaFeature, TTrainingDataset } from "@/types"; +import { + TTrainingAreaFeature, + TTrainingDataset, + TTrainingDetails, +} from "@/types"; export type TCreateTrainingDatasetArgs = { name: string; @@ -13,7 +17,7 @@ export const createTrainingDataset = async ({ status = 0, }: TCreateTrainingDatasetArgs): Promise => { return await ( - await apiClient.post(`${API_ENDPOINTS.CREATE_TRAINING_DATASETS}`, { + await apiClient.post(API_ENDPOINTS.CREATE_TRAINING_DATASETS, { name, source_imagery, status, @@ -31,9 +35,41 @@ export const createTrainingArea = async ({ geom, }: TCreateTrainingAreaArgs): Promise => { return await ( - await apiClient.post(`${API_ENDPOINTS.CREATE_TRAINING_AREA}`, { + await apiClient.post(API_ENDPOINTS.CREATE_TRAINING_AREA, { dataset, geom, }) ).data; }; + +export type TCreateTrainingRequestArgs = { + batch_size: number; + epochs: number; + input_boundary_width: number; + input_contact_spacing: number; + model: string; + zoom_level: number[]; +}; + +export const createTrainingRequest = async ({ + batch_size, + epochs, + input_boundary_width, + input_contact_spacing, + model, + zoom_level, +}: TCreateTrainingRequestArgs): Promise => { + return await ( + await apiClient.post(API_ENDPOINTS.CREATE_TRAINING_REQUEST, { + batch_size, + epochs, + description: "", + freeze_layer: false, + input_contact_spacing, + input_boundary_width, + model, + multimask: false, + zoom_level, + }) + ).data; +}; diff --git a/frontend/src/features/model-creation/components/creation-success-confirmation.tsx b/frontend/src/features/model-creation/components/creation-success-confirmation.tsx index 37622f62..14845d50 100644 --- a/frontend/src/features/model-creation/components/creation-success-confirmation.tsx +++ b/frontend/src/features/model-creation/components/creation-success-confirmation.tsx @@ -1,9 +1,25 @@ + import ModelCreationSuccess from "@/assets/images/model_creation_success.png"; import { Button } from "@/components/ui/button"; import { Image } from "@/components/ui/image"; +import { Link } from "@/components/ui/link"; +import { APPLICATION_ROUTES } from "@/utils"; +import { useEffect } from "react"; import ConfettiExplosion from "react-confetti-explosion"; +import { useNavigate, useSearchParams } from "react-router-dom"; const ModelCreationSuccessConfirmation = () => { + const navigate = useNavigate(); + const [searchParams] = useSearchParams(); + // Model ID should be in the url params upon successful creation + const modelId = searchParams.get("id"); + + useEffect(() => { + if (!modelId) { + navigate(APPLICATION_ROUTES.MODELS); + } + }, [modelId]); + return (
{ height={10000} /> Model Creation Success Icon -

Model 15 is Created!

+

Model {modelId} is Created!

Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore.

- - + + + + + +
); diff --git a/frontend/src/features/model-creation/components/model-summary.tsx b/frontend/src/features/model-creation/components/model-summary.tsx index 340c1fbb..b957685e 100644 --- a/frontend/src/features/model-creation/components/model-summary.tsx +++ b/frontend/src/features/model-creation/components/model-summary.tsx @@ -1,6 +1,7 @@ import { useModelFormContext } from "@/app/providers/model-creation-provider"; import { DatabaseIcon, + MapIcon, // MapIcon, // RAMIcon, SaveIcon, @@ -11,6 +12,9 @@ import { } from "@/components/ui/icons"; import { StepHeading } from "@/features/model-creation/components/"; import { IconProps } from "@/types"; +import { APPLICATION_ROUTES } from "@/utils"; +import { useEffect } from "react"; +import { useNavigate } from "react-router-dom"; const SummaryItem = ({ icon: Icon, @@ -60,13 +64,13 @@ const ModelSummaryStep = () => { label: "Dataset ID", content: formData.selectedTrainingDatasetId, }, - // These will be retrieved from API. + // { icon: RAMIcon, label: "Dataset Size", content: "250 Images" }, - // { - // icon: MapIcon, - // label: "Open Aerial Imagery", - // content: "San Jose Mission 10C Flight 1", - // }, + { + icon: MapIcon, + label: "Open Aerial Imagery", + content: "", + }, { icon: ZoomInIcon, label: "Zoom Levels", @@ -83,6 +87,12 @@ const ModelSummaryStep = () => { ], }, ]; + const navigate = useNavigate() + useEffect(() => { + if (!formData.modelName && !formData.modelDescription) { + navigate(APPLICATION_ROUTES.MODELS); + } + }, [formData]); return (
diff --git a/frontend/src/features/model-creation/components/progress-buttons.tsx b/frontend/src/features/model-creation/components/progress-buttons.tsx index 0c9ba69d..ea0b8a11 100644 --- a/frontend/src/features/model-creation/components/progress-buttons.tsx +++ b/frontend/src/features/model-creation/components/progress-buttons.tsx @@ -9,6 +9,7 @@ import { APPLICATION_ROUTES } from "@/utils"; import { useMemo } from "react"; import { useNavigate } from "react-router-dom"; import { TrainingDatasetOption } from "./training-dataset"; +import { TTrainingArea } from "@/types"; type ProgressButtonsProps = { currentPath: string; @@ -23,24 +24,42 @@ const ProgressButtons: React.FC = ({ }) => { const navigate = useNavigate(); - const { formData, handleChange, createNewTrainingDatasetMutation } = - useModelFormContext(); + const { + formData, + handleChange, + createNewTrainingDatasetMutation, + createNewModelMutation, + } = useModelFormContext(); const nextPage = () => { - // If on training dataset creation page, then submit the form before proceeding - if ( - currentPath == APPLICATION_ROUTES.CREATE_NEW_MODEL_TRAINING_DATASET && - formData.trainingDatasetOption === TrainingDatasetOption.CREATE_NEW - ) { - createNewTrainingDatasetMutation.mutate({ - source_imagery: formData.tmsURL, - name: formData.datasetName, - }); - // Navigation will happen in the context if successful. - } else { - if (currentPageIndex < pages.length - 1) { - navigate(pages[currentPageIndex + 1].path); - } + switch (currentPath) { + case APPLICATION_ROUTES.CREATE_NEW_MODEL_TRAINING_DATASET: + if ( + formData.trainingDatasetOption === TrainingDatasetOption.CREATE_NEW + ) { + createNewTrainingDatasetMutation.mutate({ + source_imagery: formData.tmsURL, + name: formData.datasetName, + }); + } else { + if (currentPageIndex < pages.length - 1) { + navigate(pages[currentPageIndex + 1].path); + } + } + break; + case APPLICATION_ROUTES.CREATE_NEW_MODEL_SUMMARY: + createNewModelMutation.mutate({ + dataset: formData.selectedTrainingDatasetId, + name: formData.modelName, + description: formData.modelDescription, + base_model: formData.baseModel.toUpperCase(), + }); + break; + default: + if (currentPageIndex < pages.length - 1) { + navigate(pages[currentPageIndex + 1].path); + } + break; } }; @@ -86,6 +105,7 @@ const ProgressButtons: React.FC = ({ FORM_VALIDATION_CONFIG[MODEL_CREATION_FORM_NAME.MODEL_DESCRIPTION] .minLength ); + case APPLICATION_ROUTES.CREATE_NEW_MODEL_TRAINING_DATASET: // if the user hasn't selected any of the options, then they can not proceed to next page. if (formData.trainingDatasetOption === TrainingDatasetOption.NONE) { @@ -117,11 +137,21 @@ const ProgressButtons: React.FC = ({ case APPLICATION_ROUTES.CREATE_NEW_MODEL_TRAINING_SETTINGS: // confirm that the user has selected at least an option return formData.zoomLevels.length > 0; + case APPLICATION_ROUTES.CREATE_NEW_MODEL_TRAINING_AREA: + //@ts-expect-error bad type definition + const trainingAreas: TTrainingArea = formData.trainingAreas; + // Ensure that no geometry is null before they can proceed + return ( + trainingAreas?.features?.filter((area) => area.geometry !== null) + .length > 0 + ); default: return true; } }, [formData, currentPath]); + // Handle model creation when on the last page + return (
{ + updateTMSLayer(); + updateGeoJSONLayer(); + }; + onStyleData(); + // Attach the listener for style changes + map.on("styledata", onStyleData); return () => { if (map.getLayer(trainingAreasLayerId)) { @@ -94,8 +99,10 @@ const TrainingAreaMap = ({ if (map.getSource(TMSSourceId)) { map.removeSource(TMSSourceId); } + + map.off("styledata", onStyleData); }; - }, [map, mapData]); + }, [map, mapData, tileJSONURL]); return ( 0 + ? [ + { + value: "Training Areas", + mapLayerId: trainingAreasLayerId, + }, + ] + : []), ]} /> ); diff --git a/frontend/src/features/model-creation/components/training-area/training-area.tsx b/frontend/src/features/model-creation/components/training-area/training-area.tsx index c77c2189..4ff52645 100644 --- a/frontend/src/features/model-creation/components/training-area/training-area.tsx +++ b/frontend/src/features/model-creation/components/training-area/training-area.tsx @@ -6,11 +6,14 @@ import { import { StepHeading } from "@/features/model-creation/components/"; import TrainingAreaMap from "./training-area-map"; import { Button, ButtonWithIcon } from "@/components/ui/button"; -import { useModelFormContext } from "@/app/providers/model-creation-provider"; +import { + MODEL_CREATION_FORM_NAME, + useModelFormContext, +} from "@/app/providers/model-creation-provider"; import { useGetTMSTileJSON } from "../../hooks/use-tms-tilejson"; import { useDialog } from "@/hooks/use-dialog"; import FileUploadDialog from "../dialogs/file-upload-dialog"; -import { useState } from "react"; +import { useEffect, useState } from "react"; import TrainingAreaList from "./training-area-list"; import { useGetTrainingAreas } from "../../hooks/use-training-areas"; import { useMap } from "@/app/providers/map-provider"; @@ -23,6 +26,7 @@ const TrainingAreaForm = () => { const { isPending, data, isError } = useGetTMSTileJSON(tileJSONURL); const { closeDialog, isOpened, toggle } = useDialog(); + const { handleChange } = useModelFormContext(); const { map } = useMap(); @@ -38,6 +42,16 @@ const TrainingAreaForm = () => { isPlaceholderData, } = useGetTrainingAreas(Number(formData.selectedTrainingDatasetId), offset); + useEffect(() => { + if (!trainingAreasData) return; + // update the form data when the data changes + // @ts-expect-error bad type definition + handleChange( + MODEL_CREATION_FORM_NAME.TRAINING_AREAS, + trainingAreasData?.results, + ); + }, [trainingAreasData]); + return ( <> ; +}; + +export const useCreateModel = ({ mutationConfig }: useCreateModelOptions) => { + const { onSuccess, ...restConfig } = mutationConfig || {}; + + return useMutation({ + mutationFn: (args: TCreateModelArgs) => createModel(args), + onSuccess: (...args) => { + onSuccess?.(...args); + }, + ...restConfig, + }); +}; + +type useCreateModelTrainingOptions = { + mutationConfig?: MutationConfig; +}; + +export const useCreateModelTrainingRequest = ({ + mutationConfig, +}: useCreateModelTrainingOptions) => { + const { onSuccess, ...restConfig } = mutationConfig || {}; + + return useMutation({ + mutationFn: (args: TCreateTrainingRequestArgs) => + createTrainingRequest(args), + onSuccess: (...args) => { + onSuccess?.(...args); + }, + ...restConfig, + }); +}; diff --git a/frontend/src/features/models/api/factory.ts b/frontend/src/features/models/api/factory.ts index 8d941317..5ebf6549 100644 --- a/frontend/src/features/models/api/factory.ts +++ b/frontend/src/features/models/api/factory.ts @@ -89,6 +89,7 @@ export const getTrainingWorkspaceQueryOptions = ( return queryOptions({ queryKey: ["training-workspace", datasetId, trainingId, directory_name], queryFn: () => getTrainingWorkspace(datasetId, trainingId, directory_name), + enabled: trainingId !== null, }); }; diff --git a/frontend/src/features/models/components/directory-tree.tsx b/frontend/src/features/models/components/directory-tree.tsx index 06d35867..65535afc 100644 --- a/frontend/src/features/models/components/directory-tree.tsx +++ b/frontend/src/features/models/components/directory-tree.tsx @@ -101,9 +101,11 @@ const DirectoryTree: React.FC = ({ const fetchDirectoryData = async (path: string = "") => { try { - return await queryClient.fetchQuery({ - ...getTrainingWorkspaceQueryOptions(datasetId, trainingId, path), - }); + if (trainingId !== null) { + return await queryClient.fetchQuery({ + ...getTrainingWorkspaceQueryOptions(datasetId, trainingId, path), + }); + } } catch { setHasError(true); return null; diff --git a/frontend/src/features/models/components/model-details-info.tsx b/frontend/src/features/models/components/model-details-info.tsx index bc406e34..a829129e 100644 --- a/frontend/src/features/models/components/model-details-info.tsx +++ b/frontend/src/features/models/components/model-details-info.tsx @@ -39,6 +39,7 @@ const ModelDetailsInfo = ({
-

{APP_CONTENT.models.modelsDetailsCard.viewTrainingArea}

-
+ @@ -88,6 +90,7 @@ const ModelDetailsInfo = ({ capitalizeText={false} onClick={openModelFilesDialog} prefixIcon={DirectoryIcon} + disabled={data?.published_training === null} />
diff --git a/frontend/src/features/models/components/model-feedbacks.tsx b/frontend/src/features/models/components/model-feedbacks.tsx index ad3655fa..91c77b65 100644 --- a/frontend/src/features/models/components/model-feedbacks.tsx +++ b/frontend/src/features/models/components/model-feedbacks.tsx @@ -13,7 +13,7 @@ const ModelFeedbacks = ({ trainingId }: { trainingId: number }) => { return ( <>

- {isError ? "N/A" : data?.count} + {isError ? "N/A" : (data?.count ?? 0)} {APP_CONTENT.models.modelsDetailsCard.feedbacks} @@ -24,6 +24,7 @@ const ModelFeedbacks = ({ trainingId }: { trainingId: number }) => { variant="dark" size="medium" prefixIcon={ChatbubbleIcon} + disabled={trainingId === null} />

diff --git a/frontend/src/features/models/hooks/use-training.ts b/frontend/src/features/models/hooks/use-training.ts index d130d1d9..8658c34e 100644 --- a/frontend/src/features/models/hooks/use-training.ts +++ b/frontend/src/features/models/hooks/use-training.ts @@ -14,6 +14,7 @@ export const useTrainingDetails = (id: number) => { throwOnError: (error) => error?.response?.status >= 500, refetchInterval: 10000, // 10 seconds refetchIntervalInBackground: true, + enabled: id !== null, }); }; @@ -32,6 +33,7 @@ export const useTrainingFeedbacks = (id: number) => { ...getTrainingFeedbacksQueryOptions(id), //@ts-expect-error bad type definition throwOnError: (error) => error?.response?.status >= 500, + enabled: id !== null, }); }; export const useTrainingWorkspace = ( @@ -43,6 +45,7 @@ export const useTrainingWorkspace = ( ...getTrainingWorkspaceQueryOptions(datasetId, trainingId, directory_name), //@ts-expect-error bad type definition throwOnError: (error) => error?.response?.status >= 500, + enabled: trainingId !== null, }); }; diff --git a/frontend/src/services/api-routes.ts b/frontend/src/services/api-routes.ts index 0db09aa2..b150c249 100644 --- a/frontend/src/services/api-routes.ts +++ b/frontend/src/services/api-routes.ts @@ -13,6 +13,7 @@ export const API_ENDPOINTS = { //Models GET_MODELS: "model/", + CREATE_MODELS: "model/", GET_MODEL_DETAILS: (id: string) => `model/${id}`, GET_MODELS_CENTROIDS: "models/centroid", @@ -24,6 +25,8 @@ export const API_ENDPOINTS = { CREATE_TRAINING_AREA: "aoi/", + CREATE_TRAINING_REQUEST: "training/", + DELETE_TRAINING_AREA: (id: number) => `aoi/${id}/`, GET_TRAINING_AREAS: (datasetId: number, offset: number, limit: number) =>