Skip to content

Commit

Permalink
[v16] Reuse existing SAMLConnector signing key when possible (#44666)
Browse files Browse the repository at this point in the history
* Reuse existing SAMLConnector signing key when possible

* lint
  • Loading branch information
hugoShaka authored Jul 29, 2024
1 parent 61f6c2d commit 3ac9d95
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 0 deletions.
27 changes: 27 additions & 0 deletions lib/auth/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ func (a *Server) UpsertSAMLConnector(ctx context.Context, connector types.SAMLCo
return nil, trace.Wrap(err)
}

// If someone is applying a SAML Connector obtained with `tctl get` without secrets, the signing key pair is
// not empty (cert is set) but the private key is missing. Such a SAML resource is invalid and not usable.
if connector.GetSigningKeyPair().PrivateKey == "" {
err := services.FillSAMLSigningKeyFromExisting(ctx, connector, a.Services)
if err != nil {
return nil, trace.Wrap(err)
}
}

upserted, err := a.Services.UpsertSAMLConnector(ctx, connector)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -94,6 +103,17 @@ func (a *Server) UpdateSAMLConnector(ctx context.Context, connector types.SAMLCo
return nil, trace.Wrap(err)
}

// If someone is applying a SAML Connector obtained with `tctl get` without secrets, the signing key pair is
// not empty (cert is set) but the private key is missing. In this case we want to look up the existing SAML
// connector and populate the singing key from it if it's the same certificate. This avoids accidentally clearing
// the private key and creating an unusable connector.
if connector.GetSigningKeyPair().PrivateKey == "" {
err := services.FillSAMLSigningKeyFromExisting(ctx, connector, a.Services)
if err != nil {
return nil, trace.Wrap(err)
}
}

updated, err := a.Services.UpdateSAMLConnector(ctx, connector)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -129,6 +149,13 @@ func (a *Server) CreateSAMLConnector(ctx context.Context, connector types.SAMLCo
return nil, trace.Wrap(err)
}

// If someone is applying a SAML Connector obtained with `tctl get` without secrets, the signing key pair is
// not empty (cert is set) but the private key is missing. This SAML Connector is invalid, we must reject it
// with an actionable message.
if connector.GetSigningKeyPair().PrivateKey == "" {
return nil, trace.BadParameter("Missing private key for signing connector. " + services.ErrMsgHowToFixMissingPrivateKey)
}

created, err := a.Services.CreateSAMLConnector(ctx, connector)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
26 changes: 26 additions & 0 deletions lib/services/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ import (
"github.com/gravitational/teleport/lib/utils"
)

type SAMLConnectorGetter interface {
GetSAMLConnector(ctx context.Context, id string, withSecrets bool) (types.SAMLConnector, error)
}

const ErrMsgHowToFixMissingPrivateKey = "You must either specify the singing key pair (obtain the existing one with `tctl get saml --with-secrets`) or let Teleport generate a new one (remove singing_key_pair in the resource you're trying to create)."

// ValidateSAMLConnector validates the SAMLConnector and sets default values.
// If a remote to fetch roles is specified, roles will be validated to exist.
func ValidateSAMLConnector(sc types.SAMLConnector, rg RoleGetter) error {
Expand Down Expand Up @@ -330,3 +336,23 @@ func MarshalSAMLConnector(samlConnector types.SAMLConnector, opts ...MarshalOpti
return nil, trace.BadParameter("unrecognized SAML connector version %T", samlConnector)
}
}

// FillSAMLSigningKeyFromExisting looks up the existing SAML connector and populates the signing key if it's missing.
// This must be called only if the SAML Connector signing key pair has been initialized (ValidateSAMLConnector) and
// the private key is still empty.
func FillSAMLSigningKeyFromExisting(ctx context.Context, connector types.SAMLConnector, sg SAMLConnectorGetter) error {
existing, err := sg.GetSAMLConnector(ctx, connector.GetName(), true /* with secrets */)
switch {
case trace.IsNotFound(err):
return trace.BadParameter("failed to create SAML connector, the SAML connector has no signing key set. " + ErrMsgHowToFixMissingPrivateKey)
case err != nil:
return trace.BadParameter("failed to update SAML connector, the SAML connector has no signing key set and looking up the existing connector failed with the error: %s. %s", err.Error(), ErrMsgHowToFixMissingPrivateKey)
}

existingSkp := existing.GetSigningKeyPair()
if existingSkp == nil || existingSkp.Cert != connector.GetSigningKeyPair().Cert {
return trace.BadParameter("failed to update the SAML connector, the SAML connector has no signing key and its signing certificate does not match the existing one. " + ErrMsgHowToFixMissingPrivateKey)
}
connector.SetSigningKeyPair(existingSkp)
return nil
}
112 changes: 112 additions & 0 deletions lib/services/saml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ package services

import (
"context"
"crypto/x509/pkix"
"strings"
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
Expand All @@ -30,6 +32,7 @@ import (
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/utils"
)

func TestParseFromMetadata(t *testing.T) {
Expand Down Expand Up @@ -127,6 +130,115 @@ func TestValidateRoles(t *testing.T) {
}
}

type mockSAMLGetter map[string]types.SAMLConnector

func (m mockSAMLGetter) GetSAMLConnector(_ context.Context, id string, withSecrets bool) (types.SAMLConnector, error) {
connector, ok := m[id]
if !ok {
return nil, trace.NotFound("%s not found", id)
}
return connector, nil
}

func TestFillSAMLSigningKeyFromExisting(t *testing.T) {
t.Parallel()

// Test setup: generate the fixtures
existingKeyPEM, existingCertPEM, err := utils.GenerateSelfSignedSigningCert(pkix.Name{
Organization: []string{"Teleport OSS"},
CommonName: "teleport.localhost.localdomain",
}, nil, 10*365*24*time.Hour)
require.NoError(t, err)

existingSkp := &types.AsymmetricKeyPair{
PrivateKey: string(existingKeyPEM),
Cert: string(existingCertPEM),
}

existingConnectorName := "existing"
existingConnectors := mockSAMLGetter{
existingConnectorName: &types.SAMLConnectorV2{
Spec: types.SAMLConnectorSpecV2{
SigningKeyPair: existingSkp,
},
},
}

_, unrelatedCertPEM, err := utils.GenerateSelfSignedSigningCert(pkix.Name{
Organization: []string{"Teleport OSS"},
CommonName: "teleport.localhost.localdomain",
}, nil, 10*365*24*time.Hour)
require.NoError(t, err)

// Test setup: define test cases
testCases := []struct {
name string
connectorName string
connectorSpec types.SAMLConnectorSpecV2
assertErr require.ErrorAssertionFunc
assertResult require.ValueAssertionFunc
}{
{
name: "should read singing key from existing connector with matching cert",
connectorName: existingConnectorName,
connectorSpec: types.SAMLConnectorSpecV2{
SigningKeyPair: &types.AsymmetricKeyPair{
PrivateKey: "",
Cert: string(existingCertPEM),
},
},
assertErr: require.NoError,
assertResult: func(t require.TestingT, value interface{}, args ...interface{}) {
require.Implements(t, (*types.SAMLConnector)(nil), value)
connector := value.(types.SAMLConnector)
skp := connector.GetSigningKeyPair()
require.Equal(t, existingSkp, skp)
},
},
{
name: "should error when there's no existing connector",
connectorName: "non-existing",
connectorSpec: types.SAMLConnectorSpecV2{
SigningKeyPair: &types.AsymmetricKeyPair{
PrivateKey: "",
Cert: string(unrelatedCertPEM),
},
},
assertErr: require.Error,
},
{
name: "should error when existing connector cert is not matching",
connectorName: existingConnectorName,
connectorSpec: types.SAMLConnectorSpecV2{
SigningKeyPair: &types.AsymmetricKeyPair{
PrivateKey: "",
Cert: string(unrelatedCertPEM),
},
},
assertErr: require.Error,
},
}

// Test execution
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
connector := &types.SAMLConnectorV2{
Metadata: types.Metadata{
Name: tc.connectorName,
},
Spec: tc.connectorSpec,
}

err := FillSAMLSigningKeyFromExisting(ctx, connector, existingConnectors)
tc.assertErr(t, err)
if tc.assertResult != nil {
tc.assertResult(t, connector)
}
})
}
}

// roleSet is a basic set of roles keyed by role name. It implements the
// RoleGetter interface, returning the role if it exists, or a trace.NotFound
// error if it does not exist.
Expand Down

0 comments on commit 3ac9d95

Please sign in to comment.