Skip to content

Commit

Permalink
Properly propagate private key for azure app (#48550)
Browse files Browse the repository at this point in the history
* Properly propagate new private key for azure app

* minor refactor

* safety check
  • Loading branch information
greedy52 authored Nov 12, 2024
1 parent 10d8666 commit 26cb848
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 18 deletions.
33 changes: 26 additions & 7 deletions lib/srv/alpnproxy/azure_msi_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"sync"
"time"

"github.com/gravitational/trace"
Expand All @@ -45,15 +46,16 @@ type AzureMSIMiddleware struct {
// ClientID to be returned in a claim.
ClientID string

// Key used to sign JWT
Key crypto.Signer

// Clock is used to override time in tests.
Clock clockwork.Clock
// Log is the Logger.
Log logrus.FieldLogger
// Secret to be provided by the client.
Secret string

// privateKey used to sign JWT
privateKey crypto.Signer
privateKeyMu sync.RWMutex
}

var _ LocalProxyHTTPMiddleware = &AzureMSIMiddleware{}
Expand All @@ -66,9 +68,6 @@ func (m *AzureMSIMiddleware) CheckAndSetDefaults() error {
m.Log = logrus.WithField(teleport.ComponentKey, "azure_msi")
}

if m.Key == nil {
return trace.BadParameter("missing Key")
}
if m.Secret == "" {
return trace.BadParameter("missing Secret")
}
Expand Down Expand Up @@ -96,6 +95,22 @@ func (m *AzureMSIMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Req
return false
}

// SetPrivateKey updates the private key.
func (m *AzureMSIMiddleware) SetPrivateKey(privateKey crypto.Signer) {
m.privateKeyMu.Lock()
defer m.privateKeyMu.Unlock()
m.privateKey = privateKey
}
func (m *AzureMSIMiddleware) getPrivateKey() (crypto.Signer, error) {
m.privateKeyMu.RLock()
defer m.privateKeyMu.RUnlock()
if m.privateKey == nil {
// Use a plain error to return status code 500.
return nil, trace.Errorf("missing private key set in AzureMSIMiddleware")
}
return m.privateKey, nil
}

func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Request) error {
// request validation
if req.URL.Path != ("/" + m.Secret) {
Expand Down Expand Up @@ -173,10 +188,14 @@ func (m *AzureMSIMiddleware) fetchMSILoginResp(resource string) ([]byte, error)
}

func (m *AzureMSIMiddleware) toJWT(claims jwt.AzureTokenClaims) (string, error) {
privateKey, err := m.getPrivateKey()
if err != nil {
return "", trace.Wrap(err)
}
// Create a new key that can sign and verify tokens.
key, err := jwt.New(&jwt.Config{
Clock: m.Clock,
PrivateKey: m.Key,
PrivateKey: privateKey,
ClusterName: types.TeleportAzureMSIEndpoint, // todo get cluster name
})
if err != nil {
Expand Down
26 changes: 24 additions & 2 deletions lib/srv/alpnproxy/azure_msi_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
require.NoError(t, err)
return privateKey
}
privateKey := newPrivateKey()
m := &AzureMSIMiddleware{
Identity: "azureTestIdentity",
TenantID: "cafecafe-cafe-4aaa-cafe-cafecafecafe",
ClientID: "decaffff-cafe-4aaa-cafe-cafecafecafe",
Log: logrus.WithField(teleport.ComponentKey, "msi"),
Clock: clockwork.NewFakeClockAt(time.Date(2022, 1, 1, 9, 0, 0, 0, time.UTC)),
Key: newPrivateKey(),
Secret: "my-secret",
}
require.NoError(t, m.CheckAndSetDefaults())
Expand All @@ -68,6 +68,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name string
url string
headers map[string]string
privateKey crypto.Signer
expectedHandle bool
expectedCode int
expectedBody string
Expand All @@ -76,12 +77,14 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
{
name: "ignore non-msi requests",
url: "https://graph.windows.net/foo/bar/baz",
privateKey: privateKey,
expectedHandle: false,
},
{
name: "invalid request, wrong secret",
url: "https://azure-msi.teleport.dev/bad-secret",
headers: nil,
privateKey: privateKey,
expectedHandle: true,
expectedCode: 400,
expectedBody: "{\n \"error\": {\n \"message\": \"invalid secret\"\n }\n}",
Expand All @@ -90,6 +93,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name: "invalid request, missing secret",
url: "https://azure-msi.teleport.dev",
headers: nil,
privateKey: privateKey,
expectedHandle: true,
expectedCode: 400,
expectedBody: "{\n \"error\": {\n \"message\": \"invalid secret\"\n }\n}",
Expand All @@ -98,6 +102,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name: "invalid request, missing metadata",
url: "https://azure-msi.teleport.dev/my-secret",
headers: nil,
privateKey: privateKey,
expectedHandle: true,
expectedCode: 400,
expectedBody: "{\n \"error\": {\n \"message\": \"expected Metadata header with value 'true'\"\n }\n}",
Expand All @@ -106,6 +111,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name: "invalid request, bad metadata value",
url: "https://azure-msi.teleport.dev/my-secret",
headers: map[string]string{"Metadata": "false"},
privateKey: privateKey,
expectedHandle: true,
expectedCode: 400,
expectedBody: "{\n \"error\": {\n \"message\": \"expected Metadata header with value 'true'\"\n }\n}",
Expand All @@ -114,6 +120,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name: "invalid request, missing arguments",
url: "https://azure-msi.teleport.dev/my-secret",
headers: map[string]string{"Metadata": "true"},
privateKey: privateKey,
expectedHandle: true,
expectedCode: 400,
expectedBody: "{\n \"error\": {\n \"message\": \"missing value for parameter 'resource'\"\n }\n}",
Expand All @@ -122,6 +129,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name: "invalid request, missing resource",
url: "https://azure-msi.teleport.dev/my-secret?msi_res_id=azureTestIdentity",
headers: map[string]string{"Metadata": "true"},
privateKey: privateKey,
expectedHandle: true,
expectedCode: 400,
expectedBody: "{\n \"error\": {\n \"message\": \"missing value for parameter 'resource'\"\n }\n}",
Expand All @@ -130,6 +138,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name: "invalid request, missing identity",
url: "https://azure-msi.teleport.dev/my-secret?resource=myresource",
headers: map[string]string{"Metadata": "true"},
privateKey: privateKey,
expectedHandle: true,
expectedCode: 400,
expectedBody: "{\n \"error\": {\n \"message\": \"unexpected value for parameter 'msi_res_id': \"\n }\n}",
Expand All @@ -138,6 +147,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name: "invalid request, wrong identity",
url: "https://azure-msi.teleport.dev/my-secret?resource=myresource&msi_res_id=azureTestWrongIdentity",
headers: map[string]string{"Metadata": "true"},
privateKey: privateKey,
expectedHandle: true,
expectedCode: 400,
expectedBody: "{\n \"error\": {\n \"message\": \"unexpected value for parameter 'msi_res_id': azureTestWrongIdentity\"\n }\n}",
Expand All @@ -146,6 +156,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name: "well-formatted request",
url: "https://azure-msi.teleport.dev/my-secret?resource=myresource&msi_res_id=azureTestIdentity",
headers: map[string]string{"Metadata": "true"},
privateKey: privateKey,
expectedHandle: true,
expectedCode: 200,
verifyBody: func(t *testing.T, body []byte) {
Expand Down Expand Up @@ -182,7 +193,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
return key.VerifyAzureToken(token)
}

claims, err := fromJWT(req.AccessToken, m.Key)
claims, err := fromJWT(req.AccessToken, privateKey)
require.NoError(t, err)
require.Equal(t, jwt.AzureTokenClaims{
TenantID: "cafecafe-cafe-4aaa-cafe-cafecafecafe",
Expand All @@ -202,10 +213,21 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
require.Equal(t, expected.NotBefore, req.NotBefore)
},
},
{
name: "no private key set",
url: "https://azure-msi.teleport.dev/my-secret?resource=myresource&msi_res_id=azureTestIdentity",
headers: map[string]string{"Metadata": "true"},
privateKey: nil,
expectedHandle: true,
expectedCode: 500,
expectedBody: "{\n \"error\": {\n \"message\": \"missing private key set in AzureMSIMiddleware\"\n }\n}",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m.SetPrivateKey(tt.privateKey)

// prepare request
req, err := http.NewRequest("GET", tt.url, strings.NewReader(""))
require.NoError(t, err)
Expand Down
7 changes: 7 additions & 0 deletions lib/srv/alpnproxy/local_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ type LocalProxyConfig struct {
CheckCertNeeded bool
// verifyUpstreamConnection is a callback function to verify upstream connection state.
verifyUpstreamConnection func(tls.ConnectionState) error
// onSetCert is a callback when lp.SetCert is called.
onSetCert func(tls.Certificate)
}

// LocalProxyMiddleware provides callback functions for LocalProxy.
Expand Down Expand Up @@ -484,6 +486,11 @@ func (l *LocalProxy) SetCert(cert tls.Certificate) {
l.certMu.Lock()
defer l.certMu.Unlock()
l.cfg.Cert = cert

// Callback, if any.
if l.cfg.onSetCert != nil {
l.cfg.onSetCert(cert)
}
}

// getCertForConn determines if certificates should be used when dialing
Expand Down
8 changes: 8 additions & 0 deletions lib/srv/alpnproxy/local_proxy_config_opt.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,11 @@ func mySQLVersionToProto(database types.Database) string {
// Include MySQL server version
return string(common.ProtocolMySQLWithVerPrefix) + versionBase64
}

// WithOnSetCert provides a callback when lp.SetCert is called.
func WithOnSetCert(callback func(tls.Certificate)) LocalProxyConfigOpt {
return func(config *LocalProxyConfig) error {
config.onSetCert = callback
return nil
}
}
23 changes: 14 additions & 9 deletions tool/tsh/common/app_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package common
import (
"context"
"crypto"
"crypto/tls"
"fmt"
"os"
"os/exec"
Expand Down Expand Up @@ -72,17 +73,11 @@ type azureApp struct {
*localProxyApp

cf *CLIConf
signer crypto.Signer
msiSecret string
}

// newAzureApp creates a new Azure app.
func newAzureApp(tc *client.TeleportClient, cf *CLIConf, appInfo *appInfo) (*azureApp, error) {
keyRing, err := tc.LocalAgent().GetCoreKeyRing()
if err != nil {
return nil, trace.Wrap(err)
}

msiSecret, err := getMSISecret()
if err != nil {
return nil, err
Expand All @@ -91,7 +86,6 @@ func newAzureApp(tc *client.TeleportClient, cf *CLIConf, appInfo *appInfo) (*azu
return &azureApp{
localProxyApp: newLocalProxyApp(tc, appInfo, cf.LocalProxyPort, cf.InsecureSkipVerify),
cf: cf,
signer: keyRing.TLSPrivateKey,
msiSecret: msiSecret,
}, nil
}
Expand Down Expand Up @@ -133,7 +127,6 @@ func getMSISecret() (string, error) {
// These calls are served entirely locally, which helps the overall performance experienced by the user.
func (a *azureApp) StartLocalProxies(ctx context.Context) error {
azureMiddleware := &alpnproxy.AzureMSIMiddleware{
Key: a.signer,
Secret: a.msiSecret,
// we could, in principle, get the actual TenantID either from live data or from static configuration,
// but at this moment there is no clear advantage over simply issuing a new random identifier.
Expand All @@ -143,7 +136,19 @@ func (a *azureApp) StartLocalProxies(ctx context.Context) error {
}

// HTTPS proxy mode
err := a.StartLocalProxyWithForwarder(ctx, alpnproxy.MatchAzureRequests, alpnproxy.WithHTTPMiddleware(azureMiddleware))
err := a.StartLocalProxyWithForwarder(ctx,
alpnproxy.MatchAzureRequests,
alpnproxy.WithHTTPMiddleware(azureMiddleware),
alpnproxy.WithOnSetCert(func(cert tls.Certificate) {
// Note that the PrivateKey is most likely set by api/utils/keys.TLSCertificateForSigner.
signer, ok := cert.PrivateKey.(crypto.Signer)
if ok {
azureMiddleware.SetPrivateKey(signer)
} else {
log.Warn("Provided tls.Certificate has no valid private key.")
}
}),
)
return trace.Wrap(err)
}

Expand Down

0 comments on commit 26cb848

Please sign in to comment.