From cc6b1cc6e117d8a504b48628bf6ad965329338b1 Mon Sep 17 00:00:00 2001 From: riccardopinosio Date: Tue, 12 Mar 2024 09:08:44 +0000 Subject: [PATCH] fix tests --- cmd/main.go | 41 +++++------ cmd/main_test.go | 69 ++++++++++++++++--- .../{test.jsonl => textClassification.jsonl} | 0 cmd/testData/tokenClassification.jsonl | 2 + go.mod | 3 +- go.sum | 9 +-- hugot_test.go | 67 ++++++++++-------- 7 files changed, 126 insertions(+), 65 deletions(-) rename cmd/testData/{test.jsonl => textClassification.jsonl} (100%) create mode 100644 cmd/testData/tokenClassification.jsonl diff --git a/cmd/main.go b/cmd/main.go index a5d66d0..197a0fd 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,6 +1,7 @@ package main import ( + "bufio" "context" "encoding/json" "errors" @@ -14,6 +15,7 @@ import ( "github.com/knights-analytics/hugot" "github.com/knights-analytics/hugot/pipelines" util "github.com/knights-analytics/hugot/utils" + "github.com/mattn/go-isatty" "github.com/urfave/cli/v2" ) @@ -170,6 +172,8 @@ var runCommand = &cli.Command{ } }() + // read inputs + exists, err := util.FileSystem.Exists(ctx.Context, inputPath) if err != nil { return err @@ -196,19 +200,14 @@ var runCommand = &cli.Command{ if inputPath != "" { return fmt.Errorf("file %s does not exist", inputPath) } - // otherwise read from stdin - fi, err := os.Stdin.Stat() - if err != nil { - return err - } - size := fi.Size() - if size > 0 { + + if !isatty.IsTerminal(os.Stdin.Fd()) && !isatty.IsCygwinTerminal(os.Stdin.Fd()) { + // there is something to process on stdin err := readInputs(os.Stdin, inputChannel) if err != nil { return err } } - return nil } close(inputChannel) @@ -290,21 +289,23 @@ func processWithPipeline(wg *sync.WaitGroup, inputChannel chan []input, processe func readInputs(inputSource io.Reader, inputChannel chan []input) error { inputBatch := make([]input, 0, 20) - d := json.NewDecoder(inputSource) - for { + + scanner := bufio.NewScanner(inputSource) + for scanner.Scan() { var line input - if err := d.Decode(&line); err == io.EOF { - inputChannel <- inputBatch - break // done decoding file - } else if err != nil { + err := json.Unmarshal(scanner.Bytes(), &line) + if err != nil { return err - } else { - inputBatch = append(inputBatch, line) - if len(inputBatch) == batchSize { - inputChannel <- inputBatch - inputBatch = []input{} - } } + inputBatch = append(inputBatch, line) + if len(inputBatch) == batchSize { + inputChannel <- inputBatch + inputBatch = []input{} + } + } + // flush + if len(inputBatch) > 0 { + inputChannel <- inputBatch } return nil } diff --git a/cmd/main_test.go b/cmd/main_test.go index 1800f4d..fc06b58 100644 --- a/cmd/main_test.go +++ b/cmd/main_test.go @@ -11,10 +11,13 @@ import ( "github.com/urfave/cli/v2" ) -//go:embed testData/test.jsonl -var testData []byte +//go:embed testData/textClassification.jsonl +var textClassificationData []byte -func TestCliRun(t *testing.T) { +//go:embed testData/tokenClassification.jsonl +var tokenClassificationData []byte + +func TestTextClassificationCli(t *testing.T) { app := &cli.App{ Name: "hugot", Usage: "Huggingface transformers from the command line - alpha", @@ -33,9 +36,9 @@ func TestCliRun(t *testing.T) { recurseDir := path.Join(testDataDir, "cliRecurseTest") err := os.MkdirAll(recurseDir, os.ModePerm) check(t, err) - err = os.WriteFile(path.Join(testDataDir, "test-0.jsonl"), testData, os.ModePerm) + err = os.WriteFile(path.Join(testDataDir, "test-0.jsonl"), textClassificationData, os.ModePerm) check(t, err) - err = os.WriteFile(path.Join(recurseDir, "test-1.jsonl"), testData, os.ModePerm) + err = os.WriteFile(path.Join(recurseDir, "test-1.jsonl"), textClassificationData, os.ModePerm) check(t, err) defer func() { err := os.RemoveAll(testDataDir) @@ -48,7 +51,7 @@ func TestCliRun(t *testing.T) { } } -func TestStdinEmpty(t *testing.T) { +func TestTokenClassificationCli(t *testing.T) { app := &cli.App{ Name: "hugot", Usage: "Huggingface transformers from the command line - alpha", @@ -60,14 +63,60 @@ func TestStdinEmpty(t *testing.T) { if modelFolder == "" { modelFolder = "../models/" } - testModel := path.Join(modelFolder, "distilbert-base-uncased-finetuned-sst-2-english") - inputPath = "" + testModel := path.Join(modelFolder, "distilbert-NER") + + testDataDir := path.Join(os.TempDir(), "hugoTestData") + err := os.MkdirAll(testDataDir, os.ModePerm) + check(t, err) + err = os.WriteFile(path.Join(testDataDir, "test-token-classification.jsonl"), tokenClassificationData, os.ModePerm) + check(t, err) + defer func() { + err := os.RemoveAll(testDataDir) + check(t, err) + }() - // check cli doesn't hang on empty stdin but terminates - args := append(baseArgs, "run", fmt.Sprintf("--model=%s", testModel), "--type=textClassification") + args := append(baseArgs, "run", fmt.Sprintf("--input=%s", path.Join(testDataDir, "test-token-classification.jsonl")), + fmt.Sprintf("--model=%s", testModel), "--type=tokenClassification", fmt.Sprintf("--output=%s", testDataDir)) if err := app.Run(args); err != nil { check(t, err) } + result, err := os.ReadFile(path.Join(testDataDir, "worker-0.jsonl")) + check(t, err) + fmt.Println(string(result)) +} + +func TestFeatureExtractionCli(t *testing.T) { + app := &cli.App{ + Name: "hugot", + Usage: "Huggingface transformers from the command line - alpha", + Commands: []*cli.Command{runCommand}, + } + baseArgs := os.Args[0:1] + + modelFolder := os.Getenv("TEST_MODELS_FOLDER") + if modelFolder == "" { + modelFolder = "../models/" + } + testModel := path.Join(modelFolder, "all-MiniLM-L6-v2") + + testDataDir := path.Join(os.TempDir(), "hugoTestData") + err := os.MkdirAll(testDataDir, os.ModePerm) + check(t, err) + err = os.WriteFile(path.Join(testDataDir, "test-feature-extraction.jsonl"), tokenClassificationData, os.ModePerm) + check(t, err) + defer func() { + err := os.RemoveAll(testDataDir) + check(t, err) + }() + + args := append(baseArgs, "run", fmt.Sprintf("--input=%s", path.Join(testDataDir, "test-feature-extraction.jsonl")), + fmt.Sprintf("--model=%s", testModel), "--type=featureExtraction", fmt.Sprintf("--output=%s", testDataDir)) + if err := app.Run(args); err != nil { + check(t, err) + } + result, err := os.ReadFile(path.Join(testDataDir, "worker-0.jsonl")) + check(t, err) + fmt.Println(string(result)) } func check(t *testing.T, err error) { diff --git a/cmd/testData/test.jsonl b/cmd/testData/textClassification.jsonl similarity index 100% rename from cmd/testData/test.jsonl rename to cmd/testData/textClassification.jsonl diff --git a/cmd/testData/tokenClassification.jsonl b/cmd/testData/tokenClassification.jsonl new file mode 100644 index 0000000..5ac093b --- /dev/null +++ b/cmd/testData/tokenClassification.jsonl @@ -0,0 +1,2 @@ +{"input": "Rome is a city in Italy."} +{"input": "Microsoft is headquartered in Seattle, Unites States."} \ No newline at end of file diff --git a/go.mod b/go.mod index 36724b7..807ac18 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22 require ( github.com/json-iterator/go v1.1.12 github.com/knights-analytics/tokenizers v0.10.0 + github.com/mattn/go-isatty v0.0.20 github.com/stretchr/testify v1.9.0 github.com/urfave/cli/v2 v2.27.1 github.com/viant/afs v1.25.0 @@ -14,7 +15,7 @@ require ( require ( cloud.google.com/go/storage v1.39.0 // indirect - github.com/aws/aws-sdk-go v1.50.35 // indirect + github.com/aws/aws-sdk-go v1.50.36 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-errors/errors v1.5.1 // indirect diff --git a/go.sum b/go.sum index 4147e3c..8fe6044 100644 --- a/go.sum +++ b/go.sum @@ -8,10 +8,8 @@ cloud.google.com/go/iam v1.1.6 h1:bEa06k05IO4f4uJonbB5iAgKTPpABy1ayxaIZV/GHVc= cloud.google.com/go/iam v1.1.6/go.mod h1:O0zxdPeGBoFdWW3HWmBxJsk0pfvNM/p/qa82rWOGTwI= cloud.google.com/go/storage v1.39.0 h1:brbjUa4hbDHhpQf48tjqMaXEV+f1OGoaTmQau9tmCsA= cloud.google.com/go/storage v1.39.0/go.mod h1:OAEj/WZwUYjA3YHQ10/YcN9ttGuEpLwvaoyBXIPikEk= -github.com/aws/aws-sdk-go v1.50.34 h1:J1LjHzWNN/yVxQDTr0NIlI5vz9xRPvWiNCjQ4+5wh58= -github.com/aws/aws-sdk-go v1.50.34/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= -github.com/aws/aws-sdk-go v1.50.35 h1:llQnNddBI/64pK7pwUFBoWYmg8+XGQUCs214eMbSDZc= -github.com/aws/aws-sdk-go v1.50.35/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +github.com/aws/aws-sdk-go v1.50.36 h1:PjWXHwZPuTLMR1NIb8nEjLucZBMzmf84TLoLbD8BZqk= +github.com/aws/aws-sdk-go v1.50.36/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= github.com/cpuguy83/go-md2man/v2 v2.0.3 h1:qMCsGGgs+MAzDFyp9LpAe1Lqy/fY/qCovCm0qnXZOBM= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -51,6 +49,8 @@ github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -98,6 +98,7 @@ golang.org/x/oauth2 v0.18.0 h1:09qnuIAgzdx1XplqJvW6CQqMCtGZykZWcXzPMPUusvI= golang.org/x/oauth2 v0.18.0/go.mod h1:Wf7knwG0MPoWIMMBgFlEaSUDaKskp0dCfrlJRJXbBi8= golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= diff --git a/hugot_test.go b/hugot_test.go index 3a92aa1..002780d 100644 --- a/hugot_test.go +++ b/hugot_test.go @@ -4,7 +4,6 @@ import ( _ "embed" "encoding/json" "fmt" - "log" "math" "os" "path" @@ -30,9 +29,10 @@ var onnxruntimeSharedLibrary = "" func TestTextClassificationPipeline(t *testing.T) { session, err := hugot.NewSession(onnxruntimeSharedLibrary) - check(err) + check(t, err) defer func(session *hugot.Session) { - check(session.Destroy()) + err := session.Destroy() + check(t, err) }(session) modelFolder := os.Getenv("TEST_MODELS_FOLDER") if modelFolder == "" { @@ -40,7 +40,7 @@ func TestTextClassificationPipeline(t *testing.T) { } modelPath := path.Join(modelFolder, "distilbert-base-uncased-finetuned-sst-2-english") sentimentPipeline, err := session.NewTextClassificationPipeline(modelPath, "testPipeline") - check(err) + check(t, err) tests := []struct { pipeline *pipelines.TextClassificationPipeline @@ -78,7 +78,7 @@ func TestTextClassificationPipeline(t *testing.T) { if !ok { t.FailNow() } - check(err) + check(t, err) for i, expected := range tt.expected.ClassificationOutputs { checkClassificationOutput(t, expected, result.ClassificationOutputs[i]) } @@ -91,9 +91,10 @@ func TestTextClassificationPipeline(t *testing.T) { func TestTextClassificationPipelineValidation(t *testing.T) { session, err := hugot.NewSession(onnxruntimeSharedLibrary) - check(err) + check(t, err) defer func(session *hugot.Session) { - check(session.Destroy()) + err := session.Destroy() + check(t, err) }(session) modelFolder := os.Getenv("TEST_MODELS_FOLDER") if modelFolder == "" { @@ -101,7 +102,7 @@ func TestTextClassificationPipelineValidation(t *testing.T) { } modelPath := path.Join(modelFolder, "distilbert-base-uncased-finetuned-sst-2-english") sentimentPipeline, err := session.NewTextClassificationPipeline(modelPath, "testPipeline", pipelines.WithAggregationFunction(util.SoftMax)) - check(err) + check(t, err) sentimentPipeline.IdLabelMap = map[int]string{} err = sentimentPipeline.Validate() assert.Error(t, err) @@ -122,9 +123,10 @@ func TestTextClassificationPipelineValidation(t *testing.T) { func TestTokenClassificationPipeline(t *testing.T) { session, err := hugot.NewSession(onnxruntimeSharedLibrary) - check(err) + check(t, err) defer func(session *hugot.Session) { - check(session.Destroy()) + err := session.Destroy() + check(t, err) }(session) modelFolder := os.Getenv("TEST_MODELS_FOLDER") @@ -133,12 +135,13 @@ func TestTokenClassificationPipeline(t *testing.T) { } modelPath := path.Join(modelFolder, "distilbert-NER") pipelineSimple, err2 := session.NewTokenClassificationPipeline(modelPath, "testPipelineSimple", pipelines.WithSimpleAggregation()) - check(err2) + check(t, err2) pipelineNone, err3 := session.NewTokenClassificationPipeline(modelPath, "testPipelineNone", pipelines.WithoutAggregation()) - check(err3) + check(t, err3) var expectedResults map[int]pipelines.TokenClassificationOutput - check(json.Unmarshal(tokenExpectedByte, &expectedResults)) + err4 := json.Unmarshal(tokenExpectedByte, &expectedResults) + check(t, err4) tests := []struct { pipeline *pipelines.TokenClassificationPipeline @@ -169,7 +172,7 @@ func TestTokenClassificationPipeline(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { batchResult, err := tt.pipeline.Run(tt.strings) - check(err) + check(t, err) result, ok := batchResult.(*pipelines.TokenClassificationOutput) if !ok { t.FailNow() @@ -190,9 +193,10 @@ func TestTokenClassificationPipeline(t *testing.T) { func TestTokenClassificationPipelineValidation(t *testing.T) { session, err := hugot.NewSession(onnxruntimeSharedLibrary) - check(err) + check(t, err) defer func(session *hugot.Session) { - check(session.Destroy()) + err := session.Destroy() + check(t, err) }(session) modelFolder := os.Getenv("TEST_MODELS_FOLDER") @@ -201,7 +205,7 @@ func TestTokenClassificationPipelineValidation(t *testing.T) { } modelPath := path.Join(modelFolder, "distilbert-NER") pipelineSimple, err2 := session.NewTokenClassificationPipeline(modelPath, "testPipelineSimple", pipelines.WithSimpleAggregation()) - check(err2) + check(t, err2) pipelineSimple.IdLabelMap = map[int]string{} err = pipelineSimple.Validate() @@ -223,9 +227,10 @@ func TestTokenClassificationPipelineValidation(t *testing.T) { func TestFeatureExtractionPipeline(t *testing.T) { session, err := hugot.NewSession(onnxruntimeSharedLibrary) - check(err) + check(t, err) defer func(session *hugot.Session) { - check(session.Destroy()) + err := session.Destroy() + check(t, err) }(session) modelFolder := os.Getenv("TEST_MODELS_FOLDER") @@ -234,17 +239,18 @@ func TestFeatureExtractionPipeline(t *testing.T) { } modelPath := path.Join(modelFolder, "all-MiniLM-L6-v2") pipeline, err := session.NewFeatureExtractionPipeline(modelPath, "testPipeline") - check(err) + check(t, err) var expectedResults map[string][][]float32 - check(json.Unmarshal(resultsByte, &expectedResults)) + err = json.Unmarshal(resultsByte, &expectedResults) + check(t, err) var testResults [][]float32 // test 'robert smith' testResults = expectedResults["test1output"] for i := 1; i <= 10; i++ { batchResult, err := pipeline.Run([]string{"robert smith"}) - check(err) + check(t, err) result, ok := batchResult.(*pipelines.FeatureExtractionOutput) if !ok { t.FailNow() @@ -260,7 +266,7 @@ func TestFeatureExtractionPipeline(t *testing.T) { testResults = expectedResults["test2output"] for i := 1; i <= 10; i++ { batchResult, err := pipeline.Run([]string{"robert smith junior", "francis ford coppola"}) - check(err) + check(t, err) result, ok := batchResult.(*pipelines.FeatureExtractionOutput) if !ok { t.FailNow() @@ -283,7 +289,7 @@ func TestFeatureExtractionPipeline(t *testing.T) { for k, sentencePair := range testPairs { // these vectors should be the same firstBatchResult, err2 := pipeline.Run(sentencePair[0]) - check(err2) + check(t, err2) firstResult, ok := firstBatchResult.(*pipelines.FeatureExtractionOutput) if !ok { t.FailNow() @@ -291,7 +297,7 @@ func TestFeatureExtractionPipeline(t *testing.T) { firstEmbedding := firstResult.Embeddings[0] secondBatchResult, err3 := pipeline.Run(sentencePair[1]) - check(err3) + check(t, err3) secondResult, ok := secondBatchResult.(*pipelines.FeatureExtractionOutput) if !ok { t.FailNow() @@ -315,9 +321,10 @@ func TestFeatureExtractionPipeline(t *testing.T) { func TestFeatureExtractionPipelineValidation(t *testing.T) { session, err := hugot.NewSession(onnxruntimeSharedLibrary) - check(err) + check(t, err) defer func(session *hugot.Session) { - check(session.Destroy()) + err := session.Destroy() + check(t, err) }(session) modelFolder := os.Getenv("TEST_MODELS_FOLDER") @@ -326,7 +333,7 @@ func TestFeatureExtractionPipelineValidation(t *testing.T) { } modelPath := path.Join(modelFolder, "all-MiniLM-L6-v2") pipeline, err := session.NewFeatureExtractionPipeline(modelPath, "testPipeline") - check(err) + check(t, err) pipeline.OutputDim = 0 err = pipeline.Validate() @@ -367,9 +374,9 @@ func almostEqual(a, b float64) bool { return math.Abs(a-b) <= 0.0001 } -func check(err error) { +func check(t *testing.T, err error) { if err != nil { - log.Panic(err) + t.Fatalf("Test failed with error %s", err.Error()) } }