From ca0e6af383ee21f533251dba8c88c105c6d9f169 Mon Sep 17 00:00:00 2001 From: joerger Date: Tue, 29 Oct 2024 12:18:04 -0700 Subject: [PATCH 1/2] Marshal SAMLForceAuthn to/from string/bool. --- api/types/mfa.go | 112 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/api/types/mfa.go b/api/types/mfa.go index bd0f42aa21729..6fb00f042ff69 100644 --- a/api/types/mfa.go +++ b/api/types/mfa.go @@ -16,6 +16,8 @@ package types import ( "bytes" + "encoding/json" + "strings" "time" "github.com/gogo/protobuf/jsonpb" @@ -173,3 +175,113 @@ func (d *MFADevice) UnmarshalJSON(buf []byte) error { err := unmarshaler.Unmarshal(bytes.NewReader(buf), d) return trace.Wrap(err) } + +// MarshalJSON marshals SAMLForceAuthn to string. +func (s *SAMLForceAuthn) MarshalYAML() (interface{}, error) { + val, err := s.encode() + if err != nil { + return nil, trace.Wrap(err) + } + return val, nil +} + +// UnmarshalYAML supports parsing SAMLForceAuthn from string. +func (s *SAMLForceAuthn) UnmarshalYAML(unmarshal func(interface{}) error) error { + var val interface{} + err := unmarshal(&val) + if err != nil { + return trace.Wrap(err) + } + + err = s.decode(val) + return trace.Wrap(err) +} + +// MarshalJSON marshals SAMLForceAuthn to string. +func (s *SAMLForceAuthn) MarshalJSON() ([]byte, error) { + val, err := s.encode() + if err != nil { + return nil, trace.Wrap(err) + } + out, err := json.Marshal(val) + return out, trace.Wrap(err) +} + +// UnmarshalJSON supports parsing SAMLForceAuthn from string. +func (s *SAMLForceAuthn) UnmarshalJSON(data []byte) error { + var val interface{} + err := json.Unmarshal(data, &val) + if err != nil { + return trace.Wrap(err) + } + + err = s.decode(val) + return trace.Wrap(err) +} + +const ( + // SAMLForceAuthnOTPString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_OTP + SAMLForceAuthnOTPString = "otp" + // SAMLForceAuthnWebauthnString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_WEBAUTHN + SAMLForceAuthnWebauthnString = "webauthn" + // SAMLForceAuthnSSOString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_SSO + SAMLForceAuthnSSOString = "sso" +) + +func (s *SAMLForceAuthn) encode() (string, error) { + switch *s { + case SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED: + return "", nil + case SAMLForceAuthn_FORCE_AUTHN_NO: + return "no", nil + case SAMLForceAuthn_FORCE_AUTHN_YES: + return "yes", nil + default: + return "", trace.BadParameter("SAMLForceAuthn invalid value %v", *s) + } +} + +func (s *SAMLForceAuthn) decode(val any) error { + switch v := val.(type) { + case string: + // try parsing as a boolean + switch strings.ToLower(v) { + case "": + *s = SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED + case "yes", "yeah", "y", "true", "1", "on": + *s = SAMLForceAuthn_FORCE_AUTHN_YES + case "no", "nope", "n", "false", "0", "off": + *s = SAMLForceAuthn_FORCE_AUTHN_NO + default: + return trace.BadParameter("SAMLForceAuthn invalid value %v", val) + } + case bool: + if v { + *s = SAMLForceAuthn_FORCE_AUTHN_YES + } else { + *s = SAMLForceAuthn_FORCE_AUTHN_NO + } + case int32: + return trace.Wrap(s.setFromEnum(v)) + case int64: + return trace.Wrap(s.setFromEnum(int32(v))) + case int: + return trace.Wrap(s.setFromEnum(int32(v))) + case float64: + return trace.Wrap(s.setFromEnum(int32(v))) + case float32: + return trace.Wrap(s.setFromEnum(int32(v))) + default: + return trace.BadParameter("SAMLForceAuthn invalid type %T", val) + } + return nil +} + +// setFromEnum sets the value from enum value as int32. +func (s *SAMLForceAuthn) setFromEnum(val int32) error { + if _, ok := SAMLForceAuthn_name[val]; !ok { + return trace.BadParameter("invalid SAMLForceAuthn enum %v", val) + } + *s = SAMLForceAuthn(val) + return nil +} From e3b49d23f466a6e06f438eab302c56669e4958da Mon Sep 17 00:00:00 2001 From: joerger Date: Wed, 30 Oct 2024 10:20:12 -0700 Subject: [PATCH 2/2] Use any; Move to api/types/saml.go; add test. --- api/types/mfa.go | 112 ----------------------------------------- api/types/saml.go | 96 +++++++++++++++++++++++++++++++++++ api/types/saml_test.go | 63 +++++++++++++++++++++++ 3 files changed, 159 insertions(+), 112 deletions(-) diff --git a/api/types/mfa.go b/api/types/mfa.go index 6fb00f042ff69..bd0f42aa21729 100644 --- a/api/types/mfa.go +++ b/api/types/mfa.go @@ -16,8 +16,6 @@ package types import ( "bytes" - "encoding/json" - "strings" "time" "github.com/gogo/protobuf/jsonpb" @@ -175,113 +173,3 @@ func (d *MFADevice) UnmarshalJSON(buf []byte) error { err := unmarshaler.Unmarshal(bytes.NewReader(buf), d) return trace.Wrap(err) } - -// MarshalJSON marshals SAMLForceAuthn to string. -func (s *SAMLForceAuthn) MarshalYAML() (interface{}, error) { - val, err := s.encode() - if err != nil { - return nil, trace.Wrap(err) - } - return val, nil -} - -// UnmarshalYAML supports parsing SAMLForceAuthn from string. -func (s *SAMLForceAuthn) UnmarshalYAML(unmarshal func(interface{}) error) error { - var val interface{} - err := unmarshal(&val) - if err != nil { - return trace.Wrap(err) - } - - err = s.decode(val) - return trace.Wrap(err) -} - -// MarshalJSON marshals SAMLForceAuthn to string. -func (s *SAMLForceAuthn) MarshalJSON() ([]byte, error) { - val, err := s.encode() - if err != nil { - return nil, trace.Wrap(err) - } - out, err := json.Marshal(val) - return out, trace.Wrap(err) -} - -// UnmarshalJSON supports parsing SAMLForceAuthn from string. -func (s *SAMLForceAuthn) UnmarshalJSON(data []byte) error { - var val interface{} - err := json.Unmarshal(data, &val) - if err != nil { - return trace.Wrap(err) - } - - err = s.decode(val) - return trace.Wrap(err) -} - -const ( - // SAMLForceAuthnOTPString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_OTP - SAMLForceAuthnOTPString = "otp" - // SAMLForceAuthnWebauthnString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_WEBAUTHN - SAMLForceAuthnWebauthnString = "webauthn" - // SAMLForceAuthnSSOString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_SSO - SAMLForceAuthnSSOString = "sso" -) - -func (s *SAMLForceAuthn) encode() (string, error) { - switch *s { - case SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED: - return "", nil - case SAMLForceAuthn_FORCE_AUTHN_NO: - return "no", nil - case SAMLForceAuthn_FORCE_AUTHN_YES: - return "yes", nil - default: - return "", trace.BadParameter("SAMLForceAuthn invalid value %v", *s) - } -} - -func (s *SAMLForceAuthn) decode(val any) error { - switch v := val.(type) { - case string: - // try parsing as a boolean - switch strings.ToLower(v) { - case "": - *s = SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED - case "yes", "yeah", "y", "true", "1", "on": - *s = SAMLForceAuthn_FORCE_AUTHN_YES - case "no", "nope", "n", "false", "0", "off": - *s = SAMLForceAuthn_FORCE_AUTHN_NO - default: - return trace.BadParameter("SAMLForceAuthn invalid value %v", val) - } - case bool: - if v { - *s = SAMLForceAuthn_FORCE_AUTHN_YES - } else { - *s = SAMLForceAuthn_FORCE_AUTHN_NO - } - case int32: - return trace.Wrap(s.setFromEnum(v)) - case int64: - return trace.Wrap(s.setFromEnum(int32(v))) - case int: - return trace.Wrap(s.setFromEnum(int32(v))) - case float64: - return trace.Wrap(s.setFromEnum(int32(v))) - case float32: - return trace.Wrap(s.setFromEnum(int32(v))) - default: - return trace.BadParameter("SAMLForceAuthn invalid type %T", val) - } - return nil -} - -// setFromEnum sets the value from enum value as int32. -func (s *SAMLForceAuthn) setFromEnum(val int32) error { - if _, ok := SAMLForceAuthn_name[val]; !ok { - return trace.BadParameter("invalid SAMLForceAuthn enum %v", val) - } - *s = SAMLForceAuthn(val) - return nil -} diff --git a/api/types/saml.go b/api/types/saml.go index 0295aeb3e66c1..5f042f2f4047c 100644 --- a/api/types/saml.go +++ b/api/types/saml.go @@ -17,6 +17,7 @@ limitations under the License. package types import ( + "encoding/json" "slices" "strings" "time" @@ -524,3 +525,98 @@ func (r *SAMLAuthRequest) Check() error { } return nil } + +// MarshalJSON marshals SAMLForceAuthn to string. +func (s SAMLForceAuthn) MarshalYAML() (interface{}, error) { + val, err := s.encode() + if err != nil { + return nil, trace.Wrap(err) + } + return val, nil +} + +// UnmarshalYAML supports parsing SAMLForceAuthn from string. +func (s *SAMLForceAuthn) UnmarshalYAML(unmarshal func(interface{}) error) error { + var val any + if err := unmarshal(&val); err != nil { + return trace.Wrap(err) + } + return trace.Wrap(s.decode(val)) +} + +// MarshalJSON marshals SAMLForceAuthn to string. +func (s SAMLForceAuthn) MarshalJSON() ([]byte, error) { + val, err := s.encode() + if err != nil { + return nil, trace.Wrap(err) + } + out, err := json.Marshal(val) + return out, trace.Wrap(err) +} + +// UnmarshalJSON supports parsing SAMLForceAuthn from string. +func (s *SAMLForceAuthn) UnmarshalJSON(data []byte) error { + var val any + if err := json.Unmarshal(data, &val); err != nil { + return trace.Wrap(err) + } + return trace.Wrap(s.decode(val)) +} + +func (s *SAMLForceAuthn) encode() (string, error) { + switch *s { + case SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED: + return "", nil + case SAMLForceAuthn_FORCE_AUTHN_NO: + return "no", nil + case SAMLForceAuthn_FORCE_AUTHN_YES: + return "yes", nil + default: + return "", trace.BadParameter("SAMLForceAuthn invalid value %v", *s) + } +} + +func (s *SAMLForceAuthn) decode(val any) error { + switch v := val.(type) { + case string: + // try parsing as a boolean + switch strings.ToLower(v) { + case "": + *s = SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED + case "yes", "yeah", "y", "true", "1", "on": + *s = SAMLForceAuthn_FORCE_AUTHN_YES + case "no", "nope", "n", "false", "0", "off": + *s = SAMLForceAuthn_FORCE_AUTHN_NO + default: + return trace.BadParameter("SAMLForceAuthn invalid value %v", val) + } + case bool: + if v { + *s = SAMLForceAuthn_FORCE_AUTHN_YES + } else { + *s = SAMLForceAuthn_FORCE_AUTHN_NO + } + case int32: + return trace.Wrap(s.setFromEnum(v)) + case int64: + return trace.Wrap(s.setFromEnum(int32(v))) + case int: + return trace.Wrap(s.setFromEnum(int32(v))) + case float64: + return trace.Wrap(s.setFromEnum(int32(v))) + case float32: + return trace.Wrap(s.setFromEnum(int32(v))) + default: + return trace.BadParameter("SAMLForceAuthn invalid type %T", val) + } + return nil +} + +// setFromEnum sets the value from enum value as int32. +func (s *SAMLForceAuthn) setFromEnum(val int32) error { + if _, ok := SAMLForceAuthn_name[val]; !ok { + return trace.BadParameter("invalid SAMLForceAuthn enum %v", val) + } + *s = SAMLForceAuthn(val) + return nil +} diff --git a/api/types/saml_test.go b/api/types/saml_test.go index 228b0afd9f35b..933df64e150f6 100644 --- a/api/types/saml_test.go +++ b/api/types/saml_test.go @@ -17,9 +17,13 @@ limitations under the License. package types_test import ( + "encoding/json" + "fmt" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/yaml.v2" "github.com/gravitational/teleport/api/types" ) @@ -107,3 +111,62 @@ func TestSAMLForceAuthn(t *testing.T) { }) } } + +func TestSAMLForceAuthn_Encoding(t *testing.T) { + for _, tt := range []struct { + forceAuthn types.SAMLForceAuthn + expectEncoded string + }{ + { + forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED, + expectEncoded: "", + }, { + forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_YES, + expectEncoded: "yes", + }, { + forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_NO, + expectEncoded: "no", + }, + } { + t.Run(tt.forceAuthn.String(), func(t *testing.T) { + type object struct { + ForceAuthn types.SAMLForceAuthn `json:"force_authn" yaml:"force_authn"` + } + o := object{ + ForceAuthn: tt.forceAuthn, + } + objectJSON := fmt.Sprintf(`{"force_authn":%q}`, tt.expectEncoded) + objectYAML := fmt.Sprintf("force_authn: %q\n", tt.expectEncoded) + + t.Run("JSON", func(t *testing.T) { + t.Run("Marshal", func(t *testing.T) { + gotJSON, err := json.Marshal(o) + assert.NoError(t, err, "unexpected error from json.Marshal") + assert.Equal(t, objectJSON, string(gotJSON), "unexpected json.Marshal value") + }) + + t.Run("Unmarshal", func(t *testing.T) { + var gotObject object + err := json.Unmarshal([]byte(objectJSON), &gotObject) + assert.NoError(t, err, "unexpected error from json.Unmarshal") + assert.Equal(t, tt.forceAuthn, gotObject.ForceAuthn, "unexpected json.Unmarshal value") + }) + }) + + t.Run("YAML", func(t *testing.T) { + t.Run("Marshal", func(t *testing.T) { + gotYAML, err := yaml.Marshal(o) + assert.NoError(t, err, "unexpected error from yaml.Marshal") + assert.Equal(t, objectYAML, string(gotYAML), "unexpected yaml.Marshal value") + }) + + t.Run("Unmarshal", func(t *testing.T) { + var gotObject object + err := yaml.Unmarshal([]byte(objectYAML), &gotObject) + assert.NoError(t, err, "unexpected error from yaml.Unmarshal") + assert.Equal(t, tt.forceAuthn, gotObject.ForceAuthn, "unexpected yaml.Unmarshal value") + }) + }) + }) + } +}