Skip to content

Commit

Permalink
SAML IdP attribute mapping types and config handler (#35584)
Browse files Browse the repository at this point in the history
* attribute mapping proto def

* - Get(Set)AttributeMapping methods
- Check for duplciate attribute names

* check and get urn name format for configuredattribute name format

* fix typo

* urn -> uri

* tratis -> traits

* use trace.BadParameterError struct for better tracing

* - CheckAndSetDefaults() for SAMLAttributeMapping type
- use map instead of slice for finding duplicate names
  • Loading branch information
flyinghermit authored Dec 14, 2023
1 parent dad8c44 commit b4e5e91
Show file tree
Hide file tree
Showing 5 changed files with 2,016 additions and 1,558 deletions.
14 changes: 14 additions & 0 deletions api/proto/teleport/legacy/types/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -5779,6 +5779,20 @@ message SAMLIdPServiceProviderSpecV1 {
string EntityID = 2 [(gogoproto.jsontag) = "entity_id"];
// ACSURL is the endpoint where SAML authentication response will be redirected.
string ACSURL = 3 [(gogoproto.jsontag) = "acs_url"];
// AttributeMapping is used to map Service Provider requested attributes to
// username, role and traits in Teleport.
repeated SAMLAttributeMapping AttributeMapping = 4 [(gogoproto.jsontag) = "attribute_mapping"];
}

// SAMLAttributeMapping represents SAML Service Provider requested attribute
// name, format and its values.
message SAMLAttributeMapping {
// name is an attribute name.
string name = 1 [(gogoproto.jsontag) = "name"];
// name_format is an attribute name format.
string name_format = 2 [(gogoproto.jsontag) = "name_format"];
// value is an attribute value definable with predicate expression.
string value = 3 [(gogoproto.jsontag) = "value"];
}

// IdPOptions specify options related to access Teleport IdPs.
Expand Down
61 changes: 57 additions & 4 deletions api/types/saml_idp_service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,20 @@ import (
"github.com/gravitational/teleport/api/utils"
)

const (
unspecifiedNameFormat = "urn:oasis:names:tc:SAML:2.0:attrname-format:unspecified"
uriNameFormat = "urn:oasis:names:tc:SAML:2.0:attrname-format:uri"
basicNameFormat = "urn:oasis:names:tc:SAML:2.0:attrname-format:basic"
)

var (
// ErrMissingEntityDescriptorAndEntityID is returned when both entity descriptor and entity ID is empty.
ErrEmptyEntityDescriptorAndEntityID = trace.BadParameter("either entity_descriptor or entity_id must be provided")
ErrEmptyEntityDescriptorAndEntityID = &trace.BadParameterError{Message: "either entity_descriptor or entity_id must be provided"}
// ErrMissingEntityDescriptorAndACSURL is returned when both entity descriptor and ACS URL is empty.
ErrEmptyEntityDescriptorAndACSURL = trace.BadParameter("either entity_descriptor or acs_url must be provided")
ErrEmptyEntityDescriptorAndACSURL = &trace.BadParameterError{Message: "either entity_descriptor or acs_url must be provided"}
// ErrDuplicateAttributeName is returned when attribute mapping declares two or more
// attributes with the same name.
ErrDuplicateAttributeName = &trace.BadParameterError{Message: "duplicate attribute name not allowed"}
)

// SAMLIdPServiceProvider specifies configuration for service providers for Teleport's built in SAML IdP.
Expand All @@ -51,6 +60,10 @@ type SAMLIdPServiceProvider interface {
GetACSURL() string
// SetACSURL sets the ACS URL.
SetACSURL(string)
// GetAttributeMapping returns Attribute Mapping.
GetAttributeMapping() []*SAMLAttributeMapping
// SetAttributeMapping sets Attribute Mapping.
SetAttributeMapping([]*SAMLAttributeMapping)
// Copy returns a copy of this saml idp service provider object.
Copy() SAMLIdPServiceProvider
// CloneResource returns a copy of the SAMLIdPServiceProvider as a ResourceWithLabels
Expand Down Expand Up @@ -103,6 +116,16 @@ func (s *SAMLIdPServiceProviderV1) SetACSURL(acsURL string) {
s.Spec.ACSURL = acsURL
}

// GetAttributeMapping returns the Attribute Mapping.
func (s *SAMLIdPServiceProviderV1) GetAttributeMapping() []*SAMLAttributeMapping {
return s.Spec.AttributeMapping
}

// SetAttributeMapping sets Attribute Mapping.
func (s *SAMLIdPServiceProviderV1) SetAttributeMapping(attrMaps []*SAMLAttributeMapping) {
s.Spec.AttributeMapping = attrMaps
}

// String returns the SAML IdP service provider string representation.
func (s *SAMLIdPServiceProviderV1) String() string {
return fmt.Sprintf("SAMLIdPServiceProviderV1(Name=%v)",
Expand Down Expand Up @@ -139,11 +162,11 @@ func (s *SAMLIdPServiceProviderV1) CheckAndSetDefaults() error {

if s.Spec.EntityDescriptor == "" {
if s.Spec.EntityID == "" {
return ErrEmptyEntityDescriptorAndEntityID
return trace.Wrap(ErrEmptyEntityDescriptorAndEntityID)
}

if s.Spec.ACSURL == "" {
return ErrEmptyEntityDescriptorAndACSURL
return trace.Wrap(ErrEmptyEntityDescriptorAndACSURL)
}

}
Expand All @@ -161,6 +184,18 @@ func (s *SAMLIdPServiceProviderV1) CheckAndSetDefaults() error {
s.Spec.EntityID = ed.EntityID
}

attrNames := make(map[string]struct{})
for _, attributeMap := range s.GetAttributeMapping() {
if err := attributeMap.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
// check for duplicate attribute names
if _, ok := attrNames[attributeMap.Name]; ok {
return trace.Wrap(ErrDuplicateAttributeName)
}
attrNames[attributeMap.Name] = struct{}{}
}

return nil
}

Expand All @@ -184,3 +219,21 @@ func (s SAMLIdPServiceProviders) Less(i, j int) bool { return s[i].GetName() < s

// Swap swaps two service providers.
func (s SAMLIdPServiceProviders) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

// CheckAndSetDefaults check and sets SAMLAttributeMapping default values
func (am *SAMLAttributeMapping) CheckAndSetDefaults() error {
// verify name format is one of the supported
// formats - unspecifiedNameFormat, basicNameFormat or uriNameFormat
// and assign it with the URN value of that format.
switch am.NameFormat {
case "", "unspecified", unspecifiedNameFormat:
am.NameFormat = unspecifiedNameFormat
case "basic", basicNameFormat:
am.NameFormat = basicNameFormat
case "uri", uriNameFormat:
am.NameFormat = uriNameFormat
default:
return trace.BadParameter("invalid name format: %s", am.NameFormat)
}
return nil
}
63 changes: 63 additions & 0 deletions api/types/saml_idp_service_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func TestNewSAMLIdPServiceProvider(t *testing.T) {
acsURL string
errAssertion require.ErrorAssertionFunc
expectedEntityID string
attributeMapping []*SAMLAttributeMapping
}{
{
name: "valid entity descriptor",
Expand Down Expand Up @@ -82,6 +83,64 @@ func TestNewSAMLIdPServiceProvider(t *testing.T) {
errAssertion: require.NoError,
expectedEntityID: "IAMShowcase",
},
{
name: "duplicate attribute mapping",
entityDescriptor: testEntityDescriptor,
attributeMapping: []*SAMLAttributeMapping{
{
Name: "username",
Value: "user.traits.name",
},
{
Name: "user1",
Value: "user.traits.firstname",
},
{
Name: "username",
Value: "user.traits.givenname",
},
},
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorIs(t, err, ErrDuplicateAttributeName)
},
},
{
name: "valid attribute mapping",
entityDescriptor: testEntityDescriptor,
entityID: "IAMShowcase",
expectedEntityID: "IAMShowcase",
attributeMapping: []*SAMLAttributeMapping{
{
Name: "username",
Value: "user.traits.name",
},
{
Name: "user1",
Value: "user.traits.givenname",
},
},
errAssertion: require.NoError,
},
{
name: "invalid attribute mapping name format",
entityDescriptor: testEntityDescriptor,
entityID: "IAMShowcase",
expectedEntityID: "IAMShowcase",
attributeMapping: []*SAMLAttributeMapping{
{
Name: "username",
Value: "user.traits.name",
NameFormat: "emailAddress",
},
{
Name: "user1",
Value: "user.traits.givenname",
},
},
errAssertion: func(t require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "invalid name format")
},
},
}

for _, test := range tests {
Expand All @@ -92,11 +151,15 @@ func TestNewSAMLIdPServiceProvider(t *testing.T) {
EntityDescriptor: test.entityDescriptor,
EntityID: test.entityID,
ACSURL: test.acsURL,
AttributeMapping: test.attributeMapping,
})

test.errAssertion(t, err)
if sp != nil {
require.Equal(t, test.expectedEntityID, sp.GetEntityID())
if len(sp.GetAttributeMapping()) > 0 {
require.Equal(t, test.attributeMapping, sp.GetAttributeMapping())
}
}
})
}
Expand Down
Loading

0 comments on commit b4e5e91

Please sign in to comment.