Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SAMLForceAuthn marshalling #48098

Merged
merged 2 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions api/types/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package types

import (
"encoding/json"
"slices"
"strings"
"time"
Expand Down Expand Up @@ -524,3 +525,98 @@ func (r *SAMLAuthRequest) Check() error {
}
return nil
}

// MarshalJSON marshals SAMLForceAuthn to string.
func (s SAMLForceAuthn) MarshalYAML() (interface{}, error) {
val, err := s.encode()
if err != nil {
return nil, trace.Wrap(err)
}
return val, nil
}

// UnmarshalYAML supports parsing SAMLForceAuthn from string.
func (s *SAMLForceAuthn) UnmarshalYAML(unmarshal func(interface{}) error) error {
var val any
if err := unmarshal(&val); err != nil {
return trace.Wrap(err)
}
return trace.Wrap(s.decode(val))
}

// MarshalJSON marshals SAMLForceAuthn to string.
func (s SAMLForceAuthn) MarshalJSON() ([]byte, error) {
val, err := s.encode()
if err != nil {
return nil, trace.Wrap(err)
}
out, err := json.Marshal(val)
return out, trace.Wrap(err)
}

// UnmarshalJSON supports parsing SAMLForceAuthn from string.
func (s *SAMLForceAuthn) UnmarshalJSON(data []byte) error {
var val any
if err := json.Unmarshal(data, &val); err != nil {
return trace.Wrap(err)
}
return trace.Wrap(s.decode(val))
}

func (s *SAMLForceAuthn) encode() (string, error) {
switch *s {
case SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED:
return "", nil
case SAMLForceAuthn_FORCE_AUTHN_NO:
return "no", nil
case SAMLForceAuthn_FORCE_AUTHN_YES:
return "yes", nil
default:
return "", trace.BadParameter("SAMLForceAuthn invalid value %v", *s)
}
}

func (s *SAMLForceAuthn) decode(val any) error {
switch v := val.(type) {
case string:
// try parsing as a boolean
switch strings.ToLower(v) {
case "":
*s = SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED
case "yes", "yeah", "y", "true", "1", "on":
*s = SAMLForceAuthn_FORCE_AUTHN_YES
case "no", "nope", "n", "false", "0", "off":
*s = SAMLForceAuthn_FORCE_AUTHN_NO
default:
return trace.BadParameter("SAMLForceAuthn invalid value %v", val)
}
case bool:
if v {
*s = SAMLForceAuthn_FORCE_AUTHN_YES
} else {
*s = SAMLForceAuthn_FORCE_AUTHN_NO
}
case int32:
return trace.Wrap(s.setFromEnum(v))
case int64:
return trace.Wrap(s.setFromEnum(int32(v)))
case int:
return trace.Wrap(s.setFromEnum(int32(v)))
case float64:
return trace.Wrap(s.setFromEnum(int32(v)))
case float32:
return trace.Wrap(s.setFromEnum(int32(v)))
default:
return trace.BadParameter("SAMLForceAuthn invalid type %T", val)
}
return nil
}

// setFromEnum sets the value from enum value as int32.
func (s *SAMLForceAuthn) setFromEnum(val int32) error {
if _, ok := SAMLForceAuthn_name[val]; !ok {
return trace.BadParameter("invalid SAMLForceAuthn enum %v", val)
}
*s = SAMLForceAuthn(val)
return nil
}
63 changes: 63 additions & 0 deletions api/types/saml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@ limitations under the License.
package types_test

import (
"encoding/json"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v2"

"github.com/gravitational/teleport/api/types"
)
Expand Down Expand Up @@ -107,3 +111,62 @@ func TestSAMLForceAuthn(t *testing.T) {
})
}
}

func TestSAMLForceAuthn_Encoding(t *testing.T) {
for _, tt := range []struct {
forceAuthn types.SAMLForceAuthn
expectEncoded string
}{
{
forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED,
expectEncoded: "",
}, {
forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_YES,
expectEncoded: "yes",
}, {
forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_NO,
expectEncoded: "no",
},
} {
t.Run(tt.forceAuthn.String(), func(t *testing.T) {
type object struct {
ForceAuthn types.SAMLForceAuthn `json:"force_authn" yaml:"force_authn"`
}
o := object{
ForceAuthn: tt.forceAuthn,
}
objectJSON := fmt.Sprintf(`{"force_authn":%q}`, tt.expectEncoded)
objectYAML := fmt.Sprintf("force_authn: %q\n", tt.expectEncoded)

t.Run("JSON", func(t *testing.T) {
t.Run("Marshal", func(t *testing.T) {
gotJSON, err := json.Marshal(o)
assert.NoError(t, err, "unexpected error from json.Marshal")
assert.Equal(t, objectJSON, string(gotJSON), "unexpected json.Marshal value")
})

t.Run("Unmarshal", func(t *testing.T) {
var gotObject object
err := json.Unmarshal([]byte(objectJSON), &gotObject)
assert.NoError(t, err, "unexpected error from json.Unmarshal")
assert.Equal(t, tt.forceAuthn, gotObject.ForceAuthn, "unexpected json.Unmarshal value")
})
})

t.Run("YAML", func(t *testing.T) {
t.Run("Marshal", func(t *testing.T) {
gotYAML, err := yaml.Marshal(o)
assert.NoError(t, err, "unexpected error from yaml.Marshal")
assert.Equal(t, objectYAML, string(gotYAML), "unexpected yaml.Marshal value")
})

t.Run("Unmarshal", func(t *testing.T) {
var gotObject object
err := yaml.Unmarshal([]byte(objectYAML), &gotObject)
assert.NoError(t, err, "unexpected error from yaml.Unmarshal")
assert.Equal(t, tt.forceAuthn, gotObject.ForceAuthn, "unexpected yaml.Unmarshal value")
})
})
})
}
}
Loading