Skip to content

Commit

Permalink
Properly propagate new private key for azure app
Browse files Browse the repository at this point in the history
  • Loading branch information
greedy52 committed Nov 6, 2024
1 parent 779395a commit 9ee501b
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 17 deletions.
34 changes: 27 additions & 7 deletions lib/srv/alpnproxy/azure_msi_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ package alpnproxy

import (
"crypto"
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"

"github.com/gravitational/trace"
Expand All @@ -45,15 +47,16 @@ type AzureMSIMiddleware struct {
// ClientID to be returned in a claim.
ClientID string

// Key used to sign JWT
Key crypto.Signer

// Clock is used to override time in tests.
Clock clockwork.Clock
// Log is the Logger.
Log logrus.FieldLogger
// Secret to be provided by the client.
Secret string

// Key used to sign JWT
key crypto.Signer
keyMu sync.RWMutex
}

var _ LocalProxyHTTPMiddleware = &AzureMSIMiddleware{}
Expand All @@ -66,9 +69,6 @@ func (m *AzureMSIMiddleware) CheckAndSetDefaults() error {
m.Log = logrus.WithField(teleport.ComponentKey, "azure_msi")
}

if m.Key == nil {
return trace.BadParameter("missing Key")
}
if m.Secret == "" {
return trace.BadParameter("missing Secret")
}
Expand Down Expand Up @@ -96,6 +96,26 @@ func (m *AzureMSIMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Req
return false
}

func (m *AzureMSIMiddleware) OnSetCert(cert *tls.Certificate) {
m.keyMu.Lock()
defer m.keyMu.Unlock()

if cert != nil {
// Note that the PrivateKey is most likely set by api/utils/keys.TLSCertificateForSigner
signer, ok := cert.PrivateKey.(crypto.Signer)
if ok {
m.key = signer
} else {
m.Log.Warn("Provided tls.Certificate has no valid private key")
}
}
}
func (m *AzureMSIMiddleware) getPrivateKey() crypto.Signer {
m.keyMu.RLock()
defer m.keyMu.RUnlock()
return m.key
}

func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Request) error {
// request validation
if req.URL.Path != ("/" + m.Secret) {
Expand Down Expand Up @@ -176,7 +196,7 @@ func (m *AzureMSIMiddleware) toJWT(claims jwt.AzureTokenClaims) (string, error)
// Create a new key that can sign and verify tokens.
key, err := jwt.New(&jwt.Config{
Clock: m.Clock,
PrivateKey: m.Key,
PrivateKey: m.getPrivateKey(),
ClusterName: types.TeleportAzureMSIEndpoint, // todo get cluster name
})
if err != nil {
Expand Down
7 changes: 7 additions & 0 deletions lib/srv/alpnproxy/local_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ type LocalProxyConfig struct {
CheckCertNeeded bool
// verifyUpstreamConnection is a callback function to verify upstream connection state.
verifyUpstreamConnection func(tls.ConnectionState) error
// TODO
onSetCert func(*tls.Certificate)
}

// LocalProxyMiddleware provides callback functions for LocalProxy.
Expand Down Expand Up @@ -484,6 +486,11 @@ func (l *LocalProxy) SetCert(cert tls.Certificate) {
l.certMu.Lock()
defer l.certMu.Unlock()
l.cfg.Cert = cert

// Callback, if any.
if l.cfg.onSetCert != nil {
l.cfg.onSetCert(&cert)
}
}

// getCertForConn determines if certificates should be used when dialing
Expand Down
7 changes: 7 additions & 0 deletions lib/srv/alpnproxy/local_proxy_config_opt.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,10 @@ func mySQLVersionToProto(database types.Database) string {
// Include MySQL server version
return string(common.ProtocolMySQLWithVerPrefix) + versionBase64
}

func WithOnSetCert(callback func(*tls.Certificate)) LocalProxyConfigOpt {
return func(config *LocalProxyConfig) error {
config.onSetCert = callback
return nil
}
}
15 changes: 5 additions & 10 deletions tool/tsh/common/app_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package common

import (
"context"
"crypto"
"fmt"
"os"
"os/exec"
Expand Down Expand Up @@ -72,17 +71,11 @@ type azureApp struct {
*localProxyApp

cf *CLIConf
signer crypto.Signer
msiSecret string
}

// newAzureApp creates a new Azure app.
func newAzureApp(tc *client.TeleportClient, cf *CLIConf, appInfo *appInfo) (*azureApp, error) {
keyRing, err := tc.LocalAgent().GetCoreKeyRing()
if err != nil {
return nil, trace.Wrap(err)
}

msiSecret, err := getMSISecret()
if err != nil {
return nil, err
Expand All @@ -91,7 +84,6 @@ func newAzureApp(tc *client.TeleportClient, cf *CLIConf, appInfo *appInfo) (*azu
return &azureApp{
localProxyApp: newLocalProxyApp(tc, appInfo, cf.LocalProxyPort, cf.InsecureSkipVerify),
cf: cf,
signer: keyRing.TLSPrivateKey,
msiSecret: msiSecret,
}, nil
}
Expand Down Expand Up @@ -133,7 +125,6 @@ func getMSISecret() (string, error) {
// These calls are served entirely locally, which helps the overall performance experienced by the user.
func (a *azureApp) StartLocalProxies(ctx context.Context) error {
azureMiddleware := &alpnproxy.AzureMSIMiddleware{
Key: a.signer,
Secret: a.msiSecret,
// we could, in principle, get the actual TenantID either from live data or from static configuration,
// but at this moment there is no clear advantage over simply issuing a new random identifier.
Expand All @@ -143,7 +134,11 @@ func (a *azureApp) StartLocalProxies(ctx context.Context) error {
}

// HTTPS proxy mode
err := a.StartLocalProxyWithForwarder(ctx, alpnproxy.MatchAzureRequests, alpnproxy.WithHTTPMiddleware(azureMiddleware))
err := a.StartLocalProxyWithForwarder(ctx,
alpnproxy.MatchAzureRequests,
alpnproxy.WithHTTPMiddleware(azureMiddleware),
alpnproxy.WithOnSetCert(azureMiddleware.OnSetCert),
)
return trace.Wrap(err)
}

Expand Down

0 comments on commit 9ee501b

Please sign in to comment.