From 199f98cee1bcf47938dd9ffb27056596f38175f6 Mon Sep 17 00:00:00 2001 From: Arne Luenser Date: Thu, 12 Oct 2023 16:16:47 +0200 Subject: [PATCH] WIP: refactor Transactional interface --- handler/oauth2/flow_authorize_code_token.go | 34 +++----- handler/oauth2/flow_refresh.go | 93 ++++++++------------- storage/transactional.go | 58 +++---------- 3 files changed, 58 insertions(+), 127 deletions(-) diff --git a/handler/oauth2/flow_authorize_code_token.go b/handler/oauth2/flow_authorize_code_token.go index dceb8300..dc669eed 100644 --- a/handler/oauth2/flow_authorize_code_token.go +++ b/handler/oauth2/flow_authorize_code_token.go @@ -152,26 +152,21 @@ func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx contex } } - ctx, err = storage.MaybeBeginTx(ctx, c.CoreStorage) - if err != nil { - return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) - } - defer func() { - if err != nil { - if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil { - err = errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr)) - } + if err := storage.MaybeTransaction(ctx, c.CoreStorage, func(ctx context.Context) error { + if err := c.CoreStorage.InvalidateAuthorizeCodeSession(ctx, signature); err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) } - }() - - if err = c.CoreStorage.InvalidateAuthorizeCodeSession(ctx, signature); err != nil { - return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) - } else if err = c.CoreStorage.CreateAccessTokenSession(ctx, accessSignature, requester.Sanitize([]string{})); err != nil { - return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) - } else if refreshSignature != "" { - if err = c.CoreStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester.Sanitize([]string{})); err != nil { + if err := c.CoreStorage.CreateAccessTokenSession(ctx, accessSignature, requester.Sanitize([]string{})); err != nil { return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) } + if refreshSignature != "" { + if err := c.CoreStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester.Sanitize([]string{})); err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) + } + } + return nil + }); err != nil { + return err // error already wrapped inside tx callback } responder.SetAccessToken(access) @@ -182,11 +177,6 @@ func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx contex if refresh != "" { responder.SetExtra("refresh_token", refresh) } - - if err = storage.MaybeCommitTx(ctx, c.CoreStorage); err != nil { - return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) - } - return nil } diff --git a/handler/oauth2/flow_refresh.go b/handler/oauth2/flow_refresh.go index 5bf68070..acfaa35b 100644 --- a/handler/oauth2/flow_refresh.go +++ b/handler/oauth2/flow_refresh.go @@ -126,34 +126,29 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con signature := c.RefreshTokenStrategy.RefreshTokenSignature(ctx, requester.GetRequestForm().Get("refresh_token")) - ctx, err = storage.MaybeBeginTx(ctx, c.TokenRevocationStorage) - if err != nil { - return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) - } defer func() { err = c.handleRefreshTokenEndpointStorageError(ctx, err) }() - ts, err := c.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil) - if err != nil { - return err - } else if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, ts.GetID()); err != nil { - return err - } - - if err := c.TokenRevocationStorage.RevokeRefreshTokenMaybeGracePeriod(ctx, ts.GetID(), signature); err != nil { - return err - } - - storeReq := requester.Sanitize([]string{}) - storeReq.SetID(ts.GetID()) - - if err = c.TokenRevocationStorage.CreateAccessTokenSession(ctx, accessSignature, storeReq); err != nil { - return err - } - - if err = c.TokenRevocationStorage.CreateRefreshTokenSession(ctx, refreshSignature, storeReq); err != nil { - return err + if err := storage.MaybeTransaction(ctx, c.TokenRevocationStorage, func(ctx context.Context) error { + ts, err := c.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil) + if err != nil { + return err + } + if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, ts.GetID()); err != nil { + return err + } + if err := c.TokenRevocationStorage.RevokeRefreshTokenMaybeGracePeriod(ctx, ts.GetID(), signature); err != nil { + return err + } + storeReq := requester.Sanitize([]string{}) + storeReq.SetID(ts.GetID()) + if err := c.TokenRevocationStorage.CreateAccessTokenSession(ctx, accessSignature, storeReq); err != nil { + return err + } + return c.TokenRevocationStorage.CreateRefreshTokenSession(ctx, refreshSignature, storeReq) + }); err != nil { + return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) } responder.SetAccessToken(accessToken) @@ -163,10 +158,6 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con responder.SetScopes(requester.GetGrantedScopes()) responder.SetExtra("refresh_token", refreshToken) - if err = storage.MaybeCommitTx(ctx, c.TokenRevocationStorage); err != nil { - return err - } - return nil } @@ -179,32 +170,20 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con // legitimate client is trying to access, in case of such an access // attempt the valid refresh token and the access authorization // associated with it are both revoked. -func (c *RefreshTokenGrantHandler) handleRefreshTokenReuse(ctx context.Context, signature string, req fosite.Requester) (err error) { - ctx, err = storage.MaybeBeginTx(ctx, c.TokenRevocationStorage) - if err != nil { - return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) - } - defer func() { - err = c.handleRefreshTokenEndpointStorageError(ctx, err) - }() - - if err = c.TokenRevocationStorage.DeleteRefreshTokenSession(ctx, signature); err != nil { - return err - } else if err = c.TokenRevocationStorage.RevokeRefreshToken( - ctx, req.GetID(), - ); err != nil && !errors.Is(err, fosite.ErrNotFound) { - return err - } else if err = c.TokenRevocationStorage.RevokeAccessToken( - ctx, req.GetID(), - ); err != nil && !errors.Is(err, fosite.ErrNotFound) { - return err - } - - if err = storage.MaybeCommitTx(ctx, c.TokenRevocationStorage); err != nil { - return err - } - - return nil +func (c *RefreshTokenGrantHandler) handleRefreshTokenReuse(ctx context.Context, signature string, req fosite.Requester) error { + err := storage.MaybeTransaction(ctx, c.TokenRevocationStorage, func(ctx context.Context) error { + if err := c.TokenRevocationStorage.DeleteRefreshTokenSession(ctx, signature); err != nil { + return err + } + if err := c.TokenRevocationStorage.RevokeRefreshToken(ctx, req.GetID()); err != nil && !errors.Is(err, fosite.ErrNotFound) { + return err + } + if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, req.GetID()); err != nil && !errors.Is(err, fosite.ErrNotFound) { + return err + } + return nil + }) + return c.handleRefreshTokenEndpointStorageError(ctx, err) } func (c *RefreshTokenGrantHandler) handleRefreshTokenEndpointStorageError(ctx context.Context, storageErr error) (err error) { @@ -212,12 +191,6 @@ func (c *RefreshTokenGrantHandler) handleRefreshTokenEndpointStorageError(ctx co return nil } - defer func() { - if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.TokenRevocationStorage); rollBackTxnErr != nil { - err = errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr)) - } - }() - if errors.Is(storageErr, fosite.ErrSerializationFailure) { return errorsx.WithStack(fosite.ErrInvalidRequest. WithDebugf(storageErr.Error()). diff --git a/storage/transactional.go b/storage/transactional.go index f144340f..87d57d9e 100644 --- a/storage/transactional.go +++ b/storage/transactional.go @@ -5,54 +5,22 @@ package storage import "context" -// A storage provider that has support for transactions should implement this interface to ensure atomicity for certain flows -// that require transactional semantics. Fosite will call these methods (when atomicity is required) if and only if the storage -// provider has implemented `Transactional`. It is expected that the storage provider will examine context for an existing transaction -// each time a database operation is to be performed. +// A storage provider that has support for transactions should implement this +// interface to ensure atomicity for certain flows that require transactional +// semantics. When atomicity is required, Fosite will group calls to the storage +// provider in a function and passes that to Transaction. Implementations are +// expected to execute these calls in a transactional manner. Typically, a +// handle to the transaction will be stored in the context. // -// An implementation of `BeginTX` should attempt to initiate a new transaction and store that under a unique key -// in the context that can be accessible by `Commit` and `Rollback`. The "transactional aware" context will then be -// returned for further propagation, eventually to be consumed by `Commit` or `Rollback` to finish the transaction. -// -// Implementations for `Commit` & `Rollback` should look for the transaction object inside the supplied context using the same -// key used by `BeginTX`. If these methods have been called, it is expected that a txn object should be available in the provided -// context. +// Implementations should rollback (or retry) the transaction if the callback +// returns an error. type Transactional interface { - BeginTX(ctx context.Context) (context.Context, error) - Commit(ctx context.Context) error - Rollback(ctx context.Context) error -} - -// MaybeBeginTx is a helper function that can be used to initiate a transaction if the supplied storage -// implements the `Transactional` interface. -func MaybeBeginTx(ctx context.Context, storage interface{}) (context.Context, error) { - // the type assertion checks whether the dynamic type of `storage` implements `Transactional` - txnStorage, transactional := storage.(Transactional) - if transactional { - return txnStorage.BeginTX(ctx) - } else { - return ctx, nil - } -} - -// MaybeCommitTx is a helper function that can be used to commit a transaction if the supplied storage -// implements the `Transactional` interface. -func MaybeCommitTx(ctx context.Context, storage interface{}) error { - txnStorage, transactional := storage.(Transactional) - if transactional { - return txnStorage.Commit(ctx) - } else { - return nil - } + Transaction(context.Context, func(context.Context) error) error } -// MaybeRollbackTx is a helper function that can be used to rollback a transaction if the supplied storage -// implements the `Transactional` interface. -func MaybeRollbackTx(ctx context.Context, storage interface{}) error { - txnStorage, transactional := storage.(Transactional) - if transactional { - return txnStorage.Rollback(ctx) - } else { - return nil +func MaybeTransaction(ctx context.Context, storage any, f func(context.Context) error) error { + if tx, ok := storage.(Transactional); ok { + return tx.Transaction(ctx, f) } + return f(ctx) }