Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for SMEMBERS.WATCH command #1289

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
178 changes: 178 additions & 0 deletions integration_tests/commands/resp/smemberswatch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package resp

import (
"context"
"fmt"
"net"
"testing"
"time"

"github.com/dicedb/dice/internal/clientio"
dicedb "github.com/dicedb/dicedb-go"
"gotest.tools/v3/assert"
testifyAssert "github.com/stretchr/testify/assert"
)

type smembersWatchTestCase struct {
key string
val string
result any
}

const (
smembersCommand = "SMEMBERS"
smembersWatchKey = "smemberswatchkey"
smembersWatchQuery = "SMEMBERS.WATCH %s"
smembersWatchFingerPrint = "3660753939"
)

var smembersWatchTestCases = []smembersWatchTestCase{
{smembersWatchKey, "member1", []any{"member1"}},
{smembersWatchKey, "member2", []any{"member1", "member2"}},
{smembersWatchKey, "member3", []any{"member1", "member2", "member3"}},
}

func TestSMEMBERSWATCH(t *testing.T) {
publisher := getLocalConnection()
subscribers := setupSubscribers(3)

FireCommand(publisher, fmt.Sprintf("DEL %s", smembersWatchKey))

defer func() {
if err := publisher.Close(); err != nil {
t.Errorf("Error closing publisher connection: %v", err)
}
for _, sub := range subscribers {
time.Sleep(100 * time.Millisecond)
if err := sub.Close(); err != nil {
t.Errorf("Error closing subscriber connection: %v", err)
}
}
}()

respParsers := setUpSmembersRespParsers(t, subscribers)

t.Run("Basic Set Operations", func(t *testing.T) {
testSetOperations(t, publisher, respParsers)
})
}

func setUpSmembersRespParsers(t *testing.T, subscribers []net.Conn) []*clientio.RESPParser {
respParsers := make([]*clientio.RESPParser, len(subscribers))
for i, subscriber := range subscribers {
rp := fireCommandAndGetRESPParser(subscriber, fmt.Sprintf(smembersWatchQuery, smembersWatchKey))
assert.Assert(t, rp != nil)
respParsers[i] = rp

v, err := rp.DecodeOne()
assert.NilError(t, err)
castedValue, ok := v.([]interface{})
if !ok {
t.Errorf("Type assertion to []interface{} failed for value: %v", v)
}
assert.Equal(t, 3, len(castedValue))
}
return respParsers
}

func testSetOperations(t *testing.T, publisher net.Conn, respParsers []*clientio.RESPParser) {
for _, tc := range smembersWatchTestCases {
res := FireCommand(publisher, fmt.Sprintf("SADD %s %s", tc.key, tc.val))
assert.Equal(t, int64(1), res)
verifySmembersWatchResults(t, respParsers, tc.result)
}
}

func verifySmembersWatchResults(t *testing.T, respParsers []*clientio.RESPParser, expected any) {
for _, rp := range respParsers {
v, err := rp.DecodeOne()
assert.NilError(t, err)
castedValue, ok := v.([]interface{})
if !ok {
t.Errorf("Type assertion to []interface{} failed for value: %v", v)
}
assert.Equal(t, 3, len(castedValue))
assert.Equal(t, smembersCommand, castedValue[0])
assert.Equal(t, smembersWatchFingerPrint, castedValue[1])
testifyAssert.ElementsMatch(t, expected, castedValue[2])
}
}

type smembersWatchSDKTestCase struct {
key string
val string
result []string
}

var smembersWatchSDKTestCases = []smembersWatchSDKTestCase{
{smembersWatchKey, "member1", []string{"member1"}},
{smembersWatchKey, "member2", []string{"member1", "member2"}},
{smembersWatchKey, "member3", []string{"member1", "member2", "member3"}},
}

func TestSMEMBERSWATCHWithSDK(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()

publisher := getLocalSdk()
subscribers := setupSubscribersSDK(3)
defer cleanupSubscribersSDK(subscribers)

Check failure on line 119 in integration_tests/commands/resp/smemberswatch_test.go

View workflow job for this annotation

GitHub Actions / build

undefined: cleanupSubscribersSDK

publisher.Del(ctx, smembersWatchKey)

channels := setUpSmembersWatchChannelsSDK(t, ctx, subscribers)

t.Run("Basic Set Operations", func(t *testing.T) {
testSetOperationsSDK(t, ctx, channels, publisher)
})
}

func setUpSmembersWatchChannelsSDK(t *testing.T, ctx context.Context, subscribers []WatchSubscriber) []<-chan *dicedb.WatchResult {
channels := make([]<-chan *dicedb.WatchResult, len(subscribers))
for i, subscriber := range subscribers {
watch := subscriber.client.WatchConn(ctx)
subscribers[i].watch = watch
assert.Assert(t, watch != nil)
firstMsg, err := watch.Watch(ctx, smembersCommand, smembersWatchKey)
assert.NilError(t, err)
assert.Equal(t, firstMsg.Command, smembersCommand)
channels[i] = watch.Channel()
}
return channels
}

