Skip to content

Commit

Permalink
move acs_url and relay_state input validation to rpc create and updat…
Browse files Browse the repository at this point in the history
…e methods
  • Loading branch information
flyinghermit committed Sep 22, 2024
1 parent 15fcdf4 commit 191a2c2
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 238 deletions.
18 changes: 18 additions & 0 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -6807,6 +6807,16 @@ func (a *ServerWithRoles) CreateSAMLIdPServiceProvider(ctx context.Context, sp t
return trace.Wrap(err)
}

if err := services.ValidateSAMLIdPACSURLAndRelayStateInputs(sp); err != nil {
return trace.Wrap(err)
}

if sp.GetEntityDescriptor() != "" {
if err := services.ValidateAndFilterEntityDescriptor(sp, services.SAMLACSInputStrictFilter); err != nil {
return trace.Wrap(err)
}
}

err = a.authServer.CreateSAMLIdPServiceProvider(ctx, sp)
return trace.Wrap(err)
}
Expand Down Expand Up @@ -6846,6 +6856,14 @@ func (a *ServerWithRoles) UpdateSAMLIdPServiceProvider(ctx context.Context, sp t
return trace.Wrap(err)
}

if err := services.ValidateSAMLIdPACSURLAndRelayStateInputs(sp); err != nil {
return trace.Wrap(err)
}

if err := services.ValidateAndFilterEntityDescriptor(sp, services.SAMLACSInputStrictFilter); err != nil {
return trace.Wrap(err)
}

err = a.authServer.UpdateSAMLIdPServiceProvider(ctx, sp)
return trace.Wrap(err)
}
Expand Down
183 changes: 183 additions & 0 deletions lib/auth/auth_with_roles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6000,6 +6000,189 @@ func TestUpdateSAMLIdPServiceProvider(t *testing.T) {
}
}

func TestCreateSAMLIdPServiceProviderInvalidInputs(t *testing.T) {
ctx := context.Background()
srv := newTestTLSServer(t)
user, _ := createSAMLIdPTestUsers(t, srv.Auth())
client, err := srv.NewClient(TestUser(user))
require.NoError(t, err)

tests := []struct {
name string
entityDescriptor string
entityID string
acsURL string
relayState string
errAssertion require.ErrorAssertionFunc
}{
{
name: "missing url scheme in acs input",
entityID: "sp",
acsURL: "sp",
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "invalid scheme")
},
},
{
name: "missing url scheme for acs in ed",
entityDescriptor: services.NewSAMLTestSPMetadata("sp", "sp"),
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "invalid url scheme")
},
},
{
name: "http url scheme in acs",
entityID: "sp",
acsURL: "http://sp",
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "invalid scheme")
},
},
{
name: "http url scheme for acs in ed",
entityDescriptor: services.NewSAMLTestSPMetadata("sp", "http://sp"),
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "unsupported ACS bindings")
},
},
{
name: "unsupported scheme in acs",
entityID: "sp",
acsURL: "gopher://sp",
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "invalid scheme")
},
},
{
name: "unsupported scheme for acs in ed",
entityDescriptor: services.NewSAMLTestSPMetadata("sp", "gopher://sp"),
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "invalid url scheme")
},
},
{
name: "invalid character in acs",
entityID: "sp",
acsURL: "https://sp>",
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "unsupported character")
},
},
{
name: "invalid character in acs in ed",
entityDescriptor: services.NewSAMLTestSPMetadata("sp", "https://sp>"),
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "unsupported ACS bindings")
},
},
{
name: "invalid character in relay state",
entityID: "sp",
acsURL: "https://sp",
relayState: "default_state<b",
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "unsupported character")
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
sp, err := types.NewSAMLIdPServiceProvider(types.Metadata{
Name: "test",
}, types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: test.entityDescriptor,
EntityID: test.entityID,
ACSURL: test.acsURL,
RelayState: test.relayState,
})
require.NoError(t, err)

err = client.CreateSAMLIdPServiceProvider(ctx, sp)
test.errAssertion(t, err)
})
}
}

func TestUpdateSAMLIdPServiceProviderInvalidInputs(t *testing.T) {
ctx := context.Background()
srv := newTestTLSServer(t)
user, _ := createSAMLIdPTestUsers(t, srv.Auth())
client, err := srv.NewClient(TestUser(user))
require.NoError(t, err)

sp, err := types.NewSAMLIdPServiceProvider(types.Metadata{
Name: "sp",
}, types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: services.NewSAMLTestSPMetadata("https://sp", "https://sp"),
})
require.NoError(t, err)

err = client.CreateSAMLIdPServiceProvider(ctx, sp)
require.NoError(t, err)

