From 191a2c2ade345ba79ed0c189f15b8a6044107fe7 Mon Sep 17 00:00:00 2001 From: flyinghermit Date: Sat, 21 Sep 2024 20:50:30 -0400 Subject: [PATCH] move acs_url and relay_state input validation to rpc create and update methods --- lib/auth/auth_with_roles.go | 18 ++ lib/auth/auth_with_roles_test.go | 183 +++++++++++++++ .../local/saml_idp_service_provider.go | 36 +-- .../local/saml_idp_service_provider_test.go | 211 +----------------- lib/services/saml_idp_service_provider.go | 57 ++++- .../saml_idp_service_provider_test.go | 8 +- 6 files changed, 275 insertions(+), 238 deletions(-) diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 1913df092e1d3..7867c4307cc53 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -6807,6 +6807,16 @@ func (a *ServerWithRoles) CreateSAMLIdPServiceProvider(ctx context.Context, sp t return trace.Wrap(err) } + if err := services.ValidateSAMLIdPACSURLAndRelayStateInputs(sp); err != nil { + return trace.Wrap(err) + } + + if sp.GetEntityDescriptor() != "" { + if err := services.ValidateAndFilterEntityDescriptor(sp, services.SAMLACSInputStrictFilter); err != nil { + return trace.Wrap(err) + } + } + err = a.authServer.CreateSAMLIdPServiceProvider(ctx, sp) return trace.Wrap(err) } @@ -6846,6 +6856,14 @@ func (a *ServerWithRoles) UpdateSAMLIdPServiceProvider(ctx context.Context, sp t return trace.Wrap(err) } + if err := services.ValidateSAMLIdPACSURLAndRelayStateInputs(sp); err != nil { + return trace.Wrap(err) + } + + if err := services.ValidateAndFilterEntityDescriptor(sp, services.SAMLACSInputStrictFilter); err != nil { + return trace.Wrap(err) + } + err = a.authServer.UpdateSAMLIdPServiceProvider(ctx, sp) return trace.Wrap(err) } diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index f5b27984f5979..2147390841169 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -6000,6 +6000,189 @@ func TestUpdateSAMLIdPServiceProvider(t *testing.T) { } } +func TestCreateSAMLIdPServiceProviderInvalidInputs(t *testing.T) { + ctx := context.Background() + srv := newTestTLSServer(t) + user, _ := createSAMLIdPTestUsers(t, srv.Auth()) + client, err := srv.NewClient(TestUser(user)) + require.NoError(t, err) + + tests := []struct { + name string + entityDescriptor string + entityID string + acsURL string + relayState string + errAssertion require.ErrorAssertionFunc + }{ + { + name: "missing url scheme in acs input", + entityID: "sp", + acsURL: "sp", + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "invalid scheme") + }, + }, + { + name: "missing url scheme for acs in ed", + entityDescriptor: services.NewSAMLTestSPMetadata("sp", "sp"), + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "invalid url scheme") + }, + }, + { + name: "http url scheme in acs", + entityID: "sp", + acsURL: "http://sp", + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "invalid scheme") + }, + }, + { + name: "http url scheme for acs in ed", + entityDescriptor: services.NewSAMLTestSPMetadata("sp", "http://sp"), + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "unsupported ACS bindings") + }, + }, + { + name: "unsupported scheme in acs", + entityID: "sp", + acsURL: "gopher://sp", + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "invalid scheme") + }, + }, + { + name: "unsupported scheme for acs in ed", + entityDescriptor: services.NewSAMLTestSPMetadata("sp", "gopher://sp"), + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "invalid url scheme") + }, + }, + { + name: "invalid character in acs", + entityID: "sp", + acsURL: "https://sp>", + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "unsupported character") + }, + }, + { + name: "invalid character in acs in ed", + entityDescriptor: services.NewSAMLTestSPMetadata("sp", "https://sp>"), + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "unsupported ACS bindings") + }, + }, + { + name: "invalid character in relay state", + entityID: "sp", + acsURL: "https://sp", + relayState: "default_state"), + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "unsupported ACS bindings") + }, + }, + { + name: "invalid character in relay state", + entityDescriptor: services.NewSAMLTestSPMetadata("https://sp", "https://sp"), + relayState: "default_state ` -func newSAMLSPMetadata(entityID, acsURL string) string { - return fmt.Sprintf(samlSPMetadata, entityID, acsURL) -} - -// samlSPMetadata mimics metadata generated by saml.ServiceProvider.Metadata() -const samlSPMetadata = ` - - urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified - - - -` - func TestCreateSAMLIdPServiceProvider_fetchOrGenerateEntityDescriptor(t *testing.T) { ctx := context.Background() backend, err := memory.New(memory.Config{ @@ -303,7 +291,7 @@ func TestCreateSAMLIdPServiceProvider_fetchOrGenerateEntityDescriptor(t *testing sp2FromBackend, err := service2.GetSAMLIdPServiceProvider(ctx, sp2.GetName()) require.NoError(t, err) - metadataTemplate := newSAMLSPMetadata(notFoundURL, testSPServer.URL) + metadataTemplate := services.NewSAMLTestSPMetadata(notFoundURL, testSPServer.URL) expected, err := samlsp.ParseMetadata([]byte(metadataTemplate)) require.NoError(t, err) @@ -329,7 +317,7 @@ func TestCreateSAMLIdPServiceProvider_fetchAndSetEntityDescriptor(t *testing.T) fmt.Fprintln(w, "test") default: location := fmt.Sprintf("https://%s", r.Host) - metadata := newSAMLSPMetadata(location, location) + metadata := services.NewSAMLTestSPMetadata(location, location) w.WriteHeader(http.StatusOK) fmt.Fprintln(w, metadata) } @@ -601,196 +589,3 @@ func TestCreateSAMLIdPServiceProvider_GetTeleportSPSSODescriptor(t *testing.T) { index, _ := GetTeleportSPSSODescriptor(ed.SPSSODescriptors) require.Equal(t, 3, index) } - -func TestCreateSAMLIdPServiceProviderInvalidInputs(t *testing.T) { - ctx := context.Background() - - backend, err := memory.New(memory.Config{ - Context: ctx, - Clock: clockwork.NewFakeClock(), - }) - require.NoError(t, err) - - service, err := NewSAMLIdPServiceProviderService(backend) - require.NoError(t, err) - - tests := []struct { - name string - entityDescriptor string - entityID string - acsURL string - relayState string - errAssertion require.ErrorAssertionFunc - }{ - { - name: "missing url scheme in acs input", - entityID: "sp", - acsURL: "sp", - errAssertion: func(t require.TestingT, err error, i ...interface{}) { - require.ErrorContains(t, err, "invalid scheme") - }, - }, - { - name: "missing url scheme for acs in ed", - entityDescriptor: newSAMLSPMetadata("sp", "sp"), - errAssertion: func(t require.TestingT, err error, i ...interface{}) { - require.ErrorContains(t, err, "invalid url scheme") - }, - }, - { - name: "http url scheme in acs", - entityID: "sp", - acsURL: "http://sp", - errAssertion: func(t require.TestingT, err error, i ...interface{}) { - require.ErrorContains(t, err, "invalid scheme") - }, - }, - { - name: "http url scheme for acs in ed", - entityDescriptor: newSAMLSPMetadata("sp", "http://sp"), - errAssertion: func(t require.TestingT, err error, i ...interface{}) { - require.ErrorContains(t, err, "unsupported ACS bindings") - }, - }, - { - name: "unsupported scheme in acs", - entityID: "sp", - acsURL: "gopher://sp", - errAssertion: func(t require.TestingT, err error, i ...interface{}) { - require.ErrorContains(t, err, "invalid scheme") - }, - }, - { - name: "unsupported scheme for acs in ed", - entityDescriptor: newSAMLSPMetadata("sp", "gopher://sp"), - errAssertion: func(t require.TestingT, err error, i ...interface{}) { - require.ErrorContains(t, err, "invalid url scheme") - }, - }, - { - name: "invalid character in acs", - entityID: "sp", - acsURL: "https://sp>", - errAssertion: func(t require.TestingT, err error, i ...interface{}) { - require.ErrorContains(t, err, "unsupported character") - }, - }, - { - name: "invalid character in acs in ed", - entityDescriptor: newSAMLSPMetadata("sp", "https://sp>"), - errAssertion: func(t require.TestingT, err error, i ...interface{}) { - require.ErrorContains(t, err, "unsupported ACS bindings") - }, - }, - { - name: "invalid character in relay state", - entityID: "sp", - acsURL: "https://sp", - relayState: "default_state"), - errAssertion: func(t require.TestingT, err error, i ...interface{}) { - require.ErrorContains(t, err, "unsupported ACS bindings") - }, - }, - { - name: "invalid character in relay state", - entityDescriptor: newSAMLSPMetadata("https://sp", "https://sp"), - relayState: "default_state"!;` +// SAMLACSInputFilteringThreshold defines level of strictness for entity descriptor filtering. +type SAMLACSInputFilteringThreshold string + +const ( + // SAMLACSInputStrictFilter indicates ValidateAndFilterEntityDescriptor to return an error on + // any instance of unsupported ACS value. + SAMLACSInputStrictFilter SAMLACSInputFilteringThreshold = "SAMLACSInputStrictFilter" + // SAMLACSInputPermissiveFilter indicates ValidateAndFilterEntityDescriptor to ignore an error on + // any instance of unsupported ACS value. + SAMLACSInputPermissiveFilter SAMLACSInputFilteringThreshold = "SAMLACSInputPermissiveFilter" +) + +// ValidateAndFilterEntityDescriptor validates entity id and ACS value. It specifically: +// - checks for a valid entity descriptor XML format. +// - checks for a matching entity ID field in both the entity_id field and entity ID contained in the value of +// entity_descriptor field. +// - performs filtering on the Assertion Consumer service (ACS) binding format or its location URL endpoint. +// filterThreshold dictates if ValidateAndFilterEntityDescriptor should return or ignore error on filtering result. +func ValidateAndFilterEntityDescriptor(sp types.SAMLIdPServiceProvider, filterThreshold SAMLACSInputFilteringThreshold) error { + edOriginal, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor())) + if err != nil { + return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err) + } + + if edOriginal.EntityID != sp.GetEntityID() { + return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object") + } + + if err := FilterSAMLEntityDescriptor(edOriginal, false /* quiet */); err != nil { + if filterThreshold == SAMLACSInputStrictFilter { + return trace.BadParameter("Entity descriptor for SAML IdP Service Provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err) + } + } + + return nil +} + // validateAssertionConsumerServicesEndpoint ensures that the Assertion Consumer Service location // is a valid HTTPS endpoint. func validateAssertionConsumerServicesEndpoint(acs string) error { @@ -198,3 +237,17 @@ func ValidateSAMLIdPACSURLAndRelayStateInputs(sp types.SAMLIdPServiceProvider) e return nil } + +// NewSAMLTestSPMetadata creates a new entity descriptor for tests. +func NewSAMLTestSPMetadata(entityID, acsURL string) string { + return fmt.Sprintf(samlTestSPMetadata, entityID, acsURL) +} + +// samlTestSPMetadata mimics metadata format generated by saml.ServiceProvider.Metadata() +const samlTestSPMetadata = ` + + urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified + + + + ` diff --git a/lib/services/saml_idp_service_provider_test.go b/lib/services/saml_idp_service_provider_test.go index 1d5bc457578ae..da194186aeadf 100644 --- a/lib/services/saml_idp_service_provider_test.go +++ b/lib/services/saml_idp_service_provider_test.go @@ -89,7 +89,7 @@ func TestFilterSAMLEntityDescriptor(t *testing.T) { ACS(saml.HTTPPostBinding, "https://example.com/acs"). ACS(saml.HTTPPostBinding, "http://example.com/acs"). Done(), - ok: true, + ok: false, before: 2, after: 1, name: "scheme filtering", @@ -100,7 +100,7 @@ func TestFilterSAMLEntityDescriptor(t *testing.T) { ACS("urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST-SimpleSign", "https://example.com/POST-SimpleSign"). ACS(saml.HTTPPostBinding, "https://example.com/acs"). Done(), - ok: true, + ok: false, before: 3, after: 1, name: "binding filtering", @@ -127,9 +127,9 @@ func TestFilterSAMLEntityDescriptor(t *testing.T) { err = FilterSAMLEntityDescriptor(ed, false /* quiet */) if !tt.ok { require.Error(t, err) - return + } else { + require.NoError(t, err) } - require.NoError(t, err) require.Equal(t, tt.after, getACSCount(ed)) })