diff --git a/README.md b/README.md index 735e5a1c..4c55a538 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Go Coverage](https://github.com/EchoVault/EchoVault/wiki/coverage.svg)](https://raw.githack.com/wiki/EchoVault/EchoVault/coverage.html) [![GitHub Release](https://img.shields.io/github/v/release/EchoVault/EchoVault)]()
-[![License: GPL v2](https://img.shields.io/badge/License-GPL_v2-blue.svg)](https://www.gnu.org/licenses/old-licenses/gpl-2.0.en.html) +[![License: GPL v3](https://img.shields.io/badge/License-GPL_v3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0.en.html)
[![Discord](https://img.shields.io/discord/1211815152291414037?style=flat&logo=discord&link=https%3A%2F%2Fdiscord.gg%2Fvt45CKfF)](https://discord.gg/vt45CKfF) diff --git a/docker-compose.yaml b/docker-compose.yaml index c4bd72ea..07945174 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -19,7 +19,7 @@ services: - DATA_DIR=/var/lib/echovault - IN_MEMORY=false - TLS=true - - MTLS=true + - MTLS=false - BOOTSTRAP_CLUSTER=false - ACL_CONFIG=/etc/config/echovault/acl.yml - REQUIRE_PASS=true @@ -36,7 +36,7 @@ services: # List of client certificate authorities - CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt ports: - - "7479:7480" + - "7480:7480" - "7946:7946" - "7999:8000" volumes: @@ -76,7 +76,7 @@ services: # List of client certificate authorities - CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt ports: - - "7480:7480" + - "7481:7480" - "7945:7946" - "8000:8000" volumes: @@ -117,7 +117,7 @@ services: # List of client certificate authorities - CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt ports: - - "7481:7480" + - "7482:7480" - "7947:7946" - "8001:8000" volumes: @@ -158,7 +158,7 @@ services: # List of client certificate authorities - CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt ports: - - "7482:7480" + - "7483:7480" - "7948:7946" - "8002:8000" volumes: @@ -199,7 +199,7 @@ services: # List of client certificate authorities - CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt ports: - - "7483:7480" + - "7484:7480" - "7949:7946" - "8003:8000" volumes: @@ -240,7 +240,7 @@ services: # List of client certificate authorities - CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt ports: - - "7484:7480" + - "7485:7480" - "7950:7946" - "8004:8000" volumes: diff --git a/src/memberlist/delegate.go b/src/memberlist/delegate.go index 01314cdb..8bf78b9a 100644 --- a/src/memberlist/delegate.go +++ b/src/memberlist/delegate.go @@ -7,6 +7,7 @@ import ( "github.com/echovault/echovault/src/utils" "github.com/hashicorp/memberlist" "github.com/hashicorp/raft" + "log" "time" ) @@ -79,14 +80,12 @@ func (delegate *Delegate) NotifyMsg(msgBytes []byte) { cmd, err := utils.Decode(msg.Content) if err != nil { - // TODO: Log error to configured logger - fmt.Println(err) + log.Println(err) return } if _, err := delegate.options.applyMutate(ctx, cmd); err != nil { - // TODO: Log error to configured logger - fmt.Println(err) + log.Println(err) } } } diff --git a/src/modules/acl/acl.go b/src/modules/acl/acl.go index 7ec8f40f..0b08a2e3 100644 --- a/src/modules/acl/acl.go +++ b/src/modules/acl/acl.go @@ -105,9 +105,10 @@ func NewACL(config utils.Config) *ACL { func (acl *ACL) RegisterConnection(conn *net.Conn) { // This is called only when a connection is established. - defaultUser := utils.Filter(acl.Users, func(elem *User) bool { - return elem.Username == "default" - })[0] + defaultUserIdx := slices.IndexFunc(acl.Users, func(user *User) bool { + return user.Username == "default" + }) + defaultUser := acl.Users[defaultUserIdx] acl.Connections[conn] = Connection{ Authenticated: defaultUser.NoPassword, User: defaultUser, @@ -167,7 +168,7 @@ func (acl *ACL) DeleteUser(ctx context.Context, usernames []string) error { } } // Delete the user from the ACL - acl.Users = utils.Filter(acl.Users, func(u *User) bool { + acl.Users = slices.DeleteFunc(acl.Users, func(u *User) bool { return u.Username != user.Username }) } @@ -188,9 +189,10 @@ func (acl *ACL) AuthenticateConnection(ctx context.Context, conn *net.Conn, cmd {PasswordType: "SHA256", PasswordValue: string(h.Sum(nil))}, } // Authenticate with default user - user = utils.Filter(acl.Users, func(user *User) bool { + idx := slices.IndexFunc(acl.Users, func(user *User) bool { return user.Username == "default" - })[0] + }) + user = acl.Users[idx] } if len(cmd) == 3 { // Process AUTH diff --git a/src/modules/acl/commands.go b/src/modules/acl/commands.go index 6db9ccaa..ae4764e8 100644 --- a/src/modules/acl/commands.go +++ b/src/modules/acl/commands.go @@ -139,7 +139,7 @@ func handleGetUser(ctx context.Context, cmd []string, server utils.Server, conn res = res + fmt.Sprintf("\r\n+-&%s", channel) } - res += "\r\n\r\n" + res += "\r\n" return []byte(res), nil } @@ -179,7 +179,7 @@ func handleCat(ctx context.Context, cmd []string, server utils.Server, conn *net for i, cat := range cats { res = fmt.Sprintf("%s\r\n+%s", res, cat) if i == len(cats)-1 { - res = res + "\r\n\r\n" + res = res + "\r\n" } } return []byte(res), nil @@ -193,7 +193,7 @@ func handleCat(ctx context.Context, cmd []string, server utils.Server, conn *net for i, command := range commands { res = fmt.Sprintf("%s\r\n+%s", res, command) if i == len(commands)-1 { - res = res + "\r\n\r\n" + res = res + "\r\n" } } return []byte(res), nil @@ -213,7 +213,7 @@ func handleUsers(ctx context.Context, cmd []string, server utils.Server, conn *n for _, user := range acl.Users { res += fmt.Sprintf("\r\n$%d\r\n%s", len(user.Username), user.Username) } - res += "\r\n\r\n" + res += "\r\n" return []byte(res), nil } @@ -248,7 +248,7 @@ func handleWhoAmI(ctx context.Context, cmd []string, server utils.Server, conn * return nil, errors.New("could not load ACL") } connectionInfo := acl.Connections[conn] - return []byte(fmt.Sprintf("+%s\r\n\r\n", connectionInfo.User.Username)), nil + return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil } func handleList(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -343,7 +343,7 @@ func handleList(ctx context.Context, cmd []string, server utils.Server, conn *ne res = res + fmt.Sprintf("\r\n$%d\r\n%s", len(s), s) } - res = res + "\r\n\r\n" + res = res + "\r\n" return []byte(res), nil } diff --git a/src/modules/acl/user.go b/src/modules/acl/user.go index 1d7dcddb..51370356 100644 --- a/src/modules/acl/user.go +++ b/src/modules/acl/user.go @@ -1,7 +1,6 @@ package acl import ( - "github.com/echovault/echovault/src/utils" "slices" "strings" ) @@ -105,18 +104,18 @@ func (user *User) UpdateUser(cmd []string) error { continue } if str[0] == '<' { - user.Passwords = utils.Filter(user.Passwords, func(password Password) bool { + user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool { if strings.EqualFold(password.PasswordType, "SHA256") { - return true + return false } return password.PasswordValue == str[1:] }) continue } if str[0] == '!' { - user.Passwords = utils.Filter(user.Passwords, func(password Password) bool { + user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool { if strings.EqualFold(password.PasswordType, "plaintext") { - return true + return false } return password.PasswordValue == str[1:] }) diff --git a/src/modules/admin/commands.go b/src/modules/admin/commands.go index 4e80b00b..7bdca3d5 100644 --- a/src/modules/admin/commands.go +++ b/src/modules/admin/commands.go @@ -5,7 +5,10 @@ import ( "errors" "fmt" "github.com/echovault/echovault/src/utils" + "github.com/gobwas/glob" "net" + "slices" + "strings" ) func handleGetAllCommands(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) { @@ -48,11 +51,123 @@ func handleGetAllCommands(ctx context.Context, cmd []string, server utils.Server } } - res = fmt.Sprintf("*%d\r\n%s\r\n", commandCount, res) + res = fmt.Sprintf("*%d\r\n%s", commandCount, res) return []byte(res), nil } +func handleCommandCount(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) { + var count int + + commands := server.GetAllCommands(ctx) + for _, command := range commands { + if command.SubCommands != nil && len(command.SubCommands) > 0 { + for _, _ = range command.SubCommands { + count += 1 + } + continue + } + count += 1 + } + + return []byte(fmt.Sprintf(":%d\r\n", count)), nil +} + +func handleCommandList(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) { + switch len(cmd) { + case 2: + // Command is COMMAND LIST + var count int + var res string + commands := server.GetAllCommands(ctx) + for _, command := range commands { + if command.SubCommands != nil && len(command.SubCommands) > 0 { + for _, subcommand := range command.SubCommands { + comm := fmt.Sprintf("%s %s", command.Command, subcommand.Command) + res += fmt.Sprintf("$%d\r\n%s\r\n", len(comm), comm) + count += 1 + } + continue + } + res += fmt.Sprintf("$%d\r\n%s\r\n", len(command.Command), command.Command) + count += 1 + } + res = fmt.Sprintf("*%d\r\n%s", count, res) + return []byte(res), nil + + case 5: + var count int + var res string + // Command has filter + if !strings.EqualFold("FILTERBY", cmd[2]) { + return nil, fmt.Errorf("expected FILTERBY, got %s", strings.ToUpper(cmd[2])) + } + if strings.EqualFold("ACLCAT", cmd[3]) { + // ACL Category filter + commands := server.GetAllCommands(ctx) + category := strings.ToLower(cmd[4]) + for _, command := range commands { + if command.SubCommands != nil && len(command.SubCommands) > 0 { + for _, subcommand := range command.SubCommands { + if slices.Contains(subcommand.Categories, category) { + comm := fmt.Sprintf("%s %s", command.Command, subcommand.Command) + res += fmt.Sprintf("$%d\r\n%s\r\n", len(comm), comm) + count += 1 + } + } + continue + } + if slices.Contains(command.Categories, category) { + res += fmt.Sprintf("$%d\r\n%s\r\n", len(command.Command), command.Command) + count += 1 + } + } + } else if strings.EqualFold("PATTERN", cmd[3]) { + // Pattern filter + commands := server.GetAllCommands(ctx) + g := glob.MustCompile(cmd[4]) + for _, command := range commands { + if command.SubCommands != nil && len(command.SubCommands) > 0 { + for _, subcommand := range command.SubCommands { + comm := fmt.Sprintf("%s %s", command.Command, subcommand.Command) + if g.Match(comm) { + res += fmt.Sprintf("$%d\r\n%s\r\n", len(comm), comm) + count += 1 + } + } + continue + } + if g.Match(command.Command) { + res += fmt.Sprintf("$%d\r\n%s\r\n", len(command.Command), command.Command) + count += 1 + } + } + } else { + return nil, fmt.Errorf("expected filter to be ACLCAT or PATTERN, got %s", strings.ToUpper(cmd[3])) + } + res = fmt.Sprintf("*%d\r\n%s", count, res) + return []byte(res), nil + default: + return nil, errors.New(utils.WRONG_ARGS_RESPONSE) + } +} + +func handleCommandDocs(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) { + return []byte("*0\r\n"), nil +} + +// func handleConfigGet(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) { +// return nil, errors.New("command not yet implemented") +// } +// +// func handleConfigRewrite(ctx context.Context, cmd []string, server *utils.Server, _ *net.Conn) ([]byte, error) { +// return nil, errors.New("command not yet implemented") +// } +// +// func handleConfigSet(ctx context.Context, cmd []string, server *utils.Server, _ *net.Conn) ([]byte, error) { +// return nil, errors.New("command not yet implemented") +// } + func Commands() []utils.Command { return []utils.Command{ { @@ -63,6 +178,42 @@ func Commands() []utils.Command { KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, HandlerFunc: handleGetAllCommands, }, + { + Command: "command", + Categories: []string{}, + Description: "Commands pertaining to echovault commands", + Sync: false, + KeyExtractionFunc: func(cmd []string) ([]string, error) { + return []string{}, nil + }, + SubCommands: []utils.SubCommand{ + { + Command: "docs", + Categories: []string{utils.SlowCategory, utils.ConnectionCategory}, + Description: "Get command documentation", + Sync: false, + KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, + HandlerFunc: handleCommandDocs, + }, + { + Command: "count", + Categories: []string{utils.SlowCategory}, + Description: "Get the dumber of commands in the server", + Sync: false, + KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, + HandlerFunc: handleCommandCount, + }, + { + Command: "list", + Categories: []string{utils.SlowCategory}, + Description: `(COMMAND LIST [FILTERBY ]) Get the list of command names. +Allows for filtering by ACL category or glob pattern.`, + Sync: false, + KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, + HandlerFunc: handleCommandList, + }, + }, + }, { Command: "save", Categories: []string{utils.AdminCategory, utils.SlowCategory, utils.DangerousCategory}, @@ -91,7 +242,7 @@ func Commands() []utils.Command { if msec == 0 { return nil, errors.New("no snapshot") } - return []byte(fmt.Sprintf(":%d\r\n\r\n", msec)), nil + return []byte(fmt.Sprintf(":%d\r\n", msec)), nil }, }, { diff --git a/src/modules/etc/commands.go b/src/modules/etc/commands.go index 4e8993f0..266663f1 100644 --- a/src/modules/etc/commands.go +++ b/src/modules/etc/commands.go @@ -2,6 +2,7 @@ package etc import ( "context" + "errors" "fmt" "github.com/echovault/echovault/src/utils" "net" @@ -50,9 +51,7 @@ func handleSetNX(ctx context.Context, cmd []string, server utils.Server, conn *n if server.KeyExists(key) { return nil, fmt.Errorf("key %s already exists", key) } - // TODO: Retry CreateKeyAndLock until we manage to obtain the key - _, err = server.CreateKeyAndLock(ctx, key) - if err != nil { + if _, err = server.CreateKeyAndLock(ctx, key); err != nil { return nil, err } server.SetValue(ctx, key, utils.AdaptType(cmd[2])) @@ -115,6 +114,10 @@ func handleMSet(ctx context.Context, cmd []string, server utils.Server, conn *ne return []byte(utils.OK_RESPONSE), nil } +func handleCopy(ctx context.Context, cmd []string, server *utils.Server, _ *net.Conn) ([]byte, error) { + return nil, errors.New("command not yet implemented") +} + func Commands() []utils.Command { return []utils.Command{ { diff --git a/src/modules/get/commands.go b/src/modules/get/commands.go index ddc3679d..d0b9bcbd 100644 --- a/src/modules/get/commands.go +++ b/src/modules/get/commands.go @@ -15,7 +15,7 @@ func handleGet(ctx context.Context, cmd []string, server utils.Server, conn *net key := keys[0] if !server.KeyExists(key) { - return []byte("+nil\r\n\r\n"), nil + return []byte("$-1\r\n"), nil } _, err = server.KeyRLock(ctx, key) @@ -26,7 +26,7 @@ func handleGet(ctx context.Context, cmd []string, server utils.Server, conn *net value := server.GetValue(key) - return []byte(fmt.Sprintf("+%v\r\n\r\n", value)), nil + return []byte(fmt.Sprintf("+%v\r\n", value)), nil } func handleMGet(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -51,7 +51,7 @@ func handleMGet(ctx context.Context, cmd []string, server utils.Server, conn *ne locks[key] = true continue } - values[key] = "nil" + values[key] = "" } defer func() { for key, locked := range locks { @@ -69,11 +69,13 @@ func handleMGet(ctx context.Context, cmd []string, server utils.Server, conn *ne bytes := []byte(fmt.Sprintf("*%d\r\n", len(cmd[1:]))) for _, key := range cmd[1:] { + if values[key] == "" { + bytes = append(bytes, []byte("$-1\r\n")...) + continue + } bytes = append(bytes, []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(values[key]), values[key]))...) } - bytes = append(bytes, []byte("\r\n")...) - return bytes, nil } diff --git a/src/modules/get/commands_test.go b/src/modules/get/commands_test.go index a431d0e2..97418006 100644 --- a/src/modules/get/commands_test.go +++ b/src/modules/get/commands_test.go @@ -47,8 +47,8 @@ func Test_HandleGET(t *testing.T) { if err != nil { t.Error(err) } - if !bytes.Equal(res, []byte(fmt.Sprintf("+%v\r\n\r\n", value))) { - t.Errorf("expected %s, got: %s", fmt.Sprintf("+%v\r\n\r\n", value), string(res)) + if !bytes.Equal(res, []byte(fmt.Sprintf("+%v\r\n", value))) { + t.Errorf("expected %s, got: %s", fmt.Sprintf("+%v\r\n", value), string(res)) } }(test.key, test.value) } @@ -58,8 +58,8 @@ func Test_HandleGET(t *testing.T) { if err != nil { t.Error(err) } - if !bytes.Equal(res, []byte("+nil\r\n\r\n")) { - t.Errorf("expected %+v, got: %+v", "+nil\r\n\r\n", res) + if !bytes.Equal(res, []byte("$-1\r\n")) { + t.Errorf("expected %+v, got: %+v", "+nil\r\n", res) } errorTests := []struct { @@ -93,21 +93,21 @@ func Test_HandleMGET(t *testing.T) { presetKeys []string presetValues []string command []string - expected []string + expected []interface{} expectedError error }{ { presetKeys: []string{"test1", "test2", "test3", "test4"}, presetValues: []string{"value1", "value2", "value3", "value4"}, command: []string{"MGET", "test1", "test4", "test2", "test3", "test1"}, - expected: []string{"value1", "value4", "value2", "value3", "value1"}, + expected: []interface{}{"value1", "value4", "value2", "value3", "value1"}, expectedError: nil, }, { presetKeys: []string{"test5", "test6", "test7"}, presetValues: []string{"value5", "value6", "value7"}, command: []string{"MGET", "test5", "test6", "non-existent", "non-existent", "test7", "non-existent"}, - expected: []string{"value5", "value6", "nil", "nil", "value7", "nil"}, + expected: []interface{}{"value5", "value6", nil, nil, "value7", nil}, expectedError: nil, }, { @@ -150,6 +150,12 @@ func Test_HandleMGET(t *testing.T) { t.Errorf("expected type Array, got: %s", rv.Type().String()) } for i, value := range rv.Array() { + if test.expected[i] == nil { + if !value.IsNull() { + t.Errorf("expected nil value, got %+v", value) + } + continue + } if value.String() != test.expected[i] { t.Errorf("expected value %s, got: %s", test.expected[i], value.String()) } diff --git a/src/modules/hash/commands.go b/src/modules/hash/commands.go index 300b6db0..fe102684 100644 --- a/src/modules/hash/commands.go +++ b/src/modules/hash/commands.go @@ -36,7 +36,7 @@ func handleHSET(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) server.SetValue(ctx, key, entries) - return []byte(fmt.Sprintf(":%d\r\n\r\n", len(entries))), nil + return []byte(fmt.Sprintf(":%d\r\n", len(entries))), nil } if _, err = server.KeyLock(ctx, key); err != nil { @@ -63,7 +63,7 @@ func handleHSET(ctx context.Context, cmd []string, server utils.Server, conn *ne } server.SetValue(ctx, key, hash) - return []byte(fmt.Sprintf(":%d\r\n\r\n", count)), nil + return []byte(fmt.Sprintf(":%d\r\n", count)), nil } func handleHGET(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -76,7 +76,7 @@ func handleHGET(ctx context.Context, cmd []string, server utils.Server, conn *ne fields := cmd[2:] if !server.KeyExists(key) { - return []byte("$-1\r\n\r\n"), nil + return []byte("$-1\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -113,7 +113,6 @@ func handleHGET(ctx context.Context, cmd []string, server utils.Server, conn *ne } res += fmt.Sprintf("$-1\r\n") } - res += "\r\n" return []byte(res), nil } @@ -128,7 +127,7 @@ func handleHSTRLEN(ctx context.Context, cmd []string, server utils.Server, conn fields := cmd[2:] if !server.KeyExists(key) { - return []byte("$-1\r\n\r\n"), nil + return []byte("$-1\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -165,7 +164,6 @@ func handleHSTRLEN(ctx context.Context, cmd []string, server utils.Server, conn } res += ":0\r\n" } - res += "\r\n" return []byte(res), nil } @@ -179,7 +177,7 @@ func handleHVALS(ctx context.Context, cmd []string, server utils.Server, conn *n key := keys[0] if !server.KeyExists(key) { - return []byte("*0\r\n\r\n"), nil + return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -207,7 +205,6 @@ func handleHVALS(ctx context.Context, cmd []string, server utils.Server, conn *n res += fmt.Sprintf(":%d\r\n", d) } } - res += "\r\n" return []byte(res), nil } @@ -227,7 +224,7 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server utils.Server, co return nil, errors.New("count must be an integer") } if c == 0 { - return []byte("*0\r\n\r\n"), nil + return []byte("*0\r\n"), nil } count = c } @@ -242,7 +239,7 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server utils.Server, co } if !server.KeyExists(key) { - return []byte("*0\r\n\r\n"), nil + return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -279,7 +276,6 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server utils.Server, co } } } - res += "\r\n" return []byte(res), nil } @@ -325,7 +321,6 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server utils.Server, co } } } - res += "\r\n" return []byte(res), nil } @@ -339,7 +334,7 @@ func handleHLEN(ctx context.Context, cmd []string, server utils.Server, conn *ne key := keys[0] if !server.KeyExists(key) { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -352,7 +347,7 @@ func handleHLEN(ctx context.Context, cmd []string, server utils.Server, conn *ne return nil, fmt.Errorf("value at %s is not a hash", key) } - return []byte(fmt.Sprintf(":%d\r\n\r\n", len(hash))), nil + return []byte(fmt.Sprintf(":%d\r\n", len(hash))), nil } func handleHKEYS(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -364,7 +359,7 @@ func handleHKEYS(ctx context.Context, cmd []string, server utils.Server, conn *n key := keys[0] if !server.KeyExists(key) { - return []byte("*0\r\n\r\n"), nil + return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -381,7 +376,6 @@ func handleHKEYS(ctx context.Context, cmd []string, server utils.Server, conn *n for field, _ := range hash { res += fmt.Sprintf("$%d\r\n%s\r\n", len(field), field) } - res += "\r\n" return []byte(res), nil } @@ -421,11 +415,11 @@ func handleHINCRBY(ctx context.Context, cmd []string, server utils.Server, conn if strings.EqualFold(cmd[0], "hincrbyfloat") { hash[field] = floatIncrement server.SetValue(ctx, key, hash) - return []byte(fmt.Sprintf("+%s\r\n\r\n", strconv.FormatFloat(floatIncrement, 'f', -1, 64))), nil + return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(floatIncrement, 'f', -1, 64))), nil } else { hash[field] = intIncrement server.SetValue(ctx, key, hash) - return []byte(fmt.Sprintf(":%d\r\n\r\n", intIncrement)), nil + return []byte(fmt.Sprintf(":%d\r\n", intIncrement)), nil } } @@ -465,11 +459,11 @@ func handleHINCRBY(ctx context.Context, cmd []string, server utils.Server, conn server.SetValue(ctx, key, hash) if f, ok := hash[field].(float64); ok { - return []byte(fmt.Sprintf("+%s\r\n\r\n", strconv.FormatFloat(f, 'f', -1, 64))), nil + return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(f, 'f', -1, 64))), nil } i, _ := hash[field].(int) - return []byte(fmt.Sprintf(":%d\r\n\r\n", i)), nil + return []byte(fmt.Sprintf(":%d\r\n", i)), nil } func handleHGETALL(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -481,7 +475,7 @@ func handleHGETALL(ctx context.Context, cmd []string, server utils.Server, conn key := keys[0] if !server.KeyExists(key) { - return []byte("*0\r\n\r\n"), nil + return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -508,7 +502,6 @@ func handleHGETALL(ctx context.Context, cmd []string, server utils.Server, conn res += fmt.Sprintf(":%d\r\n", d) } } - res += "\r\n" return []byte(res), nil } @@ -523,7 +516,7 @@ func handleHEXISTS(ctx context.Context, cmd []string, server utils.Server, conn field := cmd[2] if !server.KeyExists(key) { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -537,10 +530,10 @@ func handleHEXISTS(ctx context.Context, cmd []string, server utils.Server, conn } if hash[field] != nil { - return []byte(":1\r\n\r\n"), nil + return []byte(":1\r\n"), nil } - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } func handleHDEL(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -553,7 +546,7 @@ func handleHDEL(ctx context.Context, cmd []string, server utils.Server, conn *ne fields := cmd[2:] if !server.KeyExists(key) { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { @@ -577,7 +570,7 @@ func handleHDEL(ctx context.Context, cmd []string, server utils.Server, conn *ne server.SetValue(ctx, key, hash) - return []byte(fmt.Sprintf(":%d\r\n\r\n", count)), nil + return []byte(fmt.Sprintf(":%d\r\n", count)), nil } func Commands() []utils.Command { diff --git a/src/modules/list/commands.go b/src/modules/list/commands.go index e1e10873..a9a13017 100644 --- a/src/modules/list/commands.go +++ b/src/modules/list/commands.go @@ -11,7 +11,7 @@ import ( "strings" ) -func handleLLen(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { +func handleLLen(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) { keys, err := llenKeyFunc(cmd) if err != nil { return nil, err @@ -21,7 +21,7 @@ func handleLLen(ctx context.Context, cmd []string, server utils.Server, conn *ne if !server.KeyExists(key) { // If key does not exist, return 0 - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -30,7 +30,7 @@ func handleLLen(ctx context.Context, cmd []string, server utils.Server, conn *ne defer server.KeyRUnlock(key) if list, ok := server.GetValue(key).([]interface{}); ok { - return []byte(fmt.Sprintf(":%d\r\n\r\n", len(list))), nil + return []byte(fmt.Sprintf(":%d\r\n", len(list))), nil } return nil, errors.New("LLEN command on non-list item") @@ -67,7 +67,7 @@ func handleLIndex(ctx context.Context, cmd []string, server utils.Server, conn * return nil, errors.New("index must be within list range") } - return []byte(fmt.Sprintf("+%s\r\n\r\n", list[index])), nil + return []byte(fmt.Sprintf("+%s\r\n", list[index])), nil } func handleLRange(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -117,7 +117,6 @@ func handleLRange(ctx context.Context, cmd []string, server utils.Server, conn * str := fmt.Sprintf("%v", list[i]) bytes = append(bytes, []byte("$"+fmt.Sprint(len(str))+"\r\n"+str+"\r\n")...) } - bytes = append(bytes, []byte("\r\n")...) return bytes, nil } @@ -145,11 +144,8 @@ func handleLRange(ctx context.Context, cmd []string, server utils.Server, conn * } else { i-- } - } - bytes = append(bytes, []byte("\r\n")...) - return bytes, nil } @@ -377,7 +373,6 @@ func handleLPush(ctx context.Context, cmd []string, server utils.Server, conn *n case "lpushx": return nil, errors.New("LPUSHX command on non-list item") default: - // TODO: Retry CreateKeyAndLock until we obtain the key lock if _, err = server.CreateKeyAndLock(ctx, key); err != nil { return nil, err } @@ -420,7 +415,6 @@ func handleRPush(ctx context.Context, cmd []string, server utils.Server, conn *n case "rpushx": return nil, errors.New("RPUSHX command on non-list item") default: - // TODO: Retry CreateKeyAndLock until we managed to obtain the key if _, err = server.CreateKeyAndLock(ctx, key); err != nil { return nil, err } @@ -470,10 +464,10 @@ func handlePop(ctx context.Context, cmd []string, server utils.Server, conn *net switch strings.ToLower(cmd[0]) { default: server.SetValue(ctx, key, list[1:]) - return []byte(fmt.Sprintf("+%v\r\n\r\n", list[0])), nil + return []byte(fmt.Sprintf("+%v\r\n", list[0])), nil case "rpop": server.SetValue(ctx, key, list[:len(list)-1]) - return []byte(fmt.Sprintf("+%v\r\n\r\n", list[len(list)-1])), nil + return []byte(fmt.Sprintf("+%v\r\n", list[len(list)-1])), nil } } diff --git a/src/modules/ping/commands.go b/src/modules/ping/commands.go index 4779d533..509b755f 100644 --- a/src/modules/ping/commands.go +++ b/src/modules/ping/commands.go @@ -13,9 +13,9 @@ func handlePing(ctx context.Context, cmd []string, server utils.Server, conn *ne default: return nil, errors.New(utils.WRONG_ARGS_RESPONSE) case 1: - return []byte("+PONG\r\n\r\n"), nil + return []byte("+PONG\r\n"), nil case 2: - return []byte(fmt.Sprintf("$%d\r\n%s\r\n\r\n", len(cmd[1]), cmd[1])), nil + return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(cmd[1]), cmd[1])), nil } } @@ -31,17 +31,5 @@ func Commands() []utils.Command { }, HandlerFunc: handlePing, }, - { - Command: "ack", - Categories: []string{}, - Description: "", - Sync: false, - KeyExtractionFunc: func(cmd []string) ([]string, error) { - return []string{}, nil - }, - HandlerFunc: func(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { - return []byte("$-1\r\n\r\n"), nil - }, - }, } } diff --git a/src/modules/pubsub/channel.go b/src/modules/pubsub/channel.go new file mode 100644 index 00000000..41ae70f4 --- /dev/null +++ b/src/modules/pubsub/channel.go @@ -0,0 +1,113 @@ +package pubsub + +import ( + "fmt" + "github.com/gobwas/glob" + "io" + "log" + "net" + "slices" + "sync" +) + +// Channel - A channel can be subscribed to directly, or via a consumer group. +// All direct subscribers to the channel will receive any message published to the channel. +// Only one subscriber of a channel's consumer group will receive a message posted to the channel. +type Channel struct { + name string + pattern glob.Glob + subscribersRWMut sync.RWMutex + subscribers []*net.Conn + messageChan *chan string +} + +func WithName(name string) func(channel *Channel) { + return func(channel *Channel) { + channel.name = name + } +} + +func WithPattern(pattern string) func(channel *Channel) { + return func(channel *Channel) { + channel.name = pattern + channel.pattern = glob.MustCompile(pattern) + } +} + +func NewChannel(options ...func(channel *Channel)) *Channel { + messageChan := make(chan string, 4096) + + channel := &Channel{ + name: "", + pattern: nil, + subscribersRWMut: sync.RWMutex{}, + subscribers: []*net.Conn{}, + messageChan: &messageChan, + } + + for _, option := range options { + option(channel) + } + + return channel +} + +func (ch *Channel) Start() { + go func() { + for { + message := <-*ch.messageChan + + ch.subscribersRWMut.RLock() + + for _, conn := range ch.subscribers { + go func(conn *net.Conn) { + w := io.Writer(*conn) + + if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(message), message))); err != nil { + log.Println(err) + } + }(conn) + } + + ch.subscribersRWMut.RUnlock() + } + }() +} + +func (ch *Channel) Subscribe(conn *net.Conn) { + if !slices.Contains(ch.subscribers, conn) { + ch.subscribersRWMut.Lock() + defer ch.subscribersRWMut.Unlock() + + ch.subscribers = append(ch.subscribers, conn) + } +} + +func (ch *Channel) Unsubscribe(conn *net.Conn) bool { + ch.subscribersRWMut.Lock() + defer ch.subscribersRWMut.Unlock() + + var removed bool + + ch.subscribers = slices.DeleteFunc(ch.subscribers, func(c *net.Conn) bool { + if c == conn { + removed = true + return true + } + return false + }) + + return removed +} + +func (ch *Channel) Publish(message string) { + *ch.messageChan <- message +} + +func (ch *Channel) IsActive() bool { + return len(ch.subscribers) > 0 +} + +func (ch *Channel) NumSubs() int { + return len(ch.subscribers) +} diff --git a/src/modules/pubsub/commands.go b/src/modules/pubsub/commands.go index 1b550dc5..5ed322b9 100644 --- a/src/modules/pubsub/commands.go +++ b/src/modules/pubsub/commands.go @@ -3,48 +3,56 @@ package pubsub import ( "context" "errors" + "fmt" "github.com/echovault/echovault/src/utils" "net" + "strings" ) func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { pubsub, ok := server.GetPubSub().(*PubSub) if !ok { - return nil, errors.New("could not load pubsub") + return nil, errors.New("could not load pubsub module") } - switch len(cmd) { - case 2: - // Subscribe to specified channel - pubsub.Subscribe(ctx, conn, cmd[1], nil) - case 3: - // Subscribe to specified channel and specified consumer group - pubsub.Subscribe(ctx, conn, cmd[1], cmd[2]) - default: + + channels := cmd[1:] + + if len(channels) == 0 { return nil, errors.New(utils.WRONG_ARGS_RESPONSE) } - return []byte("+SUBSCRIBE_OK\r\n\r\n"), nil + + switch strings.ToLower(cmd[0]) { + case "subscribe": + return pubsub.Subscribe(ctx, conn, channels, false), nil + case "psubscribe": + return pubsub.Subscribe(ctx, conn, channels, true), nil + } + + return []byte{}, nil } func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { pubsub, ok := server.GetPubSub().(*PubSub) if !ok { - return nil, errors.New("could not load pubsub") + return nil, errors.New("could not load pubsub module") } - switch len(cmd) { - case 1: - pubsub.Unsubscribe(ctx, conn, nil) - case 2: - pubsub.Unsubscribe(ctx, conn, cmd[1]) + + channels := cmd[1:] + + switch strings.ToLower(cmd[0]) { + case "unsubscribe": + return pubsub.Unsubscribe(ctx, conn, channels, false), nil + case "punsubscribe": + return pubsub.Unsubscribe(ctx, conn, channels, true), nil default: - return nil, errors.New(utils.WRONG_ARGS_RESPONSE) + return []byte{}, nil } - return []byte(utils.OK_RESPONSE), nil } func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { pubsub, ok := server.GetPubSub().(*PubSub) if !ok { - return nil, errors.New("could not load pubsub") + return nil, errors.New("could not load pubsub module") } if len(cmd) != 3 { return nil, errors.New(utils.WRONG_ARGS_RESPONSE) @@ -53,49 +61,149 @@ func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn return []byte(utils.OK_RESPONSE), nil } +func handlePubSubChannels(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { + if len(cmd) > 3 { + return nil, errors.New(utils.WRONG_ARGS_RESPONSE) + } + + pubsub, ok := server.GetPubSub().(*PubSub) + if !ok { + return nil, errors.New("could not load pubsub module") + } + + pattern := "" + if len(cmd) == 3 { + pattern = cmd[2] + } + + return pubsub.Channels(ctx, pattern), nil +} + +func handlePubSubNumPat(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { + pubsub, ok := server.GetPubSub().(*PubSub) + if !ok { + return nil, errors.New("could not load pubsub module") + } + num := pubsub.NumPat(ctx) + return []byte(fmt.Sprintf(":%d\r\n", num)), nil +} + +func handlePubSubNumSubs(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { + pubsub, ok := server.GetPubSub().(*PubSub) + if !ok { + return nil, errors.New("could not load pubsub module") + } + return pubsub.NumSub(ctx, cmd[2:]), nil +} + func Commands() []utils.Command { return []utils.Command{ { - Command: "publish", - Categories: []string{utils.PubSubCategory, utils.FastCategory}, - Description: "(PUBLISH channel message) Publish a message to the specified channel.", - Sync: true, + Command: "subscribe", + Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory}, + Description: "(SUBSCRIBE channel [channel ...]) Subscribe to one or more channels.", + Sync: false, KeyExtractionFunc: func(cmd []string) ([]string, error) { - // Treat the channel as a key - if len(cmd) != 3 { + // Treat the channels as keys + if len(cmd) < 2 { return nil, errors.New(utils.WRONG_ARGS_RESPONSE) } - return []string{cmd[1]}, nil + return cmd[1:], nil }, - HandlerFunc: handlePublish, + HandlerFunc: handleSubscribe, }, { - Command: "subscribe", + Command: "psubscribe", Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory}, - Description: "(SUBSCRIBE channel [consumer_group]) Subscribe to a channel with an option to join a consumer group on the channel.", + Description: "(PSUBSCRIBE pattern [pattern ...]) Subscribe to one or more glob patterns.", Sync: false, KeyExtractionFunc: func(cmd []string) ([]string, error) { - // Treat the channel as a key + // Treat the patterns as keys if len(cmd) < 2 { return nil, errors.New(utils.WRONG_ARGS_RESPONSE) } - return []string{cmd[1]}, nil + return cmd[1:], nil }, HandlerFunc: handleSubscribe, }, { - Command: "unsubscribe", - Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory}, - Description: "(UNSUBSCRIBE channel) Unsubscribe from a channel.", - Sync: false, + Command: "publish", + Categories: []string{utils.PubSubCategory, utils.FastCategory}, + Description: "(PUBLISH channel message) Publish a message to the specified channel.", + Sync: true, KeyExtractionFunc: func(cmd []string) ([]string, error) { // Treat the channel as a key - if len(cmd) != 2 { + if len(cmd) != 3 { return nil, errors.New(utils.WRONG_ARGS_RESPONSE) } return []string{cmd[1]}, nil }, + HandlerFunc: handlePublish, + }, + { + Command: "unsubscribe", + Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory}, + Description: `(UNSUBSCRIBE [channel [channel ...]]) Unsubscribe from a list of channels. +If the channel list is not provided, then the connection will be unsubscribed from all the channels that +it's currently subscribe to.`, + Sync: false, + KeyExtractionFunc: func(cmd []string) ([]string, error) { + // Treat the channels as keys + return cmd[1:], nil + }, HandlerFunc: handleUnsubscribe, }, + { + Command: "punsubscribe", + Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory}, + Description: `(PUNSUBSCRIBE [channel [channel ...]]) Unsubscribe from a list of channels using patterns. +If the pattern list is not provided, then the connection will be unsubscribed from all the patterns that +it's currently subscribe to.`, + Sync: false, + KeyExtractionFunc: func(cmd []string) ([]string, error) { + // Treat the channels as keys + return cmd[1:], nil + }, + HandlerFunc: handleUnsubscribe, + }, + { + Command: "pubsub", + Categories: []string{}, + Description: "", + Sync: false, + KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, + HandlerFunc: func(_ context.Context, _ []string, _ utils.Server, _ *net.Conn) ([]byte, error) { + return nil, errors.New("provide CHANNELS, NUMPAT, or NUMSUB subcommand") + }, + SubCommands: []utils.SubCommand{ + { + Command: "channels", + Categories: []string{utils.PubSubCategory, utils.SlowCategory}, + Description: `(PUBSUB CHANNELS [pattern]) Returns an array containing the list of channels that +match the given pattern. If no pattern is provided, all active channels are returned. Active channels are +channels with 1 or more subscribers.`, + Sync: false, + KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, + HandlerFunc: handlePubSubChannels, + }, + { + Command: "numpat", + Categories: []string{utils.PubSubCategory, utils.SlowCategory}, + Description: `(PUBSUB NUMPAT) Return the number of patterns that are currently subscribed to by clients.`, + Sync: false, + KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, + HandlerFunc: handlePubSubNumPat, + }, + { + Command: "numsub", + Categories: []string{utils.PubSubCategory, utils.SlowCategory}, + Description: `(PUBSUB NUMSUB [channel [channel ...]]) Return an array of arrays containing the provided +channel name and how many clients are currently subscribed to the channel.`, + Sync: false, + KeyExtractionFunc: func(cmd []string) ([]string, error) { return cmd[2:], nil }, + HandlerFunc: handlePubSubNumSubs, + }, + }, + }, } } diff --git a/src/modules/pubsub/pubsub.go b/src/modules/pubsub/pubsub.go index 82fed04d..433e4477 100644 --- a/src/modules/pubsub/pubsub.go +++ b/src/modules/pubsub/pubsub.go @@ -1,308 +1,203 @@ package pubsub import ( - "bytes" - "container/ring" "context" "fmt" - "github.com/echovault/echovault/src/utils" - "io" + "github.com/gobwas/glob" "net" "slices" "sync" - "time" ) -// ConsumerGroup allows multiple subscribers to share the consumption load of a channel. -// Only one subscriber in the consumer group will receive messages published to the channel. -type ConsumerGroup struct { - name string - subscribersRWMut sync.RWMutex - subscribers *ring.Ring - messageChan *chan string +// PubSub container +type PubSub struct { + channels []*Channel + channelsRWMut sync.RWMutex } -func NewConsumerGroup(name string) *ConsumerGroup { - messageChan := make(chan string) - - return &ConsumerGroup{ - name: name, - subscribersRWMut: sync.RWMutex{}, - subscribers: nil, - messageChan: &messageChan, +func NewPubSub() *PubSub { + return &PubSub{ + channels: []*Channel{}, + channelsRWMut: sync.RWMutex{}, } } -func (cg *ConsumerGroup) SendMessage(message string) { - cg.subscribersRWMut.RLock() +func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte { + res := fmt.Sprintf("*%d\r\n", len(channels)) - conn := cg.subscribers.Value.(*net.Conn) + for i := 0; i < len(channels); i++ { + // Check if channel with given name exists + // If it does, subscribe the connection to the channel + // If it does not, create the channel and subscribe to it + channelIdx := slices.IndexFunc(ps.channels, func(channel *Channel) bool { + return channel.name == channels[i] + }) - cg.subscribersRWMut.RUnlock() - - w, r := io.Writer(*conn), io.Reader(*conn) - - if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message))); err != nil { - // TODO: Log error at configured logger - fmt.Println(err) - } - // Wait for an ACK - // If no ACK is received within a time limit, remove this connection from subscribers and retry - if err := (*conn).SetReadDeadline(time.Now().Add(250 * time.Millisecond)); err != nil { - // TODO: Log error at configured logger - fmt.Println(err) - } - if msg, err := utils.ReadMessage(r); err != nil { - // Remove the connection from subscribers list - cg.Unsubscribe(conn) - // Reset the deadline - if err := (*conn).SetReadDeadline(time.Time{}); err != nil { - // TODO: Log error at configured logger - fmt.Println(err) - } - // Retry sending the message - cg.SendMessage(message) - } else { - if !bytes.Equal(bytes.TrimSpace(msg), []byte("+ACK")) { - cg.Unsubscribe(conn) - if err := (*conn).SetReadDeadline(time.Time{}); err != nil { - // TODO: Log error at configured logger - fmt.Println(err) + if channelIdx == -1 { + // Create new channel, start it, and subscribe to it + var newChan *Channel + if withPattern { + newChan = NewChannel(WithPattern(channels[i])) + } else { + newChan = NewChannel(WithName(channels[i])) } - cg.SendMessage(message) + newChan.Start() + newChan.Subscribe(conn) + ps.channels = append(ps.channels, newChan) + } else { + // Subscribe to existing channel + ps.channels[channelIdx].Subscribe(conn) } - } - if err := (*conn).SetDeadline(time.Time{}); err != nil { - // TODO: Log error at configured logger - fmt.Println(err) + if len(channels) > 1 { + // If subscribing to more than one channel, write array to verify the subscription of this channel + res += fmt.Sprintf("*3\r\n+subscribe\r\n$%d\r\n%s\r\n:%d\r\n", len(channels[i]), channels[i], i+1) + } else { + // Ony one channel, simply send "subscribe" simple string response + res = "+subscribe\r\n" + } } - cg.subscribers = cg.subscribers.Next() -} -func (cg *ConsumerGroup) Start() { - go func() { - for { - message := <-*cg.messageChan - if cg.subscribers != nil { - cg.SendMessage(message) - } - } - }() + return []byte(res) } -func (cg *ConsumerGroup) Subscribe(conn *net.Conn) { - cg.subscribersRWMut.Lock() - defer cg.subscribersRWMut.Unlock() - - r := ring.New(1) - for i := 0; i < r.Len(); i++ { - r.Value = conn - r = r.Next() - } +func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte { + ps.channelsRWMut.RLock() + ps.channelsRWMut.RUnlock() - if cg.subscribers == nil { - cg.subscribers = r - return + action := "unsubscribe" + if withPattern { + action = "subscribe" } - cg.subscribers = cg.subscribers.Link(r) -} - -func (cg *ConsumerGroup) Unsubscribe(conn *net.Conn) { - cg.subscribersRWMut.Lock() - defer cg.subscribersRWMut.Unlock() + unsubscribed := make(map[int]string) + count := 1 - // If length is 1 and the connection passed is the one contained within, unlink it - if cg.subscribers.Len() == 1 { - if cg.subscribers.Value == conn { - cg.subscribers = nil + // If the channels slice is empty, unsubscribe from all channels. + if len(channels) <= 0 { + for _, channel := range ps.channels { + if channel.Unsubscribe(conn) { + unsubscribed[1] = channel.name + count += 1 + } } - return } - for i := 0; i < cg.subscribers.Len(); i++ { - if cg.subscribers.Value == conn { - cg.subscribers = cg.subscribers.Prev() - cg.subscribers.Unlink(1) - break + // If withPattern is false, unsubscribe from channels where the name exactly matches channel name. + if !withPattern { + for _, channel := range ps.channels { // For each channel in PubSub + for _, c := range channels { // For each channel name provided + if channel.name == c && channel.Unsubscribe(conn) { + unsubscribed[count] = channel.name + count += 1 + } + } } - cg.subscribers = cg.subscribers.Next() - } -} - -func (cg *ConsumerGroup) Publish(message string) { - *cg.messageChan <- message -} - -// Channel - A channel can be subscribed to directly, or via a consumer group. -// All direct subscribers to the channel will receive any message published to the channel. -// Only one subscriber of a channel's consumer group will receive a message posted to the channel. -type Channel struct { - name string - subscribersRWMut sync.RWMutex - subscribers []*net.Conn - consumerGroups []*ConsumerGroup - messageChan *chan string -} - -func NewChannel(name string) *Channel { - messageChan := make(chan string) - - return &Channel{ - name: name, - subscribersRWMut: sync.RWMutex{}, - subscribers: []*net.Conn{}, - consumerGroups: []*ConsumerGroup{}, - messageChan: &messageChan, } -} - -func (ch *Channel) Start() { - go func() { - for { - message := <-*ch.messageChan - - ch.subscribersRWMut.RLock() - - for _, conn := range ch.subscribers { - go func(conn *net.Conn) { - w, r := io.Writer(*conn), io.Reader(*conn) - if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n\r\n", len(message), message))); err != nil { - // TODO: Log error at configured logger - fmt.Println(err) - } - - if err := (*conn).SetReadDeadline(time.Now().Add(200 * time.Millisecond)); err != nil { - // TODO: Log error at configured logger - fmt.Println(err) - ch.Unsubscribe(conn) - } - defer func() { - if err := (*conn).SetReadDeadline(time.Time{}); err != nil { - // TODO: Log error at configured logger - fmt.Println(err) - ch.Unsubscribe(conn) - } - }() - - if msg, err := utils.ReadMessage(r); err != nil { - ch.Unsubscribe(conn) - } else { - if !bytes.EqualFold(bytes.TrimSpace(msg), []byte("+ACK")) { - ch.Unsubscribe(conn) - } - } - }(conn) + // If withPattern is true, unsubscribe from channels where pattern matches pattern provided, + // also unsubscribe from channels where the name matches the given pattern. + if withPattern { + for _, pattern := range channels { + g := glob.MustCompile(pattern) + for _, channel := range ps.channels { + // If it's a pattern channel, directly compare the patterns + if channel.pattern != nil && channel.name == pattern { + unsubscribed[count] = channel.name + count += 1 + continue + } + // If this is a regular channel, check if the channel name matches the pattern given + if g.Match(channel.name) { + unsubscribed[count] = channel.name + count += 1 + } } - - ch.subscribersRWMut.RUnlock() } - }() -} - -func (ch *Channel) Subscribe(conn *net.Conn, consumerGroupName interface{}) { - if consumerGroupName == nil && !slices.Contains(ch.subscribers, conn) { - ch.subscribersRWMut.Lock() - defer ch.subscribersRWMut.Unlock() - ch.subscribers = append(ch.subscribers, conn) - return } - groups := utils.Filter[*ConsumerGroup](ch.consumerGroups, func(group *ConsumerGroup) bool { - return group.name == consumerGroupName.(string) - }) - - if len(groups) == 0 { - go func() { - newGroup := NewConsumerGroup(consumerGroupName.(string)) - newGroup.Start() - newGroup.Subscribe(conn) - ch.consumerGroups = append(ch.consumerGroups, newGroup) - }() - return + res := fmt.Sprintf("*%d\r\n", len(unsubscribed)) + for key, value := range unsubscribed { + res += fmt.Sprintf("*3\r\n+%s\r\n$%d\r\n%s\r\n:%d\r\n", action, len(value), value, key) } - for _, group := range groups { - go group.Subscribe(conn) - } + return []byte(res) } -func (ch *Channel) Unsubscribe(conn *net.Conn) { - ch.subscribersRWMut.Lock() - defer ch.subscribersRWMut.Unlock() - - ch.subscribers = utils.Filter[*net.Conn](ch.subscribers, func(c *net.Conn) bool { - return c != conn - }) - - for _, group := range ch.consumerGroups { - go group.Unsubscribe(conn) +func (ps *PubSub) Publish(ctx context.Context, message string, channelName string) { + ps.channelsRWMut.RLock() + defer ps.channelsRWMut.RUnlock() + + for _, channel := range ps.channels { + // If it's a regular channel, check if the channel name matches the name given + if channel.pattern == nil { + if channel.name == channelName { + channel.Publish(message) + } + continue + } + // If it's a glob pattern channel, check if the name matches the pattern + if channel.pattern.Match(channelName) { + channel.Publish(message) + } } } -func (ch *Channel) Publish(message string) { - for _, group := range ch.consumerGroups { - go group.Publish(message) - } - *ch.messageChan <- message -} +func (ps *PubSub) Channels(ctx context.Context, pattern string) []byte { + var count int + var res string -// PubSub container -type PubSub struct { - channels []*Channel -} + if pattern == "" { + for _, channel := range ps.channels { + if channel.IsActive() { + res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name) + count += 1 + } + } -func NewPubSub() *PubSub { - return &PubSub{ - channels: []*Channel{}, + res = fmt.Sprintf("*%d\r\n%s", count, res) + return []byte(res) } -} -func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName string, consumerGroup interface{}) { - // Check if channel with given name exists - // If it does, subscribe the connection to the channel - // If it does not, create the channel and subscribe to it - channelIdx := slices.IndexFunc(ps.channels, func(channel *Channel) bool { - return channel.name == channelName - }) + g := glob.MustCompile(pattern) - if channelIdx == -1 { - go func() { - newChan := NewChannel(channelName) - newChan.Start() - newChan.Subscribe(conn, consumerGroup) - ps.channels = append(ps.channels, newChan) - }() - return + for _, channel := range ps.channels { + // If channel is a pattern channel, then directly compare the channel name to pattern + if channel.pattern != nil && channel.name == pattern && channel.IsActive() { + res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name) + count += 1 + continue + } + if g.Match(channel.name) && channel.IsActive() { + res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name) + count += 1 + } } - go ps.channels[channelIdx].Subscribe(conn, consumerGroup) + return []byte(res) } -func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName interface{}) { - if channelName == nil { - for _, channel := range ps.channels { - go channel.Unsubscribe(conn) +func (ps *PubSub) NumPat(ctx context.Context) int { + var count int + for _, channel := range ps.channels { + if channel.pattern != nil { + count += 1 } - return - } - - channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool { - return c.name == channelName - }) - - for _, channel := range channels { - go channel.Unsubscribe(conn) } + return count } -func (ps *PubSub) Publish(ctx context.Context, message string, channelName string) { - channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool { - return c.name == channelName - }) +func (ps *PubSub) NumSub(ctx context.Context, channels []string) []byte { + res := fmt.Sprintf("*%d\r\n", len(channels)) for _, channel := range channels { - go channel.Publish(message) + chanIdx := slices.IndexFunc(ps.channels, func(c *Channel) bool { + return c.name == channel + }) + if chanIdx == -1 { + res += fmt.Sprintf("*2\r\n$%d\r\n%s\r\n:0\r\n", len(channel), channel) + continue + } + res += fmt.Sprintf("*2\r\n$%d\r\n%s\r\n:%d\r\n", len(channel), channel, ps.channels[chanIdx].NumSubs()) } + return []byte(res) } diff --git a/src/modules/set/commands.go b/src/modules/set/commands.go index c425fbee..e3699ab5 100644 --- a/src/modules/set/commands.go +++ b/src/modules/set/commands.go @@ -27,7 +27,7 @@ func handleSADD(ctx context.Context, cmd []string, server utils.Server, conn *ne } server.SetValue(ctx, key, set) server.KeyUnlock(key) - return []byte(fmt.Sprintf(":%d\r\n\r\n", len(cmd[2:]))), nil + return []byte(fmt.Sprintf(":%d\r\n", len(cmd[2:]))), nil } if _, err = server.KeyLock(ctx, key); err != nil { @@ -42,7 +42,7 @@ func handleSADD(ctx context.Context, cmd []string, server utils.Server, conn *ne count := set.Add(cmd[2:]) - return []byte(fmt.Sprintf(":%d\r\n\r\n", count)), nil + return []byte(fmt.Sprintf(":%d\r\n", count)), nil } func handleSCARD(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -54,7 +54,7 @@ func handleSCARD(ctx context.Context, cmd []string, server utils.Server, conn *n key := keys[0] if !server.KeyExists(key) { - return []byte(fmt.Sprintf(":0\r\n\r\n")), nil + return []byte(fmt.Sprintf(":0\r\n")), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -69,7 +69,7 @@ func handleSCARD(ctx context.Context, cmd []string, server utils.Server, conn *n cardinality := set.Cardinality() - return []byte(fmt.Sprintf(":%d\r\n\r\n", cardinality)), nil + return []byte(fmt.Sprintf(":%d\r\n", cardinality)), nil } func handleSDIFF(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -126,7 +126,7 @@ func handleSDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n for i, e := range elems { res = fmt.Sprintf("%s\r\n$%d\r\n%s", res, len(e), e) if i == len(elems)-1 { - res += "\r\n\r\n" + res += "\r\n" } } @@ -185,7 +185,7 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co diff := baseSet.Subtract(sets) elems := diff.GetAll() - res := fmt.Sprintf(":%d\r\n\r\n", len(elems)) + res := fmt.Sprintf(":%d\r\n", len(elems)) if server.KeyExists(destination) { if _, err = server.KeyLock(ctx, destination); err != nil { @@ -223,7 +223,7 @@ func handleSINTER(ctx context.Context, cmd []string, server utils.Server, conn * for _, key := range keys[0:] { if !server.KeyExists(key) { // If key does not exist, then there is no intersection - return []byte("*0\r\n\r\n"), nil + return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err @@ -253,7 +253,7 @@ func handleSINTER(ctx context.Context, cmd []string, server utils.Server, conn * for i, e := range elems { res = fmt.Sprintf("%s\r\n$%d\r\n%s", res, len(e), e) if i == len(elems)-1 { - res += "\r\n\r\n" + res += "\r\n" } } @@ -299,7 +299,7 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server utils.Server, co for _, key := range keys { if !server.KeyExists(key) { // If key does not exist, then there is no intersection - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err @@ -324,7 +324,7 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server utils.Server, co intersect, _ := Intersection(limit, sets...) - return []byte(fmt.Sprintf(":%d\r\n\r\n", intersect.Cardinality())), nil + return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil } func handleSINTERSTORE(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -345,7 +345,7 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c for _, key := range keys[1:] { if !server.KeyExists(key) { // If key does not exist, then there is no intersection - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err @@ -380,7 +380,7 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c server.SetValue(ctx, destination, intersect) server.KeyUnlock(destination) - return []byte(fmt.Sprintf(":%d\r\n\r\n", intersect.Cardinality())), nil + return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil } func handleSISMEMBER(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -392,7 +392,7 @@ func handleSISMEMBER(ctx context.Context, cmd []string, server utils.Server, con key := keys[0] if !server.KeyExists(key) { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -406,10 +406,10 @@ func handleSISMEMBER(ctx context.Context, cmd []string, server utils.Server, con } if !set.Contains(cmd[2]) { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } - return []byte(":1\r\n\r\n"), nil + return []byte(":1\r\n"), nil } func handleSMEMBERS(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -421,7 +421,7 @@ func handleSMEMBERS(ctx context.Context, cmd []string, server utils.Server, conn key := keys[0] if !server.KeyExists(key) { - return []byte("*0\r\n\r\n"), nil + return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -440,7 +440,7 @@ func handleSMEMBERS(ctx context.Context, cmd []string, server utils.Server, conn for i, e := range elems { res = fmt.Sprintf("%s\r\n$%d\r\n%s", res, len(e), e) if i == len(elems)-1 { - res += "\r\n\r\n" + res += "\r\n" } } @@ -461,7 +461,7 @@ func handleSMISMEMBER(ctx context.Context, cmd []string, server utils.Server, co for i, _ := range members { res = fmt.Sprintf("%s\r\n:0", res) if i == len(members)-1 { - res += "\r\n\r\n" + res += "\r\n" } } return []byte(res), nil @@ -485,7 +485,7 @@ func handleSMISMEMBER(ctx context.Context, cmd []string, server utils.Server, co res += "\r\n:0" } } - res += "\r\n\r\n" + res += "\r\n" return []byte(res), nil } @@ -501,7 +501,7 @@ func handleSMOVE(ctx context.Context, cmd []string, server utils.Server, conn *n member := cmd[3] if !server.KeyExists(source) { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, source); err != nil { @@ -539,7 +539,7 @@ func handleSMOVE(ctx context.Context, cmd []string, server utils.Server, conn *n res := sourceSet.Move(destinationSet, member) - return []byte(fmt.Sprintf(":%d\r\n\r\n", res)), nil + return []byte(fmt.Sprintf(":%d\r\n", res)), nil } func handleSPOP(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -560,7 +560,7 @@ func handleSPOP(ctx context.Context, cmd []string, server utils.Server, conn *ne } if !server.KeyExists(key) { - return []byte("*-1\r\n\r\n"), nil + return []byte("*-1\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { @@ -579,7 +579,7 @@ func handleSPOP(ctx context.Context, cmd []string, server utils.Server, conn *ne for i, m := range members { res = fmt.Sprintf("%s\r\n$%d\r\n%s", res, len(m), m) if i == len(members)-1 { - res += "\r\n\r\n" + res += "\r\n" } } @@ -604,7 +604,7 @@ func handleSRANDMEMBER(ctx context.Context, cmd []string, server utils.Server, c } if !server.KeyExists(key) { - return []byte("*-1\r\n\r\n"), nil + return []byte("*-1\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { @@ -623,7 +623,7 @@ func handleSRANDMEMBER(ctx context.Context, cmd []string, server utils.Server, c for i, m := range members { res = fmt.Sprintf("%s\r\n$%d\r\n%s", res, len(m), m) if i == len(members)-1 { - res += "\r\n\r\n" + res += "\r\n" } } @@ -640,7 +640,7 @@ func handleSREM(ctx context.Context, cmd []string, server utils.Server, conn *ne members := cmd[2:] if !server.KeyExists(key) { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { @@ -655,7 +655,7 @@ func handleSREM(ctx context.Context, cmd []string, server utils.Server, conn *ne count := set.Remove(members) - return []byte(fmt.Sprintf(":%d\r\n\r\n", count)), nil + return []byte(fmt.Sprintf(":%d\r\n", count)), nil } func handleSUNION(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -702,7 +702,7 @@ func handleSUNION(ctx context.Context, cmd []string, server utils.Server, conn * for i, e := range union.GetAll() { res = fmt.Sprintf("%s\r\n$%d\r\n%s", res, len(e), e) if i == len(union.GetAll())-1 { - res += "\r\n\r\n" + res += "\r\n" } } @@ -763,7 +763,7 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c defer server.KeyUnlock(destination) server.SetValue(ctx, destination, union) - return []byte(fmt.Sprintf(":%d\r\n\r\n", union.Cardinality())), nil + return []byte(fmt.Sprintf(":%d\r\n", union.Cardinality())), nil } func Commands() []utils.Command { diff --git a/src/modules/set/set.go b/src/modules/set/set.go index 379e225e..75bcbb26 100644 --- a/src/modules/set/set.go +++ b/src/modules/set/set.go @@ -75,8 +75,8 @@ func (set *Set) GetRandom(count int) []string { n = rand.Intn(len(keys)) if !slices.Contains(res, keys[n]) { res = append(res, keys[n]) - keys = utils.Filter(keys, func(elem string) bool { - return elem != keys[n] + keys = slices.DeleteFunc(keys, func(elem string) bool { + return elem == keys[n] }) i++ } diff --git a/src/modules/sorted_set/commands.go b/src/modules/sorted_set/commands.go index 1a0abeb6..b08ce3c5 100644 --- a/src/modules/sorted_set/commands.go +++ b/src/modules/sorted_set/commands.go @@ -128,7 +128,7 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne if server.KeyExists(key) { // Key exists - _, err := server.KeyLock(ctx, key) + _, err = server.KeyLock(ctx, key) if err != nil { return nil, err } @@ -144,10 +144,10 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne // If INCR option is provided, return the new score value if incr != nil { m := set.Get(members[0].value) - return []byte(fmt.Sprintf("+%f\r\n\r\n", m.score)), nil + return []byte(fmt.Sprintf("+%f\r\n", m.score)), nil } - return []byte(fmt.Sprintf(":%d\r\n\r\n", count)), nil + return []byte(fmt.Sprintf(":%d\r\n", count)), nil } // Key does not exist @@ -159,7 +159,7 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne set := NewSortedSet(members) server.SetValue(ctx, key, set) - return []byte(fmt.Sprintf(":%d\r\n\r\n", set.Cardinality())), nil + return []byte(fmt.Sprintf(":%d\r\n", set.Cardinality())), nil } func handleZCARD(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -170,7 +170,7 @@ func handleZCARD(ctx context.Context, cmd []string, server utils.Server, conn *n key := keys[0] if !server.KeyExists(key) { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -183,7 +183,7 @@ func handleZCARD(ctx context.Context, cmd []string, server utils.Server, conn *n return nil, fmt.Errorf("value at %s is not a sorted set", key) } - return []byte(fmt.Sprintf(":%d\r\n\r\n", set.Cardinality())), nil + return []byte(fmt.Sprintf(":%d\r\n", set.Cardinality())), nil } func handleZCOUNT(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -231,7 +231,7 @@ func handleZCOUNT(ctx context.Context, cmd []string, server utils.Server, conn * } if !server.KeyExists(key) { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -251,7 +251,7 @@ func handleZCOUNT(ctx context.Context, cmd []string, server utils.Server, conn * } } - return []byte(fmt.Sprintf(":%d\r\n\r\n", len(members))), nil + return []byte(fmt.Sprintf(":%d\r\n", len(members))), nil } func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -265,7 +265,7 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.Server, con maximum := cmd[3] if !server.KeyExists(key) { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -283,7 +283,7 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.Server, con // Check if all members has the same score for i := 0; i < len(members)-2; i++ { if members[i].score != members[i+1].score { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } } @@ -296,7 +296,7 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.Server, con } } - return []byte(fmt.Sprintf(":%d\r\n\r\n", count)), nil + return []byte(fmt.Sprintf(":%d\r\n", count)), nil } func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -324,7 +324,7 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n // Extract base set if !server.KeyExists(keys[0]) { // If base set does not exist, return an empty array - return []byte("*0\r\n\r\n"), nil + return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, keys[0]); err != nil { return nil, err @@ -367,7 +367,7 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n } } - res += "\r\n\r\n" + res += "\r\n" return []byte(res), nil } @@ -392,7 +392,7 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co // Extract base set if !server.KeyExists(keys[0]) { // If base set does not exist, return 0 - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, keys[0]); err != nil { return nil, err @@ -433,7 +433,7 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co server.SetValue(ctx, destination, diff) - return []byte(fmt.Sprintf(":%d\r\n\r\n", diff.Cardinality())), nil + return []byte(fmt.Sprintf(":%d\r\n", diff.Cardinality())), nil } func handleZINCRBY(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -473,7 +473,7 @@ func handleZINCRBY(ctx context.Context, cmd []string, server utils.Server, conn } server.SetValue(ctx, key, NewSortedSet([]MemberParam{{value: member, score: increment}})) server.KeyUnlock(key) - return []byte(fmt.Sprintf("+%s\r\n\r\n", strconv.FormatFloat(float64(increment), 'f', -1, 64))), nil + return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(float64(increment), 'f', -1, 64))), nil } if _, err = server.KeyLock(ctx, key); err != nil { @@ -493,7 +493,7 @@ func handleZINCRBY(ctx context.Context, cmd []string, server utils.Server, conn "incr"); err != nil { return nil, err } - return []byte(fmt.Sprintf("+%s\r\n\r\n", + return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(float64(set.Get(member).score), 'f', -1, 64))), nil } @@ -522,7 +522,7 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.Server, conn * for i := 0; i < len(keys); i++ { if !server.KeyExists(keys[i]) { // If any of the keys is non-existent, return an empty array as there's no intersect - return []byte("*0\r\n\r\n"), nil + return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, keys[i]); err != nil { return nil, err @@ -552,7 +552,7 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.Server, conn * } } - res += "\r\n\r\n" + res += "\r\n" return []byte(res), nil } @@ -588,7 +588,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c for i := 0; i < len(keys); i++ { if !server.KeyExists(keys[i]) { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, keys[i]); err != nil { return nil, err @@ -619,7 +619,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c server.SetValue(ctx, destination, intersect) - return []byte(fmt.Sprintf(":%d\r\n\r\n", intersect.Cardinality())), nil + return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil } func handleZMPOP(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -691,13 +691,13 @@ func handleZMPOP(ctx context.Context, cmd []string, server utils.Server, conn *n res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.value), m.value, strconv.FormatFloat(float64(m.score), 'f', -1, 64)) } - res += "\r\n\r\n" + res += "\r\n" return []byte(res), nil } } - return []byte("*0\r\n\r\n"), nil + return []byte("*0\r\n"), nil } func handleZPOP(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -723,7 +723,7 @@ func handleZPOP(ctx context.Context, cmd []string, server utils.Server, conn *ne } if !server.KeyExists(key) { - return []byte("*0\r\n\r\n"), nil + return []byte("*0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { @@ -746,7 +746,7 @@ func handleZPOP(ctx context.Context, cmd []string, server utils.Server, conn *ne res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.value), m.value, strconv.FormatFloat(float64(m.score), 'f', -1, 64)) } - res += "\r\n\r\n" + res += "\r\n" return []byte(res), nil } @@ -760,7 +760,7 @@ func handleZMSCORE(ctx context.Context, cmd []string, server utils.Server, conn key := keys[0] if !server.KeyExists(key) { - return []byte("*0\r\n\r\n"), nil + return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -788,7 +788,7 @@ func handleZMSCORE(ctx context.Context, cmd []string, server utils.Server, conn } } - res += "\r\n\r\n" + res += "\r\n" return []byte(res), nil } @@ -819,7 +819,7 @@ func handleZRANDMEMBER(ctx context.Context, cmd []string, server utils.Server, c } if !server.KeyExists(key) { - return []byte("$-1\r\n\r\n"), nil + return []byte("$-1\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -843,7 +843,7 @@ func handleZRANDMEMBER(ctx context.Context, cmd []string, server utils.Server, c } } - res += "\r\n\r\n" + res += "\r\n" return []byte(res), nil } @@ -863,7 +863,7 @@ func handleZRANK(ctx context.Context, cmd []string, server utils.Server, conn *n } if !server.KeyExists(key) { - return []byte("$-1\r\n\r\n"), nil + return []byte("$-1\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -888,14 +888,14 @@ func handleZRANK(ctx context.Context, cmd []string, server utils.Server, conn *n if members[i].value == Value(member) { if withscores { score := strconv.FormatFloat(float64(members[i].score), 'f', -1, 64) - return []byte(fmt.Sprintf("*2\r\n:%d\r\n$%d\r\n%s\r\n\r\n", i, len(score), score)), nil + return []byte(fmt.Sprintf("*2\r\n:%d\r\n$%d\r\n%s\r\n", i, len(score), score)), nil } else { - return []byte(fmt.Sprintf("*1\r\n:%d\r\n\r\n", i)), nil + return []byte(fmt.Sprintf("*1\r\n:%d\r\n", i)), nil } } } - return []byte("$-1\r\n\r\n"), nil + return []byte("$-1\r\n"), nil } func handleZREM(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -907,7 +907,7 @@ func handleZREM(ctx context.Context, cmd []string, server utils.Server, conn *ne key := keys[0] if !server.KeyExists(key) { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { @@ -927,7 +927,7 @@ func handleZREM(ctx context.Context, cmd []string, server utils.Server, conn *ne } } - return []byte(fmt.Sprintf(":%d\r\n\r\n", deletedCount)), nil + return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil } func handleZSCORE(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -939,7 +939,7 @@ func handleZSCORE(ctx context.Context, cmd []string, server utils.Server, conn * key := keys[0] if !server.KeyExists(key) { - return []byte("$-1\r\n\r\n"), nil + return []byte("$-1\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err @@ -951,12 +951,12 @@ func handleZSCORE(ctx context.Context, cmd []string, server utils.Server, conn * } member := set.Get(Value(cmd[2])) if !member.exists { - return []byte("$-1\r\n\r\n"), nil + return []byte("$-1\r\n"), nil } score := strconv.FormatFloat(float64(member.score), 'f', -1, 64) - return []byte(fmt.Sprintf("$%d\r\n%s\r\n\r\n", len(score), score)), nil + return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(score), score)), nil } func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -980,7 +980,7 @@ func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server utils.Serv } if !server.KeyExists(key) { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { @@ -1000,7 +1000,7 @@ func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server utils.Serv } } - return []byte(fmt.Sprintf(":%d\r\n\r\n", deletedCount)), nil + return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil } func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -1022,7 +1022,7 @@ func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server utils.Serve } if !server.KeyExists(key) { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { @@ -1065,7 +1065,7 @@ func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server utils.Serve } } - return []byte(fmt.Sprintf(":%d\r\n\r\n", deletedCount)), nil + return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil } func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -1079,7 +1079,7 @@ func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server utils.Server maximum := cmd[3] if !server.KeyExists(key) { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { @@ -1097,7 +1097,7 @@ func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server utils.Server // Check if all the members have the same score. If not, return 0 for i := 0; i < len(members)-1; i++ { if members[i].score != members[i+1].score { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } } @@ -1112,7 +1112,7 @@ func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server utils.Server } } - return []byte(fmt.Sprintf(":%d\r\n\r\n", deletedCount)), nil + return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil } func handleZRANGE(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -1177,7 +1177,7 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.Server, conn * } if !server.KeyExists(key) { - return []byte("*0\r\n\r\n"), nil + return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -1191,7 +1191,7 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.Server, conn * } if offset > set.Cardinality() { - return []byte("*0\r\n\r\n"), nil + return []byte("*0\r\n"), nil } if count < 0 { count = set.Cardinality() - offset @@ -1211,7 +1211,7 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.Server, conn * // If policy is BYLEX, all the elements must have the same score for i := 0; i < len(members)-1; i++ { if members[i].score != members[i+1].score { - return []byte("*0\r\n\r\n"), nil + return []byte("*0\r\n"), nil } } slices.SortFunc(members, func(a, b MemberParam) int { @@ -1241,9 +1241,7 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.Server, conn * } res := fmt.Sprintf("*%d", len(resultMembers)) - if len(resultMembers) == 0 { - res += "\r\n\r\n" - } + for _, m := range resultMembers { if withscores { res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.value), m.value, strconv.FormatFloat(float64(m.score), 'f', -1, 64)) @@ -1252,7 +1250,7 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.Server, conn * } } - res += "\r\n\r\n" + res += "\r\n" return []byte(res), nil } @@ -1316,7 +1314,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.Server, c } if !server.KeyExists(source) { - return []byte("*0\r\n\r\n"), nil + return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, source); err != nil { @@ -1330,7 +1328,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.Server, c } if offset > set.Cardinality() { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if count < 0 { count = set.Cardinality() - offset @@ -1350,7 +1348,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.Server, c // If policy is BYLEX, all the elements must have the same score for i := 0; i < len(members)-1; i++ { if members[i].score != members[i+1].score { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } } slices.SortFunc(members, func(a, b MemberParam) int { @@ -1394,7 +1392,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.Server, c server.SetValue(ctx, destination, newSortedSet) - return []byte(fmt.Sprintf(":%d\r\n\r\n", newSortedSet.Cardinality())), nil + return []byte(fmt.Sprintf(":%d\r\n", newSortedSet.Cardinality())), nil } func handleZUNION(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -1446,7 +1444,7 @@ func handleZUNION(ctx context.Context, cmd []string, server utils.Server, conn * } } - res += "\r\n\r\n" + res += "\r\n" return []byte(res), nil } @@ -1512,7 +1510,7 @@ func handleZUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c server.SetValue(ctx, destination, union) - return []byte(fmt.Sprintf(":%d\r\n\r\n", union.Cardinality())), nil + return []byte(fmt.Sprintf(":%d\r\n", union.Cardinality())), nil } func Commands() []utils.Command { diff --git a/src/modules/string/commands.go b/src/modules/string/commands.go index 3ebe060b..296eac3b 100644 --- a/src/modules/string/commands.go +++ b/src/modules/string/commands.go @@ -29,7 +29,7 @@ func handleSetRange(ctx context.Context, cmd []string, server utils.Server, conn } server.SetValue(ctx, key, newStr) server.KeyUnlock(key) - return []byte(fmt.Sprintf(":%d\r\n\r\n", len(newStr))), nil + return []byte(fmt.Sprintf(":%d\r\n", len(newStr))), nil } if _, err := server.KeyLock(ctx, key); err != nil { @@ -46,14 +46,14 @@ func handleSetRange(ctx context.Context, cmd []string, server utils.Server, conn if offset >= len(str) { newStr = str + newStr server.SetValue(ctx, key, newStr) - return []byte(fmt.Sprintf(":%d\r\n\r\n", len(newStr))), nil + return []byte(fmt.Sprintf(":%d\r\n", len(newStr))), nil } // If the offset is < 0, prepend the new string to the old one. if offset < 0 { newStr = newStr + str server.SetValue(ctx, key, newStr) - return []byte(fmt.Sprintf(":%d\r\n\r\n", len(newStr))), nil + return []byte(fmt.Sprintf(":%d\r\n", len(newStr))), nil } strRunes := []rune(str) @@ -72,7 +72,7 @@ func handleSetRange(ctx context.Context, cmd []string, server utils.Server, conn server.SetValue(ctx, key, string(strRunes)) - return []byte(fmt.Sprintf(":%d\r\n\r\n", len(strRunes))), nil + return []byte(fmt.Sprintf(":%d\r\n", len(strRunes))), nil } func handleStrLen(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -84,7 +84,7 @@ func handleStrLen(ctx context.Context, cmd []string, server utils.Server, conn * key := keys[0] if !server.KeyExists(key) { - return []byte(":0\r\n\r\n"), nil + return []byte(":0\r\n"), nil } if _, err := server.KeyRLock(ctx, key); err != nil { @@ -98,7 +98,7 @@ func handleStrLen(ctx context.Context, cmd []string, server utils.Server, conn * return nil, fmt.Errorf("value at key %s is not a string", key) } - return []byte(fmt.Sprintf(":%d\r\n\r\n", len(value))), nil + return []byte(fmt.Sprintf(":%d\r\n", len(value))), nil } func handleSubStr(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { @@ -161,7 +161,7 @@ func handleSubStr(ctx context.Context, cmd []string, server utils.Server, conn * str = res } - return []byte(fmt.Sprintf("$%d\r\n%s\r\n\r\n", len(str), str)), nil + return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(str), str)), nil } func Commands() []utils.Command { diff --git a/src/server/aof/log/store.go b/src/server/aof/log/store.go index 314e094f..a0830e48 100644 --- a/src/server/aof/log/store.go +++ b/src/server/aof/log/store.go @@ -15,7 +15,7 @@ import ( ) type AppendReadWriter interface { - io.ReadWriter + io.ReadWriteSeeker io.Closer Truncate(size int64) error Sync() error @@ -101,7 +101,9 @@ func NewAppendStore(options ...func(store *AppendStore)) *AppendStore { func (store *AppendStore) Write(command []byte) error { store.mut.Lock() defer store.mut.Unlock() - if _, err := store.rw.Write(command); err != nil { + // Add new line before writing to AOF file. + out := append(command, []byte("\r\n")...) + if _, err := store.rw.Write(out); err != nil { return err } if strings.EqualFold(store.strategy, "always") { @@ -160,6 +162,10 @@ func (store *AppendStore) Truncate() error { if err := store.rw.Truncate(0); err != nil { return err } + // Seek to the beginning of the file after truncating + if _, err := store.rw.Seek(0, 0); err != nil { + return err + } return nil } diff --git a/src/server/aof/preamble/store.go b/src/server/aof/preamble/store.go index 2a36d778..08f22651 100644 --- a/src/server/aof/preamble/store.go +++ b/src/server/aof/preamble/store.go @@ -125,6 +125,10 @@ func (store *PreambleStore) Restore() error { return err } + if len(b) <= 0 { + return nil + } + state := make(map[string]interface{}) if err = json.Unmarshal(b, &state); err != nil { diff --git a/src/server/modules.go b/src/server/modules.go index 55b04bad..0dbb704c 100644 --- a/src/server/modules.go +++ b/src/server/modules.go @@ -36,6 +36,8 @@ func (server *Server) handleCommand(ctx context.Context, message []byte, conn *n return nil, err } + fmt.Println(cmd) + command, err := server.getCommand(cmd[0]) if err != nil { return nil, err diff --git a/src/server/server.go b/src/server/server.go index 50023727..cb366473 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -208,6 +208,7 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) { if err != nil && errors.Is(err, io.EOF) { // Connection closed + log.Println(err) break } @@ -218,8 +219,12 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) { res, err := server.handleCommand(ctx, message, &conn, false) + if err != nil && errors.Is(err, io.EOF) { + break + } + if err != nil { - if _, err = w.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error()))); err != nil { + if _, err = w.Write([]byte(fmt.Sprintf("-Error %s\r\n", err.Error()))); err != nil { log.Println(err) } continue @@ -227,11 +232,13 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) { chunkSize := 1024 + // If the length of the response is 0, return nothing to the client + if len(res) == 0 { + continue + } + if len(res) <= chunkSize { - _, err = w.Write(res) - if err != nil { - log.Println(err) - } + _, _ = w.Write(res) continue } @@ -240,7 +247,10 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) { for { // If the current start index is less than chunkSize from length, return the remaining bytes. if len(res)-1-startIndex < chunkSize { - _, _ = w.Write(res[startIndex:]) + _, err = w.Write(res[startIndex:]) + if err != nil { + log.Println(err) + } break } n, _ := w.Write(res[startIndex : startIndex+chunkSize]) diff --git a/src/utils/const.go b/src/utils/const.go index b7418c8e..2c7b5c7c 100644 --- a/src/utils/const.go +++ b/src/utils/const.go @@ -25,6 +25,6 @@ const ( ) const ( - OK_RESPONSE = "+OK\r\n\r\n" + OK_RESPONSE = "+OK\r\n" WRONG_ARGS_RESPONSE = "wrong number of arguments" ) diff --git a/src/utils/utils.go b/src/utils/utils.go index d005354c..ed32c13a 100644 --- a/src/utils/utils.go +++ b/src/utils/utils.go @@ -1,7 +1,9 @@ package utils import ( + "bufio" "bytes" + "errors" "io" "log" "math/big" @@ -32,60 +34,45 @@ func AdaptType(s string) interface{} { return f } -func Filter[T any](arr []T, test func(elem T) bool) (res []T) { - for _, e := range arr { - if test(e) { - res = append(res, e) - } - } - return -} - func Decode(raw []byte) ([]string, error) { - rd := resp.NewReader(bytes.NewBuffer(raw)) - var res []string - - v, _, err := rd.ReadValue() + reader := resp.NewReader(bytes.NewReader(raw)) + value, _, err := reader.ReadValue() if err != nil { return nil, err } - if slices.Contains([]string{"SimpleString", "Integer", "Error"}, v.Type().String()) { - return []string{v.String()}, nil - } - - if v.Type().String() == "Array" { - for _, elem := range v.Array() { - res = append(res, elem.String()) - } + var res []string + for i := 0; i < len(value.Array()); i++ { + res = append(res, value.Array()[i].String()) } return res, nil } func ReadMessage(r io.Reader) ([]byte, error) { - delim := []byte{'\r', '\n', '\r', '\n'} - buffSize := 8 - buff := make([]byte, buffSize) + reader := bufio.NewReader(r) - var n int - var err error var res []byte + chunk := make([]byte, 8192) + for { - n, err = r.Read(buff) - res = append(res, buff...) - if n < buffSize || err != nil { + n, err := reader.Read(chunk) + if err != nil && errors.Is(err, io.EOF) { break } - if bytes.Equal(buff[len(buff)-4:], delim) { + if err != nil { + return nil, err + } + res = append(res, chunk...) + if n < len(chunk) { break } - clear(buff) + clear(chunk) } - return res, err + return bytes.Trim(res, "\x00"), nil } func RetryBackoff(b retry.Backoff, maxRetries uint64, jitter, cappedDuration, maxDuration time.Duration) retry.Backoff {