From 2e3412739627ea14ee97ee736cbd89a526050319 Mon Sep 17 00:00:00 2001
From: riccardopinosio <rpinosio@gmail.com>
Date: Tue, 10 Sep 2024 08:19:52 +0000
Subject: [PATCH] change: use value in batch results and cast to concrete type
 in postprocessing

---
 go.mod                              |  2 +-
 go.sum                              |  4 ++--
 pipelines/featureExtraction.go      |  3 ++-
 pipelines/pipeline.go               | 28 ++++++++++------------------
 pipelines/textClassification.go     |  6 ++++--
 pipelines/tokenClassification.go    |  3 ++-
 pipelines/zeroShotClassification.go |  3 ++-
 7 files changed, 23 insertions(+), 26 deletions(-)

diff --git a/go.mod b/go.mod
index 9bf6ec2..83cb58d 100644
--- a/go.mod
+++ b/go.mod
@@ -12,7 +12,7 @@ require (
 	github.com/viant/afs v1.25.1
 	github.com/viant/afsc v1.9.3
 	github.com/yalue/onnxruntime_go v1.12.0
-	golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e
+	golang.org/x/exp v0.0.0-20240909161429-701f63a606c0
 )
 
 require (
diff --git a/go.sum b/go.sum
index 3770eae..e6336c1 100644
--- a/go.sum
+++ b/go.sum
@@ -57,8 +57,8 @@ github.com/yalue/onnxruntime_go v1.12.0 h1:UtrSZOV9cY9j8ualjiakzRSn7H+bvu6QyCHAA
 github.com/yalue/onnxruntime_go v1.12.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
 golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
 golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
-golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e h1:I88y4caeGeuDQxgdoFPUq097j7kNfw6uvuiNxUBfcBk=
-golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ=
+golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk=
+golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY=
 golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs=
 golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
 golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
diff --git a/pipelines/featureExtraction.go b/pipelines/featureExtraction.go
index c485bf9..3dffadf 100644
--- a/pipelines/featureExtraction.go
+++ b/pipelines/featureExtraction.go
@@ -225,8 +225,9 @@ func (p *FeatureExtractionPipeline) Postprocess(batch *PipelineBatch) (*FeatureE
 	tokenEmbeddings := make([][]float32, maxSequenceLength)
 	tokenEmbeddingsCounter := 0
 	batchInputCounter := 0
+	outputTensor := batch.OutputValues[0].(*ort.Tensor[float32])
 
-	for _, result := range batch.OutputTensors[0].GetData() {
+	for _, result := range outputTensor.GetData() {
 		outputEmbedding[outputEmbeddingCounter] = result
 		if outputEmbeddingCounter == int(embeddingDimension)-1 {
 			// we gathered one embedding
diff --git a/pipelines/pipeline.go b/pipelines/pipeline.go
index a3ff679..c05e008 100644
--- a/pipelines/pipeline.go
+++ b/pipelines/pipeline.go
@@ -83,19 +83,19 @@ type tokenizedInput struct {
 // PipelineBatch represents a batch of inputs that runs through the pipeline.
 type PipelineBatch struct {
 	Input             []tokenizedInput
-	InputTensors      []*ort.Tensor[int64]
 	MaxSequenceLength int
-	OutputTensors     []*ort.Tensor[float32]
+	InputValues       []ort.Value
+	OutputValues      []ort.Value
 }
 
 func (b *PipelineBatch) Destroy() error {
-	destroyErrors := make([]error, 0, len(b.InputTensors)+len(b.OutputTensors))
+	destroyErrors := make([]error, 0, len(b.InputValues)+len(b.OutputValues))
 
-	for _, tensor := range b.InputTensors {
+	for _, tensor := range b.InputValues {
 		destroyErrors = append(destroyErrors, tensor.Destroy())
 	}
 
-	for _, tensor := range b.OutputTensors {
+	for _, tensor := range b.OutputValues {
 		destroyErrors = append(destroyErrors, tensor.Destroy())
 	}
 	return errors.Join(destroyErrors...)
@@ -231,7 +231,7 @@ func createInputTensors(batch *PipelineBatch, inputsMeta []ort.InputOutputInfo)
 	tensorSize := len(batch.Input) * (batch.MaxSequenceLength)
 	batchSize := int64(len(batch.Input))
 
-	inputTensors := make([]*ort.Tensor[int64], len(inputsMeta))
+	inputTensors := make([]ort.Value, len(inputsMeta))
 	var tensorCreationErr error
 
 	for i, inputMeta := range inputsMeta {
@@ -263,7 +263,7 @@ func createInputTensors(batch *PipelineBatch, inputsMeta []ort.InputOutputInfo)
 			return tensorCreationErr
 		}
 	}
-	batch.InputTensors = inputTensors
+	batch.InputValues = inputTensors
 	return nil
 }
 
@@ -297,8 +297,7 @@ func runSessionOnBatch(batch *PipelineBatch, session *ort.DynamicAdvancedSession
 	maxSequenceLength := int64(batch.MaxSequenceLength)
 
 	// allocate vectors with right dimensions for the output
-	outputTensors := make([]*ort.Tensor[float32], len(outputs))
-	arbitraryOutputTensors := make([]ort.ArbitraryTensor, len(outputs))
+	outputTensors := make([]ort.Value, len(outputs))
 	var outputCreationErr error
 
 	for outputIndex, meta := range outputs {
@@ -326,22 +325,15 @@ func runSessionOnBatch(batch *PipelineBatch, session *ort.DynamicAdvancedSession
 		if outputCreationErr != nil {
 			return outputCreationErr
 		}
-		arbitraryOutputTensors[outputIndex] = ort.ArbitraryTensor(outputTensors[outputIndex])
 	}
 
-	// Run Onnx model
-	arbitraryInputTensors := make([]ort.ArbitraryTensor, len(batch.InputTensors))
-	for i, t := range batch.InputTensors {
-		arbitraryInputTensors[i] = ort.ArbitraryTensor(t)
-	}
-
-	errOnnx := session.Run(arbitraryInputTensors, arbitraryOutputTensors)
+	errOnnx := session.Run(batch.InputValues, outputTensors)
 	if errOnnx != nil {
 		return errOnnx
 	}
 
 	// store resulting tensors
-	batch.OutputTensors = outputTensors
+	batch.OutputValues = outputTensors
 	return nil
 }
 
diff --git a/pipelines/textClassification.go b/pipelines/textClassification.go
index a0358bf..791b3c2 100644
--- a/pipelines/textClassification.go
+++ b/pipelines/textClassification.go
@@ -240,7 +240,7 @@ func (p *TextClassificationPipeline) Forward(batch *PipelineBatch) error {
 }
 
 func (p *TextClassificationPipeline) Postprocess(batch *PipelineBatch) (*TextClassificationOutput, error) {
-	outputTensor := batch.OutputTensors[0]
+	outputValue := batch.OutputValues[0]
 	outputDims := p.OutputsMeta[0].Dimensions
 	nLogit := outputDims[len(outputDims)-1]
 	output := make([][]float32, len(batch.Input))
@@ -257,7 +257,9 @@ func (p *TextClassificationPipeline) Postprocess(batch *PipelineBatch) (*TextCla
 		return nil, fmt.Errorf("aggregation function %s is not supported", p.AggregationFunctionName)
 	}
 
-	for _, result := range outputTensor.GetData() {
+	resultTensor := outputValue.(*ort.Tensor[float32])
+
+	for _, result := range resultTensor.GetData() {
 		inputVector[vectorCounter] = result
 		if vectorCounter == int(nLogit)-1 {
 			output[inputCounter] = aggregationFunction(inputVector)
diff --git a/pipelines/tokenClassification.go b/pipelines/tokenClassification.go
index cb57b72..58ae2ff 100644
--- a/pipelines/tokenClassification.go
+++ b/pipelines/tokenClassification.go
@@ -256,7 +256,8 @@ func (p *TokenClassificationPipeline) Postprocess(batch *PipelineBatch) (*TokenC
 	// construct the output vectors by gathering the logits,
 	// however discard the embeddings of the padding tokens so that the output vector length
 	// for an input is equal to the number of original tokens
-	for _, result := range batch.OutputTensors[0].GetData() {
+	outputTensor := batch.OutputValues[0].(*ort.Tensor[float32])
+	for _, result := range outputTensor.GetData() {
 		tokenVector[tokenVectorCounter] = result
 		if tokenVectorCounter == tokenLogitsDim-1 {
 			// raw result vector for token is now complete
diff --git a/pipelines/zeroShotClassification.go b/pipelines/zeroShotClassification.go
index 7f46940..f5b526f 100644
--- a/pipelines/zeroShotClassification.go
+++ b/pipelines/zeroShotClassification.go
@@ -444,7 +444,8 @@ func (p *ZeroShotClassificationPipeline) RunPipeline(inputs []string) (*ZeroShot
 			if e := errors.Join(runErrors...); e != nil {
 				return nil, e
 			}
-			sequenceTensors = append(sequenceTensors, batch.OutputTensors[0].GetData())
+			outputTensor := batch.OutputValues[0].(*ort.Tensor[float32])
+			sequenceTensors = append(sequenceTensors, outputTensor.GetData())
 		}
 		outputTensors = append(outputTensors, sequenceTensors)
 	}