diff --git a/crypki.go b/crypki.go index ec66ec82..31a97101 100644 --- a/crypki.go +++ b/crypki.go @@ -46,6 +46,8 @@ const ( SHA256WithRSA ECDSAWithSHA256 ECDSAWithSHA384 + SHA512WithRSA + SHAWithRSA // for backward compatibility ) const ( diff --git a/pkcs11/algosigner.go b/pkcs11/algosigner.go new file mode 100644 index 00000000..3ef54078 --- /dev/null +++ b/pkcs11/algosigner.go @@ -0,0 +1,85 @@ +// Copyright 2021 Yahoo. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package pkcs11 + +import ( + "crypto" + "errors" + "io" + + "golang.org/x/crypto/ssh" + + "github.com/theparanoids/crypki" +) + +type sshAlgorithmSigner struct { + algorithm string + signer ssh.AlgorithmSigner +} + +func (s *sshAlgorithmSigner) PublicKey() ssh.PublicKey { + return s.signer.PublicKey() +} + +func (s *sshAlgorithmSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) { + return s.signer.SignWithAlgorithm(rand, data, s.algorithm) +} + +func getSignatureAlgorithm(publicAlgo crypki.PublicKeyAlgorithm, signAlgo crypki.SignatureAlgorithm) (algorithm string, err error) { + switch publicAlgo { + case crypki.RSA: + { + switch signAlgo { + case crypki.ECDSAWithSHA256, crypki.ECDSAWithSHA384: + err = errors.New("public key algo & signature algo mismatch, unable to get AlgorithmSigner") + case crypki.SHAWithRSA: + algorithm = ssh.SigAlgoRSA + case crypki.SHA512WithRSA: + algorithm = ssh.SigAlgoRSASHA2512 + case crypki.SHA256WithRSA: + algorithm = ssh.SigAlgoRSASHA2256 + default: + algorithm = ssh.SigAlgoRSASHA2256 + } + } + case crypki.ECDSA: + // For ECDSA public algorithm, signature algo does not exist. We pass in + // empty algorithm & the crypto library will ensure the right algorithm is chosen + // for signing the cert. + return + default: + err = errors.New("public key algorithm not supported") + } + return +} + +func newAlgorithmSignerFromSigner(signer crypto.Signer, publicAlgo crypki.PublicKeyAlgorithm, signAlgo crypki.SignatureAlgorithm) (ssh.Signer, error) { + sshSigner, err := ssh.NewSignerFromSigner(signer) + if err != nil { + return nil, err + } + algorithmSigner, ok := sshSigner.(ssh.AlgorithmSigner) + if !ok { + return nil, errors.New("unable to cast to ssh.AlgorithmSigner") + } + algorithm, err := getSignatureAlgorithm(publicAlgo, signAlgo) + if err != nil { + return nil, err + } + s := sshAlgorithmSigner{ + signer: algorithmSigner, + algorithm: algorithm, + } + return &s, nil +} diff --git a/pkcs11/algosigner_test.go b/pkcs11/algosigner_test.go new file mode 100644 index 00000000..f0ef7e22 --- /dev/null +++ b/pkcs11/algosigner_test.go @@ -0,0 +1,87 @@ +// Copyright 2021 Yahoo. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package pkcs11 + +import ( + "testing" + + "golang.org/x/crypto/ssh" + + "github.com/theparanoids/crypki" +) + +func TestGetSignatureAlgorithm(t *testing.T) { + t.Parallel() + tests := map[string]struct { + pubAlgo crypki.PublicKeyAlgorithm + signAlgo crypki.SignatureAlgorithm + want string + wantError bool + }{ + "rsa pub rsa 256 signing": { + pubAlgo: crypki.RSA, + signAlgo: crypki.SHA256WithRSA, + want: ssh.SigAlgoRSASHA2256, + wantError: false, + }, + "rsa pub rsa 512 signing": { + pubAlgo: crypki.RSA, + signAlgo: crypki.SHA512WithRSA, + want: ssh.SigAlgoRSASHA2512, + wantError: false, + }, + "rsa pub sha1 signing": { + pubAlgo: crypki.RSA, + signAlgo: crypki.SHAWithRSA, + want: ssh.SigAlgoRSA, + wantError: false, + }, + "rsa pub ec signing": { + pubAlgo: crypki.RSA, + signAlgo: crypki.ECDSAWithSHA384, + want: "", + wantError: true, + }, + "rsa pub no signing algo": { + pubAlgo: crypki.RSA, + signAlgo: crypki.UnknownSignatureAlgorithm, + want: ssh.SigAlgoRSASHA2256, + wantError: false, + }, + "ec pub ec sign": { + pubAlgo: crypki.ECDSA, + signAlgo: crypki.ECDSAWithSHA384, + want: "", + wantError: false, + }, + "default pub key algo": { + pubAlgo: crypki.UnknownPublicKeyAlgorithm, + signAlgo: crypki.UnknownSignatureAlgorithm, + want: "", + wantError: true, + }, + } + for name, tt := range tests { + tt := tt + t.Run(name, func(t *testing.T) { + got, err := getSignatureAlgorithm(tt.pubAlgo, tt.signAlgo) + if (err != nil) != tt.wantError { + t.Errorf("%s: got %s want %s", name, got, tt.want) + } + if got != tt.want { + t.Errorf("%s: got %s want %s", name, got, tt.want) + } + }) + } +} diff --git a/pkcs11/signer.go b/pkcs11/signer.go index c51420d1..9cacc46e 100644 --- a/pkcs11/signer.go +++ b/pkcs11/signer.go @@ -209,7 +209,7 @@ func (s *signer) SignSSHCert(ctx context.Context, reqChan chan scheduler.Request pt = time.Since(pStart).Nanoseconds() / time.Microsecond.Nanoseconds() defer pool.put(signer) - sshSigner, err := ssh.NewSignerFromSigner(signer) + sshSigner, err := newAlgorithmSignerFromSigner(signer, signer.publicKeyAlgorithm(), signer.signAlgorithm()) if err != nil { return nil, fmt.Errorf("failed to new ssh signer from signer, error :%v", err) } diff --git a/pkcs11/signer_test.go b/pkcs11/signer_test.go index 12676165..abc15e15 100644 --- a/pkcs11/signer_test.go +++ b/pkcs11/signer_test.go @@ -223,23 +223,23 @@ func TestSignSSHCert(t *testing.T) { reqChan := make(chan scheduler.Request) go dummyScheduler(ctx, reqChan) testcases := map[string]struct { - ctx context.Context - cert *ssh.Certificate - keyType crypki.PublicKeyAlgorithm - identifier string - priority proto.Priority - isBadSigner bool - expectError bool + ctx context.Context + cert *ssh.Certificate + keyType crypki.PublicKeyAlgorithm + identifier string + priority proto.Priority + isBadSigner bool + expectError bool + expectedSignatureAlgo string }{ - "host-cert-rsa": {ctx, hostCertRSA, crypki.RSA, defaultIdentifier, proto.Priority_Low, false, false}, - "host-cert-ec": {ctx, hostCertEC, crypki.ECDSA, defaultIdentifier, proto.Priority_Medium, false, false}, - "host-cert-bad-identifier": {ctx, hostCertRSA, crypki.RSA, badIdentifier, proto.Priority_High, false, true}, - "host-cert-bad-signer": {ctx, hostCertRSA, crypki.RSA, defaultIdentifier, proto.Priority_Low, true, true}, - "user-cert-rsa": {ctx, userCertRSA, crypki.RSA, defaultIdentifier, proto.Priority_Unspecified_priority, false, false}, - "user-cert-ec": {ctx, userCertEC, crypki.ECDSA, defaultIdentifier, proto.Priority_Medium, false, false}, - "user-cert-bad-identifier": {ctx, userCertRSA, crypki.RSA, badIdentifier, proto.Priority_High, false, true}, - "user-cert-bad-signer": {ctx, userCertRSA, crypki.RSA, defaultIdentifier, proto.Priority_Low, true, true}, - "user-cert-request-timeout": {timeoutCtx, userCertRSA, crypki.RSA, defaultIdentifier, proto.Priority_Low, false, true}, + "host-cert-rsa": {ctx, hostCertRSA, crypki.RSA, defaultIdentifier, proto.Priority_Low, false, false, ssh.SigAlgoRSASHA2256}, + "host-cert-ec": {ctx, hostCertEC, crypki.ECDSA, defaultIdentifier, proto.Priority_Medium, false, false, ssh.KeyAlgoECDSA256}, + "host-cert-bad-signer": {ctx, hostCertRSA, crypki.RSA, defaultIdentifier, proto.Priority_Low, true, true, ""}, + "user-cert-rsa": {ctx, userCertRSA, crypki.RSA, defaultIdentifier, proto.Priority_Unspecified_priority, false, false, ssh.SigAlgoRSASHA2256}, + "user-cert-ec": {ctx, userCertEC, crypki.ECDSA, defaultIdentifier, proto.Priority_Medium, false, false, ssh.KeyAlgoECDSA256}, + "user-cert-bad-identifier": {ctx, userCertRSA, crypki.RSA, badIdentifier, proto.Priority_High, false, true, ""}, + "user-cert-bad-signer": {ctx, userCertRSA, crypki.RSA, defaultIdentifier, proto.Priority_Low, true, true, ""}, + "user-cert-request-timeout": {timeoutCtx, userCertRSA, crypki.RSA, defaultIdentifier, proto.Priority_Low, false, true, ""}, } for label, tt := range testcases { label, tt := label, tt @@ -269,6 +269,9 @@ func TestSignSSHCert(t *testing.T) { if err := cc.CheckCert("alice", cert); err != nil { t.Fatalf("check cert failed: %v", err) } + if tt.expectedSignatureAlgo != cert.Signature.Format { + t.Fatalf("mismatch signature algorithm, got %s want %s", cert.Signature.Format, tt.expectedSignatureAlgo) + } }) } }