Skip to content

Commit

Permalink
cacert san field changes (#205)
Browse files Browse the repository at this point in the history
* cacert san field changes

* Update NewCertSign call in server module

* remove explicit nil sets

Signed-off-by: Henry Avetisyan <[email protected]>

---------

Signed-off-by: Henry Avetisyan <[email protected]>
Co-authored-by: Henry Avetisyan <[email protected]>
Co-authored-by: Henry Avetisyan <[email protected]>
  • Loading branch information
3 people authored Apr 14, 2023
1 parent 97b4839 commit 497f518
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 17 deletions.
38 changes: 31 additions & 7 deletions cmd/gen-cacert/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"fmt"
"log"
"net"
"net/url"
"os"
"time"

Expand All @@ -38,6 +39,9 @@ const (

var cfg string
var caOutPath string
var skipHostname bool
var skipIPs bool
var uri string

func getIPs() (ips []net.IP, err error) {
ifaces, err := net.Interfaces()
Expand Down Expand Up @@ -67,6 +71,10 @@ func getIPs() (ips []net.IP, err error) {
func main() {
flag.StringVar(&cfg, "config", "", "CA cert configuration file")
flag.StringVar(&caOutPath, "out", defaultCAOutPath, "the output path of the generated CA cert")
flag.BoolVar(&skipHostname, "skip-hostname", false, "skip including dnsName attribute in CA cert")
flag.BoolVar(&skipIPs, "skip-ips", false, "skip including IP attribute in CA cert")
flag.StringVar(&uri, "uri", "", "URI value to include in CA cert SAN")

flag.Parse()
log.SetFlags(log.LstdFlags | log.LUTC | log.Lmicroseconds | log.Lshortfile)

Expand All @@ -85,14 +93,30 @@ func main() {
if err := json.Unmarshal(cfgData, cc); err != nil {
log.Fatal(err)
}
hostname, err := os.Hostname()
if err != nil {
log.Fatal(err)

hostname := ""
if !skipHostname {
hostname, err = os.Hostname()
if err != nil {
log.Fatal(err)
}
}

ips, err := getIPs()
if err != nil {
log.Fatal(err)
var ips []net.IP
if !skipIPs {
ips, err = getIPs()
if err != nil {
log.Fatal(err)
}
}

var uris []*url.URL
if uri != "" {
parsedUri, err := url.Parse(uri)
if err != nil {
log.Fatal(err)
}
uris = []*url.URL{parsedUri}
}

ctx, cancel := context.WithCancel(context.Background())
Expand Down Expand Up @@ -124,7 +148,7 @@ func main() {
OrganizationalUnit: cc.OrganizationalUnit,
CommonName: cc.CommonName,
ValidityPeriod: cc.ValidityPeriod,
}}, requireX509CACert, hostname, ips, config.DefaultPKCS11Timeout)
}}, requireX509CACert, hostname, ips, uris, config.DefaultPKCS11Timeout)
if err != nil {
log.Fatalf("unable to initialize cert signer: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/sign-x509cert/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func main() {
SessionPoolSize: 2,
X509CACertLocation: caPath,
CreateCACertIfNotExist: false,
}}, requireX509CACert, "", nil, config.DefaultPKCS11Timeout) // Hostname and ips should not be needed as CreateCACertIfNotExist is set to be false.
}}, requireX509CACert, "", nil, nil, config.DefaultPKCS11Timeout) // Hostname and ips should not be needed as CreateCACertIfNotExist is set to be false.

