Skip to content

Commit

Permalink
Have the file system signer read the cert files when it needs to in c…
Browse files Browse the repository at this point in the history
…ase they have changed
  • Loading branch information
krmichelos committed May 24, 2024
1 parent 44ac515 commit 2481d8d
Show file tree
Hide file tree
Showing 15 changed files with 243 additions and 140 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ tst/certs/
credential-process-data/
tst/softhsm/
tst/softhsm2.conf

.idea/
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
repos:
- repo: https://github.com/dnephin/pre-commit-golang
rev: v0.5.1
hooks:
- id: go-mod-tidy
- id: go-fmt
17 changes: 15 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
VERSION=1.1.1

release:
.PHONY: release
release: build/bin/aws_signing_helper

build/bin/aws_signing_helper:
go build -buildmode=pie -ldflags "-X 'github.com/aws/rolesanywhere-credential-helper/cmd.Version=${VERSION}' -linkmode=external -w -s" -trimpath -o build/bin/aws_signing_helper main.go

.PHONY: clean
clean:
rm -rf build

# Setting up SoftHSM for PKCS#11 tests.
# This portion is largely copied from https://gitlab.com/openconnect/openconnect/-/blob/v9.12/tests/Makefile.am#L363.
SHM2_UTIL=SOFTHSM2_CONF=tst/softhsm2.conf.tmp softhsm2-util
Expand All @@ -20,7 +27,7 @@ PKCS12CERTS := $(patsubst %-cert.pem, %.p12, $(RSACERTS) $(ECCERTS))

# It's hard to do a file-based rule for the contents of the SoftHSM token.
# So just populate it as a side-effect of creating the softhsm2.conf file.
tst/softhsm2.conf: tst/softhsm2.conf.template $(PKCS8KEYS) $(RSACERTS) $(ECCERTS)
tst/softhsm2.conf: tst/softhsm2.conf.template $(PKCS8KEYS) $(RSACERTS) $(ECCERTS) tst/certs/rsa-2048-2-sha256-cert.pem
rm -rf tst/softhsm/*
sed 's|@top_srcdir@|${curdir}|g' $< > $@.tmp
$(SHM2_UTIL) --show-slots
Expand Down Expand Up @@ -50,6 +57,7 @@ tst/softhsm2.conf: tst/softhsm2.conf.template $(PKCS8KEYS) $(RSACERTS) $(ECCERTS
--mark-always-authenticate
mv [email protected] $@

.PHONY: test
test: test-certs tst/softhsm2.conf
SOFTHSM2_CONF=$(curdir)/tst/softhsm2.conf go test -v ./...

Expand All @@ -62,6 +70,9 @@ test: test-certs tst/softhsm2.conf
%-sha256-cert.pem: %-key.pem
SUBJ=$$(echo "$@" | sed -r 's|.*/([^/]+)-cert.pem|\1|'); \
openssl req -x509 -new -key $< -out $@ -days 10000 -subj "/CN=roles-anywhere-$${SUBJ}" -sha256
%-2-sha256-cert.pem: %-key.pem
SUBJ=$$(echo "$@" | sed -r 's|.*/([^/]+)-cert.pem|\1|'); \
openssl req -x509 -new -key $< -out $@ -days 10000 -subj "/CN=roles-anywhere-$${SUBJ}" -sha256
%-sha384-cert.pem: %-key.pem
SUBJ=$$(echo "$@" | sed -r 's|.*/([^/]+)-cert.pem|\1|'); \
openssl req -x509 -new -key $< -out $@ -days 10000 -subj "/CN=roles-anywhere-$${SUBJ}" -sha384
Expand Down Expand Up @@ -111,8 +122,10 @@ $(certsdir)/cert-bundle-with-comments.pem: $(RSACERTS) $(ECCERTS)
echo "Comment in bundle\n" >> $@; \
done

.PHONY: test-certs
test-certs: $(PKCS8KEYS) $(RSAKEYS) $(ECKEYS) $(RSACERTS) $(ECCERTS) $(PKCS12CERTS) $(certsdir)/cert-bundle.pem $(certsdir)/cert-bundle-with-comments.pem tst/softhsm2.conf

