From fea8488ea310e8a3889119c847ed23ab58b65f2e Mon Sep 17 00:00:00 2001 From: yumaojun03 <719118794@qq.com> Date: Mon, 4 Nov 2024 19:36:10 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A1=A5=E5=85=85try=20lock?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ioc/apps/metric/restful/metric.go | 4 +- ioc/config/lock/go_cache.go | 11 +++++ ioc/config/lock/interface.go | 2 + ioc/config/lock/lock_test.go | 23 ++++++++++ ioc/config/lock/options.go | 8 ++++ ioc/config/lock/redis_lock.go | 60 ++++++++++++++++++++++----- ioc/config/lock/redis_lua/obtain.lua | 40 ++++++++++++++++++ ioc/config/lock/redis_lua/pttl.lua | 21 ++++++++++ ioc/config/lock/redis_lua/refresh.lua | 17 ++++++++ ioc/config/lock/redis_lua/release.lua | 16 +++++++ 10 files changed, 190 insertions(+), 12 deletions(-) create mode 100644 ioc/config/lock/redis_lua/obtain.lua create mode 100644 ioc/config/lock/redis_lua/pttl.lua create mode 100644 ioc/config/lock/redis_lua/refresh.lua create mode 100644 ioc/config/lock/redis_lua/release.lua diff --git a/ioc/apps/metric/restful/metric.go b/ioc/apps/metric/restful/metric.go index 56c42f7..caf7d9e 100644 --- a/ioc/apps/metric/restful/metric.go +++ b/ioc/apps/metric/restful/metric.go @@ -81,12 +81,12 @@ func (h *restfulHandler) AddApiCollector() { } func (h *restfulHandler) Registry() { - tags := []string{"健康检查"} + tags := []string{"应用指标"} ws := ioc_rest.ObjectRouter(h) ws.Route(ws. GET("/"). To(h.MetricHandleFunc). - Doc("健康检查"). + Doc("Prometheus指标"). Metadata(restfulspec.KeyOpenAPITags, tags). Metadata(restfulspec.KeyOpenAPITags, tags), ) diff --git a/ioc/config/lock/go_cache.go b/ioc/config/lock/go_cache.go index f33b446..56def99 100644 --- a/ioc/config/lock/go_cache.go +++ b/ioc/config/lock/go_cache.go @@ -76,6 +76,17 @@ func (m *GoCacheLock) Lock(ctx context.Context) error { } } +// TryLock +func (m *GoCacheLock) TryLock(ctx context.Context) error { + ok, err := m.obtain(ctx) + if err != nil { + return err + } else if ok { + return nil + } + return ErrNotObtained +} + func (m *GoCacheLock) obtain(context.Context) (bool, error) { if m.cache.Has(m.key) { return false, nil diff --git a/ioc/config/lock/interface.go b/ioc/config/lock/interface.go index 128c0e2..211ed44 100644 --- a/ioc/config/lock/interface.go +++ b/ioc/config/lock/interface.go @@ -39,6 +39,8 @@ type LockFactory interface { type Lock interface { // 锁配置 WithOpt(opt *Options) Lock + // TryLock + TryLock(ctx context.Context) error // 获取锁 Lock(ctx context.Context) error // 释放锁 diff --git a/ioc/config/lock/lock_test.go b/ioc/config/lock/lock_test.go index c144ff1..f6da400 100644 --- a/ioc/config/lock/lock_test.go +++ b/ioc/config/lock/lock_test.go @@ -28,6 +28,17 @@ func TestRedisLock(t *testing.T) { time.Sleep(10 * time.Second) } +func TestRedisTryLock(t *testing.T) { + os.Setenv("LOCK_PROVIDER", lock.PROVIDER_REDIS) + ioc.DevelopmentSetup() + g := &sync.WaitGroup{} + for i := range 9 { + go TryLockTest(i, g) + } + g.Wait() + time.Sleep(10 * time.Second) +} + func TestGoCacheRedisLock(t *testing.T) { ioc.DevelopmentSetup() g := &sync.WaitGroup{} @@ -49,6 +60,18 @@ func LockTest(number int, g *sync.WaitGroup) { fmt.Println(number, "down") } +func TryLockTest(number int, g *sync.WaitGroup) { + fmt.Println(number, "start") + g.Add(1) + defer g.Done() + m := lock.L().New("test", 1*time.Second) + if err := m.TryLock(ctx); err != nil { + fmt.Println(number, err) + return + } + fmt.Println(number, "obtained lock") +} + func TestDefaultConfig(t *testing.T) { file.MustToToml( lock.AppName, diff --git a/ioc/config/lock/options.go b/ioc/config/lock/options.go index fd83d33..6e7b4f6 100644 --- a/ioc/config/lock/options.go +++ b/ioc/config/lock/options.go @@ -20,6 +20,9 @@ type Options struct { // Token is a unique value that is used to identify the lock. By default, a random tokens are generated. Use this // option to provide a custom token instead. Token string + + // 超时时间 + Timeout time.Duration } func (o *Options) getMetadata() string { @@ -42,3 +45,8 @@ func (o *Options) getRetryStrategy() RetryStrategy { } return NoRetry() } + +func (o *Options) SetTimeout(t time.Duration) *Options { + o.Timeout = t + return o +} diff --git a/ioc/config/lock/redis_lock.go b/ioc/config/lock/redis_lock.go index 110b68c..11c3e48 100644 --- a/ioc/config/lock/redis_lock.go +++ b/ioc/config/lock/redis_lock.go @@ -3,6 +3,7 @@ package lock import ( "context" "crypto/rand" + _ "embed" "encoding/base64" "io" "strconv" @@ -13,16 +14,23 @@ import ( "github.com/redis/go-redis/v9" ) +//go:embed redis_lua/release.lua +var luaReleaseScript string + +//go:embed redis_lua/refresh.lua +var luaRefreshScript string + +//go:embed redis_lua/pttl.lua +var luaPTTLScript string + +//go:embed redis_lua/obtain.lua +var luaObtainScript string + var ( - luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`) - luaRelease = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("del", KEYS[1]) else return 0 end`) - luaPTTL = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pttl", KEYS[1]) else return -3 end`) - luaObtain = redis.NewScript(` -if redis.call("set", KEYS[1], ARGV[1], "NX", "PX", ARGV[3]) then return redis.status_reply("OK") end - -local offset = tonumber(ARGV[2]) -if redis.call("getrange", KEYS[1], 0, offset-1) == string.sub(ARGV[1], 1, offset) then return redis.call("set", KEYS[1], ARGV[1], "PX", ARGV[3]) end -`) + luaRefresh = redis.NewScript(luaRefreshScript) + luaRelease = redis.NewScript(luaReleaseScript) + luaPTTL = redis.NewScript(luaPTTLScript) + luaObtain = redis.NewScript(luaObtainScript) ) func NewRedisLockProvider() *RedisLockProvider { @@ -53,6 +61,13 @@ type RedisLock struct { tmpMu sync.Mutex } +func (l *RedisLock) getTimeout() time.Duration { + if l.opt.Timeout > 0 { + return l.opt.Timeout + } + return l.ttl * 3 +} + func (l *RedisLock) TTLValueString() string { return strconv.FormatInt(int64(l.ttl/time.Millisecond), 10) } @@ -81,7 +96,7 @@ func (l *RedisLock) Lock(ctx context.Context) error { // make sure we don't retry forever if _, ok := ctx.Deadline(); !ok { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, l.ttl*3) + ctx, cancel = context.WithTimeout(ctx, l.getTimeout()) defer cancel() } @@ -114,6 +129,31 @@ func (l *RedisLock) Lock(ctx context.Context) error { } } +// 获取锁 +func (l *RedisLock) TryLock(ctx context.Context) error { + token := l.opt.getToken() + + // Create a random token + if token == "" { + var err error + if token, err = l.randomToken(); err != nil { + return err + } + } + + value := token + l.opt.getMetadata() + ok, err := l.obtain(ctx, l.key, value, len(token)) + if err != nil { + return err + } + + if !ok { + return ErrNotObtained + } + + return nil +} + func (c *RedisLock) obtain(ctx context.Context, key, value string, tokenLen int) (bool, error) { _, err := luaObtain.Run(ctx, c.client, []string{key}, value, tokenLen, c.TTLValueString()).Result() if err == redis.Nil { diff --git a/ioc/config/lock/redis_lua/obtain.lua b/ioc/config/lock/redis_lua/obtain.lua new file mode 100644 index 0000000..11e55d6 --- /dev/null +++ b/ioc/config/lock/redis_lua/obtain.lua @@ -0,0 +1,40 @@ +-- obtain.lua: arguments => [value, tokenLen, ttl] +-- Obtain.lua try to set provided keys's with value and ttl if they do not exists. +-- Keys can be overriden if they already exists and the correct value+tokenLen is provided. + +local function pexpire(ttl) + -- Update keys ttls. + for _, key in ipairs(KEYS) do + redis.call("pexpire", key, ttl) + end +end + +-- canOverrideLock check either or not the provided token match +-- previously set lock's tokens. +local function canOverrideKeys() + local offset = tonumber(ARGV[2]) + + for _, key in ipairs(KEYS) do + if redis.call("getrange", key, 0, offset-1) ~= string.sub(ARGV[1], 1, offset) then + return false + end + end + return true +end + +-- Prepare mset arguments. +local setArgs = {} +for _, key in ipairs(KEYS) do + table.insert(setArgs, key) + table.insert(setArgs, ARGV[1]) +end + +if redis.call("msetnx", unpack(setArgs)) ~= 1 then + if canOverrideKeys() == false then + return false + end + redis.call("mset", unpack(setArgs)) +end + +pexpire(ARGV[3]) +return redis.status_reply("OK") \ No newline at end of file diff --git a/ioc/config/lock/redis_lua/pttl.lua b/ioc/config/lock/redis_lua/pttl.lua new file mode 100644 index 0000000..0756001 --- /dev/null +++ b/ioc/config/lock/redis_lua/pttl.lua @@ -0,0 +1,21 @@ +-- pttl.lua: => Arguments: [value] +-- pttl.lua returns provided keys's ttls if all their values match the input. + +-- Check all keys values matches provided input. +local values = redis.call("mget", unpack(KEYS)) +for i, _ in ipairs(KEYS) do + if values[i] ~= ARGV[1] then + return false + end +end + +-- Find and return shortest TTL among keys. +local minTTL = 0 +for _, key in ipairs(KEYS) do + local ttl = redis.call("pttl", key) + -- Note: ttl < 0 probably means the key no longer exists. + if ttl > 0 and (minTTL == 0 or ttl < minTTL) then + minTTL = ttl + end +end +return minTTL \ No newline at end of file diff --git a/ioc/config/lock/redis_lua/refresh.lua b/ioc/config/lock/redis_lua/refresh.lua new file mode 100644 index 0000000..a7ca443 --- /dev/null +++ b/ioc/config/lock/redis_lua/refresh.lua @@ -0,0 +1,17 @@ +-- refresh.lua: => Arguments: [value, ttl] +-- refresh.lua refreshes provided keys's ttls if all their values match the input. + +-- Check all keys values matches provided input. +local values = redis.call("mget", unpack(KEYS)) +for i, _ in ipairs(KEYS) do + if values[i] ~= ARGV[1] then + return false + end +end + +-- Update keys ttls. +for _, key in ipairs(KEYS) do + redis.call("pexpire", key, ARGV[2]) +end + +return redis.status_reply("OK") \ No newline at end of file diff --git a/ioc/config/lock/redis_lua/release.lua b/ioc/config/lock/redis_lua/release.lua new file mode 100644 index 0000000..a0ffd6f --- /dev/null +++ b/ioc/config/lock/redis_lua/release.lua @@ -0,0 +1,16 @@ + +-- release.lua: => Arguments: [value] +-- Release.lua deletes provided keys if all their values match the input. + +-- Check all keys values matches provided input. +local values = redis.call("mget", unpack(KEYS)) +for i, _ in ipairs(KEYS) do + if values[i] ~= ARGV[1] then + return false + end +end + +-- Delete keys. +redis.call("del", unpack(KEYS)) + +return redis.status_reply("OK")