Skip to content

Commit

Permalink
add rigging for hive to pre-screen images before sending to hive (#728)
Browse files Browse the repository at this point in the history
for now, just going to log if the pre-screen fails at its job
  • Loading branch information
whyrusleeping authored Aug 27, 2024
2 parents ce365b1 + 0773430 commit 05d4210
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 0 deletions.
2 changes: 2 additions & 0 deletions automod/visual/hiveai_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
type HiveAIClient struct {
Client http.Client
ApiToken string

PreScreenClient *PreScreenClient
}

// schema: https://docs.thehive.ai/reference/classification
Expand Down
18 changes: 18 additions & 0 deletions automod/visual/hiveai_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,29 @@ func (hal *HiveAIClient) HiveLabelBlobRule(c *automod.RecordContext, blob lexuti
return nil
}

var psclabel string
if hal.PreScreenClient != nil {
val, err := hal.PreScreenClient.PreScreenImage(c.Ctx, data)
if err != nil {
c.Logger.Info("prescreen-request-error", "err", err)
} else {
psclabel = val
}
}

labels, err := hal.LabelBlob(c.Ctx, blob, data)
if err != nil {
return err
}

if psclabel == "sfw" {
if len(labels) > 0 {
c.Logger.Info("prescreen-safe-failure", "uri", c.RecordOp.ATURI())
} else {
c.Logger.Info("prescreen-safe-success", "uri", c.RecordOp.ATURI())
}
}

for _, l := range labels {
c.AddRecordLabel(l)
}
Expand Down
130 changes: 130 additions & 0 deletions automod/visual/prescreen.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package visual

import (
"bytes"
"context"
"encoding/json"
"fmt"
"mime/multipart"
"net/http"
"sync"
"time"
)

const failureThresh = 10

type PreScreenClient struct {
Host string
Token string

breakerEOL time.Time
breakerLk sync.Mutex
failures int

c *http.Client
}

func NewPreScreenClient(host, token string) *PreScreenClient {
c := &http.Client{
Timeout: time.Second * 5,
}

return &PreScreenClient{
Host: host,
Token: token,
c: c,
}
}

func (c *PreScreenClient) available() bool {
c.breakerLk.Lock()
defer c.breakerLk.Unlock()
if c.breakerEOL.IsZero() {
return true
}

if time.Now().After(c.breakerEOL) {
c.breakerEOL = time.Time{}
return true
}

return false
}

func (c *PreScreenClient) recordCallResult(success bool) {
c.breakerLk.Lock()
defer c.breakerLk.Unlock()
if !c.breakerEOL.IsZero() {
return
}

if success {
c.failures = 0
} else {
c.failures++
if c.failures > failureThresh {
c.breakerEOL = time.Now().Add(time.Minute)
c.failures = 0
}
}
}

func (c *PreScreenClient) PreScreenImage(ctx context.Context, blob []byte) (string, error) {
if !c.available() {
return "", fmt.Errorf("pre-screening temporarily unavailable")
}

res, err := c.checkImage(ctx, blob)
if err != nil {
c.recordCallResult(false)
return "", err
}

c.recordCallResult(true)
return res, nil
}

type PreScreenResult struct {
Result string `json:"result"`
}

func (c *PreScreenClient) checkImage(ctx context.Context, data []byte) (string, error) {
url := c.Host + "/predict"

body := new(bytes.Buffer)
writer := multipart.NewWriter(body)

part, err := writer.CreateFormFile("files", "image")
if err != nil {
return "", err
}

part.Write(data)

if err := writer.Close(); err != nil {
return "", err
}

req, err := http.NewRequest("POST", url, body)
if err != nil {
return "", err
}

req = req.WithContext(ctx)

req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("Authorization", "Bearer "+c.Token)

resp, err := c.c.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()

var out PreScreenResult
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return "", err
}

return out.Result, nil
}
14 changes: 14 additions & 0 deletions cmd/hepa/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ func run(args []string) error {
Usage: "force a fixed number of parallel firehose workers. default (or 0) for auto-scaling; 200 works for a large instance",
EnvVars: []string{"HEPA_FIREHOSE_PARALLELISM"},
},
&cli.StringFlag{
Name: "prescreen-host",
Usage: "hostname of prescreen server",
EnvVars: []string{"HEPA_PRESCREEN_HOST"},
},
&cli.StringFlag{
Name: "prescreen-token",
Usage: "secret token for prescreen server",
EnvVars: []string{"HEPA_PRESCREEN_TOKEN"},
},
}

app.Commands = []*cli.Command{
Expand Down Expand Up @@ -242,6 +252,8 @@ var runCmd = &cli.Command{
RatelimitBypass: cctx.String("ratelimit-bypass"),
RulesetName: cctx.String("ruleset"),
FirehoseParallelism: cctx.Int("firehose-parallelism"),
PreScreenHost: cctx.String("prescreen-host"),
PreScreenToken: cctx.String("prescreen-token"),
},
)
if err != nil {
Expand Down Expand Up @@ -316,6 +328,8 @@ func configEphemeralServer(cctx *cli.Context) (*Server, error) {
RatelimitBypass: cctx.String("ratelimit-bypass"),
RulesetName: cctx.String("ruleset"),
FirehoseParallelism: cctx.Int("firehose-parallelism"),
PreScreenHost: cctx.String("prescreen-host"),
PreScreenToken: cctx.String("prescreen-token"),
},
)
}
Expand Down
7 changes: 7 additions & 0 deletions cmd/hepa/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ type Config struct {
RulesetName string
RatelimitBypass string
FirehoseParallelism int
PreScreenHost string
PreScreenToken string
}

func NewServer(dir identity.Directory, config Config) (*Server, error) {
Expand Down Expand Up @@ -169,6 +171,11 @@ func NewServer(dir identity.Directory, config Config) (*Server, error) {
logger.Info("configuring Hive AI image labeler")
hc := visual.NewHiveAIClient(config.HiveAPIToken)
extraBlobRules = append(extraBlobRules, hc.HiveLabelBlobRule)

if config.PreScreenHost != "" {
psc := visual.NewPreScreenClient(config.PreScreenHost, config.PreScreenToken)
hc.PreScreenClient = psc
}
}

if config.AbyssHost != "" && config.AbyssPassword != "" {
Expand Down

0 comments on commit 05d4210

Please sign in to comment.