diff --git a/cmd/scepclient/scepclient.go b/cmd/scepclient/scepclient.go index 3cfe22d..d8ac7ba 100644 --- a/cmd/scepclient/scepclient.go +++ b/cmd/scepclient/scepclient.go @@ -115,15 +115,15 @@ func run(cfg runCfg) error { if err != nil { return err } - var certs []*x509.Certificate + var caCerts []*x509.Certificate { if certNum > 1 { - certs, err = scep.CACerts(resp) + caCerts, err = scep.CACerts(resp) if err != nil { return err } } else { - certs, err = x509.ParseCertificates(resp) + caCerts, err = x509.ParseCertificates(resp) if err != nil { return err } @@ -131,7 +131,7 @@ func run(cfg runCfg) error { } if cfg.debug { - logCerts(level.Debug(logger), certs) + logCerts(level.Debug(logger), caCerts) } var signerCert *x509.Certificate @@ -155,7 +155,7 @@ func run(cfg runCfg) error { tmpl := &scep.PKIMessage{ MessageType: msgType, - Recipients: certs, + Recipients: caCerts, SignerKey: key, SignerCert: signerCert, } @@ -182,7 +182,7 @@ func run(cfg runCfg) error { return errors.Wrapf(err, "PKIOperation for %s", msgType) } - respMsg, err = scep.ParsePKIMessage(respBytes, scep.WithLogger(logger), scep.WithCACerts(msg.Recipients)) + respMsg, err = scep.ParsePKIMessage(respBytes, scep.WithLogger(logger), scep.WithCACerts(caCerts)) if err != nil { return errors.Wrapf(err, "parsing pkiMessage response %s", msgType) } @@ -253,7 +253,7 @@ func validateFingerprint(fingerprint string) (hash []byte, err error) { return } -func validateFlags(keyPath, serverURL string) error { +func validateFlags(keyPath, serverURL, caFingerprint string, useKeyEnciphermentSelector bool) error { if keyPath == "" { return errors.New("must specify private key path") } @@ -264,6 +264,9 @@ func validateFlags(keyPath, serverURL string) error { if err != nil { return fmt.Errorf("invalid server-url flag parameter %s", err) } + if caFingerprint != "" && useKeyEnciphermentSelector { + return errors.New("ca-fingerprint and key-encipherment-selector can't be used at the same time") + } return nil } @@ -285,12 +288,15 @@ func main() { flDNSName = flag.String("dnsname", "", "DNS name to be included in the certificate (SAN)") // in case of multiple certificate authorities, we need to figure out who the recipient of the encrypted - // data is. - flCAFingerprint = flag.String("ca-fingerprint", "", "SHA-256 digest of CA certificate for NDES server. Note: Changed from MD5.") + // data is. This can be done using either the CA fingerprint, or based on the key usage encoded in the + // certificates returned by the authority. + flCAFingerprint = flag.String("ca-fingerprint", "", "SHA-256 digest of CA certificate for NDES server. Note: Changed from MD5.") + flKeyEnciphermentSelector = flag.Bool("key-encipherment-selector", false, "Filter CA certificates by key encipherment usage") flDebugLogging = flag.Bool("debug", false, "enable debug logging") flLogJSON = flag.Bool("log-json", false, "use JSON for log output") ) + flag.Parse() // print version information @@ -299,19 +305,22 @@ func main() { os.Exit(0) } - if err := validateFlags(*flPKeyPath, *flServerURL); err != nil { + if err := validateFlags(*flPKeyPath, *flServerURL, *flCAFingerprint, *flKeyEnciphermentSelector); err != nil { fmt.Println(err) os.Exit(1) } caCertsSelector := scep.NopCertsSelector() - if *flCAFingerprint != "" { + switch { + case *flCAFingerprint != "": hash, err := validateFingerprint(*flCAFingerprint) if err != nil { fmt.Printf("invalid fingerprint: %s\n", err) os.Exit(1) } caCertsSelector = scep.FingerprintCertsSelector(fingerprintHashType, hash) + case *flKeyEnciphermentSelector: + caCertsSelector = scep.EnciphermentCertsSelector() } dir := filepath.Dir(*flPKeyPath)