diff --git a/go.mod b/go.mod index ab3240c..79033cd 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/go-logr/logr v1.3.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/golang/protobuf v1.5.3 // indirect + github.com/google/uuid v1.5.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.15.0 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index b359725..5e47701 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU= +github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.15.0 h1:1JYBfzqrWPcCclBwxFCPAou9n+q86mfnu7NAeHfte7A= github.com/grpc-ecosystem/grpc-gateway/v2 v2.15.0/go.mod h1:YDZoGHuwE+ov0c8smSH49WLF3F2LaWnYYuDVd+EWrc0= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= diff --git a/pkg/statestore/redis.go b/pkg/statestore/redis.go index cee5a6a..3b8b8e1 100644 --- a/pkg/statestore/redis.go +++ b/pkg/statestore/redis.go @@ -122,11 +122,18 @@ func (s *RedisStore) CreateTicket(ctx context.Context, ticket *pb.Ticket) error } func (s *RedisStore) DeleteTicket(ctx context.Context, ticketID string) error { + lockedCtx, unlock, err := s.locker.WithContext(ctx, redisKeyFetchTicketsLock(s.opts.keyPrefix)) + if err != nil { + return fmt.Errorf("failed to acquire fetch tickets lock: %w", err) + } + defer unlock() + queries := []rueidis.Completed{ s.client.B().Del().Key(redisKeyTicketData(s.opts.keyPrefix, ticketID)).Build(), s.client.B().Srem().Key(redisKeyTicketIndex(s.opts.keyPrefix)).Member(ticketID).Build(), + s.client.B().Zrem().Key(redisKeyPendingTicketIndex(s.opts.keyPrefix)).Member(ticketID).Build(), } - for _, resp := range s.client.DoMulti(ctx, queries...) { + for _, resp := range s.client.DoMulti(lockedCtx, queries...) { if err := resp.Error(); err != nil { return fmt.Errorf("failed to delete ticket: %w", err) } @@ -225,7 +232,13 @@ func (s *RedisStore) setTicketsToPending(ctx context.Context, ticketIDs []string } func (s *RedisStore) ReleaseTickets(ctx context.Context, ticketIDs []string) error { - resp := s.client.Do(ctx, s.client.B().Zrem().Key(redisKeyPendingTicketIndex(s.opts.keyPrefix)).Member(ticketIDs...).Build()) + lockedCtx, unlock, err := s.locker.WithContext(ctx, redisKeyFetchTicketsLock(s.opts.keyPrefix)) + if err != nil { + return fmt.Errorf("failed to acquire fetch tickets lock: %w", err) + } + defer unlock() + + resp := s.client.Do(lockedCtx, s.client.B().Zrem().Key(redisKeyPendingTicketIndex(s.opts.keyPrefix)).Member(ticketIDs...).Build()) if err := resp.Error(); err != nil { return fmt.Errorf("failed to release tickets: %w", err) } @@ -239,10 +252,8 @@ func (s *RedisStore) AssignTickets(ctx context.Context, asgs []*pb.AssignmentGro continue } // deindex assigned tickets - for _, resp := range s.client.DoMulti(ctx, s.deIndexTickets(asg.TicketIds)...) { - if err := resp.Error(); err != nil { - return fmt.Errorf("failed to deindex assigned tickets: %w", err) - } + if err := s.deIndexTickets(ctx, asg.TicketIds); err != nil { + return fmt.Errorf("failed to deindex assigned tickets: %w", err) } // set assignment to a tickets redis := s.client @@ -364,11 +375,23 @@ func (s *RedisStore) setTicketsExpiration(ctx context.Context, ticketIDs []strin return nil } -func (s *RedisStore) deIndexTickets(ticketIDs []string) []rueidis.Completed { - return []rueidis.Completed{ +func (s *RedisStore) deIndexTickets(ctx context.Context, ticketIDs []string) error { + lockedCtx, unlock, err := s.locker.WithContext(ctx, redisKeyFetchTicketsLock(s.opts.keyPrefix)) + if err != nil { + return fmt.Errorf("failed to acquire fetch tickets lock: %w", err) + } + defer unlock() + + cmds := []rueidis.Completed{ s.client.B().Zrem().Key(redisKeyPendingTicketIndex(s.opts.keyPrefix)).Member(ticketIDs...).Build(), s.client.B().Srem().Key(redisKeyTicketIndex(s.opts.keyPrefix)).Member(ticketIDs...).Build(), } + for _, resp := range s.client.DoMulti(lockedCtx, cmds...) { + if err := resp.Error(); err != nil { + return fmt.Errorf("failed to deindex tickets: %w", err) + } + } + return nil } //nolint:unused diff --git a/pkg/statestore/redis_test.go b/pkg/statestore/redis_test.go index 25dac85..bd0a97b 100644 --- a/pkg/statestore/redis_test.go +++ b/pkg/statestore/redis_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/alicebob/miniredis/v2" + "github.com/google/uuid" "github.com/redis/rueidis" "github.com/redis/rueidis/rueidislock" "github.com/rs/xid" @@ -165,19 +166,21 @@ func TestTicketTTL(t *testing.T) { } func TestConcurrentFetchActiveTickets(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) defer cancel() mr := miniredis.RunT(t) store := newTestRedisStore(t, mr.Addr()) - for i := 0; i < 1000; i++ { + ticketCount := 1000 + concurrency := 1000 + for i := 0; i < ticketCount; i++ { require.NoError(t, store.CreateTicket(ctx, &pb.Ticket{Id: xid.New().String()})) } eg, _ := errgroup.WithContext(ctx) var mu sync.Mutex duplicateMap := map[string]struct{}{} - for i := 0; i < 1000; i++ { + for i := 0; i < concurrency; i++ { eg.Go(func() error { ticketIDs, err := store.GetActiveTicketIDs(ctx, 1000) if err != nil { @@ -197,3 +200,60 @@ func TestConcurrentFetchActiveTickets(t *testing.T) { } require.NoError(t, eg.Wait()) } + +func TestConcurrentFetchAndAssign(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + mr := miniredis.RunT(t) + store := newTestRedisStore(t, mr.Addr()) + + ticketCount := 1000 + concurrency := 1000 + for i := 0; i < ticketCount; i++ { + ticket := &pb.Ticket{Id: xid.New().String()} + require.NoError(t, store.CreateTicket(ctx, ticket)) + } + + var mu sync.Mutex + duplicateMap := map[string]struct{}{} + eg, _ := errgroup.WithContext(ctx) + for i := 0; i < concurrency; i++ { + eg.Go(func() error { + ticketIDs, err := store.GetActiveTicketIDs(ctx, 1000) + if err != nil { + return err + } + var asgs []*pb.AssignmentGroup + matches := chunkBy(ticketIDs[:len(ticketIDs)/2], 2) + for _, match := range matches { + if len(match) >= 2 { + asgs = append(asgs, &pb.AssignmentGroup{TicketIds: match, Assignment: &pb.Assignment{Connection: uuid.New().String()}}) + } + } + for _, asg := range asgs { + for _, tid := range asg.TicketIds { + mu.Lock() + if _, ok := duplicateMap[tid]; ok { + mu.Unlock() + return fmt.Errorf("duplicated! ticket id: %s", tid) + } + duplicateMap[tid] = struct{}{} + mu.Unlock() + } + } + if err := store.AssignTickets(ctx, asgs); err != nil { + return err + } + return nil + }) + } + require.NoError(t, eg.Wait()) +} + +// https://stackoverflow.com/a/72408490 +func chunkBy[T any](items []T, chunkSize int) (chunks [][]T) { + for chunkSize < len(items) { + items, chunks = items[chunkSize:], append(chunks, items[0:chunkSize:chunkSize]) + } + return append(chunks, items) +}