func testSetOperationsSDK(t *testing.T, ctx context.Context, channels []<-chan *dicedb.WatchResult, publisher *dicedb.Client) {
for _, tc := range smembersWatchSDKTestCases {
err := publisher.SAdd(ctx, tc.key, tc.val).Err()
assert.NilError(t, err)
verifySmembersWatchResultsSDK(t, channels, tc.result)
}
}

func verifySmembersWatchResultsSDK(t *testing.T, channels []<-chan *dicedb.WatchResult, expected []string) {
for _, channel := range channels {
select {
case v := <-channel:
assert.Equal(t, smembersCommand, v.Command)
assert.Equal(t, smembersWatchFingerPrint, v.Fingerprint)

received, ok := v.Data.([]interface{})
if !ok {
t.Fatalf("Expected []interface{}, got %T", v.Data)
}

receivedStrings := make([]string, len(received))
for i, item := range received {
str, ok := item.(string)
if !ok {
t.Fatalf("Expected string, got %T", item)
}
receivedStrings[i] = str
}

testifyAssert.ElementsMatch(t, expected, receivedStrings)
case <-time.After(defaultTimeout):
t.Fatal("timeout waiting for watch result")
}
}
}
90 changes: 46 additions & 44 deletions internal/eval/store_eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -4753,50 +4753,52 @@ func evalHDEL(args []string, store *dstore.Store) *EvalResponse {
// Returns an integer which represents the number of members that were added to the set, not including
// the members that were already present
func evalSADD(args []string, store *dstore.Store) *EvalResponse {
if len(args) < 2 {
return &EvalResponse{
Result: nil,
Error: diceerrors.ErrWrongArgumentCount("SADD"),
}
}
key := args[0]

// Get the set object from the store.
obj := store.Get(key)
lengthOfItems := len(args[1:])

var count = 0
if obj == nil {
var exDurationMs int64 = -1
var keepttl = false
// If the object does not exist, create a new set object.
value := make(map[string]struct{}, lengthOfItems)
// Create a new object.
obj = store.NewObj(value, exDurationMs, object.ObjTypeSet)
store.Put(key, obj, dstore.WithKeepTTL(keepttl))
}

if err := object.AssertType(obj.Type, object.ObjTypeSet); err != nil {
return &EvalResponse{
Result: nil,
Error: diceerrors.ErrWrongTypeOperation,
}
}

// Get the set object.
set := obj.Value.(map[string]struct{})

for _, arg := range args[1:] {
if _, ok := set[arg]; !ok {
set[arg] = struct{}{}
count++
}
}

return &EvalResponse{
Result: count,
Error: nil,
}
if len(args) < 2 {
return &EvalResponse{
Result: nil,
Error: diceerrors.ErrWrongArgumentCount("SADD"),
}
}
key := args[0]

// Get the set object from the store.
obj := store.Get(key)
lengthOfItems := len(args[1:])

var count = 0
var set map[string]struct{}

if obj == nil {
// If the object does not exist, create a new set
set = make(map[string]struct{}, lengthOfItems)
} else {
// Type checks
if err := object.AssertType(obj.Type, object.ObjTypeSet); err != nil {
return &EvalResponse{
Result: nil,
Error: diceerrors.ErrWrongTypeOperation,
}
}

set = obj.Value.(map[string]struct{})
}

// Add elements to the set
for _, arg := range args[1:] {
if _, ok := set[arg]; !ok {
set[arg] = struct{}{}
count++
}
}

// Single Put operation at the end
obj = store.NewObj(set, -1, object.ObjTypeSet)
store.Put(key, obj, dstore.WithKeepTTL(false), dstore.WithPutCmd(dstore.SADD))

return &EvalResponse{
Result: count,
Error: nil,
}
}

// evalSREM removes one or more members from a set
Expand Down
4 changes: 4 additions & 0 deletions internal/iothread/cmd_meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ const (
CmdSrem = "SREM"
CmdScard = "SCARD"
CmdSmembers = "SMEMBERS"
CmdSMembersWatch = "SMEMBERS.WATCH"
CmdDump = "DUMP"
CmdRestore = "RESTORE"
CmdGeoAdd = "GEOADD"
Expand Down Expand Up @@ -675,6 +676,9 @@ var CommandsMeta = map[string]CmdMeta{
CmdPFCountWatch: {
CmdType: Watch,
},
CmdSMembersWatch: {
CmdType: Watch,
},

// Unwatch commands
CmdGetUnWatch: {
Expand Down
2 changes: 2 additions & 0 deletions internal/store/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ const (
PFMERGE string = "PFMERGE"
KEYSPERSHARD string = "KEYSPERSHARD"
Evict string = "EVICT"
SADD string = "SADD"
SMEMBERS string = "SMEMBERS"
SingleShardSize string = "SINGLEDBSIZE"
SingleShardTouch string = "SINGLETOUCH"
SingleShardKeys string = "SINGLEKEYS"
Expand Down
1 change: 1 addition & 0 deletions internal/watchmanager/watch_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ var (
dstore.ZAdd: {dstore.ZRange: struct{}{}},
dstore.PFADD: {dstore.PFCOUNT: struct{}{}},
dstore.PFMERGE: {dstore.PFCOUNT: struct{}{}},
dstore.SADD: {dstore.SMEMBERS: struct{}{}},
}
)

Expand Down
Loading