diff --git a/api/x509cert.go b/api/x509cert.go index 3f879910..8bfd0fec 100644 --- a/api/x509cert.go +++ b/api/x509cert.go @@ -56,18 +56,43 @@ func (s *SigningService) GetX509CACertificate(ctx context.Context, keyMeta *prot return nil, status.Errorf(codes.InvalidArgument, "Bad request: %v", err) } + // Create a context with server side timeout. + reqCtx, cancel := context.WithTimeout(ctx, config.DefaultPKCS11Timeout) + defer cancel() // Cancel ctx as soon as GetX509CACertificate returns. + if !s.KeyUsages[config.X509CertEndpoint][keyMeta.Identifier] { statusCode = http.StatusBadRequest err = fmt.Errorf("cannot use key %q for %q", keyMeta.Identifier, config.X509CertEndpoint) return nil, status.Errorf(codes.InvalidArgument, "Bad request: %v", err) } - cert, err := s.GetX509CACert(ctx, keyMeta.Identifier) - if err != nil { - statusCode = http.StatusInternalServerError - return nil, status.Error(codes.Internal, "Internal server error") + type resp struct { + cert []byte + err error + } + respCh := make(chan resp) + go func() { + cert, err := s.GetX509CACert(ctx, keyMeta.Identifier) + respCh <- resp{cert, err} + }() + + select { + case <-ctx.Done(): + statusCode = http.StatusBadRequest + err = fmt.Errorf("client canceled request for %q", config.SSHHostCertEndpoint) + return nil, status.Errorf(codes.Canceled, "%v", err) + case <-reqCtx.Done(): + // Handle the server timeout requests. + statusCode = http.StatusServiceUnavailable + err = fmt.Errorf("request timed out for %q", config.SSHHostCertEndpoint) + return nil, status.Errorf(codes.DeadlineExceeded, "%v", err) + case response := <-respCh: + if response.err != nil { + statusCode = http.StatusInternalServerError + return nil, status.Error(codes.Internal, "Internal server error") + } + return &proto.X509Certificate{Cert: string(response.cert)}, nil } - return &proto.X509Certificate{Cert: string(cert)}, nil } // PostX509Certificate signs the given CSR using the specified key and returns a PEM encoded X509 certificate. @@ -89,7 +114,7 @@ func (s *SigningService) PostX509Certificate(ctx context.Context, request *proto return nil, status.Errorf(codes.InvalidArgument, "Bad request: %v", err) } - // create a context with server side timeout + // Create a context with server side timeout. reqCtx, cancel := context.WithTimeout(ctx, config.DefaultPKCS11Timeout) defer cancel() // Cancel ctx as soon as PostX509Certificate returns @@ -113,24 +138,22 @@ func (s *SigningService) PostX509Certificate(ctx context.Context, request *proto } type resp struct { - data []byte + cert []byte err error } respCh := make(chan resp) go func() { - data, err := s.SignX509Cert(reqCtx, req, request.KeyMeta.Identifier) - respCh <- resp{data, err} + cert, err := s.SignX509Cert(reqCtx, req, request.KeyMeta.Identifier) + respCh <- resp{cert, err} }() select { case <-ctx.Done(): - // client canceled the request. Cancel any pending server request and return - cancel() statusCode = http.StatusBadRequest err = fmt.Errorf("client canceled request for %q", config.X509CertEndpoint) return nil, status.Errorf(codes.Canceled, "%v", err) case <-reqCtx.Done(): - // server request timed out. + // Handle the server timeout requests. statusCode = http.StatusServiceUnavailable err = fmt.Errorf("request timed out for %q", config.X509CertEndpoint) return nil, status.Errorf(codes.DeadlineExceeded, "%v", err) @@ -139,6 +162,6 @@ func (s *SigningService) PostX509Certificate(ctx context.Context, request *proto statusCode = http.StatusInternalServerError return nil, status.Error(codes.Internal, "Internal server error") } - return &proto.X509Certificate{Cert: string(response.data)}, nil + return &proto.X509Certificate{Cert: string(response.cert)}, nil } } diff --git a/api/x509cert_test.go b/api/x509cert_test.go index c12e1f65..fbf7b1a8 100644 --- a/api/x509cert_test.go +++ b/api/x509cert_test.go @@ -78,11 +78,18 @@ func TestGetX509CertificateAvailableSigningKeys(t *testing.T) { func TestGetX509CACertificate(t *testing.T) { t.Parallel() + ctx := context.Background() + cancelCtx, cancel := context.WithCancel(ctx) + defer cancel() + timeoutCtx, timeCancel := context.WithTimeout(ctx, timeout) + defer timeCancel() testcases := map[string]struct { + ctx context.Context KeyUsages map[string]map[string]bool KeyMeta *proto.KeyMeta // if expectedCert set to nil, we are expecting an error while testing expectedCert *proto.X509Certificate + timeout time.Duration }{ "emptyKeyUsages": { KeyMeta: &proto.KeyMeta{Identifier: "randomid"}, @@ -116,24 +123,40 @@ func TestGetX509CACertificate(t *testing.T) { KeyMeta: &proto.KeyMeta{Identifier: "x509id2"}, expectedCert: nil, }, + "requestTimeout": { + ctx: timeoutCtx, + KeyUsages: x509keyUsage, + KeyMeta: &proto.KeyMeta{Identifier: "x509id"}, + expectedCert: nil, + timeout: timeout, + }, + "requestCancelled": { + ctx: cancelCtx, + KeyUsages: x509keyUsage, + KeyMeta: &proto.KeyMeta{Identifier: "x509id"}, + expectedCert: nil, + timeout: timeout, + }, } for label, tt := range testcases { tt := tt label := label + if tt.ctx == nil { + tt.ctx = ctx + } t.Run(label, func(t *testing.T) { t.Parallel() - var ctx context.Context // bad certsign should return error anyways - msspBad := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: true} + msspBad := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: true, timeout: tt.timeout} ssBad := initMockSigningService(msspBad) - _, err := ssBad.GetX509CACertificate(ctx, tt.KeyMeta) + _, err := ssBad.GetX509CACertificate(tt.ctx, tt.KeyMeta) if err == nil { t.Fatalf("in test %v: bad signing service should return error but got nil", label) } // good certsign - msspGood := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: false} + msspGood := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: false, timeout: tt.timeout} ssGood := initMockSigningService(msspGood) - cert, err := ssGood.GetX509CACertificate(ctx, tt.KeyMeta) + cert, err := ssGood.GetX509CACertificate(tt.ctx, tt.KeyMeta) if err != nil && tt.expectedCert != nil { t.Fatalf("in test %v: not expecting error but got error %v", label, err) }