diff --git a/README.md b/README.md index d918909..a4fb7ed 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ usage: scep [] [] type --help to see usage for each subcommand ``` -Use the `ca -init` subcommand to create a new CA and private key. +Use the `ca -init` subcommand to create a new CA and private key. CA sub-command usage: ``` @@ -95,6 +95,8 @@ Usage of ca: password to store rsa key -keySize int rsa key size (default 4096) + -common_name string + common name (CN) for CA cert (default "MICROMDM SCEP CA") -organization string organization for CA cert (default "scep-ca") -organizational_unit string diff --git a/challenge/challenge.go b/challenge/challenge.go index 2a0049b..d1b0ee4 100644 --- a/challenge/challenge.go +++ b/challenge/challenge.go @@ -2,6 +2,7 @@ package challenge import ( + "context" "crypto/x509" "errors" @@ -16,8 +17,8 @@ type Store interface { } // Middleware wraps next in a CSRSigner that verifies and invalidates the challenge -func Middleware(store Store, next scepserver.CSRSigner) scepserver.CSRSignerFunc { - return func(m *scep.CSRReqMessage) (*x509.Certificate, error) { +func Middleware(store Store, next scepserver.CSRSignerContext) scepserver.CSRSignerContextFunc { + return func(ctx context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) { // TODO: compare challenge only for PKCSReq? valid, err := store.HasChallenge(m.ChallengePassword) if err != nil { @@ -26,6 +27,6 @@ func Middleware(store Store, next scepserver.CSRSigner) scepserver.CSRSignerFunc if !valid { return nil, errors.New("invalid challenge") } - return next.SignCSR(m) + return next.SignCSRContext(ctx, m) } } diff --git a/challenge/challenge_bolt_test.go b/challenge/challenge_bolt_test.go index 003acbd..1a34cb5 100644 --- a/challenge/challenge_bolt_test.go +++ b/challenge/challenge_bolt_test.go @@ -1,6 +1,7 @@ package challenge import ( + "context" "io/ioutil" "os" "testing" @@ -69,12 +70,14 @@ func TestDynamicChallenge(t *testing.T) { ChallengePassword: challengePassword, } - _, err = signer.SignCSR(csrReq) + ctx := context.Background() + + _, err = signer.SignCSRContext(ctx, csrReq) if err != nil { t.Error(err) } - _, err = signer.SignCSR(csrReq) + _, err = signer.SignCSRContext(ctx, csrReq) if err == nil { t.Error("challenge should not be valid twice") } diff --git a/cmd/scepclient/csr.go b/cmd/scepclient/csr.go index 99ac5b7..5d03b25 100644 --- a/cmd/scepclient/csr.go +++ b/cmd/scepclient/csr.go @@ -18,8 +18,8 @@ const ( ) type csrOptions struct { - cn, org, country, ou, locality, province, challenge string - key *rsa.PrivateKey + cn, org, country, ou, locality, province, dnsName, challenge string + key *rsa.PrivateKey } func loadOrMakeCSR(path string, opts *csrOptions) (*x509.CertificateRequest, error) { @@ -44,6 +44,7 @@ func loadOrMakeCSR(path string, opts *csrOptions) (*x509.CertificateRequest, err CertificateRequest: x509.CertificateRequest{ Subject: subject, SignatureAlgorithm: x509.SHA256WithRSA, + DNSNames: subjOrNil(opts.dnsName), }, } if opts.challenge != "" { diff --git a/cmd/scepclient/scepclient.go b/cmd/scepclient/scepclient.go index 3b61222..3cfe22d 100644 --- a/cmd/scepclient/scepclient.go +++ b/cmd/scepclient/scepclient.go @@ -50,6 +50,7 @@ type runCfg struct { debug bool logfmt string caCertMsg string + dnsName string } func run(cfg runCfg) error { @@ -88,6 +89,7 @@ func run(cfg runCfg) error { province: cfg.province, challenge: cfg.challenge, key: key, + dnsName: cfg.dnsName, } csr, err := loadOrMakeCSR(cfg.csrPath, opts) @@ -234,10 +236,11 @@ func logCerts(logger log.Logger, certs []*x509.Certificate) { // validateFingerprint makes sure fingerprint looks like a hash. // We remove spaces and colons from fingerprint as it may come in various forms: -// e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 -// E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855 -// e3b0c442 98fc1c14 9afbf4c8 996fb924 27ae41e4 649b934c a495991b 7852b855 -// e3:b0:c4:42:98:fc:1c:14:9a:fb:f4:c8:99:6f:b9:24:27:ae:41:e4:64:9b:93:4c:a4:95:99:1b:78:52:b8:55 +// +// e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 +// E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855 +// e3b0c442 98fc1c14 9afbf4c8 996fb924 27ae41e4 649b934c a495991b 7852b855 +// e3:b0:c4:42:98:fc:1c:14:9a:fb:f4:c8:99:6f:b9:24:27:ae:41:e4:64:9b:93:4c:a4:95:99:1b:78:52:b8:55 func validateFingerprint(fingerprint string) (hash []byte, err error) { fingerprint = strings.NewReplacer(" ", "", ":", "").Replace(fingerprint) hash, err = hex.DecodeString(fingerprint) @@ -279,6 +282,7 @@ func main() { flProvince = flag.String("province", "", "province for certificate") flCountry = flag.String("country", "US", "country code in certificate") flCACertMessage = flag.String("cacert-message", "", "message sent with GetCACert operation") + flDNSName = flag.String("dnsname", "", "DNS name to be included in the certificate (SAN)") // in case of multiple certificate authorities, we need to figure out who the recipient of the encrypted // data is. @@ -340,6 +344,7 @@ func main() { debug: *flDebugLogging, logfmt: logfmt, caCertMsg: *flCACertMessage, + dnsName: *flDNSName, } if err := run(cfg); err != nil { diff --git a/cmd/scepserver/scepserver.go b/cmd/scepserver/scepserver.go index f2a243f..97c0170 100644 --- a/cmd/scepserver/scepserver.go +++ b/cmd/scepserver/scepserver.go @@ -43,7 +43,8 @@ func main() { //main flags var ( flVersion = flag.Bool("version", false, "prints version information") - flPort = flag.String("port", envString("SCEP_HTTP_LISTEN_PORT", "8080"), "port to listen on") + flHTTPAddr = flag.String("http-addr", envString("SCEP_HTTP_ADDR", ""), "http listen address. defaults to \":8080\"") + flPort = flag.String("port", envString("SCEP_HTTP_LISTEN_PORT", "8080"), "http port to listen on (if you want to specify an address, use -http-addr instead)") flDepotPath = flag.String("depot", envString("SCEP_FILE_DEPOT", "depot"), "path to ca folder") flCAPass = flag.String("capass", envString("SCEP_CA_PASS", ""), "passwd for the ca.key") flClDuration = flag.String("crtvalid", envString("SCEP_CERT_VALID", "365"), "validity for new client certificates in days") @@ -52,6 +53,7 @@ func main() { flCSRVerifierExec = flag.String("csrverifierexec", envString("SCEP_CSR_VERIFIER_EXEC", ""), "will be passed the CSRs for verification") flDebug = flag.Bool("debug", envBool("SCEP_LOG_DEBUG"), "enable debug logging") flLogJSON = flag.Bool("log-json", envBool("SCEP_LOG_JSON"), "output JSON logs") + flSignServerAttrs = flag.Bool("sign-server-attrs", envBool("SCEP_SIGN_SERVER_ATTRS"), "sign cert attrs for server usage") ) flag.Usage = func() { flag.PrintDefaults() @@ -67,7 +69,19 @@ func main() { fmt.Println(version) os.Exit(0) } - port := ":" + *flPort + + // -http-addr and -port conflict. Don't allow the user to set both. + httpAddrSet := setByUser("http-addr", "SCEP_HTTP_ADDR") + portSet := setByUser("port", "SCEP_HTTP_LISTEN_PORT") + var httpAddr string + if httpAddrSet && portSet { + fmt.Fprintln(os.Stderr, "cannot set both -http-addr and -port") + os.Exit(1) + } else if httpAddrSet { + httpAddr = *flHTTPAddr + } else { + httpAddr = ":" + *flPort + } var logger log.Logger { @@ -125,14 +139,17 @@ func main() { lginfo.Log("err", "missing CA certificate") os.Exit(1) } - var signer scepserver.CSRSigner = scepdepot.NewSigner( - depot, + signerOpts := []scepdepot.Option{ scepdepot.WithAllowRenewalDays(allowRenewal), scepdepot.WithValidityDays(clientValidity), scepdepot.WithCAPass(*flCAPass), - ) + } + if *flSignServerAttrs { + signerOpts = append(signerOpts, scepdepot.WithSeverAttrs()) + } + var signer scepserver.CSRSignerContext = scepserver.SignCSRAdapter(scepdepot.NewSigner(depot, signerOpts...)) if *flChallengePassword != "" { - signer = scepserver.ChallengeMiddleware(*flChallengePassword, signer) + signer = scepserver.StaticChallengeMiddleware(*flChallengePassword, signer) } if csrVerifier != nil { signer = csrverifier.Middleware(csrVerifier, signer) @@ -156,8 +173,8 @@ func main() { // start http server errs := make(chan error, 2) go func() { - lginfo.Log("transport", "http", "address", port, "msg", "listening") - errs <- http.ListenAndServe(port, h) + lginfo.Log("transport", "http", "address", httpAddr, "msg", "listening") + errs <- http.ListenAndServe(httpAddr, h) }() go func() { c := make(chan os.Signal) @@ -170,14 +187,15 @@ func main() { func caMain(cmd *flag.FlagSet) int { var ( - flDepotPath = cmd.String("depot", "depot", "path to ca folder") - flInit = cmd.Bool("init", false, "create a new CA") - flYears = cmd.Int("years", 10, "default CA years") - flKeySize = cmd.Int("keySize", 4096, "rsa key size") - flOrg = cmd.String("organization", "scep-ca", "organization for CA cert") - flOrgUnit = cmd.String("organizational_unit", "SCEP CA", "organizational unit (OU) for CA cert") - flPassword = cmd.String("key-password", "", "password to store rsa key") - flCountry = cmd.String("country", "US", "country for CA cert") + flDepotPath = cmd.String("depot", "depot", "path to ca folder") + flInit = cmd.Bool("init", false, "create a new CA") + flYears = cmd.Int("years", 10, "default CA years") + flKeySize = cmd.Int("keySize", 4096, "rsa key size") + flCommonName = cmd.String("common_name", "MICROMDM SCEP CA", "common name (CN) for CA cert") + flOrg = cmd.String("organization", "scep-ca", "organization for CA cert") + flOrgUnit = cmd.String("organizational_unit", "SCEP CA", "organizational unit (OU) for CA cert") + flPassword = cmd.String("key-password", "", "password to store rsa key") + flCountry = cmd.String("country", "US", "country for CA cert") ) cmd.Parse(os.Args[2:]) if *flInit { @@ -187,7 +205,7 @@ func caMain(cmd *flag.FlagSet) int { fmt.Println(err) return 1 } - if err := createCertificateAuthority(key, *flYears, *flOrg, *flOrgUnit, *flCountry, *flDepotPath); err != nil { + if err := createCertificateAuthority(key, *flYears, *flCommonName, *flOrg, *flOrgUnit, *flCountry, *flDepotPath); err != nil { fmt.Println(err) return 1 } @@ -232,9 +250,10 @@ func createKey(bits int, password []byte, depot string) (*rsa.PrivateKey, error) return key, nil } -func createCertificateAuthority(key *rsa.PrivateKey, years int, organization string, organizationalUnit string, country string, depot string) error { +func createCertificateAuthority(key *rsa.PrivateKey, years int, commonName string, organization string, organizationalUnit string, country string, depot string) error { cert := scepdepot.NewCACert( scepdepot.WithYears(years), + scepdepot.WithCommonName(commonName), scepdepot.WithOrganization(organization), scepdepot.WithOrganizationalUnit(organizationalUnit), scepdepot.WithCountry(country), @@ -288,3 +307,13 @@ func envBool(key string) bool { } return false } + +func setByUser(flagName, envName string) bool { + userDefinedFlags := make(map[string]bool) + flag.Visit(func(f *flag.Flag) { + userDefinedFlags[f.Name] = true + }) + flagSet := userDefinedFlags[flagName] + _, envSet := os.LookupEnv(envName) + return flagSet || envSet +} diff --git a/cryptoutil/cryptoutil_test.go b/cryptoutil/cryptoutil_test.go index ab83c2e..53a73ee 100644 --- a/cryptoutil/cryptoutil_test.go +++ b/cryptoutil/cryptoutil_test.go @@ -4,18 +4,23 @@ import ( "crypto" "crypto/ecdsa" "crypto/elliptic" + "crypto/rand" "crypto/rsa" "math/big" "testing" ) func TestGenerateSubjectKeyID(t *testing.T) { + ecKey, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + if err != nil { + t.Fatal(err) + } for _, test := range []struct { testName string pub crypto.PublicKey }{ {"RSA", &rsa.PublicKey{N: big.NewInt(123), E: 65537}}, - {"ECDSA", &ecdsa.PublicKey{X: big.NewInt(123), Y: big.NewInt(123), Curve: elliptic.P224()}}, + {"ECDSA", ecKey.Public()}, } { test := test t.Run(test.testName, func(t *testing.T) { diff --git a/csrverifier/csrverifier.go b/csrverifier/csrverifier.go index bfc350b..da6f5aa 100644 --- a/csrverifier/csrverifier.go +++ b/csrverifier/csrverifier.go @@ -2,6 +2,7 @@ package csrverifier import ( + "context" "crypto/x509" "errors" @@ -15,8 +16,8 @@ type CSRVerifier interface { } // Middleware wraps next in a CSRSigner that runs verifier -func Middleware(verifier CSRVerifier, next scepserver.CSRSigner) scepserver.CSRSignerFunc { - return func(m *scep.CSRReqMessage) (*x509.Certificate, error) { +func Middleware(verifier CSRVerifier, next scepserver.CSRSignerContext) scepserver.CSRSignerContextFunc { + return func(ctx context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) { ok, err := verifier.Verify(m.RawDecrypted) if err != nil { return nil, err @@ -24,6 +25,6 @@ func Middleware(verifier CSRVerifier, next scepserver.CSRSigner) scepserver.CSRS if !ok { return nil, errors.New("CSR verify failed") } - return next.SignCSR(m) + return next.SignCSRContext(ctx, m) } } diff --git a/depot/bolt/depot.go b/depot/bolt/depot.go index a0c1161..cdeee81 100644 --- a/depot/bolt/depot.go +++ b/depot/bolt/depot.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "math/big" + "sync" "github.com/micromdm/scep/v2/depot" @@ -18,6 +19,7 @@ import ( // https://github.com/boltdb/bolt type Depot struct { *bolt.DB + serialMu sync.RWMutex } const ( @@ -36,7 +38,7 @@ func NewBoltDepot(db *bolt.DB) (*Depot, error) { if err != nil { return nil, err } - return &Depot{db}, nil + return &Depot{DB: db}, nil } // For some read operations Bolt returns a direct memory reference to @@ -93,26 +95,28 @@ func (db *Depot) Put(cn string, crt *x509.Certificate) error { if crt == nil || crt.Raw == nil { return fmt.Errorf("%q does not specify a valid certificate for storage", cn) } - serial, err := db.Serial() - if err != nil { - return err - } - - err = db.Update(func(tx *bolt.Tx) error { + err := db.Update(func(tx *bolt.Tx) error { bucket := tx.Bucket([]byte(certBucket)) if bucket == nil { return fmt.Errorf("bucket %q not found!", certBucket) } - name := cn + "." + serial.String() + name := cn + "." + crt.SerialNumber.String() return bucket.Put([]byte(name), crt.Raw) }) + return err +} + +func (db *Depot) Serial() (*big.Int, error) { + db.serialMu.Lock() + defer db.serialMu.Unlock() + s, err := db.readSerial() if err != nil { - return err + return nil, err } - return db.incrementSerial(serial) + return s, db.incrementSerial(s) } -func (db *Depot) Serial() (*big.Int, error) { +func (db *Depot) readSerial() (*big.Int, error) { s := big.NewInt(2) if !db.hasKey([]byte("serial")) { if err := db.writeSerial(s); err != nil { @@ -132,10 +136,7 @@ func (db *Depot) Serial() (*big.Int, error) { s = s.SetBytes(k) return nil }) - if err != nil { - return nil, err - } - return s, nil + return s, err } func (db *Depot) writeSerial(s *big.Int) error { @@ -156,7 +157,7 @@ func (db *Depot) hasKey(name []byte) bool { if bucket == nil { return fmt.Errorf("bucket %q not found!", certBucket) } - k := bucket.Get([]byte("serial")) + k := bucket.Get(name) if k != nil { present = true } @@ -166,15 +167,8 @@ func (db *Depot) hasKey(name []byte) bool { } func (db *Depot) incrementSerial(s *big.Int) error { - serial := s.Add(s, big.NewInt(1)) - err := db.Update(func(tx *bolt.Tx) error { - bucket := tx.Bucket([]byte(certBucket)) - if bucket == nil { - return fmt.Errorf("bucket %q not found!", certBucket) - } - return bucket.Put([]byte("serial"), []byte(serial.Bytes())) - }) - return err + serial := new(big.Int).Add(s, big.NewInt(1)) + return db.writeSerial(serial) } func (db *Depot) HasCN(cn string, allowTime int, cert *x509.Certificate, revokeOldCertificate bool) (bool, error) { @@ -185,8 +179,7 @@ func (db *Depot) HasCN(cn string, allowTime int, cert *x509.Certificate, revokeO } var hasCN bool err := db.View(func(tx *bolt.Tx) error { - // TODO: "scep_certificates" is internal const in micromdm/scep - curs := tx.Bucket([]byte("scep_certificates")).Cursor() + curs := tx.Bucket([]byte(certBucket)).Cursor() prefix := []byte(cert.Subject.CommonName) for k, v := curs.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = curs.Next() { if bytes.Compare(v, cert.Raw) == 0 { diff --git a/depot/bolt/depot_test.go b/depot/bolt/depot_test.go index e64dbed..34a9e94 100644 --- a/depot/bolt/depot_test.go +++ b/depot/bolt/depot_test.go @@ -100,7 +100,7 @@ func TestDepot_incrementSerial(t *testing.T) { if err := db.incrementSerial(tt.args); (err != nil) != tt.wantErr { t.Errorf("%q. Depot.incrementSerial() error = %v, wantErr %v", tt.name, err, tt.wantErr) } - got, _ := db.Serial() + got, _ := db.readSerial() if !reflect.DeepEqual(got, tt.want) { t.Errorf("%q. Depot.Serial() = %v, want %v", tt.name, got, tt.want) } diff --git a/depot/file/depot.go b/depot/file/depot.go index b843b95..0eb6d74 100644 --- a/depot/file/depot.go +++ b/depot/file/depot.go @@ -16,6 +16,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "time" ) @@ -31,7 +32,9 @@ func NewFileDepot(path string) (*fileDepot, error) { } type fileDepot struct { - dirPath string + dirPath string + serialMu sync.Mutex + dbMu sync.Mutex } func (d *fileDepot) CA(pass []byte) ([]*x509.Certificate, *rsa.PrivateKey, error) { @@ -75,10 +78,7 @@ func (d *fileDepot) Put(cn string, crt *x509.Certificate) error { return err } - serial, err := d.Serial() - if err != nil { - return err - } + serial := crt.SerialNumber if crt.Subject.CommonName == "" { // this means our cn was replaced by the certificate Signature @@ -103,14 +103,12 @@ func (d *fileDepot) Put(cn string, crt *x509.Certificate) error { return err } - if err := d.incrementSerial(serial); err != nil { - return err - } - return nil } func (d *fileDepot) Serial() (*big.Int, error) { + d.serialMu.Lock() + defer d.serialMu.Unlock() name := d.path("serial") s := big.NewInt(2) if err := d.check("serial"); err != nil { @@ -136,6 +134,9 @@ func (d *fileDepot) Serial() (*big.Int, error) { if !ok { return nil, errors.New("could not convert " + string(data) + " to serial number") } + if err := d.incrementSerial(serial); err != nil { + return serial, err + } return serial, nil } @@ -255,6 +256,8 @@ func (d *fileDepot) HasCN(_ string, allowTime int, cert *x509.Certificate, revok } func (d *fileDepot) writeDB(cn string, serial *big.Int, filename string, cert *x509.Certificate) error { + d.dbMu.Lock() + defer d.dbMu.Unlock() var dbEntry bytes.Buffer @@ -365,8 +368,9 @@ func (d *fileDepot) path(name string) string { } const ( - rsaPrivateKeyPEMBlockType = "RSA PRIVATE KEY" - certificatePEMBlockType = "CERTIFICATE" + rsaPrivateKeyPEMBlockType = "RSA PRIVATE KEY" + pkcs8PrivateKeyPEMBlockType = "PRIVATE KEY" + certificatePEMBlockType = "CERTIFICATE" ) // load an encrypted private key from disk @@ -375,15 +379,33 @@ func loadKey(data []byte, password []byte) (*rsa.PrivateKey, error) { if pemBlock == nil { return nil, errors.New("PEM decode failed") } - if pemBlock.Type != rsaPrivateKeyPEMBlockType { + switch pemBlock.Type { + case rsaPrivateKeyPEMBlockType: + if x509.IsEncryptedPEMBlock(pemBlock) { + b, err := x509.DecryptPEMBlock(pemBlock, password) + if err != nil { + return nil, err + } + return x509.ParsePKCS1PrivateKey(b) + } + return x509.ParsePKCS1PrivateKey(pemBlock.Bytes) + case pkcs8PrivateKeyPEMBlockType: + priv, err := x509.ParsePKCS8PrivateKey(pemBlock.Bytes) + if err != nil { + return nil, err + } + switch priv := priv.(type) { + case *rsa.PrivateKey: + return priv, nil + // case *dsa.PublicKey: + // case *ecdsa.PublicKey: + // case ed25519.PublicKey: + default: + panic("unsupported type of public key. SCEP need RSA private key") + } + default: return nil, errors.New("unmatched type or headers") } - - b, err := x509.DecryptPEMBlock(pemBlock, password) - if err != nil { - return nil, err - } - return x509.ParsePKCS1PrivateKey(b) } // load an encrypted private key from disk diff --git a/depot/signer.go b/depot/signer.go index 46c2f0f..3e3bdb5 100644 --- a/depot/signer.go +++ b/depot/signer.go @@ -3,7 +3,6 @@ package depot import ( "crypto/rand" "crypto/x509" - "sync" "time" "github.com/micromdm/scep/v2/cryptoutil" @@ -13,10 +12,11 @@ import ( // Signer signs x509 certificates and stores them in a Depot type Signer struct { depot Depot - mu sync.Mutex caPass string allowRenewalDays int validityDays int + serverAttrs bool + signatureAlgo x509.SignatureAlgorithm } // Option customizes Signer @@ -28,6 +28,7 @@ func NewSigner(depot Depot, opts ...Option) *Signer { depot: depot, allowRenewalDays: 14, validityDays: 365, + signatureAlgo: 0, } for _, opt := range opts { opt(s) @@ -35,6 +36,15 @@ func NewSigner(depot Depot, opts ...Option) *Signer { return s } +// WithSignatureAlgorithm sets the signature algorithm to be used to sign certificates. +// When set to a non-zero value, this would take preference over the default behaviour of +// matching the signing algorithm from the x509 CSR. +func WithSignatureAlgorithm(a x509.SignatureAlgorithm) Option { + return func(s *Signer) { + s.signatureAlgo = a + } +} + // WithCAPass specifies the password to use with an encrypted CA key func WithCAPass(pass string) Option { return func(s *Signer) { @@ -56,6 +66,12 @@ func WithValidityDays(v int) Option { } } +func WithSeverAttrs() Option { + return func(s *Signer) { + s.serverAttrs = true + } +} + // SignCSR signs a certificate using Signer's Depot CA func (s *Signer) SignCSR(m *scep.CSRReqMessage) (*x509.Certificate, error) { id, err := cryptoutil.GenerateSubjectKeyID(m.CSR.PublicKey) @@ -63,32 +79,39 @@ func (s *Signer) SignCSR(m *scep.CSRReqMessage) (*x509.Certificate, error) { return nil, err } - s.mu.Lock() - defer s.mu.Unlock() - serial, err := s.depot.Serial() if err != nil { return nil, err } + var signatureAlgo x509.SignatureAlgorithm + if s.signatureAlgo != 0 { + signatureAlgo = s.signatureAlgo + } + // create cert template tmpl := &x509.Certificate{ SerialNumber: serial, Subject: m.CSR.Subject, - NotBefore: time.Now().Add(-600).UTC(), + NotBefore: time.Now().Add(time.Second * -600).UTC(), NotAfter: time.Now().AddDate(0, 0, s.validityDays).UTC(), SubjectKeyId: id, KeyUsage: x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{ x509.ExtKeyUsageClientAuth, }, - SignatureAlgorithm: m.CSR.SignatureAlgorithm, + SignatureAlgorithm: signatureAlgo, DNSNames: m.CSR.DNSNames, EmailAddresses: m.CSR.EmailAddresses, IPAddresses: m.CSR.IPAddresses, URIs: m.CSR.URIs, } + if s.serverAttrs { + tmpl.KeyUsage |= x509.KeyUsageDataEncipherment | x509.KeyUsageKeyEncipherment + tmpl.ExtKeyUsage = append(tmpl.ExtKeyUsage, x509.ExtKeyUsageServerAuth) + } + caCerts, caKey, err := s.depot.CA([]byte(s.caPass)) if err != nil { return nil, err diff --git a/go.mod b/go.mod index 0d81571..f732350 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,6 @@ require ( github.com/groob/finalizer v0.0.0-20170707115354-4c2ed49aabda github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 // indirect github.com/pkg/errors v0.8.0 - go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 - golang.org/x/net v0.0.0-20170726083632-f5079bd7f6f7 // indirect - golang.org/x/sys v0.0.0-20170728174421-0f826bdd13b5 // indirect + github.com/smallstep/pkcs7 v0.0.0-20231107075624-be1870d87d13 + golang.org/x/net v0.17.0 // indirect ) diff --git a/go.sum b/go.sum index 80502af..37afd10 100644 --- a/go.sum +++ b/go.sum @@ -16,9 +16,46 @@ github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 h1:T+h1c/A9Gawja4Y9mFVWj github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 h1:CCriYyAfq1Br1aIYettdHZTy8mBTIPo7We18TuO/bak= -go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352/go.mod h1:SNgMg+EgDFwmvSmLRTNKC5fegJjB7v23qTQ0XLGUNHk= -golang.org/x/net v0.0.0-20170726083632-f5079bd7f6f7 h1:1Pw+ZX4dmGORIwGkTwnUr7RFuMhfpCYHXRZNF04XPYs= -golang.org/x/net v0.0.0-20170726083632-f5079bd7f6f7/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/sys v0.0.0-20170728174421-0f826bdd13b5 h1:NAjcSWsnFBcOQGn/lxvHouhL7iPC53X8+znVzzQkAEg= -golang.org/x/sys v0.0.0-20170728174421-0f826bdd13b5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/smallstep/pkcs7 v0.0.0-20231107075624-be1870d87d13 h1:qRxEt9ESQhAg1kjmgJ8oyyzlc9zkAjOooe7bcKjKORQ= +github.com/smallstep/pkcs7 v0.0.0-20231107075624-be1870d87d13/go.mod h1:SoUAr/4M46rZ3WaLstHxGhLEgoYIDRqxQEXLOmOEB0Y= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/scep/scep.go b/scep/scep.go index 2c82453..8c81615 100644 --- a/scep/scep.go +++ b/scep/scep.go @@ -18,7 +18,7 @@ import ( "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" "github.com/pkg/errors" - "go.mozilla.org/pkcs7" + "github.com/smallstep/pkcs7" ) // errors @@ -85,7 +85,6 @@ const ( // reasons: type FailInfo string -// const ( BadAlg FailInfo = "0" BadMessageCheck = "1" diff --git a/server/csrsigner.go b/server/csrsigner.go index 604776f..1ddfad0 100644 --- a/server/csrsigner.go +++ b/server/csrsigner.go @@ -1,6 +1,7 @@ package scepserver import ( + "context" "crypto/subtle" "crypto/x509" "errors" @@ -8,6 +9,22 @@ import ( "github.com/micromdm/scep/v2/scep" ) +// CSRSignerContext is a handler for signing CSRs by a CA/RA. +// +// SignCSRContext should take the CSR in the CSRReqMessage and return a +// Certificate signed by the CA. +type CSRSignerContext interface { + SignCSRContext(context.Context, *scep.CSRReqMessage) (*x509.Certificate, error) +} + +// CSRSignerContextFunc is an adapter for CSR signing by the CA/RA. +type CSRSignerContextFunc func(context.Context, *scep.CSRReqMessage) (*x509.Certificate, error) + +// SignCSR calls f(ctx, m). +func (f CSRSignerContextFunc) SignCSRContext(ctx context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) { + return f(ctx, m) +} + // CSRSigner is a handler for CSR signing by the CA/RA // // SignCSR should take the CSR in the CSRReqMessage and return a @@ -16,29 +33,36 @@ type CSRSigner interface { SignCSR(*scep.CSRReqMessage) (*x509.Certificate, error) } -// CSRSignerFunc is an adapter for CSR signing by the CA/RA +// CSRSignerFunc is an adapter for CSR signing by the CA/RA. type CSRSignerFunc func(*scep.CSRReqMessage) (*x509.Certificate, error) -// SignCSR calls f(m) +// SignCSR calls f(m). func (f CSRSignerFunc) SignCSR(m *scep.CSRReqMessage) (*x509.Certificate, error) { return f(m) } -// NopCSRSigner does nothing -func NopCSRSigner() CSRSignerFunc { - return func(m *scep.CSRReqMessage) (*x509.Certificate, error) { +// NopCSRSigner does nothing. +func NopCSRSigner() CSRSignerContextFunc { + return func(_ context.Context, _ *scep.CSRReqMessage) (*x509.Certificate, error) { return nil, nil } } -// ChallengeMiddleware wraps next in a CSRSigner that validates the challenge from the CSR -func ChallengeMiddleware(challenge string, next CSRSigner) CSRSignerFunc { +// StaticChallengeMiddleware wraps next and validates the challenge from the CSR. +func StaticChallengeMiddleware(challenge string, next CSRSignerContext) CSRSignerContextFunc { challengeBytes := []byte(challenge) - return func(m *scep.CSRReqMessage) (*x509.Certificate, error) { + return func(ctx context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) { // TODO: compare challenge only for PKCSReq? if subtle.ConstantTimeCompare(challengeBytes, []byte(m.ChallengePassword)) != 1 { return nil, errors.New("invalid challenge") } + return next.SignCSRContext(ctx, m) + } +} + +// SignCSRAdapter adapts a next (i.e. no context) to a context signer. +func SignCSRAdapter(next CSRSigner) CSRSignerContextFunc { + return func(_ context.Context, m *scep.CSRReqMessage) (*x509.Certificate, error) { return next.SignCSR(m) } } diff --git a/server/csrsigner_test.go b/server/csrsigner_test.go index b4c17bf..62696de 100644 --- a/server/csrsigner_test.go +++ b/server/csrsigner_test.go @@ -1,6 +1,7 @@ package scepserver import ( + "context" "testing" "github.com/micromdm/scep/v2/scep" @@ -8,18 +9,20 @@ import ( func TestChallengeMiddleware(t *testing.T) { testPW := "RIGHT" - signer := ChallengeMiddleware(testPW, NopCSRSigner()) + signer := StaticChallengeMiddleware(testPW, NopCSRSigner()) csrReq := &scep.CSRReqMessage{ChallengePassword: testPW} - _, err := signer.SignCSR(csrReq) + ctx := context.Background() + + _, err := signer.SignCSRContext(ctx, csrReq) if err != nil { t.Error(err) } csrReq.ChallengePassword = "WRONG" - _, err = signer.SignCSR(csrReq) + _, err = signer.SignCSRContext(ctx, csrReq) if err == nil { t.Error("invalid challenge should generate an error") } diff --git a/server/service.go b/server/service.go index 58ef85e..b20eb47 100644 --- a/server/service.go +++ b/server/service.go @@ -47,7 +47,7 @@ type service struct { // The (chainable) CSR signing function. Intended to handle all // SCEP request functionality such as CSR & challenge checking, CA // issuance, RA proxying, etc. - signer CSRSigner + signer CSRSignerContext /// info logging is implemented in the service middleware layer. debugLogger log.Logger @@ -80,7 +80,7 @@ func (svc *service) PKIOperation(ctx context.Context, data []byte) ([]byte, erro return nil, err } - crt, err := svc.signer.SignCSR(msg.CSRReqMessage) + crt, err := svc.signer.SignCSRContext(ctx, msg.CSRReqMessage) if err == nil && crt == nil { err = errors.New("no signed certificate") } @@ -119,7 +119,7 @@ func WithAddlCA(ca *x509.Certificate) ServiceOption { } // NewService creates a new scep service -func NewService(crt *x509.Certificate, key *rsa.PrivateKey, signer CSRSigner, opts ...ServiceOption) (Service, error) { +func NewService(crt *x509.Certificate, key *rsa.PrivateKey, signer CSRSignerContext, opts ...ServiceOption) (Service, error) { s := &service{ crt: crt, key: key, diff --git a/server/service_bolt_test.go b/server/service_bolt_test.go index 9bd40b0..74df9f0 100644 --- a/server/service_bolt_test.go +++ b/server/service_bolt_test.go @@ -46,7 +46,7 @@ func TestCaCert(t *testing.T) { caCert := certs[0] // SCEP service - svc, err := scepserver.NewService(caCert, key, scepdepot.NewSigner(depot)) + svc, err := scepserver.NewService(caCert, key, scepserver.SignCSRAdapter(scepdepot.NewSigner(depot))) if err != nil { t.Fatal(err) } @@ -131,6 +131,12 @@ func TestCaCert(t *testing.T) { t.Error("no established chain between issued cert and CA") } + if csr.SignatureAlgorithm != respCert.SignatureAlgorithm { + t.Fatal(fmt.Errorf("cert signature algo %s different from csr signature algo %s", + csr.SignatureAlgorithm.String(), + respCert.SignatureAlgorithm.String())) + } + // verify unique certificate serials for _, ser := range serCollector { if respCert.SerialNumber.Cmp(ser) == 0 {