diff --git a/lib/auth/bot_test.go b/lib/auth/bot_test.go index 4c54114a96522..6a8bd417bba8e 100644 --- a/lib/auth/bot_test.go +++ b/lib/auth/bot_test.go @@ -388,7 +388,7 @@ func TestRegisterBot_RemoteAddr(t *testing.T) { t.Run("Azure method", func(t *testing.T) { subID := uuid.NewString() resourceGroup := "rg" - rsID := resourceID(subID, resourceGroup, "test-vm") + rsID := vmResourceID(subID, resourceGroup, "test-vm") vmID := "vmID" accessToken, err := makeToken(rsID, a.clock.Now()) @@ -408,13 +408,20 @@ func TestRegisterBot_RemoteAddr(t *testing.T) { require.NoError(t, err) require.NoError(t, a.UpsertToken(ctx, azureToken)) - vmClient := &mockAzureVMClient{vm: &azure.VirtualMachine{ - ID: rsID, - Name: "test-vm", - Subscription: subID, - ResourceGroup: resourceGroup, - VMID: vmID, - }} + vmClient := &mockAzureVMClient{ + vms: map[string]*azure.VirtualMachine{ + rsID: { + ID: rsID, + Name: "test-vm", + Subscription: subID, + ResourceGroup: resourceGroup, + VMID: vmID, + }, + }, + } + getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{ + subID: vmClient, + }) tlsConfig, err := fixtures.LocalTLSConfig() require.NoError(t, err) @@ -456,7 +463,7 @@ func TestRegisterBot_RemoteAddr(t *testing.T) { AccessToken: accessToken, } return req, nil - }, withCerts([]*x509.Certificate{tlsConfig.Certificate}), withVerifyFunc(mockVerifyToken(nil)), withVMClient(vmClient)) + }, withCerts([]*x509.Certificate{tlsConfig.Certificate}), withVerifyFunc(mockVerifyToken(nil)), withVMClientGetter(getVMClient)) require.NoError(t, err) checkCertLoginIP(t, certs.TLS, remoteAddr) }) diff --git a/lib/auth/join_azure.go b/lib/auth/join_azure.go index e044d4e810a69..4ce6311a6970f 100644 --- a/lib/auth/join_azure.go +++ b/lib/auth/join_azure.go @@ -83,11 +83,13 @@ type accessTokenClaims struct { type azureVerifyTokenFunc func(ctx context.Context, rawIDToken string) (*accessTokenClaims, error) +type vmClientGetter func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) + type azureRegisterConfig struct { clock clockwork.Clock certificateAuthorities []*x509.Certificate verify azureVerifyTokenFunc - vmClient azure.VirtualMachinesClient + getVMClient vmClientGetter } func azureVerifyFuncFromOIDCVerifier(cfg *oidc.Config) azureVerifyTokenFunc { @@ -140,6 +142,12 @@ func (cfg *azureRegisterConfig) CheckAndSetDefaults(ctx context.Context) error { } cfg.certificateAuthorities = certs } + if cfg.getVMClient == nil { + cfg.getVMClient = func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) { + client, err := azure.NewVirtualMachinesClient(subscriptionID, token, nil) + return client, trace.Wrap(err) + } + } return nil } @@ -148,42 +156,42 @@ type azureRegisterOption func(cfg *azureRegisterConfig) // parseAndVeryAttestedData verifies that an attested data document was signed // by Azure. If verification is successful, it returns the ID of the VM that // produced the document. -func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge string, certs []*x509.Certificate) (string, error) { +func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge string, certs []*x509.Certificate) (subscriptionID, vmID string, err error) { var signedAD signedAttestedData if err := utils.FastUnmarshal(adBytes, &signedAD); err != nil { - return "", trace.Wrap(err) + return "", "", trace.Wrap(err) } if signedAD.Encoding != "pkcs7" { - return "", trace.AccessDenied("unsupported signature type: %v", signedAD.Encoding) + return "", "", trace.AccessDenied("unsupported signature type: %v", signedAD.Encoding) } sigPEM := "-----BEGIN PKCS7-----\n" + signedAD.Signature + "\n-----END PKCS7-----" sigBER, _ := pem.Decode([]byte(sigPEM)) if sigBER == nil { - return "", trace.AccessDenied("unable to decode attested data document") + return "", "", trace.AccessDenied("unable to decode attested data document") } p7, err := pkcs7.Parse(sigBER.Bytes) if err != nil { - return "", trace.Wrap(err) + return "", "", trace.Wrap(err) } var ad attestedData if err := utils.FastUnmarshal(p7.Content, &ad); err != nil { - return "", trace.Wrap(err) + return "", "", trace.Wrap(err) } if ad.Nonce != challenge { - return "", trace.AccessDenied("challenge is missing or does not match") + return "", "", trace.AccessDenied("challenge is missing or does not match") } if len(p7.Certificates) == 0 { - return "", trace.AccessDenied("no certificates for signature") + return "", "", trace.AccessDenied("no certificates for signature") } fixAzureSigningAlgorithm(p7) // Azure only sends the leaf cert, so we have to fetch the intermediate. intermediate, err := getAzureIssuerCert(ctx, p7.Certificates[0]) if err != nil { - return "", trace.Wrap(err) + return "", "", trace.Wrap(err) } if intermediate != nil { p7.Certificates = append(p7.Certificates, intermediate) @@ -195,15 +203,15 @@ func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge s } if err := p7.VerifyWithChain(pool); err != nil { - return "", trace.Wrap(err) + return "", "", trace.Wrap(err) } - return ad.ID, nil + return ad.SubscriptionID, ad.ID, nil } // verifyVMIdentity verifies that the provided access token came from the // correct Azure VM. -func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken, vmID string, requestStart time.Time) (*azure.VirtualMachine, error) { +func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken, subscriptionID, vmID string, requestStart time.Time) (*azure.VirtualMachine, error) { tokenClaims, err := cfg.verify(ctx, accessToken) if err != nil { return nil, trace.Wrap(err) @@ -231,24 +239,15 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken return nil, trace.Wrap(err) } - rsID, err := arm.ParseResourceID(tokenClaims.ResourceID) + tokenCredential := azure.NewStaticCredential(azcore.AccessToken{ + Token: accessToken, + ExpiresOn: tokenClaims.Expiry.Time(), + }) + vmClient, err := cfg.getVMClient(subscriptionID, tokenCredential) if err != nil { return nil, trace.Wrap(err) } - vmClient := cfg.vmClient - if vmClient == nil { - tokenCredential := azure.NewStaticCredential(azcore.AccessToken{ - Token: accessToken, - ExpiresOn: tokenClaims.Expiry.Time(), - }) - var err error - vmClient, err = azure.NewVirtualMachinesClient(rsID.SubscriptionID, tokenCredential, nil) - if err != nil { - return nil, trace.Wrap(err) - } - } - resourceID, err := arm.ParseResourceID(tokenClaims.ResourceID) if err != nil { return nil, trace.Wrap(err) @@ -324,12 +323,12 @@ func (a *Server) checkAzureRequest(ctx context.Context, challenge string, req *p return trace.AccessDenied("this token does not support the Azure join method") } - vmID, err := parseAndVerifyAttestedData(ctx, req.AttestedData, challenge, cfg.certificateAuthorities) + subID, vmID, err := parseAndVerifyAttestedData(ctx, req.AttestedData, challenge, cfg.certificateAuthorities) if err != nil { return trace.Wrap(err) } - vm, err := verifyVMIdentity(ctx, cfg, req.AccessToken, vmID, requestStart) + vm, err := verifyVMIdentity(ctx, cfg, req.AccessToken, subID, vmID, requestStart) if err != nil { return trace.Wrap(err) } diff --git a/lib/auth/join_azure_test.go b/lib/auth/join_azure_test.go index 5fe5d487fcf3a..faa9bb6f0cb95 100644 --- a/lib/auth/join_azure_test.go +++ b/lib/auth/join_azure_test.go @@ -54,23 +54,41 @@ func withVerifyFunc(verify azureVerifyTokenFunc) azureRegisterOption { } } -func withVMClient(vmClient azure.VirtualMachinesClient) azureRegisterOption { +func withVMClientGetter(getVMClient vmClientGetter) azureRegisterOption { return func(cfg *azureRegisterConfig) { - cfg.vmClient = vmClient + cfg.getVMClient = getVMClient } } type mockAzureVMClient struct { azure.VirtualMachinesClient - vm *azure.VirtualMachine + vms map[string]*azure.VirtualMachine } -func (m *mockAzureVMClient) Get(_ context.Context, _ string) (*azure.VirtualMachine, error) { - return m.vm, nil +func (m *mockAzureVMClient) Get(_ context.Context, resourceID string) (*azure.VirtualMachine, error) { + vm, ok := m.vms[resourceID] + if !ok { + return nil, trace.NotFound("no vm with resource id %q", resourceID) + } + return vm, nil +} + +func (m *mockAzureVMClient) GetByVMID(_ context.Context, resourceGroup, vmID string) (*azure.VirtualMachine, error) { + for _, vm := range m.vms { + if vm.VMID == vmID && (resourceGroup == types.Wildcard || vm.ResourceGroup == resourceGroup) { + return vm, nil + } + } + return nil, trace.NotFound("no vm in groups %q with id %q", resourceGroup, vmID) } -func (m *mockAzureVMClient) GetByVMID(_ context.Context, _, _ string) (*azure.VirtualMachine, error) { - return m.vm, nil +func makeVMClientGetter(clients map[string]*mockAzureVMClient) vmClientGetter { + return func(subscriptionID string, _ *azure.StaticCredential) (azure.VirtualMachinesClient, error) { + if client, ok := clients[subscriptionID]; ok { + return client, nil + } + return nil, trace.NotFound("no client for subscription %q", subscriptionID) + } } type azureChallengeResponseConfig struct { @@ -85,10 +103,14 @@ func withChallengeAzure(challenge string) azureChallengeResponseOption { } } -func resourceID(subscription, resourceGroup, name string) string { +func vmResourceID(subscription, resourceGroup, name string) string { + return resourceID("virtualMachines", subscription, resourceGroup, name) +} + +func resourceID(resourceType, subscription, resourceGroup, name string) string { return fmt.Sprintf( - "/subscriptions/%v/resourcegroups/%v/providers/Microsoft.Compute/virtualMachines/%v", - subscription, resourceGroup, name, + "/subscriptions/%v/resourcegroups/%v/providers/Microsoft.Compute/%v/%v", + subscription, resourceGroup, resourceType, name, ) } @@ -161,43 +183,47 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { tlsPublicKey, err := PrivateKeyToPublicKeyTLS(sshPrivateKey) require.NoError(t, err) - isAccessDenied := func(t require.TestingT, err error, _ ...interface{}) { + isAccessDenied := func(t require.TestingT, err error, _ ...any) { require.True(t, trace.IsAccessDenied(err), "expected Access Denied error, actual error: %v", err) } - isBadParameter := func(t require.TestingT, err error, _ ...interface{}) { + isBadParameter := func(t require.TestingT, err error, _ ...any) { require.True(t, trace.IsBadParameter(err), "expected Bad Parameter error, actual error: %v", err) } + isNotFound := func(t require.TestingT, err error, _ ...any) { + require.True(t, trace.IsNotFound(err), "expected Not Found error, actual error: %v", err) + } - subID := uuid.NewString() + defaultSubscription := uuid.NewString() + defaultResourceGroup := "my-resource-group" + defaultName := "test-vm" + defaultVMID := "my-vm-id" + defaultResourceID := vmResourceID(defaultSubscription, defaultResourceGroup, defaultName) tests := []struct { name string - subscription string - resourceGroup string - vmID string - tokenName string + tokenResourceID string + tokenSubscription string + tokenVMID string requestTokenName string tokenSpec types.ProvisionTokenSpecV2 challengeResponseOptions []azureChallengeResponseOption challengeResponseErr error certs []*x509.Certificate verify azureVerifyTokenFunc - vmResult *azure.VirtualMachine assertError require.ErrorAssertionFunc }{ { - name: "basic passing case", - tokenName: "test-token", - requestTokenName: "test-token", - subscription: subID, - resourceGroup: "RG", + name: "basic passing case", + requestTokenName: "test-token", + tokenSubscription: defaultSubscription, + tokenVMID: defaultVMID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ { - Subscription: subID, - ResourceGroups: []string{"rg"}, + Subscription: defaultSubscription, + ResourceGroups: []string{defaultResourceGroup}, }, }, }, @@ -208,18 +234,17 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { assertError: require.NoError, }, { - name: "resource group is case insensitive", - tokenName: "test-token", - requestTokenName: "test-token", - subscription: subID, - resourceGroup: "my-RESOURCE-GROUP", + name: "resource group is case insensitive", + requestTokenName: "test-token", + tokenSubscription: defaultSubscription, + tokenVMID: defaultVMID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ { - Subscription: subID, - ResourceGroups: []string{"MY-resource-group"}, + Subscription: defaultSubscription, + ResourceGroups: []string{"MY-resource-GROUP"}, }, }, }, @@ -230,17 +255,16 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { assertError: require.NoError, }, { - name: "wrong token", - tokenName: "test-token", - requestTokenName: "wrong-token", - subscription: subID, - resourceGroup: "RG", + name: "wrong token", + requestTokenName: "wrong-token", + tokenSubscription: defaultSubscription, + tokenVMID: defaultVMID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ { - Subscription: subID, + Subscription: defaultSubscription, }, }, }, @@ -251,17 +275,16 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { assertError: isAccessDenied, }, { - name: "challenge response error", - tokenName: "test-token", - requestTokenName: "test-token", - subscription: subID, - resourceGroup: "RG", + name: "challenge response error", + requestTokenName: "test-token", + tokenSubscription: defaultSubscription, + tokenVMID: defaultVMID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ { - Subscription: subID, + Subscription: defaultSubscription, }, }, }, @@ -273,17 +296,16 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { assertError: isBadParameter, }, { - name: "wrong subscription", - tokenName: "test-token", - requestTokenName: "test-token", - subscription: "some-junk", - resourceGroup: "RG", + name: "wrong subscription", + requestTokenName: "test-token", + tokenSubscription: defaultSubscription, + tokenVMID: defaultVMID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ { - Subscription: subID, + Subscription: "alternate-subscription-id", }, }, }, @@ -294,18 +316,17 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { assertError: isAccessDenied, }, { - name: "wrong resource group", - tokenName: "test-token", - requestTokenName: "test-token", - subscription: subID, - resourceGroup: "WRONG-RG", + name: "wrong resource group", + requestTokenName: "test-token", + tokenSubscription: defaultSubscription, + tokenVMID: defaultVMID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ { - Subscription: subID, - ResourceGroups: []string{"rg"}, + Subscription: defaultSubscription, + ResourceGroups: []string{"alternate-resource-group"}, }, }, }, @@ -316,17 +337,16 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { assertError: isAccessDenied, }, { - name: "wrong challenge", - tokenName: "test-token", - requestTokenName: "test-token", - subscription: subID, - resourceGroup: "RG", + name: "wrong challenge", + requestTokenName: "test-token", + tokenSubscription: defaultSubscription, + tokenVMID: defaultVMID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ { - Subscription: subID, + Subscription: defaultSubscription, }, }, }, @@ -340,17 +360,16 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { assertError: isAccessDenied, }, { - name: "invalid signature", - tokenName: "test-token", - requestTokenName: "test-token", - subscription: subID, - resourceGroup: "RG", + name: "invalid signature", + requestTokenName: "test-token", + tokenSubscription: defaultSubscription, + tokenVMID: defaultVMID, tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ { - Subscription: subID, + Subscription: defaultSubscription, }, }, }, @@ -361,38 +380,94 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { assertError: require.Error, }, { - name: "attested data and access token from different VMs", - tokenName: "test-token", - requestTokenName: "test-token", - subscription: subID, - resourceGroup: "RG", - vmID: "vm-id", + name: "attested data and access token from different VMs", + requestTokenName: "test-token", + tokenSubscription: defaultSubscription, + tokenVMID: "some-other-vm-id", tokenSpec: types.ProvisionTokenSpecV2{ Roles: []types.SystemRole{types.RoleNode}, Azure: &types.ProvisionTokenSpecV2Azure{ Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ { - Subscription: subID, + Subscription: defaultSubscription, }, }, }, JoinMethod: types.JoinMethodAzure, }, - vmResult: &azure.VirtualMachine{ - Subscription: subID, - ResourceGroup: "RG", - VMID: "different-id", - }, verify: mockVerifyToken(nil), certs: []*x509.Certificate{tlsConfig.Certificate}, assertError: isAccessDenied, }, + { + name: "vm not found", + requestTokenName: "test-token", + tokenSubscription: defaultSubscription, + tokenVMID: defaultVMID, + tokenResourceID: vmResourceID(defaultSubscription, "nonexistent-group", defaultName), + tokenSpec: types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleNode}, + Azure: &types.ProvisionTokenSpecV2Azure{ + Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ + { + Subscription: defaultSubscription, + }, + }, + }, + JoinMethod: types.JoinMethodAzure, + }, + verify: mockVerifyToken(nil), + certs: []*x509.Certificate{tlsConfig.Certificate}, + assertError: isNotFound, + }, + { + name: "lookup vm by id", + requestTokenName: "test-token", + tokenSubscription: defaultSubscription, + tokenVMID: defaultVMID, + tokenResourceID: resourceID("some.other.provider", defaultSubscription, defaultResourceGroup, defaultName), + tokenSpec: types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleNode}, + Azure: &types.ProvisionTokenSpecV2Azure{ + Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ + { + Subscription: defaultSubscription, + }, + }, + }, + JoinMethod: types.JoinMethodAzure, + }, + verify: mockVerifyToken(nil), + certs: []*x509.Certificate{tlsConfig.Certificate}, + assertError: require.NoError, + }, + { + name: "vm is in a different subscription than the token it provides", + requestTokenName: "test-token", + tokenSubscription: defaultSubscription, + tokenVMID: defaultVMID, + tokenResourceID: resourceID("some.other.provider", "some-other-subscription", defaultResourceGroup, defaultName), + tokenSpec: types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleNode}, + Azure: &types.ProvisionTokenSpecV2Azure{ + Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ + { + Subscription: defaultSubscription, + }, + }, + }, + JoinMethod: types.JoinMethodAzure, + }, + verify: mockVerifyToken(nil), + certs: []*x509.Certificate{tlsConfig.Certificate}, + assertError: require.NoError, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { token, err := types.NewProvisionTokenFromSpec( - tc.tokenName, + "test-token", time.Now().Add(time.Minute), tc.tokenSpec) require.NoError(t, err) @@ -401,23 +476,28 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { require.NoError(t, a.DeleteToken(ctx, token.GetName())) }) - rsID := resourceID(tc.subscription, tc.resourceGroup, "test-vm") + rsID := tc.tokenResourceID + if rsID == "" { + rsID = vmResourceID(defaultSubscription, defaultResourceGroup, defaultName) + } accessToken, err := makeToken(rsID, a.clock.Now()) require.NoError(t, err) - vmResult := tc.vmResult - if vmResult == nil { - vmResult = &azure.VirtualMachine{ - ID: rsID, - Name: "test-vm", - Subscription: tc.subscription, - ResourceGroup: tc.resourceGroup, - VMID: tc.vmID, - } + vmClient := &mockAzureVMClient{ + vms: map[string]*azure.VirtualMachine{ + defaultResourceID: { + ID: defaultResourceID, + Name: defaultName, + Subscription: defaultSubscription, + ResourceGroup: defaultResourceGroup, + VMID: defaultVMID, + }, + }, } - - vmClient := &mockAzureVMClient{vm: vmResult} + getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{ + defaultSubscription: vmClient, + }) _, err = a.RegisterUsingAzureMethod(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { cfg := &azureChallengeResponseConfig{Challenge: challenge} @@ -427,8 +507,8 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { ad := attestedData{ Nonce: cfg.Challenge, - SubscriptionID: subID, - ID: tc.vmID, + SubscriptionID: tc.tokenSubscription, + ID: tc.tokenVMID, } adBytes, err := json.Marshal(&ad) require.NoError(t, err) @@ -456,7 +536,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { AccessToken: accessToken, } return req, tc.challengeResponseErr - }, withCerts(tc.certs), withVerifyFunc(tc.verify), withVMClient(vmClient)) + }, withCerts(tc.certs), withVerifyFunc(tc.verify), withVMClientGetter(getVMClient)) tc.assertError(t, err) }) }