Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the built in Go facility to verify certs with trusted CAs #60

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,11 @@ func (ctx *SigningContext) constructSignedInfo(el *etree.Element, enveloped bool

dataId := el.SelectAttrValue(ctx.IdAttribute, "")
if dataId == "" {
return nil, errors.New("Missing data ID")
reference.CreateAttr(URIAttr, "")
} else {
reference.CreateAttr(URIAttr, "#"+dataId)
}

reference.CreateAttr(URIAttr, "#"+dataId)

// /SignedInfo/Reference/Transforms
transforms := ctx.createNamespacedElement(reference, TransformsTag)
Expand Down
11 changes: 0 additions & 11 deletions sign_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,6 @@ func TestSignErrors(t *testing.T) {

_, err := ctx.SignEnveloped(authnRequest)
require.Error(t, err)

randomKeyStore = RandomKeyStoreForTest()
ctx = NewDefaultSigningContext(randomKeyStore)

authnRequest = &etree.Element{
Space: "samlp",
Tag: "AuthnRequest",
}

_, err = ctx.SignEnveloped(authnRequest)
require.Error(t, err)
}

func TestSignNonDefaultID(t *testing.T) {
Expand Down
60 changes: 48 additions & 12 deletions validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,17 @@ func (ctx *ValidationContext) verifySignedInfo(sig *types.Signature, canonicaliz
}

func (ctx *ValidationContext) validateSignature(el *etree.Element, sig *types.Signature, cert *x509.Certificate) (*etree.Element, error) {
idAttr := el.SelectAttr(ctx.IdAttribute)
if idAttr == nil || idAttr.Value == "" {
return nil, errors.New("Missing ID attribute")
idAttrEl := el.SelectAttr(ctx.IdAttribute)
idAttr := ""
if idAttrEl != nil {
idAttr = idAttrEl.Value
}

var ref *types.Reference

// Find the first reference which references the top-level element
for _, _ref := range sig.SignedInfo.References {
if _ref.URI == "" || _ref.URI[1:] == idAttr.Value {
if _ref.URI == "" || _ref.URI[1:] == idAttr {
ref = &_ref
}
}
Expand Down Expand Up @@ -298,9 +299,10 @@ func contains(roots []*x509.Certificate, cert *x509.Certificate) bool {

// findSignature searches for a Signature element referencing the passed root element.
func (ctx *ValidationContext) findSignature(el *etree.Element) (*types.Signature, error) {
idAttr := el.SelectAttr(ctx.IdAttribute)
if idAttr == nil || idAttr.Value == "" {
return nil, errors.New("Missing ID attribute")
idAttrEl := el.SelectAttr(ctx.IdAttribute)
idAttr := ""
if idAttrEl != nil {
idAttr = idAttrEl.Value
}

var sig *types.Signature
Expand Down Expand Up @@ -380,7 +382,7 @@ func (ctx *ValidationContext) findSignature(el *etree.Element) (*types.Signature
// Traverse references in the signature to determine whether it has at least
// one reference to the top level element. If so, conclude the search.
for _, ref := range _sig.SignedInfo.References {
if ref.URI == "" || ref.URI[1:] == idAttr.Value {
if ref.URI == "" || ref.URI[1:] == idAttr {
sig = _sig
return etreeutils.ErrTraversalHalted
}
Expand All @@ -400,7 +402,7 @@ func (ctx *ValidationContext) findSignature(el *etree.Element) (*types.Signature
return sig, nil
}

func (ctx *ValidationContext) verifyCertificate(sig *types.Signature) (*x509.Certificate, error) {
func (ctx *ValidationContext) verifyCertificate(sig *types.Signature, verify bool) (*x509.Certificate, error) {
now := ctx.Clock.Now()

roots, err := ctx.CertificateStore.Certificates()
Expand Down Expand Up @@ -436,8 +438,24 @@ func (ctx *ValidationContext) verifyCertificate(sig *types.Signature) (*x509.Cer
}

// Verify that the certificate is one we trust
if !contains(roots, cert) {
return nil, errors.New("Could not verify certificate against trusted certs")
if verify {
pool := x509.NewCertPool()
for _, c := range roots {
pool.AddCert(c)
}
opts := x509.VerifyOptions{
Roots: pool,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageCodeSigning},
}

_, err := cert.Verify(opts)
if err != nil {
return nil, err
}
} else {
if !contains(roots, cert) {
return nil, errors.New("Could not verify certificate against trusted certs")
}
}

if now.Before(cert.NotBefore) || now.After(cert.NotAfter) {
Expand All @@ -458,7 +476,25 @@ func (ctx *ValidationContext) Validate(el *etree.Element) (*etree.Element, error
return nil, err
}

cert, err := ctx.verifyCertificate(sig)
cert, err := ctx.verifyCertificate(sig, false)
if err != nil {
return nil, err
}

return ctx.validateSignature(el, sig, cert)
}

// ValidateWithRootTrust does the same as Verify except it actually verifies the root CA is trusted as well
func (ctx *ValidationContext) ValidateWithRootTrust(el *etree.Element) (*etree.Element, error) {
// Make a copy of the element to avoid mutating the one we were passed.
el = el.Copy()

sig, err := ctx.findSignature(el)
if err != nil {
return nil, err
}

cert, err := ctx.verifyCertificate(sig, true)
if err != nil {
return nil, err
}
Expand Down