Skip to content

Commit

Permalink
Use any; Move to api/types/saml.go; add test.
Browse files Browse the repository at this point in the history
  • Loading branch information
Joerger committed Oct 30, 2024
1 parent ca0e6af commit e3b49d2
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 112 deletions.
112 changes: 0 additions & 112 deletions api/types/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ package types

import (
"bytes"
"encoding/json"
"strings"
"time"

"github.com/gogo/protobuf/jsonpb"
Expand Down Expand Up @@ -175,113 +173,3 @@ func (d *MFADevice) UnmarshalJSON(buf []byte) error {
err := unmarshaler.Unmarshal(bytes.NewReader(buf), d)
return trace.Wrap(err)
}

// MarshalJSON marshals SAMLForceAuthn to string.
func (s *SAMLForceAuthn) MarshalYAML() (interface{}, error) {
val, err := s.encode()
if err != nil {
return nil, trace.Wrap(err)
}
return val, nil
}

// UnmarshalYAML supports parsing SAMLForceAuthn from string.
func (s *SAMLForceAuthn) UnmarshalYAML(unmarshal func(interface{}) error) error {
var val interface{}
err := unmarshal(&val)
if err != nil {
return trace.Wrap(err)
}

err = s.decode(val)
return trace.Wrap(err)
}

// MarshalJSON marshals SAMLForceAuthn to string.
func (s *SAMLForceAuthn) MarshalJSON() ([]byte, error) {
val, err := s.encode()
if err != nil {
return nil, trace.Wrap(err)
}
out, err := json.Marshal(val)
return out, trace.Wrap(err)
}

// UnmarshalJSON supports parsing SAMLForceAuthn from string.
func (s *SAMLForceAuthn) UnmarshalJSON(data []byte) error {
var val interface{}
err := json.Unmarshal(data, &val)
if err != nil {
return trace.Wrap(err)
}

err = s.decode(val)
return trace.Wrap(err)
}

const (
// SAMLForceAuthnOTPString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_OTP
SAMLForceAuthnOTPString = "otp"
// SAMLForceAuthnWebauthnString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_WEBAUTHN
SAMLForceAuthnWebauthnString = "webauthn"
// SAMLForceAuthnSSOString is the string representation of SAMLForceAuthn_SECOND_FACTOR_TYPE_SSO
SAMLForceAuthnSSOString = "sso"
)

func (s *SAMLForceAuthn) encode() (string, error) {
switch *s {
case SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED:
return "", nil
case SAMLForceAuthn_FORCE_AUTHN_NO:
return "no", nil
case SAMLForceAuthn_FORCE_AUTHN_YES:
return "yes", nil
default:
return "", trace.BadParameter("SAMLForceAuthn invalid value %v", *s)
}
}

func (s *SAMLForceAuthn) decode(val any) error {
switch v := val.(type) {
case string:
// try parsing as a boolean
switch strings.ToLower(v) {
case "":
*s = SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED
case "yes", "yeah", "y", "true", "1", "on":
*s = SAMLForceAuthn_FORCE_AUTHN_YES
case "no", "nope", "n", "false", "0", "off":
*s = SAMLForceAuthn_FORCE_AUTHN_NO
default:
return trace.BadParameter("SAMLForceAuthn invalid value %v", val)
}
case bool:
if v {
*s = SAMLForceAuthn_FORCE_AUTHN_YES
} else {
*s = SAMLForceAuthn_FORCE_AUTHN_NO
}
case int32:
return trace.Wrap(s.setFromEnum(v))
case int64:
return trace.Wrap(s.setFromEnum(int32(v)))
case int:
return trace.Wrap(s.setFromEnum(int32(v)))
case float64:
return trace.Wrap(s.setFromEnum(int32(v)))
case float32:
return trace.Wrap(s.setFromEnum(int32(v)))
default:
return trace.BadParameter("SAMLForceAuthn invalid type %T", val)
}
return nil
}

// setFromEnum sets the value from enum value as int32.
func (s *SAMLForceAuthn) setFromEnum(val int32) error {
if _, ok := SAMLForceAuthn_name[val]; !ok {
return trace.BadParameter("invalid SAMLForceAuthn enum %v", val)
}
*s = SAMLForceAuthn(val)
return nil
}
96 changes: 96 additions & 0 deletions api/types/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package types

import (
"encoding/json"
"slices"
"strings"
"time"
Expand Down Expand Up @@ -524,3 +525,98 @@ func (r *SAMLAuthRequest) Check() error {
}
return nil
}

// MarshalJSON marshals SAMLForceAuthn to string.
func (s SAMLForceAuthn) MarshalYAML() (interface{}, error) {
val, err := s.encode()
if err != nil {
return nil, trace.Wrap(err)
}
return val, nil
}

// UnmarshalYAML supports parsing SAMLForceAuthn from string.
func (s *SAMLForceAuthn) UnmarshalYAML(unmarshal func(interface{}) error) error {
var val any
if err := unmarshal(&val); err != nil {
return trace.Wrap(err)
}
return trace.Wrap(s.decode(val))
}

// MarshalJSON marshals SAMLForceAuthn to string.
func (s SAMLForceAuthn) MarshalJSON() ([]byte, error) {
val, err := s.encode()
if err != nil {
return nil, trace.Wrap(err)
}
out, err := json.Marshal(val)
return out, trace.Wrap(err)
}

// UnmarshalJSON supports parsing SAMLForceAuthn from string.
func (s *SAMLForceAuthn) UnmarshalJSON(data []byte) error {
var val any
if err := json.Unmarshal(data, &val); err != nil {
return trace.Wrap(err)
}
return trace.Wrap(s.decode(val))
}

func (s *SAMLForceAuthn) encode() (string, error) {
switch *s {
case SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED:
return "", nil
case SAMLForceAuthn_FORCE_AUTHN_NO:
return "no", nil
case SAMLForceAuthn_FORCE_AUTHN_YES:
return "yes", nil
default:
return "", trace.BadParameter("SAMLForceAuthn invalid value %v", *s)
}
}

func (s *SAMLForceAuthn) decode(val any) error {
switch v := val.(type) {
case string:
// try parsing as a boolean
switch strings.ToLower(v) {
case "":
*s = SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED
case "yes", "yeah", "y", "true", "1", "on":
*s = SAMLForceAuthn_FORCE_AUTHN_YES
case "no", "nope", "n", "false", "0", "off":
*s = SAMLForceAuthn_FORCE_AUTHN_NO
default:
return trace.BadParameter("SAMLForceAuthn invalid value %v", val)
}
case bool:
if v {
*s = SAMLForceAuthn_FORCE_AUTHN_YES
} else {
*s = SAMLForceAuthn_FORCE_AUTHN_NO
}
case int32:
return trace.Wrap(s.setFromEnum(v))
case int64:
return trace.Wrap(s.setFromEnum(int32(v)))
case int:
return trace.Wrap(s.setFromEnum(int32(v)))
case float64:
return trace.Wrap(s.setFromEnum(int32(v)))
case float32:
return trace.Wrap(s.setFromEnum(int32(v)))
default:
return trace.BadParameter("SAMLForceAuthn invalid type %T", val)
}
return nil
}

// setFromEnum sets the value from enum value as int32.
func (s *SAMLForceAuthn) setFromEnum(val int32) error {
if _, ok := SAMLForceAuthn_name[val]; !ok {
return trace.BadParameter("invalid SAMLForceAuthn enum %v", val)
}
*s = SAMLForceAuthn(val)
return nil
}
63 changes: 63 additions & 0 deletions api/types/saml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@ limitations under the License.
package types_test

import (
"encoding/json"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v2"

"github.com/gravitational/teleport/api/types"
)
Expand Down Expand Up @@ -107,3 +111,62 @@ func TestSAMLForceAuthn(t *testing.T) {
})
}
}

