Skip to content

Commit

Permalink
Merge pull request choria-io#111 from ripienaar/109
Browse files Browse the repository at this point in the history
(choria-io#109) Sign and Verify tasks in the client and cli
  • Loading branch information
ripienaar authored May 8, 2023
2 parents ced4df5 + 134aa01 commit 8e395d9
Show file tree
Hide file tree
Showing 12 changed files with 401 additions and 94 deletions.
86 changes: 58 additions & 28 deletions ajc/task_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"crypto/ed25519"
"encoding/hex"
"fmt"
"os"
"strings"
"time"

Expand Down Expand Up @@ -40,6 +39,8 @@ type taskCommand struct {
dependencies []string
loadDepResults bool
ed25519Seed string
ed25519PubKey string
optionalSigs bool

limit int
json bool
Expand All @@ -50,6 +51,8 @@ func configureTaskCommand(app *fisk.Application) {
c := &taskCommand{}

tasks := app.Command("tasks", "Manage Tasks").Alias("t").Alias("task")
tasks.Flag("sign", "Signs tasks using an ed25519 seed").StringVar(&c.ed25519Seed)
tasks.Flag("verify", "Verifies tasks using an ed25519 public key").StringVar(&c.ed25519PubKey)

add := tasks.Command("add", "Adds a new Task to a queue").Alias("new").Alias("a").Alias("enqueue").Action(c.addAction)
add.Arg("type", "The task type").Required().StringVar(&c.ttype)
Expand All @@ -59,7 +62,6 @@ func configureTaskCommand(app *fisk.Application) {
add.Flag("tries", "Sets the maximum amount of times this task may be tried").IntVar(&c.maxtries)
add.Flag("depends", "Sets IDs to depend on, comma sep or pass multiple times").StringsVar(&c.dependencies)
add.Flag("load", "Loads results from dependencies before executing task").BoolVar(&c.loadDepResults)
add.Flag("sign", "Signs the task using an ed25519 seed").StringVar(&c.ed25519Seed)

retry := tasks.Command("retry", "Retries delivery of a task currently in the Task Store").Action(c.retryAction)
retry.Arg("id", "The Task ID to view").Required().StringVar(&c.id)
Expand Down Expand Up @@ -106,8 +108,52 @@ func configureTaskCommand(app *fisk.Application) {
configureTaskCronCommand(tasks)
}

func (c *taskCommand) prepare(copts ...aj.ClientOpt) error {
sigOpts, err := c.clientOpts()
if err != nil {
return err
}

return prepare(append(copts, sigOpts...)...)
}

func (c *taskCommand) clientOpts() ([]aj.ClientOpt, error) {
var opts []aj.ClientOpt

if c.optionalSigs {
opts = append(opts, aj.TaskSignaturesOptional())
}

if c.ed25519Seed != "" {
if fileExist(c.ed25519Seed) {
opts = append(opts, aj.TaskSigningSeedFile(c.ed25519Seed))
} else {
seed, err := hex.DecodeString(c.ed25519Seed)
if err != nil {
return nil, err
}

opts = append(opts, aj.TaskSigningKey(ed25519.NewKeyFromSeed(seed)))
}
}

if c.ed25519PubKey != "" {
if fileExist(c.ed25519PubKey) {
opts = append(opts, aj.TaskVerificationKeyFile(c.ed25519PubKey))
} else {
pk, err := hex.DecodeString(c.ed25519PubKey)
if err != nil {
return nil, err
}
opts = append(opts, aj.TaskVerificationKey(pk))
}
}

return opts, nil
}

func (c *taskCommand) retryAction(_ *fisk.ParseContext) error {
err := prepare(aj.BindWorkQueue(c.queue))
err := c.prepare(aj.BindWorkQueue(c.queue))
if err != nil {
return err
}
Expand All @@ -121,7 +167,7 @@ func (c *taskCommand) retryAction(_ *fisk.ParseContext) error {
}

func (c *taskCommand) initAction(_ *fisk.ParseContext) error {
err := prepare(aj.NoStorageInit())
err := c.prepare(aj.NoStorageInit())
if err != nil {
return err
}
Expand All @@ -142,7 +188,7 @@ func (c *taskCommand) initAction(_ *fisk.ParseContext) error {
}

func (c *taskCommand) watchAction(_ *fisk.ParseContext) error {
err := prepare()
err := c.prepare()
if err != nil {
return err
}
Expand Down Expand Up @@ -208,7 +254,7 @@ func (c *taskCommand) processAction(_ *fisk.ParseContext) error {
copts = append(copts, aj.DiscardTaskStates(aj.TaskStateExpired))
}

err := prepare(copts...)
err := c.prepare(copts...)
if err != nil {
return err
}
Expand All @@ -227,7 +273,7 @@ func (c *taskCommand) processAction(_ *fisk.ParseContext) error {
}

func (c *taskCommand) purgeAction(_ *fisk.ParseContext) error {
err := prepare()
err := c.prepare()
if err != nil {
return err
}
Expand Down Expand Up @@ -260,7 +306,7 @@ func (c *taskCommand) purgeAction(_ *fisk.ParseContext) error {
}

func (c *taskCommand) configAction(_ *fisk.ParseContext) error {
err := prepare()
err := c.prepare()
if err != nil {
return err
}
Expand Down Expand Up @@ -291,7 +337,7 @@ func (c *taskCommand) configAction(_ *fisk.ParseContext) error {
}

func (c *taskCommand) lsAction(_ *fisk.ParseContext) error {
err := prepare()
err := c.prepare()
if err != nil {
return err
}
Expand Down Expand Up @@ -323,7 +369,7 @@ func (c *taskCommand) lsAction(_ *fisk.ParseContext) error {
}

func (c *taskCommand) rmAction(_ *fisk.ParseContext) error {
err := prepare()
err := c.prepare()
if err != nil {
return err
}
Expand All @@ -346,7 +392,7 @@ func (c *taskCommand) rmAction(_ *fisk.ParseContext) error {
}

func (c *taskCommand) viewAction(_ *fisk.ParseContext) error {
err := prepare()
err := c.prepare()
if err != nil {
return err
}
Expand Down Expand Up @@ -394,7 +440,7 @@ func (c *taskCommand) viewAction(_ *fisk.ParseContext) error {
}

func (c *taskCommand) addAction(_ *fisk.ParseContext) error {
err := prepare(aj.BindWorkQueue(c.queue))
err := c.prepare(aj.BindWorkQueue(c.queue))
if err != nil {
return err
}
Expand All @@ -420,22 +466,6 @@ func (c *taskCommand) addAction(_ *fisk.ParseContext) error {
opts = append(opts, aj.TaskMaxTries(c.maxtries))
}

if c.ed25519Seed != "" {
var seed []byte
if fileExist(c.ed25519Seed) {
seed, err = os.ReadFile(c.ed25519Seed)
if err != nil {
return err
}
} else {
seed, err = hex.DecodeString(c.ed25519Seed)
if err != nil {
return err
}
}
opts = append(opts, aj.TaskSigner(ed25519.NewKeyFromSeed(seed)))
}

task, err := aj.NewTask(c.ttype, c.payload, opts...)
if err != nil {
return err
Expand Down
1 change: 1 addition & 0 deletions ajc/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func createLogger() {

log = logrus.NewEntry(logger)
}

func prepare(copts ...asyncjobs.ClientOpt) error {
if client != nil {
return nil
Expand Down
148 changes: 147 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@ package asyncjobs

import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
"os"
"time"

"github.com/prometheus/client_golang/prometheus/promhttp"
Expand Down Expand Up @@ -41,6 +46,11 @@ func NewClient(opts ...ClientOpt) (*Client, error) {
}
}

err = copts.validate()
if err != nil {
return nil, err
}

c := &Client{opts: copts, log: copts.logger}
c.storage, err = newJetStreamStorage(copts.nc, copts.retryPolicy, c.log)
if err != nil {
Expand Down Expand Up @@ -85,7 +95,17 @@ func (c *Client) Run(ctx context.Context, router *Mux) error {

// LoadTaskByID loads a task from the backend using its ID
func (c *Client) LoadTaskByID(id string) (*Task, error) {
return c.storage.LoadTaskByID(id)
task, err := c.storage.LoadTaskByID(id)
if err != nil {
return nil, err
}

err = c.verifyTaskSignature(task)
if err != nil {
return nil, err
}

return task, nil
}

// RetryTaskByID will retry a task, first removing an entry from the Work Queue if already there
Expand All @@ -95,9 +115,135 @@ func (c *Client) RetryTaskByID(ctx context.Context, id string) error {

// EnqueueTask adds a task to the named queue which must already exist
func (c *Client) EnqueueTask(ctx context.Context, task *Task) error {
task.Queue = c.opts.queue.Name

err := c.signTask(task)
if err != nil {
return err
}

return c.opts.queue.enqueueTask(ctx, task)
}

func (c *Client) verifyTaskSignature(task *Task) error {
// is disabled
if c.opts.publicKey == nil && c.opts.publicKeyFile == "" {
return nil
}

switch {
case !c.opts.optionalTaskSignatures && task.Signature == "":
return ErrTaskNotSigned

case task.Signature == "":
return nil
}

var pubKey ed25519.PublicKey

switch {
case c.opts.publicKey != nil:
pubKey = c.opts.publicKey

case c.opts.publicKeyFile != "":
kf, err := os.ReadFile(c.opts.publicKeyFile)
if err != nil {
return err
}

kb, err := hex.DecodeString(string(kf))
if err != nil {
return err
}

if len(kb) != ed25519.PublicKeySize {
return fmt.Errorf("invalid public key")
}

pubKey = kb

case c.opts.seedFile != "":
sf, err := os.ReadFile(c.opts.seedFile)
if err != nil {
return err
}

sb, err := hex.DecodeString(string(sf))
if err != nil {
return err
}

if len(sb) != ed25519.SeedSize {
return fmt.Errorf("invalid seed length")
}

pk := ed25519.NewKeyFromSeed(sb)
defer func() {
io.ReadFull(rand.Reader, pk[:])
io.ReadFull(rand.Reader, sb[:])
io.ReadFull(rand.Reader, sf[:])
}()

pubKey = pk.Public().(ed25519.PublicKey)

default:
if !c.opts.optionalTaskSignatures {
return fmt.Errorf("no task verification keys configured")
}

return nil
}

msg, err := task.signatureMessage()
if err != nil {
return fmt.Errorf("%w: %v", ErrTaskSignatureInvalid, err)
}

sig, err := hex.DecodeString(task.Signature)
if err != nil {
return fmt.Errorf("%w: %v", ErrTaskSignatureInvalid, err)
}

if !ed25519.Verify(pubKey, msg, sig) {
return ErrTaskSignatureInvalid
}

return nil
}

func (c *Client) signTask(task *Task) error {
switch {
case c.opts.privateKey != nil:
return task.sign(c.opts.privateKey)

case c.opts.seedFile != "":
sf, err := os.ReadFile(c.opts.seedFile)
if err != nil {
return err
}

sb, err := hex.DecodeString(string(sf))
if err != nil {
return err
}

if len(sb) != ed25519.SeedSize {
return fmt.Errorf("invalid seed length")
}

pk := ed25519.NewKeyFromSeed(sb)
defer func() {
io.ReadFull(rand.Reader, pk[:])
io.ReadFull(rand.Reader, sb[:])
io.ReadFull(rand.Reader, sf[:])
}()

return task.sign(pk)
}

return nil
}

// StorageAdmin access admin features of the storage backend
func (c *Client) StorageAdmin() StorageAdmin {
return c.storage.(*jetStreamStorage)
Expand Down
Loading

0 comments on commit 8e395d9

Please sign in to comment.