Skip to content

Commit

Permalink
Have the file system signer watch the files on disk for changes and r…
Browse files Browse the repository at this point in the history
…eload them
  • Loading branch information
krmichelos committed Mar 19, 2024
1 parent 44ac515 commit aa68ad6
Show file tree
Hide file tree
Showing 9 changed files with 244 additions and 75 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ tst/certs/
credential-process-data/
tst/softhsm/
tst/softhsm2.conf

.idea/
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
VERSION=1.1.1

release:
.PHONY: release
release: build/bin/aws_signing_helper

build/bin/aws_signing_helper:
go build -buildmode=pie -ldflags "-X 'github.com/aws/rolesanywhere-credential-helper/cmd.Version=${VERSION}' -linkmode=external -w -s" -trimpath -o build/bin/aws_signing_helper main.go

.PHONY: clean
clean:
rm -rf build

# Setting up SoftHSM for PKCS#11 tests.
# This portion is largely copied from https://gitlab.com/openconnect/openconnect/-/blob/v9.12/tests/Makefile.am#L363.
SHM2_UTIL=SOFTHSM2_CONF=tst/softhsm2.conf.tmp softhsm2-util
Expand Down Expand Up @@ -50,6 +57,7 @@ tst/softhsm2.conf: tst/softhsm2.conf.template $(PKCS8KEYS) $(RSACERTS) $(ECCERTS
--mark-always-authenticate
mv [email protected] $@

.PHONY: test
test: test-certs tst/softhsm2.conf
SOFTHSM2_CONF=$(curdir)/tst/softhsm2.conf go test -v ./...

Expand Down Expand Up @@ -111,8 +119,10 @@ $(certsdir)/cert-bundle-with-comments.pem: $(RSACERTS) $(ECCERTS)
echo "Comment in bundle\n" >> $@; \
done

.PHONY: test-certs
test-certs: $(PKCS8KEYS) $(RSAKEYS) $(ECKEYS) $(RSACERTS) $(ECCERTS) $(PKCS12CERTS) $(certsdir)/cert-bundle.pem $(certsdir)/cert-bundle-with-comments.pem tst/softhsm2.conf

.PHONY: test-clean
test-clean:
rm -f $(RSAKEYS) $(ECKEYS)
rm -f $(PKCS8KEYS)
Expand Down
168 changes: 157 additions & 11 deletions aws_signing_helper/file_system_signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,30 @@ import (
"crypto/sha512"
"crypto/x509"
"errors"
"github.com/fsnotify/fsnotify"
"io"
"log"
"os"
"sync"
)

type FileSystemSigner struct {
PrivateKey crypto.PrivateKey
cert *x509.Certificate
certChain []*x509.Certificate
sync.RWMutex

PrivateKey crypto.PrivateKey
bundlePath string
cert *x509.Certificate
certChain []*x509.Certificate
certPath string
isPkcs12 bool
privateKeyPath string

watcher *fsnotify.Watcher
}

func (fileSystemSigner FileSystemSigner) Public() crypto.PublicKey {
func (fileSystemSigner *FileSystemSigner) Public() crypto.PublicKey {
fileSystemSigner.RLock()
defer fileSystemSigner.RUnlock()
{
privateKey, ok := fileSystemSigner.PrivateKey.(ecdsa.PrivateKey)
if ok {
Expand All @@ -34,9 +47,13 @@ func (fileSystemSigner FileSystemSigner) Public() crypto.PublicKey {
return nil
}

func (fileSystemSigner FileSystemSigner) Close() {}
func (fileSystemSigner *FileSystemSigner) Close() {
fileSystemSigner.watcher.Close()
}

func (fileSystemSigner FileSystemSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
func (fileSystemSigner *FileSystemSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
fileSystemSigner.RLock()
defer fileSystemSigner.RUnlock()
var hash []byte
switch opts.HashFunc() {
case crypto.SHA256:
Expand Down Expand Up @@ -72,16 +89,20 @@ func (fileSystemSigner FileSystemSigner) Sign(rand io.Reader, digest []byte, opt
return nil, errors.New("unsupported algorithm")
}

func (fileSystemSigner FileSystemSigner) Certificate() (*x509.Certificate, error) {
func (fileSystemSigner *FileSystemSigner) Certificate() (*x509.Certificate, error) {
fileSystemSigner.RLock()
defer fileSystemSigner.RUnlock()
return fileSystemSigner.cert, nil
}

func (fileSystemSigner FileSystemSigner) CertificateChain() ([]*x509.Certificate, error) {
func (fileSystemSigner *FileSystemSigner) CertificateChain() ([]*x509.Certificate, error) {
fileSystemSigner.RLock()
defer fileSystemSigner.RUnlock()
return fileSystemSigner.certChain, nil
}

// Returns a FileSystemSigner, that signs a payload using the private key passed in
func GetFileSystemSigner(privateKey crypto.PrivateKey, certificate *x509.Certificate, certificateChain []*x509.Certificate) (signer Signer, signingAlgorithm string, err error) {
// GetFileSystemSigner returns a FileSystemSigner, that signs a payload using the private key passed in
func GetFileSystemSigner(privateKey crypto.PrivateKey, certificate *x509.Certificate, certificateChain []*x509.Certificate, privateKeyPath string, certPath string, bundlePath string, isPkcs12 bool) (signer Signer, signingAlgorithm string, err error) {
// Find the signing algorithm
_, isRsaKey := privateKey.(rsa.PrivateKey)
if isRsaKey {
Expand All @@ -96,5 +117,130 @@ func GetFileSystemSigner(privateKey crypto.PrivateKey, certificate *x509.Certifi
return nil, "", errors.New("unsupported algorithm")
}

return FileSystemSigner{privateKey, certificate, certificateChain}, signingAlgorithm, nil
fsSigner := &FileSystemSigner{PrivateKey: privateKey, bundlePath: bundlePath, cert: certificate, certChain: certificateChain, certPath: certPath, isPkcs12: isPkcs12, privateKeyPath: privateKeyPath}
fsSigner.watcher, err = fsnotify.NewWatcher()
if err != nil {
return nil, "", err
}
if certPath != "" {
fsSigner.watcher.Add(certPath)
}
if privateKeyPath != "" {
fsSigner.watcher.Add(privateKeyPath)
}
if bundlePath != "" {
fsSigner.watcher.Add(bundlePath)
}

if Debug {
log.Println("Starting file watcher")
}
go fsSigner.watch()

return fsSigner, signingAlgorithm, nil
}

func (fileSystemSigner *FileSystemSigner) watch() {
for {
select {
case event, ok := <-fileSystemSigner.watcher.Events:
// Channel is closed.
if !ok {
return
}

fileSystemSigner.handleEvent(event)

case err, ok := <-fileSystemSigner.watcher.Errors:
// Channel is closed.
if !ok {
return
}

log.Printf("Certificate watch error: %s", err)
}
}
}

func (fileSystemSigner *FileSystemSigner) handleEvent(event fsnotify.Event) {
if !(isWrite(event) || isRemove(event) || isCreate(event)) {
return
}

if Debug {
log.Printf("Certificate event :%v", event)
}

if isRemove(event) {
if err := fileSystemSigner.watcher.Add(event.Name); err != nil {
log.Printf("Error re-watching file: %s", err)
}
}

if event.Name == fileSystemSigner.certPath {
if fileSystemSigner.isPkcs12 {
chain, privateKey, err := ReadPKCS12Data(fileSystemSigner.certPath)
if err != nil {
log.Printf("Failed to read modified PKCS12 certificate: %s\n", err)
os.Exit(1)
}
fileSystemSigner.Lock()
fileSystemSigner.PrivateKey = privateKey
fileSystemSigner.cert = chain[0]
fileSystemSigner.certChain = chain
fileSystemSigner.Unlock()
} else {
_, cert, err := ReadCertificateData(fileSystemSigner.certPath)
if err != nil {
log.Printf("Failed to read modified certificate: %s\n", err)
os.Exit(1)
}
fileSystemSigner.Lock()
fileSystemSigner.cert = cert
fileSystemSigner.Unlock()
}
if Debug {
log.Printf("Replaced certificate from updated file")
}
}

if event.Name == fileSystemSigner.privateKeyPath {
privateKey, err := ReadPrivateKeyData(fileSystemSigner.privateKeyPath)
if err != nil {
log.Printf("Failed to read modified private key: %s\n", err)
os.Exit(1)
}
fileSystemSigner.Lock()
fileSystemSigner.PrivateKey = privateKey
fileSystemSigner.Unlock()
if Debug {
log.Printf("Replaced private key from updated file")
}
}

if event.Name == fileSystemSigner.bundlePath {
chain, err := GetCertChain(fileSystemSigner.bundlePath)
if err != nil {
log.Printf("Failed to read modified certificate bundle: %s\n", err)
os.Exit(1)
}
fileSystemSigner.Lock()
fileSystemSigner.certChain = chain
fileSystemSigner.Unlock()
if Debug {
log.Printf("Replaced certificate chain from updated file")
}
}
}

func isWrite(event fsnotify.Event) bool {
return event.Op&fsnotify.Write == fsnotify.Write
}

func isCreate(event fsnotify.Event) bool {
return event.Op&fsnotify.Create == fsnotify.Create
}

func isRemove(event fsnotify.Event) bool {
return event.Op&fsnotify.Remove == fsnotify.Remove
}
12 changes: 11 additions & 1 deletion aws_signing_helper/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,19 @@ func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *Credentials

err := CheckValidToken(w, r)
if err != nil {
log.Printf("Token validation received error: %s\n", err)
return
}

var nextRefreshTime = cred.Expiration.Add(-RefreshTime)
if time.Until(nextRefreshTime) < RefreshTime {
credentialProcessOutput, _ := GenerateCredentials(opts, signer, signatureAlgorithm)
if Debug {
log.Println("Generating credentials")
}
credentialProcessOutput, gcErr := GenerateCredentials(opts, signer, signatureAlgorithm)
if gcErr != nil {
log.Printf("Error generating credentials: %s", gcErr)
}
cred.AccessKeyId = credentialProcessOutput.AccessKeyId
cred.SecretAccessKey = credentialProcessOutput.SecretAccessKey
cred.Token = credentialProcessOutput.SessionToken
Expand All @@ -240,6 +247,9 @@ func AllIssuesHandlers(cred *RefreshableCred, roleName string, opts *Credentials
return
}
} else {
if Debug {
log.Println("Using previous obtained credentials")
}
err := json.NewEncoder(w).Encode(cred)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
Expand Down
49 changes: 28 additions & 21 deletions aws_signing_helper/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func encodeEcdsaSigValue(signature []byte) (out []byte, err error) {
big.NewInt(0).SetBytes(signature[sigLen:])})
}

// Gets the Signer based on the flags passed in by the user (from which the CredentialsOpts structure is derived)
// GetSigner gets the Signer based on the flags passed in by the user (from which the CredentialsOpts structure is derived)
func GetSigner(opts *CredentialsOpts) (signer Signer, signatureAlgorithm string, err error) {
var (
certificate *x509.Certificate
Expand All @@ -185,16 +185,9 @@ func GetSigner(opts *CredentialsOpts) (signer Signer, signatureAlgorithm string,
}

if opts.CertificateId != "" && !strings.HasPrefix(opts.CertificateId, "pkcs11:") {
certificateData, err := ReadCertificateData(opts.CertificateId)
_, cert, err := ReadCertificateData(opts.CertificateId)
if err == nil {
certificateDerData, err := base64.StdEncoding.DecodeString(certificateData.CertificateData)
if err != nil {
return nil, "", err
}
certificate, err = x509.ParseCertificate([]byte(certificateDerData))
if err != nil {
return nil, "", err
}
certificate = cert
} else if opts.PrivateKeyId == "" {
if Debug {
log.Println("not a PEM certificate, so trying PKCS#12")
Expand All @@ -220,20 +213,18 @@ func GetSigner(opts *CredentialsOpts) (signer Signer, signatureAlgorithm string,
privateKey = *rsaPrivateKeyPtr
}
}
return GetFileSystemSigner(privateKey, certificateChain[0], certificateChain)
return GetFileSystemSigner(privateKey, certificateChain[0], certificateChain, opts.PrivateKeyId, opts.CertificateId, opts.CertificateBundleId, true)
} else {
return nil, "", err
}
}

if opts.CertificateBundleId != "" {
certificateChainPointers, err := ReadCertificateBundleData(opts.CertificateBundleId)
chain, err := GetCertChain(opts.CertificateBundleId)
if err != nil {
return nil, "", err
}
for _, certificate := range certificateChainPointers {
certificateChain = append(certificateChain, certificate)
}
certificateChain = chain
}

if strings.HasPrefix(privateKeyId, "pkcs11:") {
Expand All @@ -250,10 +241,13 @@ func GetSigner(opts *CredentialsOpts) (signer Signer, signatureAlgorithm string,
return nil, "", err
}

if certificate == nil {
return nil, "", errors.New("undefined certificate value")
}
if Debug {
log.Println("attempting to use FileSystemSigner")
}
return GetFileSystemSigner(privateKey, certificate, certificateChain)
return GetFileSystemSigner(privateKey, certificate, certificateChain, privateKeyId, opts.CertificateId, opts.CertificateBundleId, false)
}
}

Expand Down Expand Up @@ -709,18 +703,18 @@ func ReadPrivateKeyDataFromPEMBlock(block *pem.Block) (key crypto.PrivateKey, er
return nil, errors.New("unable to parse private key")
}

// Load the certificate referenced by `certificateId` and extract
// ReadCertificateData loads the certificate referenced by `certificateId` and extracts
// details required by the SDK to construct the StringToSign.
func ReadCertificateData(certificateId string) (CertificateData, error) {
func ReadCertificateData(certificateId string) (CertificateData, *x509.Certificate, error) {
block, err := parseDERFromPEM(certificateId, "CERTIFICATE")
if err != nil {
return CertificateData{}, errors.New("could not parse PEM data")
return CertificateData{}, nil, errors.New("could not parse PEM data")
}

cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
log.Println("could not parse certificate", err)
return CertificateData{}, errors.New("could not parse certificate")
return CertificateData{}, nil, errors.New("could not parse certificate")
}

//extract serial number
Expand All @@ -747,5 +741,18 @@ func ReadCertificateData(certificateId string) (CertificateData, error) {
}

//return struct
return CertificateData{keyType, encodedDer, serialNumber, supportedAlgorithms}, nil
return CertificateData{keyType, encodedDer, serialNumber, supportedAlgorithms}, cert, nil
}

// GetCertChain reads a certificate bundle and returns a chain of all the certificates it contains
func GetCertChain(certificateBundleId string) ([]*x509.Certificate, error) {
certificateChainPointers, err := ReadCertificateBundleData(certificateBundleId)
var chain []*x509.Certificate
if err != nil {
return nil, err
}
for _, certificate := range certificateChainPointers {
chain = append(chain, certificate)
}
return chain, nil
}
Loading

0 comments on commit aa68ad6

Please sign in to comment.