func TestSAMLForceAuthn_Encoding(t *testing.T) {
for _, tt := range []struct {
forceAuthn types.SAMLForceAuthn
expectEncoded string
}{
{
forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED,
expectEncoded: "",
}, {
forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_YES,
expectEncoded: "yes",
}, {
forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_NO,
expectEncoded: "no",
},
} {
t.Run(tt.forceAuthn.String(), func(t *testing.T) {
type object struct {
ForceAuthn types.SAMLForceAuthn `json:"force_authn" yaml:"force_authn"`
}
o := object{
ForceAuthn: tt.forceAuthn,
}
objectJSON := fmt.Sprintf(`{"force_authn":%q}`, tt.expectEncoded)
objectYAML := fmt.Sprintf("force_authn: %q\n", tt.expectEncoded)

t.Run("JSON", func(t *testing.T) {
t.Run("Marshal", func(t *testing.T) {
gotJSON, err := json.Marshal(o)
assert.NoError(t, err, "unexpected error from json.Marshal")
assert.Equal(t, objectJSON, string(gotJSON), "unexpected json.Marshal value")
})

t.Run("Unmarshal", func(t *testing.T) {
var gotObject object
err := json.Unmarshal([]byte(objectJSON), &gotObject)
assert.NoError(t, err, "unexpected error from json.Unmarshal")
assert.Equal(t, tt.forceAuthn, gotObject.ForceAuthn, "unexpected json.Unmarshal value")
})
})

t.Run("YAML", func(t *testing.T) {
t.Run("Marshal", func(t *testing.T) {
gotYAML, err := yaml.Marshal(o)
assert.NoError(t, err, "unexpected error from yaml.Marshal")
assert.Equal(t, objectYAML, string(gotYAML), "unexpected yaml.Marshal value")
})

t.Run("Unmarshal", func(t *testing.T) {
var gotObject object
err := yaml.Unmarshal([]byte(objectYAML), &gotObject)
assert.NoError(t, err, "unexpected error from yaml.Unmarshal")
assert.Equal(t, tt.forceAuthn, gotObject.ForceAuthn, "unexpected yaml.Unmarshal value")
})
})
})
}
}

0 comments on commit e3b49d2

Please sign in to comment.