.PHONY: test-clean
test-clean:
rm -f $(RSAKEYS) $(ECKEYS)
rm -f $(PKCS8KEYS)
Expand Down
81 changes: 64 additions & 17 deletions aws_signing_helper/file_system_signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,37 @@ import (
"errors"
"io"
"log"
"os"
)

type FileSystemSigner struct {
PrivateKey crypto.PrivateKey
cert *x509.Certificate
certChain []*x509.Certificate
bundlePath string
certPath string
isPkcs12 bool
privateKeyPath string
}

func (fileSystemSigner FileSystemSigner) Public() crypto.PublicKey {
func (fileSystemSigner *FileSystemSigner) Public() crypto.PublicKey {
privateKey, _, _ := fileSystemSigner.readCertFiles()
{
privateKey, ok := fileSystemSigner.PrivateKey.(ecdsa.PrivateKey)
privateKey, ok := privateKey.(ecdsa.PrivateKey)
if ok {
return &privateKey.PublicKey
}
}
{
privateKey, ok := fileSystemSigner.PrivateKey.(rsa.PrivateKey)
privateKey, ok := privateKey.(rsa.PrivateKey)
if ok {
return &privateKey.PublicKey
}
}
return nil
}

func (fileSystemSigner FileSystemSigner) Close() {}
func (fileSystemSigner *FileSystemSigner) Close() {}

func (fileSystemSigner FileSystemSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
func (fileSystemSigner *FileSystemSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
privateKey, _, _ := fileSystemSigner.readCertFiles()
var hash []byte
switch opts.HashFunc() {
case crypto.SHA256:
Expand All @@ -52,15 +56,15 @@ func (fileSystemSigner FileSystemSigner) Sign(rand io.Reader, digest []byte, opt
return nil, ErrUnsupportedHash
}

ecdsaPrivateKey, ok := fileSystemSigner.PrivateKey.(ecdsa.PrivateKey)
ecdsaPrivateKey, ok := privateKey.(ecdsa.PrivateKey)
if ok {
sig, err := ecdsa.SignASN1(rand, &ecdsaPrivateKey, hash[:])
if err == nil {
return sig, nil
}
}

rsaPrivateKey, ok := fileSystemSigner.PrivateKey.(rsa.PrivateKey)
rsaPrivateKey, ok := privateKey.(rsa.PrivateKey)
if ok {
sig, err := rsa.SignPKCS1v15(rand, &rsaPrivateKey, opts.HashFunc(), hash[:])
if err == nil {
Expand All @@ -72,16 +76,20 @@ func (fileSystemSigner FileSystemSigner) Sign(rand io.Reader, digest []byte, opt
return nil, errors.New("unsupported algorithm")
}

func (fileSystemSigner FileSystemSigner) Certificate() (*x509.Certificate, error) {
return fileSystemSigner.cert, nil
func (fileSystemSigner *FileSystemSigner) Certificate() (*x509.Certificate, error) {
_, cert, _ := fileSystemSigner.readCertFiles()
return cert, nil
}

func (fileSystemSigner FileSystemSigner) CertificateChain() ([]*x509.Certificate, error) {
return fileSystemSigner.certChain, nil
func (fileSystemSigner *FileSystemSigner) CertificateChain() ([]*x509.Certificate, error) {
_, _, certChain := fileSystemSigner.readCertFiles()
return certChain, nil
}

// Returns a FileSystemSigner, that signs a payload using the private key passed in
func GetFileSystemSigner(privateKey crypto.PrivateKey, certificate *x509.Certificate, certificateChain []*x509.Certificate) (signer Signer, signingAlgorithm string, err error) {
// GetFileSystemSigner returns a FileSystemSigner, that signs a payload using the private key passed in
func GetFileSystemSigner(privateKeyPath string, certPath string, bundlePath string, isPkcs12 bool) (signer Signer, signingAlgorithm string, err error) {
fsSigner := &FileSystemSigner{bundlePath: bundlePath, certPath: certPath, isPkcs12: isPkcs12, privateKeyPath: privateKeyPath}
privateKey, _, _ := fsSigner.readCertFiles()
// Find the signing algorithm
_, isRsaKey := privateKey.(rsa.PrivateKey)
if isRsaKey {
Expand All @@ -96,5 +104,44 @@ func GetFileSystemSigner(privateKey crypto.PrivateKey, certificate *x509.Certifi
return nil, "", errors.New("unsupported algorithm")
}

return FileSystemSigner{privateKey, certificate, certificateChain}, signingAlgorithm, nil
return fsSigner, signingAlgorithm, nil
}

func (fileSystemSigner *FileSystemSigner) readCertFiles() (crypto.PrivateKey, *x509.Certificate, []*x509.Certificate) {
if fileSystemSigner.isPkcs12 {
chain, privateKey, err := ReadPKCS12Data(fileSystemSigner.certPath)
if err != nil {
log.Printf("Failed to read PKCS12 certificate: %s\n", err)
os.Exit(1)
}
return privateKey, chain[0], chain
} else {
privateKey, err := ReadPrivateKeyData(fileSystemSigner.privateKeyPath)
if err != nil {
log.Printf("Failed to read private key: %s\n", err)
os.Exit(1)
}
var chain []*x509.Certificate
if fileSystemSigner.bundlePath != "" {
chain, err = GetCertChain(fileSystemSigner.bundlePath)
if err != nil {
privateKey = nil
log.Printf("Failed to read certificate bundle: %s\n", err)
os.Exit(1)
}
}
var cert *x509.Certificate
if fileSystemSigner.certPath != "" {
_, cert, err = ReadCertificateData(fileSystemSigner.certPath)
if err != nil {
privateKey = nil
log.Printf("Failed to read certificate: %s\n", err)
os.Exit(1)
}
} else if len(chain) > 0 {
cert = chain[0]
}

return privateKey, cert, chain
}
}
12 changes: 11 additions & 1 deletion aws_signing_helper/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,19 @@ func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *Credentials

err := CheckValidToken(w, r)
if err != nil {
log.Printf("Token validation received error: %s\n", err)
return
}

var nextRefreshTime = cred.Expiration.Add(-RefreshTime)
if time.Until(nextRefreshTime) < RefreshTime {
credentialProcessOutput, _ := GenerateCredentials(opts, signer, signatureAlgorithm)
if Debug {
log.Println("Generating credentials")
}
credentialProcessOutput, gcErr := GenerateCredentials(opts, signer, signatureAlgorithm)
if gcErr != nil {
log.Printf("Error generating credentials: %s\n", gcErr)
}
cred.AccessKeyId = credentialProcessOutput.AccessKeyId
cred.SecretAccessKey = credentialProcessOutput.SecretAccessKey
cred.Token = credentialProcessOutput.SessionToken
Expand All @@ -240,6 +247,9 @@ func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *Credentials
return
}
} else {
if Debug {
log.Println("Using previously obtained credentials")
}
err := json.NewEncoder(w).Encode(cred)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
Expand Down
64 changes: 29 additions & 35 deletions aws_signing_helper/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,11 @@ func encodeEcdsaSigValue(signature []byte) (out []byte, err error) {
big.NewInt(0).SetBytes(signature[sigLen:])})
}

// Gets the Signer based on the flags passed in by the user (from which the CredentialsOpts structure is derived)
// GetSigner gets the Signer based on the flags passed in by the user (from which the CredentialsOpts structure is derived)
func GetSigner(opts *CredentialsOpts) (signer Signer, signatureAlgorithm string, err error) {
var (
certificate *x509.Certificate
certificateChain []*x509.Certificate
privateKey crypto.PrivateKey
)

privateKeyId := opts.PrivateKeyId
Expand All @@ -185,16 +184,9 @@ func GetSigner(opts *CredentialsOpts) (signer Signer, signatureAlgorithm string,
}

if opts.CertificateId != "" && !strings.HasPrefix(opts.CertificateId, "pkcs11:") {
certificateData, err := ReadCertificateData(opts.CertificateId)
_, cert, err := ReadCertificateData(opts.CertificateId)
if err == nil {
certificateDerData, err := base64.StdEncoding.DecodeString(certificateData.CertificateData)
if err != nil {
return nil, "", err
}
certificate, err = x509.ParseCertificate([]byte(certificateDerData))
if err != nil {
return nil, "", err
}
certificate = cert
} else if opts.PrivateKeyId == "" {
if Debug {
log.Println("not a PEM certificate, so trying PKCS#12")
Expand All @@ -205,35 +197,21 @@ func GetSigner(opts *CredentialsOpts) (signer Signer, signatureAlgorithm string,
" within the PKCS#12 file")
}
// Not a PEM certificate? Try PKCS#12
certificateChain, privateKey, err = ReadPKCS12Data(opts.CertificateId)
_, _, err = ReadPKCS12Data(opts.CertificateId)
if err != nil {
return nil, "", err
}
if privateKey != nil {
ecPrivateKeyPtr, isEcKey := privateKey.(*ecdsa.PrivateKey)
if isEcKey {
privateKey = *ecPrivateKeyPtr
}

rsaPrivateKeyPtr, isRsaKey := privateKey.(*rsa.PrivateKey)
if isRsaKey {
privateKey = *rsaPrivateKeyPtr
}
}
return GetFileSystemSigner(privateKey, certificateChain[0], certificateChain)
return GetFileSystemSigner(opts.PrivateKeyId, opts.CertificateId, opts.CertificateBundleId, true)
} else {
return nil, "", err
}
}

if opts.CertificateBundleId != "" {
certificateChainPointers, err := ReadCertificateBundleData(opts.CertificateBundleId)
certificateChain, err = GetCertChain(opts.CertificateBundleId)
if err != nil {
return nil, "", err
}
for _, certificate := range certificateChainPointers {
certificateChain = append(certificateChain, certificate)
}
}

if strings.HasPrefix(privateKeyId, "pkcs11:") {
Expand All @@ -245,15 +223,18 @@ func GetSigner(opts *CredentialsOpts) (signer Signer, signatureAlgorithm string,
}
return GetPKCS11Signer(opts.LibPkcs11, certificate, certificateChain, opts.PrivateKeyId, opts.CertificateId, opts.ReusePin)
} else {
privateKey, err = ReadPrivateKeyData(privateKeyId)
_, err = ReadPrivateKeyData(privateKeyId)
if err != nil {
return nil, "", err
}

if certificate == nil {
return nil, "", errors.New("undefined certificate value")
}
if Debug {
log.Println("attempting to use FileSystemSigner")
}
return GetFileSystemSigner(privateKey, certificate, certificateChain)
return GetFileSystemSigner(privateKeyId, opts.CertificateId, opts.CertificateBundleId, false)
}
}

Expand Down Expand Up @@ -709,18 +690,18 @@ func ReadPrivateKeyDataFromPEMBlock(block *pem.Block) (key crypto.PrivateKey, er
return nil, errors.New("unable to parse private key")
}

// Load the certificate referenced by `certificateId` and extract
// ReadCertificateData loads the certificate referenced by `certificateId` and extracts
// details required by the SDK to construct the StringToSign.
func ReadCertificateData(certificateId string) (CertificateData, error) {
func ReadCertificateData(certificateId string) (CertificateData, *x509.Certificate, error) {
block, err := parseDERFromPEM(certificateId, "CERTIFICATE")
if err != nil {
return CertificateData{}, errors.New("could not parse PEM data")
return CertificateData{}, nil, errors.New("could not parse PEM data")
}

cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
log.Println("could not parse certificate", err)
return CertificateData{}, errors.New("could not parse certificate")
return CertificateData{}, nil, errors.New("could not parse certificate")
}

//extract serial number
Expand All @@ -747,5 +728,18 @@ func ReadCertificateData(certificateId string) (CertificateData, error) {
}

//return struct
return CertificateData{keyType, encodedDer, serialNumber, supportedAlgorithms}, nil
return CertificateData{keyType, encodedDer, serialNumber, supportedAlgorithms}, cert, nil
}

// GetCertChain reads a certificate bundle and returns a chain of all the certificates it contains
func GetCertChain(certificateBundleId string) ([]*x509.Certificate, error) {
certificateChainPointers, err := ReadCertificateBundleData(certificateBundleId)
var chain []*x509.Certificate
if err != nil {
return nil, err
}
for _, certificate := range certificateChainPointers {
chain = append(chain, certificate)
}
return chain, nil
}
Loading

0 comments on commit 2481d8d

Please sign in to comment.