Skip to content

Commit

Permalink
fix: add isolation level in transaction repository
Browse files Browse the repository at this point in the history
Currently there were chances to have race conditions while writing
to transaction repository. I have added a test to verify it doesn't
happen. Database is using repeatable read as the isolation level
to avoid overlapping transactions.

Signed-off-by: Kush Sharma <[email protected]>
  • Loading branch information
kushsharma committed Nov 29, 2024
1 parent 2968371 commit b33447d
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 142 deletions.
8 changes: 6 additions & 2 deletions internal/api/v1beta1/org.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package v1beta1
import (
"context"

"github.com/raystack/frontier/core/serviceuser"

"github.com/raystack/frontier/core/authenticate"

"go.uber.org/zap"
Expand Down Expand Up @@ -318,13 +320,15 @@ func (h Handler) ListOrganizationServiceUsers(ctx context.Context, request *fron
}
}

users, err := h.serviceUserService.ListByOrg(ctx, orgResp.ID)
usersList, err := h.serviceUserService.List(ctx, serviceuser.Filter{
OrgID: orgResp.ID,
})
if err != nil {
return nil, err
}

var usersPB []*frontierv1beta1.ServiceUser
for _, rel := range users {
for _, rel := range usersList {
u, err := transformServiceUserToPB(rel)
if err != nil {
return nil, err
Expand Down
4 changes: 3 additions & 1 deletion internal/api/v1beta1/org_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,9 @@ func TestHandler_ListOrganizationServiceUsers(t *testing.T) {
for _, u := range testUserMap {
testUserList = append(testUserList, u)
}
us.EXPECT().ListByOrg(mock.AnythingOfType("context.backgroundCtx"), testOrgID).Return([]serviceuser.ServiceUser{
us.EXPECT().List(mock.AnythingOfType("context.backgroundCtx"), serviceuser.Filter{
OrgID: testOrgID,
}).Return([]serviceuser.ServiceUser{
{
ID: "9f256f86-31a3-11ec-8d3d-0242ac130003",
Title: "Sample Service User",
Expand Down
228 changes: 125 additions & 103 deletions internal/store/postgres/billing_transactions_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"strings"
"time"

"github.com/raystack/frontier/billing/customer"

"github.com/raystack/frontier/internal/bootstrap/schema"

"github.com/jackc/pgconn"
Expand Down Expand Up @@ -81,131 +83,143 @@ func NewBillingTransactionRepository(dbc *db.Client) *BillingTransactionReposito
}
}

var (
maxRetries = 5
// Error codes from https://www.postgresql.org/docs/current/errcodes-appendix.html
serializationFailureCode = "40001"
deadlockDetectedCode = "40P01"
)

func (r BillingTransactionRepository) withRetry(ctx context.Context, fn func() error) error {
var lastErr error
for i := 0; i < maxRetries && ctx.Err() == nil; i++ {
err := fn()
if err == nil {
return nil
}

var pqErr *pgconn.PgError
if errors.As(err, &pqErr) {
// Retry on serialization failures or deadlocks
if pqErr.Code == serializationFailureCode || pqErr.Code == deadlockDetectedCode {
lastErr = err
// Exponential backoff with jitter
backoff := time.Duration(1<<uint(i)) * 100 * time.Millisecond
jitter := time.Duration(rand.Int63n(int64(backoff / 2)))
time.Sleep(backoff + jitter)
continue
}
}
return err // Return immediately for other errors
}
return fmt.Errorf("max retries exceeded: %w", lastErr)
}

func (r BillingTransactionRepository) CreateEntry(ctx context.Context, debitEntry credit.Transaction,
creditEntry credit.Transaction) ([]credit.Transaction, error) {
var customerAcc customer.Customer
txOpts := sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
ReadOnly: false,
}

var err error
var debitModel Transaction
var creditModel Transaction
var customerAcc customer.Customer

if debitEntry.CustomerID != schema.PlatformOrgID.String() {
// only fetch if it's a customer debit entry
customerAcc, err = r.customerRepo.GetByID(ctx, debitEntry.CustomerID)
if err != nil {
return nil, fmt.Errorf("failed to get customer account: %w", err)
}
}

if debitEntry.Metadata == nil {
debitEntry.Metadata = make(map[string]any)
var creditReturnedEntry, debitReturnedEntry credit.Transaction
err = r.withRetry(ctx, func() error {
return r.dbc.WithTxn(ctx, txOpts, func(tx *sqlx.Tx) error {
if debitEntry.CustomerID != schema.PlatformOrgID.String() {
// check for balance only when deducting from customer account
currentBalance, err := r.getBalanceInTx(ctx, tx, debitEntry.CustomerID)
if err != nil {
return fmt.Errorf("failed to get balance: %w", err)
}

if err := isSufficientBalance(customerAcc.CreditMin, currentBalance, debitEntry.Amount); err != nil {
return err
}
}

if err := r.createTransactionEntry(ctx, tx, debitEntry, &debitModel); err != nil {
return fmt.Errorf("failed to create debit entry: %w", err)
}
if err := r.createTransactionEntry(ctx, tx, creditEntry, &creditModel); err != nil {
return fmt.Errorf("failed to create credit entry: %w", err)
}
return nil
})
})
if err != nil {
if errors.Is(err, credit.ErrAlreadyApplied) {
return nil, credit.ErrAlreadyApplied
} else if errors.Is(err, credit.ErrInsufficientCredits) {
return nil, credit.ErrInsufficientCredits
}
return nil, fmt.Errorf("failed to create transaction entry: %w", err)
}
debitMetadata, err := json.Marshal(debitEntry.Metadata)

creditReturnedEntry, err = creditModel.transform()
if err != nil {
return nil, err
}
debitRecord := goqu.Record{
"account_id": debitEntry.CustomerID,
"description": debitEntry.Description,
"type": debitEntry.Type,
"source": debitEntry.Source,
"amount": debitEntry.Amount,
"user_id": debitEntry.UserID,
"metadata": debitMetadata,
"created_at": goqu.L("now()"),
"updated_at": goqu.L("now()"),
return nil, fmt.Errorf("failed to transform credit entry: %w", err)
}
if debitEntry.ID != "" {
debitRecord["id"] = debitEntry.ID
debitReturnedEntry, err = debitModel.transform()
if err != nil {
return nil, fmt.Errorf("failed to transform debit entry: %w", err)
}
return []credit.Transaction{debitReturnedEntry, creditReturnedEntry}, nil
}

if creditEntry.Metadata == nil {
creditEntry.Metadata = make(map[string]any)
func (r BillingTransactionRepository) createTransactionEntry(ctx context.Context, tx *sqlx.Tx, entry credit.Transaction, model *Transaction) error {
if entry.Metadata == nil {
entry.Metadata = make(map[string]any)
}
creditMetadata, err := json.Marshal(creditEntry.Metadata)
metadata, err := json.Marshal(entry.Metadata)
if err != nil {
return nil, err
}
creditRecord := goqu.Record{
"account_id": creditEntry.CustomerID,
"description": creditEntry.Description,
"type": creditEntry.Type,
"source": creditEntry.Source,
"amount": creditEntry.Amount,
"user_id": creditEntry.UserID,
"metadata": creditMetadata,
return err
}

record := goqu.Record{
"account_id": entry.CustomerID,
"description": entry.Description,
"type": entry.Type,
"source": entry.Source,
"amount": entry.Amount,
"user_id": entry.UserID,
"metadata": metadata,
"created_at": goqu.L("now()"),
"updated_at": goqu.L("now()"),
}
if creditEntry.ID != "" {
creditRecord["id"] = creditEntry.ID
if entry.ID != "" {
record["id"] = entry.ID
}

var creditReturnedEntry, debitReturnedEntry credit.Transaction
if err := r.dbc.WithTxn(ctx, sql.TxOptions{}, func(tx *sqlx.Tx) error {
// check if balance is enough if it's a customer entry
if customerAcc.ID != "" {
currentBalance, err := r.getBalanceInTx(ctx, tx, customerAcc.ID)
if err != nil {
return fmt.Errorf("failed to apply transaction: %w", err)
}
if err := isSufficientBalance(customerAcc.CreditMin, currentBalance, debitEntry.Amount); err != nil {
return err
}
}

var debitModel Transaction
var creditModel Transaction
query, params, err := dialect.Insert(TABLE_BILLING_TRANSACTIONS).Rows(debitRecord).Returning(&Transaction{}).ToSQL()
if err != nil {
return fmt.Errorf("%w: %s", parseErr, err)
}
if err = r.dbc.WithTimeout(ctx, TABLE_BILLING_TRANSACTIONS, "Create", func(ctx context.Context) error {
return r.dbc.QueryRowxContext(ctx, query, params...).StructScan(&debitModel)
}); err != nil {
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && (pqErr.Code == "23505") { // handle unique key violations
if pqErr.ConstraintName == "billing_transactions_pkey" { // primary key violation
return credit.ErrAlreadyApplied
}
// add other specific unique key violations here if needed
}
return fmt.Errorf("%w: %s", dbErr, err)
}

query, params, err = dialect.Insert(TABLE_BILLING_TRANSACTIONS).Rows(creditRecord).Returning(&Transaction{}).ToSQL()
if err != nil {
return fmt.Errorf("%w: %s", parseErr, err)
}
if err = r.dbc.WithTimeout(ctx, TABLE_BILLING_TRANSACTIONS, "Create", func(ctx context.Context) error {
return r.dbc.QueryRowxContext(ctx, query, params...).StructScan(&creditModel)
}); err != nil {
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && (pqErr.Code == "23505") { // handle unique key violations
if pqErr.ConstraintName == "billing_transactions_pkey" { // primary key violation
return credit.ErrAlreadyApplied
}
// add other specific unique key violations here if needed
}
return fmt.Errorf("%w: %s", dbErr, err)
}

creditReturnedEntry, err = creditModel.transform()
if err != nil {
return fmt.Errorf("failed to transform credit entry: %w", err)
}
debitReturnedEntry, err = debitModel.transform()
if err != nil {
return fmt.Errorf("failed to transform debit entry: %w", err)
}
query, params, err := dialect.Insert(TABLE_BILLING_TRANSACTIONS).Rows(record).Returning(&Transaction{}).ToSQL()
if err != nil {
return fmt.Errorf("%w: %w", parseErr, err)
}

return nil
if err = r.dbc.WithTimeout(ctx, TABLE_BILLING_TRANSACTIONS, "Create", func(ctx context.Context) error {
return tx.QueryRowxContext(ctx, query, params...).StructScan(model)
}); err != nil {
if errors.Is(err, credit.ErrAlreadyApplied) {
return nil, credit.ErrAlreadyApplied
} else if errors.Is(err, credit.ErrInsufficientCredits) {
return nil, credit.ErrInsufficientCredits
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && (pqErr.Code == "23505") {
if pqErr.ConstraintName == "billing_transactions_pkey" {
return credit.ErrAlreadyApplied
}
}
return nil, fmt.Errorf("failed to create transaction entry: %w", err)
return fmt.Errorf("%w: %w", dbErr, err)
}

return []credit.Transaction{debitReturnedEntry, creditReturnedEntry}, nil
return nil
}

// isSufficientBalance checks if the customer has enough balance to perform the transaction.
Expand Down Expand Up @@ -328,6 +342,7 @@ func (r BillingTransactionRepository) getDebitBalance(ctx context.Context, tx *s
"account_id": accountID,
"type": credit.DebitType,
})

query, params, err := stmt.ToSQL()
if err != nil {
return nil, fmt.Errorf("%w: %s", parseErr, err)
Expand All @@ -347,6 +362,7 @@ func (r BillingTransactionRepository) getCreditBalance(ctx context.Context, tx *
"account_id": accountID,
"type": credit.CreditType,
})

query, params, err := stmt.ToSQL()
if err != nil {
return nil, fmt.Errorf("%w: %s", parseErr, err)
Expand Down Expand Up @@ -388,11 +404,17 @@ func (r BillingTransactionRepository) getBalanceInTx(ctx context.Context, tx *sq
// in transaction table till now.
func (r BillingTransactionRepository) GetBalance(ctx context.Context, accountID string) (int64, error) {
var amount int64
if err := r.dbc.WithTxn(ctx, sql.TxOptions{}, func(tx *sqlx.Tx) error {
var err error
amount, err = r.getBalanceInTx(ctx, tx, accountID)
return err
}); err != nil {
err := r.withRetry(ctx, func() error {
return r.dbc.WithTxn(ctx, sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
ReadOnly: true,
}, func(tx *sqlx.Tx) error {
var err error
amount, err = r.getBalanceInTx(ctx, tx, accountID)
return err
})
})
if err != nil {
return 0, fmt.Errorf("failed to get balance: %w", err)
}
return amount, nil
Expand Down
Loading

0 comments on commit b33447d

Please sign in to comment.