Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly propagate private key for azure app #48550

Merged
merged 3 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions lib/srv/alpnproxy/azure_msi_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"sync"
"time"

"github.com/gravitational/trace"
Expand All @@ -45,15 +46,16 @@ type AzureMSIMiddleware struct {
// ClientID to be returned in a claim.
ClientID string

// Key used to sign JWT
Key crypto.Signer

// Clock is used to override time in tests.
Clock clockwork.Clock
// Log is the Logger.
Log logrus.FieldLogger
// Secret to be provided by the client.
Secret string

// privateKey used to sign JWT
privateKey crypto.Signer
privateKeyMu sync.RWMutex
}

var _ LocalProxyHTTPMiddleware = &AzureMSIMiddleware{}
Expand All @@ -66,9 +68,6 @@ func (m *AzureMSIMiddleware) CheckAndSetDefaults() error {
m.Log = logrus.WithField(teleport.ComponentKey, "azure_msi")
}

if m.Key == nil {
return trace.BadParameter("missing Key")
}
if m.Secret == "" {
return trace.BadParameter("missing Secret")
}
Expand Down Expand Up @@ -96,6 +95,22 @@ func (m *AzureMSIMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Req
return false
}

// SetPrivateKey updates the private key.
func (m *AzureMSIMiddleware) SetPrivateKey(privateKey crypto.Signer) {
m.privateKeyMu.Lock()
defer m.privateKeyMu.Unlock()
m.privateKey = privateKey
}
func (m *AzureMSIMiddleware) getPrivateKey() (crypto.Signer, error) {
m.privateKeyMu.RLock()
defer m.privateKeyMu.RUnlock()
if m.privateKey == nil {
// Use a plain error to return status code 500.
return nil, trace.Errorf("missing private key set in AzureMSIMiddleware")
}
return m.privateKey, nil
}

