diff --git a/config/config.go b/config/config.go index aba8fab4..32b23507 100644 --- a/config/config.go +++ b/config/config.go @@ -31,7 +31,7 @@ const ( defaultIdleTimeout = 30 defaultReadTimeout = 10 defaultWriteTimeout = 10 - DefaultPKCS11Timeout = 10 + DefaultPKCS11Timeout = 10 // in seconds // X509CertEndpoint specifies the endpoint for signing X509 certificate. X509CertEndpoint = "/sig/x509-cert" diff --git a/pkcs11/signer.go b/pkcs11/signer.go index 2a893f8d..e0483025 100644 --- a/pkcs11/signer.go +++ b/pkcs11/signer.go @@ -42,10 +42,9 @@ import ( It also has a channel on which it waits for the response. */ type Request struct { - pool sPool // pool is a signer pool per identifier from which to fetch the signer - identifier string // identifier indicates the endpoint for which we are fetching the signer in order to sign it - remainingTime time.Duration // remainingTime indicates the time remaining before either the client cancels or the request times out. - respChan chan signerWithSignAlgorithm // respChan is the channel where the worker sends the signer once it gets it from the pool + pool sPool // pool is a signer pool per identifier from which to fetch the signer + identifier string // identifier indicates the endpoint for which we are fetching the signer in order to sign it + respChan chan signerWithSignAlgorithm // respChan is the channel where the worker sends the signer once it gets it from the pool } // signer implements crypki.CertSign interface. @@ -67,19 +66,7 @@ type signer struct { requestTimeout uint } -func getRemainingRequestTime(ctx context.Context, keyIdentifier string, requestTimeout uint) (time.Duration, error) { - remTime := time.Duration(requestTimeout) * time.Second - if deadline, ok := ctx.Deadline(); ok { - remTime = time.Until(deadline) - if remTime <= 0 { - // context expired, we should stop processing and return immediately - return 0, fmt.Errorf("context deadline expired for key identifier %q", keyIdentifier) - } - } - return remTime, nil -} - -func getSigner(ctx context.Context, requestChan chan scheduler.Request, pool sPool, keyIdentifier string, priority proto.Priority, requestTimeout uint) (signer signerWithSignAlgorithm, err error) { +func getSigner(ctx context.Context, requestChan chan scheduler.Request, pool sPool, keyIdentifier string, priority proto.Priority) (signer signerWithSignAlgorithm, err error) { // Need to handle case when we directly invoke SignSSHCert or SignX509Cert for // either generating the host certs or X509 CA certs. In that case we don't need the server // running nor do we need to worry about priority scheduling. In that case, we immediately @@ -87,16 +74,11 @@ func getSigner(ctx context.Context, requestChan chan scheduler.Request, pool sPo if requestChan == nil { return pool.get(ctx) } - remTime, err := getRemainingRequestTime(ctx, keyIdentifier, requestTimeout) - if err != nil { - return nil, err - } respChan := make(chan signerWithSignAlgorithm) req := &Request{ - pool: pool, - identifier: keyIdentifier, - remainingTime: remTime, - respChan: respChan, + pool: pool, + identifier: keyIdentifier, + respChan: respChan, } if priority == proto.Priority_Unspecified_priority { // If priority is unspecified, treat the request as high priority. @@ -116,7 +98,8 @@ func getSigner(ctx context.Context, requestChan chan scheduler.Request, pool sPo return nil, errors.New("client request timed out, skip signing cert request") } case <-ctx.Done(): - // In order to ensure we don't keep on blocking on the response, we add this check. + // In order to ensure we don't keep on blocking on the response, we close the response channel for this request & return. + close(respChan) return nil, ctx.Err() } return signer, nil @@ -210,7 +193,7 @@ func (s *signer) SignSSHCert(ctx context.Context, reqChan chan scheduler.Request return nil, fmt.Errorf("unknown key identifier %q", keyIdentifier) } pStart := time.Now() - signer, err := getSigner(ctx, reqChan, pool, keyIdentifier, priority, s.requestTimeout) + signer, err := getSigner(ctx, reqChan, pool, keyIdentifier, priority) if err != nil { pt = time.Since(pStart).Nanoseconds() / time.Microsecond.Nanoseconds() return nil, err @@ -258,7 +241,7 @@ func (s *signer) SignX509Cert(ctx context.Context, reqChan chan scheduler.Reques return nil, fmt.Errorf("unknown key identifier %q", keyIdentifier) } pStart := time.Now() - signer, err := getSigner(ctx, reqChan, pool, keyIdentifier, priority, s.requestTimeout) + signer, err := getSigner(ctx, reqChan, pool, keyIdentifier, priority) if err != nil { pt = time.Since(pStart).Nanoseconds() / time.Microsecond.Nanoseconds() return nil, err @@ -323,7 +306,7 @@ func (s *signer) SignBlob(ctx context.Context, reqChan chan scheduler.Request, d return nil, fmt.Errorf("unknown key identifier %q", keyIdentifier) } pStart := time.Now() - signer, err := getSigner(ctx, reqChan, pool, keyIdentifier, priority, s.requestTimeout) + signer, err := getSigner(ctx, reqChan, pool, keyIdentifier, priority) if err != nil { pt = time.Since(pStart).Nanoseconds() / time.Microsecond.Nanoseconds() return nil, err diff --git a/pkcs11/signer_test.go b/pkcs11/signer_test.go index b474bb7b..9eb31a2b 100644 --- a/pkcs11/signer_test.go +++ b/pkcs11/signer_test.go @@ -113,7 +113,7 @@ func dummyScheduler(ctx context.Context, reqChan chan scheduler.Request) { req := <-reqChan go func() { // create worker with different priorities - worker := &scheduler.Worker{ID: 1, Priority: req.Priority, Quit: make(chan struct{})} + worker := &scheduler.Worker{ID: 1, Priority: req.Priority, Quit: make(chan struct{}), HSMTimeout: 1 * time.Second} req.DoWorker.DoWork(ctx, worker) }() } @@ -336,7 +336,8 @@ func TestSignX509RSACert(t *testing.T) { cp := x509.NewCertPool() cp.AddCert(caCert) - ctx := context.Background() + ctx, cnc := context.WithTimeout(context.Background(), 1*time.Second) + defer cnc() cancelCtx, cancel := context.WithCancel(ctx) defer cancel() @@ -424,7 +425,8 @@ func TestSignX509ECCert(t *testing.T) { cp := x509.NewCertPool() cp.AddCert(caCert) - ctx := context.Background() + ctx, cnc := context.WithTimeout(context.Background(), 1*time.Second) + defer cnc() reqChan := make(chan scheduler.Request) testcases := map[string]struct { ctx context.Context diff --git a/pkcs11/work.go b/pkcs11/work.go index e64a212c..26e5a48c 100644 --- a/pkcs11/work.go +++ b/pkcs11/work.go @@ -26,30 +26,68 @@ type Work struct { work *Request // workChan is a channel which has a request enqueue for the worker to work on. } -//DoWork performs the work of fetching the signer from the pool and sending it back on the response channel +// DoWork performs the work of fetching the signer from the pool and sending it back on the response channel. +// If the client cancels the request or times out, the worker should not wait indefinitely for getting the signer +// from the pool.We also have a PKCS11 timeout which is the maximum duration for which worker waits to fetch the +// signer from pool & cancel the client request if it exceeds that. func (w *Work) DoWork(workerCtx context.Context, worker *scheduler.Worker) { - select { - case <-workerCtx.Done(): - log.Printf("%s: worker stopped", worker.String()) - return - default: - reqCtx, cancel := context.WithTimeout(context.Background(), w.work.remainingTime) - defer cancel() + reqCtx, cancel := context.WithTimeout(context.Background(), worker.HSMTimeout) + type resp struct { + signer signerWithSignAlgorithm + err error + } + + signerRespCh := make(chan resp) + go func(ctx context.Context) { signer, err := w.work.pool.get(reqCtx) - if err != nil { - worker.TotalTimeout.Inc() - log.Printf("%s: error fetching signer %v", worker.String(), err) - w.work.respChan <- nil + select { + case signerRespCh <- resp{signer, err}: + case <-ctx.Done(): return } + }(workerCtx) + + for { select { - case <-reqCtx.Done(): - // request timed out, increment timeout context & return nil. - worker.TotalTimeout.Inc() - w.work.respChan <- nil - default: + case <-workerCtx.Done(): + // Case 1: Worker stopped either due to context cancelled or worker timed out. + // This case is to avoid worker being stuck in a blocking call or a deadlock scenario. + log.Printf("%s: worker stopped", worker.String()) + cancel() + w.sendResponse(nil) + return + case resp := <-signerRespCh: + // Case 2: Received response. It could either be a pkcs11 timeout or thr worker was able to get a signer + // from the signer pool. + if resp.signer == nil || resp.err != nil { + worker.TotalTimeout.Inc() + log.Printf("%s: error fetching signer %v", worker.String(), resp.err) + w.work.respChan <- nil + cancel() + return + } worker.TotalProcessed.Inc() - w.work.respChan <- signer + w.sendResponse(resp.signer) + cancel() + return + case _, ok := <-w.work.respChan: + // Case 3: Check for current state of respChan. If the client request is cancelled, the client + // will close the respChan. In that case, we should cancel reqCtx & return to avoid extra processing. + if !ok { + log.Printf("%s: worker request timed out, client cancelled request", worker.String()) + cancel() + worker.TotalTimeout.Inc() + return + } } } } + +// sendResponse sends the response on the respChan if the channel is not yet closed by the client. +func (w *Work) sendResponse(resp signerWithSignAlgorithm) { + select { + case <-w.work.respChan: + default: + w.work.respChan <- resp + } +} diff --git a/server/scheduler/worker.go b/server/scheduler/worker.go index 24a1086a..6dc52923 100644 --- a/server/scheduler/worker.go +++ b/server/scheduler/worker.go @@ -17,7 +17,9 @@ import ( "context" "fmt" "log" + "time" + "github.com/theparanoids/crypki/config" "github.com/theparanoids/crypki/proto" ) @@ -36,6 +38,7 @@ type Request struct { type Worker struct { ID int // ID is a unique id for the worker Priority proto.Priority // Priority indicates the priority of the request the worker is handling. + HSMTimeout time.Duration // HSMTimeout is the max time a worker can wait to get signer from pool. TotalProcessed Counter // TotalProcessed indicates the total requests processed per priority by this worker. TotalTimeout Counter // TotalTimeout indicates the total requests that timed out before worker could process it. Quit chan struct{} // Quit is a channel to cancel the worker @@ -56,9 +59,10 @@ func (w *Worker) String() string { // that the worker can add itself to when it is idle. It also creates a slice for storing totalProcessed requests. func newWorker(workerId int, workerPriority proto.Priority) *Worker { return &Worker{ - ID: workerId, - Priority: workerPriority, - Quit: make(chan struct{}), + ID: workerId, + Priority: workerPriority, + HSMTimeout: config.DefaultPKCS11Timeout * time.Second, + Quit: make(chan struct{}), } }