Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add isolation level in transaction repository #824

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions internal/api/v1beta1/org.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package v1beta1
import (
"context"

"github.com/raystack/frontier/core/serviceuser"

"github.com/raystack/frontier/core/authenticate"

"go.uber.org/zap"
Expand Down Expand Up @@ -318,13 +320,15 @@ func (h Handler) ListOrganizationServiceUsers(ctx context.Context, request *fron
}
}

users, err := h.serviceUserService.ListByOrg(ctx, orgResp.ID)
usersList, err := h.serviceUserService.List(ctx, serviceuser.Filter{
OrgID: orgResp.ID,
})
if err != nil {
return nil, err
}

var usersPB []*frontierv1beta1.ServiceUser
for _, rel := range users {
for _, rel := range usersList {
u, err := transformServiceUserToPB(rel)
if err != nil {
return nil, err
Expand Down
4 changes: 3 additions & 1 deletion internal/api/v1beta1/org_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,9 @@ func TestHandler_ListOrganizationServiceUsers(t *testing.T) {
for _, u := range testUserMap {
testUserList = append(testUserList, u)
}
us.EXPECT().ListByOrg(mock.AnythingOfType("context.backgroundCtx"), testOrgID).Return([]serviceuser.ServiceUser{
us.EXPECT().List(mock.AnythingOfType("context.backgroundCtx"), serviceuser.Filter{
OrgID: testOrgID,
}).Return([]serviceuser.ServiceUser{
{
ID: "9f256f86-31a3-11ec-8d3d-0242ac130003",
Title: "Sample Service User",
Expand Down
228 changes: 125 additions & 103 deletions internal/store/postgres/billing_transactions_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"strings"
"time"

"github.com/raystack/frontier/billing/customer"

"github.com/raystack/frontier/internal/bootstrap/schema"

"github.com/jackc/pgconn"
Expand Down Expand Up @@ -81,131 +83,143 @@ func NewBillingTransactionRepository(dbc *db.Client) *BillingTransactionReposito
}
}

var (
maxRetries = 5
// Error codes from https://www.postgresql.org/docs/current/errcodes-appendix.html
serializationFailureCode = "40001"
deadlockDetectedCode = "40P01"
)

func (r BillingTransactionRepository) withRetry(ctx context.Context, fn func() error) error {
var lastErr error
for i := 0; i < maxRetries && ctx.Err() == nil; i++ {
err := fn()
if err == nil {
return nil
}

var pqErr *pgconn.PgError
if errors.As(err, &pqErr) {
// Retry on serialization failures or deadlocks
if pqErr.Code == serializationFailureCode || pqErr.Code == deadlockDetectedCode {
lastErr = err
// Exponential backoff with jitter
backoff := time.Duration(1<<uint(i)) * 100 * time.Millisecond
jitter := time.Duration(rand.Int63n(int64(backoff / 2)))
time.Sleep(backoff + jitter)
continue
}
}
return err // Return immediately for other errors
}
return fmt.Errorf("max retries exceeded: %w", lastErr)
}

func (r BillingTransactionRepository) CreateEntry(ctx context.Context, debitEntry credit.Transaction,
creditEntry credit.Transaction) ([]credit.Transaction, error) {
var customerAcc customer.Customer
txOpts := sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
ReadOnly: false,
}

var err error
var debitModel Transaction
var creditModel Transaction
var customerAcc customer.Customer

if debitEntry.CustomerID != schema.PlatformOrgID.String() {
// only fetch if it's a customer debit entry
customerAcc, err = r.customerRepo.GetByID(ctx, debitEntry.CustomerID)
if err != nil {
return nil, fmt.Errorf("failed to get customer account: %w", err)
}
}

if debitEntry.Metadata == nil {
debitEntry.Metadata = make(map[string]any)
var creditReturnedEntry, debitReturnedEntry credit.Transaction
err = r.withRetry(ctx, func() error {
return r.dbc.WithTxn(ctx, txOpts, func(tx *sqlx.Tx) error {
if debitEntry.CustomerID != schema.PlatformOrgID.String() {
// check for balance only when deducting from customer account
currentBalance, err := r.getBalanceInTx(ctx, tx, debitEntry.CustomerID)
if err != nil {
return fmt.Errorf("failed to get balance: %w", err)
}

if err := isSufficientBalance(customerAcc.CreditMin, currentBalance, debitEntry.Amount); err != nil {
return err
}
}

if err := r.createTransactionEntry(ctx, tx, debitEntry, &debitModel); err != nil {
return fmt.Errorf("failed to create debit entry: %w", err)
}
if err := r.createTransactionEntry(ctx, tx, creditEntry, &creditModel); err != nil {
return fmt.Errorf("failed to create credit entry: %w", err)
}
return nil
})
})
if err != nil {
if errors.Is(err, credit.ErrAlreadyApplied) {
return nil, credit.ErrAlreadyApplied
} else if errors.Is(err, credit.ErrInsufficientCredits) {
return nil, credit.ErrInsufficientCredits
}
return nil, fmt.Errorf("failed to create transaction entry: %w", err)
}
debitMetadata, err := json.Marshal(debitEntry.Metadata)

creditReturnedEntry, err = creditModel.transform()
if err != nil {
return nil, err
}
debitRecord := goqu.Record{
"account_id": debitEntry.CustomerID,
"description": debitEntry.Description,
"type": debitEntry.Type,
"source": debitEntry.Source,
"amount": debitEntry.Amount,
"user_id": debitEntry.UserID,
"metadata": debitMetadata,
"created_at": goqu.L("now()"),
"updated_at": goqu.L("now()"),
return nil, fmt.Errorf("failed to transform credit entry: %w", err)
}
if debitEntry.ID != "" {
debitRecord["id"] = debitEntry.ID
debitReturnedEntry, err = debitModel.transform()
if err != nil {
return nil, fmt.Errorf("failed to transform debit entry: %w", err)
}
return []credit.Transaction{debitReturnedEntry, creditReturnedEntry}, nil
}

if creditEntry.Metadata == nil {
creditEntry.Metadata = make(map[string]any)
func (r BillingTransactionRepository) createTransactionEntry(ctx context.Context, tx *sqlx.Tx, entry credit.Transaction, model *Transaction) error {
if entry.Metadata == nil {
entry.Metadata = make(map[string]any)
}
creditMetadata, err := json.Marshal(creditEntry.Metadata)
metadata, err := json.Marshal(entry.Metadata)
if err != nil {
return nil, err
}
creditRecord := goqu.Record{
"account_id": creditEntry.CustomerID,
"description": creditEntry.Description,
"type": creditEntry.Type,
"source": creditEntry.Source,
"amount": creditEntry.Amount,
"user_id": creditEntry.UserID,
"metadata": creditMetadata,
return err
}

record := goqu.Record{
"account_id": entry.CustomerID,
"description": entry.Description,
"type": entry.Type,
"source": entry.Source,
"amount": entry.Amount,
"user_id": entry.UserID,
"metadata": metadata,
"created_at": goqu.L("now()"),
"updated_at": goqu.L("now()"),
}
if creditEntry.ID != "" {
creditRecord["id"] = creditEntry.ID
if entry.ID != "" {
record["id"] = entry.ID
}

var creditReturnedEntry, debitReturnedEntry credit.Transaction
if err := r.dbc.WithTxn(ctx, sql.TxOptions{}, func(tx *sqlx.Tx) error {
// check if balance is enough if it's a customer entry
if customerAcc.ID != "" {
currentBalance, err := r.getBalanceInTx(ctx, tx, customerAcc.ID)
if err != nil {
return fmt.Errorf("failed to apply transaction: %w", err)
}
if err := isSufficientBalance(customerAcc.CreditMin, currentBalance, debitEntry.Amount); err != nil {
return err
}
}

var debitModel Transaction
var creditModel Transaction
query, params, err := dialect.Insert(TABLE_BILLING_TRANSACTIONS).Rows(debitRecord).Returning(&Transaction{}).ToSQL()
if err != nil {
return fmt.Errorf("%w: %s", parseErr, err)
}
if err = r.dbc.WithTimeout(ctx, TABLE_BILLING_TRANSACTIONS, "Create", func(ctx context.Context) error {
return r.dbc.QueryRowxContext(ctx, query, params...).StructScan(&debitModel)
}); err != nil {
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && (pqErr.Code == "23505") { // handle unique key violations
if pqErr.ConstraintName == "billing_transactions_pkey" { // primary key violation
return credit.ErrAlreadyApplied
}
// add other specific unique key violations here if needed
}
return fmt.Errorf("%w: %s", dbErr, err)
}

query, params, err = dialect.Insert(TABLE_BILLING_TRANSACTIONS).Rows(creditRecord).Returning(&Transaction{}).ToSQL()
if err != nil {
return fmt.Errorf("%w: %s", parseErr, err)
}
if err = r.dbc.WithTimeout(ctx, TABLE_BILLING_TRANSACTIONS, "Create", func(ctx context.Context) error {
return r.dbc.QueryRowxContext(ctx, query, params...).StructScan(&creditModel)
}); err != nil {
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && (pqErr.Code == "23505") { // handle unique key violations
if pqErr.ConstraintName == "billing_transactions_pkey" { // primary key violation
return credit.ErrAlreadyApplied
}
// add other specific unique key violations here if needed
}
return fmt.Errorf("%w: %s", dbErr, err)
}

creditReturnedEntry, err = creditModel.transform()
if err != nil {
return fmt.Errorf("failed to transform credit entry: %w", err)
}
debitReturnedEntry, err = debitModel.transform()
if err != nil {
return fmt.Errorf("failed to transform debit entry: %w", err)
}
query, params, err := dialect.Insert(TABLE_BILLING_TRANSACTIONS).Rows(record).Returning(&Transaction{}).ToSQL()
if err != nil {
return fmt.Errorf("%w: %w", parseErr, err)
}

return nil
if err = r.dbc.WithTimeout(ctx, TABLE_BILLING_TRANSACTIONS, "Create", func(ctx context.Context) error {
return tx.QueryRowxContext(ctx, query, params...).StructScan(model)
}); err != nil {
if errors.Is(err, credit.ErrAlreadyApplied) {
return nil, credit.ErrAlreadyApplied
} else if errors.Is(err, credit.ErrInsufficientCredits) {
return nil, credit.ErrInsufficientCredits
var pqErr *pgconn.PgError
if errors.As(err, &pqErr) && (pqErr.Code == "23505") {
if pqErr.ConstraintName == "billing_transactions_pkey" {
return credit.ErrAlreadyApplied
}
}
return nil, fmt.Errorf("failed to create transaction entry: %w", err)
return fmt.Errorf("%w: %w", dbErr, err)
}

return []credit.Transaction{debitReturnedEntry, creditReturnedEntry}, nil
return nil
}

// isSufficientBalance checks if the customer has enough balance to perform the transaction.
Expand Down Expand Up @@ -328,6 +342,7 @@ func (r BillingTransactionRepository) getDebitBalance(ctx context.Context, tx *s
"account_id": accountID,
"type": credit.DebitType,
})

query, params, err := stmt.ToSQL()
if err != nil {
return nil, fmt.Errorf("%w: %s", parseErr, err)
Expand All @@ -347,6 +362,7 @@ func (r BillingTransactionRepository) getCreditBalance(ctx context.Context, tx *
"account_id": accountID,
"type": credit.CreditType,
})

query, params, err := stmt.ToSQL()
if err != nil {
return nil, fmt.Errorf("%w: %s", parseErr, err)
Expand Down Expand Up @@ -388,11 +404,17 @@ func (r BillingTransactionRepository) getBalanceInTx(ctx context.Context, tx *sq
// in transaction table till now.
func (r BillingTransactionRepository) GetBalance(ctx context.Context, accountID string) (int64, error) {
var amount int64
if err := r.dbc.WithTxn(ctx, sql.TxOptions{}, func(tx *sqlx.Tx) error {
var err error
amount, err = r.getBalanceInTx(ctx, tx, accountID)
return err
}); err != nil {
err := r.withRetry(ctx, func() error {
return r.dbc.WithTxn(ctx, sql.TxOptions{
Isolation: sql.LevelRepeatableRead,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I understand correctly, we check balance by always looping over the table and calculating difference. We insert new rows for debit/credit. How will sql.LevelRepeatableRead prevent race condition here? Phantom reads and Write Skew are allowed in Repeatable Read as far I can see. It won't prevent new rows being written and hence change of balance mid transaction.
Can you test if this isolation is working fine? Maybe write a Python or Go code and hit the API concurrently 20-30 times for the same org. If let's say starting balance is 105, and we deduct 10 credit each call then only 10 calls should succeed and the balance should be 5 at the end.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, Serializable would be best and safest(with some added latency), that's what I proposed in the last PR. Although I don't see the issue happening in the test https://github.com/raystack/frontier/pull/824/files#diff-aaeca6450e0dcd11871799ac2b14da4225d7a76f94c04a9d4636e44ae780773eR1115

ReadOnly: true,
}, func(tx *sqlx.Tx) error {
var err error
amount, err = r.getBalanceInTx(ctx, tx, accountID)
return err
})
})
if err != nil {
return 0, fmt.Errorf("failed to get balance: %w", err)
}
return amount, nil
Expand Down
Loading
Loading