diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 2c58972148e3b..f482375240ab5 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -31,6 +31,7 @@ import ( "crypto/subtle" "crypto/tls" "crypto/x509" + "encoding/pem" "errors" "fmt" "io" @@ -2226,7 +2227,8 @@ type AugmentUserCertificateOpts struct { // Used by Device Trust to add device extensions to the user certificate. func (a *Server) AugmentContextUserCertificates( ctx context.Context, - authCtx *authz.Context, opts *AugmentUserCertificateOpts, + authCtx *authz.Context, + opts *AugmentUserCertificateOpts, ) (*proto.Certs, error) { switch { case authCtx == nil: @@ -2235,10 +2237,126 @@ func (a *Server) AugmentContextUserCertificates( return nil, trace.BadParameter("opts required") } + // Fetch user TLS certificate. + x509Cert, err := authz.UserCertificateFromContext(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + identity := authCtx.Identity.GetIdentity() + + return a.augmentUserCertificates(ctx, augmentUserCertificatesOpts{ + checker: authCtx.Checker, + x509Cert: x509Cert, + x509Identity: &identity, + sshAuthorizedKey: opts.SSHAuthorizedKey, + deviceExtensions: opts.DeviceExtensions, + }) +} + +// AugmentWebSessionCertificatesOpts aggregates arguments for +// [AugmentWebSessionCertificates]. +type AugmentWebSessionCertificatesOpts struct { + // WebSessionID is the identifier for the WebSession. + WebSessionID string + + // DeviceExtensions are the device-aware extensions to add to the certificates + // being augmented. + DeviceExtensions *DeviceExtensions +} + +// AugmentWebSessionCertificates is a variant of +// [AugmentContextUserCertificates] that operates directly in the certificates +// stored in a WebSession. +// +// The authCtx user must be the owner of the session. Unlike +// [AugmentContextUserCertificates], the user certificate doesn't need to be +// present in the ctx, as the session certificates are used. +// +// On success the WebSession is updated with device extension certificates. +func (a *Server) AugmentWebSessionCertificates( + ctx context.Context, + authCtx *authz.Context, + opts *AugmentWebSessionCertificatesOpts, +) error { + switch { + case authCtx == nil: + return trace.BadParameter("authCtx required") + case opts == nil: + return trace.BadParameter("opts required") + case opts.WebSessionID == "": + return trace.BadParameter("opts.WebSessionID required") + } + + identity := authCtx.Identity.GetIdentity() + + // Get and validate session. + sessions := a.WebSessions() + session, err := sessions.Get(ctx, types.GetWebSessionRequest{ + User: identity.Username, + SessionID: opts.WebSessionID, + }) + if err != nil { + return trace.Wrap(err) + } + // Sanity check: session must belong to user. + if session.GetUser() != identity.Username { + return trace.AccessDenied("identity and session user mismatch") + } + + // Coerce session before doing more expensive operations. + sessionV2, ok := session.(*types.WebSessionV2) + if !ok { + return trace.BadParameter("unexpected WebSession type: %T", session) + } + + // Parse X.509 certificate. + block, _ := pem.Decode(session.GetTLSCert()) + if block == nil { + return trace.BadParameter("cannot decode session TLS certificate") + } + x509Cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return trace.Wrap(err) + } + x509Identity, err := tlsca.FromSubject(x509Cert.Subject, x509Cert.NotAfter) + if err != nil { + return trace.Wrap(err) + } + + // Augment certificates. + newCerts, err := a.augmentUserCertificates(ctx, augmentUserCertificatesOpts{ + checker: authCtx.Checker, + x509Cert: x509Cert, + x509Identity: x509Identity, + sshAuthorizedKey: session.GetPub(), + deviceExtensions: opts.DeviceExtensions, + }) + if err != nil { + return trace.Wrap(err) + } + + // Update WebSession. + sessionV2.Spec.Pub = newCerts.SSH + sessionV2.Spec.TLSCert = newCerts.TLS + return trace.Wrap(sessions.Upsert(ctx, sessionV2)) +} + +type augmentUserCertificatesOpts struct { + checker services.AccessChecker + x509Cert *x509.Certificate + x509Identity *tlsca.Identity + sshAuthorizedKey []byte + deviceExtensions *DeviceExtensions +} + +func (a *Server) augmentUserCertificates( + ctx context.Context, + opts augmentUserCertificatesOpts, +) (*proto.Certs, error) { // Is at least one extension present? // Are the extensions valid? - identity := authCtx.Identity.GetIdentity() - dev := opts.DeviceExtensions + dev := opts.deviceExtensions switch { case dev == nil: // Only extension that currently exists. return nil, trace.BadParameter("at least one opts extension must be present") @@ -2248,28 +2366,27 @@ func (a *Server) AugmentContextUserCertificates( return nil, trace.BadParameter("opts.DeviceExtensions.AssetTag required") case dev.CredentialID == "": return nil, trace.BadParameter("opts.DeviceExtensions.CredentialID required") - // Do not reissue if device extensions are already present. - case identity.DeviceExtensions.DeviceID != "", - identity.DeviceExtensions.AssetTag != "", - identity.DeviceExtensions.CredentialID != "": - return nil, trace.BadParameter("device extensions already present") } - // Fetch user TLS certificate. - x509Cert, err := authz.UserCertificateFromContext(ctx) - if err != nil { - return nil, trace.Wrap(err) - } + x509Cert := opts.x509Cert + x509Identity := opts.x509Identity - // Sanity check: x509Cert matches identity. - // Both the TLS certificate and the identity come from the same source, so - // they are unlikely to mismatch unless Teleport itself mixes it up. - if x509Cert.Subject.CommonName != identity.Username { + // Sanity check: x509Cert identity matches x509Identity. + if x509Cert.Subject.CommonName != x509Identity.Username { return nil, trace.BadParameter("identity and x509 user mismatch") } + // Do not reissue if device extensions are already present. + // Note that the certIdentity extensions could differ from the "current" + // identity extensions if this was not the cert used to authenticate. + if x509Identity.DeviceExtensions.DeviceID != "" || + x509Identity.DeviceExtensions.AssetTag != "" || + x509Identity.DeviceExtensions.CredentialID != "" { + return nil, trace.BadParameter("device extensions already present") + } + // Parse and verify SSH certificate. - sshAuthorizedKey := opts.SSHAuthorizedKey + sshAuthorizedKey := opts.sshAuthorizedKey var sshCert *ssh.Certificate if len(sshAuthorizedKey) > 0 { var err error @@ -2306,9 +2423,9 @@ func (a *Server) AugmentContextUserCertificates( switch { case sshCert.CertType != ssh.UserCert: return nil, trace.BadParameter("ssh cert type mismatch") - case sshCert.KeyId != identity.Username: + case sshCert.KeyId != x509Identity.Username: return nil, trace.BadParameter("identity and SSH user mismatch") - case !slices.Equal(filterAndSortPrincipals(sshCert.ValidPrincipals), filterAndSortPrincipals(identity.Principals)): + case !slices.Equal(filterAndSortPrincipals(sshCert.ValidPrincipals), filterAndSortPrincipals(x509Identity.Principals)): return nil, trace.BadParameter("identity and SSH principals mismatch") case !apisshutils.KeysEqual(sshCert.Key, xPubKey): return nil, trace.BadParameter("x509 and SSH public key mismatch") @@ -2376,18 +2493,18 @@ func (a *Server) AugmentContextUserCertificates( return nil, trace.Wrap(err) } if err := a.verifyLocksForUserCerts(verifyLocksForUserCertsReq{ - checker: authCtx.Checker, + checker: opts.checker, defaultMode: authPref.GetLockingMode(), - username: identity.Username, - mfaVerified: identity.MFAVerified, - activeAccessRequests: identity.ActiveRequests, - deviceID: opts.DeviceExtensions.DeviceID, // Check lock against requested device. + username: x509Identity.Username, + mfaVerified: x509Identity.MFAVerified, + activeAccessRequests: x509Identity.ActiveRequests, + deviceID: dev.DeviceID, // Check lock against requested device. }); err != nil { return nil, trace.Wrap(err) } // Augment TLS certificate. - newIdentity := identity + newIdentity := x509Identity newIdentity.DeviceExtensions.DeviceID = dev.DeviceID newIdentity.DeviceExtensions.AssetTag = dev.AssetTag newIdentity.DeviceExtensions.CredentialID = dev.CredentialID diff --git a/lib/auth/auth_test.go b/lib/auth/auth_test.go index b82f1c4c5caf8..074918fa9dd10 100644 --- a/lib/auth/auth_test.go +++ b/lib/auth/auth_test.go @@ -1933,6 +1933,210 @@ func TestServer_AugmentContextUserCertificates_errors(t *testing.T) { } } +func TestServer_AugmentWebSessionCertificates(t *testing.T) { + t.Parallel() + + testServer := newTestTLSServer(t) + authServer := testServer.Auth() + ctx := context.Background() + + userData := setupUserForAugmentWebSessionCertificatesTest(t, testServer) + + // Safe to reuse, user-independent. + deviceExts := &DeviceExtensions{ + DeviceID: "my-device-id", + AssetTag: "my-device-asset-tag", + CredentialID: "my-device-credential-id", + } + + assertSSHCert := func(t *testing.T, sshCert []byte) { + cert, err := sshutils.ParseCertificate(sshCert) + require.NoError(t, err, "ParseCertificate") + + // Not empty is good enough here, other tests assert this deeply. + assert.NotEmpty(t, cert.Extensions[teleport.CertExtensionDeviceID], "DeviceID empty") + assert.NotEmpty(t, cert.Extensions[teleport.CertExtensionDeviceAssetTag], "AssetTag empty") + assert.NotEmpty(t, cert.Extensions[teleport.CertExtensionDeviceCredentialID], "CredentialID empty") + } + + assertX509Cert := func(t *testing.T, x509PEM []byte) { + _, identity := parseX509PEMAndIdentity(t, x509PEM) + + // Not empty is good enough here, other tests assert this deeply. + assert.NotEmpty(t, identity.DeviceExtensions.DeviceID, "DeviceID empty") + assert.NotEmpty(t, identity.DeviceExtensions.AssetTag, "AssetTag empty") + assert.NotEmpty(t, identity.DeviceExtensions.CredentialID, "CredentialID empty") + } + + t.Run("ok", func(t *testing.T) { + t.Parallel() // Get the errors suite going asap. + + opts := &AugmentWebSessionCertificatesOpts{ + WebSessionID: userData.webSessionID, + DeviceExtensions: deviceExts, + } + err := authServer.AugmentWebSessionCertificates(ctx, userData.authCtx, opts) + require.NoError(t, err, "AugmentWebSessionCertificates") + + // Verify WebSession certificates. + webSession, err := authServer.WebSessions().Get(ctx, types.GetWebSessionRequest{ + User: userData.user, + SessionID: userData.webSessionID, + }) + require.NoError(t, err, "WebSessions().Get() failed: %v", err) + assertSSHCert(t, webSession.GetPub()) + assertX509Cert(t, webSession.GetTLSCert()) + + // Scenario requires augmented certs to work. + t.Run("cannot re-augment the same session", func(t *testing.T) { + err := authServer.AugmentWebSessionCertificates(ctx, userData.authCtx, opts) + const wantErr = "extensions already present" + assert.ErrorContains(t, err, wantErr, "AugmentWebSessionCertificates error mismatch") + }) + }) + + user2Data := setupUserForAugmentWebSessionCertificatesTest(t, testServer) + user2Opts := &AugmentWebSessionCertificatesOpts{ + WebSessionID: user2Data.webSessionID, + DeviceExtensions: deviceExts, + } + + t.Run("errors", func(t *testing.T) { + tests := []struct { + name string + authCtx *authz.Context + opts *AugmentWebSessionCertificatesOpts + wantErr string + assertErr func(error) bool // defaults to trace.IsBadParameter + }{ + { + name: "authCtx nil", + opts: user2Opts, + wantErr: "authCtx required", + }, + { + name: "opts nil", + authCtx: user2Data.authCtx, + wantErr: "opts required", + }, + { + name: "opts.WebSessionID is empty", + authCtx: user2Data.authCtx, + opts: func() *AugmentWebSessionCertificatesOpts { + opts := *user2Opts + opts.WebSessionID = "" + return &opts + }(), + wantErr: "WebSessionID required", + }, + { + name: "opts.DeviceExtensions nil", + authCtx: user2Data.authCtx, + opts: func() *AugmentWebSessionCertificatesOpts { + opts := *user2Opts + opts.DeviceExtensions = nil + return &opts + }(), + wantErr: "at least one opts extension", + }, + { + // This is the most we can mismatch assuming the session is well-formed. + // Internally the method will still check the certificates against each + // other. + name: "user/session mismatch", + authCtx: userData.authCtx, // user1 + opts: user2Opts, // user2 + wantErr: "session not found", + assertErr: trace.IsNotFound, + }, + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + err := authServer.AugmentWebSessionCertificates(ctx, test.authCtx, test.opts) + assert.ErrorContains(t, err, test.wantErr, "AugmentWebSessionCertificates error mismatch") + + assertErr := test.assertErr + if assertErr == nil { + assertErr = trace.IsBadParameter + } + assert.True(t, + assertErr(err), + "AugmentWebSessionCertificates: assertErr failed: err=%v (%T)", err, trace.Unwrap(err)) + }) + } + }) +} + +type augmentUserData struct { + user string + pass []byte + pubKey []byte // SSH "AuthorizedKey" format + authCtx *authz.Context + webSessionID string +} + +func setupUserForAugmentWebSessionCertificatesTest(t *testing.T, testServer *TestTLSServer) *augmentUserData { + authServer := testServer.Auth() + authorizer := testServer.APIConfig.Authorizer + ctx := context.Background() + + user := &augmentUserData{ + user: "llama_" + uuid.NewString(), + pass: []byte("passwordforllamaA1!"), + } + + // Create user and assign a password. + _, _, err := CreateUserAndRole(authServer, user.user, []string{user.user}, nil /* allowRules */) + require.NoError(t, err, "CreateUserAndRole") + require.NoError(t, + authServer.UpsertPassword(user.user, user.pass), + "UpsertPassword", + ) + + // Generate underlying keys for SSH and TLS. + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err, "GenerateKey") + pubKeySSH, err := ssh.NewPublicKey(privKey.Public()) + require.NoError(t, err, "NewPublicKey") + user.pubKey = ssh.MarshalAuthorizedKey(pubKeySSH) + + // Authenticate user with SSH and use the replies certificates as the context + // identity. + authnReq := AuthenticateUserRequest{ + Username: user.user, + PublicKey: user.pubKey, + Pass: &PassCreds{ + Password: user.pass, + }, + } + authnResp, err := authServer.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{ + AuthenticateUserRequest: authnReq, + TTL: 1 * time.Hour, + }) + require.NoError(t, err, "AuthenticateSSHUser") + + _, userIdentity := parseX509PEMAndIdentity(t, authnResp.TLSCert) + + // Prepare ctx and authz.Context for user. + userCtx := context.Background() + userCtx = authz.ContextWithUser(userCtx, authz.LocalUser{ + Username: user.user, + Identity: *userIdentity, + }) + user.authCtx, err = authorizer.Authorize(userCtx) + require.NoError(t, err, "Authorize failed") + + // Prepare a WebSession to be augmented. + session, err := authServer.AuthenticateWebUser(ctx, authnReq) + require.NoError(t, err, "AuthenticateWebUser") + user.webSessionID = session.GetName() + + return user +} + func TestGenerateUserCertIPPinning(t *testing.T) { modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise})