From 2e3412739627ea14ee97ee736cbd89a526050319 Mon Sep 17 00:00:00 2001 From: riccardopinosio <rpinosio@gmail.com> Date: Tue, 10 Sep 2024 08:19:52 +0000 Subject: [PATCH] change: use value in batch results and cast to concrete type in postprocessing --- go.mod | 2 +- go.sum | 4 ++-- pipelines/featureExtraction.go | 3 ++- pipelines/pipeline.go | 28 ++++++++++------------------ pipelines/textClassification.go | 6 ++++-- pipelines/tokenClassification.go | 3 ++- pipelines/zeroShotClassification.go | 3 ++- 7 files changed, 23 insertions(+), 26 deletions(-) diff --git a/go.mod b/go.mod index 9bf6ec2..83cb58d 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/viant/afs v1.25.1 github.com/viant/afsc v1.9.3 github.com/yalue/onnxruntime_go v1.12.0 - golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e + golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 ) require ( diff --git a/go.sum b/go.sum index 3770eae..e6336c1 100644 --- a/go.sum +++ b/go.sum @@ -57,8 +57,8 @@ github.com/yalue/onnxruntime_go v1.12.0 h1:UtrSZOV9cY9j8ualjiakzRSn7H+bvu6QyCHAA github.com/yalue/onnxruntime_go v1.12.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= -golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e h1:I88y4caeGeuDQxgdoFPUq097j7kNfw6uvuiNxUBfcBk= -golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= +golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= +golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/pipelines/featureExtraction.go b/pipelines/featureExtraction.go index c485bf9..3dffadf 100644 --- a/pipelines/featureExtraction.go +++ b/pipelines/featureExtraction.go @@ -225,8 +225,9 @@ func (p *FeatureExtractionPipeline) Postprocess(batch *PipelineBatch) (*FeatureE tokenEmbeddings := make([][]float32, maxSequenceLength) tokenEmbeddingsCounter := 0 batchInputCounter := 0 + outputTensor := batch.OutputValues[0].(*ort.Tensor[float32]) - for _, result := range batch.OutputTensors[0].GetData() { + for _, result := range outputTensor.GetData() { outputEmbedding[outputEmbeddingCounter] = result if outputEmbeddingCounter == int(embeddingDimension)-1 { // we gathered one embedding diff --git a/pipelines/pipeline.go b/pipelines/pipeline.go index a3ff679..c05e008 100644 --- a/pipelines/pipeline.go +++ b/pipelines/pipeline.go @@ -83,19 +83,19 @@ type tokenizedInput struct { // PipelineBatch represents a batch of inputs that runs through the pipeline. type PipelineBatch struct { Input []tokenizedInput - InputTensors []*ort.Tensor[int64] MaxSequenceLength int - OutputTensors []*ort.Tensor[float32] + InputValues []ort.Value + OutputValues []ort.Value } func (b *PipelineBatch) Destroy() error { - destroyErrors := make([]error, 0, len(b.InputTensors)+len(b.OutputTensors)) + destroyErrors := make([]error, 0, len(b.InputValues)+len(b.OutputValues)) - for _, tensor := range b.InputTensors { + for _, tensor := range b.InputValues { destroyErrors = append(destroyErrors, tensor.Destroy()) } - for _, tensor := range b.OutputTensors { + for _, tensor := range b.OutputValues { destroyErrors = append(destroyErrors, tensor.Destroy()) } return errors.Join(destroyErrors...) @@ -231,7 +231,7 @@ func createInputTensors(batch *PipelineBatch, inputsMeta []ort.InputOutputInfo) tensorSize := len(batch.Input) * (batch.MaxSequenceLength) batchSize := int64(len(batch.Input)) - inputTensors := make([]*ort.Tensor[int64], len(inputsMeta)) + inputTensors := make([]ort.Value, len(inputsMeta)) var tensorCreationErr error for i, inputMeta := range inputsMeta { @@ -263,7 +263,7 @@ func createInputTensors(batch *PipelineBatch, inputsMeta []ort.InputOutputInfo) return tensorCreationErr } } - batch.InputTensors = inputTensors + batch.InputValues = inputTensors return nil } @@ -297,8 +297,7 @@ func runSessionOnBatch(batch *PipelineBatch, session *ort.DynamicAdvancedSession maxSequenceLength := int64(batch.MaxSequenceLength) // allocate vectors with right dimensions for the output - outputTensors := make([]*ort.Tensor[float32], len(outputs)) - arbitraryOutputTensors := make([]ort.ArbitraryTensor, len(outputs)) + outputTensors := make([]ort.Value, len(outputs)) var outputCreationErr error for outputIndex, meta := range outputs { @@ -326,22 +325,15 @@ func runSessionOnBatch(batch *PipelineBatch, session *ort.DynamicAdvancedSession if outputCreationErr != nil { return outputCreationErr } - arbitraryOutputTensors[outputIndex] = ort.ArbitraryTensor(outputTensors[outputIndex]) } - // Run Onnx model - arbitraryInputTensors := make([]ort.ArbitraryTensor, len(batch.InputTensors)) - for i, t := range batch.InputTensors { - arbitraryInputTensors[i] = ort.ArbitraryTensor(t) - } - - errOnnx := session.Run(arbitraryInputTensors, arbitraryOutputTensors) + errOnnx := session.Run(batch.InputValues, outputTensors) if errOnnx != nil { return errOnnx } // store resulting tensors - batch.OutputTensors = outputTensors + batch.OutputValues = outputTensors return nil } diff --git a/pipelines/textClassification.go b/pipelines/textClassification.go index a0358bf..791b3c2 100644 --- a/pipelines/textClassification.go +++ b/pipelines/textClassification.go @@ -240,7 +240,7 @@ func (p *TextClassificationPipeline) Forward(batch *PipelineBatch) error { } func (p *TextClassificationPipeline) Postprocess(batch *PipelineBatch) (*TextClassificationOutput, error) { - outputTensor := batch.OutputTensors[0] + outputValue := batch.OutputValues[0] outputDims := p.OutputsMeta[0].Dimensions nLogit := outputDims[len(outputDims)-1] output := make([][]float32, len(batch.Input)) @@ -257,7 +257,9 @@ func (p *TextClassificationPipeline) Postprocess(batch *PipelineBatch) (*TextCla return nil, fmt.Errorf("aggregation function %s is not supported", p.AggregationFunctionName) } - for _, result := range outputTensor.GetData() { + resultTensor := outputValue.(*ort.Tensor[float32]) + + for _, result := range resultTensor.GetData() { inputVector[vectorCounter] = result if vectorCounter == int(nLogit)-1 { output[inputCounter] = aggregationFunction(inputVector) diff --git a/pipelines/tokenClassification.go b/pipelines/tokenClassification.go index cb57b72..58ae2ff 100644 --- a/pipelines/tokenClassification.go +++ b/pipelines/tokenClassification.go @@ -256,7 +256,8 @@ func (p *TokenClassificationPipeline) Postprocess(batch *PipelineBatch) (*TokenC // construct the output vectors by gathering the logits, // however discard the embeddings of the padding tokens so that the output vector length // for an input is equal to the number of original tokens - for _, result := range batch.OutputTensors[0].GetData() { + outputTensor := batch.OutputValues[0].(*ort.Tensor[float32]) + for _, result := range outputTensor.GetData() { tokenVector[tokenVectorCounter] = result if tokenVectorCounter == tokenLogitsDim-1 { // raw result vector for token is now complete diff --git a/pipelines/zeroShotClassification.go b/pipelines/zeroShotClassification.go index 7f46940..f5b526f 100644 --- a/pipelines/zeroShotClassification.go +++ b/pipelines/zeroShotClassification.go @@ -444,7 +444,8 @@ func (p *ZeroShotClassificationPipeline) RunPipeline(inputs []string) (*ZeroShot if e := errors.Join(runErrors...); e != nil { return nil, e } - sequenceTensors = append(sequenceTensors, batch.OutputTensors[0].GetData()) + outputTensor := batch.OutputValues[0].(*ort.Tensor[float32]) + sequenceTensors = append(sequenceTensors, outputTensor.GetData()) } outputTensors = append(outputTensors, sequenceTensors) }