Skip to content

Commit

Permalink
make pkcs11 request timeout configurable (#152)
Browse files Browse the repository at this point in the history
* make pkcs11 request timeout configurable

* address comments
  • Loading branch information
hkadakia authored Mar 4, 2022
1 parent 1724ff3 commit 0ac26c4
Show file tree
Hide file tree
Showing 15 changed files with 87 additions and 77 deletions.
7 changes: 4 additions & 3 deletions api/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ import (
// SigningService implements proto.SigningServer interface.
type SigningService struct {
crypki.CertSign
KeyUsages map[string]map[string]bool
MaxValidity map[string]uint64
RequestChan map[string]chan scheduler.Request
KeyUsages map[string]map[string]bool
MaxValidity map[string]uint64
RequestChan map[string]chan scheduler.Request
RequestTimeout uint
proto.UnimplementedSigningServer
}

Expand Down
12 changes: 7 additions & 5 deletions api/sign_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@ var (
)

type mockSigningServiceParam struct {
KeyUsages map[string]map[string]bool
MaxValidity map[string]uint64
sendError bool
timeout time.Duration
KeyUsages map[string]map[string]bool
MaxValidity map[string]uint64
RequestTimeout uint
sendError bool
randSleepTimeout time.Duration
}

type mockBadCertSign struct {
Expand Down Expand Up @@ -136,12 +137,13 @@ func initMockSigningService(mssp mockSigningServiceParam) *SigningService {
ss := &SigningService{}
ss.KeyUsages = mssp.KeyUsages
ss.MaxValidity = mssp.MaxValidity
ss.RequestTimeout = mssp.RequestTimeout
if mssp.sendError {
ss.CertSign = &mockBadCertSign{}
} else {
ss.CertSign = &mockGoodCertSign{}
}
time.Sleep(mssp.timeout)
time.Sleep(mssp.randSleepTimeout)
return ss
}

Expand Down
4 changes: 2 additions & 2 deletions api/sshhost.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (s *SigningService) GetHostSSHCertificateSigningKey(ctx context.Context, ke
}

// create a context with server side timeout
reqCtx, cancel := context.WithTimeout(ctx, config.DefaultPKCS11Timeout)
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(s.RequestTimeout)*time.Second)
defer cancel() // Cancel ctx as soon as GetHostSSHCertificateSigningKey returns

if !s.KeyUsages[config.SSHHostCertEndpoint][keyMeta.Identifier] {
Expand Down Expand Up @@ -123,7 +123,7 @@ func (s *SigningService) PostHostSSHCertificate(ctx context.Context, request *pr
}

// create a context with server side timeout
reqCtx, cancel := context.WithTimeout(ctx, config.DefaultPKCS11Timeout)
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(s.RequestTimeout)*time.Second)
defer cancel() // Cancel ctx as soon as PostHostSSHCertificate returns

maxValidity := s.MaxValidity[config.SSHHostCertEndpoint]
Expand Down
10 changes: 5 additions & 5 deletions api/sshhost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,14 @@ func TestGetHostSSHCertificateSigningKey(t *testing.T) {
t.Run(label, func(t *testing.T) {
t.Parallel()
// bad certsign should return error anyways
msspBad := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: true, timeout: tt.timeout}
msspBad := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: true, randSleepTimeout: tt.timeout, RequestTimeout: config.DefaultPKCS11Timeout}
ssBad := initMockSigningService(msspBad)
_, err := ssBad.GetHostSSHCertificateSigningKey(tt.ctx, tt.KeyMeta)
if err == nil {
t.Fatalf("in test %v: bad signing service should return error but got nil", label)
}
// good certsign
msspGood := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: false, timeout: tt.timeout}
msspGood := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: false, randSleepTimeout: tt.timeout, RequestTimeout: config.DefaultPKCS11Timeout}
ssGood := initMockSigningService(msspGood)
key, err := ssGood.GetHostSSHCertificateSigningKey(tt.ctx, tt.KeyMeta)
if err != nil && tt.expectedSSHKey != nil {
Expand Down Expand Up @@ -181,7 +181,7 @@ func TestPostHostSSHCertificate(t *testing.T) {
defer cancel()
timeoutCtx, timeCancel := context.WithTimeout(ctx, timeout)
defer timeCancel()
defaultMaxValidity := map[string]uint64{config.X509CertEndpoint: 0}
defaultMaxValidity := map[string]uint64{config.SSHHostCertEndpoint: 0}
testcases := map[string]struct {
ctx context.Context
KeyUsages map[string]map[string]bool
Expand Down Expand Up @@ -356,7 +356,7 @@ func TestPostHostSSHCertificate(t *testing.T) {
t.Run(label, func(t *testing.T) {
t.Parallel()
// bad certsign should return error anyways
msspBad := mockSigningServiceParam{KeyUsages: tt.KeyUsages, MaxValidity: tt.maxValidity, sendError: true, timeout: tt.timeout}
msspBad := mockSigningServiceParam{KeyUsages: tt.KeyUsages, MaxValidity: tt.maxValidity, sendError: true, randSleepTimeout: tt.timeout, RequestTimeout: config.DefaultPKCS11Timeout}
ssBad := initMockSigningService(msspBad)
requestBad := &proto.SSHCertificateSigningRequest{KeyMeta: tt.KeyMeta, PublicKey: tt.PubKey, Validity: tt.validity, KeyId: tt.KeyID}
_, err := ssBad.PostHostSSHCertificate(tt.ctx, requestBad)
Expand All @@ -365,7 +365,7 @@ func TestPostHostSSHCertificate(t *testing.T) {
}

// good certsign
msspGood := mockSigningServiceParam{KeyUsages: tt.KeyUsages, MaxValidity: tt.maxValidity, sendError: false, timeout: tt.timeout}
msspGood := mockSigningServiceParam{KeyUsages: tt.KeyUsages, MaxValidity: tt.maxValidity, sendError: false, randSleepTimeout: tt.timeout, RequestTimeout: config.DefaultPKCS11Timeout}
ssGood := initMockSigningService(msspGood)
requestGood := &proto.SSHCertificateSigningRequest{KeyMeta: tt.KeyMeta, PublicKey: tt.PubKey, Validity: tt.validity, KeyId: tt.KeyID}
cert, err := ssGood.PostHostSSHCertificate(tt.ctx, requestGood)
Expand Down
4 changes: 2 additions & 2 deletions api/sshuser.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (s *SigningService) GetUserSSHCertificateSigningKey(ctx context.Context, ke
}

// create a context with server side timeout
reqCtx, cancel := context.WithTimeout(ctx, config.DefaultPKCS11Timeout)
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(s.RequestTimeout)*time.Second)
defer cancel() // Cancel ctx as soon as GetUserSSHCertificateSigningKey returns

if !s.KeyUsages[config.SSHUserCertEndpoint][keyMeta.Identifier] {
Expand Down Expand Up @@ -123,7 +123,7 @@ func (s *SigningService) PostUserSSHCertificate(ctx context.Context, request *pr
}

// create a context with server side timeout
reqCtx, cancel := context.WithTimeout(ctx, config.DefaultPKCS11Timeout)
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(s.RequestTimeout)*time.Second)
defer cancel() // Cancel ctx as soon as PostUserSSHCertificate returns

maxValidity := s.MaxValidity[config.SSHUserCertEndpoint]
Expand Down
8 changes: 4 additions & 4 deletions api/sshuser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,14 @@ func TestGetUserSSHCertificateSigningKey(t *testing.T) {
t.Run(label, func(t *testing.T) {
t.Parallel()
// bad certsign should return error anyways
msspBad := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: true, timeout: tt.timeout}
msspBad := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: true, randSleepTimeout: tt.timeout, RequestTimeout: config.DefaultPKCS11Timeout}
ssBad := initMockSigningService(msspBad)
_, err := ssBad.GetUserSSHCertificateSigningKey(tt.ctx, tt.KeyMeta)
if err == nil {
t.Fatalf("in test %v: bad signing service should return error but got nil", label)
}
// good certsign
msspGood := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: false, timeout: tt.timeout}
msspGood := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: false, randSleepTimeout: tt.timeout, RequestTimeout: config.DefaultPKCS11Timeout}
ssGood := initMockSigningService(msspGood)
key, err := ssGood.GetUserSSHCertificateSigningKey(tt.ctx, tt.KeyMeta)
if err != nil && tt.expectedSSHKey != nil {
Expand Down Expand Up @@ -347,7 +347,7 @@ func TestPostUserSSHCertificate(t *testing.T) {
t.Run(label, func(t *testing.T) {
t.Parallel()
// bad certsign should return error anyways
msspBad := mockSigningServiceParam{KeyUsages: tt.KeyUsages, MaxValidity: tt.maxValidity, sendError: true, timeout: tt.timeout}
msspBad := mockSigningServiceParam{KeyUsages: tt.KeyUsages, MaxValidity: tt.maxValidity, sendError: true, randSleepTimeout: tt.timeout, RequestTimeout: config.DefaultPKCS11Timeout}
ssBad := initMockSigningService(msspBad)
requestBad := &proto.SSHCertificateSigningRequest{KeyMeta: tt.KeyMeta, PublicKey: tt.PubKey, Validity: tt.validity, KeyId: tt.KeyID}
_, err := ssBad.PostUserSSHCertificate(tt.ctx, requestBad)
Expand All @@ -356,7 +356,7 @@ func TestPostUserSSHCertificate(t *testing.T) {
}

// good certsign
msspGood := mockSigningServiceParam{KeyUsages: tt.KeyUsages, MaxValidity: tt.maxValidity, sendError: false, timeout: tt.timeout}
msspGood := mockSigningServiceParam{KeyUsages: tt.KeyUsages, MaxValidity: tt.maxValidity, sendError: false, randSleepTimeout: tt.timeout, RequestTimeout: config.DefaultPKCS11Timeout}
ssGood := initMockSigningService(msspGood)
requestGood := &proto.SSHCertificateSigningRequest{KeyMeta: tt.KeyMeta, PublicKey: tt.PubKey, Validity: tt.validity, KeyId: tt.KeyID}
cert, err := ssGood.PostUserSSHCertificate(tt.ctx, requestGood)
Expand Down
4 changes: 2 additions & 2 deletions api/x509cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (s *SigningService) GetX509CACertificate(ctx context.Context, keyMeta *prot
}

// Create a context with server side timeout.
reqCtx, cancel := context.WithTimeout(ctx, config.DefaultPKCS11Timeout)
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(s.RequestTimeout)*time.Second)
defer cancel() // Cancel ctx as soon as GetX509CACertificate returns.

if !s.KeyUsages[config.X509CertEndpoint][keyMeta.Identifier] {
Expand Down Expand Up @@ -115,7 +115,7 @@ func (s *SigningService) PostX509Certificate(ctx context.Context, request *proto
}

// Create a context with server side timeout.
reqCtx, cancel := context.WithTimeout(ctx, config.DefaultPKCS11Timeout)
reqCtx, cancel := context.WithTimeout(ctx, time.Duration(s.RequestTimeout)*time.Second)
defer cancel() // Cancel ctx as soon as PostX509Certificate returns

maxValidity := s.MaxValidity[config.X509CertEndpoint]
Expand Down
28 changes: 17 additions & 11 deletions api/x509cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,15 @@ func TestGetX509CACertificate(t *testing.T) {
defer cancel()
timeoutCtx, timeCancel := context.WithTimeout(ctx, timeout)
defer timeCancel()
defaultRequestTimeout := 10
testcases := map[string]struct {
ctx context.Context
KeyUsages map[string]map[string]bool
KeyMeta *proto.KeyMeta
// if expectedCert set to nil, we are expecting an error while testing
expectedCert *proto.X509Certificate
timeout time.Duration
expectedCert *proto.X509Certificate
timeout time.Duration
requestTimeout uint
}{
"emptyKeyUsages": {
KeyMeta: &proto.KeyMeta{Identifier: "randomid"},
Expand Down Expand Up @@ -131,11 +133,12 @@ func TestGetX509CACertificate(t *testing.T) {
timeout: timeout,
},
"requestCancelled": {
ctx: cancelCtx,
KeyUsages: x509keyUsage,
KeyMeta: &proto.KeyMeta{Identifier: "x509id"},
expectedCert: nil,
timeout: timeout,
ctx: cancelCtx,
KeyUsages: x509keyUsage,
KeyMeta: &proto.KeyMeta{Identifier: "x509id"},
expectedCert: nil,
timeout: timeout,
requestTimeout: 1,
},
}
for label, tt := range testcases {
Expand All @@ -146,15 +149,18 @@ func TestGetX509CACertificate(t *testing.T) {
}
t.Run(label, func(t *testing.T) {
t.Parallel()
if tt.requestTimeout == 0 {
tt.requestTimeout = uint(defaultRequestTimeout)
}
// bad certsign should return error anyways
msspBad := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: true, timeout: tt.timeout}
msspBad := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: true, randSleepTimeout: tt.timeout, RequestTimeout: tt.requestTimeout}
ssBad := initMockSigningService(msspBad)
_, err := ssBad.GetX509CACertificate(tt.ctx, tt.KeyMeta)
if err == nil {
t.Fatalf("in test %v: bad signing service should return error but got nil", label)
}
// good certsign
msspGood := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: false, timeout: tt.timeout}
msspGood := mockSigningServiceParam{KeyUsages: tt.KeyUsages, sendError: false, randSleepTimeout: tt.timeout, RequestTimeout: tt.requestTimeout}
ssGood := initMockSigningService(msspGood)
cert, err := ssGood.GetX509CACertificate(tt.ctx, tt.KeyMeta)
if err != nil && tt.expectedCert != nil {
Expand Down Expand Up @@ -323,15 +329,15 @@ func TestPostX509Certificate(t *testing.T) {
t.Run(label, func(t *testing.T) {
t.Parallel()
// bad certsign should return error anyways
msspBad := mockSigningServiceParam{KeyUsages: tt.KeyUsages, MaxValidity: tt.maxValidity, sendError: true, timeout: tt.timeout}
msspBad := mockSigningServiceParam{KeyUsages: tt.KeyUsages, MaxValidity: tt.maxValidity, sendError: true, randSleepTimeout: tt.timeout, RequestTimeout: config.DefaultPKCS11Timeout}
ssBad := initMockSigningService(msspBad)
requestBad := &proto.X509CertificateSigningRequest{KeyMeta: tt.KeyMeta, Csr: tt.CSR, Validity: tt.validity}
if _, err := ssBad.PostX509Certificate(tt.ctx, requestBad); err == nil {
t.Fatalf("expected error for invalid test %v, got nil", label)
}

// good certsign
msspGood := mockSigningServiceParam{KeyUsages: tt.KeyUsages, MaxValidity: tt.maxValidity, sendError: false, timeout: tt.timeout}
msspGood := mockSigningServiceParam{KeyUsages: tt.KeyUsages, MaxValidity: tt.maxValidity, sendError: false, randSleepTimeout: tt.timeout, RequestTimeout: config.DefaultPKCS11Timeout}
ssGood := initMockSigningService(msspGood)
requestGood := &proto.X509CertificateSigningRequest{KeyMeta: tt.KeyMeta, Csr: tt.CSR, Validity: tt.validity}
cert, err := ssGood.PostX509Certificate(tt.ctx, requestGood)
Expand Down
2 changes: 1 addition & 1 deletion cmd/gen-cacert/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func main() {
OrganizationalUnit: cc.OrganizationalUnit,
CommonName: cc.CommonName,
ValidityPeriod: cc.ValidityPeriod,
}}, requireX509CACert, hostname, ips)
}}, requireX509CACert, hostname, ips, config.DefaultPKCS11Timeout)
if err != nil {
log.Fatalf("unable to initialize cert signer: %v", err)
}
Expand Down
17 changes: 11 additions & 6 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"fmt"
"os"
"strings"
"time"
)

const (
Expand All @@ -29,9 +28,10 @@ const (
defaultShutdownOnSigningFailureTimerDurationSecond = 60
defaultShutdownOnSigningFailureTimerCount = 10

defaultIdleTimeout = 30
defaultReadTimeout = 10
defaultWriteTimeout = 10
defaultIdleTimeout = 30
defaultReadTimeout = 10
defaultWriteTimeout = 10
DefaultPKCS11Timeout = 10

// X509CertEndpoint specifies the endpoint for signing X509 certificate.
X509CertEndpoint = "/sig/x509-cert"
Expand All @@ -41,8 +41,6 @@ const (
SSHHostCertEndpoint = "/sig/ssh-host-cert"
// BlobEndpoint specifies the endpoint for raw signing.
BlobEndpoint = "/sig/blob"
// DefaultPKCS11Timeout specifies the max time required by HSM to sign a cert.
DefaultPKCS11Timeout = 10 * time.Second
)

var endpoints = map[string]bool{
Expand Down Expand Up @@ -146,6 +144,10 @@ type Config struct {
IdleTimeout uint
ReadTimeout uint
WriteTimeout uint

// PKCS11RequestTimeout indicates the max time an HSM can take to process a signing request for a
// certificate in seconds.
PKCS11RequestTimeout uint `json:"requestTimeout"`
}

// Parse loads configuration values from input file and returns config object and CA cert.
Expand Down Expand Up @@ -279,4 +281,7 @@ func (c *Config) loadDefaults() {
if c.WriteTimeout == 0 {
c.WriteTimeout = defaultWriteTimeout
}
if c.PKCS11RequestTimeout == 0 {
c.PKCS11RequestTimeout = DefaultPKCS11Timeout
}
}
7 changes: 4 additions & 3 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ func TestParse(t *testing.T) {
TimerDurationSecond: 120,
TimerCountLimit: 20,
},
IdleTimeout: 30,
ReadTimeout: 10,
WriteTimeout: 10,
IdleTimeout: 30,
ReadTimeout: 10,
WriteTimeout: 10,
PKCS11RequestTimeout: 15,
}
testcases := map[string]struct {
filePath string
Expand Down
3 changes: 2 additions & 1 deletion config/testdata/testconf-good.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
},
"IdleTimeout": 30,
"ReadTimeout": 10,
"WriteTimeout": 10
"WriteTimeout": 10,
"RequestTimeout": 15
}
Loading

0 comments on commit 0ac26c4

Please sign in to comment.