diff --git a/api/types/mfa.go b/api/types/mfa.go index 6fb00f042ff69..bd0f42aa21729 100644 --- a/api/types/mfa.go +++ b/api/types/mfa.go @@ -16,8 +16,6 @@ package types import ( "bytes" - "encoding/json" - "strings" "time" "github.com/gogo/protobuf/jsonpb" @@ -175,113 +173,3 @@ func (d *MFADevice) UnmarshalJSON(buf []byte) error { err := unmarshaler.Unmarshal(bytes.NewReader(buf), d) return trace.Wrap(err) } - -// MarshalJSON marshals SAMLForceAuthn to string. -func (s *SAMLForceAuthn) MarshalYAML() (interface{}, error) { - val, err := s.encode() - if err != nil { - return nil, trace.Wrap(err) - } - return val, nil -} - -// UnmarshalYAML supports parsing SAMLForceAuthn from string. -func (s *SAMLForceAuthn) UnmarshalYAML(unmarshal func(interface{}) error) error { - var val interface{} - err := unmarshal(&val) - if err != nil { - return trace.Wrap(err) - } - - err = s.decode(val) - return trace.Wrap(err) -} - -// MarshalJSON marshals SAMLForceAuthn to string. -func (s *SAMLForceAuthn) MarshalJSON() ([]byte, error) { - val, err := s.encode() - if err != nil { - return nil, trace.Wrap(err) - } - out, err := json.Marshal(val) - return out, trace.Wrap(err) -} - -// UnmarshalJSON supports parsing SAMLForceAuthn from string. -func (s *SAMLForceAuthn) UnmarshalJSON(data []byte) error { - var val interface{} - err := json.Unmarshal(data, &val) - if err != nil { - return trace.Wrap(err) - } - - err = s.decode(val) - return trace.Wrap(err) -} - -const ( - // SAMLForceAuthnOTPString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_OTP - SAMLForceAuthnOTPString = "otp" - // SAMLForceAuthnWebauthnString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_WEBAUTHN - SAMLForceAuthnWebauthnString = "webauthn" - // SAMLForceAuthnSSOString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_SSO - SAMLForceAuthnSSOString = "sso" -) - -func (s *SAMLForceAuthn) encode() (string, error) { - switch *s { - case SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED: - return "", nil - case SAMLForceAuthn_FORCE_AUTHN_NO: - return "no", nil - case SAMLForceAuthn_FORCE_AUTHN_YES: - return "yes", nil - default: - return "", trace.BadParameter("SAMLForceAuthn invalid value %v", *s) - } -} - -func (s *SAMLForceAuthn) decode(val any) error { - switch v := val.(type) { - case string: - // try parsing as a boolean - switch strings.ToLower(v) { - case "": - *s = SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED - case "yes", "yeah", "y", "true", "1", "on": - *s = SAMLForceAuthn_FORCE_AUTHN_YES - case "no", "nope", "n", "false", "0", "off": - *s = SAMLForceAuthn_FORCE_AUTHN_NO - default: - return trace.BadParameter("SAMLForceAuthn invalid value %v", val) - } - case bool: - if v { - *s = SAMLForceAuthn_FORCE_AUTHN_YES - } else { - *s = SAMLForceAuthn_FORCE_AUTHN_NO - } - case int32: - return trace.Wrap(s.setFromEnum(v)) - case int64: - return trace.Wrap(s.setFromEnum(int32(v))) - case int: - return trace.Wrap(s.setFromEnum(int32(v))) - case float64: - return trace.Wrap(s.setFromEnum(int32(v))) - case float32: - return trace.Wrap(s.setFromEnum(int32(v))) - default: - return trace.BadParameter("SAMLForceAuthn invalid type %T", val) - } - return nil -} - -// setFromEnum sets the value from enum value as int32. -func (s *SAMLForceAuthn) setFromEnum(val int32) error { - if _, ok := SAMLForceAuthn_name[val]; !ok { - return trace.BadParameter("invalid SAMLForceAuthn enum %v", val) - } - *s = SAMLForceAuthn(val) - return nil -} diff --git a/api/types/saml.go b/api/types/saml.go index 0295aeb3e66c1..e67ce51104817 100644 --- a/api/types/saml.go +++ b/api/types/saml.go @@ -17,6 +17,7 @@ limitations under the License. package types import ( + "encoding/json" "slices" "strings" "time" @@ -524,3 +525,107 @@ func (r *SAMLAuthRequest) Check() error { } return nil } + +// MarshalJSON marshals SAMLForceAuthn to string. +func (s *SAMLForceAuthn) MarshalYAML() (any, error) { + val, err := s.Encode() + if err != nil { + return nil, trace.Wrap(err) + } + return val, nil +} + +// UnmarshalYAML supports parsing SAMLForceAuthn from string. +func (s *SAMLForceAuthn) UnmarshalYAML(unmarshal func(any) error) error { + var val any + if err := unmarshal(&val); err != nil { + return trace.Wrap(err) + } + return trace.Wrap(s.Decode(val)) +} + +// MarshalJSON marshals SAMLForceAuthn to string. +func (s *SAMLForceAuthn) MarshalJSON() ([]byte, error) { + val, err := s.Encode() + if err != nil { + return nil, trace.Wrap(err) + } + out, err := json.Marshal(val) + return out, trace.Wrap(err) +} + +// UnmarshalJSON supports parsing SAMLForceAuthn from string. +func (s *SAMLForceAuthn) UnmarshalJSON(data []byte) error { + var val any + if err := json.Unmarshal(data, &val); err != nil { + return trace.Wrap(err) + } + return trace.Wrap(s.Decode(val)) +} + +const ( + // SAMLForceAuthnOTPString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_OTP + SAMLForceAuthnOTPString = "otp" + // SAMLForceAuthnWebauthnString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_WEBAUTHN + SAMLForceAuthnWebauthnString = "webauthn" + // SAMLForceAuthnSSOString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_SSO + SAMLForceAuthnSSOString = "sso" +) + +func (s *SAMLForceAuthn) Encode() (string, error) { + switch *s { + case SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED: + return "", nil + case SAMLForceAuthn_FORCE_AUTHN_NO: + return "no", nil + case SAMLForceAuthn_FORCE_AUTHN_YES: + return "yes", nil + default: + return "", trace.BadParameter("SAMLForceAuthn invalid value %v", *s) + } +} + +func (s *SAMLForceAuthn) Decode(val any) error { + switch v := val.(type) { + case string: + // try parsing as a boolean + switch strings.ToLower(v) { + case "": + *s = SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED + case "yes", "yeah", "y", "true", "1", "on": + *s = SAMLForceAuthn_FORCE_AUTHN_YES + case "no", "nope", "n", "false", "0", "off": + *s = SAMLForceAuthn_FORCE_AUTHN_NO + default: + return trace.BadParameter("SAMLForceAuthn invalid value %v", val) + } + case bool: + if v { + *s = SAMLForceAuthn_FORCE_AUTHN_YES + } else { + *s = SAMLForceAuthn_FORCE_AUTHN_NO + } + case int32: + return trace.Wrap(s.setFromEnum(v)) + case int64: + return trace.Wrap(s.setFromEnum(int32(v))) + case int: + return trace.Wrap(s.setFromEnum(int32(v))) + case float64: + return trace.Wrap(s.setFromEnum(int32(v))) + case float32: + return trace.Wrap(s.setFromEnum(int32(v))) + default: + return trace.BadParameter("SAMLForceAuthn invalid type %T", val) + } + return nil +} + +// setFromEnum sets the value from enum value as int32. +func (s *SAMLForceAuthn) setFromEnum(val int32) error { + if _, ok := SAMLForceAuthn_name[val]; !ok { + return trace.BadParameter("invalid SAMLForceAuthn enum %v", val) + } + *s = SAMLForceAuthn(val) + return nil +} diff --git a/api/types/saml_test.go b/api/types/saml_test.go index 228b0afd9f35b..1ac3702c5104a 100644 --- a/api/types/saml_test.go +++ b/api/types/saml_test.go @@ -19,6 +19,7 @@ package types_test import ( "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -107,3 +108,36 @@ func TestSAMLForceAuthn(t *testing.T) { }) } } + +func TestEncodeDecodeSAMLForceAuthn(t *testing.T) { + for _, tt := range []struct { + forceAuthn types.SAMLForceAuthn + encoded string + }{ + { + forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED, + encoded: "", + }, { + forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_NO, + encoded: "no", + }, { + forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_YES, + encoded: "yes", + }, + } { + t.Run(tt.forceAuthn.String(), func(t *testing.T) { + t.Run("encode", func(t *testing.T) { + encoded, err := tt.forceAuthn.Encode() + assert.NoError(t, err) + assert.Equal(t, tt.encoded, encoded) + }) + + t.Run("decode", func(t *testing.T) { + var decoded types.SAMLForceAuthn + err := decoded.Decode(tt.encoded) + assert.NoError(t, err) + assert.Equal(t, tt.forceAuthn, decoded) + }) + }) + } +}