Skip to content

Commit

Permalink
change: use value in batch results and cast to concrete type in postp…
Browse files Browse the repository at this point in the history
…rocessing
  • Loading branch information
riccardopinosio committed Sep 10, 2024
1 parent ed4ef41 commit 2e34127
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 26 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ require (
github.com/viant/afs v1.25.1
github.com/viant/afsc v1.9.3
github.com/yalue/onnxruntime_go v1.12.0
golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0
)

require (
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ github.com/yalue/onnxruntime_go v1.12.0 h1:UtrSZOV9cY9j8ualjiakzRSn7H+bvu6QyCHAA
github.com/yalue/onnxruntime_go v1.12.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e h1:I88y4caeGeuDQxgdoFPUq097j7kNfw6uvuiNxUBfcBk=
golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk=
golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY=
golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs=
golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
Expand Down
3 changes: 2 additions & 1 deletion pipelines/featureExtraction.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,9 @@ func (p *FeatureExtractionPipeline) Postprocess(batch *PipelineBatch) (*FeatureE
tokenEmbeddings := make([][]float32, maxSequenceLength)
tokenEmbeddingsCounter := 0
batchInputCounter := 0
outputTensor := batch.OutputValues[0].(*ort.Tensor[float32])

for _, result := range batch.OutputTensors[0].GetData() {
for _, result := range outputTensor.GetData() {
outputEmbedding[outputEmbeddingCounter] = result
if outputEmbeddingCounter == int(embeddingDimension)-1 {
// we gathered one embedding
Expand Down
28 changes: 10 additions & 18 deletions pipelines/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,19 @@ type tokenizedInput struct {
// PipelineBatch represents a batch of inputs that runs through the pipeline.
type PipelineBatch struct {
Input []tokenizedInput
InputTensors []*ort.Tensor[int64]
MaxSequenceLength int
OutputTensors []*ort.Tensor[float32]
InputValues []ort.Value
OutputValues []ort.Value
}

func (b *PipelineBatch) Destroy() error {
destroyErrors := make([]error, 0, len(b.InputTensors)+len(b.OutputTensors))
destroyErrors := make([]error, 0, len(b.InputValues)+len(b.OutputValues))

for _, tensor := range b.InputTensors {
for _, tensor := range b.InputValues {
destroyErrors = append(destroyErrors, tensor.Destroy())
}

for _, tensor := range b.OutputTensors {
for _, tensor := range b.OutputValues {
destroyErrors = append(destroyErrors, tensor.Destroy())
}
return errors.Join(destroyErrors...)
Expand Down Expand Up @@ -231,7 +231,7 @@ func createInputTensors(batch *PipelineBatch, inputsMeta []ort.InputOutputInfo)
tensorSize := len(batch.Input) * (batch.MaxSequenceLength)
batchSize := int64(len(batch.Input))

inputTensors := make([]*ort.Tensor[int64], len(inputsMeta))
inputTensors := make([]ort.Value, len(inputsMeta))
var tensorCreationErr error

for i, inputMeta := range inputsMeta {
Expand Down Expand Up @@ -263,7 +263,7 @@ func createInputTensors(batch *PipelineBatch, inputsMeta []ort.InputOutputInfo)
return tensorCreationErr
}
}
batch.InputTensors = inputTensors
batch.InputValues = inputTensors
return nil
}

Expand Down Expand Up @@ -297,8 +297,7 @@ func runSessionOnBatch(batch *PipelineBatch, session *ort.DynamicAdvancedSession
maxSequenceLength := int64(batch.MaxSequenceLength)

// allocate vectors with right dimensions for the output
outputTensors := make([]*ort.Tensor[float32], len(outputs))
arbitraryOutputTensors := make([]ort.ArbitraryTensor, len(outputs))
outputTensors := make([]ort.Value, len(outputs))
var outputCreationErr error

for outputIndex, meta := range outputs {
Expand Down Expand Up @@ -326,22 +325,15 @@ func runSessionOnBatch(batch *PipelineBatch, session *ort.DynamicAdvancedSession
if outputCreationErr != nil {
return outputCreationErr
}
arbitraryOutputTensors[outputIndex] = ort.ArbitraryTensor(outputTensors[outputIndex])
}

// Run Onnx model
arbitraryInputTensors := make([]ort.ArbitraryTensor, len(batch.InputTensors))
for i, t := range batch.InputTensors {
arbitraryInputTensors[i] = ort.ArbitraryTensor(t)
}

errOnnx := session.Run(arbitraryInputTensors, arbitraryOutputTensors)
errOnnx := session.Run(batch.InputValues, outputTensors)
if errOnnx != nil {
return errOnnx
}

// store resulting tensors
batch.OutputTensors = outputTensors
batch.OutputValues = outputTensors
return nil
}

Expand Down
6 changes: 4 additions & 2 deletions pipelines/textClassification.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func (p *TextClassificationPipeline) Forward(batch *PipelineBatch) error {
}

func (p *TextClassificationPipeline) Postprocess(batch *PipelineBatch) (*TextClassificationOutput, error) {
outputTensor := batch.OutputTensors[0]
outputValue := batch.OutputValues[0]
outputDims := p.OutputsMeta[0].Dimensions
nLogit := outputDims[len(outputDims)-1]
output := make([][]float32, len(batch.Input))
Expand All @@ -257,7 +257,9 @@ func (p *TextClassificationPipeline) Postprocess(batch *PipelineBatch) (*TextCla
return nil, fmt.Errorf("aggregation function %s is not supported", p.AggregationFunctionName)
}

for _, result := range outputTensor.GetData() {
resultTensor := outputValue.(*ort.Tensor[float32])

for _, result := range resultTensor.GetData() {
inputVector[vectorCounter] = result
if vectorCounter == int(nLogit)-1 {
output[inputCounter] = aggregationFunction(inputVector)
Expand Down
3 changes: 2 additions & 1 deletion pipelines/tokenClassification.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ func (p *TokenClassificationPipeline) Postprocess(batch *PipelineBatch) (*TokenC
// construct the output vectors by gathering the logits,
// however discard the embeddings of the padding tokens so that the output vector length
// for an input is equal to the number of original tokens
for _, result := range batch.OutputTensors[0].GetData() {
outputTensor := batch.OutputValues[0].(*ort.Tensor[float32])
for _, result := range outputTensor.GetData() {
tokenVector[tokenVectorCounter] = result
if tokenVectorCounter == tokenLogitsDim-1 {
// raw result vector for token is now complete
Expand Down
3 changes: 2 additions & 1 deletion pipelines/zeroShotClassification.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,8 @@ func (p *ZeroShotClassificationPipeline) RunPipeline(inputs []string) (*ZeroShot
if e := errors.Join(runErrors...); e != nil {
return nil, e
}
sequenceTensors = append(sequenceTensors, batch.OutputTensors[0].GetData())
outputTensor := batch.OutputValues[0].(*ort.Tensor[float32])
sequenceTensors = append(sequenceTensors, outputTensor.GetData())
}
outputTensors = append(outputTensors, sequenceTensors)
}
Expand Down

0 comments on commit 2e34127

Please sign in to comment.