Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardopinosio committed Mar 12, 2024
1 parent a0a9770 commit cc6b1cc
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 65 deletions.
41 changes: 21 additions & 20 deletions cmd/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"bufio"
"context"
"encoding/json"
"errors"
Expand All @@ -14,6 +15,7 @@ import (
"github.com/knights-analytics/hugot"
"github.com/knights-analytics/hugot/pipelines"
util "github.com/knights-analytics/hugot/utils"
"github.com/mattn/go-isatty"
"github.com/urfave/cli/v2"
)

Expand Down Expand Up @@ -170,6 +172,8 @@ var runCommand = &cli.Command{
}
}()

// read inputs

exists, err := util.FileSystem.Exists(ctx.Context, inputPath)
if err != nil {
return err
Expand All @@ -196,19 +200,14 @@ var runCommand = &cli.Command{
if inputPath != "" {
return fmt.Errorf("file %s does not exist", inputPath)
}
// otherwise read from stdin
fi, err := os.Stdin.Stat()
if err != nil {
return err
}
size := fi.Size()
if size > 0 {

if !isatty.IsTerminal(os.Stdin.Fd()) && !isatty.IsCygwinTerminal(os.Stdin.Fd()) {
// there is something to process on stdin
err := readInputs(os.Stdin, inputChannel)
if err != nil {
return err
}
}
return nil
}

close(inputChannel)
Expand Down Expand Up @@ -290,21 +289,23 @@ func processWithPipeline(wg *sync.WaitGroup, inputChannel chan []input, processe

func readInputs(inputSource io.Reader, inputChannel chan []input) error {
inputBatch := make([]input, 0, 20)
d := json.NewDecoder(inputSource)
for {

scanner := bufio.NewScanner(inputSource)
for scanner.Scan() {
var line input
if err := d.Decode(&line); err == io.EOF {
inputChannel <- inputBatch
break // done decoding file
} else if err != nil {
err := json.Unmarshal(scanner.Bytes(), &line)
if err != nil {
return err
} else {
inputBatch = append(inputBatch, line)
if len(inputBatch) == batchSize {
inputChannel <- inputBatch
inputBatch = []input{}
}
}
inputBatch = append(inputBatch, line)
if len(inputBatch) == batchSize {
inputChannel <- inputBatch
inputBatch = []input{}
}
}
// flush
if len(inputBatch) > 0 {
inputChannel <- inputBatch
}
return nil
}
Expand Down
69 changes: 59 additions & 10 deletions cmd/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ import (
"github.com/urfave/cli/v2"
)

//go:embed testData/test.jsonl
var testData []byte
//go:embed testData/textClassification.jsonl
var textClassificationData []byte

func TestCliRun(t *testing.T) {
//go:embed testData/tokenClassification.jsonl
var tokenClassificationData []byte

func TestTextClassificationCli(t *testing.T) {
app := &cli.App{
Name: "hugot",
Usage: "Huggingface transformers from the command line - alpha",
Expand All @@ -33,9 +36,9 @@ func TestCliRun(t *testing.T) {
recurseDir := path.Join(testDataDir, "cliRecurseTest")
err := os.MkdirAll(recurseDir, os.ModePerm)
check(t, err)
err = os.WriteFile(path.Join(testDataDir, "test-0.jsonl"), testData, os.ModePerm)
err = os.WriteFile(path.Join(testDataDir, "test-0.jsonl"), textClassificationData, os.ModePerm)
check(t, err)
err = os.WriteFile(path.Join(recurseDir, "test-1.jsonl"), testData, os.ModePerm)
err = os.WriteFile(path.Join(recurseDir, "test-1.jsonl"), textClassificationData, os.ModePerm)
check(t, err)
defer func() {
err := os.RemoveAll(testDataDir)
Expand All @@ -48,7 +51,7 @@ func TestCliRun(t *testing.T) {
}
}

func TestStdinEmpty(t *testing.T) {
func TestTokenClassificationCli(t *testing.T) {
app := &cli.App{
Name: "hugot",
Usage: "Huggingface transformers from the command line - alpha",
Expand All @@ -60,14 +63,60 @@ func TestStdinEmpty(t *testing.T) {
if modelFolder == "" {
modelFolder = "../models/"
}
testModel := path.Join(modelFolder, "distilbert-base-uncased-finetuned-sst-2-english")
inputPath = ""
testModel := path.Join(modelFolder, "distilbert-NER")

testDataDir := path.Join(os.TempDir(), "hugoTestData")
err := os.MkdirAll(testDataDir, os.ModePerm)
check(t, err)
err = os.WriteFile(path.Join(testDataDir, "test-token-classification.jsonl"), tokenClassificationData, os.ModePerm)
check(t, err)
defer func() {
err := os.RemoveAll(testDataDir)
check(t, err)
}()

// check cli doesn't hang on empty stdin but terminates
args := append(baseArgs, "run", fmt.Sprintf("--model=%s", testModel), "--type=textClassification")
args := append(baseArgs, "run", fmt.Sprintf("--input=%s", path.Join(testDataDir, "test-token-classification.jsonl")),
fmt.Sprintf("--model=%s", testModel), "--type=tokenClassification", fmt.Sprintf("--output=%s", testDataDir))
if err := app.Run(args); err != nil {
check(t, err)
}
result, err := os.ReadFile(path.Join(testDataDir, "worker-0.jsonl"))
check(t, err)
fmt.Println(string(result))
}

func TestFeatureExtractionCli(t *testing.T) {
app := &cli.App{
Name: "hugot",
Usage: "Huggingface transformers from the command line - alpha",
Commands: []*cli.Command{runCommand},
}
baseArgs := os.Args[0:1]

modelFolder := os.Getenv("TEST_MODELS_FOLDER")
if modelFolder == "" {
modelFolder = "../models/"
}
testModel := path.Join(modelFolder, "all-MiniLM-L6-v2")

testDataDir := path.Join(os.TempDir(), "hugoTestData")
err := os.MkdirAll(testDataDir, os.ModePerm)
check(t, err)
err = os.WriteFile(path.Join(testDataDir, "test-feature-extraction.jsonl"), tokenClassificationData, os.ModePerm)
check(t, err)
defer func() {
err := os.RemoveAll(testDataDir)
check(t, err)
}()

args := append(baseArgs, "run", fmt.Sprintf("--input=%s", path.Join(testDataDir, "test-feature-extraction.jsonl")),
fmt.Sprintf("--model=%s", testModel), "--type=featureExtraction", fmt.Sprintf("--output=%s", testDataDir))
if err := app.Run(args); err != nil {
check(t, err)
}
result, err := os.ReadFile(path.Join(testDataDir, "worker-0.jsonl"))
check(t, err)
fmt.Println(string(result))
}

func check(t *testing.T, err error) {
Expand Down
File renamed without changes.
2 changes: 2 additions & 0 deletions cmd/testData/tokenClassification.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"input": "Rome is a city in Italy."}
{"input": "Microsoft is headquartered in Seattle, Unites States."}
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.22
require (
github.com/json-iterator/go v1.1.12
github.com/knights-analytics/tokenizers v0.10.0
github.com/mattn/go-isatty v0.0.20
github.com/stretchr/testify v1.9.0
github.com/urfave/cli/v2 v2.27.1
github.com/viant/afs v1.25.0
Expand All @@ -14,7 +15,7 @@ require (

require (
cloud.google.com/go/storage v1.39.0 // indirect
github.com/aws/aws-sdk-go v1.50.35 // indirect
github.com/aws/aws-sdk-go v1.50.36 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-errors/errors v1.5.1 // indirect
Expand Down
9 changes: 5 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ cloud.google.com/go/iam v1.1.6 h1:bEa06k05IO4f4uJonbB5iAgKTPpABy1ayxaIZV/GHVc=
cloud.google.com/go/iam v1.1.6/go.mod h1:O0zxdPeGBoFdWW3HWmBxJsk0pfvNM/p/qa82rWOGTwI=
cloud.google.com/go/storage v1.39.0 h1:brbjUa4hbDHhpQf48tjqMaXEV+f1OGoaTmQau9tmCsA=
cloud.google.com/go/storage v1.39.0/go.mod h1:OAEj/WZwUYjA3YHQ10/YcN9ttGuEpLwvaoyBXIPikEk=
github.com/aws/aws-sdk-go v1.50.34 h1:J1LjHzWNN/yVxQDTr0NIlI5vz9xRPvWiNCjQ4+5wh58=
github.com/aws/aws-sdk-go v1.50.34/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
github.com/aws/aws-sdk-go v1.50.35 h1:llQnNddBI/64pK7pwUFBoWYmg8+XGQUCs214eMbSDZc=
github.com/aws/aws-sdk-go v1.50.35/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
github.com/aws/aws-sdk-go v1.50.36 h1:PjWXHwZPuTLMR1NIb8nEjLucZBMzmf84TLoLbD8BZqk=
github.com/aws/aws-sdk-go v1.50.36/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
github.com/cpuguy83/go-md2man/v2 v2.0.3 h1:qMCsGGgs+MAzDFyp9LpAe1Lqy/fY/qCovCm0qnXZOBM=
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
Expand Down Expand Up @@ -51,6 +49,8 @@ github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
Expand Down Expand Up @@ -98,6 +98,7 @@ golang.org/x/oauth2 v0.18.0 h1:09qnuIAgzdx1XplqJvW6CQqMCtGZykZWcXzPMPUusvI=
golang.org/x/oauth2 v0.18.0/go.mod h1:Wf7knwG0MPoWIMMBgFlEaSUDaKskp0dCfrlJRJXbBi8=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
Expand Down
Loading

0 comments on commit cc6b1cc

Please sign in to comment.