tests := []struct {
name string
entityDescriptor string
entityID string
acsURL string
relayState string
errAssertion require.ErrorAssertionFunc
}{
{
name: "missing url scheme for acs in ed",
entityDescriptor: services.NewSAMLTestSPMetadata("https://sp", "sp"),
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "invalid url scheme")
},
},
{
name: "http url scheme for acs in ed",
entityDescriptor: services.NewSAMLTestSPMetadata("https://sp", "http://sp"),
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "unsupported ACS bindings")
},
},
{
name: "unsupported scheme for acs in ed",
entityDescriptor: services.NewSAMLTestSPMetadata("https://sp", "gopher://sp"),
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "invalid url scheme")
},
},
{
name: "invalid character in acs in ed",
entityDescriptor: services.NewSAMLTestSPMetadata("https://sp", "https://sp>"),
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "unsupported ACS bindings")
},
},
{
name: "invalid character in relay state",
entityDescriptor: services.NewSAMLTestSPMetadata("https://sp", "https://sp"),
relayState: "default_state<b",
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "unsupported character")
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
sp, err := types.NewSAMLIdPServiceProvider(types.Metadata{
Name: "sp",
}, types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: test.entityDescriptor,
RelayState: test.relayState,
})
require.NoError(t, err)

err = client.UpdateSAMLIdPServiceProvider(ctx, sp)
test.errAssertion(t, err)
})
}
}

func TestDeleteSAMLIdPServiceProvider(t *testing.T) {
ctx := context.Background()
srv := newTestTLSServer(t)
Expand Down
36 changes: 12 additions & 24 deletions lib/services/local/saml_idp_service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ func (s *SAMLIdPServiceProviderService) GetSAMLIdPServiceProvider(ctx context.Co
// CreateSAMLIdPServiceProvider creates a new SAML IdP service provider resource.
func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error {
if err := services.ValidateSAMLIdPACSURLAndRelayStateInputs(sp); err != nil {
return trace.Wrap(err)
// logging instead of returning an error cause we do not want to break cache writes on a cluster
// that already has a service provider with unsupported characters/scheme in the acs_url or relay_state.
s.log.Warn(err)
}

if sp.GetEntityDescriptor() == "" {
if err := s.configureEntityDescriptorPerPreset(sp); err != nil {
errMsg := fmt.Errorf("failed to configure entity descriptor with the given entity_id %q and acs_url %q: %w",
Expand All @@ -126,7 +127,9 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context
}
}

if err := s.validateEntityDescriptor(sp); err != nil {
// we only verify if the entity ID field in the spec matches with the entity descriptor.
// filtering is done only for logging purpose.
if err := services.ValidateAndFilterEntityDescriptor(sp, services.SAMLACSInputPermissiveFilter); err != nil {
return trace.Wrap(err)
}

Expand Down Expand Up @@ -157,10 +160,14 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context
// UpdateSAMLIdPServiceProvider updates an existing SAML IdP service provider resource.
func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error {
if err := services.ValidateSAMLIdPACSURLAndRelayStateInputs(sp); err != nil {
return trace.Wrap(err)
// logging instead of returning an error cause we do not want to break cache writes on a cluster
// that already has a service provider with unsupported characters/scheme in the acs_url or relay_state.
s.log.Warn(err)
}

if err := s.validateEntityDescriptor(sp); err != nil {
// we only verify if the entity ID field in the spec matches with the entity descriptor.
// filtering is done only for logging purpose.
if err := services.ValidateAndFilterEntityDescriptor(sp, services.SAMLACSInputPermissiveFilter); err != nil {
return trace.Wrap(err)
}

Expand Down Expand Up @@ -346,25 +353,6 @@ func (s *SAMLIdPServiceProviderService) embedAttributeMapping(sp types.SAMLIdPSe
return nil
}

// validateEntityDescriptor validates entity descriptor XML, entity ID and logs unsupported ACS bindings.
func (s *SAMLIdPServiceProviderService) validateEntityDescriptor(sp types.SAMLIdPServiceProvider) error {
ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
if err != nil {
return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err)
}

if ed.EntityID != sp.GetEntityID() {
return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object")
}

// ensure any filtering related issues get logged
if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil {
return trace.BadParameter("Entity descriptor for SAML IdP Service Provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err)
}

return nil
}

// GetTeleportSPSSODescriptor returns Teleport embedded SPSSODescriptor and its index from a
// list of SPSSODescriptors. The correct SPSSODescriptor is determined by searching for
// AttributeConsumingService element with ServiceNames named teleport_saml_idp_service.
Expand Down
Loading

0 comments on commit 191a2c2

Please sign in to comment.