Skip to content

Commit

Permalink
feat: new algorithm for generating token IDs with high performance
Browse files Browse the repository at this point in the history
  • Loading branch information
Mmx233 committed May 16, 2024
1 parent a92c9bc commit 6c55e90
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 39 deletions.
4 changes: 1 addition & 3 deletions internal/db/redis/MfaLogin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,4 @@ import (
"github.com/ncuhome/GeniusAuthoritarian/pkg/tokenStore"
)

func NewMfaLogin() tokenStore.TokenStore[jwtClaims.MfaRedis] {
return tokenStore.NewTokenStore[jwtClaims.MfaRedis](Client, keyUserMfaLogin.String())
}
var NewMfaLogin = tokenStore.NewTokenStoreFactory[jwtClaims.MfaRedis](Client, keyUserMfaLogin.String())
4 changes: 3 additions & 1 deletion internal/db/redis/RefreshToken.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ type RefreshTokenStore struct {
tokenStore.TokenStore[types.Nil]
}

var _NewRefreshTokenStore = tokenStore.NewTokenStoreFactory[types.Nil](Client, keyRecordedToken.String())

func NewRecordedToken() RefreshTokenStore {
return RefreshTokenStore{
tokenStore.NewTokenStore[types.Nil](Client, keyRecordedToken.String()),
_NewRefreshTokenStore(),
}
}

Expand Down
4 changes: 1 addition & 3 deletions internal/db/redis/ThirdPartyLogin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,4 @@ import (
"github.com/ncuhome/GeniusAuthoritarian/pkg/tokenStore"
)

func NewThirdPartyLogin() tokenStore.TokenStore[jwtClaims.LoginRedis] {
return tokenStore.NewTokenStore[jwtClaims.LoginRedis](Client, keyThirdPartyLogin.String())
}
var NewThirdPartyLogin = tokenStore.NewTokenStoreFactory[jwtClaims.LoginRedis](Client, keyThirdPartyLogin.String())
4 changes: 1 addition & 3 deletions internal/db/redis/U2F.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,4 @@ import (
"go/types"
)

func NewU2F() tokenStore.TokenStore[types.Nil] {
return tokenStore.NewTokenStore[types.Nil](Client, keyU2F.String())
}
var NewU2F = tokenStore.NewTokenStoreFactory[types.Nil](Client, keyU2F.String())
114 changes: 85 additions & 29 deletions pkg/tokenStore/tokenStore.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,86 @@ import (
"encoding/json"
"fmt"
"github.com/go-redis/redis/v8"
"sync"
"sync/atomic"
"time"
)

