diff --git a/Dockerfile b/Dockerfile index 162acec..2abacf8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,13 +1,13 @@ -ARG GO_VERSION=1.22.3 -ARG RUST_VERSION=1.78 +ARG GO_VERSION=1.22.5 +ARG RUST_VERSION=1.79 ARG ONNXRUNTIME_VERSION=1.18.0 ARG BUILD_PLATFORM=linux/amd64 - +ARG CGO_LDFLAGS="-L./usr/lib/libtokenizers.a" #--- rust build of tokenizer --- FROM --platform=$BUILD_PLATFORM rust:$RUST_VERSION AS tokenizer -RUN git clone https://github.com/knights-analytics/tokenizers -b main && \ +RUN git clone https://github.com/knights-analytics/tokenizers -b rebase && \ cd tokenizers && \ cargo build --release @@ -16,6 +16,7 @@ RUN git clone https://github.com/knights-analytics/tokenizers -b main && \ FROM --platform=$BUILD_PLATFORM public.ecr.aws/amazonlinux/amazonlinux:2023 AS hugot-build ARG GO_VERSION ARG ONNXRUNTIME_VERSION +ARG CGO_LDFLAGS RUN dnf -y install gcc jq bash tar xz gzip glibc-static libstdc++ wget zip git && \ ln -s /usr/lib64/libstdc++.so.6 /usr/lib64/libstdc++.so && \ diff --git a/README.md b/README.md index 2b1909f..e0cc1d7 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Currently, we have implementations for the following transfomer pipelines: - [featureExtraction](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.FeatureExtractionPipeline) - [textClassification](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.TextClassificationPipeline) - [tokenClassification](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.TokenClassificationPipeline) +- [zeroShotClassification](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.ZeroShotClassificationPipeline) Implementations for additional pipelines will follow. We also very gladly accept PRs to expand the set of pipelines! See [here](https://huggingface.co/docs/transformers/en/main_classes/pipelines) for the missing pipelines that can be implemented, and the contributing section below if you want to lend a hand. @@ -70,6 +71,7 @@ Pipelines are also tested on specifically NLP use cases. In particular, we use t - feature extraction: all-MiniLM-L6-v2 - text classification: distilbert-base-uncased-finetuned-sst-2-english - token classification: distilbert-NER and Roberta-base-go_emotions +- zero shot classification: protectai/deberta-v3-base-zeroshot-v1-onnx If you encounter any further issues or want further features, please open an issue. @@ -81,13 +83,13 @@ Hugot can be used in two ways: as a library in your go application, or as a comm To use Hugot as a library in your application, you will need the following dependencies on your system: -- the tokenizers.a file obtained from building the [tokenizer](https://github.com/Knights-Analytics/tokenizers) go library (which is itself a fork of https://github.com/daulet/tokenizers). This file should be at /usr/lib/tokenizers.a so that hugot can load it. +- the tokenizers.a file obtained from building the [tokenizer](https://github.com/daulet/tokenizers) go library (which is itself a fork of https://github.com/daulet/tokenizers). This file should be at /usr/lib/tokenizers.a so that hugot can load it. - the onnxruntime.go file obtained from the onnxruntime project. This is dynamically linked by hugot and used by the onnxruntime inference library [onnxruntime_go](https://github.com/yalue/onnxruntime_go). This file should be at /usr/lib/onnxruntime.so or /usr/lib64/onnxruntime.so You can get the libtokenizers.a in two ways. Assuming you have rust installed, you can compile the tokenizers library and get the required libtokenizers.a: ``` -git clone https://github.com/Knights-Analytics/tokenizers -b main && \ +git clone https://github.com/daulet/tokenizers -b main && \ cd tokenizers && \ cargo build --release mv target/release/libtokenizers.a /usr/lib/libtokenizers.a diff --git a/cmd/main.go b/cmd/main.go index 843145b..6e7a503 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -296,7 +296,6 @@ func main() { } func writeOutputs(wg *sync.WaitGroup, processedChannel chan []byte, errorChannel chan error, writeTarget io.WriteCloser) { - for processedChannel != nil || errorChannel != nil { select { case output, ok := <-processedChannel: diff --git a/cmd/main_test.go b/cmd/main_test.go index 829d859..690cf0b 100644 --- a/cmd/main_test.go +++ b/cmd/main_test.go @@ -87,7 +87,7 @@ func TestFeatureExtractionCli(t *testing.T) { } baseArgs := os.Args[0:1] - testModel := path.Join("../models", "KnightsAnalytics_all-MiniLM-L6-v2") + testModel := path.Join("../models", "sentence-transformers_all-MiniLM-L6-v2") testDataDir := path.Join(os.TempDir(), "hugoTestData") err := os.MkdirAll(testDataDir, os.ModePerm) diff --git a/downloader.go b/downloader.go index 68c78e7..1cfbd6b 100644 --- a/downloader.go +++ b/downloader.go @@ -14,8 +14,6 @@ import ( "time" hfd "github.com/bodaay/HuggingFaceModelDownloader/hfdownloader" - - util "github.com/knights-analytics/hugot/utils" ) // DownloadOptions is a struct of options that can be passed to DownloadModel @@ -122,7 +120,7 @@ func checkURL(client *http.Client, url string, authToken string) (bool, bool, er var dirs []hfFile for _, f := range filesList { - if f.Path == "tokenizer.json" { + if filepath.Base(f.Path) == "tokenizer.json" { tokenizerFound = true } if filepath.Ext(f.Path) == ".onnx" { @@ -153,26 +151,3 @@ func checkURL(client *http.Client, url string, authToken string) (bool, bool, er return tokenizerFound, onnxFound, nil } - -func downloadModelIfNotExists(session *Session, modelName string, destination string) string { - - modelNameFS := modelName - if strings.Contains(modelNameFS, ":") { - modelNameFS = strings.Split(modelName, ":")[0] - } - modelNameFS = path.Join(destination, strings.Replace(modelNameFS, "/", "_", -1)) - - fullModelPath := path.Join(destination, modelNameFS) - exists, err := util.FileSystem.Exists(context.Background(), fullModelPath) - if err != nil { - panic(err) - } - if exists { - return fullModelPath - } - fullModelPath, err = session.DownloadModel(modelName, destination, NewDownloadOptions()) - if err != nil { - panic(err) - } - return fullModelPath -} diff --git a/go.mod b/go.mod index 67cfd2c..bf876e9 100644 --- a/go.mod +++ b/go.mod @@ -4,21 +4,23 @@ go 1.20 replace github.com/viant/afsc => github.com/knights-analytics/afsc v0.0.0-20240425201009-7e46526445df +replace github.com/daulet/tokenizers => github.com/knights-analytics/tokenizers v0.0.0-20240717085127-ca3ae0687267 + require ( github.com/bodaay/HuggingFaceModelDownloader v0.0.0-20240307153905-2f38356a6d6c + github.com/daulet/tokenizers v0.8.0 github.com/json-iterator/go v1.1.12 - github.com/knights-analytics/tokenizers v0.12.1 github.com/mattn/go-isatty v0.0.20 github.com/stretchr/testify v1.9.0 github.com/urfave/cli/v2 v2.27.2 github.com/viant/afs v1.25.1 github.com/viant/afsc v1.9.2 github.com/yalue/onnxruntime_go v1.10.0 - golang.org/x/exp v0.0.0-20240529005216-23cca8864a10 + golang.org/x/exp v0.0.0-20240716175740-e3f259677ff7 ) require ( - github.com/aws/aws-sdk-go v1.53.12 // indirect + github.com/aws/aws-sdk-go v1.54.19 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/fatih/color v1.17.0 // indirect @@ -32,8 +34,8 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect - golang.org/x/crypto v0.23.0 // indirect - golang.org/x/oauth2 v0.20.0 // indirect - golang.org/x/sys v0.20.0 // indirect + golang.org/x/crypto v0.25.0 // indirect + golang.org/x/oauth2 v0.21.0 // indirect + golang.org/x/sys v0.22.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 4b8b1d4..826cfcf 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/aws/aws-sdk-go v1.53.12 h1:8f8K+YaTy2qwtGwVIo2Ftq22UCH96xQAX7Q0lyZKDiA= -github.com/aws/aws-sdk-go v1.53.12/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +github.com/aws/aws-sdk-go v1.54.19 h1:tyWV+07jagrNiCcGRzRhdtVjQs7Vy41NwsuOcl0IbVI= +github.com/aws/aws-sdk-go v1.54.19/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/bodaay/HuggingFaceModelDownloader v0.0.0-20240307153905-2f38356a6d6c h1:3TPq2BhzOquTGmbS53KeGcM1yalBUb/4zQM1wmaINrE= github.com/bodaay/HuggingFaceModelDownloader v0.0.0-20240307153905-2f38356a6d6c/go.mod h1:p6JQ7mJjWx82F+SrFfj9RkoHlKEGXR4959uX/vkMbzE= github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4= @@ -21,8 +21,8 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/knights-analytics/afsc v0.0.0-20240425201009-7e46526445df h1:rVna1iJaI7gj5RonGys0dZ0iLy7upULdcbRQd9F2qg8= github.com/knights-analytics/afsc v0.0.0-20240425201009-7e46526445df/go.mod h1:yZo80n1EB2eMwmmec7BekX6clpd7uY+joUpDRIBbeYs= -github.com/knights-analytics/tokenizers v0.12.1 h1:5bIxk3SQKXIHKxlzAOmqPXgFeKE+LCvbXS3hpTgOAX4= -github.com/knights-analytics/tokenizers v0.12.1/go.mod h1:TD+zVXlFlS4QyP6/RN8SPSAKkT2hpMmF64WdrdbBfts= +github.com/knights-analytics/tokenizers v0.0.0-20240717085127-ca3ae0687267 h1:M2jdyK5zl/AUe1ZBLUWqAAjSu6LwF9ZFegk+UBMjVjY= +github.com/knights-analytics/tokenizers v0.0.0-20240717085127-ca3ae0687267/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= @@ -54,19 +54,17 @@ github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGC github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yalue/onnxruntime_go v1.10.0 h1:om1yzOQYv/4GlsSP5HIZvS6G3WF3THv4x5rhO5AFERU= github.com/yalue/onnxruntime_go v1.10.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/exp v0.0.0-20240529005216-23cca8864a10 h1:vpzMC/iZhYFAjJzHU0Cfuq+w1vLLsF2vLkDrPjzKYck= -golang.org/x/exp v0.0.0-20240529005216-23cca8864a10/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= -golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= -golang.org/x/oauth2 v0.20.0 h1:4mQdhULixXKP1rwYBW0vAijoXnkTG0BLCDRzfe1idMo= -golang.org/x/oauth2 v0.20.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/exp v0.0.0-20240716175740-e3f259677ff7 h1:wDLEX9a7YQoKdKNQt88rtydkqDxeGaBUTnIYc3iG/mA= +golang.org/x/exp v0.0.0-20240716175740-e3f259677ff7/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= +golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= +golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/hugot.go b/hugot.go index 5daf02b..1d25bed 100644 --- a/hugot.go +++ b/hugot.go @@ -14,10 +14,11 @@ import ( // Session allows for the creation of new pipelines and holds the pipeline already created. type Session struct { - featureExtractionPipelines pipelineMap[*pipelines.FeatureExtractionPipeline] - tokenClassificationPipelines pipelineMap[*pipelines.TokenClassificationPipeline] - textClassificationPipelines pipelineMap[*pipelines.TextClassificationPipeline] - ortOptions *ort.SessionOptions + featureExtractionPipelines pipelineMap[*pipelines.FeatureExtractionPipeline] + tokenClassificationPipelines pipelineMap[*pipelines.TokenClassificationPipeline] + textClassificationPipelines pipelineMap[*pipelines.TextClassificationPipeline] + zeroShotClassifcationPipelines pipelineMap[*pipelines.ZeroShotClassificationPipeline] + ortOptions *ort.SessionOptions } type pipelineMap[T pipelines.Pipeline] map[string]T @@ -38,28 +39,32 @@ func (m pipelineMap[T]) GetStats() []string { return stats } -// TokenClassificationConfig is the configuration for a token classification pipeline -type TokenClassificationConfig = pipelines.PipelineConfig[*pipelines.TokenClassificationPipeline] +// FeatureExtractionConfig is the configuration for a feature extraction pipeline +type FeatureExtractionConfig = pipelines.PipelineConfig[*pipelines.FeatureExtractionPipeline] + +// FeatureExtractionOption is an option for a feature extraction pipeline +type FeatureExtractionOption = pipelines.PipelineOption[*pipelines.FeatureExtractionPipeline] // TextClassificationConfig is the configuration for a text classification pipeline type TextClassificationConfig = pipelines.PipelineConfig[*pipelines.TextClassificationPipeline] -// FeatureExtractionConfig is the configuration for a feature extraction pipeline -type FeatureExtractionConfig = pipelines.PipelineConfig[*pipelines.FeatureExtractionPipeline] +// type ZSCConfig = pipelines.PipelineConfig[*pipelines.ZeroShotClassificationPipeline] -// TokenClassificationOption is an option for a token classification pipeline -type TokenClassificationOption = pipelines.PipelineOption[*pipelines.TokenClassificationPipeline] +type ZeroShotClassificationConfig = pipelines.PipelineConfig[*pipelines.ZeroShotClassificationPipeline] // TextClassificationOption is an option for a text classification pipeline type TextClassificationOption = pipelines.PipelineOption[*pipelines.TextClassificationPipeline] -// FeatureExtractionOption is an option for a feature extraction pipeline -type FeatureExtractionOption = pipelines.PipelineOption[*pipelines.FeatureExtractionPipeline] +// TokenClassificationConfig is the configuration for a token classification pipeline +type TokenClassificationConfig = pipelines.PipelineConfig[*pipelines.TokenClassificationPipeline] + +// TokenClassificationOption is an option for a token classification pipeline +type TokenClassificationOption = pipelines.PipelineOption[*pipelines.TokenClassificationPipeline] // NewSession is the main entrypoint to hugot and is used to create a new hugot session object. // ortLibraryPath should be the path to onnxruntime.so. If it's the empty string, hugot will try // to load the library from the default location (/usr/lib/onnxruntime.so). -// A new session must be destroyed when it's not needed anymore to avoid memory leaks. See the Destroy method. +// A new session must be destroyed when it's not needed any more to avoid memory leaks. See the Destroy method. // Note moreover that there can be at most one hugot session active (i.e., the Session object is a singleton), // otherwise NewSession will return an error. func NewSession(options ...WithOption) (*Session, error) { @@ -69,9 +74,10 @@ func NewSession(options ...WithOption) (*Session, error) { } session := &Session{ - featureExtractionPipelines: map[string]*pipelines.FeatureExtractionPipeline{}, - tokenClassificationPipelines: map[string]*pipelines.TokenClassificationPipeline{}, - textClassificationPipelines: map[string]*pipelines.TextClassificationPipeline{}, + featureExtractionPipelines: map[string]*pipelines.FeatureExtractionPipeline{}, + textClassificationPipelines: map[string]*pipelines.TextClassificationPipeline{}, + tokenClassificationPipelines: map[string]*pipelines.TokenClassificationPipeline{}, + zeroShotClassifcationPipelines: map[string]*pipelines.ZeroShotClassificationPipeline{}, } // set session options and initialise @@ -248,6 +254,14 @@ func NewPipeline[T pipelines.Pipeline](s *Session, pipelineConfig pipelines.Pipe } s.featureExtractionPipelines[config.Name] = pipelineInitialised pipeline = any(pipelineInitialised).(T) + case *pipelines.ZeroShotClassificationPipeline: + config := any(pipelineConfig).(pipelines.PipelineConfig[*pipelines.ZeroShotClassificationPipeline]) + pipelineInitialised, err := pipelines.NewZeroShotClassificationPipeline(config, s.ortOptions) + if err != nil { + return pipeline, err + } + s.zeroShotClassifcationPipelines[config.Name] = pipelineInitialised + pipeline = any(pipelineInitialised).(T) default: return pipeline, fmt.Errorf("not implemented") } @@ -276,18 +290,25 @@ func GetPipeline[T pipelines.Pipeline](s *Session, name string) (T, error) { return pipeline, &pipelineNotFoundError{pipelineName: name} } return any(p).(T), nil + case *pipelines.ZeroShotClassificationPipeline: + p, ok := s.zeroShotClassifcationPipelines[name] + if !ok { + return pipeline, &pipelineNotFoundError{pipelineName: name} + } + return any(p).(T), nil default: return pipeline, errors.New("pipeline type not supported") } } // Destroy deletes the hugot session and onnxruntime environment and all initialized pipelines, freeing memory. -// A hugot session should be destroyed when not neeeded anymore, preferably with a defer() call. +// A hugot session should be destroyed when not neeeded any more, preferably with a defer() call. func (s *Session) Destroy() error { return errors.Join( s.featureExtractionPipelines.Destroy(), s.tokenClassificationPipelines.Destroy(), s.textClassificationPipelines.Destroy(), + s.zeroShotClassifcationPipelines.Destroy(), s.ortOptions.Destroy(), ort.DestroyEnvironment(), ) @@ -302,67 +323,10 @@ func (s *Session) Destroy() error { // the average time per onnxruntime inference batch call func (s *Session) GetStats() []string { // slices.Concat() is not implemented in experimental x/exp/slices package - return append(append(s.tokenClassificationPipelines.GetStats(), + return append(append(append( + s.tokenClassificationPipelines.GetStats(), s.textClassificationPipelines.GetStats()...), - s.featureExtractionPipelines.GetStats()..., + s.featureExtractionPipelines.GetStats()...), + s.zeroShotClassifcationPipelines.GetStats()..., ) } - -// deprecated methods - -// NewTokenClassificationPipeline creates and returns a new token classification pipeline object. -// modelPath should be the path to a folder with the onnx exported transformer model. Name is an identifier -// for the pipeline (see GetTokenClassificationPipeline). -// Deprecated: use NewPipeline -func (s *Session) NewTokenClassificationPipeline(modelPath string, name string, opts ...TokenClassificationOption) (*pipelines.TokenClassificationPipeline, error) { - config := pipelines.PipelineConfig[*pipelines.TokenClassificationPipeline]{ - ModelPath: modelPath, - Name: name, - Options: opts, - } - return NewPipeline(s, config) -} - -// NewTextClassificationPipeline creates and returns a new text classification pipeline object. -// modelPath should be the path to a folder with the onnx exported transformer model. Name is an identifier -// for the pipeline (see GetTextClassificationPipeline). -// Deprecated: use NewPipeline -func (s *Session) NewTextClassificationPipeline(modelPath string, name string, opts ...TextClassificationOption) (*pipelines.TextClassificationPipeline, error) { - config := pipelines.PipelineConfig[*pipelines.TextClassificationPipeline]{ - ModelPath: modelPath, - Name: name, - Options: opts, - } - return NewPipeline(s, config) -} - -// NewFeatureExtractionPipeline creates and returns a new feature extraction pipeline object. -// modelPath should be the path to a folder with the onnx exported transformer model. Name is an identifier -// for the pipeline (see GetFeatureExtractionPipeline). -// Deprecated: use NewPipeline -func (s *Session) NewFeatureExtractionPipeline(modelPath string, name string, opts ...FeatureExtractionOption) (*pipelines.FeatureExtractionPipeline, error) { - config := pipelines.PipelineConfig[*pipelines.FeatureExtractionPipeline]{ - ModelPath: modelPath, - Name: name, - Options: opts, - } - return NewPipeline(s, config) -} - -// GetFeatureExtractionPipeline returns a feature extraction pipeline by name. If the name does not exist, it will return an error. -// Deprecated: use GetPipeline. -func (s *Session) GetFeatureExtractionPipeline(name string) (*pipelines.FeatureExtractionPipeline, error) { - return GetPipeline[*pipelines.FeatureExtractionPipeline](s, name) -} - -// GetTextClassificationPipeline returns a text classification pipeline by name. If the name does not exist, it will return an error. -// Deprecated: use GetPipeline. -func (s *Session) GetTextClassificationPipeline(name string) (*pipelines.TextClassificationPipeline, error) { - return GetPipeline[*pipelines.TextClassificationPipeline](s, name) -} - -// GetTokenClassificationPipeline returns a token classification pipeline by name. If the name does not exist, it will return an error. -// Deprecated: use GetPipeline. -func (s *Session) GetTokenClassificationPipeline(name string) (*pipelines.TokenClassificationPipeline, error) { - return GetPipeline[*pipelines.TokenClassificationPipeline](s, name) -} diff --git a/hugot_test.go b/hugot_test.go index 5cea24d..75a4770 100644 --- a/hugot_test.go +++ b/hugot_test.go @@ -10,10 +10,11 @@ import ( "strings" "testing" - "github.com/stretchr/testify/assert" - "github.com/knights-analytics/hugot/pipelines" util "github.com/knights-analytics/hugot/utils" + "github.com/stretchr/testify/assert" + + ort "github.com/yalue/onnxruntime_go" ) //go:embed testData/tokenExpected.json @@ -33,6 +34,163 @@ func TestDownloadValidation(t *testing.T) { // a model without tokenizer.json or .onnx model should error err = validateDownloadHfModel("ByteDance/SDXL-Lightning", "main", "") assert.Error(t, err) + // a model with the required files in a subfolder should not error + err = validateDownloadHfModel("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "main", "") + assert.NoError(t, err) +} + +// FEATURE EXTRACTION + +func TestFeatureExtractionPipelineValidation(t *testing.T) { + session, err := NewSession(WithOnnxLibraryPath(onnxRuntimeSharedLibrary)) + check(t, err) + defer func(session *Session) { + err := session.Destroy() + check(t, err) + }(session) + + modelPath := "./models/sentence-transformers_all-MiniLM-L6-v2" + config := FeatureExtractionConfig{ + ModelPath: modelPath, + Name: "testPipeline", + } + pipeline, err := NewPipeline(session, config) + check(t, err) + + pipeline.InputsMeta[0].Dimensions = ort.NewShape(-1, -1, -1) + + err = pipeline.Validate() + assert.Error(t, err) + + pipeline.InputsMeta[0].Dimensions = ort.NewShape(1, 1, 1, 1) + err = pipeline.Validate() + assert.Error(t, err) +} + +func TestFeatureExtractionPipeline(t *testing.T) { + session, err := NewSession(WithOnnxLibraryPath(onnxRuntimeSharedLibrary)) + check(t, err) + defer func(session *Session) { + err := session.Destroy() + check(t, err) + }(session) + + modelPath := "./models/sentence-transformers_all-MiniLM-L6-v2" + + config := FeatureExtractionConfig{ + ModelPath: modelPath, + Name: "testPipeline", + } + pipeline, err := NewPipeline(session, config) + check(t, err) + + var expectedResults map[string][][]float32 + err = json.Unmarshal(resultsByte, &expectedResults) + check(t, err) + var testResults [][]float32 + + // test 'robert smith' + testResults = expectedResults["test1output"] + batchResult, err := pipeline.RunPipeline([]string{"robert smith"}) + if err != nil { + t.Fatalf(err.Error()) + } + for i := range batchResult.Embeddings { + e := floatsEqual(batchResult.Embeddings[i], testResults[i]) + if e != nil { + t.Logf("Test 1: The neural network didn't produce the correct result on loop %d: %s\n", i, e) + t.FailNow() + } + } + + // test ['robert smith junior', 'francis ford coppola'] + testResults = expectedResults["test2output"] + batchResult, err = pipeline.RunPipeline([]string{"robert smith junior", "francis ford coppola"}) + if err != nil { + t.FailNow() + } + for i := range batchResult.Embeddings { + e := floatsEqual(batchResult.Embeddings[i], testResults[i]) + if e != nil { + t.Logf("Test 1: The neural network didn't produce the correct result on loop %d: %s\n", i, e) + t.FailNow() + } + } + + // determinism test to make sure embeddings of a string are not influenced by other strings in the batch + testPairs := map[string][][]string{} + testPairs["identity"] = [][]string{{"sinopharm", "yo"}, {"sinopharm", "yo"}} + testPairs["contextOverlap"] = [][]string{{"sinopharm", "yo"}, {"sinopharm", "yo mama yo"}} + testPairs["contextDisjoint"] = [][]string{{"sinopharm", "yo"}, {"sinopharm", "another test"}} + + for k, sentencePair := range testPairs { + // these vectors should be the same + firstBatchResult, err2 := pipeline.RunPipeline(sentencePair[0]) + check(t, err2) + firstEmbedding := firstBatchResult.Embeddings[0] + + secondBatchResult, err3 := pipeline.RunPipeline(sentencePair[1]) + check(t, err3) + secondEmbedding := secondBatchResult.Embeddings[0] + e := floatsEqual(firstEmbedding, secondEmbedding) + if e != nil { + t.Logf("Equality failed for determinism test %s test with pairs %s and %s", k, strings.Join(sentencePair[0], ","), strings.Join(sentencePair[1], ",")) + t.Log("First vector", firstEmbedding) + t.Log("second vector", secondEmbedding) + t.Fail() + } + } + + zero := uint64(0) + assert.Greater(t, pipeline.PipelineTimings.NumCalls, zero, "PipelineTimings.NumCalls should be greater than 0") + assert.Greater(t, pipeline.PipelineTimings.TotalNS, zero, "PipelineTimings.TotalNS should be greater than 0") + assert.Greater(t, pipeline.TokenizerTimings.NumCalls, zero, "TokenizerTimings.NumCalls should be greater than 0") + assert.Greater(t, pipeline.TokenizerTimings.TotalNS, zero, "TokenizerTimings.TotalNS should be greater than 0") + + // test normalization + testResults = expectedResults["normalizedOutput"] + config = FeatureExtractionConfig{ + ModelPath: modelPath, + Name: "testPipelineNormalise", + Options: []FeatureExtractionOption{ + pipelines.WithNormalization(), + }, + } + pipeline, err = NewPipeline(session, config) + check(t, err) + normalizationStrings := []string{"Onnxruntime is a great inference backend"} + normalizedEmbedding, err := pipeline.RunPipeline(normalizationStrings) + check(t, err) + for i, embedding := range normalizedEmbedding.Embeddings { + e := floatsEqual(embedding, testResults[i]) + if e != nil { + t.Fatalf("Normalization test failed: %s", normalizationStrings[i]) + } + } + + // test getting sentence embeddings + configSentence := FeatureExtractionConfig{ + ModelPath: modelPath, + Name: "testPipelineSentence", + Options: []FeatureExtractionOption{pipelines.WithOutputName("sentence_embedding")}, + } + pipelineSentence, err := NewPipeline(session, configSentence) + check(t, err) + outputSentence, err := pipelineSentence.RunPipeline([]string{"Onnxruntime is a great inference backend"}) + if err != nil { + t.FailNow() + } + fmt.Println(outputSentence.Embeddings[0]) + configSentence = FeatureExtractionConfig{ + ModelPath: modelPath, + Name: "testPipelineToken", + } + pipelineToken, err := NewPipeline(session, configSentence) + check(t, err) + _, err = pipelineToken.RunPipeline([]string{"Onnxruntime is a great inference backend"}) + if err != nil { + t.FailNow() + } } // Text classification @@ -51,7 +209,9 @@ func TestTextClassificationPipeline(t *testing.T) { errDestroy := session.Destroy() check(t, errDestroy) }(session) - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/distilbert-base-uncased-finetuned-sst-2-english", "./models") + modelPath := "./models/KnightsAnalytics_distilbert-base-uncased-finetuned-sst-2-english" + modelPathMulti := "./models/SamLowe_roberta-base-go_emotions-onnx" + config := TextClassificationConfig{ ModelPath: modelPath, Name: "testPipelineSimple", @@ -62,7 +222,6 @@ func TestTextClassificationPipeline(t *testing.T) { sentimentPipeline, err := NewPipeline(session, config) check(t, err) - modelPathMulti := downloadModelIfNotExists(session, "SamLowe/roberta-base-go_emotions-onnx", "./models") configMulti := TextClassificationConfig{ ModelPath: modelPathMulti, Name: "testPipelineSimpleMulti", @@ -248,7 +407,7 @@ func TestTextClassificationPipelineValidation(t *testing.T) { err := session.Destroy() check(t, err) }(session) - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/distilbert-base-uncased-finetuned-sst-2-english", "./models") + modelPath := "./models/KnightsAnalytics_distilbert-base-uncased-finetuned-sst-2-english" config := TextClassificationConfig{ ModelPath: modelPath, @@ -259,20 +418,290 @@ func TestTextClassificationPipelineValidation(t *testing.T) { } sentimentPipeline, err := NewPipeline(session, config) check(t, err) - sentimentPipeline.IdLabelMap = map[int]string{} - err = sentimentPipeline.Validate() - assert.Error(t, err) - if err != nil { - errInt := err.(interface{ Unwrap() []error }) - assert.Equal(t, 3, len(errInt.Unwrap())) + + t.Run("id-label-map", func(t *testing.T) { + labelMapInitial := sentimentPipeline.IDLabelMap + defer func() { + sentimentPipeline.IDLabelMap = labelMapInitial + }() + sentimentPipeline.IDLabelMap = map[int]string{} + err = sentimentPipeline.Validate() + assert.Error(t, err) + }) + + t.Run("output-shape", func(t *testing.T) { + dimensionInitial := sentimentPipeline.OutputsMeta[0].Dimensions + defer func() { + sentimentPipeline.OutputsMeta[0].Dimensions = dimensionInitial + }() + sentimentPipeline.OutputsMeta[0].Dimensions = ort.NewShape(-1, -1, -1) + err = sentimentPipeline.Validate() + assert.Error(t, err) + }) +} + +func TestZeroShotClassificationPipeline(t *testing.T) { + session, err := NewSession() + check(t, err) + defer func(session *Session) { + err := session.Destroy() + check(t, err) + }(session) + + modelPath := "./models/protectai_deberta-v3-base-zeroshot-v1-onnx" + + config := ZeroShotClassificationConfig{ + ModelPath: modelPath, + Name: "testPipeline", + Options: []pipelines.PipelineOption[*pipelines.ZeroShotClassificationPipeline]{ + pipelines.WithHypothesisTemplate("This example is {}."), + pipelines.WithLabels([]string{"fun", "dangerous"}), + }, } - sentimentPipeline.OutputDim = 0 - err = sentimentPipeline.Validate() - assert.Error(t, err) - if err != nil { - errInt := err.(interface{ Unwrap() []error }) - assert.Equal(t, 3, len(errInt.Unwrap())) + + classificationPipeline, err := NewPipeline(session, config) + check(t, err) + + tests := []struct { + pipeline *pipelines.ZeroShotClassificationPipeline + name string + sequences []string + labels []string + multilabel bool + expected pipelines.ZeroShotOutput + }{ + { + pipeline: classificationPipeline, + name: "single sequence, single label, no multilabel", + sequences: []string{"I am going to the park"}, + labels: []string{"fun"}, + multilabel: false, + expected: pipelines.ZeroShotOutput{ + ClassificationOutputs: []pipelines.ZeroShotClassificationOutput{ + { + Sequence: "I am going to the park", + SortedValues: []struct { + Key string + Value float64 + }{ + { + Key: "fun", + Value: 0.0009069009101949632, + }, + }, + }, + }, + }, + }, + { + pipeline: classificationPipeline, + name: "multiple sequences, multiple labels, no multilabel", + sequences: []string{"I am going to the park", "I will watch Interstellar tonight"}, + labels: []string{"fun", "movie"}, + multilabel: false, + expected: pipelines.ZeroShotOutput{ + ClassificationOutputs: []pipelines.ZeroShotClassificationOutput{ + { + Sequence: "I am going to the park", + SortedValues: []struct { + Key string + Value float64 + }{ + { + Key: "fun", + Value: 0.7746766209602356, + }, + { + Key: "movie", + Value: 0.2253233790397644, + }, + }, + }, + { + Sequence: "I will watch Interstellar tonight", + SortedValues: []struct { + Key string + Value float64 + }{ + { + Key: "movie", + Value: 0.9984978437423706, + }, + { + Key: "fun", + Value: 0.001502170693129301, + }, + }, + }, + }, + }, + }, + { + pipeline: classificationPipeline, + name: "multiple sequences, multiple labels, multilabel", + sequences: []string{"I am going to the park", "I will watch Interstellar tonight"}, + labels: []string{"fun", "movie"}, + multilabel: true, + expected: pipelines.ZeroShotOutput{ + ClassificationOutputs: []pipelines.ZeroShotClassificationOutput{ + { + Sequence: "I am going to the park", + SortedValues: []struct { + Key string + Value float64 + }{ + { + Key: "fun", + Value: 0.0009069009101949632, + }, + { + Key: "movie", + Value: 0.00009480675362283364, + }, + }, + }, + { + Sequence: "I will watch Interstellar tonight", + SortedValues: []struct { + Key string + Value float64 + }{ + { + Key: "movie", + Value: 0.9985591769218445, + }, + { + Key: "fun", + Value: 0.0006653196760453284, + }, + }, + }, + }, + }, + }, + { + pipeline: classificationPipeline, + name: "multiple sequences, single label, multilabel", + sequences: []string{"I am going to the park", "I will watch Interstellar tonight"}, + labels: []string{"fun"}, + multilabel: true, + expected: pipelines.ZeroShotOutput{ + ClassificationOutputs: []pipelines.ZeroShotClassificationOutput{ + { + Sequence: "I am going to the park", + SortedValues: []struct { + Key string + Value float64 + }{ + { + Key: "fun", + Value: 0.0009069009101949632, + }, + }, + }, + { + Sequence: "I will watch Interstellar tonight", + SortedValues: []struct { + Key string + Value float64 + }{ + { + Key: "fun", + Value: 0.0006653196760453284, + }, + }, + }, + }, + }, + }, + { + pipeline: classificationPipeline, + name: "single sequence, multiple labels, multilabel=false", + sequences: []string{"Please don't bother me, I'm in a rush"}, + labels: []string{"busy", "relaxed", "stressed"}, + multilabel: false, + expected: pipelines.ZeroShotOutput{ + ClassificationOutputs: []pipelines.ZeroShotClassificationOutput{ + { + Sequence: "Please don't bother me, I'm in a rush", + SortedValues: []struct { + Key string + Value float64 + }{ + { + Key: "stressed", + Value: 0.8865461349487305, + }, + { + Key: "busy", + Value: 0.10629364103078842, + }, + { + Key: "relaxed", + Value: 0.007160270120948553, + }, + }, + }, + }, + }, + }, } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + classificationPipeline.Labels = tt.labels + classificationPipeline.Multilabel = tt.multilabel + batchResult, _ := tt.pipeline.RunPipeline(tt.sequences) + assert.Equal(t, len(batchResult.GetOutput()), len(tt.expected.ClassificationOutputs)) + + for ind, expected := range tt.expected.ClassificationOutputs { + expectedResult := expected.SortedValues + testResult := batchResult.ClassificationOutputs[ind].SortedValues + assert.Equal(t, len(expectedResult), len(testResult)) + assert.Equal(t, tt.expected.ClassificationOutputs[ind].Sequence, batchResult.ClassificationOutputs[ind].Sequence) + for i := range testResult { + assert.True(t, almostEqual(testResult[i].Value, expectedResult[i].Value)) + } + } + }) + } +} + +func TestZeroShotClassificationPipelineValidation(t *testing.T) { + session, err := NewSession(WithOnnxLibraryPath(onnxRuntimeSharedLibrary)) + check(t, err) + defer func(session *Session) { + err := session.Destroy() + check(t, err) + }(session) + modelPath := "./models/protectai_deberta-v3-base-zeroshot-v1-onnx" + + config := TextClassificationConfig{ + ModelPath: modelPath, + Name: "testPipelineSimple", + } + sentimentPipeline, err := NewPipeline(session, config) + check(t, err) + + t.Run("id-label-map", func(t *testing.T) { + labelMapInitial := sentimentPipeline.IDLabelMap + defer func() { + sentimentPipeline.IDLabelMap = labelMapInitial + }() + sentimentPipeline.IDLabelMap = map[int]string{} + err = sentimentPipeline.Validate() + assert.Error(t, err) + }) + + t.Run("output-shape", func(t *testing.T) { + dimensionInitial := sentimentPipeline.OutputsMeta[0].Dimensions + defer func() { + sentimentPipeline.OutputsMeta[0].Dimensions = dimensionInitial + }() + sentimentPipeline.OutputsMeta[0].Dimensions = ort.NewShape(-1, -1, -1) + err = sentimentPipeline.Validate() + assert.Error(t, err) + }) } // Token classification @@ -285,7 +714,7 @@ func TestTokenClassificationPipeline(t *testing.T) { check(t, err) }(session) - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/distilbert-NER", "./models") + modelPath := "./models/KnightsAnalytics_distilbert-NER" configSimple := TokenClassificationConfig{ ModelPath: modelPath, Name: "testPipelineSimple", @@ -362,7 +791,7 @@ func TestTokenClassificationPipelineValidation(t *testing.T) { check(t, err) }(session) - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/distilbert-NER", "./models") + modelPath := "./models/KnightsAnalytics_distilbert-NER" configSimple := TokenClassificationConfig{ ModelPath: modelPath, Name: "testPipelineSimple", @@ -374,20 +803,25 @@ func TestTokenClassificationPipelineValidation(t *testing.T) { pipelineSimple, err2 := NewPipeline(session, configSimple) check(t, err2) - pipelineSimple.IdLabelMap = map[int]string{} - err = pipelineSimple.Validate() - assert.Error(t, err) - if err != nil { - errInt := err.(interface{ Unwrap() []error }) - assert.Equal(t, 2, len(errInt.Unwrap())) - } - pipelineSimple.OutputDim = 0 - err = pipelineSimple.Validate() - assert.Error(t, err) - if err != nil { - errInt := err.(interface{ Unwrap() []error }) - assert.Equal(t, 2, len(errInt.Unwrap())) - } + t.Run("id-label-map", func(t *testing.T) { + labelMapInitial := pipelineSimple.IDLabelMap + defer func() { + pipelineSimple.IDLabelMap = labelMapInitial + }() + pipelineSimple.IDLabelMap = map[int]string{} + err = pipelineSimple.Validate() + assert.Error(t, err) + }) + + t.Run("output-shape", func(t *testing.T) { + dimensionInitial := pipelineSimple.OutputsMeta[0].Dimensions + defer func() { + pipelineSimple.OutputsMeta[0].Dimensions = dimensionInitial + }() + pipelineSimple.OutputsMeta[0].Dimensions = ort.NewShape(-1, -1, -1) + err = pipelineSimple.Validate() + assert.Error(t, err) + }) } func TestNoSameNamePipeline(t *testing.T) { @@ -398,7 +832,7 @@ func TestNoSameNamePipeline(t *testing.T) { check(t, err) }(session) - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/distilbert-NER", "./models") + modelPath := "./models/KnightsAnalytics_distilbert-NER" configSimple := TokenClassificationConfig{ ModelPath: modelPath, Name: "testPipelineSimple", @@ -415,129 +849,6 @@ func TestNoSameNamePipeline(t *testing.T) { assert.Error(t, err3) } -// feature extraction - -func TestFeatureExtractionPipeline(t *testing.T) { - session, err := NewSession(WithOnnxLibraryPath(onnxRuntimeSharedLibrary)) - check(t, err) - defer func(session *Session) { - err := session.Destroy() - check(t, err) - }(session) - - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/all-MiniLM-L6-v2", "./models") - - config := FeatureExtractionConfig{ - ModelPath: modelPath, - Name: "testPipeline", - } - pipeline, err := NewPipeline(session, config) - check(t, err) - - var expectedResults map[string][][]float32 - err = json.Unmarshal(resultsByte, &expectedResults) - check(t, err) - var testResults [][]float32 - - // test 'robert smith' - testResults = expectedResults["test1output"] - for i := 1; i <= 10; i++ { - batchResult, err := pipeline.RunPipeline([]string{"robert smith"}) - check(t, err) - e := floatsEqual(batchResult.Embeddings[0], testResults[0]) - if e != nil { - t.Logf("Test 1: The neural network didn't produce the correct result on loop %d: %s\n", i, e) - t.FailNow() - } - } - - // test ['robert smith junior', 'francis ford coppola'] - testResults = expectedResults["test2output"] - for i := 1; i <= 10; i++ { - batchResult, err := pipeline.RunPipeline([]string{"robert smith junior", "francis ford coppola"}) - check(t, err) - for j, res := range batchResult.Embeddings { - e := floatsEqual(res, testResults[j]) - if e != nil { - t.Logf("Test 2: The neural network didn't produce the correct result on loop %d: %s\n", i, e) - t.FailNow() - } - } - } - - // determinism test to make sure embeddings of a string are not influenced by other strings in the batch - testPairs := map[string][][]string{} - testPairs["identity"] = [][]string{{"sinopharm", "yo"}, {"sinopharm", "yo"}} - testPairs["contextOverlap"] = [][]string{{"sinopharm", "yo"}, {"sinopharm", "yo mama yo"}} - testPairs["contextDisjoint"] = [][]string{{"sinopharm", "yo"}, {"sinopharm", "another test"}} - - for k, sentencePair := range testPairs { - // these vectors should be the same - firstBatchResult, err2 := pipeline.RunPipeline(sentencePair[0]) - check(t, err2) - firstEmbedding := firstBatchResult.Embeddings[0] - - secondBatchResult, err3 := pipeline.RunPipeline(sentencePair[1]) - check(t, err3) - secondEmbedding := secondBatchResult.Embeddings[0] - e := floatsEqual(firstEmbedding, secondEmbedding) - if e != nil { - t.Logf("Equality failed for determinism test %s test with pairs %s and %s", k, strings.Join(sentencePair[0], ","), strings.Join(sentencePair[1], ",")) - t.Log("First vector", firstEmbedding) - t.Log("second vector", secondEmbedding) - t.Fail() - } - } - - zero := uint64(0) - assert.Greater(t, pipeline.PipelineTimings.NumCalls, zero, "PipelineTimings.NumCalls should be greater than 0") - assert.Greater(t, pipeline.PipelineTimings.TotalNS, zero, "PipelineTimings.TotalNS should be greater than 0") - assert.Greater(t, pipeline.TokenizerTimings.NumCalls, zero, "TokenizerTimings.NumCalls should be greater than 0") - assert.Greater(t, pipeline.TokenizerTimings.TotalNS, zero, "TokenizerTimings.TotalNS should be greater than 0") - - // test normalization - testResults = expectedResults["normalizedOutput"] - config = FeatureExtractionConfig{ - ModelPath: modelPath, - Name: "testPipelineNormalise", - Options: []FeatureExtractionOption{ - pipelines.WithNormalization(), - }, - } - pipeline, err = NewPipeline(session, config) - check(t, err) - normalizationStrings := []string{"Onnxruntime is a great inference backend"} - normalizedEmbedding, err := pipeline.RunPipeline(normalizationStrings) - check(t, err) - for i, embedding := range normalizedEmbedding.Embeddings { - e := floatsEqual(embedding, testResults[i]) - if e != nil { - t.Fatalf("Normalization test failed: %s", normalizationStrings[i]) - } - } -} - -func TestFeatureExtractionPipelineValidation(t *testing.T) { - session, err := NewSession(WithOnnxLibraryPath(onnxRuntimeSharedLibrary)) - check(t, err) - defer func(session *Session) { - err := session.Destroy() - check(t, err) - }(session) - - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/all-MiniLM-L6-v2", "./models") - config := FeatureExtractionConfig{ - ModelPath: modelPath, - Name: "testPipeline", - } - pipeline, err := NewPipeline(session, config) - check(t, err) - - pipeline.OutputDim = 0 - err = pipeline.Validate() - assert.Error(t, err) -} - // README: test the readme examples func TestReadmeExample(t *testing.T) { @@ -616,7 +927,7 @@ func TestCuda(t *testing.T) { } }(session) - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/all-MiniLM-L6-v2", "./models") + modelPath := "./models/KnightsAnalytics_all-MiniLM-L6-v2" config := FeatureExtractionConfig{ ModelPath: modelPath, Name: "benchmarkEmbedding", @@ -655,7 +966,7 @@ func runBenchmarkEmbedding(strings *[]string, cuda bool) { } }(session) - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/all-MiniLM-L6-v2", "./models") + modelPath := "./models/KnightsAnalytics_all-MiniLM-L6-v2" config := FeatureExtractionConfig{ ModelPath: modelPath, Name: "benchmarkEmbedding", @@ -697,7 +1008,17 @@ func BenchmarkCPUEmbedding(b *testing.B) { } } -// utilities +// Utilities + +func checkClassificationOutput(t *testing.T, inputResult []pipelines.ClassificationOutput, inputExpected []pipelines.ClassificationOutput) { + t.Helper() + assert.Equal(t, len(inputResult), len(inputExpected)) + for i, output := range inputResult { + resultExpected := inputExpected[i] + assert.Equal(t, output.Label, resultExpected.Label) + assert.True(t, almostEqual(float64(output.Score), float64(resultExpected.Score))) + } +} // Returns an error if any element between a and b don't match. func floatsEqual(a, b []float32) error { @@ -718,16 +1039,6 @@ func floatsEqual(a, b []float32) error { return nil } -func checkClassificationOutput(t *testing.T, inputResult []pipelines.ClassificationOutput, inputExpected []pipelines.ClassificationOutput) { - t.Helper() - assert.Equal(t, len(inputResult), len(inputExpected)) - for i, output := range inputResult { - resultExpected := inputExpected[i] - assert.Equal(t, output.Label, resultExpected.Label) - assert.True(t, almostEqual(float64(output.Score), float64(resultExpected.Score))) - } -} - func almostEqual(a, b float64) bool { return math.Abs(a-b) <= 0.0001 } diff --git a/pipelines/featureExtraction.go b/pipelines/featureExtraction.go index 5aa1a8a..bd9a868 100644 --- a/pipelines/featureExtraction.go +++ b/pipelines/featureExtraction.go @@ -2,25 +2,25 @@ package pipelines import ( "errors" + "fmt" + "math" + "strings" + "sync/atomic" + "time" ort "github.com/yalue/onnxruntime_go" + "github.com/daulet/tokenizers" util "github.com/knights-analytics/hugot/utils" - "github.com/knights-analytics/tokenizers" ) // FeatureExtractionPipeline A feature extraction pipeline is a go version of // https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/feature_extraction.py - -// types - type FeatureExtractionPipeline struct { - BasePipeline + basePipeline Normalization bool -} - -type FeatureExtractionPipelineConfig struct { - IdLabelMap map[int]string `json:"id2label"` + OutputName string + Output ort.InputOutputInfo } type FeatureExtractionOutput struct { @@ -35,15 +35,24 @@ func (t *FeatureExtractionOutput) GetOutput() []any { return out } -// options +// PIPELINE OPTIONS +// WithNormalization applies normalization to the mean pooled output of the feature pipeline. func WithNormalization() PipelineOption[*FeatureExtractionPipeline] { return func(pipeline *FeatureExtractionPipeline) { pipeline.Normalization = true } } -// NewFeatureExtractionPipeline Initialize a feature extraction pipeline +// WithOutputName if there are multiple outputs from the underlying model, which output should +// be returned. If not passed, the first output from the feature pipeline is returned. +func WithOutputName(outputName string) PipelineOption[*FeatureExtractionPipeline] { + return func(pipeline *FeatureExtractionPipeline) { + pipeline.OutputName = outputName + } +} + +// NewFeatureExtractionPipeline init a feature extraction pipeline. func NewFeatureExtractionPipeline(config PipelineConfig[*FeatureExtractionPipeline], ortOptions *ort.SessionOptions) (*FeatureExtractionPipeline, error) { pipeline := &FeatureExtractionPipeline{} pipeline.ModelPath = config.ModelPath @@ -55,79 +64,207 @@ func NewFeatureExtractionPipeline(config PipelineConfig[*FeatureExtractionPipeli o(pipeline) } - // tokenizer + // tokenizer init pipeline.TokenizerOptions = []tokenizers.EncodeOption{tokenizers.WithReturnTypeIDs(), tokenizers.WithReturnAttentionMask()} + tk, err := loadTokenizer(pipeline.ModelPath) + if err != nil { + return nil, err + } + pipeline.Tokenizer = tk - pipeline.PipelineTimings = &Timings{} - pipeline.TokenizerTimings = &Timings{} + // onnx model init + model, err := loadOnnxModelBytes(pipeline.ModelPath, pipeline.OnnxFilename) + if err != nil { + return nil, err + } - // load onnx model - err := pipeline.loadModel() + // init of inputs and outputs + inputs, outputs, err := loadInputOutputMeta(model) if err != nil { return nil, err } + pipeline.InputsMeta = inputs + pipeline.OutputsMeta = outputs - // the dimension of the output is taken from the output meta. For the moment we assume that there is only one output - pipeline.OutputDim = int(pipeline.OutputsMeta[0].Dimensions[2]) + // filter outputs + if pipeline.OutputName != "" { + for _, output := range outputs { + if output.Name == pipeline.OutputName { + pipeline.Output = output + break + } + } + if pipeline.Output.Name == "" { + return nil, fmt.Errorf("output %s is not available, outputs are: %s", pipeline.OutputName, strings.Join(getNames(outputs), ", ")) + } + } else { + pipeline.Output = outputs[0] // we take the first output otherwise, like transformers does + } - err = pipeline.Validate() + // creation of the session. Only one output (either token or sentence embedding). + session, err := createSession(model, inputs, []ort.InputOutputInfo{pipeline.Output}, ortOptions) if err != nil { return nil, err } + pipeline.OrtSession = session + + // initialize timings + pipeline.PipelineTimings = &timings{} + pipeline.TokenizerTimings = &timings{} + + // validate pipeline + err = pipeline.Validate() + if err != nil { + errDestroy := pipeline.Destroy() + return nil, errors.Join(err, errDestroy) + } return pipeline, nil } +// INTERFACE IMPLEMENTATION + +// GetMetadata returns metadata information about the pipeline, in particular: +// OutputInfo: names and dimensions of the output layer. +func (p *FeatureExtractionPipeline) GetMetadata() PipelineMetadata { + return PipelineMetadata{ + OutputsInfo: []OutputInfo{ + { + Name: p.OutputName, + Dimensions: p.Output.Dimensions, + }, + }, + } +} + +// Destroy frees the feature extraction pipeline resources. +func (p *FeatureExtractionPipeline) Destroy() error { + return destroySession(p.Tokenizer, p.OrtSession) +} + +// GetStats returns the runtime statistics for the pipeline. +func (p *FeatureExtractionPipeline) GetStats() []string { + return []string{ + fmt.Sprintf("Statistics for pipeline: %s", p.PipelineName), + fmt.Sprintf("Tokenizer: Total time=%s, Execution count=%d, Average query time=%s", + time.Duration(p.TokenizerTimings.TotalNS), + p.TokenizerTimings.NumCalls, + time.Duration(float64(p.TokenizerTimings.TotalNS)/math.Max(1, float64(p.TokenizerTimings.NumCalls)))), + fmt.Sprintf("ONNX: Total time=%s, Execution count=%d, Average query time=%s", + time.Duration(p.PipelineTimings.TotalNS), + p.PipelineTimings.NumCalls, + time.Duration(float64(p.PipelineTimings.TotalNS)/math.Max(1, float64(p.PipelineTimings.NumCalls)))), + } +} + +// Validate checks that the pipeline is valid. func (p *FeatureExtractionPipeline) Validate() error { var validationErrors []error - if p.OutputDim <= 0 { - validationErrors = append(validationErrors, errors.New("pipeline configuration invalid: outputDim parameter must be greater than zero")) + for _, input := range p.InputsMeta { + dims := []int64(input.Dimensions) + if len(dims) > 3 { + validationErrors = append(validationErrors, fmt.Errorf("inputs and outputs currently can have at most 3 dimensions")) + } + nDynamicDimensions := 0 + for _, d := range dims { + if d == -1 { + nDynamicDimensions++ + } + } + if nDynamicDimensions > 2 { + validationErrors = append(validationErrors, fmt.Errorf(`input %s has dimensions: %s. + There can only be max 2 dynamic dimensions (batch size and sequence length)`, + input.Name, input.Dimensions.String())) + } } return errors.Join(validationErrors...) } -// Postprocess Parse the results of the forward pass into the output. Token embeddings are mean pooled. -func (p *FeatureExtractionPipeline) Postprocess(batch PipelineBatch) (*FeatureExtractionOutput, error) { - maxSequence := batch.MaxSequence - vectorCounter := 0 - tokenCounter := 0 - inputCounter := 0 - outputs := make([][]float32, len(batch.Input)) - tokens := make([][]float32, maxSequence) - vectors := make([]float32, p.OutputDim) - - for _, result := range batch.OutputTensor { - vectors[vectorCounter] = result - if vectorCounter == p.OutputDim-1 { - tokens[tokenCounter] = vectors - vectorCounter = 0 - vectors = make([]float32, p.OutputDim) - if tokenCounter == maxSequence-1 { - outputs[inputCounter] = meanPooling(tokens, batch.Input[inputCounter], maxSequence, p.OutputDim) - tokenCounter = 0 - tokens = make([][]float32, maxSequence) - inputCounter++ +// Preprocess tokenizes the input strings. +func (p *FeatureExtractionPipeline) Preprocess(batch *PipelineBatch, inputs []string) error { + start := time.Now() + tokenizeInputs(batch, p.Tokenizer, inputs, p.TokenizerOptions) + atomic.AddUint64(&p.TokenizerTimings.NumCalls, 1) + atomic.AddUint64(&p.TokenizerTimings.TotalNS, uint64(time.Since(start))) + err := createInputTensors(batch, p.InputsMeta) + return err +} + +// Forward performs the forward inference of the feature extraction pipeline. +func (p *FeatureExtractionPipeline) Forward(batch *PipelineBatch) error { + start := time.Now() + err := runSessionOnBatch(batch, p.OrtSession, []ort.InputOutputInfo{p.Output}) + if err != nil { + return err + } + atomic.AddUint64(&p.PipelineTimings.NumCalls, 1) + atomic.AddUint64(&p.PipelineTimings.TotalNS, uint64(time.Since(start))) + return nil +} + +// Postprocess parses the first output from the network similar to the transformers implementation. +func (p *FeatureExtractionPipeline) Postprocess(batch *PipelineBatch) (*FeatureExtractionOutput, error) { + // TODO: this works if token embeddings are returned or sentence embeddings are returned. + // in the former case embeddings are mean pooled. In the latter they are just returned. + // to make this more general for other pipelines and to allow return of raw token embeddings, + // we need an ndarray type that can be the return type of this pipeline. Need to think + // about how to do this in a lightweight manner. + + batchEmbeddings := make([][]float32, len(batch.Input)) + outputDimensions := []int64(p.Output.Dimensions) + embeddingDimension := outputDimensions[len(outputDimensions)-1] + maxSequenceLength := batch.MaxSequenceLength + + // now take the output slice and gather the results as a "matrix" + outputEmbedding := make([]float32, embeddingDimension) + outputEmbeddingCounter := 0 + tokenEmbeddings := make([][]float32, maxSequenceLength) + tokenEmbeddingsCounter := 0 + batchInputCounter := 0 + + for _, result := range batch.OutputTensors[0].GetData() { + outputEmbedding[outputEmbeddingCounter] = result + if outputEmbeddingCounter == int(embeddingDimension)-1 { + // we gathered one embedding + if len(outputDimensions) <= 2 { + // it is already a sentence embedding, just add it to batch outputs + batchEmbeddings[batchInputCounter] = outputEmbedding + outputEmbedding = make([]float32, embeddingDimension) + batchInputCounter++ } else { - tokenCounter++ + // output is embedding for a token, add to token embeddings + tokenEmbeddings[tokenEmbeddingsCounter] = outputEmbedding + outputEmbedding = make([]float32, embeddingDimension) + if tokenEmbeddingsCounter == maxSequenceLength-1 { + // computed all embeddings for the tokens, calculate sentence embedding, add to batch outputs, and reset token embeddings and counter + batchEmbeddings[batchInputCounter] = meanPooling(tokenEmbeddings, batch.Input[batchInputCounter], maxSequenceLength, int(embeddingDimension)) + tokenEmbeddings = make([][]float32, maxSequenceLength) + tokenEmbeddingsCounter = 0 + batchInputCounter++ + } else { + // still more tokens to go + tokenEmbeddingsCounter++ + } } + outputEmbeddingCounter = 0 } else { - vectorCounter++ + // still more elements of the embedding to go + outputEmbeddingCounter++ } } // Normalize embeddings (if asked), like in https://huggingface.co/sentence-transformers/all-mpnet-base-v2 if p.Normalization { - for i, output := range outputs { - outputs[i] = util.Normalize(output, 2) + for i, output := range batchEmbeddings { + batchEmbeddings[i] = util.Normalize(output, 2) } } - return &FeatureExtractionOutput{Embeddings: outputs}, nil + return &FeatureExtractionOutput{Embeddings: batchEmbeddings}, nil } -func meanPooling(tokens [][]float32, input TokenizedInput, maxSequence int, dimensions int) []float32 { - +func meanPooling(tokens [][]float32, input tokenizedInput, maxSequence int, dimensions int) []float32 { length := len(input.AttentionMask) vector := make([]float32, dimensions) for j := 0; j < maxSequence; j++ { @@ -146,16 +283,30 @@ func meanPooling(tokens [][]float32, input TokenizedInput, maxSequence int, dime return vector } -// Run the pipeline on a string batch +// Run the pipeline on a batch of strings. func (p *FeatureExtractionPipeline) Run(inputs []string) (PipelineBatchOutput, error) { return p.RunPipeline(inputs) } +// RunPipeline is like Run, but returns the concrete feature extraction output type rather than the interface. func (p *FeatureExtractionPipeline) RunPipeline(inputs []string) (*FeatureExtractionOutput, error) { - batch := p.Preprocess(inputs) - batch, forwardError := p.Forward(batch) - if forwardError != nil { - return nil, forwardError + var runErrors []error + batch := NewBatch() + defer func(*PipelineBatch) { + runErrors = append(runErrors, batch.Destroy()) + }(batch) + + runErrors = append(runErrors, p.Preprocess(batch, inputs)) + if e := errors.Join(runErrors...); e != nil { + return nil, e } - return p.Postprocess(batch) + + runErrors = append(runErrors, p.Forward(batch)) + if e := errors.Join(runErrors...); e != nil { + return nil, e + } + + result, postErr := p.Postprocess(batch) + runErrors = append(runErrors, postErr) + return result, errors.Join(runErrors...) } diff --git a/pipelines/pipeline.go b/pipelines/pipeline.go index fd3e02b..cc7cd40 100644 --- a/pipelines/pipeline.go +++ b/pipelines/pipeline.go @@ -5,20 +5,17 @@ import ( "errors" "fmt" "io" - "math" "os" "strings" - "sync/atomic" - "time" - "github.com/knights-analytics/tokenizers" + "github.com/daulet/tokenizers" ort "github.com/yalue/onnxruntime_go" util "github.com/knights-analytics/hugot/utils" ) -// BasePipeline is a basic pipeline type used for struct composition in the other pipelines. -type BasePipeline struct { +// BasePipeline can be embedded by a pipeline. +type basePipeline struct { ModelPath string OnnxFilename string PipelineName string @@ -28,27 +25,37 @@ type BasePipeline struct { TokenizerOptions []tokenizers.EncodeOption InputsMeta []ort.InputOutputInfo OutputsMeta []ort.InputOutputInfo - hasTokenTypeIds bool - hasAttentionMask bool - OutputDim int - TokenizerTimings *Timings - PipelineTimings *Timings + TokenizerTimings *timings + PipelineTimings *timings +} + +type OutputInfo struct { + Name string + Dimensions []int64 +} + +type PipelineMetadata struct { + OutputsInfo []OutputInfo } type PipelineBatchOutput interface { GetOutput() []any } +// Pipeline is the interface that any pipeline must implement. type Pipeline interface { - Destroy() error - GetStats() []string - GetOutputDim() int - Validate() error - Run([]string) (PipelineBatchOutput, error) + Destroy() error // Destroy the pipeline along with its onnx session + GetStats() []string // Get the pipeline running stats + Validate() error // Validate the pipeline for correctness + GetMetadata() PipelineMetadata // Return metadata information for the pipeline + Run([]string) (PipelineBatchOutput, error) // Run the pipeline on an input } +// PipelineOption is an option for a pipeline type. type PipelineOption[T Pipeline] func(eo T) +// PipelineConfig is a configuration for a pipeline type that can be used +// to create that pipeline. type PipelineConfig[T Pipeline] struct { ModelPath string Name string @@ -56,81 +63,84 @@ type PipelineConfig[T Pipeline] struct { Options []PipelineOption[T] } -type Timings struct { +type timings struct { NumCalls uint64 TotalNS uint64 } -type TokenizedInput struct { +// tokenizedInput holds the result of running tokenizer on an input. +type tokenizedInput struct { Raw string Tokens []string - TokenIds []uint32 - TypeIds []uint32 + TokenIDs []uint32 + TypeIDs []uint32 AttentionMask []uint32 SpecialTokensMask []uint32 MaxAttentionIndex int Offsets []tokenizers.Offset } +// PipelineBatch represents a batch of inputs that runs through the pipeline. type PipelineBatch struct { - Input []TokenizedInput - IdsTensor []int64 - TypeIdsTensor []int64 - AttentionMasksTensor []int64 - MaxSequence int - OutputTensor []float32 + Input []tokenizedInput + InputTensors []*ort.Tensor[int64] + MaxSequenceLength int + OutputTensors []*ort.Tensor[float32] } -func (p *BasePipeline) GetOutputDim() int { - return p.OutputDim -} +func (b *PipelineBatch) Destroy() error { + destroyErrors := make([]error, 0, len(b.InputTensors)+len(b.OutputTensors)) -func getOnnxFiles(path string) ([][]string, error) { - var onnxFiles [][]string - walker := func(_ context.Context, _ string, parent string, info os.FileInfo, _ io.Reader) (toContinue bool, err error) { - if strings.HasSuffix(info.Name(), ".onnx") { - onnxFiles = append(onnxFiles, []string{util.PathJoinSafe(path, parent), info.Name()}) - } - return true, nil + for _, tensor := range b.InputTensors { + destroyErrors = append(destroyErrors, tensor.Destroy()) } - err := util.FileSystem.Walk(context.Background(), path, walker) - return onnxFiles, err + + for _, tensor := range b.OutputTensors { + destroyErrors = append(destroyErrors, tensor.Destroy()) + } + return errors.Join(destroyErrors...) +} + +// NewBatch initializes a new batch for inference. +func NewBatch() *PipelineBatch { + return &PipelineBatch{} } -// Load the ort model supporting the pipeline. -func (p *BasePipeline) loadModel() error { - tokenizerBytes, err := util.ReadFileBytes(util.PathJoinSafe(p.ModelPath, "tokenizer.json")) +func loadTokenizer(modelPath string) (*tokenizers.Tokenizer, error) { + tokenizerBytes, err := util.ReadFileBytes(util.PathJoinSafe(modelPath, "tokenizer.json")) if err != nil { - return err + return nil, err } tk, err := tokenizers.FromBytes(tokenizerBytes) if err != nil { - return err + return nil, err } + return tk, nil +} - // we look for .onnx files. +func loadOnnxModelBytes(modelPath string, modelFilename string) ([]byte, error) { var modelOnnxFile string - onnxFiles, err := getOnnxFiles(p.ModelPath) + onnxFiles, err := getOnnxFiles(modelPath) if err != nil { - return err + return nil, err } if len(onnxFiles) == 0 { - return fmt.Errorf("no .onnx file detected at %s. There should be exactly .onnx file", p.ModelPath) + return nil, fmt.Errorf("no .onnx file detected at %s. There should be exactly .onnx file", modelPath) } if len(onnxFiles) > 1 { - if p.OnnxFilename == "" { - return fmt.Errorf("multiple .onnx file detected at %s and no OnnxFilename specified", p.ModelPath) + if modelFilename == "" { + return nil, fmt.Errorf("multiple .onnx file detected at %s and no OnnxFilename specified", modelPath) } modelNameFound := false for i := range onnxFiles { - if onnxFiles[i][1] == p.OnnxFilename { + if onnxFiles[i][1] == modelFilename { modelNameFound = true modelOnnxFile = util.PathJoinSafe(onnxFiles[i]...) } } if !modelNameFound { - return fmt.Errorf("file %s not found at %s", p.OnnxFilename, p.ModelPath) + return nil, fmt.Errorf("file %s not found at %s", modelFilename, modelPath) } } else { modelOnnxFile = util.PathJoinSafe(onnxFiles[0]...) @@ -138,70 +148,57 @@ func (p *BasePipeline) loadModel() error { onnxBytes, err := util.ReadFileBytes(modelOnnxFile) if err != nil { - return err + return nil, err } + return onnxBytes, err +} +func loadInputOutputMeta(onnxBytes []byte) ([]ort.InputOutputInfo, []ort.InputOutputInfo, error) { inputs, outputs, err := ort.GetInputOutputInfoWithONNXData(onnxBytes) if err != nil { - return err + return nil, nil, err } + return inputs, outputs, nil +} - p.InputsMeta = inputs - p.OutputsMeta = outputs - - inputNames := make([]string, len(inputs)) - for i, meta := range inputs { - inputNames[i] = meta.Name - switch meta.Name { - case "token_type_ids": - p.hasTokenTypeIds = true - case "attention_mask": - p.hasAttentionMask = true - } +func createSession(onnxBytes []byte, inputs, outputs []ort.InputOutputInfo, options *ort.SessionOptions) (*ort.DynamicAdvancedSession, error) { + var inputNames []string + var outputNames []string + for _, v := range inputs { + inputNames = append(inputNames, v.Name) } - outputNames := make([]string, len(outputs)) - for i, meta := range outputs { - outputNames[i] = meta.Name + for _, v := range outputs { + outputNames = append(outputNames, v.Name) } session, err := ort.NewDynamicAdvancedSessionWithONNXData( onnxBytes, inputNames, outputNames, - p.OrtOptions, + options, ) - if err != nil { - return err - } - - p.OrtSession = session - p.Tokenizer = tk - return nil + return session, err } -func (p *BasePipeline) Destroy() error { - var finalErr error - errTokenizer := p.Tokenizer.Close() - if errTokenizer != nil { - finalErr = errTokenizer - } - ortError := p.OrtSession.Destroy() - if ortError != nil { - finalErr = ortError +func getOnnxFiles(path string) ([][]string, error) { + var onnxFiles [][]string + walker := func(_ context.Context, _ string, parent string, info os.FileInfo, _ io.Reader) (toContinue bool, err error) { + if strings.HasSuffix(info.Name(), ".onnx") { + onnxFiles = append(onnxFiles, []string{util.PathJoinSafe(path, parent), info.Name()}) + } + return true, nil } - return finalErr + err := util.FileSystem.Walk(context.Background(), path, walker) + return onnxFiles, err } -// Preprocess the input strings in the batch -func (p *BasePipeline) Preprocess(inputs []string) PipelineBatch { - start := time.Now() - - outputs := make([]TokenizedInput, len(inputs)) +func tokenizeInputs(batch *PipelineBatch, tk *tokenizers.Tokenizer, inputs []string, options []tokenizers.EncodeOption) { + outputs := make([]tokenizedInput, len(inputs)) maxSequence := 0 for i, input := range inputs { - output := p.Tokenizer.EncodeWithOptions(input, + output := tk.EncodeWithOptions(input, true, - p.TokenizerOptions..., + options..., ) maxAttentionIndex := 0 @@ -211,11 +208,11 @@ func (p *BasePipeline) Preprocess(inputs []string) PipelineBatch { } } - outputs[i] = TokenizedInput{ + outputs[i] = tokenizedInput{ Raw: input, Tokens: output.Tokens, - TokenIds: output.IDs, - TypeIds: output.TypeIDs, + TokenIDs: output.IDs, + TypeIDs: output.TypeIDs, AttentionMask: output.AttentionMask, MaxAttentionIndex: maxAttentionIndex, SpecialTokensMask: output.SpecialTokensMask, @@ -225,114 +222,121 @@ func (p *BasePipeline) Preprocess(inputs []string) PipelineBatch { maxSequence = maxAttentionIndex } } - - atomic.AddUint64(&p.TokenizerTimings.NumCalls, 1) - atomic.AddUint64(&p.TokenizerTimings.TotalNS, uint64(time.Since(start))) - batch := p.convertInputToTensors(outputs, maxSequence+1) - return batch + batch.Input = outputs + batch.MaxSequenceLength = maxSequence + 1 } -func (p *BasePipeline) getInputTensors(batch PipelineBatch, actualBatchSize int64, maxSequence int64) ([]ort.ArbitraryTensor, error) { - inputTensors := make([]ort.ArbitraryTensor, len(p.InputsMeta)) - var err error - - for i, input := range p.InputsMeta { - var inputTensor *ort.Tensor[int64] - - // create the tensor for the input name - switch input.Name { - case "input_ids": - inputTensor, err = ort.NewTensor(ort.NewShape(actualBatchSize, maxSequence), batch.IdsTensor) - case "token_type_ids": - inputTensor, err = ort.NewTensor(ort.NewShape(actualBatchSize, maxSequence), batch.TypeIdsTensor) - case "attention_mask": - inputTensor, err = ort.NewTensor(ort.NewShape(actualBatchSize, maxSequence), batch.AttentionMasksTensor) +// createInputTensors creates ort input tensors. +func createInputTensors(batch *PipelineBatch, inputsMeta []ort.InputOutputInfo) error { + tensorSize := len(batch.Input) * (batch.MaxSequenceLength) + batchSize := int64(len(batch.Input)) + + inputTensors := make([]*ort.Tensor[int64], len(inputsMeta)) + var tensorCreationErr error + + for i, inputMeta := range inputsMeta { + backingSlice := make([]int64, tensorSize) + counter := 0 + + for _, input := range batch.Input { + length := len(input.TokenIDs) + for j := 0; j < batch.MaxSequenceLength; j++ { + if j+1 <= length { + switch inputMeta.Name { + case "input_ids": + backingSlice[counter] = int64(input.TokenIDs[j]) + case "token_type_ids": + backingSlice[counter] = int64(input.TypeIDs[j]) + case "attention_mask": + backingSlice[counter] = int64(input.AttentionMask[j]) + default: + return fmt.Errorf("input %s not recognized", inputMeta.Name) + } + } else { + backingSlice[counter] = 0 // pad with zero + } + counter++ + } + } + inputTensors[i], tensorCreationErr = ort.NewTensor(ort.NewShape(batchSize, int64(batch.MaxSequenceLength)), backingSlice) + if tensorCreationErr != nil { + return tensorCreationErr } - - inputTensors[i] = inputTensor } - return inputTensors, err + batch.InputTensors = inputTensors + return nil } -// Forward pass of the neural network on the tokenized input -func (p *BasePipeline) Forward(batch PipelineBatch) (PipelineBatch, error) { - start := time.Now() - - actualBatchSize := int64(len(batch.Input)) - maxSequence := int64(batch.MaxSequence) - inputTensors, err := p.getInputTensors(batch, actualBatchSize, maxSequence) - if err != nil { - return batch, err - } - - outputTensor, err4 := ort.NewEmptyTensor[float32](ort.NewShape(actualBatchSize, maxSequence, int64(p.OutputDim))) - if err4 != nil { - return batch, err4 - } - - defer func(inputTensors []ort.ArbitraryTensor) { - for _, tensor := range inputTensors { - err = errors.Join(err, tensor.Destroy()) - } - }(inputTensors) - - // Run Onnx model - errOnnx := p.OrtSession.Run(inputTensors, []ort.ArbitraryTensor{outputTensor}) - if errOnnx != nil { - return batch, errOnnx +func getNames(info []ort.InputOutputInfo) []string { + names := make([]string, 0, len(info)) + for _, v := range info { + names = append(names, v.Name) } - batch.OutputTensor = outputTensor.GetData() - defer func(outputTensor *ort.Tensor[float32]) { - err = errors.Join(err, outputTensor.Destroy()) - }(outputTensor) - - atomic.AddUint64(&p.PipelineTimings.NumCalls, 1) - atomic.AddUint64(&p.PipelineTimings.TotalNS, uint64(time.Since(start))) - return batch, err + return names } -// convert tokenized input to the format required by the onnxruntime library -func (p *BasePipeline) convertInputToTensors(inputs []TokenizedInput, maxSequence int) PipelineBatch { - tensorSize := len(inputs) * maxSequence - counter := 0 - - idsTensor := make([]int64, tensorSize) - typeIdsTensor := make([]int64, tensorSize) - attentionMasksTensor := make([]int64, tensorSize) - - for _, input := range inputs { - length := len(input.TokenIds) - for j := 0; j < maxSequence; j++ { - if j+1 <= length { - idsTensor[counter] = int64(input.TokenIds[j]) - if p.hasTokenTypeIds { - typeIdsTensor[counter] = int64(input.TypeIds[j]) - } - if p.hasAttentionMask { - attentionMasksTensor[counter] = int64(input.AttentionMask[j]) +func runSessionOnBatch(batch *PipelineBatch, session *ort.DynamicAdvancedSession, outputs []ort.InputOutputInfo) error { + actualBatchSize := int64(len(batch.Input)) + maxSequenceLength := int64(batch.MaxSequenceLength) + + // allocate vectors with right dimensions for the output + outputTensors := make([]*ort.Tensor[float32], len(outputs)) + arbitraryOutputTensors := make([]ort.ArbitraryTensor, len(outputs)) + var outputCreationErr error + + for outputIndex, meta := range outputs { + var batchDimSet bool + var tokenDimSet bool + actualDims := make([]int64, 0, len(meta.Dimensions)) + + for _, dim := range meta.Dimensions { + if dim == -1 { + if !batchDimSet { + actualDims = append(actualDims, actualBatchSize) + batchDimSet = true + } else if !tokenDimSet { + actualDims = append(actualDims, maxSequenceLength) + tokenDimSet = true + } else { + return fmt.Errorf("only two axis can be dynamic (batch size and number of tokens)") } } else { - // padding all vectors to max sequence length - idsTensor[counter] = 0 - typeIdsTensor[counter] = 0 - attentionMasksTensor[counter] = 0 + actualDims = append(actualDims, dim) } - counter++ } + outputShape := ort.NewShape(actualDims...) + outputTensors[outputIndex], outputCreationErr = ort.NewEmptyTensor[float32](outputShape) + if outputCreationErr != nil { + return outputCreationErr + } + arbitraryOutputTensors[outputIndex] = ort.ArbitraryTensor(outputTensors[outputIndex]) } - return PipelineBatch{ - Input: inputs, - IdsTensor: idsTensor, - TypeIdsTensor: typeIdsTensor, - AttentionMasksTensor: attentionMasksTensor, - MaxSequence: maxSequence, + + // Run Onnx model + arbitraryInputTensors := make([]ort.ArbitraryTensor, len(batch.InputTensors)) + for i, t := range batch.InputTensors { + arbitraryInputTensors[i] = ort.ArbitraryTensor(t) + } + + errOnnx := session.Run(arbitraryInputTensors, arbitraryOutputTensors) + if errOnnx != nil { + return errOnnx } + + // store resulting tensors + batch.OutputTensors = outputTensors + return nil } -func (p *BasePipeline) GetStats() []string { - return []string{ - fmt.Sprintf("Statistics for pipeline: %s", p.PipelineName), - fmt.Sprintf("Tokenizer: Total time=%s, Execution count=%d, Average query time=%s", time.Duration(p.TokenizerTimings.TotalNS), p.TokenizerTimings.NumCalls, time.Duration(float64(p.TokenizerTimings.TotalNS)/math.Max(1, float64(p.TokenizerTimings.NumCalls)))), - fmt.Sprintf("ONNX: Total time=%s, Execution count=%d, Average query time=%s", time.Duration(p.PipelineTimings.TotalNS), p.PipelineTimings.NumCalls, time.Duration(float64(p.PipelineTimings.TotalNS)/math.Max(1, float64(p.PipelineTimings.NumCalls)))), +func destroySession(tk *tokenizers.Tokenizer, session *ort.DynamicAdvancedSession) error { + var finalErr error + errTokenizer := tk.Close() + if errTokenizer != nil { + finalErr = errTokenizer } + ortError := session.Destroy() + if ortError != nil { + finalErr = ortError + } + return finalErr } diff --git a/pipelines/textClassification.go b/pipelines/textClassification.go index 02504bd..5435d41 100644 --- a/pipelines/textClassification.go +++ b/pipelines/textClassification.go @@ -3,27 +3,28 @@ package pipelines import ( "errors" "fmt" + "math" "sync/atomic" "time" util "github.com/knights-analytics/hugot/utils" + "github.com/daulet/tokenizers" jsoniter "github.com/json-iterator/go" - "github.com/knights-analytics/tokenizers" ort "github.com/yalue/onnxruntime_go" ) // types type TextClassificationPipeline struct { - BasePipeline - IdLabelMap map[int]string + basePipeline + IDLabelMap map[int]string AggregationFunctionName string ProblemType string } type TextClassificationPipelineConfig struct { - IdLabelMap map[int]string `json:"id2label"` + IDLabelMap map[int]string `json:"id2label"` } type ClassificationOutput struct { @@ -71,7 +72,7 @@ func WithMultiLabel() PipelineOption[*TextClassificationPipeline] { } } -// NewTextClassificationPipeline initializes a new text classification pipeline +// NewTextClassificationPipeline initializes a new text classification pipeline. func NewTextClassificationPipeline(config PipelineConfig[*TextClassificationPipeline], ortOptions *ort.SessionOptions) (*TextClassificationPipeline, error) { pipeline := &TextClassificationPipeline{} pipeline.ModelPath = config.ModelPath @@ -94,10 +95,17 @@ func NewTextClassificationPipeline(config PipelineConfig[*TextClassificationPipe } } + // tokenizer init pipeline.TokenizerOptions = []tokenizers.EncodeOption{ tokenizers.WithReturnAttentionMask(), } + tk, err := loadTokenizer(pipeline.ModelPath) + if err != nil { + return nil, err + } + pipeline.Tokenizer = tk + // read id to label map configPath := util.PathJoinSafe(pipeline.ModelPath, "config.json") pipelineInputConfig := TextClassificationPipelineConfig{} mapBytes, err := util.ReadFileBytes(configPath) @@ -109,88 +117,135 @@ func NewTextClassificationPipeline(config PipelineConfig[*TextClassificationPipe return nil, err } - pipeline.IdLabelMap = pipelineInputConfig.IdLabelMap - pipeline.PipelineTimings = &Timings{} - pipeline.TokenizerTimings = &Timings{} + pipeline.IDLabelMap = pipelineInputConfig.IDLabelMap - // load onnx model - loadErr := pipeline.loadModel() - if loadErr != nil { - return nil, loadErr + // onnx model init + model, err := loadOnnxModelBytes(pipeline.ModelPath, pipeline.OnnxFilename) + if err != nil { + return nil, err } - pipeline.OutputDim = int(pipeline.OutputsMeta[0].Dimensions[1]) + // init of inputs and outputs + inputs, outputs, err := loadInputOutputMeta(model) + if err != nil { + return nil, err + } + pipeline.InputsMeta = inputs + pipeline.OutputsMeta = outputs - // validate - validationErrors := pipeline.Validate() - if validationErrors != nil { - return nil, validationErrors + // creation of the session + session, err := createSession(model, inputs, pipeline.OutputsMeta, ortOptions) + if err != nil { + return nil, err } + pipeline.OrtSession = session + // initialize timings + pipeline.PipelineTimings = &timings{} + pipeline.TokenizerTimings = &timings{} + + // validate + err = pipeline.Validate() + if err != nil { + errDestroy := pipeline.Destroy() + return nil, errors.Join(err, errDestroy) + } return pipeline, nil } +// INTERFACE IMPLEMENTATION + +// GetMetadata returns metadata information about the pipeline, in particular: +// OutputInfo: names and dimensions of the output layer used for text classification. +func (p *TextClassificationPipeline) GetMetadata() PipelineMetadata { + return PipelineMetadata{ + OutputsInfo: []OutputInfo{ + { + Name: p.OutputsMeta[0].Name, + Dimensions: p.OutputsMeta[0].Dimensions, + }, + }, + } +} + +// Destroy frees the text classification pipeline resources. +func (p *TextClassificationPipeline) Destroy() error { + return destroySession(p.Tokenizer, p.OrtSession) +} + +// GetStats returns the runtime statistics for the pipeline. +func (p *TextClassificationPipeline) GetStats() []string { + return []string{ + fmt.Sprintf("Statistics for pipeline: %s", p.PipelineName), + fmt.Sprintf("Tokenizer: Total time=%s, Execution count=%d, Average query time=%s", + time.Duration(p.TokenizerTimings.TotalNS), + p.TokenizerTimings.NumCalls, + time.Duration(float64(p.TokenizerTimings.TotalNS)/math.Max(1, float64(p.TokenizerTimings.NumCalls)))), + fmt.Sprintf("ONNX: Total time=%s, Execution count=%d, Average query time=%s", + time.Duration(p.PipelineTimings.TotalNS), + p.PipelineTimings.NumCalls, + time.Duration(float64(p.PipelineTimings.TotalNS)/math.Max(1, float64(p.PipelineTimings.NumCalls)))), + } +} + +// Validate checks that the pipeline is valid. func (p *TextClassificationPipeline) Validate() error { var validationErrors []error - if len(p.IdLabelMap) < 1 { - validationErrors = append(validationErrors, fmt.Errorf("only single label classification models are currently supported and more than one label is required")) + if len(p.IDLabelMap) <= 0 { + validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: length of id2label map for token classification pipeline must be greater than zero")) } - if p.OutputDim <= 0 { - validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: outputDim parameter must be greater than zero")) + + outDims := p.OutputsMeta[0].Dimensions + if len(outDims) != 2 { + validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: text classification must have 2 dimensional output")) } - if len(p.IdLabelMap) <= 0 { - validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: length of id2label map for token classification pipeline must be greater than zero")) + dynamicBatch := false + for _, d := range outDims { + if d == -1 { + if dynamicBatch { + validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: text classification must have max one dynamic dimensions (input)")) + break + } + dynamicBatch = true + } } - if len(p.IdLabelMap) != p.OutputDim { - validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: length of id2label map does not match model output dimension")) + nLogits := int(outDims[len(outDims)-1]) + if len(p.IDLabelMap) != nLogits { + validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: length of id2label map does not match number of logits in output (%d)", nLogits)) } return errors.Join(validationErrors...) } -func (p *TextClassificationPipeline) Forward(batch PipelineBatch) (PipelineBatch, error) { +// Preprocess tokenizes the input strings. +func (p *TextClassificationPipeline) Preprocess(batch *PipelineBatch, inputs []string) error { start := time.Now() + tokenizeInputs(batch, p.Tokenizer, inputs, p.TokenizerOptions) + atomic.AddUint64(&p.TokenizerTimings.NumCalls, 1) + atomic.AddUint64(&p.TokenizerTimings.TotalNS, uint64(time.Since(start))) + err := createInputTensors(batch, p.InputsMeta) + return err +} - actualBatchSize := int64(len(batch.Input)) - maxSequence := int64(batch.MaxSequence) - inputTensors, err := p.getInputTensors(batch, actualBatchSize, maxSequence) +func (p *TextClassificationPipeline) Forward(batch *PipelineBatch) error { + start := time.Now() + err := runSessionOnBatch(batch, p.OrtSession, p.OutputsMeta) if err != nil { - return batch, err - } - - defer func(inputTensors []ort.ArbitraryTensor) { - for _, tensor := range inputTensors { - err = errors.Join(err, tensor.Destroy()) - } - }(inputTensors) - - outputTensor, errTensor := ort.NewEmptyTensor[float32](ort.NewShape(actualBatchSize, int64(p.OutputDim))) - if errTensor != nil { - return batch, errTensor - } - - defer func(outputTensor *ort.Tensor[float32]) { - err = errors.Join(err, outputTensor.Destroy()) - }(outputTensor) - - // Run Onnx model - errOnnx := p.OrtSession.Run(inputTensors, []ort.ArbitraryTensor{outputTensor}) - if errOnnx != nil { - return batch, errOnnx + return err } - batch.OutputTensor = outputTensor.GetData() - atomic.AddUint64(&p.PipelineTimings.NumCalls, 1) atomic.AddUint64(&p.PipelineTimings.TotalNS, uint64(time.Since(start))) - return batch, err + return nil } -func (p *TextClassificationPipeline) Postprocess(batch PipelineBatch) (*TextClassificationOutput, error) { - outputTensor := batch.OutputTensor +func (p *TextClassificationPipeline) Postprocess(batch *PipelineBatch) (*TextClassificationOutput, error) { + outputTensor := batch.OutputTensors[0] + outputDims := p.OutputsMeta[0].Dimensions + nLogit := outputDims[len(outputDims)-1] output := make([][]float32, len(batch.Input)) inputCounter := 0 vectorCounter := 0 - inputVector := make([]float32, p.OutputDim) + inputVector := make([]float32, nLogit) var aggregationFunction func([]float32) []float32 switch p.AggregationFunctionName { case "SIGMOID": @@ -201,13 +256,12 @@ func (p *TextClassificationPipeline) Postprocess(batch PipelineBatch) (*TextClas return nil, fmt.Errorf("aggregation function %s is not supported", p.AggregationFunctionName) } - for _, result := range outputTensor { + for _, result := range outputTensor.GetData() { inputVector[vectorCounter] = result - if vectorCounter == p.OutputDim-1 { - + if vectorCounter == int(nLogit)-1 { output[inputCounter] = aggregationFunction(inputVector) vectorCounter = 0 - inputVector = make([]float32, p.OutputDim) + inputVector = make([]float32, nLogit) inputCounter++ } else { vectorCounter++ @@ -229,7 +283,7 @@ func (p *TextClassificationPipeline) Postprocess(batch PipelineBatch) (*TextClas err = errArgMax continue } - class, ok := p.IdLabelMap[index] + class, ok := p.IDLabelMap[index] if !ok { err = fmt.Errorf("class with index number %d not found in id label map", index) } @@ -239,9 +293,9 @@ func (p *TextClassificationPipeline) Postprocess(batch PipelineBatch) (*TextClas } batchClassificationOutputs.ClassificationOutputs[i] = inputClassificationOutputs case "multiLabel": - inputClassificationOutputs := make([]ClassificationOutput, len(p.IdLabelMap)) + inputClassificationOutputs := make([]ClassificationOutput, len(p.IDLabelMap)) for j := range output[i] { - class, ok := p.IdLabelMap[j] + class, ok := p.IDLabelMap[j] if !ok { err = fmt.Errorf("class with index number %d not found in id label map", j) } @@ -258,16 +312,29 @@ func (p *TextClassificationPipeline) Postprocess(batch PipelineBatch) (*TextClas return &batchClassificationOutputs, err } -// Run the pipeline on a string batch +// Run the pipeline on a string batch. func (p *TextClassificationPipeline) Run(inputs []string) (PipelineBatchOutput, error) { return p.RunPipeline(inputs) } func (p *TextClassificationPipeline) RunPipeline(inputs []string) (*TextClassificationOutput, error) { - batch := p.Preprocess(inputs) - batch, err := p.Forward(batch) - if err != nil { - return nil, err + var runErrors []error + batch := NewBatch() + defer func(*PipelineBatch) { + runErrors = append(runErrors, batch.Destroy()) + }(batch) + + runErrors = append(runErrors, p.Preprocess(batch, inputs)) + if e := errors.Join(runErrors...); e != nil { + return nil, e } - return p.Postprocess(batch) + + runErrors = append(runErrors, p.Forward(batch)) + if e := errors.Join(runErrors...); e != nil { + return nil, e + } + + result, postErr := p.Postprocess(batch) + runErrors = append(runErrors, postErr) + return result, errors.Join(runErrors...) } diff --git a/pipelines/tokenClassification.go b/pipelines/tokenClassification.go index 6372435..7128e9d 100644 --- a/pipelines/tokenClassification.go +++ b/pipelines/tokenClassification.go @@ -3,30 +3,31 @@ package pipelines import ( "errors" "fmt" + "math" + "slices" "strings" - - // according to https://freshman.tech/snippets/go/check-if-slice-contains-element - "golang.org/x/exp/slices" + "sync/atomic" + "time" ort "github.com/yalue/onnxruntime_go" util "github.com/knights-analytics/hugot/utils" + "github.com/daulet/tokenizers" jsoniter "github.com/json-iterator/go" - "github.com/knights-analytics/tokenizers" ) -// types - +// TokenClassificationPipeline is a go version of huggingface tokenClassificationPipeline. +// https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/token_classification.py type TokenClassificationPipeline struct { - BasePipeline - IdLabelMap map[int]string + basePipeline + IDLabelMap map[int]string AggregationStrategy string IgnoreLabels []string } type TokenClassificationPipelineConfig struct { - IdLabelMap map[int]string `json:"id2label"` + IDLabelMap map[int]string `json:"id2label"` } type Entity struct { @@ -35,7 +36,7 @@ type Entity struct { Scores []float32 Index int Word string - TokenId uint32 + TokenID uint32 Start uint End uint IsSubword bool @@ -55,12 +56,17 @@ func (t *TokenClassificationOutput) GetOutput() []any { // options +// TODO: need to implement the other types of aggregation (max etc) + +// WithSimpleAggregation sets the aggregation strategy for the token labels to simple +// It reproduces simple aggregation from the huggingface implementation. func WithSimpleAggregation() PipelineOption[*TokenClassificationPipeline] { return func(pipeline *TokenClassificationPipeline) { pipeline.AggregationStrategy = "SIMPLE" } } +// WithoutAggregation returns the token labels. func WithoutAggregation() PipelineOption[*TokenClassificationPipeline] { return func(pipeline *TokenClassificationPipeline) { pipeline.AggregationStrategy = "NONE" @@ -73,7 +79,7 @@ func WithIgnoreLabels(ignoreLabels []string) PipelineOption[*TokenClassification } } -// NewTokenClassificationPipeline Initializes a feature extraction pipeline +// NewTokenClassificationPipeline Initializes a feature extraction pipeline. func NewTokenClassificationPipeline(config PipelineConfig[*TokenClassificationPipeline], ortOptions *ort.SessionOptions) (*TokenClassificationPipeline, error) { pipeline := &TokenClassificationPipeline{} pipeline.ModelPath = config.ModelPath @@ -84,7 +90,7 @@ func NewTokenClassificationPipeline(config PipelineConfig[*TokenClassificationPi o(pipeline) } - // inputs and encoding options + // tokenizer init pipeline.TokenizerOptions = []tokenizers.EncodeOption{ tokenizers.WithReturnTokens(), tokenizers.WithReturnTypeIDs(), @@ -92,8 +98,27 @@ func NewTokenClassificationPipeline(config PipelineConfig[*TokenClassificationPi tokenizers.WithReturnSpecialTokensMask(), tokenizers.WithReturnOffsets(), } + tk, err := loadTokenizer(pipeline.ModelPath) + if err != nil { + return nil, err + } + pipeline.Tokenizer = tk + + // onnx model init + model, err := loadOnnxModelBytes(pipeline.ModelPath, pipeline.OnnxFilename) + if err != nil { + return nil, err + } + + // init of inputs and outputs + inputs, outputs, err := loadInputOutputMeta(model) + if err != nil { + return nil, err + } + pipeline.InputsMeta = inputs + pipeline.OutputsMeta = outputs - // load json model config and set pipeline settings + // Id label map configPath := util.PathJoinSafe(config.ModelPath, "config.json") pipelineInputConfig := TokenClassificationPipelineConfig{} mapBytes, err := util.ReadFileBytes(configPath) @@ -105,13 +130,9 @@ func NewTokenClassificationPipeline(config PipelineConfig[*TokenClassificationPi if err != nil { return nil, err } - pipeline.IdLabelMap = pipelineInputConfig.IdLabelMap - - pipeline.PipelineTimings = &Timings{} - pipeline.TokenizerTimings = &Timings{} - - // defaults + pipeline.IDLabelMap = pipelineInputConfig.IDLabelMap + // default strategies if not set if pipeline.AggregationStrategy == "" { pipeline.AggregationStrategy = "SIMPLE" } @@ -119,14 +140,15 @@ func NewTokenClassificationPipeline(config PipelineConfig[*TokenClassificationPi pipeline.IgnoreLabels = []string{"O"} } - // load onnx model - errModel := pipeline.loadModel() - if errModel != nil { - return nil, errModel - } + pipeline.PipelineTimings = &timings{} + pipeline.TokenizerTimings = &timings{} - // the dimension of the output is taken from the output meta. - pipeline.OutputDim = int(pipeline.OutputsMeta[0].Dimensions[2]) + // creation of the session. Only one output (either token or sentence embedding). + session, err := createSession(model, inputs, outputs, ortOptions) + if err != nil { + return nil, err + } + pipeline.OrtSession = session err = pipeline.Validate() if err != nil { @@ -135,54 +157,121 @@ func NewTokenClassificationPipeline(config PipelineConfig[*TokenClassificationPi return pipeline, nil } +// INTERFACE IMPLEMENTATION + +// GetMetadata returns metadata information about the pipeline, in particular: +// OutputInfo: names and dimensions of the output layer used for token classification. +func (p *TokenClassificationPipeline) GetMetadata() PipelineMetadata { + return PipelineMetadata{ + OutputsInfo: []OutputInfo{ + { + Name: p.OutputsMeta[0].Name, + Dimensions: p.OutputsMeta[0].Dimensions, + }, + }, + } +} + +// Destroy frees the feature extraction pipeline resources. +func (p *TokenClassificationPipeline) Destroy() error { + return destroySession(p.Tokenizer, p.OrtSession) +} + +// GetStats returns the runtime statistics for the pipeline. +func (p *TokenClassificationPipeline) GetStats() []string { + return []string{ + fmt.Sprintf("Statistics for pipeline: %s", p.PipelineName), + fmt.Sprintf("Tokenizer: Total time=%s, Execution count=%d, Average query time=%s", + time.Duration(p.TokenizerTimings.TotalNS), + p.TokenizerTimings.NumCalls, + time.Duration(float64(p.TokenizerTimings.TotalNS)/math.Max(1, float64(p.TokenizerTimings.NumCalls)))), + fmt.Sprintf("ONNX: Total time=%s, Execution count=%d, Average query time=%s", + time.Duration(p.PipelineTimings.TotalNS), + p.PipelineTimings.NumCalls, + time.Duration(float64(p.PipelineTimings.TotalNS)/math.Max(1, float64(p.PipelineTimings.NumCalls)))), + } +} + +// Validate checks that the pipeline is valid. func (p *TokenClassificationPipeline) Validate() error { var validationErrors []error - if p.OutputDim <= 0 { - validationErrors = append(validationErrors, fmt.Errorf("p configuration invalid: outputDim parameter must be greater than zero")) + outputDim := p.OutputsMeta[0].Dimensions + if len(outputDim) != 3 { + validationErrors = append(validationErrors, + fmt.Errorf("output for token classification must be three dimensional (input, sequence, logits)")) } - if len(p.IdLabelMap) <= 0 { - validationErrors = append(validationErrors, fmt.Errorf("p configuration invalid: length of id2label map for token classification p must be greater than zero")) + + if outputDim[len(outputDim)-1] == -1 { + validationErrors = append(validationErrors, + fmt.Errorf("logit dimension cannot be dynamic")) } - if len(p.IdLabelMap) != p.OutputDim { - validationErrors = append(validationErrors, fmt.Errorf("p configuration invalid: length of id2label map does not match model output dimension")) + if len(p.IDLabelMap) <= 0 { + validationErrors = append(validationErrors, fmt.Errorf("p configuration invalid: length of id2label map for token classification p must be greater than zero")) } return errors.Join(validationErrors...) } -// Postprocess function for a token classification pipeline -func (p *TokenClassificationPipeline) Postprocess(batch PipelineBatch) (*TokenClassificationOutput, error) { +// Preprocess tokenizes the input strings. +func (p *TokenClassificationPipeline) Preprocess(batch *PipelineBatch, inputs []string) error { + start := time.Now() + tokenizeInputs(batch, p.Tokenizer, inputs, p.TokenizerOptions) + atomic.AddUint64(&p.TokenizerTimings.NumCalls, 1) + atomic.AddUint64(&p.TokenizerTimings.TotalNS, uint64(time.Since(start))) + err := createInputTensors(batch, p.InputsMeta) + return err +} - outputs := make([][][]float32, len(batch.Input)) // holds the final output - inputVectors := make([][]float32, 0, batch.MaxSequence) // holds the embeddings of each original token (no padding) for an input - tokenVector := make([]float32, p.OutputDim) // holds the vector embedding for a token - inputTokens := batch.Input[0].TokenIds +// Forward performs the forward inference of the pipeline. +func (p *TokenClassificationPipeline) Forward(batch *PipelineBatch) error { + start := time.Now() + err := runSessionOnBatch(batch, p.OrtSession, p.OutputsMeta) + if err != nil { + return err + } + atomic.AddUint64(&p.PipelineTimings.NumCalls, 1) + atomic.AddUint64(&p.PipelineTimings.TotalNS, uint64(time.Since(start))) + return nil +} + +// Postprocess function for a token classification pipeline. +func (p *TokenClassificationPipeline) Postprocess(batch *PipelineBatch) (*TokenClassificationOutput, error) { + if len(batch.Input) == 0 { + return &TokenClassificationOutput{}, nil + } + + outputDims := p.OutputsMeta[0].Dimensions + tokenLogitsDim := int(outputDims[len(outputDims)-1]) + outputs := make([][][]float32, len(batch.Input)) // holds the final output + inputVectors := make([][]float32, 0, batch.MaxSequenceLength) // holds the embeddings of each original token (no padding) for an input + tokenVector := make([]float32, tokenLogitsDim) // holds the vector embedding for a token + inputTokens := batch.Input[0].TokenIDs // original tokens from the input excluding the padded tokens tokenVectorCounter := 0 tokenCounter := 0 inputCounter := 0 nInputs := len(batch.Input) - // construct the output vectors, however discard the embeddings of the padding tokens so that the output vector length + // construct the output vectors by gathering the logits, + // however discard the embeddings of the padding tokens so that the output vector length // for an input is equal to the number of original tokens - - for _, result := range batch.OutputTensor { + for _, result := range batch.OutputTensors[0].GetData() { tokenVector[tokenVectorCounter] = result - if tokenVectorCounter == p.OutputDim-1 { + if tokenVectorCounter == tokenLogitsDim-1 { // raw result vector for token is now complete if tokenCounter < len(inputTokens) { // it is an original token (not resulting from padding), keep it inputVectors = append(inputVectors, util.SoftMax(tokenVector)) } tokenVectorCounter = 0 - tokenVector = make([]float32, p.OutputDim) - if tokenCounter == batch.MaxSequence-1 { + tokenVector = make([]float32, tokenLogitsDim) + if tokenCounter == batch.MaxSequenceLength-1 { // we went through all tokens in the sequence for this input outputs[inputCounter] = inputVectors tokenCounter = 0 - inputVectors = make([][]float32, 0, batch.MaxSequence) + inputVectors = make([][]float32, 0, batch.MaxSequenceLength) inputCounter++ if inputCounter < nInputs { - inputTokens = batch.Input[inputCounter].TokenIds + inputTokens = batch.Input[inputCounter].TokenIDs } } else { tokenCounter++ @@ -216,8 +305,7 @@ func (p *TokenClassificationPipeline) Postprocess(batch PipelineBatch) (*TokenCl } // GatherPreEntities from batch of logits to list of pre-aggregated outputs -func (p *TokenClassificationPipeline) GatherPreEntities(input TokenizedInput, output [][]float32) []Entity { - +func (p *TokenClassificationPipeline) GatherPreEntities(input tokenizedInput, output [][]float32) []Entity { sentence := input.Raw var preEntities []Entity @@ -229,7 +317,7 @@ func (p *TokenClassificationPipeline) GatherPreEntities(input TokenizedInput, ou } // TODO: the python code uses id_to_token to get the token here which is a method on the rust tokenizer, check if it's better word := input.Tokens[j] - tokenId := input.TokenIds[j] + tokenID := input.TokenIDs[j] // TODO: the determination of subword can probably be better done by exporting the words field from the tokenizer directly startInd := input.Offsets[j][0] endInd := input.Offsets[j][1] @@ -239,7 +327,7 @@ func (p *TokenClassificationPipeline) GatherPreEntities(input TokenizedInput, ou // in that case set the subword as in the python code preEntities = append(preEntities, Entity{ Word: word, - TokenId: tokenId, + TokenID: tokenID, Scores: tokenScores, Start: startInd, End: endInd, @@ -250,7 +338,7 @@ func (p *TokenClassificationPipeline) GatherPreEntities(input TokenizedInput, ou return preEntities } -func (p *TokenClassificationPipeline) Aggregate(input TokenizedInput, preEntities []Entity) ([]Entity, error) { +func (p *TokenClassificationPipeline) Aggregate(input tokenizedInput, preEntities []Entity) ([]Entity, error) { entities := make([]Entity, len(preEntities)) if p.AggregationStrategy == "SIMPLE" || p.AggregationStrategy == "NONE" { for i, preEntity := range preEntities { @@ -258,7 +346,7 @@ func (p *TokenClassificationPipeline) Aggregate(input TokenizedInput, preEntitie if argMaxErr != nil { return nil, argMaxErr } - label, ok := p.IdLabelMap[entityIdx] + label, ok := p.IDLabelMap[entityIdx] if !ok { return nil, fmt.Errorf("could not determine entity type for input %s, predicted entity index %d", input.Raw, entityIdx) } @@ -267,7 +355,7 @@ func (p *TokenClassificationPipeline) Aggregate(input TokenizedInput, preEntitie Score: score, Index: preEntity.Index, Word: preEntity.Word, - TokenId: preEntity.TokenId, + TokenID: preEntity.TokenID, Start: preEntity.Start, End: preEntity.End, } @@ -310,7 +398,7 @@ func (p *TokenClassificationPipeline) groupSubEntities(entities []Entity) Entity tokens := make([]uint32, len(entities)) for i, s := range entities { scores[i] = s.Score - tokens[i] = s.TokenId + tokens[i] = s.TokenID } score := util.Mean(scores) // note: here we directly appeal to the tokenizer decoder with the tokenIds @@ -326,7 +414,7 @@ func (p *TokenClassificationPipeline) groupSubEntities(entities []Entity) Entity } } -// GroupEntities group together adjacent tokens with the same entity predicted +// GroupEntities group together adjacent tokens with the same entity predicted. func (p *TokenClassificationPipeline) GroupEntities(entities []Entity) ([]Entity, error) { var entityGroups []Entity var currentGroupDisagg []Entity @@ -355,16 +443,30 @@ func (p *TokenClassificationPipeline) GroupEntities(entities []Entity) ([]Entity return entityGroups, nil } -// Run the pipeline on a string batch +// Run the pipeline on a string batch. func (p *TokenClassificationPipeline) Run(inputs []string) (PipelineBatchOutput, error) { return p.RunPipeline(inputs) } +// RunPipeline is like Run but returns the concrete type rather than the interface. func (p *TokenClassificationPipeline) RunPipeline(inputs []string) (*TokenClassificationOutput, error) { - batch := p.Preprocess(inputs) - batch, errForward := p.Forward(batch) - if errForward != nil { - return nil, errForward + var runErrors []error + batch := NewBatch() + defer func(*PipelineBatch) { + runErrors = append(runErrors, batch.Destroy()) + }(batch) + + runErrors = append(runErrors, p.Preprocess(batch, inputs)) + if e := errors.Join(runErrors...); e != nil { + return nil, e } - return p.Postprocess(batch) + + runErrors = append(runErrors, p.Forward(batch)) + if e := errors.Join(runErrors...); e != nil { + return nil, e + } + + result, postErr := p.Postprocess(batch) + runErrors = append(runErrors, postErr) + return result, errors.Join(runErrors...) } diff --git a/pipelines/zeroShotClassification.go b/pipelines/zeroShotClassification.go new file mode 100644 index 0000000..9743013 --- /dev/null +++ b/pipelines/zeroShotClassification.go @@ -0,0 +1,514 @@ +package pipelines + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "math" + "os" + "sort" + "strings" + "sync/atomic" + "time" + + util "github.com/knights-analytics/hugot/utils" + ort "github.com/yalue/onnxruntime_go" + + "github.com/daulet/tokenizers" + jsoniter "github.com/json-iterator/go" +) + +/** +sample usage: + +package main +import ( + "fmt" + "github.com/knights-analytics/hugot" + "github.com/knights-analytics/hugot/pipelines" +) + +func main() { + session, err := hugot.NewSession() + check(err) + defer func(session *hugot.Session) { + err := session.Destroy() + check(err) + }(session) + + modelPath, err := session.DownloadModel("protectai/deberta-v3-base-zeroshot-v1-onnx", "./models", hugot.NewDownloadOptions()) + check(err) + + config := hugot.ZeroShotClassificationConfig{ + ModelPath: modelPath, + Name: "testPipeline", + Options: []pipelines.PipelineOption[*pipelines.ZeroShotClassificationPipeline]{ + pipelines.WithHypothesisTemplate("This example is {}."), + pipelines.WithLabels([]string{"fun", "dangerous"}), + }, + } + + sentimentPipeline, err := hugot.NewPipeline(session, config) + check(err) + + outputs, _ := sentimentPipeline.RunPipeline([]string{"I am going to war today"}) + fmt.Println("raw outputs: ", outputs) + fmt.Println("outputs.GetOutput(): ", outputs.GetOutput()) +} + +main() +**/ + +type ZeroShotClassificationPipeline struct { + basePipeline + IDLabelMap map[int]string + Sequences []string + Labels []string + HypothesisTemplate string + Multilabel bool + entailmentID int + separatorToken string +} + +type ZeroShotClassificationPipelineConfig struct { + IDLabelMap map[int]string `json:"id2label"` +} + +type ZeroShotClassificationOutput struct { + Sequence string + SortedValues []struct { + Key string + Value float64 + } +} + +type ZeroShotOutput struct { + ClassificationOutputs []ZeroShotClassificationOutput +} + +// options + +// WithMultilabel can be used to set whether the pipeline is multilabel. +func WithMultilabel(multilabel bool) PipelineOption[*ZeroShotClassificationPipeline] { + return func(pipeline *ZeroShotClassificationPipeline) { + pipeline.Multilabel = multilabel + } +} + +// WithLabels can be used to set the labels to classify the examples. +func WithLabels(labels []string) PipelineOption[*ZeroShotClassificationPipeline] { + return func(pipeline *ZeroShotClassificationPipeline) { + pipeline.Labels = labels + } +} + +// WithHypothesisTemplate can be used to set the hypothesis template for classification. +func WithHypothesisTemplate(hypothesisTemplate string) PipelineOption[*ZeroShotClassificationPipeline] { + return func(pipeline *ZeroShotClassificationPipeline) { + pipeline.HypothesisTemplate = hypothesisTemplate + } +} + +// GetOutput converts raw output to readable output. +func (t *ZeroShotOutput) GetOutput() []any { + out := make([]any, len(t.ClassificationOutputs)) + for i, o := range t.ClassificationOutputs { + out[i] = any(o) + } + return out +} + +// create all pairs between input sequences and labels +func createSequencePairs(sequences interface{}, labels []string, hypothesisTemplate string) ([][][]string, []string, error) { + // Check if labels or sequences are empty + if len(labels) == 0 || sequences == nil { + return nil, nil, errors.New("you must include at least one label and at least one sequence") + } + + // Check if hypothesisTemplate can be formatted with labels + if fmt.Sprintf(hypothesisTemplate, labels[0]) == hypothesisTemplate { + return nil, nil, fmt.Errorf(`the provided hypothesis_template "%s" was not able to be formatted with the target labels. Make sure the passed template includes formatting syntax such as {{}} where the label should go`, hypothesisTemplate) + } + + // Convert sequences to []string if it's a single string + var seqs []string + switch v := sequences.(type) { + case string: + seqs = []string{v} + case []string: + seqs = v + default: + return nil, nil, errors.New("sequences must be either a string or a []string") + } + + // Create sequence_pairs + var sequencePairs [][][]string + for _, sequence := range seqs { + var temp [][]string + for _, label := range labels { + hypothesis := strings.Replace(hypothesisTemplate, "{}", label, 1) + temp = append(temp, []string{sequence, hypothesis}) + } + sequencePairs = append(sequencePairs, temp) + } + return sequencePairs, seqs, nil +} + +// NewZeroShotClassificationPipeline create new Zero Shot Classification Pipeline. +func NewZeroShotClassificationPipeline(config PipelineConfig[*ZeroShotClassificationPipeline], ortOptions *ort.SessionOptions) (*ZeroShotClassificationPipeline, error) { + pipeline := &ZeroShotClassificationPipeline{} + pipeline.ModelPath = config.ModelPath + pipeline.PipelineName = config.Name + pipeline.OrtOptions = ortOptions + pipeline.OnnxFilename = config.OnnxFilename + pipeline.entailmentID = -1 // Default value + pipeline.HypothesisTemplate = "This example is {}." + + for _, o := range config.Options { + o(pipeline) + } + + if len(pipeline.Labels) == 0 { + return nil, fmt.Errorf("no labels provided, please provide labels using the WithLabels() option") + } + + pipeline.TokenizerOptions = []tokenizers.EncodeOption{ + tokenizers.WithReturnTypeIDs(), + tokenizers.WithReturnAttentionMask(), + } + + tk, err := loadTokenizer(pipeline.ModelPath) + if err != nil { + return nil, err + } + pipeline.Tokenizer = tk + + // read id to label map + configPath := util.PathJoinSafe(pipeline.ModelPath, "config.json") + pipelineInputConfig := ZeroShotClassificationPipelineConfig{} + mapBytes, err := util.ReadFileBytes(configPath) + if err != nil { + return nil, err + } + err = jsoniter.Unmarshal(mapBytes, &pipelineInputConfig) + if err != nil { + return nil, err + } + + // Set IDLabelMap + pipeline.IDLabelMap = pipelineInputConfig.IDLabelMap + + // Find entailment ID + for id, label := range pipeline.IDLabelMap { + if strings.HasPrefix(strings.ToLower(label), "entail") { + pipeline.entailmentID = id + break + } + } + + configPath1 := util.PathJoinSafe(pipeline.ModelPath, "special_tokens_map.json") + file, err := os.Open(configPath1) + if err != nil { + return nil, fmt.Errorf("cannot read special_tokens_map.json at %s", pipeline.ModelPath) + } + defer func() { + err = file.Close() + }() + + byteValue, _ := io.ReadAll(file) + var result map[string]interface{} + err = json.Unmarshal(byteValue, &result) + if err != nil { + return nil, fmt.Errorf("cannot unmarshal special_tokens_map.json at %s", pipeline.ModelPath) + } + + sepToken, ok := result["sep_token"] + if !ok { + return nil, fmt.Errorf("no sep token detected in special_tokens_map.json at %s", pipeline.ModelPath) + } + + switch v := sepToken.(type) { + case map[string]interface{}: + t, ok := v["content"] + if !ok { + return nil, fmt.Errorf("sep_token is map but no content field is available") + } + tString, ok := t.(string) + if !ok { + return nil, fmt.Errorf("sep_token cannot be converted to string: %v", t) + } + pipeline.separatorToken = tString + case string: + pipeline.separatorToken = v + default: + return nil, fmt.Errorf("sep_token has unexpected type: %v", v) + } + + // onnx model init + model, err := loadOnnxModelBytes(pipeline.ModelPath, pipeline.OnnxFilename) + if err != nil { + return nil, err + } + + inputs, outputs, err := loadInputOutputMeta(model) + if err != nil { + return nil, err + } + pipeline.InputsMeta = inputs + pipeline.OutputsMeta = outputs + + session, err := createSession(model, inputs, pipeline.OutputsMeta, ortOptions) + if err != nil { + return nil, err + } + pipeline.OrtSession = session + + pipeline.PipelineTimings = &timings{} + pipeline.TokenizerTimings = &timings{} + return pipeline, err +} + +func (p *ZeroShotClassificationPipeline) Preprocess(batch *PipelineBatch, inputs []string) error { + start := time.Now() + tokenizeInputs(batch, p.Tokenizer, inputs, p.TokenizerOptions) + atomic.AddUint64(&p.TokenizerTimings.NumCalls, 1) + atomic.AddUint64(&p.TokenizerTimings.TotalNS, uint64(time.Since(start))) + err := createInputTensors(batch, p.InputsMeta) + return err +} + +func (p *ZeroShotClassificationPipeline) Forward(batch *PipelineBatch) error { + start := time.Now() + err := runSessionOnBatch(batch, p.OrtSession, p.OutputsMeta) + if err != nil { + return err + } + atomic.AddUint64(&p.PipelineTimings.NumCalls, 1) + atomic.AddUint64(&p.PipelineTimings.TotalNS, uint64(time.Since(start))) + return nil +} + +func (p *ZeroShotClassificationPipeline) Postprocess(outputTensors [][][]float32, labels []string, sequences []string) (*ZeroShotOutput, error) { + classificationOutputs := make([]ZeroShotClassificationOutput, 0, len(sequences)) + + LabelLikelihood := make(map[string]float64) + if p.Multilabel || len(p.Labels) == 1 { + for ind, sequence := range outputTensors { + output := ZeroShotClassificationOutput{ + Sequence: sequences[ind], + } + + var entailmentLogits []float32 + var contradictionLogits []float32 + + var entailmentID int + var contradictionID int + switch p.entailmentID { + case -1: + entailmentID = len(sequence[0]) - 1 + contradictionID = 0 + default: + entailmentID = p.entailmentID + contradictionID = 0 + if entailmentID == 0 { + contradictionID = len(sequence[0]) - 1 + } + } + + for _, tensor := range sequence { + entailmentLogits = append(entailmentLogits, tensor[entailmentID]) + contradictionLogits = append(contradictionLogits, tensor[contradictionID]) + } + + for i := range entailmentLogits { + logits := []float64{float64(contradictionLogits[i]), float64(entailmentLogits[i])} + expLogits := []float64{math.Exp(logits[0]), math.Exp(logits[1])} + sumExpLogits := expLogits[0] + expLogits[1] + score := expLogits[1] / sumExpLogits + LabelLikelihood[labels[i]] = score + } + + // Define ss as a slice of anonymous structs + var ss []struct { + Key string + Value float64 + } + for k, v := range LabelLikelihood { + ss = append(ss, struct { + Key string + Value float64 + }{k, v}) + } + + // Sort the slice by the value field + sort.Slice(ss, func(i, j int) bool { + return ss[i].Value > ss[j].Value + }) + + output.SortedValues = ss + classificationOutputs = append(classificationOutputs, output) + } + return &ZeroShotOutput{ + ClassificationOutputs: classificationOutputs, + }, nil + } + + for ind, sequence := range outputTensors { + output := ZeroShotClassificationOutput{} + + var entailmentLogits []float32 + var entailmentID int + switch p.entailmentID { + case -1: + entailmentID = len(sequence[0]) - 1 + default: + entailmentID = p.entailmentID + } + for _, tensor := range sequence { + entailmentLogits = append(entailmentLogits, tensor[entailmentID]) + } + + var numerator []float64 + var logitSum float64 + for _, logit := range entailmentLogits { + exp := math.Exp(float64(logit)) + numerator = append(numerator, exp) + logitSum += exp + } + + var quotient []float64 + + for ind, i := range numerator { + quotient = append(quotient, i/logitSum) + LabelLikelihood[labels[ind]] = quotient[ind] + } + + output.Sequence = sequences[ind] + + // Define ss as a slice of anonymous structs + var ss []struct { + Key string + Value float64 + } + for k, v := range LabelLikelihood { + ss = append(ss, struct { + Key string + Value float64 + }{k, v}) + } + + // Sort the slice by the value field + sort.Slice(ss, func(i, j int) bool { + return ss[i].Value > ss[j].Value + }) + + output.SortedValues = ss + classificationOutputs = append(classificationOutputs, output) + } + return &ZeroShotOutput{ + ClassificationOutputs: classificationOutputs, + }, nil +} + +func (p *ZeroShotClassificationPipeline) RunPipeline(inputs []string) (*ZeroShotOutput, error) { + var outputTensors [][][]float32 + batch := NewBatch() + var runErrors []error + defer func(*PipelineBatch) { + runErrors = append(runErrors, batch.Destroy()) + }(batch) + + sequencePairs, _, err := createSequencePairs(inputs, p.Labels, p.HypothesisTemplate) + if err != nil { + return nil, err + } + + for _, sequence := range sequencePairs { + var sequenceTensors [][]float32 + for _, pair := range sequence { + // have to do this because python inserts a separator token in between the two clauses when tokenizing + // separator token isn't universal and depends on its value in special_tokens_map.json of model + // still isn't perfect because some models (protectai/MoritzLaurer-roberta-base-zeroshot-v2.0-c-onnx for example) + // insert two separator tokens while others (protectai/deberta-v3-base-zeroshot-v1-onnx and others) only insert one + // need to find a way to determine how many to insert or find a better way to tokenize inputs + // The difference in outputs for one separator vs two is very small (differences in the thousandths place), but they + // definitely are different + concatenatedString := pair[0] + p.separatorToken + pair[1] + runErrors = append(runErrors, p.Preprocess(batch, []string{concatenatedString})) + if e := errors.Join(runErrors...); e != nil { + return nil, e + } + runErrors = append(runErrors, p.Forward(batch)) + if e := errors.Join(runErrors...); e != nil { + return nil, e + } + sequenceTensors = append(sequenceTensors, batch.OutputTensors[0].GetData()) + } + outputTensors = append(outputTensors, sequenceTensors) + } + + outputs, err := p.Postprocess(outputTensors, p.Labels, inputs) + runErrors = append(runErrors, err) + return outputs, errors.Join(runErrors...) +} + +// PIPELINE INTERFACE IMPLEMENTATION + +func (p *ZeroShotClassificationPipeline) Destroy() error { + return destroySession(p.Tokenizer, p.OrtSession) +} + +func (p *ZeroShotClassificationPipeline) GetStats() []string { + return []string{ + fmt.Sprintf("Statistics for pipeline: %s", p.PipelineName), + fmt.Sprintf("Tokenizer: Total time=%s, Execution count=%d, Average query time=%s", + time.Duration(p.TokenizerTimings.TotalNS), + p.TokenizerTimings.NumCalls, + time.Duration(float64(p.TokenizerTimings.TotalNS)/math.Max(1, float64(p.TokenizerTimings.NumCalls)))), + fmt.Sprintf("ONNX: Total time=%s, Execution count=%d, Average query time=%s", + time.Duration(p.PipelineTimings.TotalNS), + p.PipelineTimings.NumCalls, + time.Duration(float64(p.PipelineTimings.TotalNS)/math.Max(1, float64(p.PipelineTimings.NumCalls)))), + } +} + +func (p *ZeroShotClassificationPipeline) GetMetadata() PipelineMetadata { + return PipelineMetadata{ + OutputsInfo: []OutputInfo{ + { + Name: p.OutputsMeta[0].Name, + Dimensions: p.OutputsMeta[0].Dimensions, + }, + }, + } +} + +func (p *ZeroShotClassificationPipeline) Run(inputs []string) (PipelineBatchOutput, error) { + return p.RunPipeline(inputs) +} + +func (p *ZeroShotClassificationPipeline) Validate() error { + var validationErrors []error + + if len(p.IDLabelMap) <= 0 { + validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: length of id2label map for token classification pipeline must be greater than zero")) + } + + outDims := p.OutputsMeta[0].Dimensions + if len(outDims) != 2 { + validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: zero shot classification must have 2 dimensional output")) + } + + dynamicBatch := false + for _, d := range outDims { + if d == -1 { + if dynamicBatch { + validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: text classification must have max one dynamic dimensions (input)")) + break + } + dynamicBatch = true + } + } + return errors.Join(validationErrors...) +} diff --git a/testData/downloadModels.go b/testData/downloadModels.go index a54fa21..3dc39f8 100644 --- a/testData/downloadModels.go +++ b/testData/downloadModels.go @@ -31,7 +31,8 @@ func main() { } downloadOptions := hugot.NewDownloadOptions() for _, modelName := range []string{ - "KnightsAnalytics/all-MiniLM-L6-v2", + "sentence-transformers/all-MiniLM-L6-v2", + "protectai/deberta-v3-base-zeroshot-v1-onnx", "KnightsAnalytics/distilbert-base-uncased-finetuned-sst-2-english", "KnightsAnalytics/distilbert-NER", "SamLowe/roberta-base-go_emotions-onnx"} {