func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Request) error {
// request validation
if req.URL.Path != ("/" + m.Secret) {
Expand Down Expand Up @@ -173,10 +188,14 @@ func (m *AzureMSIMiddleware) fetchMSILoginResp(resource string) ([]byte, error)
}

func (m *AzureMSIMiddleware) toJWT(claims jwt.AzureTokenClaims) (string, error) {
privateKey, err := m.getPrivateKey()
if err != nil {
return "", trace.Wrap(err)
}
// Create a new key that can sign and verify tokens.
key, err := jwt.New(&jwt.Config{
Clock: m.Clock,
PrivateKey: m.Key,
PrivateKey: privateKey,
ClusterName: types.TeleportAzureMSIEndpoint, // todo get cluster name
})
if err != nil {
Expand Down
26 changes: 24 additions & 2 deletions lib/srv/alpnproxy/azure_msi_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
require.NoError(t, err)
return privateKey
}
privateKey := newPrivateKey()
m := &AzureMSIMiddleware{
Identity: "azureTestIdentity",
TenantID: "cafecafe-cafe-4aaa-cafe-cafecafecafe",
ClientID: "decaffff-cafe-4aaa-cafe-cafecafecafe",
Log: logrus.WithField(teleport.ComponentKey, "msi"),
Clock: clockwork.NewFakeClockAt(time.Date(2022, 1, 1, 9, 0, 0, 0, time.UTC)),
Key: newPrivateKey(),
Secret: "my-secret",
}
require.NoError(t, m.CheckAndSetDefaults())
Expand All @@ -68,6 +68,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name string
url string
headers map[string]string
privateKey crypto.Signer
expectedHandle bool
expectedCode int
expectedBody string
Expand All @@ -76,12 +77,14 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
{
name: "ignore non-msi requests",
url: "https://graph.windows.net/foo/bar/baz",
privateKey: privateKey,
expectedHandle: false,
},
{
name: "invalid request, wrong secret",
url: "https://azure-msi.teleport.dev/bad-secret",
headers: nil,
privateKey: privateKey,
expectedHandle: true,
expectedCode: 400,
expectedBody: "{\n \"error\": {\n \"message\": \"invalid secret\"\n }\n}",
Expand All @@ -90,6 +93,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name: "invalid request, missing secret",
url: "https://azure-msi.teleport.dev",
headers: nil,
privateKey: privateKey,
expectedHandle: true,
expectedCode: 400,
expectedBody: "{\n \"error\": {\n \"message\": \"invalid secret\"\n }\n}",
Expand All @@ -98,6 +102,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name: "invalid request, missing metadata",
url: "https://azure-msi.teleport.dev/my-secret",
headers: nil,
privateKey: privateKey,
expectedHandle: true,
expectedCode: 400,
expectedBody: "{\n \"error\": {\n \"message\": \"expected Metadata header with value 'true'\"\n }\n}",
Expand All @@ -106,6 +111,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name: "invalid request, bad metadata value",
url: "https://azure-msi.teleport.dev/my-secret",
headers: map[string]string{"Metadata": "false"},
privateKey: privateKey,
expectedHandle: true,
expectedCode: 400,
expectedBody: "{\n \"error\": {\n \"message\": \"expected Metadata header with value 'true'\"\n }\n}",
Expand All @@ -114,6 +120,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name: "invalid request, missing arguments",
url: "https://azure-msi.teleport.dev/my-secret",
headers: map[string]string{"Metadata": "true"},
privateKey: privateKey,
expectedHandle: true,
expectedCode: 400,
expectedBody: "{\n \"error\": {\n \"message\": \"missing value for parameter 'resource'\"\n }\n}",
Expand All @@ -122,6 +129,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name: "invalid request, missing resource",
url: "https://azure-msi.teleport.dev/my-secret?msi_res_id=azureTestIdentity",
headers: map[string]string{"Metadata": "true"},
privateKey: privateKey,
expectedHandle: true,
expectedCode: 400,
expectedBody: "{\n \"error\": {\n \"message\": \"missing value for parameter 'resource'\"\n }\n}",
Expand All @@ -130,6 +138,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name: "invalid request, missing identity",
url: "https://azure-msi.teleport.dev/my-secret?resource=myresource",
headers: map[string]string{"Metadata": "true"},
privateKey: privateKey,
expectedHandle: true,
expectedCode: 400,
expectedBody: "{\n \"error\": {\n \"message\": \"unexpected value for parameter 'msi_res_id': \"\n }\n}",
Expand All @@ -138,6 +147,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name: "invalid request, wrong identity",
url: "https://azure-msi.teleport.dev/my-secret?resource=myresource&msi_res_id=azureTestWrongIdentity",
headers: map[string]string{"Metadata": "true"},
privateKey: privateKey,
expectedHandle: true,
expectedCode: 400,
expectedBody: "{\n \"error\": {\n \"message\": \"unexpected value for parameter 'msi_res_id': azureTestWrongIdentity\"\n }\n}",
Expand All @@ -146,6 +156,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
name: "well-formatted request",
url: "https://azure-msi.teleport.dev/my-secret?resource=myresource&msi_res_id=azureTestIdentity",
headers: map[string]string{"Metadata": "true"},
privateKey: privateKey,
expectedHandle: true,
expectedCode: 200,
verifyBody: func(t *testing.T, body []byte) {
Expand Down Expand Up @@ -182,7 +193,7 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
return key.VerifyAzureToken(token)
}

claims, err := fromJWT(req.AccessToken, m.Key)
claims, err := fromJWT(req.AccessToken, privateKey)
require.NoError(t, err)
require.Equal(t, jwt.AzureTokenClaims{
TenantID: "cafecafe-cafe-4aaa-cafe-cafecafecafe",
Expand All @@ -202,10 +213,21 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith
require.Equal(t, expected.NotBefore, req.NotBefore)
},
},
{
name: "no private key set",
url: "https://azure-msi.teleport.dev/my-secret?resource=myresource&msi_res_id=azureTestIdentity",
headers: map[string]string{"Metadata": "true"},
privateKey: nil,
expectedHandle: true,
expectedCode: 500,
expectedBody: "{\n \"error\": {\n \"message\": \"missing private key set in AzureMSIMiddleware\"\n }\n}",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m.SetPrivateKey(tt.privateKey)

// prepare request
req, err := http.NewRequest("GET", tt.url, strings.NewReader(""))
require.NoError(t, err)
Expand Down
7 changes: 7 additions & 0 deletions lib/srv/alpnproxy/local_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ type LocalProxyConfig struct {
CheckCertNeeded bool
// verifyUpstreamConnection is a callback function to verify upstream connection state.
verifyUpstreamConnection func(tls.ConnectionState) error
// onSetCert is a callback when lp.SetCert is called.
onSetCert func(tls.Certificate)
}

// LocalProxyMiddleware provides callback functions for LocalProxy.
Expand Down Expand Up @@ -484,6 +486,11 @@ func (l *LocalProxy) SetCert(cert tls.Certificate) {
l.certMu.Lock()
defer l.certMu.Unlock()
l.cfg.Cert = cert

// Callback, if any.
if l.cfg.onSetCert != nil {
l.cfg.onSetCert(cert)
}
}

// getCertForConn determines if certificates should be used when dialing
Expand Down
8 changes: 8 additions & 0 deletions lib/srv/alpnproxy/local_proxy_config_opt.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,11 @@ func mySQLVersionToProto(database types.Database) string {
// Include MySQL server version
return string(common.ProtocolMySQLWithVerPrefix) + versionBase64
}

// WithOnSetCert provides a callback when lp.SetCert is called.
func WithOnSetCert(callback func(tls.Certificate)) LocalProxyConfigOpt {
return func(config *LocalProxyConfig) error {
config.onSetCert = callback
return nil
}
}
23 changes: 14 additions & 9 deletions tool/tsh/common/app_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package common
import (
"context"
"crypto"
"crypto/tls"
"fmt"
"os"
"os/exec"
Expand Down Expand Up @@ -72,17 +73,11 @@ type azureApp struct {
*localProxyApp

cf *CLIConf
signer crypto.Signer
msiSecret string
}

// newAzureApp creates a new Azure app.
func newAzureApp(tc *client.TeleportClient, cf *CLIConf, appInfo *appInfo) (*azureApp, error) {
keyRing, err := tc.LocalAgent().GetCoreKeyRing()
if err != nil {
return nil, trace.Wrap(err)
}

msiSecret, err := getMSISecret()
if err != nil {
return nil, err
Expand All @@ -91,7 +86,6 @@ func newAzureApp(tc *client.TeleportClient, cf *CLIConf, appInfo *appInfo) (*azu
return &azureApp{
localProxyApp: newLocalProxyApp(tc, appInfo, cf.LocalProxyPort, cf.InsecureSkipVerify),
cf: cf,
signer: keyRing.TLSPrivateKey,
msiSecret: msiSecret,
}, nil
}
Expand Down Expand Up @@ -133,7 +127,6 @@ func getMSISecret() (string, error) {
// These calls are served entirely locally, which helps the overall performance experienced by the user.
func (a *azureApp) StartLocalProxies(ctx context.Context) error {
azureMiddleware := &alpnproxy.AzureMSIMiddleware{
Key: a.signer,
Secret: a.msiSecret,
// we could, in principle, get the actual TenantID either from live data or from static configuration,
// but at this moment there is no clear advantage over simply issuing a new random identifier.
Expand All @@ -143,7 +136,19 @@ func (a *azureApp) StartLocalProxies(ctx context.Context) error {
}

// HTTPS proxy mode
err := a.StartLocalProxyWithForwarder(ctx, alpnproxy.MatchAzureRequests, alpnproxy.WithHTTPMiddleware(azureMiddleware))
err := a.StartLocalProxyWithForwarder(ctx,
alpnproxy.MatchAzureRequests,
alpnproxy.WithHTTPMiddleware(azureMiddleware),
alpnproxy.WithOnSetCert(func(cert tls.Certificate) {
// Note that the PrivateKey is most likely set by api/utils/keys.TLSCertificateForSigner.
signer, ok := cert.PrivateKey.(crypto.Signer)
if ok {
azureMiddleware.SetPrivateKey(signer)
} else {
log.Warn("Provided tls.Certificate has no valid private key.")
Tener marked this conversation as resolved.
Show resolved Hide resolved
}
}),
)
return trace.Wrap(err)
}

Expand Down
Loading