Skip to content

Commit

Permalink
enforce certificate correctness in TBSCertificate.SignWith
Browse files Browse the repository at this point in the history
  • Loading branch information
JackDoanRivian committed Nov 5, 2024
1 parent 50850ee commit c096f5d
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 12 deletions.
4 changes: 4 additions & 0 deletions cert/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string
after = time.Now().Add(time.Second * 60).Round(time.Second)
}

if len(networks) == 0 {
networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
}

var pub, priv []byte
switch curve {
case Curve_CURVE25519:
Expand Down
98 changes: 91 additions & 7 deletions cert/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,89 @@ func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Cert
}
}

// readyToSign checks all signing requirements that don't require us to cross-reference with a CA
func (t *TBSCertificate) readyToSign() error {
if t.PublicKey == nil {
return fmt.Errorf("public key not set")
}

if !t.IsCA && len(t.Networks) == 0 {
return fmt.Errorf("non-CA certificates must contain at least one network")
}

hasV4Networks := false
hasV6Networks := false
for _, n := range t.Networks {
if !n.IsValid() || !n.Addr().IsValid() {
return fmt.Errorf("invalid network: %s", n)
}
if t.Version == Version1 && n.Addr().Is6() {
return fmt.Errorf("certificate v1 may not contain IPv6 networks: %v", t.Networks)
}
if n.Addr().Zone() != "" {
return fmt.Errorf("networks may not contain zones: %s", n)
}
if n.Addr().Is4In6() {
return fmt.Errorf("4in6 networks are not allowed: %s", n)
}
if !t.IsCA && n.Addr().IsUnspecified() {
return fmt.Errorf("non-CA certificates must not use the zero address as a network: %s", n)
}

hasV4Networks = hasV4Networks || n.Addr().Is4()
hasV6Networks = hasV6Networks || n.Addr().Is6()
}

slices.SortFunc(t.Networks, comparePrefix)
err := findDuplicatePrefix(t.Networks)
if err != nil {
return err
}

for _, n := range t.UnsafeNetworks {
if !n.IsValid() || !n.Addr().IsValid() {
return fmt.Errorf("invalid unsafe network: %s", n)
}
if n.Addr().Zone() != "" {
return fmt.Errorf("unsafe_networks may not contain zones: %s", n)
}
//todo are unsafe networks that overlap networks allowed?

if n.Addr().Is6() {
if t.Version == Version1 {
return fmt.Errorf("certificate v1 may not contain IPv6 unsafe networks: %v", t.Networks)
}
if !hasV6Networks && !t.IsCA {
return fmt.Errorf("IPv6 unsafe networks require an IPv6 address assignment")
}
} else if n.Addr().Is4() {
if !hasV4Networks && !t.IsCA {
return fmt.Errorf("IPv4 unsafe networks require an IPv4 address assignment")
}
}
}

slices.SortFunc(t.UnsafeNetworks, comparePrefix)
err = findDuplicatePrefix(t.UnsafeNetworks)
if err != nil {
return err
}

return nil
}

// SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature.
// You should only use SignWith if you do not have direct access to your private key.
func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) {
if curve != t.Curve {
return nil, fmt.Errorf("curve in cert and private key supplied don't match")
}

//TODO: make sure we have all minimum properties to sign, like a public key
//TODO: we need to verify networks and unsafe networks (no duplicates, max of 1 of each version for v2 certs
//readyToSign sorts Networks and UnsafeNetworks for us
err := t.readyToSign()
if err != nil {
return nil, err
}

if signer != nil {
if t.IsCA {
Expand All @@ -107,20 +181,17 @@ func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLamb
}
}

slices.SortFunc(t.Networks, comparePrefix)
slices.SortFunc(t.UnsafeNetworks, comparePrefix)

var c beingSignedCertificate
switch t.Version {
case Version1:
c = &certificateV1{}
err := c.fromTBSCertificate(t)
err = c.fromTBSCertificate(t)
if err != nil {
return nil, err
}
case Version2:
c = &certificateV2{}
err := c.fromTBSCertificate(t)
err = c.fromTBSCertificate(t)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -158,3 +229,16 @@ func comparePrefix(a, b netip.Prefix) int {
}
return addr
}

// findDuplicatePrefix returns an error if there is a duplicate prefix in the pre-sorted input slice sortedPrefixes
func findDuplicatePrefix(sortedPrefixes []netip.Prefix) error {
if len(sortedPrefixes) < 2 {
return nil
}
for i := 1; i < len(sortedPrefixes); i++ {
if comparePrefix(sortedPrefixes[i], sortedPrefixes[i-1]) == 0 {
return fmt.Errorf("duplicate network detected: %v", sortedPrefixes[i])
}
}
return nil
}
20 changes: 15 additions & 5 deletions cmd/nebula-cert/print_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func Test_printCert(t *testing.T) {
tf.Truncate(0)
tf.Seek(0, 0)
ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil)
c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, []string{"hi"})
c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}, nil, []string{"hi"})

p, _ := c.MarshalPEM()
tf.Write(p)
Expand All @@ -97,7 +97,9 @@ func Test_printCert(t *testing.T) {
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [],
"networks": [
"10.0.0.123/8"
],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
Expand All @@ -116,7 +118,9 @@ func Test_printCert(t *testing.T) {
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [],
"networks": [
"10.0.0.123/8"
],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
Expand All @@ -135,7 +139,9 @@ func Test_printCert(t *testing.T) {
"isCa": false,
"issuer": "`+c.Issuer()+`",
"name": "test",
"networks": [],
"networks": [
"10.0.0.123/8"
],
"notAfter": "0001-01-01T00:00:00Z",
"notBefore": "0001-01-01T00:00:00Z",
"publicKey": "`+pk+`",
Expand Down Expand Up @@ -166,7 +172,7 @@ func Test_printCert(t *testing.T) {
assert.Nil(t, err)
assert.Equal(
t,
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
`,
ob.String(),
)
Expand Down Expand Up @@ -212,6 +218,10 @@ func NewTestCert(ca cert.Certificate, signerKey []byte, name string, before, aft
after = ca.NotAfter()
}

if len(networks) == 0 {
networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
}

pub, rawPriv := x25519Keypair()
nc := &cert.TBSCertificate{
Version: cert.Version1,
Expand Down

0 comments on commit c096f5d

Please sign in to comment.