Skip to content

Commit

Permalink
resolve potential deadlock (#154)
Browse files Browse the repository at this point in the history
* resolve potential deadlock

* cancel signing request if client cancelled

* cancel work job when client cancels request

* address comments

* address comments
  • Loading branch information
hkadakia authored Mar 9, 2022
1 parent e7fbcd2 commit 9367c75
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 54 deletions.
2 changes: 1 addition & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const (
defaultIdleTimeout = 30
defaultReadTimeout = 10
defaultWriteTimeout = 10
DefaultPKCS11Timeout = 10
DefaultPKCS11Timeout = 10 // in seconds

// X509CertEndpoint specifies the endpoint for signing X509 certificate.
X509CertEndpoint = "/sig/x509-cert"
Expand Down
41 changes: 12 additions & 29 deletions pkcs11/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@ import (
It also has a channel on which it waits for the response.
*/
type Request struct {
pool sPool // pool is a signer pool per identifier from which to fetch the signer
identifier string // identifier indicates the endpoint for which we are fetching the signer in order to sign it
remainingTime time.Duration // remainingTime indicates the time remaining before either the client cancels or the request times out.
respChan chan signerWithSignAlgorithm // respChan is the channel where the worker sends the signer once it gets it from the pool
pool sPool // pool is a signer pool per identifier from which to fetch the signer
identifier string // identifier indicates the endpoint for which we are fetching the signer in order to sign it
respChan chan signerWithSignAlgorithm // respChan is the channel where the worker sends the signer once it gets it from the pool
}

// signer implements crypki.CertSign interface.
Expand All @@ -67,36 +66,19 @@ type signer struct {
requestTimeout uint
}

func getRemainingRequestTime(ctx context.Context, keyIdentifier string, requestTimeout uint) (time.Duration, error) {
remTime := time.Duration(requestTimeout) * time.Second
if deadline, ok := ctx.Deadline(); ok {
remTime = time.Until(deadline)
if remTime <= 0 {
// context expired, we should stop processing and return immediately
return 0, fmt.Errorf("context deadline expired for key identifier %q", keyIdentifier)
}
}
return remTime, nil
}

func getSigner(ctx context.Context, requestChan chan scheduler.Request, pool sPool, keyIdentifier string, priority proto.Priority, requestTimeout uint) (signer signerWithSignAlgorithm, err error) {
func getSigner(ctx context.Context, requestChan chan scheduler.Request, pool sPool, keyIdentifier string, priority proto.Priority) (signer signerWithSignAlgorithm, err error) {
// Need to handle case when we directly invoke SignSSHCert or SignX509Cert for
// either generating the host certs or X509 CA certs. In that case we don't need the server
// running nor do we need to worry about priority scheduling. In that case, we immediately
// fetch the signer from the pool.
if requestChan == nil {
return pool.get(ctx)
}
remTime, err := getRemainingRequestTime(ctx, keyIdentifier, requestTimeout)
if err != nil {
return nil, err
}
respChan := make(chan signerWithSignAlgorithm)
req := &Request{
pool: pool,
identifier: keyIdentifier,
remainingTime: remTime,
respChan: respChan,
pool: pool,
identifier: keyIdentifier,
respChan: respChan,
}
if priority == proto.Priority_Unspecified_priority {
// If priority is unspecified, treat the request as high priority.
Expand All @@ -116,7 +98,8 @@ func getSigner(ctx context.Context, requestChan chan scheduler.Request, pool sPo
return nil, errors.New("client request timed out, skip signing cert request")
}
case <-ctx.Done():
// In order to ensure we don't keep on blocking on the response, we add this check.
// In order to ensure we don't keep on blocking on the response, we close the response channel for this request & return.
close(respChan)
return nil, ctx.Err()
}
return signer, nil
Expand Down Expand Up @@ -210,7 +193,7 @@ func (s *signer) SignSSHCert(ctx context.Context, reqChan chan scheduler.Request
return nil, fmt.Errorf("unknown key identifier %q", keyIdentifier)
}
pStart := time.Now()
signer, err := getSigner(ctx, reqChan, pool, keyIdentifier, priority, s.requestTimeout)
signer, err := getSigner(ctx, reqChan, pool, keyIdentifier, priority)
if err != nil {
pt = time.Since(pStart).Nanoseconds() / time.Microsecond.Nanoseconds()
return nil, err
Expand Down Expand Up @@ -258,7 +241,7 @@ func (s *signer) SignX509Cert(ctx context.Context, reqChan chan scheduler.Reques
return nil, fmt.Errorf("unknown key identifier %q", keyIdentifier)
}
pStart := time.Now()
signer, err := getSigner(ctx, reqChan, pool, keyIdentifier, priority, s.requestTimeout)
signer, err := getSigner(ctx, reqChan, pool, keyIdentifier, priority)
if err != nil {
pt = time.Since(pStart).Nanoseconds() / time.Microsecond.Nanoseconds()
return nil, err
Expand Down Expand Up @@ -323,7 +306,7 @@ func (s *signer) SignBlob(ctx context.Context, reqChan chan scheduler.Request, d
return nil, fmt.Errorf("unknown key identifier %q", keyIdentifier)
}
pStart := time.Now()
signer, err := getSigner(ctx, reqChan, pool, keyIdentifier, priority, s.requestTimeout)
signer, err := getSigner(ctx, reqChan, pool, keyIdentifier, priority)
if err != nil {
pt = time.Since(pStart).Nanoseconds() / time.Microsecond.Nanoseconds()
return nil, err
Expand Down
8 changes: 5 additions & 3 deletions pkcs11/signer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func dummyScheduler(ctx context.Context, reqChan chan scheduler.Request) {
req := <-reqChan
go func() {
// create worker with different priorities
worker := &scheduler.Worker{ID: 1, Priority: req.Priority, Quit: make(chan struct{})}
worker := &scheduler.Worker{ID: 1, Priority: req.Priority, Quit: make(chan struct{}), HSMTimeout: 1 * time.Second}
req.DoWorker.DoWork(ctx, worker)
}()
}
Expand Down Expand Up @@ -336,7 +336,8 @@ func TestSignX509RSACert(t *testing.T) {
cp := x509.NewCertPool()
cp.AddCert(caCert)

ctx := context.Background()
ctx, cnc := context.WithTimeout(context.Background(), 1*time.Second)
defer cnc()
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()

Expand Down Expand Up @@ -424,7 +425,8 @@ func TestSignX509ECCert(t *testing.T) {
cp := x509.NewCertPool()
cp.AddCert(caCert)

ctx := context.Background()
ctx, cnc := context.WithTimeout(context.Background(), 1*time.Second)
defer cnc()
reqChan := make(chan scheduler.Request)
testcases := map[string]struct {
ctx context.Context
Expand Down
74 changes: 56 additions & 18 deletions pkcs11/work.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,68 @@ type Work struct {
work *Request // workChan is a channel which has a request enqueue for the worker to work on.
}

//DoWork performs the work of fetching the signer from the pool and sending it back on the response channel
// DoWork performs the work of fetching the signer from the pool and sending it back on the response channel.
// If the client cancels the request or times out, the worker should not wait indefinitely for getting the signer
// from the pool.We also have a PKCS11 timeout which is the maximum duration for which worker waits to fetch the
// signer from pool & cancel the client request if it exceeds that.
func (w *Work) DoWork(workerCtx context.Context, worker *scheduler.Worker) {
select {
case <-workerCtx.Done():
log.Printf("%s: worker stopped", worker.String())
return
default:
reqCtx, cancel := context.WithTimeout(context.Background(), w.work.remainingTime)
defer cancel()
reqCtx, cancel := context.WithTimeout(context.Background(), worker.HSMTimeout)
type resp struct {
signer signerWithSignAlgorithm
err error
}

signerRespCh := make(chan resp)
go func(ctx context.Context) {
signer, err := w.work.pool.get(reqCtx)
if err != nil {
worker.TotalTimeout.Inc()
log.Printf("%s: error fetching signer %v", worker.String(), err)
w.work.respChan <- nil
select {
case signerRespCh <- resp{signer, err}:
case <-ctx.Done():
return
}
}(workerCtx)

for {
select {
case <-reqCtx.Done():
// request timed out, increment timeout context & return nil.
worker.TotalTimeout.Inc()
w.work.respChan <- nil
default:
case <-workerCtx.Done():
// Case 1: Worker stopped either due to context cancelled or worker timed out.
// This case is to avoid worker being stuck in a blocking call or a deadlock scenario.
log.Printf("%s: worker stopped", worker.String())
cancel()
w.sendResponse(nil)
return
case resp := <-signerRespCh:
// Case 2: Received response. It could either be a pkcs11 timeout or thr worker was able to get a signer
// from the signer pool.
if resp.signer == nil || resp.err != nil {
worker.TotalTimeout.Inc()
log.Printf("%s: error fetching signer %v", worker.String(), resp.err)
w.work.respChan <- nil
cancel()
return
}
worker.TotalProcessed.Inc()
w.work.respChan <- signer
w.sendResponse(resp.signer)
cancel()
return
case _, ok := <-w.work.respChan:
// Case 3: Check for current state of respChan. If the client request is cancelled, the client
// will close the respChan. In that case, we should cancel reqCtx & return to avoid extra processing.
if !ok {
log.Printf("%s: worker request timed out, client cancelled request", worker.String())
cancel()
worker.TotalTimeout.Inc()
return
}
}
}
}

// sendResponse sends the response on the respChan if the channel is not yet closed by the client.
func (w *Work) sendResponse(resp signerWithSignAlgorithm) {
select {
case <-w.work.respChan:
default:
w.work.respChan <- resp
}
}
10 changes: 7 additions & 3 deletions server/scheduler/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ import (
"context"
"fmt"
"log"
"time"

"github.com/theparanoids/crypki/config"
"github.com/theparanoids/crypki/proto"
)

Expand All @@ -36,6 +38,7 @@ type Request struct {
type Worker struct {
ID int // ID is a unique id for the worker
Priority proto.Priority // Priority indicates the priority of the request the worker is handling.
HSMTimeout time.Duration // HSMTimeout is the max time a worker can wait to get signer from pool.
TotalProcessed Counter // TotalProcessed indicates the total requests processed per priority by this worker.
TotalTimeout Counter // TotalTimeout indicates the total requests that timed out before worker could process it.
Quit chan struct{} // Quit is a channel to cancel the worker
Expand All @@ -56,9 +59,10 @@ func (w *Worker) String() string {
// that the worker can add itself to when it is idle. It also creates a slice for storing totalProcessed requests.
func newWorker(workerId int, workerPriority proto.Priority) *Worker {
return &Worker{
ID: workerId,
Priority: workerPriority,
Quit: make(chan struct{}),
ID: workerId,
Priority: workerPriority,
HSMTimeout: config.DefaultPKCS11Timeout * time.Second,
Quit: make(chan struct{}),
}
}

Expand Down

0 comments on commit 9367c75

Please sign in to comment.