Skip to content

Commit

Permalink
feat: adding ability to download models
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardopinosio committed Mar 18, 2024
1 parent 22bd07f commit 3db0c10
Show file tree
Hide file tree
Showing 15 changed files with 543 additions and 99 deletions.
132 changes: 132 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
linters:
disable-all: true
enable:
# Check for pass []any as any in variadic func(...any).
# Rare case but saved me from debugging a few times.
- asasalint

# I prefer plane ASCII identifiers.
# Symbol `∆` instead of `delta` looks cool but no thanks.
- asciicheck

# Checks for dangerous unicode character sequences.
# Super rare but why not to be a bit paranoid?
- bidichk

# Checks whether HTTP response body is closed successfully.
- bodyclose

# Check whether the function uses a non-inherited context.
- contextcheck

# Check for two durations multiplied together.
- durationcheck

# Forces to not skip error check.
- errcheck

# Checks `Err-` prefix for var and `-Error` suffix for error type.
- errname

# Suggests to use `%w` for error-wrapping.
- errorlint

# Checks for pointers to enclosing loop variables.
- exportloopref

# Forces to put `.` at the end of the comment. Code is poetry.
- godot

# Might not be that important but I prefer to keep all of them.
# `gofumpt` is amazing, kudos to Daniel Marti https://github.com/mvdan/gofumpt
- gofmt
- gofumpt
- goimports

# Allow or ban replace directives in go.mod
# or force explanation for retract directives.
- gomoddirectives

# Powerful security-oriented linter. But requires some time to
# configure it properly, see https://github.com/securego/gosec#available-rules
- gosec

# Linter that specializes in simplifying code.
- gosimple

# Official Go tool. Must have.
- govet

# Detects when assignments to existing variables are not used
# Last week I caught a bug with it.
- ineffassign

# Fix all the misspells, amazing thing.
- misspell

# Finds naked/bare returns and requires change them.
- nakedret

# Both require a bit more explicit returns.
- nilerr
- nilnil

# Finds sending HTTP request without context.Context.
- noctx

# Forces comment why another check is disabled.
# Better not to have //nolint: at all ;)
- nolintlint

# Finds slices that could potentially be pre-allocated.
# Small performance win + cleaner code.
- prealloc

# Finds shadowing of Go's predeclared identifiers.
# I hear a lot of complaints from junior developers.
# But after some time they find it very useful.
- predeclared

# Lint your Prometheus metrics name.
- promlinter

# Checks that package variables are not reassigned.
# Super rare case but can catch bad things (like `io.EOF = nil`)
- reassign

# Drop-in replacement of `golint`.
- revive

# Somewhat similar to `bodyclose` but for `database/sql` package.
- rowserrcheck
- sqlclosecheck

# I have found that it's not the same as staticcheck binary :\
- staticcheck

# Is a replacement for `golint`, similar to `revive`.
- stylecheck

# Check struct tags.
- tagliatelle

# Test-related checks. All of them are good.
- tenv
- testableexamples
- thelper
- tparallel

# Remove unnecessary type conversions, make code cleaner
- unconvert

# Might be noisy but better to know what is unused
- unparam

# Must have. Finds unused declarations.
- unused

# Detect the possibility to use variables/constants from stdlib.
- usestdlibvars

# Finds wasted assignment statements.
- wastedassign
8 changes: 3 additions & 5 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG GO_VERSION=1.22.0
ARG GO_VERSION=1.22.1
ARG RUST_VERSION=1.76
ARG ONNXRUNTIME_VERSION=1.17.1

Expand Down Expand Up @@ -37,8 +37,6 @@ RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -o test2json -ldflags="-s -w"
curl -LO https://github.com/gotestyourself/gotestsum/releases/download/v1.11.0/gotestsum_1.11.0_linux_amd64.tar.gz && \
tar -xzf gotestsum_1.11.0_linux_amd64.tar.gz --directory /usr/local/bin

COPY ./models /models

# build cli
COPY . /build
WORKDIR /build
Expand All @@ -47,7 +45,7 @@ RUN cd ./cmd && CGO_ENABLED=1 GOOS=linux GOARCH=amd64 go build -a -o ./target ma
# NON-PRIVILEDGED USER
# create non-priviledged testuser with id: 1000
RUN dnf install --disablerepo=* --enablerepo=amazonlinux --allowerasing -y dirmngr && dnf clean all
RUN useradd -u 1000 -m testuser
RUN useradd -u 1000 -m testuser && chown -R testuser:testuser /build

# ENTRYPOINT
COPY ./scripts/entrypoint.sh /entrypoint.sh
Expand All @@ -62,4 +60,4 @@ FROM scratch AS artifacts

