Skip to content

Commit

Permalink
[v15] Require device trust for initial device registration endpoints (#…
Browse files Browse the repository at this point in the history
…38657)

* Require device trust for initial device registration endpoints.

* Address comments.

* Wrap error to be user readable.

* Add privilege_exception token comment.

* Fix test.
  • Loading branch information
Joerger authored Feb 27, 2024
1 parent 82eb3db commit 4c8cd9a
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 8 deletions.
5 changes: 1 addition & 4 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3093,7 +3093,7 @@ func (a *Server) CreateRegisterChallenge(ctx context.Context, req *proto.CreateR
}
username = token.GetUser()

case req.ExistingMFAResponse != nil: // Authenticated user without token, tsh.
default: // Authenticated user without token, tsh.
var err error
username, err = authz.GetClientUsername(ctx)
if err != nil {
Expand All @@ -3116,9 +3116,6 @@ func (a *Server) CreateRegisterChallenge(ctx context.Context, req *proto.CreateR
if err != nil {
return nil, trace.Wrap(err)
}

default:
return nil, trace.BadParameter("either a token or an MFA response are required")
}

regChal, err := a.createRegisterChallenge(ctx, &newRegisterChallengeRequest{
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/auth_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ func TestCreateRegisterChallenge(t *testing.T) {
DeviceType: proto.DeviceType_DEVICE_TYPE_WEBAUTHN,
DeviceUsage: proto.DeviceUsage_DEVICE_USAGE_MFA,
})
assert.ErrorContains(t, err, "token or an MFA response")
assert.ErrorContains(t, err, "second factor authentication required")

// Acquire and solve an authn challenge.
authnChal, err := authClient.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{
Expand Down
36 changes: 33 additions & 3 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -6335,17 +6335,35 @@ func (a *ServerWithRoles) CreateAuthenticateChallenge(ctx context.Context, req *

// CreatePrivilegeToken is implemented by AuthService.CreatePrivilegeToken.
func (a *ServerWithRoles) CreatePrivilegeToken(ctx context.Context, req *proto.CreatePrivilegeTokenRequest) (*types.UserTokenV3, error) {
// Device trust: authorize device before issuing a privileged token without an MFA response.
//
// This is an exceptional case for that that results in a "privilege_exception" token, which can
// used to register a user's first MFA device thorugh the WebUI. Since a register challenge can
// be created on behalf of the user using this token (e.g. by the Proxy Service), we must enforce
// the device trust requirement seen in [CreatePrivilegeToken] here instead.
if mfaResp := req.GetExistingMFAResponse(); mfaResp.GetTOTP() == nil && mfaResp.GetWebauthn() == nil {
if err := a.enforceGlobalModeTrustedDevice(ctx); err != nil {
return nil, trace.Wrap(err, "device trust is required for users to create a privileged token without an MFA check")
}
}

return a.authServer.CreatePrivilegeToken(ctx, req)
}

// CreateRegisterChallenge is implemented by AuthService.CreateRegisterChallenge.
func (a *ServerWithRoles) CreateRegisterChallenge(ctx context.Context, req *proto.CreateRegisterChallengeRequest) (*proto.MFARegisterChallenge, error) {
switch {
case req.TokenID != "":
case req.ExistingMFAResponse != nil:
if req.TokenID == "" {
if !authz.IsLocalOrRemoteUser(a.context) {
return nil, trace.BadParameter("only end users are allowed issue registration challenges without a privilege token")
}

// Device trust: authorize device before issuing a register challenge without an MFA response or privilege token.
// This is an exceptional case for users registering their first MFA challenge through `tsh`.
if mfaResp := req.GetExistingMFAResponse(); mfaResp.GetTOTP() == nil && mfaResp.GetWebauthn() == nil {
if err := a.enforceGlobalModeTrustedDevice(ctx); err != nil {
return nil, trace.Wrap(err, "device trust is required for users to register their first MFA device")
}
}
}

// The following serve as means of authentication for this RPC:
Expand All @@ -6354,6 +6372,18 @@ func (a *ServerWithRoles) CreateRegisterChallenge(ctx context.Context, req *prot
return a.authServer.CreateRegisterChallenge(ctx, req)
}

// enforceGlobalModeTrustedDevice is used to enforce global device trust requirements
// for key endpoints.
func (a *ServerWithRoles) enforceGlobalModeTrustedDevice(ctx context.Context) error {
authPref, err := a.GetAuthPreference(ctx)
if err != nil {
return trace.Wrap(err)
}

err = dtauthz.VerifyTLSUser(authPref.GetDeviceTrust(), a.context.Identity.GetIdentity())
return trace.Wrap(err)
}

// GetAccountRecoveryCodes is implemented by AuthService.GetAccountRecoveryCodes.
func (a *ServerWithRoles) GetAccountRecoveryCodes(ctx context.Context, req *proto.GetAccountRecoveryCodesRequest) (*proto.RecoveryCodes, error) {
// User in context can retrieve their own recovery codes.
Expand Down
114 changes: 114 additions & 0 deletions lib/auth/grpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ import (
wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/defaults"
dtauthz "github.com/gravitational/teleport/lib/devicetrust/authz"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/tlsca"
Expand Down Expand Up @@ -977,6 +978,119 @@ func TestGenerateUserCerts_deviceAuthz(t *testing.T) {
}
}

// Test that device trust is required for a user registering their first MFA device.
func TestRegisterFirstDevice_deviceAuthz(t *testing.T) {
modules.SetTestModules(t, &modules.TestModules{
TestBuildType: modules.BuildEnterprise, // required for Device Trust.
})

testServer := newTestTLSServer(t)

ctx := context.Background()
authServer := testServer.Auth()

// Create a user for testing.
user, _, err := CreateUserAndRole(testServer.Auth(), "llama", []string{"llama"}, nil)
require.NoError(t, err, "CreateUserAndRole failed")
username := user.GetName()

// Create clients with and without device extensions.
clientWithoutDevice, err := testServer.NewClient(TestUser(username))
require.NoError(t, err, "NewClient failed")

clientWithDevice, err := testServer.NewClient(
TestUserWithDeviceExtensions(username, tlsca.DeviceExtensions{
DeviceID: "deviceid1",
AssetTag: "assettag1",
CredentialID: "credentialid1",
}))
require.NoError(t, err, "NewClient failed")

// updateAuthPref is a helper used throughout the test.
updateAuthPref := func(t *testing.T, modify func(ap types.AuthPreference)) {
authPref, err := authServer.GetAuthPreference(ctx)
require.NoError(t, err, "GetAuthPreference failed")

modify(authPref)

require.NoError(t,
authServer.SetAuthPreference(ctx, authPref),
"SetAuthPreference failed")
}

// Enable webauthn
updateAuthPref(t, func(authPref types.AuthPreference) {
authPref.SetSecondFactor(constants.SecondFactorOptional)
authPref.SetWebauthn(&types.Webauthn{
RPID: "localhost",
})
})

assertSuccess := func(t *testing.T, err error) {
assert.NoError(t, err)
}
assertAccessDenied := func(t *testing.T, err error) {
assert.True(t, trace.IsAccessDenied(err), "expected access denied error but got %v", err)
assert.ErrorContains(t, err, dtauthz.ErrTrustedDeviceRequired.Error())
}

tests := []struct {
name string
clusterDeviceMode string
client *Client
skipLoginCerts bool // aka non-MFA issuance.
skipSingleUseCerts bool // aka MFA/streaming issuance.
assertErr func(t *testing.T, err error)
}{
{
name: "mode=optional without extensions",
clusterDeviceMode: constants.DeviceTrustModeOptional,
client: clientWithoutDevice,
assertErr: assertSuccess,
},
{
name: "mode=optional with extensions",
clusterDeviceMode: constants.DeviceTrustModeOptional,
client: clientWithDevice,
assertErr: assertSuccess,
},
{
name: "nok: mode=required without extensions",
clusterDeviceMode: constants.DeviceTrustModeRequired,
client: clientWithoutDevice,
assertErr: assertAccessDenied,
},
{
name: "mode=required with extensions",
clusterDeviceMode: constants.DeviceTrustModeRequired,
client: clientWithDevice,
assertErr: assertSuccess,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
updateAuthPref(t, func(ap types.AuthPreference) {
ap.SetDeviceTrust(&types.DeviceTrust{
Mode: test.clusterDeviceMode,
})
})

t.Run("CreatePrivilegeTokenRequest", func(t *testing.T) {
_, err := test.client.CreatePrivilegeToken(ctx, &proto.CreatePrivilegeTokenRequest{})
test.assertErr(t, err)
})

t.Run("CreateRegisterChallenge", func(t *testing.T) {
_, err := test.client.CreateRegisterChallenge(ctx, &proto.CreateRegisterChallengeRequest{
DeviceType: proto.DeviceType_DEVICE_TYPE_WEBAUTHN,
DeviceUsage: proto.DeviceUsage_DEVICE_USAGE_MFA,
})
test.assertErr(t, err)
})
})
}
}

func mustCreateDatabase(t *testing.T, name, protocol, uri string) *types.DatabaseV3 {
database, err := types.NewDatabaseV3(
types.Metadata{
Expand Down

0 comments on commit 4c8cd9a

Please sign in to comment.