Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method to augment WebSession with device extensions #39147

Merged
merged 1 commit into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 143 additions & 26 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"crypto/subtle"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -2226,7 +2227,8 @@ type AugmentUserCertificateOpts struct {
// Used by Device Trust to add device extensions to the user certificate.
func (a *Server) AugmentContextUserCertificates(
ctx context.Context,
authCtx *authz.Context, opts *AugmentUserCertificateOpts,
authCtx *authz.Context,
opts *AugmentUserCertificateOpts,
) (*proto.Certs, error) {
switch {
case authCtx == nil:
Expand All @@ -2235,10 +2237,126 @@ func (a *Server) AugmentContextUserCertificates(
return nil, trace.BadParameter("opts required")
}

// Fetch user TLS certificate.
x509Cert, err := authz.UserCertificateFromContext(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

identity := authCtx.Identity.GetIdentity()

return a.augmentUserCertificates(ctx, augmentUserCertificatesOpts{
checker: authCtx.Checker,
x509Cert: x509Cert,
x509Identity: &identity,
sshAuthorizedKey: opts.SSHAuthorizedKey,
deviceExtensions: opts.DeviceExtensions,
})
}

// AugmentWebSessionCertificatesOpts aggregates arguments for
// [AugmentWebSessionCertificates].
type AugmentWebSessionCertificatesOpts struct {
// WebSessionID is the identifier for the WebSession.
WebSessionID string

// DeviceExtensions are the device-aware extensions to add to the certificates
// being augmented.
DeviceExtensions *DeviceExtensions
}

// AugmentWebSessionCertificates is a variant of
// [AugmentContextUserCertificates] that operates directly in the certificates
// stored in a WebSession.
//
// The authCtx user must be the owner of the session. Unlike
// [AugmentContextUserCertificates], the user certificate doesn't need to be
// present in the ctx, as the session certificates are used.
//
// On success the WebSession is updated with device extension certificates.
func (a *Server) AugmentWebSessionCertificates(
ctx context.Context,
authCtx *authz.Context,
opts *AugmentWebSessionCertificatesOpts,
) error {
switch {
case authCtx == nil:
return trace.BadParameter("authCtx required")
case opts == nil:
return trace.BadParameter("opts required")
case opts.WebSessionID == "":
return trace.BadParameter("opts.WebSessionID required")
}

identity := authCtx.Identity.GetIdentity()

// Get and validate session.
sessions := a.WebSessions()
session, err := sessions.Get(ctx, types.GetWebSessionRequest{
User: identity.Username,
SessionID: opts.WebSessionID,
})
if err != nil {
return trace.Wrap(err)
}
// Sanity check: session must belong to user.
if session.GetUser() != identity.Username {
return trace.AccessDenied("identity and session user mismatch")
codingllama marked this conversation as resolved.
Show resolved Hide resolved
}

// Coerce session before doing more expensive operations.
sessionV2, ok := session.(*types.WebSessionV2)
if !ok {
return trace.BadParameter("unexpected WebSession type: %T", session)
}

// Parse X.509 certificate.
block, _ := pem.Decode(session.GetTLSCert())
if block == nil {
return trace.BadParameter("cannot decode session TLS certificate")
}
x509Cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return trace.Wrap(err)
}
x509Identity, err := tlsca.FromSubject(x509Cert.Subject, x509Cert.NotAfter)
if err != nil {
return trace.Wrap(err)
}

// Augment certificates.
newCerts, err := a.augmentUserCertificates(ctx, augmentUserCertificatesOpts{
checker: authCtx.Checker,
x509Cert: x509Cert,
x509Identity: x509Identity,
sshAuthorizedKey: session.GetPub(),
deviceExtensions: opts.DeviceExtensions,
})
if err != nil {
return trace.Wrap(err)
}

// Update WebSession.
sessionV2.Spec.Pub = newCerts.SSH
sessionV2.Spec.TLSCert = newCerts.TLS
return trace.Wrap(sessions.Upsert(ctx, sessionV2))
}

type augmentUserCertificatesOpts struct {
checker services.AccessChecker
x509Cert *x509.Certificate
x509Identity *tlsca.Identity
sshAuthorizedKey []byte
deviceExtensions *DeviceExtensions
}

func (a *Server) augmentUserCertificates(
ctx context.Context,
opts augmentUserCertificatesOpts,
) (*proto.Certs, error) {
// Is at least one extension present?
// Are the extensions valid?
identity := authCtx.Identity.GetIdentity()
dev := opts.DeviceExtensions
dev := opts.deviceExtensions
switch {
case dev == nil: // Only extension that currently exists.
return nil, trace.BadParameter("at least one opts extension must be present")
Expand All @@ -2248,28 +2366,27 @@ func (a *Server) AugmentContextUserCertificates(
return nil, trace.BadParameter("opts.DeviceExtensions.AssetTag required")
case dev.CredentialID == "":
return nil, trace.BadParameter("opts.DeviceExtensions.CredentialID required")
// Do not reissue if device extensions are already present.
case identity.DeviceExtensions.DeviceID != "",
identity.DeviceExtensions.AssetTag != "",
identity.DeviceExtensions.CredentialID != "":
return nil, trace.BadParameter("device extensions already present")
}

// Fetch user TLS certificate.
x509Cert, err := authz.UserCertificateFromContext(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
x509Cert := opts.x509Cert
x509Identity := opts.x509Identity

// Sanity check: x509Cert matches identity.
// Both the TLS certificate and the identity come from the same source, so
// they are unlikely to mismatch unless Teleport itself mixes it up.
if x509Cert.Subject.CommonName != identity.Username {
// Sanity check: x509Cert identity matches x509Identity.
if x509Cert.Subject.CommonName != x509Identity.Username {
return nil, trace.BadParameter("identity and x509 user mismatch")
}

// Do not reissue if device extensions are already present.
// Note that the certIdentity extensions could differ from the "current"
// identity extensions if this was not the cert used to authenticate.
if x509Identity.DeviceExtensions.DeviceID != "" ||
x509Identity.DeviceExtensions.AssetTag != "" ||
x509Identity.DeviceExtensions.CredentialID != "" {
return nil, trace.BadParameter("device extensions already present")
}

// Parse and verify SSH certificate.
sshAuthorizedKey := opts.SSHAuthorizedKey
sshAuthorizedKey := opts.sshAuthorizedKey
var sshCert *ssh.Certificate
if len(sshAuthorizedKey) > 0 {
var err error
Expand Down Expand Up @@ -2306,9 +2423,9 @@ func (a *Server) AugmentContextUserCertificates(
switch {
case sshCert.CertType != ssh.UserCert:
return nil, trace.BadParameter("ssh cert type mismatch")
case sshCert.KeyId != identity.Username:
case sshCert.KeyId != x509Identity.Username:
return nil, trace.BadParameter("identity and SSH user mismatch")
case !slices.Equal(filterAndSortPrincipals(sshCert.ValidPrincipals), filterAndSortPrincipals(identity.Principals)):
case !slices.Equal(filterAndSortPrincipals(sshCert.ValidPrincipals), filterAndSortPrincipals(x509Identity.Principals)):
return nil, trace.BadParameter("identity and SSH principals mismatch")
case !apisshutils.KeysEqual(sshCert.Key, xPubKey):
return nil, trace.BadParameter("x509 and SSH public key mismatch")
Expand Down Expand Up @@ -2376,18 +2493,18 @@ func (a *Server) AugmentContextUserCertificates(
return nil, trace.Wrap(err)
}
if err := a.verifyLocksForUserCerts(verifyLocksForUserCertsReq{
checker: authCtx.Checker,
checker: opts.checker,
defaultMode: authPref.GetLockingMode(),
username: identity.Username,
mfaVerified: identity.MFAVerified,
activeAccessRequests: identity.ActiveRequests,
deviceID: opts.DeviceExtensions.DeviceID, // Check lock against requested device.
username: x509Identity.Username,
mfaVerified: x509Identity.MFAVerified,
activeAccessRequests: x509Identity.ActiveRequests,
deviceID: dev.DeviceID, // Check lock against requested device.
}); err != nil {
return nil, trace.Wrap(err)
}

// Augment TLS certificate.
newIdentity := identity
newIdentity := x509Identity
newIdentity.DeviceExtensions.DeviceID = dev.DeviceID
newIdentity.DeviceExtensions.AssetTag = dev.AssetTag
newIdentity.DeviceExtensions.CredentialID = dev.CredentialID
Expand Down
Loading
Loading