COPY --from=building /usr/lib64/onnxruntime.so onnxruntime.so
COPY --from=building /usr/lib/libtokenizers.a libtokenizers.a
COPY --from=building /build/cmd/target /hugot-cli-linux-amd64
COPY --from=building /build/cmd/target /hugot-cli-linux-amd64
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,12 @@ defer func(session *hugot.Session) {
err := session.Destroy()
check(err)
}(session)
// we now create a text classification pipeline. It requires the path to the onnx model folder,
// Let's download an onnx sentiment test classification model in the current directory
modelPath, err := session.DownloadModel("KnightsAnalytics/distilbert-base-uncased-finetuned-sst-2-english", "./")
check(err)
// we now create a text classification pipeline. It requires the path to the just downloader onnx model folder,
// and a pipeline name
sentimentPipeline, err := session.NewTextClassificationPipeline("/path/to/model/", "testPipeline")
sentimentPipeline, err := session.NewTextClassificationPipeline(modelPath, "testPipeline")
check(err)
// we can now use the pipeline for prediction on a batch of strings
batch := []string{"This movie is disgustingly good !", "The director tried too much"}
Expand All @@ -129,7 +132,7 @@ This will install the hugot binary at $HOME/.local/bin/hugot, and the correspond
The if $HOME/.local/bin is on your $PATH, you can do:

```
hugot run --model=/path/to/onnx/model --input=/path/to/input.jsonl --output=/path/to/folder/output --type=textClassification
hugot run --model=KnightsAnalytics/distilbert-base-uncased-finetuned-sst-2-english --input=/path/to/input.jsonl --output=/path/to/folder/output --type=textClassification
```

Hugot will load the model, process the input, and write the results in the output folder.
Expand Down Expand Up @@ -157,6 +160,10 @@ echo '{"input":"The director tried too much"}' | hugot run --model=/path/to/mode

To be able to run transformers fully from the command line.

Note that the --model parameter can be:
1. the full path to a model to load
2. the name of a huggingface model. Hugot will first try to look for the model at $HOME/hugot, or will try to download the model from huggingface.

## Contributing

### Development environment
Expand Down
51 changes: 50 additions & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"os"
"path"
"path/filepath"
"strings"
"sync"

"github.com/knights-analytics/hugot"
Expand All @@ -25,6 +26,7 @@ var outputPath string
var pipelineType string
var sharedLibraryPath string
var batchSize int
var modelsDir string

var runCommand = &cli.Command{
Name: "run",
Expand All @@ -34,7 +36,8 @@ var runCommand = &cli.Command{
ArgsUsage: `
--input: path to a .jsonl file or a folder with .jsonl files to process. If omitted, the input will be read from stdin.
--output: path to a folder where to write the output. If omitted, the output will be sent to stdout.
--model: path to the .onnx model to load.
--model: model name or path to the .onnx model to load. The hugot cli looks for models with this chain: first use the provided path. If the path does not exist, look for a model
with this name at $HOME/hugot/models. Finally, try to download the model from Huggingface and use it.
--type: pipeline type. Currently implemented types are: featureExtraction, tokenClassification, and textClassification (only single label)
--onnxruntimeSharedLibrary: path to the onnxruntime.so library. If not provided, the cli will try to load it from $HOME/lib/hugot/onnxruntime.so, and from /usr/lib/onnxruntime.so in the last instance.
`,
Expand Down Expand Up @@ -80,11 +83,27 @@ var runCommand = &cli.Command{
Required: false,
Value: 20,
},
&cli.StringFlag{
Name: "modelFolder",
Usage: "Folder where to store downloaded models. Falls back to $HOME/hugot/models if not specified",
Aliases: []string{"f"},
Destination: &modelsDir,
Required: false,
Value: "",
},
},
Action: func(ctx *cli.Context) error {

var onnxLibraryPathOpt hugot.SessionOption

if modelsDir == "" {
userDir, err := os.UserHomeDir()
if err != nil {
return err
}
modelsDir = util.PathJoinSafe(userDir, "hugot", "models")
}

if sharedLibraryPath != "" {
onnxLibraryPathOpt = hugot.WithOnnxLibraryPath(sharedLibraryPath)
} else {
Expand All @@ -110,6 +129,36 @@ var runCommand = &cli.Command{

var pipe pipelines.Pipeline

// is the model a full path to a model
ok, err := util.FileSystem.Exists(ctx.Context, modelPath)
if err != nil {
return err
}
if !ok {
// is the model the name of a model previously downloaded
downloadedModelName := strings.Replace(modelPath, "/", "_", -1)
ok, err = util.FileSystem.Exists(ctx.Context, util.PathJoinSafe(modelsDir, downloadedModelName))
if err != nil {
return err
}
if ok {
modelPath = util.PathJoinSafe(modelsDir, downloadedModelName)
} else {
// is the model the name of a model to download
if strings.Contains(modelPath, ":") {
return fmt.Errorf("filters with : are currently not supported")
}
err = util.FileSystem.Create(context.Background(), modelsDir, os.ModePerm, true)
if err != nil {
return err
}
modelPath, err = session.DownloadModel(modelPath, modelsDir, hugot.NewDownloadOptions())
if err != nil {
return err
}
}
}

switch pipelineType {
case "tokenClassification":
pipe, err = session.NewTokenClassificationPipeline(modelPath, "cliPipeline")
Expand Down
Loading

0 comments on commit 3db0c10

Please sign in to comment.