if err != nil {
log.Fatalf("unable to initialize cert signer: %v", err)
Expand Down
9 changes: 5 additions & 4 deletions pkcs11/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"fmt"
"log"
"net"
"net/url"
"os"
"time"

Expand Down Expand Up @@ -120,7 +121,7 @@ func getSignerData(ctx context.Context, requestChan chan scheduler.Request, pool
}

// NewCertSign initializes a CertSign object that interacts with PKCS11 compliant device.
func NewCertSign(ctx context.Context, pkcs11ModulePath string, keys []config.KeyConfig, requireX509CACert map[string]bool, hostname string, ips []net.IP, requestTimeout uint) (crypki.CertSign, error) {
func NewCertSign(ctx context.Context, pkcs11ModulePath string, keys []config.KeyConfig, requireX509CACert map[string]bool, hostname string, ips []net.IP, uris []*url.URL, requestTimeout uint) (crypki.CertSign, error) {
p11ctx, err := initPKCS11Context(pkcs11ModulePath)
if err != nil {
return nil, fmt.Errorf("unable to initialize PKCS11 context: %v", err)
Expand Down Expand Up @@ -159,7 +160,7 @@ func NewCertSign(ctx context.Context, pkcs11ModulePath string, keys []config.Key
s.sPool[key.Identifier] = pool
// initialize x509 CA cert if this key will be used for signing x509 certs.
if requireX509CACert[key.Identifier] {
cert, err := getX509CACert(ctx, key, pool, hostname, ips)
cert, err := getX509CACert(ctx, key, pool, hostname, ips, uris)
if err != nil {
log.Fatalf("failed to get x509 CA cert for key with identifier %q, err: %v", key.Identifier, err)
}
Expand Down Expand Up @@ -257,7 +258,7 @@ func (s *signer) SignBlob(ctx context.Context, reqChan chan scheduler.Request, d
// getX509CACert reads and returns x509 CA certificate from X509CACertLocation.
// If the certificate is not valid, and CreateCACertIfNotExist is true, a new CA
// certificate will be generated based on the config, and wrote to X509CACertLocation.
func getX509CACert(ctx context.Context, key config.KeyConfig, pool sPool, hostname string, ips []net.IP) (*x509.Certificate, error) {
func getX509CACert(ctx context.Context, key config.KeyConfig, pool sPool, hostname string, ips []net.IP, uris []*url.URL) (*x509.Certificate, error) {
// Try parse certificate in the given location.
if certBytes, err := os.ReadFile(key.X509CACertLocation); err == nil {
block, _ := pem.Decode(certBytes)
Expand Down Expand Up @@ -293,7 +294,7 @@ func getX509CACert(ctx context.Context, key config.KeyConfig, pool sPool, hostna
}
caConfig.LoadDefaults()

out, err := x509cert.GenCACert(caConfig, signer, hostname, ips, signer.publicKeyAlgorithm(), signer.signAlgorithm())
out, err := x509cert.GenCACert(caConfig, signer, hostname, ips, uris, signer.publicKeyAlgorithm(), signer.signAlgorithm())
if err != nil {
return nil, fmt.Errorf("unable to generate x509 CA certificate: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func Main() {
log.Fatal(err)
}

signer, err := pkcs11.NewCertSign(ctx, cfg.ModulePath, cfg.Keys, keyUsages[config.X509CertEndpoint], hostname, ips, cfg.PKCS11RequestTimeout)
signer, err := pkcs11.NewCertSign(ctx, cfg.ModulePath, cfg.Keys, keyUsages[config.X509CertEndpoint], hostname, ips, nil, cfg.PKCS11RequestTimeout)
if err != nil {
log.Fatalf("unable to initialize cert signer: %v", err)
}
Expand Down
11 changes: 8 additions & 3 deletions x509cert/x509.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@ import (
"fmt"
"math/big"
"net"
"net/url"
"time"

"github.com/theparanoids/crypki"
)

// GenCACert creates the CA certificate given signer.
func GenCACert(config *crypki.CAConfig, signer crypto.Signer, hostname string, ips []net.IP, pka x509.PublicKeyAlgorithm, sa x509.SignatureAlgorithm) ([]byte, error) {
func GenCACert(config *crypki.CAConfig, signer crypto.Signer, hostname string, ips []net.IP, uris []*url.URL, pka x509.PublicKeyAlgorithm, sa x509.SignatureAlgorithm) ([]byte, error) {
// Backdate start time by one hour as the current system clock may be ahead of other running systems.
start := uint64(time.Now().Unix())
end := start + config.ValidityPeriod
start -= 3600
var country, locality, province, org, orgUnit []string
var country, locality, province, org, orgUnit, dnsNames []string
if config.Country != "" {
country = []string{config.Country}
}
Expand All @@ -39,6 +40,9 @@ func GenCACert(config *crypki.CAConfig, signer crypto.Signer, hostname string, i
if config.OrganizationalUnit != "" {
orgUnit = []string{config.OrganizationalUnit}
}
if hostname != "" {
dnsNames = []string{hostname}
}

subj := pkix.Name{
CommonName: config.CommonName,
Expand All @@ -56,8 +60,9 @@ func GenCACert(config *crypki.CAConfig, signer crypto.Signer, hostname string, i
SignatureAlgorithm: sa,
NotBefore: time.Unix(int64(start), 0),
NotAfter: time.Unix(int64(end), 0),
DNSNames: []string{hostname},
DNSNames: dnsNames,
IPAddresses: ips,
URIs: uris,
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
Expand Down
43 changes: 42 additions & 1 deletion x509cert/x509_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"crypto/x509/pkix"
"encoding/pem"
"net"
"net/url"
"reflect"
"testing"

Expand All @@ -25,11 +26,15 @@ func TestGenCACert(t *testing.T) {
if err != nil {
t.Fatal(err)
}
spiffeUri, _ := url.Parse("spiffe://paranoids/crypki")
uris := []*url.URL{spiffeUri}

tests := map[string]struct {
cfg *crypki.CAConfig
signer crypto.Signer
hostname string
ips []net.IP
uris []*url.URL
pka x509.PublicKeyAlgorithm
sa x509.SignatureAlgorithm
wantSubj pkix.Name
Expand Down Expand Up @@ -57,6 +62,28 @@ func TestGenCACert(t *testing.T) {
OrganizationalUnit: []string{"Foo Org Unit"},
},
},
"no-hostname-with-uri": {
cfg: &crypki.CAConfig{
Country: "US",
Locality: "Sunnyvale",
State: "CA",
Organization: "Foo Org",
OrganizationalUnit: "Foo Org Unit",
CommonName: "foo.example.com",
},
signer: eckey,
uris: uris,
pka: pka,
sa: sa,
wantSubj: pkix.Name{
CommonName: "foo.example.com",
Country: []string{"US"},
Locality: []string{"Sunnyvale"},
Province: []string{"CA"},
Organization: []string{"Foo Org"},
OrganizationalUnit: []string{"Foo Org Unit"},
},
},
"no-ST": {
cfg: &crypki.CAConfig{
Country: "US",
Expand Down Expand Up @@ -123,7 +150,7 @@ func TestGenCACert(t *testing.T) {
name, tt := name, tt
t.Run(name, func(t *testing.T) {
t.Parallel()
got, err := GenCACert(tt.cfg, tt.signer, tt.hostname, tt.ips, tt.pka, tt.sa)
got, err := GenCACert(tt.cfg, tt.signer, tt.hostname, tt.ips, tt.uris, tt.pka, tt.sa)
if err != nil {
if !tt.expectError {
t.Error("unexpected error")
Expand All @@ -145,6 +172,20 @@ func TestGenCACert(t *testing.T) {
if !reflect.DeepEqual(cert.Subject, tt.wantSubj) {
t.Errorf("subject mismatch:\n got: \n%+v\n want: \n%+v\n", cert.Subject, tt.wantSubj)
}
if len(tt.uris) > 0 {
if !reflect.DeepEqual(cert.URIs, tt.uris) {
t.Errorf("uri mismatch: %+v\n", cert.URIs)
}
}
if tt.hostname != "" {
if tt.hostname != cert.DNSNames[0] {
t.Errorf("dnsName mismatch: got:%s want: %s\n", cert.DNSNames[0], tt.hostname)
}
} else {
if len(cert.DNSNames) > 0 {
t.Errorf("unexpected dnsName values: %s\n", cert.DNSNames[0])
}
}
})
}

Expand Down

0 comments on commit 497f518

Please sign in to comment.