diff --git a/lib/auth/db.go b/lib/auth/db.go index 40533a0120596..33f2b1fb71d0c 100644 --- a/lib/auth/db.go +++ b/lib/auth/db.go @@ -113,8 +113,17 @@ func (a *Server) generateDatabaseClientCert(ctx context.Context, req *proto.Data } // db clients should trust the Database Server CA when establishing // connection to a database, so return that CA's certs in the response. - dbServerCA, err := a.GetCertAuthority(ctx, types.CertAuthID{ - Type: types.DatabaseCA, + // + // The only exception is the SQL Server with PKINIT integration, where the + // `kinit` command line needs our client CA to trust the user certificates + // we pass. + returnedCAType := types.DatabaseCA + if req.CertificateExtensions == proto.DatabaseCertRequest_WINDOWS_SMARTCARD { + returnedCAType = types.DatabaseClientCA + } + + returnedCA, err := a.GetCertAuthority(ctx, types.CertAuthID{ + Type: returnedCAType, DomainName: clusterName.GetClusterName(), }, false) if err != nil { @@ -122,7 +131,7 @@ func (a *Server) generateDatabaseClientCert(ctx context.Context, req *proto.Data } return &proto.DatabaseCertResponse{ Cert: cert, - CACerts: services.GetTLSCerts(dbServerCA), + CACerts: services.GetTLSCerts(returnedCA), }, nil } diff --git a/lib/auth/db_test.go b/lib/auth/db_test.go index 13ea93087fd33..2565f168a1905 100644 --- a/lib/auth/db_test.go +++ b/lib/auth/db_test.go @@ -153,6 +153,7 @@ func TestDBCertSigning(t *testing.T) { tests := []struct { name string requester proto.DatabaseCertRequest_Requester + extensions proto.DatabaseCertRequest_Extensions wantCertSigner []byte wantCACerts [][]byte wantKeyUsage []x509.ExtKeyUsage @@ -170,16 +171,32 @@ func TestDBCertSigning(t *testing.T) { wantCACerts: [][]byte{activeDBClientCACert, newDBClientCACert}, wantKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, }, + { + name: "DB service request for SQL Server databases is signed by active db client and trusts db client CAs", + extensions: proto.DatabaseCertRequest_WINDOWS_SMARTCARD, + wantCertSigner: activeDBClientCACert, + wantCACerts: [][]byte{activeDBClientCACert, newDBClientCACert}, + wantKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + }, + { + name: "tctl request for SQL Server databases is signed by new db CA and trusts db client CAs", + requester: proto.DatabaseCertRequest_TCTL, + extensions: proto.DatabaseCertRequest_WINDOWS_SMARTCARD, + wantCertSigner: newDBCACert, + wantCACerts: [][]byte{activeDBClientCACert, newDBClientCACert}, + wantKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() certResp, err := authServer.AuthServer.GenerateDatabaseCert(ctx, &proto.DatabaseCertRequest{ - CSR: csr, - ServerName: "localhost", - TTL: proto.Duration(time.Hour), - RequesterName: tt.requester, + CSR: csr, + ServerName: "localhost", + TTL: proto.Duration(time.Hour), + RequesterName: tt.requester, + CertificateExtensions: tt.extensions, }) require.NoError(t, err) require.Equal(t, tt.wantCACerts, certResp.CACerts)