From d2ab220f7f5cfd6b809a2a419d64dc002fda47e6 Mon Sep 17 00:00:00 2001
From: philippe horne
Date: Fri, 5 Nov 2021 15:35:09 -0400
Subject: [PATCH 01/16] Initial addition of segmentation route.
---
api/routes/segmentation.go | 47 ++++++++++++++++++++++++++++++++++++++
api/task/segment.go | 29 +++++++++++++++++++++++
go.mod | 2 +-
go.sum | 4 ++--
main.go | 1 +
5 files changed, 80 insertions(+), 3 deletions(-)
create mode 100644 api/routes/segmentation.go
create mode 100644 api/task/segment.go
diff --git a/api/routes/segmentation.go b/api/routes/segmentation.go
new file mode 100644
index 000000000..6156cba7f
--- /dev/null
+++ b/api/routes/segmentation.go
@@ -0,0 +1,47 @@
+package routes
+
+import (
+ "net/http"
+
+ "github.com/pkg/errors"
+ "goji.io/v3/pat"
+
+ "github.com/uncharted-distil/distil/api/env"
+ api "github.com/uncharted-distil/distil/api/model"
+ "github.com/uncharted-distil/distil/api/task"
+)
+
+// SegmentationHandler will segment the specified remote sensing dataset.
+func SegmentationHandler(metaCtor api.MetadataStorageCtor, dataCtor api.DataStorageCtor, config env.Config) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ // get dataset name
+ dataset := pat.Param(r, "dataset")
+
+ metaStorage, err := metaCtor()
+ if err != nil {
+ handleError(w, err)
+ return
+ }
+
+ ds, err := metaStorage.FetchDataset(dataset, false, false, false)
+ if err != nil {
+ handleError(w, err)
+ return
+ }
+
+ outputURI, err := task.Segment(ds)
+ if err != nil {
+ handleError(w, errors.Wrap(err, "unable segment dataset"))
+ return
+ }
+
+ // marshal output into JSON
+ err = handleJSON(w, map[string]interface{}{
+ "uri": outputURI,
+ })
+ if err != nil {
+ handleError(w, errors.Wrap(err, "unable marshal clustering result into JSON"))
+ return
+ }
+ }
+}
diff --git a/api/task/segment.go b/api/task/segment.go
new file mode 100644
index 000000000..f09908771
--- /dev/null
+++ b/api/task/segment.go
@@ -0,0 +1,29 @@
+package task
+
+import (
+ "github.com/uncharted-distil/distil-compute/primitive/compute/description"
+ "github.com/uncharted-distil/distil/api/env"
+ api "github.com/uncharted-distil/distil/api/model"
+)
+
+// Segment segments an image into separate parts.
+func Segment(dataset *api.Dataset) (string, error) {
+ envConfig, err := env.LoadConfig()
+ if err != nil {
+ return "", err
+ }
+
+ datasetInputDir := env.ResolvePath(dataset.Source, dataset.Folder)
+
+ step, err := description.CreateRemoteSensingSegmentationPipeline("segmentation", "basic image segmentation", envConfig.RemoteSensingNumJobs)
+ if err != nil {
+ return "", err
+ }
+
+ datasetURI, err := submitPipeline([]string{datasetInputDir}, step, true)
+ if err != nil {
+ return "", err
+ }
+
+ return datasetURI, nil
+}
diff --git a/go.mod b/go.mod
index c4abdba77..4af24a519 100644
--- a/go.mod
+++ b/go.mod
@@ -34,7 +34,7 @@ require (
github.com/russross/blackfriday v2.0.0+incompatible
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
github.com/stretchr/testify v1.6.1
- github.com/uncharted-distil/distil-compute v0.0.0-20211105135441-4d7fd71299f8
+ github.com/uncharted-distil/distil-compute v0.0.0-20211105184340-305f023f7b6b
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb
github.com/uncharted-distil/gdal v0.0.0-20200504224203-25f2e6a0dc2a
github.com/unchartedsoftware/plog v0.0.0-20200807135627-83d59e50ced5
diff --git a/go.sum b/go.sum
index 20f3f02bf..71dd8b217 100644
--- a/go.sum
+++ b/go.sum
@@ -213,8 +213,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/uncharted-distil/distil-compute v0.0.0-20211105135441-4d7fd71299f8 h1:E6ERSR1kCCGPJe/R7sPttpAWeohrHu8hrQImdk7W17c=
-github.com/uncharted-distil/distil-compute v0.0.0-20211105135441-4d7fd71299f8/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8=
+github.com/uncharted-distil/distil-compute v0.0.0-20211105184340-305f023f7b6b h1:0n76IL/t7ewpJSyhIGkrDRJ7Zp8F5W3fm4Vfyd3SvCo=
+github.com/uncharted-distil/distil-compute v0.0.0-20211105184340-305f023f7b6b/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8=
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb h1:wDsXsrF8qM34nLeQ9xW+zbEdRNATk5sgOwuwCTrZmvY=
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb/go.mod h1:Xhb77n2q8yDvcVS3Mvw0XlpdNMiFsL+vOlvoe556ivc=
github.com/uncharted-distil/gdal v0.0.0-20200504224203-25f2e6a0dc2a h1:BPJrlnjdhxMBrJWiU4/Gl3PVdCUlY9JspWFTJ9UVO0Y=
diff --git a/main.go b/main.go
index dd4227229..ff11df2f0 100644
--- a/main.go
+++ b/main.go
@@ -297,6 +297,7 @@ func main() {
registerRoutePost(mux, "/distil/update/:dataset", routes.UpdateHandler(esMetadataStorageCtor, pgDataStorageCtor, config))
registerRoutePost(mux, "/distil/clone-result/:produce-request-id", routes.CloningResultsHandler(esMetadataStorageCtor, pgDataStorageCtor, pgSolutionStorageCtor, config))
registerRoutePost(mux, "/distil/clone/:dataset", routes.CloningHandler(esMetadataStorageCtor, pgDataStorageCtor, config))
+ registerRoutePost(mux, "/distil/segment/:dataset", routes.SegmentationHandler(esMetadataStorageCtor, pgDataStorageCtor, config))
registerRoutePost(mux, "/distil/save-dataset/:dataset", routes.SaveDatasetHandler(esMetadataStorageCtor, pgDataStorageCtor, config))
registerRoutePost(mux, "/distil/add-field/:dataset", routes.AddFieldHandler(esMetadataStorageCtor, pgDataStorageCtor))
registerRoutePost(mux, "/distil/extract/:dataset", routes.ExtractHandler(esMetadataStorageCtor, pgDataStorageCtor, config))
From 3cf15f9dc0867ca0263b9777f7b1e23fa450cb85 Mon Sep 17 00:00:00 2001
From: philippe horne
Date: Fri, 12 Nov 2021 15:40:17 -0500
Subject: [PATCH 02/16] Added working segmentation route as well as a
placeholder band combination.
---
api/model/storage/postgres/dataset.go | 6 +-
api/routes/image_pack.go | 45 ++++++++-
api/routes/multiband_image.go | 126 +++++++++++++++++---------
api/routes/segmentation.go | 10 +-
api/task/dataset.go | 35 ++++++-
api/task/segment.go | 118 +++++++++++++++++++++++-
api/util/imagery/imagery.go | 4 +
go.mod | 2 +-
go.sum | 4 +-
main.go | 2 +-
10 files changed, 296 insertions(+), 56 deletions(-)
diff --git a/api/model/storage/postgres/dataset.go b/api/model/storage/postgres/dataset.go
index 4b6033e69..b6604d913 100644
--- a/api/model/storage/postgres/dataset.go
+++ b/api/model/storage/postgres/dataset.go
@@ -357,8 +357,10 @@ func (s *Storage) FetchDataset(dataset string, storageName string,
filteredVars := []*model.Variable{}
selectedVars := map[string]bool{}
- for _, v := range filterParams.Variables {
- selectedVars[v] = true
+ if limitSelectedFields && filterParams != nil {
+ for _, v := range filterParams.Variables {
+ selectedVars[v] = true
+ }
}
// only include data with distilrole data and index
diff --git a/api/routes/image_pack.go b/api/routes/image_pack.go
index d545dfc32..5d4799361 100644
--- a/api/routes/image_pack.go
+++ b/api/routes/image_pack.go
@@ -80,7 +80,9 @@ func MultiBandImagePackHandler(ctor api.MetadataStorageCtor, dataCtor api.DataSt
funcPointer := getImages
optramMap := map[string]imagery.OptramEdges{}
precision := 0
- if params.Band != "" {
+ if params.Band == imagery.Segmentation {
+ funcPointer = getSegmentationImages
+ } else if params.Band != "" {
// if band is not empty then get multiBandImages
funcPointer = getMultiBandImages
if params.Band == imagery.OPTRAM {
@@ -200,6 +202,7 @@ func getImages(imagePackRequest *ImagePackRequest, _ map[string]imagery.OptramEd
}
result <- chanStruct{data: temp, IDs: IDs, errorIDs: errorIDs}
}
+
func getMultiBandImages(multiBandPackRequest *ImagePackRequest, optramMap map[string]imagery.OptramEdges, precision int, threadID int, numThreads int, result chan chanStruct, ctor api.MetadataStorageCtor, dataCtor api.DataStorageCtor) {
temp := [][]byte{}
IDs := []string{}
@@ -278,6 +281,46 @@ func getMultiBandImages(multiBandPackRequest *ImagePackRequest, optramMap map[st
result <- chanStruct{data: temp, IDs: IDs, errorIDs: errorIDs}
}
+func getSegmentationImages(multiBandPackRequest *ImagePackRequest, optramMap map[string]imagery.OptramEdges, precision int, threadID int, numThreads int, result chan chanStruct, ctor api.MetadataStorageCtor, dataCtor api.DataStorageCtor) {
+ temp := [][]byte{}
+ IDs := []string{}
+ errorIDs := []string{}
+ // get common storage
+ storage, err := ctor()
+ if err != nil {
+ log.Error(err)
+ return
+ }
+
+ res, err := storage.FetchDataset(multiBandPackRequest.Dataset, false, false, false)
+ if err != nil {
+ log.Error(err)
+ return
+ }
+
+ sourcePath := path.Join(env.GetResourcePath(), res.ID, "media")
+
+ // loop through image info
+ for i := threadID; i < len(multiBandPackRequest.ImageIDs); i += numThreads {
+ imageID := multiBandPackRequest.ImageIDs[i]
+ img, err := getSegmentationImage(imageID, sourcePath, true, ThumbnailDimensions)
+ if err != nil {
+ handleThreadError(&errorIDs, &imageID, &err)
+ continue
+ }
+
+ imageBytes, err := imagery.ImageToJPEG(img)
+ if err != nil {
+ handleThreadError(&errorIDs, &imageID, &err)
+ continue
+ }
+ temp = append(temp, imageBytes)
+ IDs = append(IDs, imageID)
+ }
+
+ result <- chanStruct{data: temp, IDs: IDs, errorIDs: errorIDs}
+}
+
func handleThreadError(errorIDs *[]string, imageID *string, err *error) {
*errorIDs = append(*errorIDs, *imageID)
log.Error(*err)
diff --git a/api/routes/multiband_image.go b/api/routes/multiband_image.go
index 2f2c54334..8918ede56 100644
--- a/api/routes/multiband_image.go
+++ b/api/routes/multiband_image.go
@@ -16,13 +16,18 @@
package routes
import (
+ "bytes"
"encoding/json"
"fmt"
+ "image"
+ "image/draw"
+ "io/ioutil"
"net/http"
"path"
"strconv"
"strings"
+ "github.com/nfnt/resize"
"github.com/pkg/errors"
"github.com/uncharted-distil/distil-compute/metadata"
"github.com/uncharted-distil/distil-compute/model"
@@ -69,66 +74,77 @@ func MultiBandImageHandler(ctor api.MetadataStorageCtor, dataCtor api.DataStorag
handleError(w, err)
return
}
-
res, err := storage.FetchDataset(dataset, false, false, false)
if err != nil {
handleError(w, err)
return
}
- sourcePath := env.ResolvePath(res.Source, res.Folder)
+ options := imagery.Options{Gain: 2.5, Gamma: 2.2, GainL: 1.0, Scale: false} // default options for color correction
- // need to read the dataset doc to determine the path to the data resource
- metaDisk, err := metadata.LoadMetadataFromOriginalSchema(path.Join(sourcePath, compute.D3MDataSchema), false)
- if err != nil {
- handleError(w, err)
- return
- }
- for _, dr := range metaDisk.DataResources {
- if dr.IsCollection && dr.ResType == model.ResTypeImage {
- sourcePath = model.GetResourcePathFromFolder(sourcePath, dr)
- break
+ var img *image.RGBA
+ if bandCombo == imagery.Segmentation {
+ sourcePath := path.Join(env.GetResourcePath(), res.ID, "media")
+ img, err = getSegmentationImage(imageID, sourcePath, false, 0)
+ if err != nil {
+ handleError(w, err)
+ return
}
- }
- options := imagery.Options{Gain: 2.5, Gamma: 2.2, GainL: 1.0, Scale: false} // default options for color correction
- if paramOption != "" {
- err := json.Unmarshal([]byte(paramOption), &options)
+ } else {
+ sourcePath := env.ResolvePath(res.Source, res.Folder)
+
+ // need to read the dataset doc to determine the path to the data resource
+ metaDisk, err := metadata.LoadMetadataFromOriginalSchema(path.Join(sourcePath, compute.D3MDataSchema), false)
if err != nil {
handleError(w, err)
return
}
- }
- if isThumbnail {
- imageScale = imagery.ImageScale{Width: ThumbnailDimensions, Height: ThumbnailDimensions}
- // if thumbnail scale should be 0
- options.Scale = false
- }
+ for _, dr := range metaDisk.DataResources {
+ if dr.IsCollection && dr.ResType == model.ResTypeImage {
+ sourcePath = model.GetResourcePathFromFolder(sourcePath, dr)
+ break
+ }
+ }
+ if paramOption != "" {
+ err := json.Unmarshal([]byte(paramOption), &options)
+ if err != nil {
+ handleError(w, err)
+ return
+ }
+ }
+ if isThumbnail {
+ imageScale = imagery.ImageScale{Width: ThumbnailDimensions, Height: ThumbnailDimensions}
+ // if thumbnail scale should be 0
+ options.Scale = false
+ }
- // need to get the band -> filename from the data
- bandMapping, err := getBandMapping(res, []string{imageID}, dataStorage)
- if err != nil {
- handleError(w, err)
- return
- }
- var optramMap map[string]imagery.OptramEdges
- optramPath := ""
- edge := imagery.OptramEdges{}
- precision := 0
- if bandCombo == imagery.OPTRAM {
- optramPath = strings.Join([]string{env.ResolvePath(res.Source, res.Folder), imagery.OPTRAMJSONFile}, "/")
- optramMap, precision, err = imagery.ReadOptramFile(optramPath)
+ // need to get the band -> filename from the data
+ bandMapping, err := getBandMapping(res, []string{imageID}, dataStorage)
if err != nil {
handleError(w, err)
return
}
- geoHash := imagery.ParseGeoHashFromID(imageID, precision)
- edge = optramMap[geoHash]
- }
+ var optramMap map[string]imagery.OptramEdges
+ optramPath := ""
+ edge := imagery.OptramEdges{}
+ precision := 0
+ if bandCombo == imagery.OPTRAM {
+ optramPath = strings.Join([]string{env.ResolvePath(res.Source, res.Folder), imagery.OPTRAMJSONFile}, "/")
+ optramMap, precision, err = imagery.ReadOptramFile(optramPath)
+ if err != nil {
+ handleError(w, err)
+ return
+ }
+ geoHash := imagery.ParseGeoHashFromID(imageID, precision)
+ edge = optramMap[geoHash]
+ }
- img, err := imagery.ImageFromCombination(sourcePath, bandMapping[imageID], bandCombo, imageScale, &edge, ramp, options)
- if err != nil {
- handleError(w, err)
- return
+ img, err = imagery.ImageFromCombination(sourcePath, bandMapping[imageID], bandCombo, imageScale, &edge, ramp, options)
+ if err != nil {
+ handleError(w, err)
+ return
+ }
}
+
if options.Scale && config.ShouldScaleImages {
img = c_util.UpscaleImage(img, c_util.GetModelType(config.ModelType))
}
@@ -145,6 +161,32 @@ func MultiBandImageHandler(ctor api.MetadataStorageCtor, dataCtor api.DataStorag
}
}
+func getSegmentationImage(imageID string, sourcePath string, thumbnail bool, dimensions int) (*image.RGBA, error) {
+ data, err := ioutil.ReadFile(path.Join(sourcePath, fmt.Sprintf("%s-segmentation.png", imageID)))
+ if err != nil {
+ return nil, errors.Wrapf(err, "unable to read segmentation image")
+ }
+
+ img, _, err := image.Decode(bytes.NewReader(data))
+ if err != nil {
+ return nil, errors.Wrapf(err, "unable to decode segmentation image")
+ }
+ dimensionsY := dimensions
+ dimensionsX := dimensions
+ if thumbnail {
+ img = resize.Thumbnail(uint(dimensionsX), uint(dimensionsY), img, resize.Lanczos3)
+ } else {
+ size := img.Bounds().Size()
+ dimensionsY = size.X
+ dimensionsX = size.Y
+ }
+
+ rgbaImg := image.NewRGBA(image.Rect(0, 0, dimensionsX, dimensionsY))
+ draw.Draw(rgbaImg, image.Rect(0, 0, dimensionsX, dimensionsY), img, img.Bounds().Min, draw.Src)
+
+ return rgbaImg, nil
+}
+
func getBandMapping(ds *api.Dataset, groupKeys []string, dataStorage api.DataStorage) (map[string]map[string]string, error) {
// build a filter to only include rows matching a group id
var groupingCol string
diff --git a/api/routes/segmentation.go b/api/routes/segmentation.go
index 6156cba7f..93e2d11cf 100644
--- a/api/routes/segmentation.go
+++ b/api/routes/segmentation.go
@@ -16,6 +16,8 @@ func SegmentationHandler(metaCtor api.MetadataStorageCtor, dataCtor api.DataStor
return func(w http.ResponseWriter, r *http.Request) {
// get dataset name
dataset := pat.Param(r, "dataset")
+ // get variable name
+ variable := pat.Param(r, "variable")
metaStorage, err := metaCtor()
if err != nil {
@@ -23,13 +25,19 @@ func SegmentationHandler(metaCtor api.MetadataStorageCtor, dataCtor api.DataStor
return
}
+ dataStorage, err := dataCtor()
+ if err != nil {
+ handleError(w, err)
+ return
+ }
+
ds, err := metaStorage.FetchDataset(dataset, false, false, false)
if err != nil {
handleError(w, err)
return
}
- outputURI, err := task.Segment(ds)
+ outputURI, err := task.Segment(ds, dataStorage, variable)
if err != nil {
handleError(w, errors.Wrap(err, "unable segment dataset"))
return
diff --git a/api/task/dataset.go b/api/task/dataset.go
index 8dacf0644..8ebe70447 100644
--- a/api/task/dataset.go
+++ b/api/task/dataset.go
@@ -129,7 +129,40 @@ func CopyDiskDataset(existingURI string, newURI string, newDatasetID string, new
// it to disk in D3M dataset format.
func ExportDataset(dataset string, metaStorage api.MetadataStorage, dataStorage api.DataStorage, filterParams *api.FilterParams) (string, string, error) {
// TODO: most likely need to either get a unique folder name for output or error if already exists
- return exportDiskDataset(dataset, dataset, env.ResolvePath(metadata.Augmented, dataset), metaStorage, dataStorage, false, filterParams)
+ datasetID, datasetPath, err := exportDiskDataset(dataset, dataset, env.ResolvePath(metadata.Augmented, dataset), metaStorage, dataStorage, false, filterParams)
+ if err != nil {
+ return "", "", err
+ }
+
+ // update the metadata stored to have the index reflect what is on disk
+ metaDisk, err := serialization.ReadMetadata(path.Join(datasetPath, compute.D3MDataSchema))
+ if err != nil {
+ return "", "", err
+ }
+
+ metaStored, err := metaStorage.FetchDataset(datasetID, true, true, true)
+ if err != nil {
+ return "", "", err
+ }
+
+ diskVars := api.MapVariables(metaDisk.GetMainDataResource().Variables, func(variable *model.Variable) string { return variable.Key })
+ notIncludedCount := 0
+ for _, v := range metaStored.Variables {
+ vDisk := diskVars[v.Key]
+ if vDisk == nil {
+ // variable not in disk dataset so arbitrarily give it an index > # of variables on disk
+ v.Index = len(diskVars) + notIncludedCount
+ notIncludedCount++
+ } else {
+ v.Index = vDisk.Index
+ }
+ }
+ err = metaStorage.UpdateDataset(metaStored)
+ if err != nil {
+ return "", "", err
+ }
+
+ return datasetID, datasetPath, nil
}
// CreateDatasetFromResult creates a new dataset based on a result set & the input
diff --git a/api/task/segment.go b/api/task/segment.go
index f09908771..7d4aed940 100644
--- a/api/task/segment.go
+++ b/api/task/segment.go
@@ -1,29 +1,137 @@
package task
import (
+ "fmt"
+ "os"
+ "path"
+ "strconv"
+
+ "github.com/pkg/errors"
+
+ "github.com/uncharted-distil/distil-compute/model"
"github.com/uncharted-distil/distil-compute/primitive/compute/description"
+ "github.com/uncharted-distil/distil-compute/primitive/compute/result"
"github.com/uncharted-distil/distil/api/env"
api "github.com/uncharted-distil/distil/api/model"
+ "github.com/uncharted-distil/distil/api/util"
+ "github.com/uncharted-distil/distil/api/util/imagery"
)
// Segment segments an image into separate parts.
-func Segment(dataset *api.Dataset) (string, error) {
+func Segment(ds *api.Dataset, dataStorage api.DataStorage, variableName string) (string, error) {
envConfig, err := env.LoadConfig()
if err != nil {
return "", err
}
- datasetInputDir := env.ResolvePath(dataset.Source, dataset.Folder)
+ datasetInputDir := env.ResolvePath(ds.Source, ds.Folder)
+
+ var variable *model.Variable
+ for _, v := range ds.Variables {
+ if v.Key == variableName {
+ variable = v
+ break
+ }
+ }
+
+ step, err := description.CreateRemoteSensingSegmentationPipeline("segmentation", "basic image segmentation", variable, envConfig.RemoteSensingNumJobs)
+ if err != nil {
+ return "", err
+ }
- step, err := description.CreateRemoteSensingSegmentationPipeline("segmentation", "basic image segmentation", envConfig.RemoteSensingNumJobs)
+ resultURI, err := submitPipeline([]string{datasetInputDir}, step, true)
if err != nil {
return "", err
}
- datasetURI, err := submitPipeline([]string{datasetInputDir}, step, true)
+ // read the file and parse the output mask
+ result, err := result.ParseResultCSV(resultURI)
if err != nil {
return "", err
}
- return datasetURI, nil
+ // need to pull the data to properly map d3m index to expected file names
+ // filenames should be "groupid-segmentation.png" for now
+ // TODO: may need to build the grouping key from multiple fields when moving away from test
+ var groupingKey *model.Variable
+ for _, v := range ds.Variables {
+ if v.DistilRole == model.VarDistilRoleGrouping {
+ groupingKey = v
+ break
+ }
+ }
+ if groupingKey == nil {
+ return "", errors.Errorf("no grouping found to use for output filename")
+ }
+ mapping, err := getFieldMapping(ds, groupingKey.Key, dataStorage)
+ if err != nil {
+ return "", err
+ }
+
+ // need to output all the masks as images
+ imageOutputFolder := path.Join(env.GetResourcePath(), ds.ID, "media")
+ for _, r := range result[1:] {
+ // create the image that captures the mask
+ d3mIndex := r[0].(string)
+ rawMask := r[1].([]interface{})
+ rawFloats := make([][]float64, len(rawMask))
+ for i, f := range rawMask {
+ dataF := f.([]interface{})
+ nestedFloats := make([]float64, len(dataF))
+ for j, nf := range dataF {
+ fp, err := strconv.ParseFloat(nf.(string), 64)
+ if err != nil {
+ return "", errors.Wrapf(err, "unable to parse mask")
+ }
+ nestedFloats[j] = fp
+ }
+ rawFloats[i] = nestedFloats
+ }
+
+ filter := imagery.ConfidenceMatrixToImage(rawFloats, imagery.MagmaColorScale, uint8(100))
+ imageBytes, err := imagery.ImageToPNG(filter)
+ if err != nil {
+ return "", err
+ }
+
+ // write the image to disk using a basic naming convention
+ imageFilename := path.Join(imageOutputFolder, fmt.Sprintf("%s-segmentation.png", mapping[d3mIndex]))
+ err = util.WriteFileWithDirs(imageFilename, imageBytes, os.ModePerm)
+ if err != nil {
+ return "", err
+ }
+ }
+
+ return "", nil
+}
+
+func getFieldMapping(ds *api.Dataset, fieldName string, dataStorage api.DataStorage) (map[string]string, error) {
+ filter := &api.FilterParams{Variables: []string{model.D3MIndexFieldName, fieldName}}
+
+ // pull back all rows for a group id
+ data, err := dataStorage.FetchData(ds.ID, ds.StorageName, filter, true, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ // cycle through results to build the band mapping
+ fieldColumn, ok := data.Columns[fieldName]
+ if !ok {
+ return nil, errors.Errorf("'%s' column not found in stored data", fieldName)
+ }
+ fieldColumnIndex := fieldColumn.Index
+ d3mColumn, ok := data.Columns[model.D3MIndexFieldName]
+ if !ok {
+ return nil, errors.Errorf("'%s' column not found in stored data", model.D3MIndexFieldName)
+ }
+ d3mColumnIndex := d3mColumn.Index
+
+ mapping := map[string]string{}
+ for _, r := range data.Values {
+ d3mIndexData := fmt.Sprintf("%.0f", r[d3mColumnIndex].Value.(float64))
+ fieldData := r[fieldColumnIndex].Value.(string)
+ mapping[d3mIndexData] = fieldData
+ }
+
+ return mapping, nil
}
diff --git a/api/util/imagery/imagery.go b/api/util/imagery/imagery.go
index e1905e1ea..471b9c317 100644
--- a/api/util/imagery/imagery.go
+++ b/api/util/imagery/imagery.go
@@ -71,6 +71,9 @@ const (
// AtmosphericRemoval identifies a band mapping that displays an image in near true color with atmoshperic effects reduced.
AtmosphericRemoval = "atmospheric_removal"
+ // Segmentation identifies a placeholder band mapping to display image segmentation output.
+ Segmentation = "segmentation"
+
// ShortwaveInfrared identifies a band mapping that displays an image in shortwave infrared.
ShortwaveInfrared = "shortwave_infrared"
@@ -203,6 +206,7 @@ func init() {
MNDWI: {MNDWI, "Modified Normalized Difference Water Index", []string{"b03", "b11"}, BrownYellowBlueRamp, NormalizingTransform, false},
RSWIR: {RSWIR, "Red and Shortwave Infrared", []string{"b04", "b11"}, BrownYellowBlueRamp, NormalizingTransform, false},
OPTRAM: {OPTRAM, "OPTRAM", []string{"b08", "b04", "b12"}, RedYellowGreenRamp, OptramTransform, false},
+ Segmentation: {Segmentation, "Segmentation", []string{}, nil, nil, false},
}
}
diff --git a/go.mod b/go.mod
index 4af24a519..1d39db182 100644
--- a/go.mod
+++ b/go.mod
@@ -34,7 +34,7 @@ require (
github.com/russross/blackfriday v2.0.0+incompatible
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
github.com/stretchr/testify v1.6.1
- github.com/uncharted-distil/distil-compute v0.0.0-20211105184340-305f023f7b6b
+ github.com/uncharted-distil/distil-compute v0.0.0-20211112201613-074edcd7ab1d
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb
github.com/uncharted-distil/gdal v0.0.0-20200504224203-25f2e6a0dc2a
github.com/unchartedsoftware/plog v0.0.0-20200807135627-83d59e50ced5
diff --git a/go.sum b/go.sum
index 71dd8b217..6d3074212 100644
--- a/go.sum
+++ b/go.sum
@@ -213,8 +213,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/uncharted-distil/distil-compute v0.0.0-20211105184340-305f023f7b6b h1:0n76IL/t7ewpJSyhIGkrDRJ7Zp8F5W3fm4Vfyd3SvCo=
-github.com/uncharted-distil/distil-compute v0.0.0-20211105184340-305f023f7b6b/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8=
+github.com/uncharted-distil/distil-compute v0.0.0-20211112201613-074edcd7ab1d h1:CV0kH3rZFLDqjvSpNtvI2IUS9NVITwOQVgxkAaBo+KQ=
+github.com/uncharted-distil/distil-compute v0.0.0-20211112201613-074edcd7ab1d/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8=
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb h1:wDsXsrF8qM34nLeQ9xW+zbEdRNATk5sgOwuwCTrZmvY=
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb/go.mod h1:Xhb77n2q8yDvcVS3Mvw0XlpdNMiFsL+vOlvoe556ivc=
github.com/uncharted-distil/gdal v0.0.0-20200504224203-25f2e6a0dc2a h1:BPJrlnjdhxMBrJWiU4/Gl3PVdCUlY9JspWFTJ9UVO0Y=
diff --git a/main.go b/main.go
index ff11df2f0..2c30c7044 100644
--- a/main.go
+++ b/main.go
@@ -297,7 +297,7 @@ func main() {
registerRoutePost(mux, "/distil/update/:dataset", routes.UpdateHandler(esMetadataStorageCtor, pgDataStorageCtor, config))
registerRoutePost(mux, "/distil/clone-result/:produce-request-id", routes.CloningResultsHandler(esMetadataStorageCtor, pgDataStorageCtor, pgSolutionStorageCtor, config))
registerRoutePost(mux, "/distil/clone/:dataset", routes.CloningHandler(esMetadataStorageCtor, pgDataStorageCtor, config))
- registerRoutePost(mux, "/distil/segment/:dataset", routes.SegmentationHandler(esMetadataStorageCtor, pgDataStorageCtor, config))
+ registerRoutePost(mux, "/distil/segment/:dataset/:variable", routes.SegmentationHandler(esMetadataStorageCtor, pgDataStorageCtor, config))
registerRoutePost(mux, "/distil/save-dataset/:dataset", routes.SaveDatasetHandler(esMetadataStorageCtor, pgDataStorageCtor, config))
registerRoutePost(mux, "/distil/add-field/:dataset", routes.AddFieldHandler(esMetadataStorageCtor, pgDataStorageCtor))
registerRoutePost(mux, "/distil/extract/:dataset", routes.ExtractHandler(esMetadataStorageCtor, pgDataStorageCtor, config))
From d1268526a7c58c30eb95509d59f5cacf8dd2a2dc Mon Sep 17 00:00:00 2001
From: philippe horne
Date: Tue, 16 Nov 2021 09:56:44 -0500
Subject: [PATCH 03/16] Added config param to disable segmentation layer.
---
api/env/config.go | 1 +
api/util/imagery/imagery.go | 6 +++++-
go.mod | 2 +-
go.sum | 4 ++--
4 files changed, 9 insertions(+), 4 deletions(-)
diff --git a/api/env/config.go b/api/env/config.go
index 167ef5a13..19e507acd 100644
--- a/api/env/config.go
+++ b/api/env/config.go
@@ -84,6 +84,7 @@ type Config struct {
ResourceSubFolder string `env:"RESOURCE_SUBFOLDER" envDefault:"resources"`
ShouldScaleImages bool `env:"SHOULD_SCALE_IMAGES" envDefault:"false"` // enables and disables image scaling
SkipPreprocessing bool `env:"SKIP_PREPROCESSING" envDefault:"false"`
+ SegmentationEnabled bool `env:"SEGMENTATION_ENABLED" envDefault:"false"`
SolutionComputeEndpoint string `env:"SOLUTION_COMPUTE_ENDPOINT" envDefault:"localhost:50051"`
SolutionComputePullTimeout int `env:"SOLUTION_COMPUTE_PULL_TIMEOUT" envDefault:"60"`
SolutionComputePullMax int `env:"SOLUTION_COMPUTE_PULL_MAX" envDefault:"10"`
diff --git a/api/util/imagery/imagery.go b/api/util/imagery/imagery.go
index 471b9c317..314103a90 100644
--- a/api/util/imagery/imagery.go
+++ b/api/util/imagery/imagery.go
@@ -186,6 +186,7 @@ var (
)
func init() {
+ config, _ := env.LoadConfig()
// initialize the band combination structures - needs to be done in init so that referenced color ramps are built
SentinelBandCombinations = map[string]*BandCombination{
NaturalColors1: {NaturalColors1, "Natural Colors", []string{"b04", "b03", "b02"}, nil, nil, false},
@@ -206,7 +207,10 @@ func init() {
MNDWI: {MNDWI, "Modified Normalized Difference Water Index", []string{"b03", "b11"}, BrownYellowBlueRamp, NormalizingTransform, false},
RSWIR: {RSWIR, "Red and Shortwave Infrared", []string{"b04", "b11"}, BrownYellowBlueRamp, NormalizingTransform, false},
OPTRAM: {OPTRAM, "OPTRAM", []string{"b08", "b04", "b12"}, RedYellowGreenRamp, OptramTransform, false},
- Segmentation: {Segmentation, "Segmentation", []string{}, nil, nil, false},
+ }
+
+ if config.SegmentationEnabled {
+ SentinelBandCombinations[Segmentation] = &BandCombination{Segmentation, "Segmentation", []string{}, nil, nil, false}
}
}
diff --git a/go.mod b/go.mod
index 1d39db182..5547f6b2a 100644
--- a/go.mod
+++ b/go.mod
@@ -34,7 +34,7 @@ require (
github.com/russross/blackfriday v2.0.0+incompatible
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
github.com/stretchr/testify v1.6.1
- github.com/uncharted-distil/distil-compute v0.0.0-20211112201613-074edcd7ab1d
+ github.com/uncharted-distil/distil-compute v0.0.0-20211116145504-3a728e358d77
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb
github.com/uncharted-distil/gdal v0.0.0-20200504224203-25f2e6a0dc2a
github.com/unchartedsoftware/plog v0.0.0-20200807135627-83d59e50ced5
diff --git a/go.sum b/go.sum
index 6d3074212..df7428ee3 100644
--- a/go.sum
+++ b/go.sum
@@ -213,8 +213,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/uncharted-distil/distil-compute v0.0.0-20211112201613-074edcd7ab1d h1:CV0kH3rZFLDqjvSpNtvI2IUS9NVITwOQVgxkAaBo+KQ=
-github.com/uncharted-distil/distil-compute v0.0.0-20211112201613-074edcd7ab1d/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8=
+github.com/uncharted-distil/distil-compute v0.0.0-20211116145504-3a728e358d77 h1:h/3JZ7rDTSwnJ6MY6w4QyrUZFlqlWstAxoXkG+60Iww=
+github.com/uncharted-distil/distil-compute v0.0.0-20211116145504-3a728e358d77/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8=
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb h1:wDsXsrF8qM34nLeQ9xW+zbEdRNATk5sgOwuwCTrZmvY=
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb/go.mod h1:Xhb77n2q8yDvcVS3Mvw0XlpdNMiFsL+vOlvoe556ivc=
github.com/uncharted-distil/gdal v0.0.0-20200504224203-25f2e6a0dc2a h1:BPJrlnjdhxMBrJWiU4/Gl3PVdCUlY9JspWFTJ9UVO0Y=
From be7dc5f74d8fa11a9e7d8f22c488808bd8eb3a66 Mon Sep 17 00:00:00 2001
From: phorne
Date: Mon, 11 Jul 2022 17:07:16 -0400
Subject: [PATCH 04/16] Initial changes to run segmentation as a model search.
---
api/compute/segment.go | 58 ++++++++++
api/compute/solution_request.go | 170 +++++++++++++++++++++++++---
api/compute/split.go | 52 +++++++++
api/compute/task.go | 4 +-
api/model/grouped_variables.go | 35 ++++++
api/task/cleaning.go | 2 +-
api/task/segment.go | 48 +++-----
go.mod | 2 +-
go.sum | 4 +-
public/components/SettingsModal.vue | 26 +++++
public/store/dataset/index.ts | 1 +
11 files changed, 349 insertions(+), 53 deletions(-)
create mode 100644 api/compute/segment.go
diff --git a/api/compute/segment.go b/api/compute/segment.go
new file mode 100644
index 000000000..0e3069feb
--- /dev/null
+++ b/api/compute/segment.go
@@ -0,0 +1,58 @@
+//
+// Copyright © 2021 Uncharted Software Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package compute
+
+import (
+ "strconv"
+
+ "github.com/pkg/errors"
+
+ "github.com/uncharted-distil/distil/api/util/imagery"
+)
+
+// BuildSegmentationImage uses the raw segmentation output to build an image layer.
+func BuildSegmentationImage(rawSegmentation [][]interface{}) (map[string][]byte, error) {
+ // output is mapping of d3m index to new segmentation layer
+ output := map[string][]byte{}
+ // need to output all the masks as images
+ for _, r := range rawSegmentation[1:] {
+ // create the image that captures the mask
+ d3mIndex := r[0].(string)
+ rawMask := r[1].([]interface{})
+ rawFloats := make([][]float64, len(rawMask))
+ for i, f := range rawMask {
+ dataF := f.([]interface{})
+ nestedFloats := make([]float64, len(dataF))
+ for j, nf := range dataF {
+ fp, err := strconv.ParseFloat(nf.(string), 64)
+ if err != nil {
+ return nil, errors.Wrapf(err, "unable to parse mask")
+ }
+ nestedFloats[j] = fp
+ }
+ rawFloats[i] = nestedFloats
+ }
+
+ filter := imagery.ConfidenceMatrixToImage(rawFloats, imagery.MagmaColorScale, uint8(100))
+ imageBytes, err := imagery.ImageToPNG(filter)
+ if err != nil {
+ return nil, err
+ }
+ output[d3mIndex] = imageBytes
+ }
+
+ return output, nil
+}
diff --git a/api/compute/solution_request.go b/api/compute/solution_request.go
index 22dbd009b..9bc3f1e4f 100644
--- a/api/compute/solution_request.go
+++ b/api/compute/solution_request.go
@@ -18,6 +18,7 @@ package compute
import (
"context"
"fmt"
+ "os"
"path"
"path/filepath"
"strconv"
@@ -33,8 +34,11 @@ import (
"github.com/uncharted-distil/distil-compute/pipeline"
"github.com/uncharted-distil/distil-compute/primitive/compute"
"github.com/uncharted-distil/distil-compute/primitive/compute/description"
+ "github.com/uncharted-distil/distil-compute/primitive/compute/result"
+ "github.com/uncharted-distil/distil/api/env"
api "github.com/uncharted-distil/distil/api/model"
"github.com/uncharted-distil/distil/api/serialization"
+ "github.com/uncharted-distil/distil/api/util"
"github.com/uncharted-distil/distil/api/util/json"
log "github.com/unchartedsoftware/plog"
"google.golang.org/grpc/codes"
@@ -563,7 +567,6 @@ func describeSolution(client *compute.Client, initialSearchSolutionID string) (*
func (s *SolutionRequest) dispatchRequest(client *compute.Client, solutionStorage api.SolutionStorage,
dataStorage api.DataStorage, searchContext pipelineSearchContext) {
-
// update request status
err := s.persistRequestStatus(s.requestChannel, solutionStorage, searchContext.searchID, searchContext.dataset, compute.RequestRunningStatus)
if err != nil {
@@ -631,6 +634,110 @@ func (s *SolutionRequest) dispatchRequest(client *compute.Client, solutionStorag
s.finished <- nil
}
+func dispatchSegmentation(s *SolutionRequest, solutionStorage api.SolutionStorage, metaStorage api.MetadataStorage,
+ dataStorage api.DataStorage, client *compute.Client, datasetInputDir string, step *description.FullySpecifiedPipeline) {
+ // need a request ID
+ uuid, err := uuid.NewV4()
+ if err != nil {
+ s.finished <- errors.Wrapf(err, "unable to generate request id")
+ return
+ }
+
+ // create the backing data
+ err = s.persistRequestStatus(s.requestChannel, solutionStorage, uuid.String(), s.Dataset, compute.RequestRunningStatus)
+ if err != nil {
+ s.finished <- err
+ return
+ }
+
+ // run the pipeline
+ resultURI, err := SubmitPipeline(client, []string{datasetInputDir}, nil, nil, step, nil, true)
+ if err != nil {
+ s.finished <- err
+ return
+ }
+
+ // update status and respond to client as needed
+
+ // read the file and parse the output mask
+ result, err := result.ParseResultCSV(resultURI)
+ if err != nil {
+ s.finished <- err
+ return
+ }
+
+ images, err := BuildSegmentationImage(result)
+ if err != nil {
+ s.finished <- err
+ return
+ }
+
+ // get the grouping key since it makes up part of the filename
+ dataset, err := metaStorage.FetchDataset(s.Dataset, true, true, false)
+ if err != nil {
+ s.finished <- err
+ return
+ }
+
+ var groupingKey *model.Variable
+ for _, v := range dataset.Variables {
+ if v.HasRole(model.VarDistilRoleGrouping) {
+ groupingKey = v
+ break
+ }
+ }
+ if groupingKey == nil {
+ s.finished <- errors.Errorf("no grouping found to use for output filename")
+ return
+ }
+
+ // get the d3m index -> grouping key mapping
+ mapping, err := api.BuildFieldMapping(dataset.ID, dataset.StorageName, model.D3MIndexFieldName, groupingKey.Key, dataStorage)
+ if err != nil {
+ s.finished <- err
+ return
+ }
+
+ imageOutputFolder := path.Join(env.GetResourcePath(), dataset.ID, "media")
+ for d3mIndex, imageBytes := range images {
+ imageFilename := path.Join(imageOutputFolder, fmt.Sprintf("%s-segmentation.png", mapping[d3mIndex]))
+ err = util.WriteFileWithDirs(imageFilename, imageBytes, os.ModePerm)
+ if err != nil {
+ s.finished <- err
+ return
+ }
+ }
+
+ s.finished <- nil
+}
+
+func processSegmentation(s *SolutionRequest, client *compute.Client, solutionStorage api.SolutionStorage, metaStorage api.MetadataStorage, dataStorage api.DataStorage) error {
+ // create the fully specified pipeline
+ envConfig, err := env.LoadConfig()
+ if err != nil {
+ return err
+ }
+
+ // fetch the source dataset
+ dataset, err := metaStorage.FetchDataset(s.Dataset, true, true, false)
+ if err != nil {
+ return nil
+ }
+ s.DatasetMetadata = dataset
+
+ datasetInputDir := env.ResolvePath(dataset.Source, dataset.Folder)
+
+ step, err := description.CreateRemoteSensingSegmentationPipeline("segmentation", "basic image segmentation", s.TargetFeature, envConfig.RemoteSensingNumJobs)
+ if err != nil {
+ return err
+ }
+
+ // dispatch it as if it were a model search
+ go dispatchSegmentation(s, solutionStorage, metaStorage, dataStorage, client, datasetInputDir, step)
+
+ return nil
+}
+
// PersistAndDispatch persists the solution request and dispatches it.
func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionStorage api.SolutionStorage, metaStorage api.MetadataStorage, dataStorage api.DataStorage) error {
@@ -706,18 +813,6 @@ func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionSto
}
s.Filters = updatedFilters
- // get the target
- datasetInputDir := filteredDatasetPath
- meta, err := serialization.ReadMetadata(path.Join(datasetInputDir, compute.D3MDataSchema))
- if err != nil {
- return err
- }
- metaVars := meta.GetMainDataResource().Variables
- targetVariable, err = findVariable(targetVariable.Key, metaVars)
- if err != nil {
- return err
- }
-
if dataset.LearningDataset != "" {
s.useParquet = true
groupingVariableIndex = -1
@@ -733,6 +828,11 @@ func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionSto
return err
}
s.Task = task.Task
+
+ if HasTaskType(task, compute.SegmentationTask) {
+ return processSegmentation(s, client, solutionStorage, metaStorage, dataStorage)
+ }
+
// check if TimestampSplitValue is not 0
if s.TimestampSplitValue > 0 {
found := false
@@ -751,6 +851,24 @@ func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionSto
}
}
+ // HACK: SEGMENTATION TASK NEEDS TO ACT ON BASE DATASET!
+ // CURRENTLY SET TO IGNORE PREFILTERING!!
+ if HasTaskType(task, compute.SegmentationTask) {
+ filteredDatasetPath = env.ResolvePath(dataset.Source, dataset.Folder)
+ s.useParquet = false
+ }
+
+ // get the target
+ meta, err := serialization.ReadMetadata(path.Join(filteredDatasetPath, compute.D3MDataSchema))
+ if err != nil {
+ return err
+ }
+ metaVars := meta.GetMainDataResource().Variables
+ targetVariable, err = findVariable(targetVariable.Key, metaVars)
+ if err != nil {
+ return err
+ }
+
// when dealing with categorical data we want to stratify
stratify := model.IsCategorical(s.TargetFeature.Type)
// create the splitter to use for the train / test split
@@ -779,8 +897,30 @@ func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionSto
s.Filters = mapFilterKeys(s.Dataset, s.Filters, dataset.Variables)
// generate the pre-processing pipeline to enforce feature selection and semantic type changes
+ // HACK: IF SEGMENTATION, THEN SUBMIT THE FULLY SPECIFIED PIPELINE!!!
var preprocessing *pipeline.PipelineDescription
- if !client.SkipPreprocessing {
+ if HasTaskType(task, compute.SegmentationTask) {
+ envConfig, err := env.LoadConfig()
+ if err != nil {
+ return err
+ }
+
+ ps, err := description.CreateRemoteSensingSegmentationPipeline("segmentation", "basic image segmentation", s.TargetFeature, envConfig.RemoteSensingNumJobs)
+ if err != nil {
+ return err
+ }
+ preprocessing = ps.Pipeline
+
+ // remove the segmentation task
+ tasksUpdated := []string{}
+ for _, t := range s.Task {
+ if t != compute.SegmentationTask {
+ tasksUpdated = append(tasksUpdated, t)
+ }
+ }
+ s.Task = tasksUpdated
+
+ } else if !client.SkipPreprocessing {
if dataset.LearningDataset == "" {
preprocessing, err = s.createPreprocessingPipeline(variables, metaStorage)
} else {
@@ -846,7 +986,7 @@ func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionSto
searchID: requestID,
dataset: dataset.ID,
storageName: dataset.StorageName,
- sourceDatasetURI: datasetInputDir,
+ sourceDatasetURI: filteredDatasetPath,
trainDatasetURI: datasetPathTrain,
testDatasetURI: datasetPathTest,
produceDatasetURI: datasetPathTest,
diff --git a/api/compute/split.go b/api/compute/split.go
index 4d90964b8..8e0030ba0 100644
--- a/api/compute/split.go
+++ b/api/compute/split.go
@@ -56,6 +56,10 @@ type basicSplitter struct {
trainTestSplit float64
}
+type copySplitter struct {
+ rowLimits rowLimits
+}
+
type stratifiedSplitter struct {
rowLimits rowLimits
targetCol int
@@ -204,6 +208,50 @@ func (b *basicSplitter) sample(data [][]string, maxRows int) [][]string {
return output
}
+func (c *copySplitter) hash(schemaFile string, params ...interface{}) (uint64, error) {
+ // generate the hash from the params
+ hashStruct := struct {
+ Schema string
+ Copy bool
+ RowLimits rowLimits
+ Params []interface{}
+ }{
+ Schema: schemaFile,
+ Copy: true,
+ RowLimits: c.rowLimits,
+ Params: params,
+ }
+ hash, err := hashstructure.Hash(hashStruct, nil)
+ if err != nil {
+ return 0, errors.Wrap(err, "failed to generate persisted data hash")
+ }
+ return hash, nil
+}
+
+func (c *copySplitter) split(data [][]string) ([][]string, [][]string, error) {
+ log.Infof("splitting data using copy splitter...")
+ // create the output
+ outputTrain := [][]string{}
+ outputTest := [][]string{}
+
+ // handle the header
+ inputData, outputTrain, outputTest := splitTrainTestHeader(data, outputTrain, outputTest, true)
+
+ numTrainingRows := c.rowLimits.trainingRows(len(inputData))
+
+ // sample to meet row limit constraints
+ output := c.sample(inputData, numTrainingRows)
+
+ return append(outputTrain, output...), append(outputTest, output...), nil
+}
+
+func (c *copySplitter) sample(data [][]string, maxRows int) [][]string {
+ output := [][]string{}
+ output, _ = shuffleAndWrite(data[1:], -1, maxRows, 0, false, output, nil, float64(1))
+
+ return output
+}
+
func (s *stratifiedSplitter) hash(schemaFile string, params ...interface{}) (uint64, error) {
// generate the hash from the params
hashStruct := struct {
@@ -496,6 +544,10 @@ func createSplitter(taskType []string, targetFieldIndex int, groupingFieldIndex
trainTestSplit: trainTestSplit,
},
}
+ } else if task == compute.SegmentationTask {
+ return ©Splitter{
+ rowLimits: limits,
+ }
}
}
// if not null
diff --git a/api/compute/task.go b/api/compute/task.go
index db4543124..408f01363 100644
--- a/api/compute/task.go
+++ b/api/compute/task.go
@@ -88,7 +88,7 @@ func ResolveTask(storage api.DataStorage, datasetStorageName string, targetVaria
if model.IsImage(feature.Type) {
task = append(task, compute.ImageTask)
} else if model.IsMultiBandImage(feature.Type) {
- task = append(task, compute.RemoteSensingTask)
+ task = append(task, compute.RemoteSensingTask, compute.SegmentationTask)
} else if model.IsTimeSeries(feature.Type) {
task = append(task, compute.TimeSeriesTask)
}
@@ -110,7 +110,7 @@ func ResolveTask(storage api.DataStorage, datasetStorageName string, targetVaria
task = append(task, compute.SemiSupervisedTask)
}
// If there are 3 labels (2 + empty), update this as a binary classification task
- if len(targetCounts) == 2 {
+ if len(targetCounts) == 3 {
task = append(task, compute.BinaryTask)
} else {
task = append(task, compute.MultiClassTask)
diff --git a/api/model/grouped_variables.go b/api/model/grouped_variables.go
index 9da46392c..70560e63d 100644
--- a/api/model/grouped_variables.go
+++ b/api/model/grouped_variables.go
@@ -18,6 +18,8 @@ package model
import (
"fmt"
+ "github.com/pkg/errors"
+
"github.com/uncharted-distil/distil-compute/model"
log "github.com/unchartedsoftware/plog"
)
@@ -171,6 +173,39 @@ func GetClusterColFromGrouping(group model.BaseGrouping) (string, bool) {
return "", false
}
+// BuildFieldMapping builds a mapping from a source field to a target field.
+func BuildFieldMapping(dsID string, dsStorageName string, sourceFieldName string,
+ targetFieldName string, dataStorage DataStorage) (map[string]string, error) {
+ filter := &FilterParams{Variables: []string{sourceFieldName, targetFieldName}}
+
+ // pull back all rows for a group id
+ data, err := dataStorage.FetchData(dsID, dsStorageName, filter, true, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ // cycle through results to build the band mapping
+ targetFieldColumn, ok := data.Columns[targetFieldName]
+ if !ok {
+ return nil, errors.Errorf("'%s' column not found in stored data", targetFieldName)
+ }
+ targetFieldColumnIndex := targetFieldColumn.Index
+ sourceColumn, ok := data.Columns[sourceFieldName]
+ if !ok {
+ return nil, errors.Errorf("'%s' column not found in stored data", sourceFieldName)
+ }
+ sourceColumnIndex := sourceColumn.Index
+
+ mapping := map[string]string{}
+ for _, r := range data.Values {
+ sourceData := fmt.Sprintf("%.0f", r[sourceColumnIndex].Value.(float64))
+ fieldData := r[targetFieldColumnIndex].Value.(string)
+ mapping[sourceData] = fieldData
+ }
+
+ return mapping, nil
+}
+
// UpdateFilterKey updates the supplied filter key to point to a group-specific column, rather than relying on the group variable
// name.
func UpdateFilterKey(metaStore MetadataStorage, dataset string, dataMode DataMode, filter *model.Filter, variable *model.Variable) {
diff --git a/api/task/cleaning.go b/api/task/cleaning.go
index ec9e86542..7b2ffe5ba 100644
--- a/api/task/cleaning.go
+++ b/api/task/cleaning.go
@@ -59,7 +59,7 @@ func Clean(schemaFile string, dataset string, params *IngestParams, config *Inge
}
// create & submit the solution request
- pip, err := description.CreateDataCleaningPipeline("Mary Poppins", "", vars)
+ pip, err := description.CreateDataCleaningPipeline("Mary Poppins", "", vars, true)
if err != nil {
return "", errors.Wrap(err, "unable to create format pipeline")
}
diff --git a/api/task/segment.go b/api/task/segment.go
index 1d7b8aed4..7ac28e4f6 100644
--- a/api/task/segment.go
+++ b/api/task/segment.go
@@ -1,3 +1,18 @@
+//
+// Copyright © 2021 Uncharted Software Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
package task
import (
@@ -63,7 +78,7 @@ func Segment(ds *api.Dataset, dataStorage api.DataStorage, variableName string)
if groupingKey == nil {
return "", errors.Errorf("no grouping found to use for output filename")
}
- mapping, err := getFieldMapping(ds, groupingKey.Key, dataStorage)
+ mapping, err := api.BuildFieldMapping(ds.ID, ds.StorageName, model.D3MIndexFieldName, groupingKey.Key, dataStorage)
if err != nil {
return "", err
}
@@ -104,34 +119,3 @@ func Segment(ds *api.Dataset, dataStorage api.DataStorage, variableName string)
return "", nil
}
-
-func getFieldMapping(ds *api.Dataset, fieldName string, dataStorage api.DataStorage) (map[string]string, error) {
- filter := &api.FilterParams{Variables: []string{model.D3MIndexFieldName, fieldName}}
-
- // pull back all rows for a group id
- data, err := dataStorage.FetchData(ds.ID, ds.StorageName, filter, true, nil)
- if err != nil {
- return nil, err
- }
-
- // cycle through results to build the band mapping
- fieldColumn, ok := data.Columns[fieldName]
- if !ok {
- return nil, errors.Errorf("'%s' column not found in stored data", fieldName)
- }
- fieldColumnIndex := fieldColumn.Index
- d3mColumn, ok := data.Columns[model.D3MIndexFieldName]
- if !ok {
- return nil, errors.Errorf("'%s' column not found in stored data", model.D3MIndexFieldName)
- }
- d3mColumnIndex := d3mColumn.Index
-
- mapping := map[string]string{}
- for _, r := range data.Values {
- d3mIndexData := fmt.Sprintf("%.0f", r[d3mColumnIndex].Value.(float64))
- fieldData := r[fieldColumnIndex].Value.(string)
- mapping[d3mIndexData] = fieldData
- }
-
- return mapping, nil
-}
diff --git a/go.mod b/go.mod
index 5547f6b2a..711222db8 100644
--- a/go.mod
+++ b/go.mod
@@ -34,7 +34,7 @@ require (
github.com/russross/blackfriday v2.0.0+incompatible
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
github.com/stretchr/testify v1.6.1
- github.com/uncharted-distil/distil-compute v0.0.0-20211116145504-3a728e358d77
+ github.com/uncharted-distil/distil-compute v0.0.0-20220706150456-b5974c46e396
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb
github.com/uncharted-distil/gdal v0.0.0-20200504224203-25f2e6a0dc2a
github.com/unchartedsoftware/plog v0.0.0-20200807135627-83d59e50ced5
diff --git a/go.sum b/go.sum
index df7428ee3..e83463aa4 100644
--- a/go.sum
+++ b/go.sum
@@ -213,8 +213,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/uncharted-distil/distil-compute v0.0.0-20211116145504-3a728e358d77 h1:h/3JZ7rDTSwnJ6MY6w4QyrUZFlqlWstAxoXkG+60Iww=
-github.com/uncharted-distil/distil-compute v0.0.0-20211116145504-3a728e358d77/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8=
+github.com/uncharted-distil/distil-compute v0.0.0-20220706150456-b5974c46e396 h1:PXewTNQEdzPji0CY6Kld51dOnLJDNJz3Vqv8RLV/S2Q=
+github.com/uncharted-distil/distil-compute v0.0.0-20220706150456-b5974c46e396/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8=
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb h1:wDsXsrF8qM34nLeQ9xW+zbEdRNATk5sgOwuwCTrZmvY=
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb/go.mod h1:Xhb77n2q8yDvcVS3Mvw0XlpdNMiFsL+vOlvoe556ivc=
github.com/uncharted-distil/gdal v0.0.0-20200504224203-25f2e6a0dc2a h1:BPJrlnjdhxMBrJWiU4/Gl3PVdCUlY9JspWFTJ9UVO0Y=
diff --git a/public/components/SettingsModal.vue b/public/components/SettingsModal.vue
index 9b502f76a..387954294 100644
--- a/public/components/SettingsModal.vue
+++ b/public/components/SettingsModal.vue
@@ -131,6 +131,18 @@
+
+
+
+ {{ task }}
+
+
+
@@ -196,6 +208,20 @@ export default Vue.extend({
});
},
+ multipleTasks(): boolean {
+ // hack to only really be true when classification and segmentation is possible
+ return (
+ this.task.includes(TaskTypes.REMOTE_SENSING) && this.tasks.length > 1
+ );
+ },
+
+ tasks(): string[] {
+ // hack to only really be allow for classification and segmentation
+ return this.task
+ .split(",")
+ .filter((t) => t != TaskTypes.REMOTE_SENSING && t != TaskTypes.BINARY);
+ },
+
task(): string {
return routeGetters.getRouteTask(this.$store) ?? "";
},
diff --git a/public/store/dataset/index.ts b/public/store/dataset/index.ts
index c34c06289..61b2352a9 100644
--- a/public/store/dataset/index.ts
+++ b/public/store/dataset/index.ts
@@ -267,6 +267,7 @@ export interface TimeseriesExtrema {
// task string definitions - should mirror those defined in the MIT/LL d3m problem schema
export enum TaskTypes {
CLASSIFICATION = "classification",
+ SEGMENTATION = "segmentation",
REGRESSION = "regression",
CLUSTERING = "clustering",
LINK_PREDICTION = "linkPrediction",
From 5122b29dd09941adb7cbf5f1679603937264e867 Mon Sep 17 00:00:00 2001
From: phorne
Date: Fri, 15 Jul 2022 10:25:58 -0400
Subject: [PATCH 05/16] Added fitted pipeline id to the output of a fully
specified pipeline execution.
---
api/compute/filter.go | 2 +-
api/compute/pipeline.go | 30 ++++--
api/compute/solution_request.go | 167 ++++++++++++++++++++++++--------
api/task/pipelines.go | 6 +-
go.mod | 2 +-
go.sum | 4 +-
6 files changed, 156 insertions(+), 55 deletions(-)
diff --git a/api/compute/filter.go b/api/compute/filter.go
index 9f39c5dc5..a608095ad 100644
--- a/api/compute/filter.go
+++ b/api/compute/filter.go
@@ -85,7 +85,7 @@ func filterData(client *compute.Client, ds *api.Dataset, filterParams *api.Filte
// output the filtered results as the data in the filtered dataset
_, outputDataFile := getPreFilteringOutputDataFile(outputFolder)
- err = util.CopyFile(filteredData, outputDataFile)
+ err = util.CopyFile(filteredData.ResultURI, outputDataFile)
if err != nil {
return "", nil, err
}
diff --git a/api/compute/pipeline.go b/api/compute/pipeline.go
index 1be073ca9..0140635bb 100644
--- a/api/compute/pipeline.go
+++ b/api/compute/pipeline.go
@@ -115,10 +115,16 @@ type QueueItem struct {
// QueueResponse represents the result from processing a queue item.
type QueueResponse struct {
- Output interface{}
+ Output *PipelineOutput
Error error
}
+// PipelineOutput represents an output from executing a queued pipeline.
+type PipelineOutput struct {
+ ResultURI string
+ FittedSolutionID string
+}
+
// Queue uses a buffered channel to queue tasks and provides the result via channels.
type Queue struct {
mu sync.RWMutex
@@ -234,7 +240,7 @@ func InitializeQueue(config *env.Config) {
// SubmitPipeline executes pipelines using the client and returns the result URI.
func SubmitPipeline(client *compute.Client, datasets []string, datasetsProduce []string, searchRequest *pipeline.SearchSolutionsRequest,
- fullySpecifiedStep *description.FullySpecifiedPipeline, allowedValueTypes []string, shouldCache bool) (string, error) {
+ fullySpecifiedStep *description.FullySpecifiedPipeline, allowedValueTypes []string, shouldCache bool) (*PipelineOutput, error) {
request := compute.NewExecPipelineRequest(datasets, datasetsProduce, fullySpecifiedStep.Pipeline)
@@ -254,12 +260,12 @@ func SubmitPipeline(client *compute.Client, datasets []string, datasetsProduce [
if cache.readEnabled {
if shouldCache {
if err != nil {
- return "", err
+ return nil, err
}
entry, found := cache.cache.Get(hashedPipelineUniqueKey)
if found {
log.Infof("returning cached entry for pipeline")
- return entry.(string), nil
+ return entry.(*PipelineOutput), nil
}
}
} else {
@@ -268,7 +274,7 @@ func SubmitPipeline(client *compute.Client, datasets []string, datasetsProduce [
// get equivalency key for enqueuing
hashedPipelineEquivKey, err := queueTask.hashEquivalent()
if err != nil {
- return "", err
+ return nil, err
}
resultChan := queue.Enqueue(hashedPipelineEquivKey, queueTask)
@@ -276,17 +282,16 @@ func SubmitPipeline(client *compute.Client, datasets []string, datasetsProduce [
result := <-resultChan
if result.Error != nil {
- return "", result.Error
+ return nil, result.Error
}
- datasetURI := result.Output.(string)
- cache.cache.Set(hashedPipelineUniqueKey, datasetURI, gc.DefaultExpiration)
+ cache.cache.Set(hashedPipelineUniqueKey, result.Output, gc.DefaultExpiration)
err = cache.PersistCache()
if err != nil {
log.Warnf("error persisting cache: %v", err)
}
- return datasetURI, nil
+ return result.Output, nil
}
func runPipelineQueue(queue *Queue) {
@@ -316,6 +321,7 @@ func runPipelineQueue(queue *Queue) {
// listen for completion
var errPipeline error
var datasetURI string
+ var fittedSolutionID string
err = pipelineTask.request.Listen(func(status compute.ExecPipelineStatus) {
// check for error
if status.Error != nil {
@@ -324,6 +330,7 @@ func runPipelineQueue(queue *Queue) {
if status.Progress == compute.RequestCompletedStatus {
datasetURI = status.ResultURI
+ fittedSolutionID = status.FittedSolutionID
}
})
if err != nil {
@@ -342,7 +349,10 @@ func runPipelineQueue(queue *Queue) {
datasetURI = strings.Replace(datasetURI, "file://", "", -1)
- queueTask.returnResult(&QueueResponse{Output: datasetURI})
+ queueTask.returnResult(&QueueResponse{&PipelineOutput{
+ ResultURI: datasetURI,
+ FittedSolutionID: fittedSolutionID,
+ }, nil})
}
log.Infof("ending queue processing")
diff --git a/api/compute/solution_request.go b/api/compute/solution_request.go
index 9bc3f1e4f..472e48deb 100644
--- a/api/compute/solution_request.go
+++ b/api/compute/solution_request.go
@@ -634,33 +634,64 @@ func (s *SolutionRequest) dispatchRequest(client *compute.Client, solutionStorag
s.finished <- nil
}
-func dispatchSegmentation(s *SolutionRequest, solutionStorage api.SolutionStorage, metaStorage api.MetadataStorage,
+func dispatchSegmentation(s *SolutionRequest, requestID string, solutionStorage api.SolutionStorage, metaStorage api.MetadataStorage,
dataStorage api.DataStorage, client *compute.Client, datasetInputDir string, step *description.FullySpecifiedPipeline) {
- // need a request ID
- uuid, err := uuid.NewV4()
+ log.Infof("dispatching segmentation pipeline")
+
+ // create the backing data
+ err := s.persistRequestStatus(s.requestChannel, solutionStorage, requestID, s.Dataset, compute.RequestRunningStatus)
if err != nil {
- s.finished <- errors.Wrapf(err, "unable to generate request id")
+ s.finished <- err
return
}
- // create the backing data
- err = s.persistRequestStatus(s.requestChannel, solutionStorage, uuid.String(), s.Dataset, compute.RequestRunningStatus)
+ c := newStatusChannel()
+ // add the solution to the request
+ uuidGen, err := uuid.NewV4()
if err != nil {
- s.finished <- err
+ s.finished <- errors.Wrapf(err, "unable to generate solution id")
return
}
+ solutionID := uuidGen.String()
+ s.addSolution(c)
+ s.persistSolution(c, solutionStorage, requestID, solutionID, "")
+ s.persistSolutionStatus(c, solutionStorage, requestID, solutionID, compute.SolutionPendingStatus)
// run the pipeline
- resultURI, err := SubmitPipeline(client, []string{datasetInputDir}, nil, nil, step, nil, true)
+ pipelineResult, err := SubmitPipeline(client, []string{datasetInputDir}, nil, nil, step, nil, true)
+ if err != nil {
+ s.finished <- err
+ return
+ }
+ s.persistSolutionStatus(c, solutionStorage, requestID, solutionID, compute.SolutionScoringStatus)
+
+ // HACK: MAKE UP A SOLUTION SCORE!!!
+ err = solutionStorage.PersistSolutionScore(solutionID, util.F1Micro, 0.5)
if err != nil {
s.finished <- err
return
}
+ s.persistSolutionStatus(c, solutionStorage, requestID, solutionID, compute.SolutionProducingStatus)
// update status and respond to client as needed
+ uuidGen, err = uuid.NewV4()
+ if err != nil {
+ s.finished <- errors.Wrapf(err, "unable to generate solution id")
+ return
+ }
+ resultID := uuidGen.String()
+ c <- SolutionStatus{
+ RequestID: requestID,
+ SolutionID: solutionID,
+ ResultID: resultID,
+ Progress: compute.SolutionCompletedStatus,
+ Timestamp: time.Now(),
+ }
+ close(c)
// read the file and parse the output mask
- result, err := result.ParseResultCSV(resultURI)
+ log.Infof("processing segmentation pipeline output")
+ result, err := result.ParseResultCSV(pipelineResult.ResultURI)
if err != nil {
s.finished <- err
return
@@ -708,6 +739,45 @@ func dispatchSegmentation(s *SolutionRequest, solutionStorage api.SolutionStorag
}
}
+ // HACK: INPUT FAKE RESULTS TO THE DB!!!
+ // FAKE RESULTS SHOULD JUST BE A CONSTANT!
+ uuidGen, err = uuid.NewV4()
+ if err != nil {
+ s.finished <- errors.Wrapf(err, "unable to generate produce request id")
+ return
+ }
+ produceRequestID := uuidGen.String()
+
+ // HACK: CREATE FAKE RESULTS TO PERSIST AS THE ACTUAL RESULTS SHOULD NOT BE STORED IN THE DB!!!
+ resultOutput := []string{fmt.Sprintf("%s,%s,%s", model.D3MIndexFieldName, s.TargetFeature.HeaderName, "confidence")}
+ for i := 1; i < len(result); i++ {
+ resultOutput = append(resultOutput, fmt.Sprintf("%s,%s,%d", result[i][0].(string), "segmented", 1))
+ }
+ resultOutputURI := fmt.Sprintf("%s-distil-%s",
+ pipelineResult.ResultURI[:len(pipelineResult.ResultURI)-4], pipelineResult.ResultURI[len(pipelineResult.ResultURI)-4:])
+ log.Infof("writing distil formatted segmentation results to '%s'", resultOutputURI)
+ err = util.WriteFileWithDirs(resultOutputURI, []byte(strings.Join(resultOutput, "\n")), os.ModePerm)
+ if err != nil {
+ s.finished <- err
+ return
+ }
+
+ log.Infof("persisting results in URI '%s'", resultOutputURI)
+ err = s.persistSolutionResults(c, client, solutionStorage, dataStorage, requestID,
+ dataset.ID, dataset.StorageName, solutionID, pipelineResult.FittedSolutionID, produceRequestID, resultID, resultOutputURI)
+ if err != nil {
+ s.finished <- errors.Wrapf(err, "unable to persist solution result")
+ return
+ }
+
+ log.Infof("segmentation pipeline processing complete")
+
+ err = s.persistRequestStatus(s.requestChannel, solutionStorage, requestID, dataset.ID, compute.RequestCompletedStatus)
+ if err != nil {
+ s.finished <- err
+ return
+ }
+ close(s.requestChannel)
s.finished <- nil
}
@@ -724,6 +794,7 @@ func processSegmentation(s *SolutionRequest, client *compute.Client, solutionSto
return nil
}
s.DatasetMetadata = dataset
+ variablesMap := api.MapVariables(dataset.Variables, func(v *model.Variable) string { return v.Key })
datasetInputDir := env.ResolvePath(dataset.Source, dataset.Folder)
@@ -732,8 +803,53 @@ func processSegmentation(s *SolutionRequest, client *compute.Client, solutionSto
return err
}
+ // need a request ID
+ uuidGen, err := uuid.NewV4()
+ if err != nil {
+ return err
+ }
+ requestID := uuidGen.String()
+
+ // persist the request
+ err = s.persistRequestStatus(s.requestChannel, solutionStorage, requestID, dataset.ID, compute.RequestPendingStatus)
+ if err != nil {
+ return err
+ }
+
+ // store the request features - note that we are storing the original request filters, not the expanded
+ // list that was generated
+ // also note that augmented features should not be included
+ for _, v := range s.Filters.Variables {
+ var typ string
+ // ignore the index field
+ if v == model.D3MIndexFieldName {
+ continue
+ } else if variablesMap[v].HasRole(model.VarDistilRoleAugmented) {
+ continue
+ }
+
+ if v == s.TargetFeature.Key {
+ // store target feature
+ typ = model.FeatureTypeTarget
+ } else {
+ // store training feature
+ typ = model.FeatureTypeTrain
+ }
+ err = solutionStorage.PersistRequestFeature(requestID, v, typ)
+ if err != nil {
+ return err
+ }
+ }
+
+ // store the original request filters
+ // HACK: NO FILTERS SUPPORTED FOR SEGMENTATION!
+ err = solutionStorage.PersistRequestFilters(requestID, s.Filters)
+ if err != nil {
+ return err
+ }
+
// dispatch it as if it were a model search
- go dispatchSegmentation(s, solutionStorage, metaStorage, dataStorage, client, datasetInputDir, step)
+ go dispatchSegmentation(s, requestID, solutionStorage, metaStorage, dataStorage, client, datasetInputDir, step)
return nil
}
@@ -851,13 +967,6 @@ func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionSto
}
}
- // HACK: SEGMENTATION TASK NEEDS TO ACT ON BASE DATASET!
- // CURRENTLY SET TO IGNORE PREFILTERING!!
- if HasTaskType(task, compute.SegmentationTask) {
- filteredDatasetPath = env.ResolvePath(dataset.Source, dataset.Folder)
- s.useParquet = false
- }
-
// get the target
meta, err := serialization.ReadMetadata(path.Join(filteredDatasetPath, compute.D3MDataSchema))
if err != nil {
@@ -897,30 +1006,8 @@ func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionSto
s.Filters = mapFilterKeys(s.Dataset, s.Filters, dataset.Variables)
// generate the pre-processing pipeline to enforce feature selection and semantic type changes
- // HACK: IF SEGMENTATION, THEN SUBMIT THE FULLY SPECIFIED PIPELINE!!!
var preprocessing *pipeline.PipelineDescription
- if HasTaskType(task, compute.SegmentationTask) {
- envConfig, err := env.LoadConfig()
- if err != nil {
- return err
- }
-
- ps, err := description.CreateRemoteSensingSegmentationPipeline("segmentation", "basic image segmentation", s.TargetFeature, envConfig.RemoteSensingNumJobs)
- if err != nil {
- return err
- }
- preprocessing = ps.Pipeline
-
- // remove the segmentation task
- tasksUpdated := []string{}
- for _, t := range s.Task {
- if t != compute.SegmentationTask {
- tasksUpdated = append(tasksUpdated, t)
- }
- }
- s.Task = tasksUpdated
-
- } else if !client.SkipPreprocessing {
+ if !client.SkipPreprocessing {
if dataset.LearningDataset == "" {
preprocessing, err = s.createPreprocessingPipeline(variables, metaStorage)
} else {
diff --git a/api/task/pipelines.go b/api/task/pipelines.go
index 51fe3f092..0e56498d8 100644
--- a/api/task/pipelines.go
+++ b/api/task/pipelines.go
@@ -62,7 +62,11 @@ func SetClient(computeClient *compute.Client) {
}
func submitPipeline(datasets []string, step *description.FullySpecifiedPipeline, shouldCache bool) (string, error) {
- return sr.SubmitPipeline(client, datasets, nil, nil, step, nil, shouldCache)
+ result, err := sr.SubmitPipeline(client, datasets, nil, nil, step, nil, shouldCache)
+ if err != nil {
+ return "", err
+ }
+ return result.ResultURI, nil
}
func getD3MIndexField(dr *model.DataResource) int {
diff --git a/go.mod b/go.mod
index 711222db8..bc36dac11 100644
--- a/go.mod
+++ b/go.mod
@@ -34,7 +34,7 @@ require (
github.com/russross/blackfriday v2.0.0+incompatible
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
github.com/stretchr/testify v1.6.1
- github.com/uncharted-distil/distil-compute v0.0.0-20220706150456-b5974c46e396
+ github.com/uncharted-distil/distil-compute v0.0.0-20220714184701-da71999368f3
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb
github.com/uncharted-distil/gdal v0.0.0-20200504224203-25f2e6a0dc2a
github.com/unchartedsoftware/plog v0.0.0-20200807135627-83d59e50ced5
diff --git a/go.sum b/go.sum
index e83463aa4..1e8b7da11 100644
--- a/go.sum
+++ b/go.sum
@@ -213,8 +213,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/uncharted-distil/distil-compute v0.0.0-20220706150456-b5974c46e396 h1:PXewTNQEdzPji0CY6Kld51dOnLJDNJz3Vqv8RLV/S2Q=
-github.com/uncharted-distil/distil-compute v0.0.0-20220706150456-b5974c46e396/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8=
+github.com/uncharted-distil/distil-compute v0.0.0-20220714184701-da71999368f3 h1:U09LSIsskN8fE87e2XbKgWlu7/3Bj2plhRw8Df7/NrE=
+github.com/uncharted-distil/distil-compute v0.0.0-20220714184701-da71999368f3/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8=
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb h1:wDsXsrF8qM34nLeQ9xW+zbEdRNATk5sgOwuwCTrZmvY=
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb/go.mod h1:Xhb77n2q8yDvcVS3Mvw0XlpdNMiFsL+vOlvoe556ivc=
github.com/uncharted-distil/gdal v0.0.0-20200504224203-25f2e6a0dc2a h1:BPJrlnjdhxMBrJWiU4/Gl3PVdCUlY9JspWFTJ9UVO0Y=
From 2b4aad3789569c7240f9559ae25fb9c38bfece9e Mon Sep 17 00:00:00 2001
From: phorne
Date: Fri, 15 Jul 2022 14:05:25 -0400
Subject: [PATCH 06/16] Solution id for fully specified pipelines now captured
as part of the execution.
---
api/compute/pipeline.go | 4 ++++
api/compute/solution_request.go | 30 +++++++++++++-----------------
go.mod | 2 +-
go.sum | 4 ++--
4 files changed, 20 insertions(+), 20 deletions(-)
diff --git a/api/compute/pipeline.go b/api/compute/pipeline.go
index 0140635bb..1682d9138 100644
--- a/api/compute/pipeline.go
+++ b/api/compute/pipeline.go
@@ -121,6 +121,7 @@ type QueueResponse struct {
// PipelineOutput represents an output from executing a queued pipeline.
type PipelineOutput struct {
+ SolutionID string
ResultURI string
FittedSolutionID string
}
@@ -322,6 +323,7 @@ func runPipelineQueue(queue *Queue) {
var errPipeline error
var datasetURI string
var fittedSolutionID string
+ var solutionID string
err = pipelineTask.request.Listen(func(status compute.ExecPipelineStatus) {
// check for error
if status.Error != nil {
@@ -331,6 +333,7 @@ func runPipelineQueue(queue *Queue) {
if status.Progress == compute.RequestCompletedStatus {
datasetURI = status.ResultURI
fittedSolutionID = status.FittedSolutionID
+ solutionID = status.SolutionID
}
})
if err != nil {
@@ -352,6 +355,7 @@ func runPipelineQueue(queue *Queue) {
queueTask.returnResult(&QueueResponse{&PipelineOutput{
ResultURI: datasetURI,
FittedSolutionID: fittedSolutionID,
+ SolutionID: solutionID,
}, nil})
}
diff --git a/api/compute/solution_request.go b/api/compute/solution_request.go
index 472e48deb..227251614 100644
--- a/api/compute/solution_request.go
+++ b/api/compute/solution_request.go
@@ -646,16 +646,6 @@ func dispatchSegmentation(s *SolutionRequest, requestID string, solutionStorage
}
c := newStatusChannel()
- // add the solution to the request
- uuidGen, err := uuid.NewV4()
- if err != nil {
- s.finished <- errors.Wrapf(err, "unable to generate solution id")
- return
- }
- solutionID := uuidGen.String()
- s.addSolution(c)
- s.persistSolution(c, solutionStorage, requestID, solutionID, "")
- s.persistSolutionStatus(c, solutionStorage, requestID, solutionID, compute.SolutionPendingStatus)
// run the pipeline
pipelineResult, err := SubmitPipeline(client, []string{datasetInputDir}, nil, nil, step, nil, true)
@@ -663,18 +653,24 @@ func dispatchSegmentation(s *SolutionRequest, requestID string, solutionStorage
s.finished <- err
return
}
- s.persistSolutionStatus(c, solutionStorage, requestID, solutionID, compute.SolutionScoringStatus)
+
+ // add the solution to the request
+ // doing this after submission to have the solution id available!
+ s.addSolution(c)
+ s.persistSolution(c, solutionStorage, requestID, pipelineResult.SolutionID, "")
+ s.persistSolutionStatus(c, solutionStorage, requestID, pipelineResult.SolutionID, compute.SolutionPendingStatus)
+ s.persistSolutionStatus(c, solutionStorage, requestID, pipelineResult.SolutionID, compute.SolutionScoringStatus)
// HACK: MAKE UP A SOLUTION SCORE!!!
- err = solutionStorage.PersistSolutionScore(solutionID, util.F1Micro, 0.5)
+ err = solutionStorage.PersistSolutionScore(pipelineResult.SolutionID, util.F1Micro, 0.5)
if err != nil {
s.finished <- err
return
}
- s.persistSolutionStatus(c, solutionStorage, requestID, solutionID, compute.SolutionProducingStatus)
+ s.persistSolutionStatus(c, solutionStorage, requestID, pipelineResult.SolutionID, compute.SolutionProducingStatus)
// update status and respond to client as needed
- uuidGen, err = uuid.NewV4()
+ uuidGen, err := uuid.NewV4()
if err != nil {
s.finished <- errors.Wrapf(err, "unable to generate solution id")
return
@@ -682,7 +678,7 @@ func dispatchSegmentation(s *SolutionRequest, requestID string, solutionStorage
resultID := uuidGen.String()
c <- SolutionStatus{
RequestID: requestID,
- SolutionID: solutionID,
+ SolutionID: pipelineResult.SolutionID,
ResultID: resultID,
Progress: compute.SolutionCompletedStatus,
Timestamp: time.Now(),
@@ -763,8 +759,8 @@ func dispatchSegmentation(s *SolutionRequest, requestID string, solutionStorage
}
log.Infof("persisting results in URI '%s'", resultOutputURI)
- err = s.persistSolutionResults(c, client, solutionStorage, dataStorage, requestID,
- dataset.ID, dataset.StorageName, solutionID, pipelineResult.FittedSolutionID, produceRequestID, resultID, resultOutputURI)
+ err = s.persistSolutionResults(c, client, solutionStorage, dataStorage, requestID, dataset.ID,
+ dataset.StorageName, pipelineResult.SolutionID, pipelineResult.FittedSolutionID, produceRequestID, resultID, resultOutputURI)
if err != nil {
s.finished <- errors.Wrapf(err, "unable to persist solution result")
return
diff --git a/go.mod b/go.mod
index bc36dac11..2327dd0b2 100644
--- a/go.mod
+++ b/go.mod
@@ -34,7 +34,7 @@ require (
github.com/russross/blackfriday v2.0.0+incompatible
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
github.com/stretchr/testify v1.6.1
- github.com/uncharted-distil/distil-compute v0.0.0-20220714184701-da71999368f3
+ github.com/uncharted-distil/distil-compute v0.0.0-20220715171604-26f9f01bab93
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb
github.com/uncharted-distil/gdal v0.0.0-20200504224203-25f2e6a0dc2a
github.com/unchartedsoftware/plog v0.0.0-20200807135627-83d59e50ced5
diff --git a/go.sum b/go.sum
index 1e8b7da11..1f4b7d3fc 100644
--- a/go.sum
+++ b/go.sum
@@ -213,8 +213,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/uncharted-distil/distil-compute v0.0.0-20220714184701-da71999368f3 h1:U09LSIsskN8fE87e2XbKgWlu7/3Bj2plhRw8Df7/NrE=
-github.com/uncharted-distil/distil-compute v0.0.0-20220714184701-da71999368f3/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8=
+github.com/uncharted-distil/distil-compute v0.0.0-20220715171604-26f9f01bab93 h1:UNSU3FX3h4k8wrzzXWLtX2kl4bb2AW7BqoV2FkQigRs=
+github.com/uncharted-distil/distil-compute v0.0.0-20220715171604-26f9f01bab93/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8=
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb h1:wDsXsrF8qM34nLeQ9xW+zbEdRNATk5sgOwuwCTrZmvY=
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb/go.mod h1:Xhb77n2q8yDvcVS3Mvw0XlpdNMiFsL+vOlvoe556ivc=
github.com/uncharted-distil/gdal v0.0.0-20200504224203-25f2e6a0dc2a h1:BPJrlnjdhxMBrJWiU4/Gl3PVdCUlY9JspWFTJ9UVO0Y=
From efef6d999539d34fdffd2b95fcc3cbdbe780d921 Mon Sep 17 00:00:00 2001
From: phorne
Date: Tue, 19 Jul 2022 15:12:20 -0400
Subject: [PATCH 07/16] Segmentation predictions updated to output the same
data as the model search.
---
api/compute/pipeline.go | 3 ++
api/compute/search.go | 2 +-
api/compute/solution_request.go | 51 ++++++++++++++++++++++++++-------
api/task/prediction.go | 4 ++-
api/ws/pipeline.go | 11 +++++++
5 files changed, 59 insertions(+), 12 deletions(-)
diff --git a/api/compute/pipeline.go b/api/compute/pipeline.go
index 1682d9138..fbd979de7 100644
--- a/api/compute/pipeline.go
+++ b/api/compute/pipeline.go
@@ -201,6 +201,9 @@ func (q *Queue) Done() {
// InitializeCache sets up an empty cache or if a source file provided, reads
// the cache from the source file.
func InitializeCache(sourceFile string, readEnabled bool) error {
+ // register the output type for the cache!
+ gob.Register(&PipelineOutput{})
+
var c *gc.Cache
if util.FileExists(sourceFile) {
b, err := ioutil.ReadFile(sourceFile)
diff --git a/api/compute/search.go b/api/compute/search.go
index 2d4953d0f..8638f7c29 100644
--- a/api/compute/search.go
+++ b/api/compute/search.go
@@ -278,7 +278,7 @@ func (s *SolutionRequest) dispatchSolutionSearchPipeline(statusChan chan Solutio
if ok {
// reformat result to have one row per d3m index since confidences
// can produce one row / class
- resultURI, err = reformatResult(resultURI)
+ resultURI, err = reformatResult(resultURI, s.TargetFeature.HeaderName, &Task{s.Task})
if err != nil {
return nil, err
}
diff --git a/api/compute/solution_request.go b/api/compute/solution_request.go
index 227251614..709461003 100644
--- a/api/compute/solution_request.go
+++ b/api/compute/solution_request.go
@@ -327,7 +327,7 @@ func (s *SolutionRequest) createPreprocessingPipeline(featureVariables []*model.
}
// GeneratePredictions produces predictions using the specified.
-func GeneratePredictions(datasetURI string, solutionID string, fittedSolutionID string, client *compute.Client) (*PredictionResult, error) {
+func GeneratePredictions(datasetURI string, solutionID string, fittedSolutionID string, task *Task, targetName string, client *compute.Client) (*PredictionResult, error) {
// check if the solution can be explained
desc, err := client.GetSolutionDescription(context.Background(), solutionID)
if err != nil {
@@ -359,7 +359,7 @@ func GeneratePredictions(datasetURI string, solutionID string, fittedSolutionID
if err != nil {
return nil, err
}
- resultURI, err = reformatResult(resultURI)
+ resultURI, err = reformatResult(resultURI, targetName, task)
if err != nil {
return nil, err
}
@@ -745,14 +745,13 @@ func dispatchSegmentation(s *SolutionRequest, requestID string, solutionStorage
produceRequestID := uuidGen.String()
// HACK: CREATE FAKE RESULTS TO PERSIST AS THE ACTUAL RESULTS SHOULD NOT BE STORED IN THE DB!!!
- resultOutput := []string{fmt.Sprintf("%s,%s,%s", model.D3MIndexFieldName, s.TargetFeature.HeaderName, "confidence")}
- for i := 1; i < len(result); i++ {
- resultOutput = append(resultOutput, fmt.Sprintf("%s,%s,%d", result[i][0].(string), "segmented", 1))
+ dataReader := serialization.GetStorage(pipelineResult.ResultURI)
+ dataResult, err := dataReader.ReadData(pipelineResult.ResultURI)
+ if err != nil {
+ s.finished <- err
+ return
}
- resultOutputURI := fmt.Sprintf("%s-distil-%s",
- pipelineResult.ResultURI[:len(pipelineResult.ResultURI)-4], pipelineResult.ResultURI[len(pipelineResult.ResultURI)-4:])
- log.Infof("writing distil formatted segmentation results to '%s'", resultOutputURI)
- err = util.WriteFileWithDirs(resultOutputURI, []byte(strings.Join(resultOutput, "\n")), os.ModePerm)
+ resultOutputURI, err := createSegmentationResult(pipelineResult.ResultURI, s.TargetFeature.HeaderName, dataResult)
if err != nil {
s.finished <- err
return
@@ -1135,7 +1134,7 @@ type confidenceValue struct {
row int
}
-func reformatResult(resultURI string) (string, error) {
+func reformatResult(resultURI string, targetName string, task *Task) (string, error) {
// read data from original file
dataReader := serialization.GetStorage(resultURI)
data, err := dataReader.ReadData(resultURI)
@@ -1143,6 +1142,11 @@ func reformatResult(resultURI string) (string, error) {
return "", err
}
+ // segmentation results need to be reduced to tagging segmented images
+ if HasTaskType(task, compute.SegmentationTask) && isSegmentationOutput(resultURI) {
+ return createSegmentationResult(resultURI, targetName, data)
+ }
+
// only need to reformat if confidences are there (column count >= 3)
if len(data[0]) < 3 {
return resultURI, nil
@@ -1191,3 +1195,30 @@ func reformatResult(resultURI string) (string, error) {
return filteredURI, nil
}
+
+// isSegmentationOutput checks if a result is from an image segmentation pipeline.
+// NOTE: returns false if it cannot confirm it is segmentation (ex: exception occurs)!
+func isSegmentationOutput(resultURI string) bool {
+ result, err := result.ParseResultCSV(resultURI)
+ if err != nil {
+ return false
+ }
+
+ // segmentation output has as header "d3mIndex,positive_mask"
+ return len(result[0]) == 2 && result[0][0].(string) == model.D3MIndexFieldName && result[0][1].(string) == "positive_mask"
+}
+
+func createSegmentationResult(resultURI string, targetName string, result [][]string) (string, error) {
+ resultOutput := []string{fmt.Sprintf("%s,%s,%s", model.D3MIndexFieldName, targetName, "confidence")}
+ for i := 1; i < len(result); i++ {
+ resultOutput = append(resultOutput, fmt.Sprintf("%s,%s,%d", result[i][0], "segmented", 1))
+ }
+ resultOutputURI := fmt.Sprintf("%s-distil-%s", resultURI[:len(resultURI)-4], resultURI[len(resultURI)-4:])
+ log.Infof("writing distil formatted segmentation results to '%s'", resultOutputURI)
+ err := util.WriteFileWithDirs(resultOutputURI, []byte(strings.Join(resultOutput, "\n")), os.ModePerm)
+ if err != nil {
+ return "", err
+ }
+
+ return resultOutputURI, nil
+}
diff --git a/api/task/prediction.go b/api/task/prediction.go
index 4813dd73b..4c2d7f5f8 100644
--- a/api/task/prediction.go
+++ b/api/task/prediction.go
@@ -206,6 +206,7 @@ type PredictParams struct {
DatasetConstructor DatasetConstructor
OutputPath string
IndexFields []string
+ Task *comp.Task
Target *model.Variable
MetaStorage api.MetadataStorage
DataStorage api.DataStorage
@@ -708,7 +709,8 @@ func Predict(params *PredictParams) (string, error) {
// submit the new dataset for predictions
log.Infof("generating predictions using data found at '%s'", params.SchemaPath)
- predictionResult, err := comp.GeneratePredictions(params.SchemaPath, solution.SolutionID, params.FittedSolutionID, client)
+ predictionResult, err := comp.GeneratePredictions(params.SchemaPath,
+ solution.SolutionID, params.FittedSolutionID, params.Task, params.Target.HeaderName, client)
if err != nil {
return "", err
}
diff --git a/api/ws/pipeline.go b/api/ws/pipeline.go
index 333b833f6..59723a689 100644
--- a/api/ws/pipeline.go
+++ b/api/ws/pipeline.go
@@ -430,6 +430,7 @@ func handlePredict(conn *Connection, client *compute.Client, metadataCtor apiMod
SolutionID: sr.SolutionID,
FittedSolutionID: request.FittedSolutionID,
OutputPath: path.Join(config.D3MOutputDir, config.AugmentedSubFolder),
+ Task: requestTask,
Target: targetVar,
MetaStorage: metaStorage,
DataStorage: dataStorage,
@@ -445,6 +446,16 @@ func handlePredict(conn *Connection, client *compute.Client, metadataCtor apiMod
handleErr(conn, msg, errors.Wrap(err, "unable to create raw dataset"))
return
}
+
+ // if the task is a segmentation task, run it against the base dataset
+ if api.HasTaskType(requestTask, compute.SegmentationTask) {
+ dsPred, err := metaStorage.FetchDataset(datasetName, true, true, true)
+ if err != nil {
+ handleErr(conn, msg, errors.Wrap(err, "unable to resolve prediction dataset"))
+ return
+ }
+ datasetPath = path.Join(env.ResolvePath(dsPred.Source, dsPred.Folder), compute.D3MDataSchema)
+ }
predictParams.Dataset = datasetName
predictParams.SchemaPath = datasetPath
From a4518475d92b09e61fea3e98736971a693471977 Mon Sep 17 00:00:00 2001
From: phorne
Date: Wed, 20 Jul 2022 12:45:04 -0400
Subject: [PATCH 08/16] Segmentation predictions now write the segmentation
layer to disk.
---
api/compute/search.go | 2 +-
api/compute/solution_request.go | 121 ++++++++++++++++----------------
api/task/prediction.go | 4 +-
3 files changed, 63 insertions(+), 64 deletions(-)
diff --git a/api/compute/search.go b/api/compute/search.go
index 8638f7c29..a2ae05d5e 100644
--- a/api/compute/search.go
+++ b/api/compute/search.go
@@ -278,7 +278,7 @@ func (s *SolutionRequest) dispatchSolutionSearchPipeline(statusChan chan Solutio
if ok {
// reformat result to have one row per d3m index since confidences
// can produce one row / class
- resultURI, err = reformatResult(resultURI, s.TargetFeature.HeaderName, &Task{s.Task})
+ resultURI, err = reformatResult(resultURI, s.TargetFeature.HeaderName)
if err != nil {
return nil, err
}
diff --git a/api/compute/solution_request.go b/api/compute/solution_request.go
index 709461003..39a157e65 100644
--- a/api/compute/solution_request.go
+++ b/api/compute/solution_request.go
@@ -327,7 +327,8 @@ func (s *SolutionRequest) createPreprocessingPipeline(featureVariables []*model.
}
// GeneratePredictions produces predictions using the specified.
-func GeneratePredictions(datasetURI string, solutionID string, fittedSolutionID string, task *Task, targetName string, client *compute.Client) (*PredictionResult, error) {
+func GeneratePredictions(datasetID string, datasetURI string, solutionID string, fittedSolutionID string, task *Task,
+ targetName string, metaStorage api.MetadataStorage, dataStorage api.DataStorage, client *compute.Client) (*PredictionResult, error) {
// check if the solution can be explained
desc, err := client.GetSolutionDescription(context.Background(), solutionID)
if err != nil {
@@ -359,7 +360,14 @@ func GeneratePredictions(datasetURI string, solutionID string, fittedSolutionID
if err != nil {
return nil, err
}
- resultURI, err = reformatResult(resultURI, targetName, task)
+
+ // segmentation results need to be reduced to tagging segmented images
+ if HasTaskType(task, compute.SegmentationTask) && isSegmentationOutput(resultURI) {
+ resultURI, err = createSegmentationResult(datasetID, resultURI, targetName, metaStorage, dataStorage)
+ } else {
+ resultURI, err = reformatResult(resultURI, targetName)
+ }
+
if err != nil {
return nil, err
}
@@ -685,56 +693,14 @@ func dispatchSegmentation(s *SolutionRequest, requestID string, solutionStorage
}
close(c)
- // read the file and parse the output mask
- log.Infof("processing segmentation pipeline output")
- result, err := result.ParseResultCSV(pipelineResult.ResultURI)
- if err != nil {
- s.finished <- err
- return
- }
-
- images, err := BuildSegmentationImage(result)
- if err != nil {
- s.finished <- err
- return
- }
-
// get the grouping key since it makes up part of the filename
+ log.Infof("processing segmentation pipeline output")
dataset, err := metaStorage.FetchDataset(s.Dataset, true, true, false)
if err != nil {
s.finished <- err
return
}
- var groupingKey *model.Variable
- for _, v := range dataset.Variables {
- if v.HasRole(model.VarDistilRoleGrouping) {
- groupingKey = v
- break
- }
- }
- if groupingKey == nil {
- s.finished <- errors.Errorf("no grouping found to use for output filename")
- return
- }
-
- // get the d3m index -> grouping key mapping
- mapping, err := api.BuildFieldMapping(dataset.ID, dataset.StorageName, model.D3MIndexFieldName, groupingKey.Key, dataStorage)
- if err != nil {
- s.finished <- err
- return
- }
-
- imageOutputFolder := path.Join(env.GetResourcePath(), dataset.ID, "media")
- for d3mIndex, imageBytes := range images {
- imageFilename := path.Join(imageOutputFolder, fmt.Sprintf("%s-segmentation.png", mapping[d3mIndex]))
- err = util.WriteFileWithDirs(imageFilename, imageBytes, os.ModePerm)
- if err != nil {
- s.finished <- err
- return
- }
- }
-
// HACK: INPUT FAKE RESULTS TO THE DB!!!
// FAKE RESULTS SHOULD JUST BE A CONSTANT!
uuidGen, err = uuid.NewV4()
@@ -745,13 +711,7 @@ func dispatchSegmentation(s *SolutionRequest, requestID string, solutionStorage
produceRequestID := uuidGen.String()
// HACK: CREATE FAKE RESULTS TO PERSIST AS THE ACTUAL RESULTS SHOULD NOT BE STORED IN THE DB!!!
- dataReader := serialization.GetStorage(pipelineResult.ResultURI)
- dataResult, err := dataReader.ReadData(pipelineResult.ResultURI)
- if err != nil {
- s.finished <- err
- return
- }
- resultOutputURI, err := createSegmentationResult(pipelineResult.ResultURI, s.TargetFeature.HeaderName, dataResult)
+ resultOutputURI, err := createSegmentationResult(s.Dataset, pipelineResult.ResultURI, s.TargetFeature.HeaderName, metaStorage, dataStorage)
if err != nil {
s.finished <- err
return
@@ -1134,7 +1094,7 @@ type confidenceValue struct {
row int
}
-func reformatResult(resultURI string, targetName string, task *Task) (string, error) {
+func reformatResult(resultURI string, targetName string) (string, error) {
// read data from original file
dataReader := serialization.GetStorage(resultURI)
data, err := dataReader.ReadData(resultURI)
@@ -1142,11 +1102,6 @@ func reformatResult(resultURI string, targetName string, task *Task) (string, er
return "", err
}
- // segmentation results need to be reduced to tagging segmented images
- if HasTaskType(task, compute.SegmentationTask) && isSegmentationOutput(resultURI) {
- return createSegmentationResult(resultURI, targetName, data)
- }
-
// only need to reformat if confidences are there (column count >= 3)
if len(data[0]) < 3 {
return resultURI, nil
@@ -1208,14 +1163,58 @@ func isSegmentationOutput(resultURI string) bool {
return len(result[0]) == 2 && result[0][0].(string) == model.D3MIndexFieldName && result[0][1].(string) == "positive_mask"
}
-func createSegmentationResult(resultURI string, targetName string, result [][]string) (string, error) {
+func createSegmentationResult(datasetID string, resultURI string,
+ targetName string, metaStorage api.MetadataStorage, dataStorage api.DataStorage) (string, error) {
+ log.Infof("processing segmentation pipeline output")
+ result, err := result.ParseResultCSV(resultURI)
+ if err != nil {
+ return "", err
+ }
+
+ images, err := BuildSegmentationImage(result)
+ if err != nil {
+ return "", err
+ }
+
+ // get the grouping key since it makes up part of the filename
+ dataset, err := metaStorage.FetchDataset(datasetID, true, true, false)
+ if err != nil {
+ return "", err
+ }
+
+ var groupingKey *model.Variable
+ for _, v := range dataset.Variables {
+ if v.HasRole(model.VarDistilRoleGrouping) {
+ groupingKey = v
+ break
+ }
+ }
+ if groupingKey == nil {
+ return "", errors.Errorf("no grouping found to use for output filename")
+ }
+
+ // get the d3m index -> grouping key mapping
+ mapping, err := api.BuildFieldMapping(dataset.ID, dataset.StorageName, model.D3MIndexFieldName, groupingKey.Key, dataStorage)
+ if err != nil {
+ return "", err
+ }
+
+ imageOutputFolder := path.Join(env.GetResourcePath(), dataset.ID, "media")
+ for d3mIndex, imageBytes := range images {
+ imageFilename := path.Join(imageOutputFolder, fmt.Sprintf("%s-segmentation.png", mapping[d3mIndex]))
+ err = util.WriteFileWithDirs(imageFilename, imageBytes, os.ModePerm)
+ if err != nil {
+ return "", err
+ }
+ }
+
resultOutput := []string{fmt.Sprintf("%s,%s,%s", model.D3MIndexFieldName, targetName, "confidence")}
for i := 1; i < len(result); i++ {
resultOutput = append(resultOutput, fmt.Sprintf("%s,%s,%d", result[i][0], "segmented", 1))
}
- resultOutputURI := fmt.Sprintf("%s-distil-%s", resultURI[:len(resultURI)-4], resultURI[len(resultURI)-4:])
+ resultOutputURI := fmt.Sprintf("%s-distil%s", resultURI[:len(resultURI)-4], resultURI[len(resultURI)-4:])
log.Infof("writing distil formatted segmentation results to '%s'", resultOutputURI)
- err := util.WriteFileWithDirs(resultOutputURI, []byte(strings.Join(resultOutput, "\n")), os.ModePerm)
+ err = util.WriteFileWithDirs(resultOutputURI, []byte(strings.Join(resultOutput, "\n")), os.ModePerm)
if err != nil {
return "", err
}
diff --git a/api/task/prediction.go b/api/task/prediction.go
index 4c2d7f5f8..e6ae3265d 100644
--- a/api/task/prediction.go
+++ b/api/task/prediction.go
@@ -709,8 +709,8 @@ func Predict(params *PredictParams) (string, error) {
// submit the new dataset for predictions
log.Infof("generating predictions using data found at '%s'", params.SchemaPath)
- predictionResult, err := comp.GeneratePredictions(params.SchemaPath,
- solution.SolutionID, params.FittedSolutionID, params.Task, params.Target.HeaderName, client)
+ predictionResult, err := comp.GeneratePredictions(params.Dataset, params.SchemaPath,
+ solution.SolutionID, params.FittedSolutionID, params.Task, params.Target.HeaderName, params.MetaStorage, params.DataStorage, client)
if err != nil {
return "", err
}
From 173c8089040ceb90d8c0fbcb2804064a481aaf30 Mon Sep 17 00:00:00 2001
From: phorne
Date: Thu, 21 Jul 2022 15:09:43 -0400
Subject: [PATCH 09/16] User can specify task properly now.
---
api/compute/solution_request.go | 14 ++++++++++----
api/ws/pipeline.go | 9 +++++----
public/components/CreateSolutionsForm.vue | 8 +++++++-
public/components/SettingsModal.vue | 21 +++++++++++++++++++++
public/store/requests/actions.ts | 2 ++
public/store/route/getters.ts | 8 ++++++++
public/store/route/module.ts | 1 +
public/util/routes.ts | 1 +
8 files changed, 55 insertions(+), 9 deletions(-)
diff --git a/api/compute/solution_request.go b/api/compute/solution_request.go
index 39a157e65..3e7c0f754 100644
--- a/api/compute/solution_request.go
+++ b/api/compute/solution_request.go
@@ -894,11 +894,17 @@ func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionSto
if err != nil {
return err
}
- task, err := ResolveTask(dataStorage, dataset.StorageName, s.TargetFeature, trainingVariables)
- if err != nil {
- return err
+
+ var task *Task
+ if len(s.Task) > 0 {
+ task = &Task{s.Task}
+ } else {
+ task, err = ResolveTask(dataStorage, dataset.StorageName, s.TargetFeature, trainingVariables)
+ if err != nil {
+ return err
+ }
+ s.Task = task.Task
}
- s.Task = task.Task
if HasTaskType(task, compute.SegmentationTask) {
return processSegmentation(s, client, solutionStorage, metaStorage, dataStorage)
diff --git a/api/ws/pipeline.go b/api/ws/pipeline.go
index 59723a689..ff2793cb8 100644
--- a/api/ws/pipeline.go
+++ b/api/ws/pipeline.go
@@ -158,12 +158,13 @@ func handleCreateSolutions(conn *Connection, client *compute.Client, metadataCto
// load defaults
config, _ := env.LoadConfig()
- if len(request.Task) == 0 {
- request.Task = api.DefaultTaskType(request.TargetFeature.Type, request.ProblemType)
- log.Infof("Defaulting task type to `%s`", request.Task)
+ metricTasks := request.Task
+ if len(metricTasks) == 0 {
+ metricTasks = api.DefaultTaskType(request.TargetFeature.Type, request.ProblemType)
+ log.Infof("Defaulting metric task type to `%s`", metricTasks)
}
if len(request.Metrics) == 0 {
- request.Metrics = api.DefaultMetrics(request.Task)
+ request.Metrics = api.DefaultMetrics(metricTasks)
log.Infof("Defaulting metrics to `%s`", strings.Join(request.Metrics, ","))
}
if request.MaxTime == 0 {
diff --git a/public/components/CreateSolutionsForm.vue b/public/components/CreateSolutionsForm.vue
index d1d95c740..02e722a0b 100644
--- a/public/components/CreateSolutionsForm.vue
+++ b/public/components/CreateSolutionsForm.vue
@@ -173,6 +173,8 @@ export default Vue.extend({
// flag as pending
this.pending = true;
// dispatch action that triggers request send to server
+ const selectedTask = routeGetters.getRouteSelectedTask(this.$store);
+ const taskToRun = selectedTask ? selectedTask.split(",") : null;
const routeSplit = routeGetters.getRouteTrainTestSplit(this.$store);
const defaultSplit = appGetters.getTrainTestSplit(this.$store);
const timestampSplit = routeGetters.getRouteTimestampSplit(this.$store);
@@ -193,6 +195,7 @@ export default Vue.extend({
quality: routeGetters.getModelQuality(this.$store),
trainTestSplit: !!routeSplit ? routeSplit : defaultSplit,
timestampSplitValue: timestampSplit,
+ task: taskToRun,
} as SolutionRequestMsg;
// Add optional values to the request
@@ -208,13 +211,16 @@ export default Vue.extend({
this.pending = false;
const dataMode = routeGetters.getDataMode(this.$store);
const dataModeDefault = dataMode ? dataMode : DataMode.Default;
+ const taskUsed = selectedTask
+ ? selectedTask
+ : routeGetters.getRouteTask(this.$store);
// transition to result screen
const entry = createRouteEntry(RESULTS_ROUTE, {
dataset: routeGetters.getRouteDataset(this.$store),
target: routeGetters.getRouteTargetVariable(this.$store),
solutionId: res.solutionId,
- task: routeGetters.getRouteTask(this.$store),
+ task: taskUsed,
dataMode: dataModeDefault,
varModes: varModesToString(
routeGetters.getDecodedVarModes(this.$store)
diff --git a/public/components/SettingsModal.vue b/public/components/SettingsModal.vue
index 387954294..2f1b225cc 100644
--- a/public/components/SettingsModal.vue
+++ b/public/components/SettingsModal.vue
@@ -183,6 +183,7 @@ export default Vue.extend({
// fill this from the API later, first posting back the target's type
// then getting a list of allowed scoring methods with keys, description
selectedMetric: null,
+ selectedTask: null,
trainingCount: 1,
timestampSplitValue: new Date(),
splitByTime: false,
@@ -226,6 +227,25 @@ export default Vue.extend({
return routeGetters.getRouteTask(this.$store) ?? "";
},
+ rebuildTask(): string {
+ // hack to submit only either classification or segmentation when dealing with remote sensing
+ if (this.multipleTasks) {
+ // if no task selected, then return null
+ if (this.selectedTask) {
+ return (
+ TaskTypes.REMOTE_SENSING +
+ "," +
+ TaskTypes.BINARY +
+ "," +
+ this.selectedTask
+ );
+ }
+ return null;
+ }
+
+ return this.task;
+ },
+
totalDataCount(): number {
return datasetGetters.getIncludedTableDataNumRows(this.$store);
},
@@ -336,6 +356,7 @@ export default Vue.extend({
modelTimeLimit: this.timeLimit,
modelQuality: this.speedQuality,
metrics: this.selectedMetric,
+ selectedTask: this.rebuildTask,
trainTestSplit: this.trainingRatio,
timestampSplit:
this.hasTimeRange && this.splitByTime
diff --git a/public/store/requests/actions.ts b/public/store/requests/actions.ts
index f0355f89e..ce78abb5f 100644
--- a/public/store/requests/actions.ts
+++ b/public/store/requests/actions.ts
@@ -74,6 +74,7 @@ export interface SolutionRequestMsg {
target: string;
timestampSplitValue?: number;
trainTestSplit: number;
+ task?: string[];
}
// Solution status message used in web socket context
@@ -596,6 +597,7 @@ export const actions = {
filters: request.filters,
trainTestSplit: request.trainTestSplit,
timestampSplitValue: request.timestampSplitValue,
+ task: request.task,
});
});
},
diff --git a/public/store/route/getters.ts b/public/store/route/getters.ts
index c49422d79..42c3f3341 100644
--- a/public/store/route/getters.ts
+++ b/public/store/route/getters.ts
@@ -604,6 +604,14 @@ export const getters = {
return task;
},
+ getRouteSelectedTask(state: Route, getters: any): string {
+ const selectedTask = state.query.selectedTask as string;
+ if (!selectedTask) {
+ return null;
+ }
+ return selectedTask;
+ },
+
getDataMode(state: Route, getters: any): DataMode {
const mode = state.query.dataMode as string;
if (!mode) {
diff --git a/public/store/route/module.ts b/public/store/route/module.ts
index 3ea2e2c8d..09e6c7f32 100644
--- a/public/store/route/module.ts
+++ b/public/store/route/module.ts
@@ -134,6 +134,7 @@ export const getters = {
getGeoZoom: read(moduleGetters.getGeoZoom),
getGroupingType: read(moduleGetters.getGroupingType),
getRouteTask: read(moduleGetters.getRouteTask),
+ getRouteSelectedTask: read(moduleGetters.getRouteSelectedTask),
getColorScale: read(moduleGetters.getColorScale),
getColorScaleVariable: read(moduleGetters.getColorScaleVariable),
getImageLayerScale: read(moduleGetters.getImageLayerScale),
diff --git a/public/util/routes.ts b/public/util/routes.ts
index c3f78340b..23d31a8c9 100644
--- a/public/util/routes.ts
+++ b/public/util/routes.ts
@@ -59,6 +59,7 @@ export interface RouteArgs {
resultTrainingVarsSearch?: string;
trainingVarsSearch?: string;
task?: string;
+ selectedTask?: string;
dataMode?: string;
varModes?: string;
varRanked?: string;
From c2961e9dbd6209c34799965664b336386c715342 Mon Sep 17 00:00:00 2001
From: phorne
Date: Fri, 22 Jul 2022 11:51:06 -0400
Subject: [PATCH 10/16] Added postgres schema version check on startup.
---
api/env/config.go | 1 +
api/model/storage/postgres/storage.go | 45 +++++++++++++-
api/postgres/postgres.go | 90 +++++++++++++++++++++++++++
main.go | 6 +-
4 files changed, 138 insertions(+), 4 deletions(-)
diff --git a/api/env/config.go b/api/env/config.go
index 19e507acd..33c2539ad 100644
--- a/api/env/config.go
+++ b/api/env/config.go
@@ -76,6 +76,7 @@ type Config struct {
PostgresPassword string `env:"PG_PASSWORD" envDefault:""`
PostgresPort int `env:"PG_PORT" envDefault:"5432"`
PostgresRandomSeed float64 `env:"PG_RANDOM_SEED" envDefault:"0.2"`
+ PostgresUpdate bool `env:"PG_UPDATE" envDefault:"false"`
PostgresUser string `env:"PG_USER" envDefault:"distil"`
PublicSubFolder string `env:"PUBLIC_SUBFOLDER" envDefault:"public"`
RankingOutputPath string `env:"RANKING_OUTPUT_PATH" envDefault:"importance.json"`
diff --git a/api/model/storage/postgres/storage.go b/api/model/storage/postgres/storage.go
index 26bf9e2a0..4ff5dcea9 100644
--- a/api/model/storage/postgres/storage.go
+++ b/api/model/storage/postgres/storage.go
@@ -22,6 +22,7 @@ import (
"github.com/uncharted-distil/distil-compute/model"
log "github.com/unchartedsoftware/plog"
+ "github.com/uncharted-distil/distil/api/env"
api "github.com/uncharted-distil/distil/api/model"
"github.com/uncharted-distil/distil/api/postgres"
)
@@ -46,10 +47,48 @@ func NewDataStorage(clientCtor postgres.ClientCtor, batchClientCtor postgres.Cli
}
// NewSolutionStorage returns a constructor for a solution storage.
-func NewSolutionStorage(clientCtor postgres.ClientCtor, metadataCtor api.MetadataStorageCtor) api.SolutionStorageCtor {
- return func() (api.SolutionStorage, error) {
- return newStorage(clientCtor, nil, metadataCtor)
+func NewSolutionStorage(clientCtor postgres.ClientCtor, metadataCtor api.MetadataStorageCtor, updateStorage bool) (api.SolutionStorageCtor, error) {
+ if updateStorage {
+ config, err := env.LoadConfig()
+ if err != nil {
+ return nil, err
+ }
+
+ // Connect to the database.
+ postgresConfig := &postgres.Config{
+ Password: config.PostgresPassword,
+ User: config.PostgresUser,
+ Database: config.PostgresDatabase,
+ Host: config.PostgresHost,
+ Port: config.PostgresPort,
+ PostgresLogLevel: "error",
+ }
+ pg, err := postgres.NewDatabase(postgresConfig, false)
+ if err != nil {
+ return nil, errors.Wrapf(err, "unable to initialize a new database")
+ }
+
+ latestSchema, err := pg.IsLatestSchema()
+ if err != nil {
+ return nil, err
+ }
+
+ if !latestSchema {
+ err = pg.InitializeConfig()
+ if err != nil {
+ return nil, err
+ }
+ }
}
+ return func() (api.SolutionStorage, error) {
+ storage, err := newStorage(clientCtor, nil, metadataCtor)
+
+ if err != nil {
+ return nil, err
+ }
+
+ return storage, nil
+ }, nil
}
func newStorage(clientCtor postgres.ClientCtor, batchClientCtor postgres.ClientCtor, metadataCtor api.MetadataStorageCtor) (*Storage, error) {
diff --git a/api/postgres/postgres.go b/api/postgres/postgres.go
index 6aaaaa385..45251eaf2 100644
--- a/api/postgres/postgres.go
+++ b/api/postgres/postgres.go
@@ -80,9 +80,18 @@ const (
// WordStemTableName is the name of the table for the word stems.
WordStemTableName = "word_stem"
+ configTableName = "config"
+ version = "0.1"
+
+ configTableCreationSQL = `CREATE TABLE %s (
+ key text,
+ value text
+ );`
+
requestTableCreationSQL = `CREATE TABLE %s (
request_id text,
dataset varchar(200),
+ task text,
progress varchar(40),
created_time timestamp,
last_updated_time timestamp
@@ -165,6 +174,8 @@ const (
resultTableSuffix = "_result"
variableTableSuffix = "_variable"
explainTableSuffix = "_explain"
+
+ distilSchemaKey = "distil-schema-version"
)
var (
@@ -229,6 +240,79 @@ func NewDatabase(config *Config, batch bool) (*Database, error) {
return database, nil
}
+// IsLatestSchema returns true if the solution metadata schema matches the latest.
+func (d *Database) IsLatestSchema() (bool, error) {
+ // check for the presence of the config table
+ configExists, err := d.tableExists(configTableName)
+ if err != nil {
+ return false, err
+ }
+
+ // if the config table isnt there, then it isnt the latest
+ if !configExists {
+ return false, nil
+ }
+
+ // check the version stored in the config table against the latest version
+ config, err := d.loadConfig()
+ if err != nil {
+ return false, err
+ }
+
+ return config[distilSchemaKey] == version, nil
+}
+
+func (d *Database) loadConfig() (map[string]string, error) {
+ log.Infof("reading postgres config")
+ sql := fmt.Sprintf("SELECT key, value FROM %s;", configTableName)
+
+ rows, err := d.Client.Query(sql)
+ if err != nil {
+ return nil, errors.Wrapf(err, "unable to query postgres config")
+ }
+ defer rows.Close()
+
+ configData := map[string]string{}
+ for rows.Next() {
+ var key string
+ var value string
+
+ err = rows.Scan(&key, &value)
+ if err != nil {
+ return nil, errors.Wrapf(err, "unable to parse postgres config")
+ }
+ configData[key] = value
+ }
+
+ log.Infof("postgres config: %v", configData)
+
+ return configData, nil
+}
+
+func (d *Database) tableExists(name string) (bool, error) {
+ sql := "SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1);"
+
+ rows, err := d.Client.Query(sql, name)
+ if err != nil {
+ return false, errors.Wrapf(err, "unable to verify if a table exists")
+ }
+ defer rows.Close()
+
+ rows.Next()
+ var exists bool
+ err = rows.Scan(&exists)
+ if err != nil {
+ return false, errors.Wrap(err, "unable to parse table existance result")
+ }
+
+ return exists, nil
+}
+
+// InitializeConfig sets up the config table with the current config values.
+func (d *Database) InitializeConfig() error {
+ return nil
+}
+
// CreateSolutionMetadataTables creates an empty table for the solution results.
func (d *Database) CreateSolutionMetadataTables() error {
// Create the solution tables.
@@ -240,6 +324,12 @@ func (d *Database) CreateSolutionMetadataTables() error {
return errors.Wrap(err, "failed to drop table")
}
+ _ = d.DropTable(configTableName)
+ _, err = d.Client.Exec(fmt.Sprintf(configTableCreationSQL, configTableName))
+ if err != nil {
+ return errors.Wrap(err, "failed to drop table")
+ }
+
_ = d.DropTable(RequestTableName)
_, err = d.Client.Exec(fmt.Sprintf(requestTableCreationSQL, RequestTableName))
if err != nil {
diff --git a/main.go b/main.go
index 2c30c7044..127e45790 100644
--- a/main.go
+++ b/main.go
@@ -166,7 +166,11 @@ func main() {
pgDataStorageCtor := pg.NewDataStorage(postgresClientCtor, postgresBatchClientCtor, esMetadataStorageCtor)
// instantiate the postgres solution storage constructor.
- pgSolutionStorageCtor := pg.NewSolutionStorage(postgresClientCtor, esMetadataStorageCtor)
+ pgSolutionStorageCtor, err := pg.NewSolutionStorage(postgresClientCtor, esMetadataStorageCtor, config.PostgresUpdate)
+ if err != nil {
+ log.Errorf("%+v", err)
+ os.Exit(1)
+ }
// Instantiate the solution compute client
solutionClient, err := task.NewDefaultClient(config, userAgent, discoveryLogger)
From d7f15b3aa883f88389853ecc5f7a61ea002d1317 Mon Sep 17 00:00:00 2001
From: phorne
Date: Fri, 22 Jul 2022 14:24:30 -0400
Subject: [PATCH 11/16] Added storage of request task to solution storage.
---
api/compute/search.go | 1 +
api/compute/solution_request.go | 27 ++++++++++++++-----------
api/model/model.go | 1 +
api/model/storage.go | 2 +-
api/model/storage/postgres/request.go | 20 +++++++++---------
api/model/storage/postgres/storage.go | 2 +-
api/postgres/postgres.go | 29 ++++++++++++++++++++++-----
7 files changed, 54 insertions(+), 28 deletions(-)
diff --git a/api/compute/search.go b/api/compute/search.go
index a2ae05d5e..e84aff02c 100644
--- a/api/compute/search.go
+++ b/api/compute/search.go
@@ -39,6 +39,7 @@ type searchResult struct {
type pipelineSearchContext struct {
searchID string
dataset string
+ task []string
storageName string
sourceDatasetURI string
trainDatasetURI string
diff --git a/api/compute/solution_request.go b/api/compute/solution_request.go
index 3e7c0f754..240f9737e 100644
--- a/api/compute/solution_request.go
+++ b/api/compute/solution_request.go
@@ -495,10 +495,11 @@ func (s *SolutionRequest) persistSolutionStatus(statusChan chan SolutionStatus,
}
}
-func (s *SolutionRequest) persistRequestError(statusChan chan SolutionStatus, solutionStorage api.SolutionStorage, searchID string, dataset string, err error) {
+func (s *SolutionRequest) persistRequestError(statusChan chan SolutionStatus,
+ solutionStorage api.SolutionStorage, searchID string, dataset string, task []string, err error) {
// persist the updated state
// NOTE: ignoring error
- _ = solutionStorage.PersistRequest(searchID, dataset, compute.RequestErroredStatus, time.Now())
+ _ = solutionStorage.PersistRequest(searchID, dataset, task, compute.RequestErroredStatus, time.Now())
// notify of error
statusChan <- SolutionStatus{
@@ -509,12 +510,13 @@ func (s *SolutionRequest) persistRequestError(statusChan chan SolutionStatus, so
}
}
-func (s *SolutionRequest) persistRequestStatus(statusChan chan SolutionStatus, solutionStorage api.SolutionStorage, searchID string, dataset string, status string) error {
+func (s *SolutionRequest) persistRequestStatus(statusChan chan SolutionStatus,
+ solutionStorage api.SolutionStorage, searchID string, dataset string, task []string, status string) error {
// persist the updated state
- err := solutionStorage.PersistRequest(searchID, dataset, status, time.Now())
+ err := solutionStorage.PersistRequest(searchID, dataset, task, status, time.Now())
if err != nil {
// notify of error
- s.persistRequestError(statusChan, solutionStorage, searchID, dataset, err)
+ s.persistRequestError(statusChan, solutionStorage, searchID, dataset, task, err)
return err
}
@@ -576,7 +578,7 @@ func describeSolution(client *compute.Client, initialSearchSolutionID string) (*
func (s *SolutionRequest) dispatchRequest(client *compute.Client, solutionStorage api.SolutionStorage,
dataStorage api.DataStorage, searchContext pipelineSearchContext) {
// update request status
- err := s.persistRequestStatus(s.requestChannel, solutionStorage, searchContext.searchID, searchContext.dataset, compute.RequestRunningStatus)
+ err := s.persistRequestStatus(s.requestChannel, solutionStorage, searchContext.searchID, searchContext.dataset, searchContext.task, compute.RequestRunningStatus)
if err != nil {
s.finished <- err
return
@@ -626,9 +628,9 @@ func (s *SolutionRequest) dispatchRequest(client *compute.Client, solutionStorag
// update request status
if err != nil {
- s.persistRequestError(s.requestChannel, solutionStorage, searchContext.searchID, searchContext.dataset, err)
+ s.persistRequestError(s.requestChannel, solutionStorage, searchContext.searchID, searchContext.dataset, searchContext.task, err)
} else {
- if err = s.persistRequestStatus(s.requestChannel, solutionStorage, searchContext.searchID, searchContext.dataset, compute.RequestCompletedStatus); err != nil {
+ if err = s.persistRequestStatus(s.requestChannel, solutionStorage, searchContext.searchID, searchContext.dataset, searchContext.task, compute.RequestCompletedStatus); err != nil {
log.Errorf("failed to persist status %s for search %s", compute.RequestCompletedStatus, searchContext.searchID)
}
}
@@ -647,7 +649,7 @@ func dispatchSegmentation(s *SolutionRequest, requestID string, solutionStorage
log.Infof("dispatching segmentation pipeline")
// create the backing data
- err := s.persistRequestStatus(s.requestChannel, solutionStorage, requestID, s.Dataset, compute.RequestRunningStatus)
+ err := s.persistRequestStatus(s.requestChannel, solutionStorage, requestID, s.Dataset, s.Task, compute.RequestRunningStatus)
if err != nil {
s.finished <- err
return
@@ -727,7 +729,7 @@ func dispatchSegmentation(s *SolutionRequest, requestID string, solutionStorage
log.Infof("segmentation pipeline processing complete")
- err = s.persistRequestStatus(s.requestChannel, solutionStorage, requestID, dataset.ID, compute.RequestCompletedStatus)
+ err = s.persistRequestStatus(s.requestChannel, solutionStorage, requestID, dataset.ID, s.Task, compute.RequestCompletedStatus)
if err != nil {
s.finished <- err
return
@@ -766,7 +768,7 @@ func processSegmentation(s *SolutionRequest, client *compute.Client, solutionSto
requestID := uuidGen.String()
// persist the request
- err = s.persistRequestStatus(s.requestChannel, solutionStorage, requestID, dataset.ID, compute.RequestPendingStatus)
+ err = s.persistRequestStatus(s.requestChannel, solutionStorage, requestID, dataset.ID, s.Task, compute.RequestPendingStatus)
if err != nil {
return err
}
@@ -993,7 +995,7 @@ func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionSto
}
// persist the request
- err = s.persistRequestStatus(s.requestChannel, solutionStorage, requestID, dataset.ID, compute.RequestPendingStatus)
+ err = s.persistRequestStatus(s.requestChannel, solutionStorage, requestID, dataset.ID, s.Task, compute.RequestPendingStatus)
if err != nil {
return err
}
@@ -1033,6 +1035,7 @@ func (s *SolutionRequest) PersistAndDispatch(client *compute.Client, solutionSto
searchContext := pipelineSearchContext{
searchID: requestID,
dataset: dataset.ID,
+ task: s.Task,
storageName: dataset.StorageName,
sourceDatasetURI: filteredDatasetPath,
trainDatasetURI: datasetPathTrain,
diff --git a/api/model/model.go b/api/model/model.go
index c1f10ea2c..526b91c4a 100644
--- a/api/model/model.go
+++ b/api/model/model.go
@@ -52,6 +52,7 @@ type ExportedModel struct {
type Request struct {
RequestID string `json:"requestId"`
Dataset string `json:"dataset"`
+ Task []string `json:"task"`
Progress string `json:"progress"`
CreatedTime time.Time `json:"timestamp"`
LastUpdatedTime time.Time `json:"lastUpdatedTime"`
diff --git a/api/model/storage.go b/api/model/storage.go
index 42745c544..d735460f0 100644
--- a/api/model/storage.go
+++ b/api/model/storage.go
@@ -138,7 +138,7 @@ type SolutionStorageCtor func() (SolutionStorage, error)
// solution storage.
type SolutionStorage interface {
PersistPrediction(requestID string, dataset string, target string, fittedSolutionID string, progress string, createdTime time.Time) error
- PersistRequest(requestID string, dataset string, progress string, createdTime time.Time) error
+ PersistRequest(requestID string, dataset string, task []string, progress string, createdTime time.Time) error
PersistRequestFeature(requestID string, featureName string, featureType string) error
PersistRequestFilters(requestID string, filters *FilterParams) error
PersistSolution(requestID string, solutionID string, explainedSolutionID string, createdTime time.Time) error
diff --git a/api/model/storage/postgres/request.go b/api/model/storage/postgres/request.go
index df4aa0f68..de07b82c6 100644
--- a/api/model/storage/postgres/request.go
+++ b/api/model/storage/postgres/request.go
@@ -28,10 +28,10 @@ import (
)
// PersistRequest persists a request to Postgres.
-func (s *Storage) PersistRequest(requestID string, dataset string, progress string, createdTime time.Time) error {
- sql := fmt.Sprintf("INSERT INTO %s (request_id, dataset, progress, created_time, last_updated_time) VALUES ($1, $2, $3, $4, $4);", postgres.RequestTableName)
+func (s *Storage) PersistRequest(requestID string, dataset string, task []string, progress string, createdTime time.Time) error {
+ sql := fmt.Sprintf("INSERT INTO %s (request_id, dataset, task, progress, created_time, last_updated_time) VALUES ($1, $2, $3, $4, $4, $5);", postgres.RequestTableName)
- _, err := s.client.Exec(sql, requestID, dataset, progress, createdTime)
+ _, err := s.client.Exec(sql, requestID, dataset, strings.Join(task, ","), progress, createdTime)
return errors.Wrapf(err, "failed to persist request to PostGres")
}
@@ -93,7 +93,7 @@ func (s *Storage) PersistRequestFilters(requestID string, filters *api.FilterPar
// FetchRequest pulls request information from Postgres.
func (s *Storage) FetchRequest(requestID string) (*api.Request, error) {
- sql := fmt.Sprintf("SELECT request_id, dataset, progress, created_time, last_updated_time FROM %s WHERE request_id = $1 ORDER BY created_time desc LIMIT 1;", postgres.RequestTableName)
+ sql := fmt.Sprintf("SELECT request_id, dataset, task, progress, created_time, last_updated_time FROM %s WHERE request_id = $1 ORDER BY created_time desc LIMIT 1;", postgres.RequestTableName)
rows, err := s.client.Query(sql, requestID)
if err != nil {
@@ -114,7 +114,7 @@ func (s *Storage) FetchRequest(requestID string) (*api.Request, error) {
// FetchRequestByResultUUID pulls request information from Postgres using
// a result UUID.
func (s *Storage) FetchRequestByResultUUID(resultUUID string) (*api.Request, error) {
- sql := fmt.Sprintf("SELECT req.request_id, req.dataset, req.progress, req.created_time, req.last_updated_time "+
+ sql := fmt.Sprintf("SELECT req.request_id, req.dataset, req.task, req.progress, req.created_time, req.last_updated_time "+
"FROM %s as req INNER JOIN %s as sol ON req.request_id = sol.request_id INNER JOIN %s as sol_res ON sol.solution_id = sol_res.solution_id "+
"WHERE sol_res.result_uuid = $1;", postgres.RequestTableName, postgres.SolutionTableName, postgres.SolutionResultTableName)
@@ -139,7 +139,7 @@ func (s *Storage) FetchRequestByResultUUID(resultUUID string) (*api.Request, err
// FetchRequestBySolutionID pulls request information from Postgres using
// a solution ID.
func (s *Storage) FetchRequestBySolutionID(solutionID string) (*api.Request, error) {
- sql := fmt.Sprintf("SELECT req.request_id, req.dataset, req.progress, req.created_time, req.last_updated_time "+
+ sql := fmt.Sprintf("SELECT req.request_id, req.dataset, req.task, req.progress, req.created_time, req.last_updated_time "+
"FROM %s as req INNER JOIN %s as sol ON req.request_id = sol.request_id "+
"WHERE sol.solution_id = $1;", postgres.RequestTableName, postgres.SolutionTableName)
@@ -164,7 +164,7 @@ func (s *Storage) FetchRequestBySolutionID(solutionID string) (*api.Request, err
// FetchRequestByFittedSolutionID pulls request information from Postgres using
// a fitted solution ID.
func (s *Storage) FetchRequestByFittedSolutionID(fittedSolutionID string) (*api.Request, error) {
- sql := fmt.Sprintf("SELECT req.request_id, req.dataset, req.progress, req.created_time, req.last_updated_time "+
+ sql := fmt.Sprintf("SELECT req.request_id, req.dataset, req.task, req.progress, req.created_time, req.last_updated_time "+
"FROM %s as req INNER JOIN %s as sol ON req.request_id = sol.request_id INNER JOIN %s sr on sr.solution_id = sol.solution_id "+
"WHERE sr.fitted_solution_id = $1;", postgres.RequestTableName, postgres.SolutionTableName, postgres.SolutionResultTableName)
@@ -189,11 +189,12 @@ func (s *Storage) FetchRequestByFittedSolutionID(fittedSolutionID string) (*api.
func (s *Storage) loadRequest(rows pgx.Rows) (*api.Request, error) {
var requestID string
var dataset string
+ var task string
var progress string
var createdTime time.Time
var lastUpdatedTime time.Time
- err := rows.Scan(&requestID, &dataset, &progress, &createdTime, &lastUpdatedTime)
+ err := rows.Scan(&requestID, &dataset, &task, &progress, &createdTime, &lastUpdatedTime)
if err != nil {
return nil, errors.Wrap(err, "Unable to parse request from Postgres")
}
@@ -211,6 +212,7 @@ func (s *Storage) loadRequest(rows pgx.Rows) (*api.Request, error) {
return &api.Request{
RequestID: requestID,
Dataset: dataset,
+ Task: strings.Split(task, ","),
Progress: progress,
CreatedTime: createdTime,
LastUpdatedTime: lastUpdatedTime,
@@ -350,7 +352,7 @@ func (s *Storage) FetchRequestFilters(requestID string, features []*api.Feature)
// FetchRequestByDatasetTarget pulls requests associated with a given dataset and target from postgres.
func (s *Storage) FetchRequestByDatasetTarget(dataset string, target string) ([]*api.Request, error) {
// get the solution ids
- sql := fmt.Sprintf("SELECT DISTINCT ON(request.request_id) request.request_id, request.dataset, request.progress, request.created_time, request.last_updated_time "+
+ sql := fmt.Sprintf("SELECT DISTINCT ON(request.request_id) request.request_id, request.dataset, request.task, request.progress, request.created_time, request.last_updated_time "+
"FROM %s request INNER JOIN %s rf ON request.request_id = rf.request_id "+
"INNER JOIN %s solution ON request.request_id = solution.request_id",
postgres.RequestTableName, postgres.RequestFeatureTableName, postgres.SolutionTableName)
diff --git a/api/model/storage/postgres/storage.go b/api/model/storage/postgres/storage.go
index 4ff5dcea9..0c754fe3d 100644
--- a/api/model/storage/postgres/storage.go
+++ b/api/model/storage/postgres/storage.go
@@ -74,7 +74,7 @@ func NewSolutionStorage(clientCtor postgres.ClientCtor, metadataCtor api.Metadat
}
if !latestSchema {
- err = pg.InitializeConfig()
+ err = pg.UpdateSchema()
if err != nil {
return nil, err
}
diff --git a/api/postgres/postgres.go b/api/postgres/postgres.go
index 45251eaf2..1ca75f397 100644
--- a/api/postgres/postgres.go
+++ b/api/postgres/postgres.go
@@ -243,6 +243,7 @@ func NewDatabase(config *Config, batch bool) (*Database, error) {
// IsLatestSchema returns true if the solution metadata schema matches the latest.
func (d *Database) IsLatestSchema() (bool, error) {
// check for the presence of the config table
+ log.Infof("verifying that postgres is using the latest schema")
configExists, err := d.tableExists(configTableName)
if err != nil {
return false, err
@@ -250,6 +251,7 @@ func (d *Database) IsLatestSchema() (bool, error) {
// if the config table isnt there, then it isnt the latest
if !configExists {
+ log.Infof("postgres not using latest schema as the config table does not exist")
return false, nil
}
@@ -259,9 +261,31 @@ func (d *Database) IsLatestSchema() (bool, error) {
return false, err
}
+ log.Infof("postgres schema version %s and the latest version is %s", config[distilSchemaKey], version)
return config[distilSchemaKey] == version, nil
}
+// UpdateSchema updates the metadata schema and stores the version to the database.
+func (d *Database) UpdateSchema() error {
+ // recreate metadata tables
+ err := d.CreateSolutionMetadataTables()
+ if err != nil {
+ return err
+ }
+
+ // write the version to the config table
+ configToStore := map[string]string{distilSchemaKey: version}
+ for k, v := range configToStore {
+ sql := fmt.Sprintf("INSERT INTO %s (key, value) VALUES ($1, $2);", configTableName)
+ _, err = d.Client.Exec(sql, k, v)
+ if err != nil {
+ return errors.Wrapf(err, "unable to store postgres config")
+ }
+ }
+
+ return nil
+}
+
func (d *Database) loadConfig() (map[string]string, error) {
log.Infof("reading postgres config")
sql := fmt.Sprintf("SELECT key, value FROM %s;", configTableName)
@@ -308,11 +332,6 @@ func (d *Database) tableExists(name string) (bool, error) {
return exists, nil
}
-// InitializeConfig sets up the config table with the current config values.
-func (d *Database) InitializeConfig() error {
- return nil
-}
-
// CreateSolutionMetadataTables creates an empty table for the solution results.
func (d *Database) CreateSolutionMetadataTables() error {
// Create the solution tables.
From 50bf17f5a661a23956694ffb57a5bb3af0156a44 Mon Sep 17 00:00:00 2001
From: phorne
Date: Fri, 22 Jul 2022 15:05:26 -0400
Subject: [PATCH 12/16] Persisted fitted solutions now also save the task.
---
api/model/model.go | 1 +
api/model/storage/elastic/model.go | 5 +++++
api/model/storage/postgres/request.go | 2 +-
api/task/solution.go | 1 +
4 files changed, 8 insertions(+), 1 deletion(-)
diff --git a/api/model/model.go b/api/model/model.go
index 526b91c4a..5b13fb6d5 100644
--- a/api/model/model.go
+++ b/api/model/model.go
@@ -42,6 +42,7 @@ type ExportedModel struct {
FittedSolutionID string `json:"fittedSolutionId"`
DatasetID string `json:"datasetId"`
DatasetName string `json:"datasetName"`
+ Task []string `json:"task"`
Target *SolutionVariable `json:"target"`
Variables []string `json:"variables"`
VariableDetails []*SolutionVariable `json:"variableDetails"`
diff --git a/api/model/storage/elastic/model.go b/api/model/storage/elastic/model.go
index c8d42f939..76667b35e 100644
--- a/api/model/storage/elastic/model.go
+++ b/api/model/storage/elastic/model.go
@@ -122,6 +122,10 @@ func (s *Storage) parseModels(res *elastic.SearchResult, includeDeleted bool) ([
if !ok {
return nil, errors.New("failed to parse the dataset id")
}
+
+ // get the task
+ tasks, _ := json.StringArray(src, "task")
+
// extract the target
targetInfo, ok := json.Get(src, "target")
if !ok {
@@ -153,6 +157,7 @@ func (s *Storage) parseModels(res *elastic.SearchResult, includeDeleted bool) ([
FittedSolutionID: fittedSolutionID,
DatasetID: datasetID,
DatasetName: name,
+ Task: tasks,
Target: target,
Variables: variables,
VariableDetails: variableDetails,
diff --git a/api/model/storage/postgres/request.go b/api/model/storage/postgres/request.go
index de07b82c6..fe7e11029 100644
--- a/api/model/storage/postgres/request.go
+++ b/api/model/storage/postgres/request.go
@@ -29,7 +29,7 @@ import (
// PersistRequest persists a request to Postgres.
func (s *Storage) PersistRequest(requestID string, dataset string, task []string, progress string, createdTime time.Time) error {
- sql := fmt.Sprintf("INSERT INTO %s (request_id, dataset, task, progress, created_time, last_updated_time) VALUES ($1, $2, $3, $4, $4, $5);", postgres.RequestTableName)
+ sql := fmt.Sprintf("INSERT INTO %s (request_id, dataset, task, progress, created_time, last_updated_time) VALUES ($1, $2, $3, $4, $5, $5);", postgres.RequestTableName)
_, err := s.client.Exec(sql, requestID, dataset, strings.Join(task, ","), progress, createdTime)
diff --git a/api/task/solution.go b/api/task/solution.go
index 2b39e891c..dcd4c3cca 100644
--- a/api/task/solution.go
+++ b/api/task/solution.go
@@ -88,6 +88,7 @@ func SaveFittedSolution(fittedSolutionID string, modelName string, modelDescript
FittedSolutionID: fittedSolutionID,
DatasetID: request.Dataset,
DatasetName: dataset.Name,
+ Task: request.Task,
Variables: vars,
VariableDetails: varDetails,
Target: target,
From a0f80ad7c0980e8d193278c6dea59571bad21e1e Mon Sep 17 00:00:00 2001
From: phorne
Date: Fri, 22 Jul 2022 15:54:57 -0400
Subject: [PATCH 13/16] Prediction task uses the model task if it was saved.
---
api/task/prediction.go | 10 +++-------
api/ws/pipeline.go | 26 ++++++++++++++++++++------
2 files changed, 23 insertions(+), 13 deletions(-)
diff --git a/api/task/prediction.go b/api/task/prediction.go
index e6ae3265d..e7fa48420 100644
--- a/api/task/prediction.go
+++ b/api/task/prediction.go
@@ -211,7 +211,7 @@ type PredictParams struct {
MetaStorage api.MetadataStorage
DataStorage api.DataStorage
SolutionStorage api.SolutionStorage
- ModelStorage api.ExportedModelStorage
+ ExportedModel *api.ExportedModel
IngestConfig *IngestTaskConfig
Config *env.Config
}
@@ -696,12 +696,8 @@ func Predict(params *PredictParams) (string, error) {
// Ensure the ta2 has fitted solution loaded. If the model wasn't saved, it should be available
// as part of the session.
- exportedModel, err := params.ModelStorage.FetchModelByID(params.FittedSolutionID)
- if err != nil {
- return "", err
- }
- if exportedModel != nil {
- _, err = LoadFittedSolution(exportedModel.FilePath, params.SolutionStorage, params.MetaStorage)
+ if params.ExportedModel != nil {
+ _, err = LoadFittedSolution(params.ExportedModel.FilePath, params.SolutionStorage, params.MetaStorage)
if err != nil {
return "", err
}
diff --git a/api/ws/pipeline.go b/api/ws/pipeline.go
index ff2793cb8..8992d8958 100644
--- a/api/ws/pipeline.go
+++ b/api/ws/pipeline.go
@@ -413,13 +413,27 @@ func handlePredict(conn *Connection, client *compute.Client, metadataCtor apiMod
return
}
- // resolve the task so we know what type of data we should be expecting
- requestTask, err := api.ResolveTask(dataStorage, meta.StorageName, targetVar, variables)
+ // get the exported model
+ exportedModel, err := modelStorage.FetchModelByID(request.FittedSolutionID)
if err != nil {
handleErr(conn, msg, err)
return
}
+ // if the task wasnt saved, determine it via the target and features
+ var modelTask *api.Task
+ if len(exportedModel.Task) > 0 {
+ modelTask = &api.Task{
+ Task: exportedModel.Task,
+ }
+ } else {
+ modelTask, err = api.ResolveTask(dataStorage, meta.StorageName, targetVar, variables)
+ if err != nil {
+ handleErr(conn, msg, err)
+ return
+ }
+ }
+
// config objects required for ingest
config, _ := env.LoadConfig()
ingestConfig := task.NewConfig(config)
@@ -431,25 +445,25 @@ func handlePredict(conn *Connection, client *compute.Client, metadataCtor apiMod
SolutionID: sr.SolutionID,
FittedSolutionID: request.FittedSolutionID,
OutputPath: path.Join(config.D3MOutputDir, config.AugmentedSubFolder),
- Task: requestTask,
+ Task: modelTask,
Target: targetVar,
MetaStorage: metaStorage,
DataStorage: dataStorage,
SolutionStorage: solutionStorage,
- ModelStorage: modelStorage,
+ ExportedModel: exportedModel,
Config: &config,
IngestConfig: ingestConfig,
SourceDatasetID: meta.ID,
}
- datasetName, datasetPath, err := getPredictionDataset(requestTask, request, predictParams)
+ datasetName, datasetPath, err := getPredictionDataset(modelTask, request, predictParams)
if err != nil {
handleErr(conn, msg, errors.Wrap(err, "unable to create raw dataset"))
return
}
// if the task is a segmentation task, run it against the base dataset
- if api.HasTaskType(requestTask, compute.SegmentationTask) {
+ if api.HasTaskType(modelTask, compute.SegmentationTask) {
dsPred, err := metaStorage.FetchDataset(datasetName, true, true, true)
if err != nil {
handleErr(conn, msg, errors.Wrap(err, "unable to resolve prediction dataset"))
From ac6d853bb431de27d88150f1ddbd8d34b94c4e34 Mon Sep 17 00:00:00 2001
From: phorne
Date: Mon, 25 Jul 2022 12:45:05 -0400
Subject: [PATCH 14/16] Segmentation pipelines are no longer cached as they
need to be treated like models and not fully specified pipelines.
---
api/compute/solution_request.go | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/api/compute/solution_request.go b/api/compute/solution_request.go
index 240f9737e..5019068ca 100644
--- a/api/compute/solution_request.go
+++ b/api/compute/solution_request.go
@@ -658,7 +658,7 @@ func dispatchSegmentation(s *SolutionRequest, requestID string, solutionStorage
c := newStatusChannel()
// run the pipeline
- pipelineResult, err := SubmitPipeline(client, []string{datasetInputDir}, nil, nil, step, nil, true)
+ pipelineResult, err := SubmitPipeline(client, []string{datasetInputDir}, nil, nil, step, nil, false)
if err != nil {
s.finished <- err
return
From 38f1743ea5b4c59b1cc1b226b7c9e032befcc375 Mon Sep 17 00:00:00 2001
From: Chris Bethune
Date: Fri, 19 Aug 2022 10:12:53 -0400
Subject: [PATCH 15/16] Adds parameters for GPU batch size and fixes segment
map opacity
---
api/compute/segment.go | 2 +-
api/compute/solution_request.go | 3 ++-
api/task/segment.go | 5 +++--
run.sh | 3 +++
4 files changed, 9 insertions(+), 4 deletions(-)
diff --git a/api/compute/segment.go b/api/compute/segment.go
index 0e3069feb..a68f3f314 100644
--- a/api/compute/segment.go
+++ b/api/compute/segment.go
@@ -46,7 +46,7 @@ func BuildSegmentationImage(rawSegmentation [][]interface{}) (map[string][]byte,
rawFloats[i] = nestedFloats
}
- filter := imagery.ConfidenceMatrixToImage(rawFloats, imagery.MagmaColorScale, uint8(100))
+ filter := imagery.ConfidenceMatrixToImage(rawFloats, imagery.MagmaColorScale, uint8(255))
imageBytes, err := imagery.ImageToPNG(filter)
if err != nil {
return nil, err
diff --git a/api/compute/solution_request.go b/api/compute/solution_request.go
index 5019068ca..f4fb7a3b7 100644
--- a/api/compute/solution_request.go
+++ b/api/compute/solution_request.go
@@ -755,7 +755,8 @@ func processSegmentation(s *SolutionRequest, client *compute.Client, solutionSto
datasetInputDir := env.ResolvePath(dataset.Source, dataset.Folder)
- step, err := description.CreateRemoteSensingSegmentationPipeline("segmentation", "basic image segmentation", s.TargetFeature, envConfig.RemoteSensingNumJobs)
+ step, err := description.CreateRemoteSensingSegmentationPipeline("segmentation", "basic image segmentation", s.TargetFeature,
+ envConfig.RemoteSensingNumJobs, envConfig.RemoteSensingGPUBatchSize)
if err != nil {
return err
}
diff --git a/api/task/segment.go b/api/task/segment.go
index 7ac28e4f6..6dbd8a582 100644
--- a/api/task/segment.go
+++ b/api/task/segment.go
@@ -49,7 +49,8 @@ func Segment(ds *api.Dataset, dataStorage api.DataStorage, variableName string)
}
}
- step, err := description.CreateRemoteSensingSegmentationPipeline("segmentation", "basic image segmentation", variable, envConfig.RemoteSensingNumJobs)
+ step, err := description.CreateRemoteSensingSegmentationPipeline("segmentation", "basic image segmentation", variable,
+ envConfig.RemoteSensingNumJobs, envConfig.RemoteSensingGPUBatchSize)
if err != nil {
return "", err
}
@@ -103,7 +104,7 @@ func Segment(ds *api.Dataset, dataStorage api.DataStorage, variableName string)
rawFloats[i] = nestedFloats
}
- filter := imagery.ConfidenceMatrixToImage(rawFloats, imagery.MagmaColorScale, uint8(100))
+ filter := imagery.ConfidenceMatrixToImage(rawFloats, imagery.MagmaColorScale, uint8(255))
imageBytes, err := imagery.ImageToPNG(filter)
if err != nil {
return "", err
diff --git a/run.sh b/run.sh
index dfc5cfd03..62525d636 100755
--- a/run.sh
+++ b/run.sh
@@ -35,6 +35,9 @@ export TILE_REQUEST_URL=https://server.arcgisonline.com/ArcGIS/rest/services/Wor
export INGEST_SAMPLE_ROW_LIMIT=200000
# export MAX_TRAINING_ROWS=500
# export MAX_TEST_ROWS=500
+export PG_UPDATE=true
+export SEGMENTATION_ENABLED=true
+export REMOTE_SENSING_GPU_BATCH_SIZE=4
ulimit -n 4096
From b4207b9ff318045d4ee15eebccc9b8b65c3fb22d Mon Sep 17 00:00:00 2001
From: Chris Bethune
Date: Fri, 19 Aug 2022 11:16:15 -0400
Subject: [PATCH 16/16] updates to latest distil-compute
---
go.mod | 2 +-
go.sum | 2 ++
2 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/go.mod b/go.mod
index 2327dd0b2..5c67f4875 100644
--- a/go.mod
+++ b/go.mod
@@ -34,7 +34,7 @@ require (
github.com/russross/blackfriday v2.0.0+incompatible
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
github.com/stretchr/testify v1.6.1
- github.com/uncharted-distil/distil-compute v0.0.0-20220715171604-26f9f01bab93
+ github.com/uncharted-distil/distil-compute v0.0.0-20220818194426-a130f919e111
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb
github.com/uncharted-distil/gdal v0.0.0-20200504224203-25f2e6a0dc2a
github.com/unchartedsoftware/plog v0.0.0-20200807135627-83d59e50ced5
diff --git a/go.sum b/go.sum
index 1f4b7d3fc..86f82af75 100644
--- a/go.sum
+++ b/go.sum
@@ -215,6 +215,8 @@ github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/uncharted-distil/distil-compute v0.0.0-20220715171604-26f9f01bab93 h1:UNSU3FX3h4k8wrzzXWLtX2kl4bb2AW7BqoV2FkQigRs=
github.com/uncharted-distil/distil-compute v0.0.0-20220715171604-26f9f01bab93/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8=
+github.com/uncharted-distil/distil-compute v0.0.0-20220818194426-a130f919e111 h1:HRYDNq9tSNcqZ02mzfnp8Ee+piBUSGt6vXUjjZxKxIM=
+github.com/uncharted-distil/distil-compute v0.0.0-20220818194426-a130f919e111/go.mod h1:iFA7B2kb+WJfkzukdwfZJVY3o/ZFEjHPsA8k2N6I+B8=
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb h1:wDsXsrF8qM34nLeQ9xW+zbEdRNATk5sgOwuwCTrZmvY=
github.com/uncharted-distil/distil-image-upscale v0.0.0-20210923132226-8eaee866ebdb/go.mod h1:Xhb77n2q8yDvcVS3Mvw0XlpdNMiFsL+vOlvoe556ivc=
github.com/uncharted-distil/gdal v0.0.0-20200504224203-25f2e6a0dc2a h1:BPJrlnjdhxMBrJWiU4/Gl3PVdCUlY9JspWFTJ9UVO0Y=