Skip to content

Commit

Permalink
Use any; Move to api/types/saml.go; add test.
Browse files Browse the repository at this point in the history
  • Loading branch information
Joerger committed Oct 30, 2024
1 parent ca0e6af commit fc9990f
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 112 deletions.
112 changes: 0 additions & 112 deletions api/types/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ package types

import (
"bytes"
"encoding/json"
"strings"
"time"

"github.com/gogo/protobuf/jsonpb"
Expand Down Expand Up @@ -175,113 +173,3 @@ func (d *MFADevice) UnmarshalJSON(buf []byte) error {
err := unmarshaler.Unmarshal(bytes.NewReader(buf), d)
return trace.Wrap(err)
}

// MarshalJSON marshals SAMLForceAuthn to string.
func (s *SAMLForceAuthn) MarshalYAML() (interface{}, error) {
val, err := s.encode()
if err != nil {
return nil, trace.Wrap(err)
}
return val, nil
}

// UnmarshalYAML supports parsing SAMLForceAuthn from string.
func (s *SAMLForceAuthn) UnmarshalYAML(unmarshal func(interface{}) error) error {
var val interface{}
err := unmarshal(&val)
if err != nil {
return trace.Wrap(err)
}

err = s.decode(val)
return trace.Wrap(err)
}

// MarshalJSON marshals SAMLForceAuthn to string.
func (s *SAMLForceAuthn) MarshalJSON() ([]byte, error) {
val, err := s.encode()
if err != nil {
return nil, trace.Wrap(err)
}
out, err := json.Marshal(val)
return out, trace.Wrap(err)
}

// UnmarshalJSON supports parsing SAMLForceAuthn from string.
func (s *SAMLForceAuthn) UnmarshalJSON(data []byte) error {
var val interface{}
err := json.Unmarshal(data, &val)
if err != nil {
return trace.Wrap(err)
}

err = s.decode(val)
return trace.Wrap(err)
}

const (
// SAMLForceAuthnOTPString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_OTP
SAMLForceAuthnOTPString = "otp"
// SAMLForceAuthnWebauthnString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_WEBAUTHN
SAMLForceAuthnWebauthnString = "webauthn"
// SAMLForceAuthnSSOString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_SSO
SAMLForceAuthnSSOString = "sso"
)

func (s *SAMLForceAuthn) encode() (string, error) {
switch *s {
case SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED:
return "", nil
case SAMLForceAuthn_FORCE_AUTHN_NO:
return "no", nil
case SAMLForceAuthn_FORCE_AUTHN_YES:
return "yes", nil
default:
return "", trace.BadParameter("SAMLForceAuthn invalid value %v", *s)
}
}

func (s *SAMLForceAuthn) decode(val any) error {
switch v := val.(type) {
case string:
// try parsing as a boolean
switch strings.ToLower(v) {
case "":
*s = SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED
case "yes", "yeah", "y", "true", "1", "on":
*s = SAMLForceAuthn_FORCE_AUTHN_YES
case "no", "nope", "n", "false", "0", "off":
*s = SAMLForceAuthn_FORCE_AUTHN_NO
default:
return trace.BadParameter("SAMLForceAuthn invalid value %v", val)
}
case bool:
if v {
*s = SAMLForceAuthn_FORCE_AUTHN_YES
} else {
*s = SAMLForceAuthn_FORCE_AUTHN_NO
}
case int32:
return trace.Wrap(s.setFromEnum(v))
case int64:
return trace.Wrap(s.setFromEnum(int32(v)))
case int:
return trace.Wrap(s.setFromEnum(int32(v)))
case float64:
return trace.Wrap(s.setFromEnum(int32(v)))
case float32:
return trace.Wrap(s.setFromEnum(int32(v)))
default:
return trace.BadParameter("SAMLForceAuthn invalid type %T", val)
}
return nil
}

// setFromEnum sets the value from enum value as int32.
func (s *SAMLForceAuthn) setFromEnum(val int32) error {
if _, ok := SAMLForceAuthn_name[val]; !ok {
return trace.BadParameter("invalid SAMLForceAuthn enum %v", val)
}
*s = SAMLForceAuthn(val)
return nil
}
105 changes: 105 additions & 0 deletions api/types/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package types

import (
"encoding/json"
"slices"
"strings"
"time"
Expand Down Expand Up @@ -524,3 +525,107 @@ func (r *SAMLAuthRequest) Check() error {
}
return nil
}

// MarshalJSON marshals SAMLForceAuthn to string.
func (s *SAMLForceAuthn) MarshalYAML() (any, error) {
val, err := s.Encode()
if err != nil {
return nil, trace.Wrap(err)
}
return val, nil
}

// UnmarshalYAML supports parsing SAMLForceAuthn from string.
func (s *SAMLForceAuthn) UnmarshalYAML(unmarshal func(any) error) error {
var val any
if err := unmarshal(&val); err != nil {
return trace.Wrap(err)
}
return trace.Wrap(s.Decode(val))
}

// MarshalJSON marshals SAMLForceAuthn to string.
func (s *SAMLForceAuthn) MarshalJSON() ([]byte, error) {
val, err := s.Encode()
if err != nil {
return nil, trace.Wrap(err)
}
out, err := json.Marshal(val)
return out, trace.Wrap(err)
}

// UnmarshalJSON supports parsing SAMLForceAuthn from string.
func (s *SAMLForceAuthn) UnmarshalJSON(data []byte) error {
var val any
if err := json.Unmarshal(data, &val); err != nil {
return trace.Wrap(err)
}
return trace.Wrap(s.Decode(val))
}

const (
// SAMLForceAuthnOTPString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_OTP
SAMLForceAuthnOTPString = "otp"
// SAMLForceAuthnWebauthnString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_WEBAUTHN
SAMLForceAuthnWebauthnString = "webauthn"
// SAMLForceAuthnSSOString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_SSO
SAMLForceAuthnSSOString = "sso"
)

func (s *SAMLForceAuthn) Encode() (string, error) {
switch *s {
case SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED:
return "", nil
case SAMLForceAuthn_FORCE_AUTHN_NO:
return "no", nil
case SAMLForceAuthn_FORCE_AUTHN_YES:
return "yes", nil
default:
return "", trace.BadParameter("SAMLForceAuthn invalid value %v", *s)
}
}

func (s *SAMLForceAuthn) Decode(val any) error {
switch v := val.(type) {
case string:
// try parsing as a boolean
switch strings.ToLower(v) {
case "":
*s = SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED
case "yes", "yeah", "y", "true", "1", "on":
*s = SAMLForceAuthn_FORCE_AUTHN_YES
case "no", "nope", "n", "false", "0", "off":
*s = SAMLForceAuthn_FORCE_AUTHN_NO
default:
return trace.BadParameter("SAMLForceAuthn invalid value %v", val)
}
case bool:
if v {
*s = SAMLForceAuthn_FORCE_AUTHN_YES
} else {
*s = SAMLForceAuthn_FORCE_AUTHN_NO
}
case int32:
return trace.Wrap(s.setFromEnum(v))
case int64:
return trace.Wrap(s.setFromEnum(int32(v)))
case int:
return trace.Wrap(s.setFromEnum(int32(v)))
case float64:
return trace.Wrap(s.setFromEnum(int32(v)))
case float32:
return trace.Wrap(s.setFromEnum(int32(v)))
default:
return trace.BadParameter("SAMLForceAuthn invalid type %T", val)
}
return nil
}

// setFromEnum sets the value from enum value as int32.
func (s *SAMLForceAuthn) setFromEnum(val int32) error {
if _, ok := SAMLForceAuthn_name[val]; !ok {
return trace.BadParameter("invalid SAMLForceAuthn enum %v", val)
}
*s = SAMLForceAuthn(val)
return nil
}
34 changes: 34 additions & 0 deletions api/types/saml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package types_test
import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
Expand Down Expand Up @@ -107,3 +108,36 @@ func TestSAMLForceAuthn(t *testing.T) {
})
}
}

func TestEncodeDecodeSAMLForceAuthn(t *testing.T) {
for _, tt := range []struct {
forceAuthn types.SAMLForceAuthn
encoded string
}{
{
forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED,
encoded: "",
}, {
forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_NO,
encoded: "no",
}, {
forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_YES,
encoded: "yes",
},
} {
t.Run(tt.forceAuthn.String(), func(t *testing.T) {
t.Run("encode", func(t *testing.T) {
encoded, err := tt.forceAuthn.Encode()
assert.NoError(t, err)
assert.Equal(t, tt.encoded, encoded)
})

t.Run("decode", func(t *testing.T) {
var decoded types.SAMLForceAuthn
err := decoded.Decode(tt.encoded)
assert.NoError(t, err)
assert.Equal(t, tt.forceAuthn, decoded)
})
})
}
}

0 comments on commit fc9990f

Please sign in to comment.