Skip to content

Commit

Permalink
Adapt unit tests to use MockDBClient
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Barnes committed Jan 15, 2025
1 parent e6bcd9d commit bad8c9b
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 93 deletions.
66 changes: 38 additions & 28 deletions backend/operations_scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@ package main

import (
"context"
"errors"
"log/slog"
"net/http"
"net/http/httptest"
"testing"

cmv1 "github.com/openshift-online/ocm-sdk-go/clustersmgmt/v1"
"go.uber.org/mock/gomock"

"github.com/Azure/ARO-HCP/internal/api/arm"
"github.com/Azure/ARO-HCP/internal/database"
"github.com/Azure/ARO-HCP/internal/mocks"
"github.com/Azure/ARO-HCP/internal/ocm"
)

Expand Down Expand Up @@ -60,6 +61,8 @@ func TestDeleteOperationCompleted(t *testing.T) {
var request *http.Request

ctx := context.Background()
ctrl := gomock.NewController(t)
mockDBClient := mocks.NewMockDBClient(ctrl)

resourceID, err := arm.ParseResourceID("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/testGroup/providers/Microsoft.RedHatOpenShift/hcpOpenShiftClusters/testCluster")
if err != nil {
Expand All @@ -74,20 +77,26 @@ func TestDeleteOperationCompleted(t *testing.T) {
defer server.Close()

scanner := &OperationsScanner{
dbClient: database.NewCache(),
dbClient: mockDBClient,
notificationClient: server.Client(),
}

operationDoc := database.NewOperationDocument(database.OperationRequestDelete, resourceID, internalID)
operationDoc.NotificationURI = server.URL
operationDoc.Status = tt.operationStatus

_ = scanner.dbClient.CreateOperationDoc(ctx, operationDoc)
var resourceDocDeleted bool

if tt.resourceDocPresent {
resourceDoc := database.NewResourceDocument(resourceID)
_ = scanner.dbClient.CreateResourceDoc(ctx, resourceDoc)
}
mockDBClient.EXPECT().
DeleteResourceDoc(gomock.Any(), resourceID).
Do(func(ctx context.Context, resourceID *arm.ResourceID) {
resourceDocDeleted = tt.resourceDocPresent
})
mockDBClient.EXPECT().
UpdateOperationDoc(gomock.Any(), operationDoc.ID, gomock.Any()).
DoAndReturn(func(ctx context.Context, operationID string, callback func(*database.OperationDocument) bool) (bool, error) {
return callback(operationDoc), nil
})

err = scanner.deleteOperationCompleted(ctx, slog.Default(), operationDoc)

Expand All @@ -103,18 +112,11 @@ func TestDeleteOperationCompleted(t *testing.T) {
t.Errorf("Got unexpected error: %v", err)
}

if err == nil && tt.resourceDocPresent {
_, getErr := scanner.dbClient.GetResourceDoc(ctx, resourceID)
if !errors.Is(getErr, database.ErrNotFound) {
t.Error("Expected resource document to be deleted")
}
if err == nil && tt.resourceDocPresent && !resourceDocDeleted {
t.Error("Expected resource document to be deleted")
}

if err == nil && tt.expectAsyncNotification {
operationDoc, getErr := scanner.dbClient.GetOperationDoc(ctx, operationDoc.ID)
if getErr != nil {
t.Fatal(getErr)
}
if operationDoc.Status != arm.ProvisioningStateSucceeded {
t.Errorf("Expected updated operation status to be %s but got %s",
arm.ProvisioningStateSucceeded,
Expand Down Expand Up @@ -207,6 +209,8 @@ func TestUpdateOperationStatus(t *testing.T) {
var request *http.Request

ctx := context.Background()
ctrl := gomock.NewController(t)
mockDBClient := mocks.NewMockDBClient(ctrl)

resourceID, err := arm.ParseResourceID("/subscriptions/00000000-0000-0000-0000-000000000000/resourceGroups/testGroup/providers/Microsoft.RedHatOpenShift/hcpOpenShiftClusters/testCluster")
if err != nil {
Expand All @@ -221,27 +225,41 @@ func TestUpdateOperationStatus(t *testing.T) {
defer server.Close()

scanner := &OperationsScanner{
dbClient: database.NewCache(),
dbClient: mockDBClient,
notificationClient: server.Client(),
}

operationDoc := database.NewOperationDocument(database.OperationRequestCreate, resourceID, internalID)
operationDoc.NotificationURI = server.URL
operationDoc.Status = tt.currentOperationStatus

_ = scanner.dbClient.CreateOperationDoc(ctx, operationDoc)
var resourceDoc *database.ResourceDocument

if tt.resourceDocPresent {
resourceDoc := database.NewResourceDocument(resourceID)
resourceDoc = database.NewResourceDocument(resourceID)
if tt.resourceMatchOperationID {
resourceDoc.ActiveOperationID = operationDoc.ID
} else {
resourceDoc.ActiveOperationID = "another operation"
}
resourceDoc.ProvisioningState = tt.resourceProvisioningState
_ = scanner.dbClient.CreateResourceDoc(ctx, resourceDoc)
}

mockDBClient.EXPECT().
UpdateOperationDoc(gomock.Any(), operationDoc.ID, gomock.Any()).
DoAndReturn(func(ctx context.Context, operationID string, callback func(*database.OperationDocument) bool) (bool, error) {
return callback(operationDoc), nil
})
mockDBClient.EXPECT().
UpdateResourceDoc(gomock.Any(), resourceID, gomock.Any()).
DoAndReturn(func(ctx context.Context, resourceID *arm.ResourceID, callback func(*database.ResourceDocument) bool) (bool, error) {
if resourceDoc != nil {
return callback(resourceDoc), nil
} else {
return false, database.ErrNotFound
}
})

err = scanner.updateOperationStatus(ctx, slog.Default(), operationDoc, tt.updatedOperationStatus, nil)

if request == nil && tt.expectAsyncNotification {
Expand All @@ -257,10 +275,6 @@ func TestUpdateOperationStatus(t *testing.T) {
}

if err == nil && tt.expectAsyncNotification {
operationDoc, getErr := scanner.dbClient.GetOperationDoc(ctx, operationDoc.ID)
if getErr != nil {
t.Fatal(getErr)
}
if operationDoc.Status != tt.updatedOperationStatus {
t.Errorf("Expected updated operation status to be %s but got %s",
tt.updatedOperationStatus,
Expand All @@ -269,10 +283,6 @@ func TestUpdateOperationStatus(t *testing.T) {
}

if err == nil && tt.resourceDocPresent {
resourceDoc, getErr := scanner.dbClient.GetResourceDoc(ctx, resourceID)
if getErr != nil {
t.Fatal(getErr)
}
if resourceDoc.ActiveOperationID == "" && !tt.expectResourceOperationIDCleared {
t.Error("Resource's active operation ID is unexpectedly empty")
} else if resourceDoc.ActiveOperationID != "" && tt.expectResourceOperationIDCleared {
Expand Down
72 changes: 59 additions & 13 deletions frontend/pkg/frontend/frontend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ import (
"time"

"github.com/prometheus/client_golang/prometheus"
"go.uber.org/mock/gomock"

"github.com/Azure/ARO-HCP/internal/api"
"github.com/Azure/ARO-HCP/internal/api/arm"
"github.com/Azure/ARO-HCP/internal/database"
"github.com/Azure/ARO-HCP/internal/mocks"
)

var testLogger = slog.New(slog.NewTextHandler(io.Discard, nil))
Expand All @@ -41,8 +43,11 @@ func TestReadiness(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
mockDBClient := mocks.NewMockDBClient(ctrl)

f := &Frontend{
dbClient: database.NewCache(),
dbClient: mockDBClient,
metrics: NewPrometheusEmitter(prometheus.NewRegistry()),
}
f.ready.Store(test.ready)
Expand All @@ -51,6 +56,9 @@ func TestReadiness(t *testing.T) {
return ContextWithLogger(context.Background(), testLogger)
}

// Call expected but is irrelevant to the test.
mockDBClient.EXPECT().DBConnectionTest(gomock.Any())

rs, err := ts.Client().Get(ts.URL + "/healthz")
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -92,17 +100,25 @@ func TestSubscriptionsGET(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
mockDBClient := mocks.NewMockDBClient(ctrl)

f := &Frontend{
dbClient: database.NewCache(),
dbClient: mockDBClient,
metrics: NewPrometheusEmitter(prometheus.NewRegistry()),
}

if test.subDoc != nil {
err := f.dbClient.CreateSubscriptionDoc(context.TODO(), test.subDoc)
if err != nil {
t.Fatal(err)
}
}
// ArmSubscriptionGet and MetricsMiddleware
mockDBClient.EXPECT().
GetSubscriptionDoc(gomock.Any(), gomock.Any()).
DoAndReturn(func(context.Context, string) (*database.SubscriptionDocument, error) {
if test.subDoc != nil {
return test.subDoc, nil
} else {
return nil, database.ErrNotFound
}
}).
Times(2)

ts := httptest.NewServer(f.routes())
ts.Config.BaseContext = func(net.Listener) context.Context {
Expand Down Expand Up @@ -209,8 +225,11 @@ func TestSubscriptionsPUT(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
mockDBClient := mocks.NewMockDBClient(ctrl)

f := &Frontend{
dbClient: database.NewCache(),
dbClient: mockDBClient,
metrics: NewPrometheusEmitter(prometheus.NewRegistry()),
}

Expand All @@ -219,12 +238,39 @@ func TestSubscriptionsPUT(t *testing.T) {
t.Fatal(err)
}

if test.subDoc != nil {
err := f.dbClient.CreateSubscriptionDoc(context.TODO(), test.subDoc)
if err != nil {
t.Fatal(err)
// MiddlewareLockSubscription
// (except when MiddlewareValidateStatic fails)
mockDBClient.EXPECT().
GetLockClient().
MaxTimes(1)
if test.expectedStatusCode != http.StatusBadRequest {
// ArmSubscriptionPut
mockDBClient.EXPECT().
GetSubscriptionDoc(gomock.Any(), gomock.Any()).
DoAndReturn(func(context.Context, string) (*database.SubscriptionDocument, error) {
if test.subDoc != nil {
return test.subDoc, nil
} else {
return nil, database.ErrNotFound
}
})
// ArmSubscriptionPut
if test.subDoc == nil {
mockDBClient.EXPECT().
CreateSubscriptionDoc(gomock.Any(), gomock.Any())
} else {
mockDBClient.EXPECT().
UpdateSubscriptionDoc(gomock.Any(), gomock.Any(), gomock.Any())
}
}
// MiddlewareMetrics
// (except when MiddlewareValidateStatic fails)
mockDBClient.EXPECT().
GetSubscriptionDoc(gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, subscriptionID string) (*database.SubscriptionDocument, error) {
return database.NewSubscriptionDocument(subscriptionID, test.subscription), nil
}).
MaxTimes(1)

ts := httptest.NewServer(f.routes())
ts.Config.BaseContext = func(net.Listener) context.Context {
Expand Down
37 changes: 25 additions & 12 deletions frontend/pkg/frontend/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ import (
"net/http"
"testing"

"go.uber.org/mock/gomock"

"github.com/Azure/ARO-HCP/internal/api/arm"
"github.com/Azure/ARO-HCP/internal/database"
"github.com/Azure/ARO-HCP/internal/mocks"
)

func TestCheckForProvisioningStateConflict(t *testing.T) {
Expand Down Expand Up @@ -149,21 +152,25 @@ func TestCheckForProvisioningStateConflict(t *testing.T) {
name = fmt.Sprintf("%s (directState=%s)", tt.name, directState)
t.Run(name, func(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockDBClient := mocks.NewMockDBClient(ctrl)

frontend := &Frontend{
dbClient: database.NewCache(),
dbClient: mockDBClient,
}

doc := database.NewResourceDocument(resourceID)
doc.ProvisioningState = directState

parentResourceID := resourceID.GetParent()
if parentResourceID.ResourceType.Namespace == resourceID.ResourceType.Namespace {
parentDoc := database.NewResourceDocument(parentResourceID)
// Hold the provisioning state to something benign.
parentDoc.ProvisioningState = arm.ProvisioningStateSucceeded
_ = frontend.dbClient.CreateResourceDoc(ctx, parentDoc)
}
mockDBClient.EXPECT().
GetResourceDoc(gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, resourceID *arm.ResourceID) (*database.ResourceDocument, error) {
resourceDoc := database.NewResourceDocument(resourceID)
// Hold the provisioning state to something benign.
resourceDoc.ProvisioningState = arm.ProvisioningStateSucceeded
return resourceDoc, nil
}).
MaxTimes(1)

cloudError := frontend.CheckForProvisioningStateConflict(ctx, tt.operationRequest, doc)

Expand All @@ -183,9 +190,11 @@ func TestCheckForProvisioningStateConflict(t *testing.T) {
name = fmt.Sprintf("%s (parentState=%s)", tt.name, parentState)
t.Run(name, func(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockDBClient := mocks.NewMockDBClient(ctrl)

frontend := &Frontend{
dbClient: database.NewCache(),
dbClient: mockDBClient,
}

doc := database.NewResourceDocument(resourceID)
Expand All @@ -194,9 +203,13 @@ func TestCheckForProvisioningStateConflict(t *testing.T) {

parentResourceID := resourceID.GetParent()
if parentResourceID.ResourceType.Namespace == resourceID.ResourceType.Namespace {
parentDoc := database.NewResourceDocument(parentResourceID)
parentDoc.ProvisioningState = parentState
_ = frontend.dbClient.CreateResourceDoc(ctx, parentDoc)
mockDBClient.EXPECT().
GetResourceDoc(gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, resourceID *arm.ResourceID) (*database.ResourceDocument, error) {
resourceDoc := database.NewResourceDocument(resourceID)
resourceDoc.ProvisioningState = parentState
return resourceDoc, nil
})
} else {
t.Fatalf("Parent resource type namespace (%s) differs from child namespace (%s)",
parentResourceID.ResourceType.Namespace,
Expand Down
Loading

0 comments on commit bad8c9b

Please sign in to comment.