diff --git a/automod/visual/hiveai_client.go b/automod/visual/hiveai_client.go index 8b4a78a4d..ea6163b58 100644 --- a/automod/visual/hiveai_client.go +++ b/automod/visual/hiveai_client.go @@ -20,6 +20,8 @@ import ( type HiveAIClient struct { Client http.Client ApiToken string + + PreScreenClient *PreScreenClient } // schema: https://docs.thehive.ai/reference/classification diff --git a/automod/visual/hiveai_rule.go b/automod/visual/hiveai_rule.go index cc5946f8a..ac1af4bff 100644 --- a/automod/visual/hiveai_rule.go +++ b/automod/visual/hiveai_rule.go @@ -13,11 +13,29 @@ func (hal *HiveAIClient) HiveLabelBlobRule(c *automod.RecordContext, blob lexuti return nil } + var psclabel string + if hal.PreScreenClient != nil { + val, err := hal.PreScreenClient.PreScreenImage(c.Ctx, data) + if err != nil { + c.Logger.Info("prescreen-request-error", "err", err) + } else { + psclabel = val + } + } + labels, err := hal.LabelBlob(c.Ctx, blob, data) if err != nil { return err } + if psclabel == "sfw" { + if len(labels) > 0 { + c.Logger.Info("prescreen-safe-failure", "uri", c.RecordOp.ATURI()) + } else { + c.Logger.Info("prescreen-safe-success", "uri", c.RecordOp.ATURI()) + } + } + for _, l := range labels { c.AddRecordLabel(l) } diff --git a/automod/visual/prescreen.go b/automod/visual/prescreen.go new file mode 100644 index 000000000..1bb6cfe25 --- /dev/null +++ b/automod/visual/prescreen.go @@ -0,0 +1,130 @@ +package visual + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "mime/multipart" + "net/http" + "sync" + "time" +) + +const failureThresh = 10 + +type PreScreenClient struct { + Host string + Token string + + breakerEOL time.Time + breakerLk sync.Mutex + failures int + + c *http.Client +} + +func NewPreScreenClient(host, token string) *PreScreenClient { + c := &http.Client{ + Timeout: time.Second * 5, + } + + return &PreScreenClient{ + Host: host, + Token: token, + c: c, + } +} + +func (c *PreScreenClient) available() bool { + c.breakerLk.Lock() + defer c.breakerLk.Unlock() + if c.breakerEOL.IsZero() { + return true + } + + if time.Now().After(c.breakerEOL) { + c.breakerEOL = time.Time{} + return true + } + + return false +} + +func (c *PreScreenClient) recordCallResult(success bool) { + c.breakerLk.Lock() + defer c.breakerLk.Unlock() + if !c.breakerEOL.IsZero() { + return + } + + if success { + c.failures = 0 + } else { + c.failures++ + if c.failures > failureThresh { + c.breakerEOL = time.Now().Add(time.Minute) + c.failures = 0 + } + } +} + +func (c *PreScreenClient) PreScreenImage(ctx context.Context, blob []byte) (string, error) { + if !c.available() { + return "", fmt.Errorf("pre-screening temporarily unavailable") + } + + res, err := c.checkImage(ctx, blob) + if err != nil { + c.recordCallResult(false) + return "", err + } + + c.recordCallResult(true) + return res, nil +} + +type PreScreenResult struct { + Result string `json:"result"` +} + +func (c *PreScreenClient) checkImage(ctx context.Context, data []byte) (string, error) { + url := c.Host + "/predict" + + body := new(bytes.Buffer) + writer := multipart.NewWriter(body) + + part, err := writer.CreateFormFile("files", "image") + if err != nil { + return "", err + } + + part.Write(data) + + if err := writer.Close(); err != nil { + return "", err + } + + req, err := http.NewRequest("POST", url, body) + if err != nil { + return "", err + } + + req = req.WithContext(ctx) + + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Authorization", "Bearer "+c.Token) + + resp, err := c.c.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var out PreScreenResult + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return "", err + } + + return out.Result, nil +} diff --git a/cmd/hepa/main.go b/cmd/hepa/main.go index bd2447e95..6cc0bcb82 100644 --- a/cmd/hepa/main.go +++ b/cmd/hepa/main.go @@ -138,6 +138,16 @@ func run(args []string) error { Usage: "force a fixed number of parallel firehose workers. default (or 0) for auto-scaling; 200 works for a large instance", EnvVars: []string{"HEPA_FIREHOSE_PARALLELISM"}, }, + &cli.StringFlag{ + Name: "prescreen-host", + Usage: "hostname of prescreen server", + EnvVars: []string{"HEPA_PRESCREEN_HOST"}, + }, + &cli.StringFlag{ + Name: "prescreen-token", + Usage: "secret token for prescreen server", + EnvVars: []string{"HEPA_PRESCREEN_TOKEN"}, + }, } app.Commands = []*cli.Command{ @@ -242,6 +252,8 @@ var runCmd = &cli.Command{ RatelimitBypass: cctx.String("ratelimit-bypass"), RulesetName: cctx.String("ruleset"), FirehoseParallelism: cctx.Int("firehose-parallelism"), + PreScreenHost: cctx.String("prescreen-host"), + PreScreenToken: cctx.String("prescreen-token"), }, ) if err != nil { @@ -316,6 +328,8 @@ func configEphemeralServer(cctx *cli.Context) (*Server, error) { RatelimitBypass: cctx.String("ratelimit-bypass"), RulesetName: cctx.String("ruleset"), FirehoseParallelism: cctx.Int("firehose-parallelism"), + PreScreenHost: cctx.String("prescreen-host"), + PreScreenToken: cctx.String("prescreen-token"), }, ) } diff --git a/cmd/hepa/server.go b/cmd/hepa/server.go index 265db91b3..5f3c032c0 100644 --- a/cmd/hepa/server.go +++ b/cmd/hepa/server.go @@ -61,6 +61,8 @@ type Config struct { RulesetName string RatelimitBypass string FirehoseParallelism int + PreScreenHost string + PreScreenToken string } func NewServer(dir identity.Directory, config Config) (*Server, error) { @@ -169,6 +171,11 @@ func NewServer(dir identity.Directory, config Config) (*Server, error) { logger.Info("configuring Hive AI image labeler") hc := visual.NewHiveAIClient(config.HiveAPIToken) extraBlobRules = append(extraBlobRules, hc.HiveLabelBlobRule) + + if config.PreScreenHost != "" { + psc := visual.NewPreScreenClient(config.PreScreenHost, config.PreScreenToken) + hc.PreScreenClient = psc + } } if config.AbyssHost != "" && config.AbyssPassword != "" {