func NewTokenStore[C any](Client *redis.Client, keyPrefix string) TokenStore[C] {
return TokenStore[C]{
client: Client,
keyPrefix: keyPrefix,
keyID: keyPrefix + "id",
type Node struct {
// this is a unique id for each node process, it
// will be reallocated every day.
// ID must be smaller than 100.
ID uint64

client *redis.Client
keyNodeIDPrefix string

// use for refresh fields.
Lock *sync.Mutex
IDTimeMark uint64
TokenID *atomic.Uint64
}

func (node *Node) keyNodeID(timeMark uint64) string {
return fmt.Sprintf("%s-%d", node.keyNodeIDPrefix, timeMark)
}

func (node *Node) currentTimeMark() uint64 {
return uint64(time.Now().YearDay())
}

func (node *Node) GenID(ctx context.Context) (uint64, error) {
currentTimeMark := node.currentTimeMark()
if node.IDTimeMark != currentTimeMark {
node.Lock.Lock()
if node.IDTimeMark == currentTimeMark {
node.Lock.Unlock()
return node.GenID(ctx)
}
defer node.Lock.Unlock()
newNodeID, err := node.client.Incr(ctx, node.keyNodeID(currentTimeMark)).Uint64()
if err != nil {
return 0, err
}
node.ID = newNodeID % 100
node.TokenID.Store(0)
node.IDTimeMark = currentTimeMark
}
tokenID := node.TokenID.Add(1)
tokenID = (tokenID << 5) + (node.ID << 3) + node.IDTimeMark
return tokenID, nil
}

func NewTokenStoreFactory[C any](Client *redis.Client, keyPrefix string) func() TokenStore[C] {
node := Node{
client: Client,
keyNodeIDPrefix: keyPrefix + "id",
Lock: &sync.Mutex{},
TokenID: &atomic.Uint64{},
}
return func() TokenStore[C] {
return TokenStore[C]{
client: Client,
keyPrefix: keyPrefix,
node: &node,
}
}
}

type TokenStore[C any] struct {
client *redis.Client
// token 有效校验的 key 前缀
keyPrefix string
// redis ID 字段 key,用于给 token 分配不一样的 ID
keyID string

// node should be set with static variable
node *Node
}

func (a TokenStore[C]) genKey(id uint64) string {
return a.keyPrefix + fmt.Sprint(id)
func (store TokenStore[C]) genKey(id uint64) string {
return store.keyPrefix + fmt.Sprint(id)
}

func (a TokenStore[C]) CreateStorePointWithID(ctx context.Context, id uint64, valid time.Duration, claims *C) error {
func (store TokenStore[C]) CreateStorePointWithID(ctx context.Context, id uint64, valid time.Duration, claims *C) error {
var value []byte
var err error
if claims != nil {
Expand All @@ -39,29 +95,29 @@ func (a TokenStore[C]) CreateStorePointWithID(ctx context.Context, id uint64, va
} else {
value = []byte{'1'}
}
return a.client.Set(ctx, a.genKey(id), value, valid).Err()
return store.client.Set(ctx, store.genKey(id), value, valid).Err()
}

func (a TokenStore[C]) CreateStorePoint(ctx context.Context, valid time.Duration, claims *C) (uint64, error) {
id, err := a.client.Incr(ctx, a.keyID).Uint64()
func (store TokenStore[C]) CreateStorePoint(ctx context.Context, valid time.Duration, claims *C) (uint64, error) {
id, err := store.node.GenID(ctx)
if err != nil {
return 0, err
}
return id, a.CreateStorePointWithID(ctx, id, valid, claims)
return id, store.CreateStorePointWithID(ctx, id, valid, claims)
}

func (a TokenStore[C]) MPointGet(ctx context.Context, ids ...uint64) ([]interface{}, error) {
func (store TokenStore[C]) MPointGet(ctx context.Context, ids ...uint64) ([]interface{}, error) {
keys := make([]string, len(ids))
for i, id := range ids {
keys[i] = a.genKey(id)
keys[i] = store.genKey(id)
}
return a.client.MGet(ctx, keys...).Result()
return store.client.MGet(ctx, keys...).Result()
}

func (a TokenStore[C]) NewStorePoint(id uint64) Point[C] {
func (store TokenStore[C]) NewStorePoint(id uint64) Point[C] {
return Point[C]{
s: a,
key: a.genKey(id),
s: store,
key: store.genKey(id),
}
}

Expand All @@ -70,31 +126,31 @@ type Point[C any] struct {
key string
}

func (a Point[C]) parsePoint(data []byte, claims *C) error {
func (point Point[C]) parsePoint(data []byte, claims *C) error {
if claims != nil {
return json.Unmarshal(data, claims)
}
return nil
}

func (a Point[C]) GetAndDestroy(ctx context.Context, claims *C) error {
value, err := a.s.client.GetDel(ctx, a.key).Bytes()
func (point Point[C]) GetAndDestroy(ctx context.Context, claims *C) error {
value, err := point.s.client.GetDel(ctx, point.key).Bytes()
if err != nil {
return err
}

return a.parsePoint(value, claims)
return point.parsePoint(value, claims)
}

func (a Point[C]) Get(ctx context.Context, claims *C) error {
value, err := a.s.client.Get(ctx, a.key).Bytes()
func (point Point[C]) Get(ctx context.Context, claims *C) error {
value, err := point.s.client.Get(ctx, point.key).Bytes()
if err != nil {
return err
}

return a.parsePoint(value, claims)
return point.parsePoint(value, claims)
}

func (a Point[C]) Destroy(ctx context.Context) error {
return a.s.client.Del(ctx, a.key).Err()
func (point Point[C]) Destroy(ctx context.Context) error {
return point.s.client.Del(ctx, point.key).Err()
}

0 comments on commit 6c55e90

Please sign in to comment.