diff --git a/README.md b/README.md
index 735e5a1c..4c55a538 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
[![Go Coverage](https://github.com/EchoVault/EchoVault/wiki/coverage.svg)](https://raw.githack.com/wiki/EchoVault/EchoVault/coverage.html)
[![GitHub Release](https://img.shields.io/github/v/release/EchoVault/EchoVault)]()
-[![License: GPL v2](https://img.shields.io/badge/License-GPL_v2-blue.svg)](https://www.gnu.org/licenses/old-licenses/gpl-2.0.en.html)
+[![License: GPL v3](https://img.shields.io/badge/License-GPL_v3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0.en.html)
[![Discord](https://img.shields.io/discord/1211815152291414037?style=flat&logo=discord&link=https%3A%2F%2Fdiscord.gg%2Fvt45CKfF)](https://discord.gg/vt45CKfF)
diff --git a/docker-compose.yaml b/docker-compose.yaml
index c4bd72ea..07945174 100644
--- a/docker-compose.yaml
+++ b/docker-compose.yaml
@@ -19,7 +19,7 @@ services:
- DATA_DIR=/var/lib/echovault
- IN_MEMORY=false
- TLS=true
- - MTLS=true
+ - MTLS=false
- BOOTSTRAP_CLUSTER=false
- ACL_CONFIG=/etc/config/echovault/acl.yml
- REQUIRE_PASS=true
@@ -36,7 +36,7 @@ services:
# List of client certificate authorities
- CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt
ports:
- - "7479:7480"
+ - "7480:7480"
- "7946:7946"
- "7999:8000"
volumes:
@@ -76,7 +76,7 @@ services:
# List of client certificate authorities
- CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt
ports:
- - "7480:7480"
+ - "7481:7480"
- "7945:7946"
- "8000:8000"
volumes:
@@ -117,7 +117,7 @@ services:
# List of client certificate authorities
- CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt
ports:
- - "7481:7480"
+ - "7482:7480"
- "7947:7946"
- "8001:8000"
volumes:
@@ -158,7 +158,7 @@ services:
# List of client certificate authorities
- CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt
ports:
- - "7482:7480"
+ - "7483:7480"
- "7948:7946"
- "8002:8000"
volumes:
@@ -199,7 +199,7 @@ services:
# List of client certificate authorities
- CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt
ports:
- - "7483:7480"
+ - "7484:7480"
- "7949:7946"
- "8003:8000"
volumes:
@@ -240,7 +240,7 @@ services:
# List of client certificate authorities
- CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt
ports:
- - "7484:7480"
+ - "7485:7480"
- "7950:7946"
- "8004:8000"
volumes:
diff --git a/src/memberlist/delegate.go b/src/memberlist/delegate.go
index 01314cdb..8bf78b9a 100644
--- a/src/memberlist/delegate.go
+++ b/src/memberlist/delegate.go
@@ -7,6 +7,7 @@ import (
"github.com/echovault/echovault/src/utils"
"github.com/hashicorp/memberlist"
"github.com/hashicorp/raft"
+ "log"
"time"
)
@@ -79,14 +80,12 @@ func (delegate *Delegate) NotifyMsg(msgBytes []byte) {
cmd, err := utils.Decode(msg.Content)
if err != nil {
- // TODO: Log error to configured logger
- fmt.Println(err)
+ log.Println(err)
return
}
if _, err := delegate.options.applyMutate(ctx, cmd); err != nil {
- // TODO: Log error to configured logger
- fmt.Println(err)
+ log.Println(err)
}
}
}
diff --git a/src/modules/acl/acl.go b/src/modules/acl/acl.go
index 7ec8f40f..0b08a2e3 100644
--- a/src/modules/acl/acl.go
+++ b/src/modules/acl/acl.go
@@ -105,9 +105,10 @@ func NewACL(config utils.Config) *ACL {
func (acl *ACL) RegisterConnection(conn *net.Conn) {
// This is called only when a connection is established.
- defaultUser := utils.Filter(acl.Users, func(elem *User) bool {
- return elem.Username == "default"
- })[0]
+ defaultUserIdx := slices.IndexFunc(acl.Users, func(user *User) bool {
+ return user.Username == "default"
+ })
+ defaultUser := acl.Users[defaultUserIdx]
acl.Connections[conn] = Connection{
Authenticated: defaultUser.NoPassword,
User: defaultUser,
@@ -167,7 +168,7 @@ func (acl *ACL) DeleteUser(ctx context.Context, usernames []string) error {
}
}
// Delete the user from the ACL
- acl.Users = utils.Filter(acl.Users, func(u *User) bool {
+ acl.Users = slices.DeleteFunc(acl.Users, func(u *User) bool {
return u.Username != user.Username
})
}
@@ -188,9 +189,10 @@ func (acl *ACL) AuthenticateConnection(ctx context.Context, conn *net.Conn, cmd
{PasswordType: "SHA256", PasswordValue: string(h.Sum(nil))},
}
// Authenticate with default user
- user = utils.Filter(acl.Users, func(user *User) bool {
+ idx := slices.IndexFunc(acl.Users, func(user *User) bool {
return user.Username == "default"
- })[0]
+ })
+ user = acl.Users[idx]
}
if len(cmd) == 3 {
// Process AUTH
diff --git a/src/modules/acl/commands.go b/src/modules/acl/commands.go
index 6db9ccaa..ae4764e8 100644
--- a/src/modules/acl/commands.go
+++ b/src/modules/acl/commands.go
@@ -139,7 +139,7 @@ func handleGetUser(ctx context.Context, cmd []string, server utils.Server, conn
res = res + fmt.Sprintf("\r\n+-&%s", channel)
}
- res += "\r\n\r\n"
+ res += "\r\n"
return []byte(res), nil
}
@@ -179,7 +179,7 @@ func handleCat(ctx context.Context, cmd []string, server utils.Server, conn *net
for i, cat := range cats {
res = fmt.Sprintf("%s\r\n+%s", res, cat)
if i == len(cats)-1 {
- res = res + "\r\n\r\n"
+ res = res + "\r\n"
}
}
return []byte(res), nil
@@ -193,7 +193,7 @@ func handleCat(ctx context.Context, cmd []string, server utils.Server, conn *net
for i, command := range commands {
res = fmt.Sprintf("%s\r\n+%s", res, command)
if i == len(commands)-1 {
- res = res + "\r\n\r\n"
+ res = res + "\r\n"
}
}
return []byte(res), nil
@@ -213,7 +213,7 @@ func handleUsers(ctx context.Context, cmd []string, server utils.Server, conn *n
for _, user := range acl.Users {
res += fmt.Sprintf("\r\n$%d\r\n%s", len(user.Username), user.Username)
}
- res += "\r\n\r\n"
+ res += "\r\n"
return []byte(res), nil
}
@@ -248,7 +248,7 @@ func handleWhoAmI(ctx context.Context, cmd []string, server utils.Server, conn *
return nil, errors.New("could not load ACL")
}
connectionInfo := acl.Connections[conn]
- return []byte(fmt.Sprintf("+%s\r\n\r\n", connectionInfo.User.Username)), nil
+ return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil
}
func handleList(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -343,7 +343,7 @@ func handleList(ctx context.Context, cmd []string, server utils.Server, conn *ne
res = res + fmt.Sprintf("\r\n$%d\r\n%s", len(s), s)
}
- res = res + "\r\n\r\n"
+ res = res + "\r\n"
return []byte(res), nil
}
diff --git a/src/modules/acl/user.go b/src/modules/acl/user.go
index 1d7dcddb..51370356 100644
--- a/src/modules/acl/user.go
+++ b/src/modules/acl/user.go
@@ -1,7 +1,6 @@
package acl
import (
- "github.com/echovault/echovault/src/utils"
"slices"
"strings"
)
@@ -105,18 +104,18 @@ func (user *User) UpdateUser(cmd []string) error {
continue
}
if str[0] == '<' {
- user.Passwords = utils.Filter(user.Passwords, func(password Password) bool {
+ user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool {
if strings.EqualFold(password.PasswordType, "SHA256") {
- return true
+ return false
}
return password.PasswordValue == str[1:]
})
continue
}
if str[0] == '!' {
- user.Passwords = utils.Filter(user.Passwords, func(password Password) bool {
+ user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool {
if strings.EqualFold(password.PasswordType, "plaintext") {
- return true
+ return false
}
return password.PasswordValue == str[1:]
})
diff --git a/src/modules/admin/commands.go b/src/modules/admin/commands.go
index 4e80b00b..7bdca3d5 100644
--- a/src/modules/admin/commands.go
+++ b/src/modules/admin/commands.go
@@ -5,7 +5,10 @@ import (
"errors"
"fmt"
"github.com/echovault/echovault/src/utils"
+ "github.com/gobwas/glob"
"net"
+ "slices"
+ "strings"
)
func handleGetAllCommands(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
@@ -48,11 +51,123 @@ func handleGetAllCommands(ctx context.Context, cmd []string, server utils.Server
}
}
- res = fmt.Sprintf("*%d\r\n%s\r\n", commandCount, res)
+ res = fmt.Sprintf("*%d\r\n%s", commandCount, res)
return []byte(res), nil
}
+func handleCommandCount(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
+ var count int
+
+ commands := server.GetAllCommands(ctx)
+ for _, command := range commands {
+ if command.SubCommands != nil && len(command.SubCommands) > 0 {
+ for _, _ = range command.SubCommands {
+ count += 1
+ }
+ continue
+ }
+ count += 1
+ }
+
+ return []byte(fmt.Sprintf(":%d\r\n", count)), nil
+}
+
+func handleCommandList(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
+ switch len(cmd) {
+ case 2:
+ // Command is COMMAND LIST
+ var count int
+ var res string
+ commands := server.GetAllCommands(ctx)
+ for _, command := range commands {
+ if command.SubCommands != nil && len(command.SubCommands) > 0 {
+ for _, subcommand := range command.SubCommands {
+ comm := fmt.Sprintf("%s %s", command.Command, subcommand.Command)
+ res += fmt.Sprintf("$%d\r\n%s\r\n", len(comm), comm)
+ count += 1
+ }
+ continue
+ }
+ res += fmt.Sprintf("$%d\r\n%s\r\n", len(command.Command), command.Command)
+ count += 1
+ }
+ res = fmt.Sprintf("*%d\r\n%s", count, res)
+ return []byte(res), nil
+
+ case 5:
+ var count int
+ var res string
+ // Command has filter
+ if !strings.EqualFold("FILTERBY", cmd[2]) {
+ return nil, fmt.Errorf("expected FILTERBY, got %s", strings.ToUpper(cmd[2]))
+ }
+ if strings.EqualFold("ACLCAT", cmd[3]) {
+ // ACL Category filter
+ commands := server.GetAllCommands(ctx)
+ category := strings.ToLower(cmd[4])
+ for _, command := range commands {
+ if command.SubCommands != nil && len(command.SubCommands) > 0 {
+ for _, subcommand := range command.SubCommands {
+ if slices.Contains(subcommand.Categories, category) {
+ comm := fmt.Sprintf("%s %s", command.Command, subcommand.Command)
+ res += fmt.Sprintf("$%d\r\n%s\r\n", len(comm), comm)
+ count += 1
+ }
+ }
+ continue
+ }
+ if slices.Contains(command.Categories, category) {
+ res += fmt.Sprintf("$%d\r\n%s\r\n", len(command.Command), command.Command)
+ count += 1
+ }
+ }
+ } else if strings.EqualFold("PATTERN", cmd[3]) {
+ // Pattern filter
+ commands := server.GetAllCommands(ctx)
+ g := glob.MustCompile(cmd[4])
+ for _, command := range commands {
+ if command.SubCommands != nil && len(command.SubCommands) > 0 {
+ for _, subcommand := range command.SubCommands {
+ comm := fmt.Sprintf("%s %s", command.Command, subcommand.Command)
+ if g.Match(comm) {
+ res += fmt.Sprintf("$%d\r\n%s\r\n", len(comm), comm)
+ count += 1
+ }
+ }
+ continue
+ }
+ if g.Match(command.Command) {
+ res += fmt.Sprintf("$%d\r\n%s\r\n", len(command.Command), command.Command)
+ count += 1
+ }
+ }
+ } else {
+ return nil, fmt.Errorf("expected filter to be ACLCAT or PATTERN, got %s", strings.ToUpper(cmd[3]))
+ }
+ res = fmt.Sprintf("*%d\r\n%s", count, res)
+ return []byte(res), nil
+ default:
+ return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
+ }
+}
+
+func handleCommandDocs(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
+ return []byte("*0\r\n"), nil
+}
+
+// func handleConfigGet(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
+// return nil, errors.New("command not yet implemented")
+// }
+//
+// func handleConfigRewrite(ctx context.Context, cmd []string, server *utils.Server, _ *net.Conn) ([]byte, error) {
+// return nil, errors.New("command not yet implemented")
+// }
+//
+// func handleConfigSet(ctx context.Context, cmd []string, server *utils.Server, _ *net.Conn) ([]byte, error) {
+// return nil, errors.New("command not yet implemented")
+// }
+
func Commands() []utils.Command {
return []utils.Command{
{
@@ -63,6 +178,42 @@ func Commands() []utils.Command {
KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
HandlerFunc: handleGetAllCommands,
},
+ {
+ Command: "command",
+ Categories: []string{},
+ Description: "Commands pertaining to echovault commands",
+ Sync: false,
+ KeyExtractionFunc: func(cmd []string) ([]string, error) {
+ return []string{}, nil
+ },
+ SubCommands: []utils.SubCommand{
+ {
+ Command: "docs",
+ Categories: []string{utils.SlowCategory, utils.ConnectionCategory},
+ Description: "Get command documentation",
+ Sync: false,
+ KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
+ HandlerFunc: handleCommandDocs,
+ },
+ {
+ Command: "count",
+ Categories: []string{utils.SlowCategory},
+ Description: "Get the dumber of commands in the server",
+ Sync: false,
+ KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
+ HandlerFunc: handleCommandCount,
+ },
+ {
+ Command: "list",
+ Categories: []string{utils.SlowCategory},
+ Description: `(COMMAND LIST [FILTERBY ]) Get the list of command names.
+Allows for filtering by ACL category or glob pattern.`,
+ Sync: false,
+ KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
+ HandlerFunc: handleCommandList,
+ },
+ },
+ },
{
Command: "save",
Categories: []string{utils.AdminCategory, utils.SlowCategory, utils.DangerousCategory},
@@ -91,7 +242,7 @@ func Commands() []utils.Command {
if msec == 0 {
return nil, errors.New("no snapshot")
}
- return []byte(fmt.Sprintf(":%d\r\n\r\n", msec)), nil
+ return []byte(fmt.Sprintf(":%d\r\n", msec)), nil
},
},
{
diff --git a/src/modules/etc/commands.go b/src/modules/etc/commands.go
index 4e8993f0..266663f1 100644
--- a/src/modules/etc/commands.go
+++ b/src/modules/etc/commands.go
@@ -2,6 +2,7 @@ package etc
import (
"context"
+ "errors"
"fmt"
"github.com/echovault/echovault/src/utils"
"net"
@@ -50,9 +51,7 @@ func handleSetNX(ctx context.Context, cmd []string, server utils.Server, conn *n
if server.KeyExists(key) {
return nil, fmt.Errorf("key %s already exists", key)
}
- // TODO: Retry CreateKeyAndLock until we manage to obtain the key
- _, err = server.CreateKeyAndLock(ctx, key)
- if err != nil {
+ if _, err = server.CreateKeyAndLock(ctx, key); err != nil {
return nil, err
}
server.SetValue(ctx, key, utils.AdaptType(cmd[2]))
@@ -115,6 +114,10 @@ func handleMSet(ctx context.Context, cmd []string, server utils.Server, conn *ne
return []byte(utils.OK_RESPONSE), nil
}
+func handleCopy(ctx context.Context, cmd []string, server *utils.Server, _ *net.Conn) ([]byte, error) {
+ return nil, errors.New("command not yet implemented")
+}
+
func Commands() []utils.Command {
return []utils.Command{
{
diff --git a/src/modules/get/commands.go b/src/modules/get/commands.go
index ddc3679d..d0b9bcbd 100644
--- a/src/modules/get/commands.go
+++ b/src/modules/get/commands.go
@@ -15,7 +15,7 @@ func handleGet(ctx context.Context, cmd []string, server utils.Server, conn *net
key := keys[0]
if !server.KeyExists(key) {
- return []byte("+nil\r\n\r\n"), nil
+ return []byte("$-1\r\n"), nil
}
_, err = server.KeyRLock(ctx, key)
@@ -26,7 +26,7 @@ func handleGet(ctx context.Context, cmd []string, server utils.Server, conn *net
value := server.GetValue(key)
- return []byte(fmt.Sprintf("+%v\r\n\r\n", value)), nil
+ return []byte(fmt.Sprintf("+%v\r\n", value)), nil
}
func handleMGet(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -51,7 +51,7 @@ func handleMGet(ctx context.Context, cmd []string, server utils.Server, conn *ne
locks[key] = true
continue
}
- values[key] = "nil"
+ values[key] = ""
}
defer func() {
for key, locked := range locks {
@@ -69,11 +69,13 @@ func handleMGet(ctx context.Context, cmd []string, server utils.Server, conn *ne
bytes := []byte(fmt.Sprintf("*%d\r\n", len(cmd[1:])))
for _, key := range cmd[1:] {
+ if values[key] == "" {
+ bytes = append(bytes, []byte("$-1\r\n")...)
+ continue
+ }
bytes = append(bytes, []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(values[key]), values[key]))...)
}
- bytes = append(bytes, []byte("\r\n")...)
-
return bytes, nil
}
diff --git a/src/modules/get/commands_test.go b/src/modules/get/commands_test.go
index a431d0e2..97418006 100644
--- a/src/modules/get/commands_test.go
+++ b/src/modules/get/commands_test.go
@@ -47,8 +47,8 @@ func Test_HandleGET(t *testing.T) {
if err != nil {
t.Error(err)
}
- if !bytes.Equal(res, []byte(fmt.Sprintf("+%v\r\n\r\n", value))) {
- t.Errorf("expected %s, got: %s", fmt.Sprintf("+%v\r\n\r\n", value), string(res))
+ if !bytes.Equal(res, []byte(fmt.Sprintf("+%v\r\n", value))) {
+ t.Errorf("expected %s, got: %s", fmt.Sprintf("+%v\r\n", value), string(res))
}
}(test.key, test.value)
}
@@ -58,8 +58,8 @@ func Test_HandleGET(t *testing.T) {
if err != nil {
t.Error(err)
}
- if !bytes.Equal(res, []byte("+nil\r\n\r\n")) {
- t.Errorf("expected %+v, got: %+v", "+nil\r\n\r\n", res)
+ if !bytes.Equal(res, []byte("$-1\r\n")) {
+ t.Errorf("expected %+v, got: %+v", "+nil\r\n", res)
}
errorTests := []struct {
@@ -93,21 +93,21 @@ func Test_HandleMGET(t *testing.T) {
presetKeys []string
presetValues []string
command []string
- expected []string
+ expected []interface{}
expectedError error
}{
{
presetKeys: []string{"test1", "test2", "test3", "test4"},
presetValues: []string{"value1", "value2", "value3", "value4"},
command: []string{"MGET", "test1", "test4", "test2", "test3", "test1"},
- expected: []string{"value1", "value4", "value2", "value3", "value1"},
+ expected: []interface{}{"value1", "value4", "value2", "value3", "value1"},
expectedError: nil,
},
{
presetKeys: []string{"test5", "test6", "test7"},
presetValues: []string{"value5", "value6", "value7"},
command: []string{"MGET", "test5", "test6", "non-existent", "non-existent", "test7", "non-existent"},
- expected: []string{"value5", "value6", "nil", "nil", "value7", "nil"},
+ expected: []interface{}{"value5", "value6", nil, nil, "value7", nil},
expectedError: nil,
},
{
@@ -150,6 +150,12 @@ func Test_HandleMGET(t *testing.T) {
t.Errorf("expected type Array, got: %s", rv.Type().String())
}
for i, value := range rv.Array() {
+ if test.expected[i] == nil {
+ if !value.IsNull() {
+ t.Errorf("expected nil value, got %+v", value)
+ }
+ continue
+ }
if value.String() != test.expected[i] {
t.Errorf("expected value %s, got: %s", test.expected[i], value.String())
}
diff --git a/src/modules/hash/commands.go b/src/modules/hash/commands.go
index 300b6db0..fe102684 100644
--- a/src/modules/hash/commands.go
+++ b/src/modules/hash/commands.go
@@ -36,7 +36,7 @@ func handleHSET(ctx context.Context, cmd []string, server utils.Server, conn *ne
}
defer server.KeyUnlock(key)
server.SetValue(ctx, key, entries)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", len(entries))), nil
+ return []byte(fmt.Sprintf(":%d\r\n", len(entries))), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
@@ -63,7 +63,7 @@ func handleHSET(ctx context.Context, cmd []string, server utils.Server, conn *ne
}
server.SetValue(ctx, key, hash)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", count)), nil
+ return []byte(fmt.Sprintf(":%d\r\n", count)), nil
}
func handleHGET(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -76,7 +76,7 @@ func handleHGET(ctx context.Context, cmd []string, server utils.Server, conn *ne
fields := cmd[2:]
if !server.KeyExists(key) {
- return []byte("$-1\r\n\r\n"), nil
+ return []byte("$-1\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -113,7 +113,6 @@ func handleHGET(ctx context.Context, cmd []string, server utils.Server, conn *ne
}
res += fmt.Sprintf("$-1\r\n")
}
- res += "\r\n"
return []byte(res), nil
}
@@ -128,7 +127,7 @@ func handleHSTRLEN(ctx context.Context, cmd []string, server utils.Server, conn
fields := cmd[2:]
if !server.KeyExists(key) {
- return []byte("$-1\r\n\r\n"), nil
+ return []byte("$-1\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -165,7 +164,6 @@ func handleHSTRLEN(ctx context.Context, cmd []string, server utils.Server, conn
}
res += ":0\r\n"
}
- res += "\r\n"
return []byte(res), nil
}
@@ -179,7 +177,7 @@ func handleHVALS(ctx context.Context, cmd []string, server utils.Server, conn *n
key := keys[0]
if !server.KeyExists(key) {
- return []byte("*0\r\n\r\n"), nil
+ return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -207,7 +205,6 @@ func handleHVALS(ctx context.Context, cmd []string, server utils.Server, conn *n
res += fmt.Sprintf(":%d\r\n", d)
}
}
- res += "\r\n"
return []byte(res), nil
}
@@ -227,7 +224,7 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server utils.Server, co
return nil, errors.New("count must be an integer")
}
if c == 0 {
- return []byte("*0\r\n\r\n"), nil
+ return []byte("*0\r\n"), nil
}
count = c
}
@@ -242,7 +239,7 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server utils.Server, co
}
if !server.KeyExists(key) {
- return []byte("*0\r\n\r\n"), nil
+ return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -279,7 +276,6 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server utils.Server, co
}
}
}
- res += "\r\n"
return []byte(res), nil
}
@@ -325,7 +321,6 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server utils.Server, co
}
}
}
- res += "\r\n"
return []byte(res), nil
}
@@ -339,7 +334,7 @@ func handleHLEN(ctx context.Context, cmd []string, server utils.Server, conn *ne
key := keys[0]
if !server.KeyExists(key) {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -352,7 +347,7 @@ func handleHLEN(ctx context.Context, cmd []string, server utils.Server, conn *ne
return nil, fmt.Errorf("value at %s is not a hash", key)
}
- return []byte(fmt.Sprintf(":%d\r\n\r\n", len(hash))), nil
+ return []byte(fmt.Sprintf(":%d\r\n", len(hash))), nil
}
func handleHKEYS(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -364,7 +359,7 @@ func handleHKEYS(ctx context.Context, cmd []string, server utils.Server, conn *n
key := keys[0]
if !server.KeyExists(key) {
- return []byte("*0\r\n\r\n"), nil
+ return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -381,7 +376,6 @@ func handleHKEYS(ctx context.Context, cmd []string, server utils.Server, conn *n
for field, _ := range hash {
res += fmt.Sprintf("$%d\r\n%s\r\n", len(field), field)
}
- res += "\r\n"
return []byte(res), nil
}
@@ -421,11 +415,11 @@ func handleHINCRBY(ctx context.Context, cmd []string, server utils.Server, conn
if strings.EqualFold(cmd[0], "hincrbyfloat") {
hash[field] = floatIncrement
server.SetValue(ctx, key, hash)
- return []byte(fmt.Sprintf("+%s\r\n\r\n", strconv.FormatFloat(floatIncrement, 'f', -1, 64))), nil
+ return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(floatIncrement, 'f', -1, 64))), nil
} else {
hash[field] = intIncrement
server.SetValue(ctx, key, hash)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", intIncrement)), nil
+ return []byte(fmt.Sprintf(":%d\r\n", intIncrement)), nil
}
}
@@ -465,11 +459,11 @@ func handleHINCRBY(ctx context.Context, cmd []string, server utils.Server, conn
server.SetValue(ctx, key, hash)
if f, ok := hash[field].(float64); ok {
- return []byte(fmt.Sprintf("+%s\r\n\r\n", strconv.FormatFloat(f, 'f', -1, 64))), nil
+ return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(f, 'f', -1, 64))), nil
}
i, _ := hash[field].(int)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", i)), nil
+ return []byte(fmt.Sprintf(":%d\r\n", i)), nil
}
func handleHGETALL(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -481,7 +475,7 @@ func handleHGETALL(ctx context.Context, cmd []string, server utils.Server, conn
key := keys[0]
if !server.KeyExists(key) {
- return []byte("*0\r\n\r\n"), nil
+ return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -508,7 +502,6 @@ func handleHGETALL(ctx context.Context, cmd []string, server utils.Server, conn
res += fmt.Sprintf(":%d\r\n", d)
}
}
- res += "\r\n"
return []byte(res), nil
}
@@ -523,7 +516,7 @@ func handleHEXISTS(ctx context.Context, cmd []string, server utils.Server, conn
field := cmd[2]
if !server.KeyExists(key) {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -537,10 +530,10 @@ func handleHEXISTS(ctx context.Context, cmd []string, server utils.Server, conn
}
if hash[field] != nil {
- return []byte(":1\r\n\r\n"), nil
+ return []byte(":1\r\n"), nil
}
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
func handleHDEL(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -553,7 +546,7 @@ func handleHDEL(ctx context.Context, cmd []string, server utils.Server, conn *ne
fields := cmd[2:]
if !server.KeyExists(key) {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
@@ -577,7 +570,7 @@ func handleHDEL(ctx context.Context, cmd []string, server utils.Server, conn *ne
server.SetValue(ctx, key, hash)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", count)), nil
+ return []byte(fmt.Sprintf(":%d\r\n", count)), nil
}
func Commands() []utils.Command {
diff --git a/src/modules/list/commands.go b/src/modules/list/commands.go
index e1e10873..a9a13017 100644
--- a/src/modules/list/commands.go
+++ b/src/modules/list/commands.go
@@ -11,7 +11,7 @@ import (
"strings"
)
-func handleLLen(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
+func handleLLen(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
keys, err := llenKeyFunc(cmd)
if err != nil {
return nil, err
@@ -21,7 +21,7 @@ func handleLLen(ctx context.Context, cmd []string, server utils.Server, conn *ne
if !server.KeyExists(key) {
// If key does not exist, return 0
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -30,7 +30,7 @@ func handleLLen(ctx context.Context, cmd []string, server utils.Server, conn *ne
defer server.KeyRUnlock(key)
if list, ok := server.GetValue(key).([]interface{}); ok {
- return []byte(fmt.Sprintf(":%d\r\n\r\n", len(list))), nil
+ return []byte(fmt.Sprintf(":%d\r\n", len(list))), nil
}
return nil, errors.New("LLEN command on non-list item")
@@ -67,7 +67,7 @@ func handleLIndex(ctx context.Context, cmd []string, server utils.Server, conn *
return nil, errors.New("index must be within list range")
}
- return []byte(fmt.Sprintf("+%s\r\n\r\n", list[index])), nil
+ return []byte(fmt.Sprintf("+%s\r\n", list[index])), nil
}
func handleLRange(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -117,7 +117,6 @@ func handleLRange(ctx context.Context, cmd []string, server utils.Server, conn *
str := fmt.Sprintf("%v", list[i])
bytes = append(bytes, []byte("$"+fmt.Sprint(len(str))+"\r\n"+str+"\r\n")...)
}
- bytes = append(bytes, []byte("\r\n")...)
return bytes, nil
}
@@ -145,11 +144,8 @@ func handleLRange(ctx context.Context, cmd []string, server utils.Server, conn *
} else {
i--
}
-
}
- bytes = append(bytes, []byte("\r\n")...)
-
return bytes, nil
}
@@ -377,7 +373,6 @@ func handleLPush(ctx context.Context, cmd []string, server utils.Server, conn *n
case "lpushx":
return nil, errors.New("LPUSHX command on non-list item")
default:
- // TODO: Retry CreateKeyAndLock until we obtain the key lock
if _, err = server.CreateKeyAndLock(ctx, key); err != nil {
return nil, err
}
@@ -420,7 +415,6 @@ func handleRPush(ctx context.Context, cmd []string, server utils.Server, conn *n
case "rpushx":
return nil, errors.New("RPUSHX command on non-list item")
default:
- // TODO: Retry CreateKeyAndLock until we managed to obtain the key
if _, err = server.CreateKeyAndLock(ctx, key); err != nil {
return nil, err
}
@@ -470,10 +464,10 @@ func handlePop(ctx context.Context, cmd []string, server utils.Server, conn *net
switch strings.ToLower(cmd[0]) {
default:
server.SetValue(ctx, key, list[1:])
- return []byte(fmt.Sprintf("+%v\r\n\r\n", list[0])), nil
+ return []byte(fmt.Sprintf("+%v\r\n", list[0])), nil
case "rpop":
server.SetValue(ctx, key, list[:len(list)-1])
- return []byte(fmt.Sprintf("+%v\r\n\r\n", list[len(list)-1])), nil
+ return []byte(fmt.Sprintf("+%v\r\n", list[len(list)-1])), nil
}
}
diff --git a/src/modules/ping/commands.go b/src/modules/ping/commands.go
index 4779d533..509b755f 100644
--- a/src/modules/ping/commands.go
+++ b/src/modules/ping/commands.go
@@ -13,9 +13,9 @@ func handlePing(ctx context.Context, cmd []string, server utils.Server, conn *ne
default:
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
case 1:
- return []byte("+PONG\r\n\r\n"), nil
+ return []byte("+PONG\r\n"), nil
case 2:
- return []byte(fmt.Sprintf("$%d\r\n%s\r\n\r\n", len(cmd[1]), cmd[1])), nil
+ return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(cmd[1]), cmd[1])), nil
}
}
@@ -31,17 +31,5 @@ func Commands() []utils.Command {
},
HandlerFunc: handlePing,
},
- {
- Command: "ack",
- Categories: []string{},
- Description: "",
- Sync: false,
- KeyExtractionFunc: func(cmd []string) ([]string, error) {
- return []string{}, nil
- },
- HandlerFunc: func(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
- return []byte("$-1\r\n\r\n"), nil
- },
- },
}
}
diff --git a/src/modules/pubsub/channel.go b/src/modules/pubsub/channel.go
new file mode 100644
index 00000000..41ae70f4
--- /dev/null
+++ b/src/modules/pubsub/channel.go
@@ -0,0 +1,113 @@
+package pubsub
+
+import (
+ "fmt"
+ "github.com/gobwas/glob"
+ "io"
+ "log"
+ "net"
+ "slices"
+ "sync"
+)
+
+// Channel - A channel can be subscribed to directly, or via a consumer group.
+// All direct subscribers to the channel will receive any message published to the channel.
+// Only one subscriber of a channel's consumer group will receive a message posted to the channel.
+type Channel struct {
+ name string
+ pattern glob.Glob
+ subscribersRWMut sync.RWMutex
+ subscribers []*net.Conn
+ messageChan *chan string
+}
+
+func WithName(name string) func(channel *Channel) {
+ return func(channel *Channel) {
+ channel.name = name
+ }
+}
+
+func WithPattern(pattern string) func(channel *Channel) {
+ return func(channel *Channel) {
+ channel.name = pattern
+ channel.pattern = glob.MustCompile(pattern)
+ }
+}
+
+func NewChannel(options ...func(channel *Channel)) *Channel {
+ messageChan := make(chan string, 4096)
+
+ channel := &Channel{
+ name: "",
+ pattern: nil,
+ subscribersRWMut: sync.RWMutex{},
+ subscribers: []*net.Conn{},
+ messageChan: &messageChan,
+ }
+
+ for _, option := range options {
+ option(channel)
+ }
+
+ return channel
+}
+
+func (ch *Channel) Start() {
+ go func() {
+ for {
+ message := <-*ch.messageChan
+
+ ch.subscribersRWMut.RLock()
+
+ for _, conn := range ch.subscribers {
+ go func(conn *net.Conn) {
+ w := io.Writer(*conn)
+
+ if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n", len(message), message))); err != nil {
+ log.Println(err)
+ }
+ }(conn)
+ }
+
+ ch.subscribersRWMut.RUnlock()
+ }
+ }()
+}
+
+func (ch *Channel) Subscribe(conn *net.Conn) {
+ if !slices.Contains(ch.subscribers, conn) {
+ ch.subscribersRWMut.Lock()
+ defer ch.subscribersRWMut.Unlock()
+
+ ch.subscribers = append(ch.subscribers, conn)
+ }
+}
+
+func (ch *Channel) Unsubscribe(conn *net.Conn) bool {
+ ch.subscribersRWMut.Lock()
+ defer ch.subscribersRWMut.Unlock()
+
+ var removed bool
+
+ ch.subscribers = slices.DeleteFunc(ch.subscribers, func(c *net.Conn) bool {
+ if c == conn {
+ removed = true
+ return true
+ }
+ return false
+ })
+
+ return removed
+}
+
+func (ch *Channel) Publish(message string) {
+ *ch.messageChan <- message
+}
+
+func (ch *Channel) IsActive() bool {
+ return len(ch.subscribers) > 0
+}
+
+func (ch *Channel) NumSubs() int {
+ return len(ch.subscribers)
+}
diff --git a/src/modules/pubsub/commands.go b/src/modules/pubsub/commands.go
index 1b550dc5..5ed322b9 100644
--- a/src/modules/pubsub/commands.go
+++ b/src/modules/pubsub/commands.go
@@ -3,48 +3,56 @@ package pubsub
import (
"context"
"errors"
+ "fmt"
"github.com/echovault/echovault/src/utils"
"net"
+ "strings"
)
func handleSubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*PubSub)
if !ok {
- return nil, errors.New("could not load pubsub")
+ return nil, errors.New("could not load pubsub module")
}
- switch len(cmd) {
- case 2:
- // Subscribe to specified channel
- pubsub.Subscribe(ctx, conn, cmd[1], nil)
- case 3:
- // Subscribe to specified channel and specified consumer group
- pubsub.Subscribe(ctx, conn, cmd[1], cmd[2])
- default:
+
+ channels := cmd[1:]
+
+ if len(channels) == 0 {
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
}
- return []byte("+SUBSCRIBE_OK\r\n\r\n"), nil
+
+ switch strings.ToLower(cmd[0]) {
+ case "subscribe":
+ return pubsub.Subscribe(ctx, conn, channels, false), nil
+ case "psubscribe":
+ return pubsub.Subscribe(ctx, conn, channels, true), nil
+ }
+
+ return []byte{}, nil
}
func handleUnsubscribe(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*PubSub)
if !ok {
- return nil, errors.New("could not load pubsub")
+ return nil, errors.New("could not load pubsub module")
}
- switch len(cmd) {
- case 1:
- pubsub.Unsubscribe(ctx, conn, nil)
- case 2:
- pubsub.Unsubscribe(ctx, conn, cmd[1])
+
+ channels := cmd[1:]
+
+ switch strings.ToLower(cmd[0]) {
+ case "unsubscribe":
+ return pubsub.Unsubscribe(ctx, conn, channels, false), nil
+ case "punsubscribe":
+ return pubsub.Unsubscribe(ctx, conn, channels, true), nil
default:
- return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
+ return []byte{}, nil
}
- return []byte(utils.OK_RESPONSE), nil
}
func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
pubsub, ok := server.GetPubSub().(*PubSub)
if !ok {
- return nil, errors.New("could not load pubsub")
+ return nil, errors.New("could not load pubsub module")
}
if len(cmd) != 3 {
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
@@ -53,49 +61,149 @@ func handlePublish(ctx context.Context, cmd []string, server utils.Server, conn
return []byte(utils.OK_RESPONSE), nil
}
+func handlePubSubChannels(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
+ if len(cmd) > 3 {
+ return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
+ }
+
+ pubsub, ok := server.GetPubSub().(*PubSub)
+ if !ok {
+ return nil, errors.New("could not load pubsub module")
+ }
+
+ pattern := ""
+ if len(cmd) == 3 {
+ pattern = cmd[2]
+ }
+
+ return pubsub.Channels(ctx, pattern), nil
+}
+
+func handlePubSubNumPat(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
+ pubsub, ok := server.GetPubSub().(*PubSub)
+ if !ok {
+ return nil, errors.New("could not load pubsub module")
+ }
+ num := pubsub.NumPat(ctx)
+ return []byte(fmt.Sprintf(":%d\r\n", num)), nil
+}
+
+func handlePubSubNumSubs(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
+ pubsub, ok := server.GetPubSub().(*PubSub)
+ if !ok {
+ return nil, errors.New("could not load pubsub module")
+ }
+ return pubsub.NumSub(ctx, cmd[2:]), nil
+}
+
func Commands() []utils.Command {
return []utils.Command{
{
- Command: "publish",
- Categories: []string{utils.PubSubCategory, utils.FastCategory},
- Description: "(PUBLISH channel message) Publish a message to the specified channel.",
- Sync: true,
+ Command: "subscribe",
+ Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory},
+ Description: "(SUBSCRIBE channel [channel ...]) Subscribe to one or more channels.",
+ Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
- // Treat the channel as a key
- if len(cmd) != 3 {
+ // Treat the channels as keys
+ if len(cmd) < 2 {
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
}
- return []string{cmd[1]}, nil
+ return cmd[1:], nil
},
- HandlerFunc: handlePublish,
+ HandlerFunc: handleSubscribe,
},
{
- Command: "subscribe",
+ Command: "psubscribe",
Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory},
- Description: "(SUBSCRIBE channel [consumer_group]) Subscribe to a channel with an option to join a consumer group on the channel.",
+ Description: "(PSUBSCRIBE pattern [pattern ...]) Subscribe to one or more glob patterns.",
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
- // Treat the channel as a key
+ // Treat the patterns as keys
if len(cmd) < 2 {
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
}
- return []string{cmd[1]}, nil
+ return cmd[1:], nil
},
HandlerFunc: handleSubscribe,
},
{
- Command: "unsubscribe",
- Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory},
- Description: "(UNSUBSCRIBE channel) Unsubscribe from a channel.",
- Sync: false,
+ Command: "publish",
+ Categories: []string{utils.PubSubCategory, utils.FastCategory},
+ Description: "(PUBLISH channel message) Publish a message to the specified channel.",
+ Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
// Treat the channel as a key
- if len(cmd) != 2 {
+ if len(cmd) != 3 {
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
}
return []string{cmd[1]}, nil
},
+ HandlerFunc: handlePublish,
+ },
+ {
+ Command: "unsubscribe",
+ Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory},
+ Description: `(UNSUBSCRIBE [channel [channel ...]]) Unsubscribe from a list of channels.
+If the channel list is not provided, then the connection will be unsubscribed from all the channels that
+it's currently subscribe to.`,
+ Sync: false,
+ KeyExtractionFunc: func(cmd []string) ([]string, error) {
+ // Treat the channels as keys
+ return cmd[1:], nil
+ },
HandlerFunc: handleUnsubscribe,
},
+ {
+ Command: "punsubscribe",
+ Categories: []string{utils.PubSubCategory, utils.ConnectionCategory, utils.SlowCategory},
+ Description: `(PUNSUBSCRIBE [channel [channel ...]]) Unsubscribe from a list of channels using patterns.
+If the pattern list is not provided, then the connection will be unsubscribed from all the patterns that
+it's currently subscribe to.`,
+ Sync: false,
+ KeyExtractionFunc: func(cmd []string) ([]string, error) {
+ // Treat the channels as keys
+ return cmd[1:], nil
+ },
+ HandlerFunc: handleUnsubscribe,
+ },
+ {
+ Command: "pubsub",
+ Categories: []string{},
+ Description: "",
+ Sync: false,
+ KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
+ HandlerFunc: func(_ context.Context, _ []string, _ utils.Server, _ *net.Conn) ([]byte, error) {
+ return nil, errors.New("provide CHANNELS, NUMPAT, or NUMSUB subcommand")
+ },
+ SubCommands: []utils.SubCommand{
+ {
+ Command: "channels",
+ Categories: []string{utils.PubSubCategory, utils.SlowCategory},
+ Description: `(PUBSUB CHANNELS [pattern]) Returns an array containing the list of channels that
+match the given pattern. If no pattern is provided, all active channels are returned. Active channels are
+channels with 1 or more subscribers.`,
+ Sync: false,
+ KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
+ HandlerFunc: handlePubSubChannels,
+ },
+ {
+ Command: "numpat",
+ Categories: []string{utils.PubSubCategory, utils.SlowCategory},
+ Description: `(PUBSUB NUMPAT) Return the number of patterns that are currently subscribed to by clients.`,
+ Sync: false,
+ KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil },
+ HandlerFunc: handlePubSubNumPat,
+ },
+ {
+ Command: "numsub",
+ Categories: []string{utils.PubSubCategory, utils.SlowCategory},
+ Description: `(PUBSUB NUMSUB [channel [channel ...]]) Return an array of arrays containing the provided
+channel name and how many clients are currently subscribed to the channel.`,
+ Sync: false,
+ KeyExtractionFunc: func(cmd []string) ([]string, error) { return cmd[2:], nil },
+ HandlerFunc: handlePubSubNumSubs,
+ },
+ },
+ },
}
}
diff --git a/src/modules/pubsub/pubsub.go b/src/modules/pubsub/pubsub.go
index 82fed04d..433e4477 100644
--- a/src/modules/pubsub/pubsub.go
+++ b/src/modules/pubsub/pubsub.go
@@ -1,308 +1,203 @@
package pubsub
import (
- "bytes"
- "container/ring"
"context"
"fmt"
- "github.com/echovault/echovault/src/utils"
- "io"
+ "github.com/gobwas/glob"
"net"
"slices"
"sync"
- "time"
)
-// ConsumerGroup allows multiple subscribers to share the consumption load of a channel.
-// Only one subscriber in the consumer group will receive messages published to the channel.
-type ConsumerGroup struct {
- name string
- subscribersRWMut sync.RWMutex
- subscribers *ring.Ring
- messageChan *chan string
+// PubSub container
+type PubSub struct {
+ channels []*Channel
+ channelsRWMut sync.RWMutex
}
-func NewConsumerGroup(name string) *ConsumerGroup {
- messageChan := make(chan string)
-
- return &ConsumerGroup{
- name: name,
- subscribersRWMut: sync.RWMutex{},
- subscribers: nil,
- messageChan: &messageChan,
+func NewPubSub() *PubSub {
+ return &PubSub{
+ channels: []*Channel{},
+ channelsRWMut: sync.RWMutex{},
}
}
-func (cg *ConsumerGroup) SendMessage(message string) {
- cg.subscribersRWMut.RLock()
+func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte {
+ res := fmt.Sprintf("*%d\r\n", len(channels))
- conn := cg.subscribers.Value.(*net.Conn)
+ for i := 0; i < len(channels); i++ {
+ // Check if channel with given name exists
+ // If it does, subscribe the connection to the channel
+ // If it does not, create the channel and subscribe to it
+ channelIdx := slices.IndexFunc(ps.channels, func(channel *Channel) bool {
+ return channel.name == channels[i]
+ })
- cg.subscribersRWMut.RUnlock()
-
- w, r := io.Writer(*conn), io.Reader(*conn)
-
- if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n\n", len(message), message))); err != nil {
- // TODO: Log error at configured logger
- fmt.Println(err)
- }
- // Wait for an ACK
- // If no ACK is received within a time limit, remove this connection from subscribers and retry
- if err := (*conn).SetReadDeadline(time.Now().Add(250 * time.Millisecond)); err != nil {
- // TODO: Log error at configured logger
- fmt.Println(err)
- }
- if msg, err := utils.ReadMessage(r); err != nil {
- // Remove the connection from subscribers list
- cg.Unsubscribe(conn)
- // Reset the deadline
- if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
- // TODO: Log error at configured logger
- fmt.Println(err)
- }
- // Retry sending the message
- cg.SendMessage(message)
- } else {
- if !bytes.Equal(bytes.TrimSpace(msg), []byte("+ACK")) {
- cg.Unsubscribe(conn)
- if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
- // TODO: Log error at configured logger
- fmt.Println(err)
+ if channelIdx == -1 {
+ // Create new channel, start it, and subscribe to it
+ var newChan *Channel
+ if withPattern {
+ newChan = NewChannel(WithPattern(channels[i]))
+ } else {
+ newChan = NewChannel(WithName(channels[i]))
}
- cg.SendMessage(message)
+ newChan.Start()
+ newChan.Subscribe(conn)
+ ps.channels = append(ps.channels, newChan)
+ } else {
+ // Subscribe to existing channel
+ ps.channels[channelIdx].Subscribe(conn)
}
- }
- if err := (*conn).SetDeadline(time.Time{}); err != nil {
- // TODO: Log error at configured logger
- fmt.Println(err)
+ if len(channels) > 1 {
+ // If subscribing to more than one channel, write array to verify the subscription of this channel
+ res += fmt.Sprintf("*3\r\n+subscribe\r\n$%d\r\n%s\r\n:%d\r\n", len(channels[i]), channels[i], i+1)
+ } else {
+ // Ony one channel, simply send "subscribe" simple string response
+ res = "+subscribe\r\n"
+ }
}
- cg.subscribers = cg.subscribers.Next()
-}
-func (cg *ConsumerGroup) Start() {
- go func() {
- for {
- message := <-*cg.messageChan
- if cg.subscribers != nil {
- cg.SendMessage(message)
- }
- }
- }()
+ return []byte(res)
}
-func (cg *ConsumerGroup) Subscribe(conn *net.Conn) {
- cg.subscribersRWMut.Lock()
- defer cg.subscribersRWMut.Unlock()
-
- r := ring.New(1)
- for i := 0; i < r.Len(); i++ {
- r.Value = conn
- r = r.Next()
- }
+func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channels []string, withPattern bool) []byte {
+ ps.channelsRWMut.RLock()
+ ps.channelsRWMut.RUnlock()
- if cg.subscribers == nil {
- cg.subscribers = r
- return
+ action := "unsubscribe"
+ if withPattern {
+ action = "subscribe"
}
- cg.subscribers = cg.subscribers.Link(r)
-}
-
-func (cg *ConsumerGroup) Unsubscribe(conn *net.Conn) {
- cg.subscribersRWMut.Lock()
- defer cg.subscribersRWMut.Unlock()
+ unsubscribed := make(map[int]string)
+ count := 1
- // If length is 1 and the connection passed is the one contained within, unlink it
- if cg.subscribers.Len() == 1 {
- if cg.subscribers.Value == conn {
- cg.subscribers = nil
+ // If the channels slice is empty, unsubscribe from all channels.
+ if len(channels) <= 0 {
+ for _, channel := range ps.channels {
+ if channel.Unsubscribe(conn) {
+ unsubscribed[1] = channel.name
+ count += 1
+ }
}
- return
}
- for i := 0; i < cg.subscribers.Len(); i++ {
- if cg.subscribers.Value == conn {
- cg.subscribers = cg.subscribers.Prev()
- cg.subscribers.Unlink(1)
- break
+ // If withPattern is false, unsubscribe from channels where the name exactly matches channel name.
+ if !withPattern {
+ for _, channel := range ps.channels { // For each channel in PubSub
+ for _, c := range channels { // For each channel name provided
+ if channel.name == c && channel.Unsubscribe(conn) {
+ unsubscribed[count] = channel.name
+ count += 1
+ }
+ }
}
- cg.subscribers = cg.subscribers.Next()
- }
-}
-
-func (cg *ConsumerGroup) Publish(message string) {
- *cg.messageChan <- message
-}
-
-// Channel - A channel can be subscribed to directly, or via a consumer group.
-// All direct subscribers to the channel will receive any message published to the channel.
-// Only one subscriber of a channel's consumer group will receive a message posted to the channel.
-type Channel struct {
- name string
- subscribersRWMut sync.RWMutex
- subscribers []*net.Conn
- consumerGroups []*ConsumerGroup
- messageChan *chan string
-}
-
-func NewChannel(name string) *Channel {
- messageChan := make(chan string)
-
- return &Channel{
- name: name,
- subscribersRWMut: sync.RWMutex{},
- subscribers: []*net.Conn{},
- consumerGroups: []*ConsumerGroup{},
- messageChan: &messageChan,
}
-}
-
-func (ch *Channel) Start() {
- go func() {
- for {
- message := <-*ch.messageChan
-
- ch.subscribersRWMut.RLock()
-
- for _, conn := range ch.subscribers {
- go func(conn *net.Conn) {
- w, r := io.Writer(*conn), io.Reader(*conn)
- if _, err := w.Write([]byte(fmt.Sprintf("$%d\r\n%s\r\n\r\n", len(message), message))); err != nil {
- // TODO: Log error at configured logger
- fmt.Println(err)
- }
-
- if err := (*conn).SetReadDeadline(time.Now().Add(200 * time.Millisecond)); err != nil {
- // TODO: Log error at configured logger
- fmt.Println(err)
- ch.Unsubscribe(conn)
- }
- defer func() {
- if err := (*conn).SetReadDeadline(time.Time{}); err != nil {
- // TODO: Log error at configured logger
- fmt.Println(err)
- ch.Unsubscribe(conn)
- }
- }()
-
- if msg, err := utils.ReadMessage(r); err != nil {
- ch.Unsubscribe(conn)
- } else {
- if !bytes.EqualFold(bytes.TrimSpace(msg), []byte("+ACK")) {
- ch.Unsubscribe(conn)
- }
- }
- }(conn)
+ // If withPattern is true, unsubscribe from channels where pattern matches pattern provided,
+ // also unsubscribe from channels where the name matches the given pattern.
+ if withPattern {
+ for _, pattern := range channels {
+ g := glob.MustCompile(pattern)
+ for _, channel := range ps.channels {
+ // If it's a pattern channel, directly compare the patterns
+ if channel.pattern != nil && channel.name == pattern {
+ unsubscribed[count] = channel.name
+ count += 1
+ continue
+ }
+ // If this is a regular channel, check if the channel name matches the pattern given
+ if g.Match(channel.name) {
+ unsubscribed[count] = channel.name
+ count += 1
+ }
}
-
- ch.subscribersRWMut.RUnlock()
}
- }()
-}
-
-func (ch *Channel) Subscribe(conn *net.Conn, consumerGroupName interface{}) {
- if consumerGroupName == nil && !slices.Contains(ch.subscribers, conn) {
- ch.subscribersRWMut.Lock()
- defer ch.subscribersRWMut.Unlock()
- ch.subscribers = append(ch.subscribers, conn)
- return
}
- groups := utils.Filter[*ConsumerGroup](ch.consumerGroups, func(group *ConsumerGroup) bool {
- return group.name == consumerGroupName.(string)
- })
-
- if len(groups) == 0 {
- go func() {
- newGroup := NewConsumerGroup(consumerGroupName.(string))
- newGroup.Start()
- newGroup.Subscribe(conn)
- ch.consumerGroups = append(ch.consumerGroups, newGroup)
- }()
- return
+ res := fmt.Sprintf("*%d\r\n", len(unsubscribed))
+ for key, value := range unsubscribed {
+ res += fmt.Sprintf("*3\r\n+%s\r\n$%d\r\n%s\r\n:%d\r\n", action, len(value), value, key)
}
- for _, group := range groups {
- go group.Subscribe(conn)
- }
+ return []byte(res)
}
-func (ch *Channel) Unsubscribe(conn *net.Conn) {
- ch.subscribersRWMut.Lock()
- defer ch.subscribersRWMut.Unlock()
-
- ch.subscribers = utils.Filter[*net.Conn](ch.subscribers, func(c *net.Conn) bool {
- return c != conn
- })
-
- for _, group := range ch.consumerGroups {
- go group.Unsubscribe(conn)
+func (ps *PubSub) Publish(ctx context.Context, message string, channelName string) {
+ ps.channelsRWMut.RLock()
+ defer ps.channelsRWMut.RUnlock()
+
+ for _, channel := range ps.channels {
+ // If it's a regular channel, check if the channel name matches the name given
+ if channel.pattern == nil {
+ if channel.name == channelName {
+ channel.Publish(message)
+ }
+ continue
+ }
+ // If it's a glob pattern channel, check if the name matches the pattern
+ if channel.pattern.Match(channelName) {
+ channel.Publish(message)
+ }
}
}
-func (ch *Channel) Publish(message string) {
- for _, group := range ch.consumerGroups {
- go group.Publish(message)
- }
- *ch.messageChan <- message
-}
+func (ps *PubSub) Channels(ctx context.Context, pattern string) []byte {
+ var count int
+ var res string
-// PubSub container
-type PubSub struct {
- channels []*Channel
-}
+ if pattern == "" {
+ for _, channel := range ps.channels {
+ if channel.IsActive() {
+ res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name)
+ count += 1
+ }
+ }
-func NewPubSub() *PubSub {
- return &PubSub{
- channels: []*Channel{},
+ res = fmt.Sprintf("*%d\r\n%s", count, res)
+ return []byte(res)
}
-}
-func (ps *PubSub) Subscribe(ctx context.Context, conn *net.Conn, channelName string, consumerGroup interface{}) {
- // Check if channel with given name exists
- // If it does, subscribe the connection to the channel
- // If it does not, create the channel and subscribe to it
- channelIdx := slices.IndexFunc(ps.channels, func(channel *Channel) bool {
- return channel.name == channelName
- })
+ g := glob.MustCompile(pattern)
- if channelIdx == -1 {
- go func() {
- newChan := NewChannel(channelName)
- newChan.Start()
- newChan.Subscribe(conn, consumerGroup)
- ps.channels = append(ps.channels, newChan)
- }()
- return
+ for _, channel := range ps.channels {
+ // If channel is a pattern channel, then directly compare the channel name to pattern
+ if channel.pattern != nil && channel.name == pattern && channel.IsActive() {
+ res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name)
+ count += 1
+ continue
+ }
+ if g.Match(channel.name) && channel.IsActive() {
+ res += fmt.Sprintf("$%d\r\n%s\r\n", len(channel.name), channel.name)
+ count += 1
+ }
}
- go ps.channels[channelIdx].Subscribe(conn, consumerGroup)
+ return []byte(res)
}
-func (ps *PubSub) Unsubscribe(ctx context.Context, conn *net.Conn, channelName interface{}) {
- if channelName == nil {
- for _, channel := range ps.channels {
- go channel.Unsubscribe(conn)
+func (ps *PubSub) NumPat(ctx context.Context) int {
+ var count int
+ for _, channel := range ps.channels {
+ if channel.pattern != nil {
+ count += 1
}
- return
- }
-
- channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool {
- return c.name == channelName
- })
-
- for _, channel := range channels {
- go channel.Unsubscribe(conn)
}
+ return count
}
-func (ps *PubSub) Publish(ctx context.Context, message string, channelName string) {
- channels := utils.Filter[*Channel](ps.channels, func(c *Channel) bool {
- return c.name == channelName
- })
+func (ps *PubSub) NumSub(ctx context.Context, channels []string) []byte {
+ res := fmt.Sprintf("*%d\r\n", len(channels))
for _, channel := range channels {
- go channel.Publish(message)
+ chanIdx := slices.IndexFunc(ps.channels, func(c *Channel) bool {
+ return c.name == channel
+ })
+ if chanIdx == -1 {
+ res += fmt.Sprintf("*2\r\n$%d\r\n%s\r\n:0\r\n", len(channel), channel)
+ continue
+ }
+ res += fmt.Sprintf("*2\r\n$%d\r\n%s\r\n:%d\r\n", len(channel), channel, ps.channels[chanIdx].NumSubs())
}
+ return []byte(res)
}
diff --git a/src/modules/set/commands.go b/src/modules/set/commands.go
index c425fbee..e3699ab5 100644
--- a/src/modules/set/commands.go
+++ b/src/modules/set/commands.go
@@ -27,7 +27,7 @@ func handleSADD(ctx context.Context, cmd []string, server utils.Server, conn *ne
}
server.SetValue(ctx, key, set)
server.KeyUnlock(key)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", len(cmd[2:]))), nil
+ return []byte(fmt.Sprintf(":%d\r\n", len(cmd[2:]))), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
@@ -42,7 +42,7 @@ func handleSADD(ctx context.Context, cmd []string, server utils.Server, conn *ne
count := set.Add(cmd[2:])
- return []byte(fmt.Sprintf(":%d\r\n\r\n", count)), nil
+ return []byte(fmt.Sprintf(":%d\r\n", count)), nil
}
func handleSCARD(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -54,7 +54,7 @@ func handleSCARD(ctx context.Context, cmd []string, server utils.Server, conn *n
key := keys[0]
if !server.KeyExists(key) {
- return []byte(fmt.Sprintf(":0\r\n\r\n")), nil
+ return []byte(fmt.Sprintf(":0\r\n")), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -69,7 +69,7 @@ func handleSCARD(ctx context.Context, cmd []string, server utils.Server, conn *n
cardinality := set.Cardinality()
- return []byte(fmt.Sprintf(":%d\r\n\r\n", cardinality)), nil
+ return []byte(fmt.Sprintf(":%d\r\n", cardinality)), nil
}
func handleSDIFF(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -126,7 +126,7 @@ func handleSDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n
for i, e := range elems {
res = fmt.Sprintf("%s\r\n$%d\r\n%s", res, len(e), e)
if i == len(elems)-1 {
- res += "\r\n\r\n"
+ res += "\r\n"
}
}
@@ -185,7 +185,7 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co
diff := baseSet.Subtract(sets)
elems := diff.GetAll()
- res := fmt.Sprintf(":%d\r\n\r\n", len(elems))
+ res := fmt.Sprintf(":%d\r\n", len(elems))
if server.KeyExists(destination) {
if _, err = server.KeyLock(ctx, destination); err != nil {
@@ -223,7 +223,7 @@ func handleSINTER(ctx context.Context, cmd []string, server utils.Server, conn *
for _, key := range keys[0:] {
if !server.KeyExists(key) {
// If key does not exist, then there is no intersection
- return []byte("*0\r\n\r\n"), nil
+ return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err
@@ -253,7 +253,7 @@ func handleSINTER(ctx context.Context, cmd []string, server utils.Server, conn *
for i, e := range elems {
res = fmt.Sprintf("%s\r\n$%d\r\n%s", res, len(e), e)
if i == len(elems)-1 {
- res += "\r\n\r\n"
+ res += "\r\n"
}
}
@@ -299,7 +299,7 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server utils.Server, co
for _, key := range keys {
if !server.KeyExists(key) {
// If key does not exist, then there is no intersection
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err
@@ -324,7 +324,7 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server utils.Server, co
intersect, _ := Intersection(limit, sets...)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", intersect.Cardinality())), nil
+ return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
}
func handleSINTERSTORE(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -345,7 +345,7 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c
for _, key := range keys[1:] {
if !server.KeyExists(key) {
// If key does not exist, then there is no intersection
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err
@@ -380,7 +380,7 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c
server.SetValue(ctx, destination, intersect)
server.KeyUnlock(destination)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", intersect.Cardinality())), nil
+ return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
}
func handleSISMEMBER(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -392,7 +392,7 @@ func handleSISMEMBER(ctx context.Context, cmd []string, server utils.Server, con
key := keys[0]
if !server.KeyExists(key) {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -406,10 +406,10 @@ func handleSISMEMBER(ctx context.Context, cmd []string, server utils.Server, con
}
if !set.Contains(cmd[2]) {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
- return []byte(":1\r\n\r\n"), nil
+ return []byte(":1\r\n"), nil
}
func handleSMEMBERS(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -421,7 +421,7 @@ func handleSMEMBERS(ctx context.Context, cmd []string, server utils.Server, conn
key := keys[0]
if !server.KeyExists(key) {
- return []byte("*0\r\n\r\n"), nil
+ return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -440,7 +440,7 @@ func handleSMEMBERS(ctx context.Context, cmd []string, server utils.Server, conn
for i, e := range elems {
res = fmt.Sprintf("%s\r\n$%d\r\n%s", res, len(e), e)
if i == len(elems)-1 {
- res += "\r\n\r\n"
+ res += "\r\n"
}
}
@@ -461,7 +461,7 @@ func handleSMISMEMBER(ctx context.Context, cmd []string, server utils.Server, co
for i, _ := range members {
res = fmt.Sprintf("%s\r\n:0", res)
if i == len(members)-1 {
- res += "\r\n\r\n"
+ res += "\r\n"
}
}
return []byte(res), nil
@@ -485,7 +485,7 @@ func handleSMISMEMBER(ctx context.Context, cmd []string, server utils.Server, co
res += "\r\n:0"
}
}
- res += "\r\n\r\n"
+ res += "\r\n"
return []byte(res), nil
}
@@ -501,7 +501,7 @@ func handleSMOVE(ctx context.Context, cmd []string, server utils.Server, conn *n
member := cmd[3]
if !server.KeyExists(source) {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyLock(ctx, source); err != nil {
@@ -539,7 +539,7 @@ func handleSMOVE(ctx context.Context, cmd []string, server utils.Server, conn *n
res := sourceSet.Move(destinationSet, member)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", res)), nil
+ return []byte(fmt.Sprintf(":%d\r\n", res)), nil
}
func handleSPOP(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -560,7 +560,7 @@ func handleSPOP(ctx context.Context, cmd []string, server utils.Server, conn *ne
}
if !server.KeyExists(key) {
- return []byte("*-1\r\n\r\n"), nil
+ return []byte("*-1\r\n"), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
@@ -579,7 +579,7 @@ func handleSPOP(ctx context.Context, cmd []string, server utils.Server, conn *ne
for i, m := range members {
res = fmt.Sprintf("%s\r\n$%d\r\n%s", res, len(m), m)
if i == len(members)-1 {
- res += "\r\n\r\n"
+ res += "\r\n"
}
}
@@ -604,7 +604,7 @@ func handleSRANDMEMBER(ctx context.Context, cmd []string, server utils.Server, c
}
if !server.KeyExists(key) {
- return []byte("*-1\r\n\r\n"), nil
+ return []byte("*-1\r\n"), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
@@ -623,7 +623,7 @@ func handleSRANDMEMBER(ctx context.Context, cmd []string, server utils.Server, c
for i, m := range members {
res = fmt.Sprintf("%s\r\n$%d\r\n%s", res, len(m), m)
if i == len(members)-1 {
- res += "\r\n\r\n"
+ res += "\r\n"
}
}
@@ -640,7 +640,7 @@ func handleSREM(ctx context.Context, cmd []string, server utils.Server, conn *ne
members := cmd[2:]
if !server.KeyExists(key) {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
@@ -655,7 +655,7 @@ func handleSREM(ctx context.Context, cmd []string, server utils.Server, conn *ne
count := set.Remove(members)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", count)), nil
+ return []byte(fmt.Sprintf(":%d\r\n", count)), nil
}
func handleSUNION(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -702,7 +702,7 @@ func handleSUNION(ctx context.Context, cmd []string, server utils.Server, conn *
for i, e := range union.GetAll() {
res = fmt.Sprintf("%s\r\n$%d\r\n%s", res, len(e), e)
if i == len(union.GetAll())-1 {
- res += "\r\n\r\n"
+ res += "\r\n"
}
}
@@ -763,7 +763,7 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c
defer server.KeyUnlock(destination)
server.SetValue(ctx, destination, union)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", union.Cardinality())), nil
+ return []byte(fmt.Sprintf(":%d\r\n", union.Cardinality())), nil
}
func Commands() []utils.Command {
diff --git a/src/modules/set/set.go b/src/modules/set/set.go
index 379e225e..75bcbb26 100644
--- a/src/modules/set/set.go
+++ b/src/modules/set/set.go
@@ -75,8 +75,8 @@ func (set *Set) GetRandom(count int) []string {
n = rand.Intn(len(keys))
if !slices.Contains(res, keys[n]) {
res = append(res, keys[n])
- keys = utils.Filter(keys, func(elem string) bool {
- return elem != keys[n]
+ keys = slices.DeleteFunc(keys, func(elem string) bool {
+ return elem == keys[n]
})
i++
}
diff --git a/src/modules/sorted_set/commands.go b/src/modules/sorted_set/commands.go
index 1a0abeb6..b08ce3c5 100644
--- a/src/modules/sorted_set/commands.go
+++ b/src/modules/sorted_set/commands.go
@@ -128,7 +128,7 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne
if server.KeyExists(key) {
// Key exists
- _, err := server.KeyLock(ctx, key)
+ _, err = server.KeyLock(ctx, key)
if err != nil {
return nil, err
}
@@ -144,10 +144,10 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne
// If INCR option is provided, return the new score value
if incr != nil {
m := set.Get(members[0].value)
- return []byte(fmt.Sprintf("+%f\r\n\r\n", m.score)), nil
+ return []byte(fmt.Sprintf("+%f\r\n", m.score)), nil
}
- return []byte(fmt.Sprintf(":%d\r\n\r\n", count)), nil
+ return []byte(fmt.Sprintf(":%d\r\n", count)), nil
}
// Key does not exist
@@ -159,7 +159,7 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne
set := NewSortedSet(members)
server.SetValue(ctx, key, set)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", set.Cardinality())), nil
+ return []byte(fmt.Sprintf(":%d\r\n", set.Cardinality())), nil
}
func handleZCARD(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -170,7 +170,7 @@ func handleZCARD(ctx context.Context, cmd []string, server utils.Server, conn *n
key := keys[0]
if !server.KeyExists(key) {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -183,7 +183,7 @@ func handleZCARD(ctx context.Context, cmd []string, server utils.Server, conn *n
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
- return []byte(fmt.Sprintf(":%d\r\n\r\n", set.Cardinality())), nil
+ return []byte(fmt.Sprintf(":%d\r\n", set.Cardinality())), nil
}
func handleZCOUNT(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -231,7 +231,7 @@ func handleZCOUNT(ctx context.Context, cmd []string, server utils.Server, conn *
}
if !server.KeyExists(key) {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -251,7 +251,7 @@ func handleZCOUNT(ctx context.Context, cmd []string, server utils.Server, conn *
}
}
- return []byte(fmt.Sprintf(":%d\r\n\r\n", len(members))), nil
+ return []byte(fmt.Sprintf(":%d\r\n", len(members))), nil
}
func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -265,7 +265,7 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.Server, con
maximum := cmd[3]
if !server.KeyExists(key) {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -283,7 +283,7 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.Server, con
// Check if all members has the same score
for i := 0; i < len(members)-2; i++ {
if members[i].score != members[i+1].score {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
}
@@ -296,7 +296,7 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.Server, con
}
}
- return []byte(fmt.Sprintf(":%d\r\n\r\n", count)), nil
+ return []byte(fmt.Sprintf(":%d\r\n", count)), nil
}
func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -324,7 +324,7 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n
// Extract base set
if !server.KeyExists(keys[0]) {
// If base set does not exist, return an empty array
- return []byte("*0\r\n\r\n"), nil
+ return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, keys[0]); err != nil {
return nil, err
@@ -367,7 +367,7 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n
}
}
- res += "\r\n\r\n"
+ res += "\r\n"
return []byte(res), nil
}
@@ -392,7 +392,7 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co
// Extract base set
if !server.KeyExists(keys[0]) {
// If base set does not exist, return 0
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, keys[0]); err != nil {
return nil, err
@@ -433,7 +433,7 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co
server.SetValue(ctx, destination, diff)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", diff.Cardinality())), nil
+ return []byte(fmt.Sprintf(":%d\r\n", diff.Cardinality())), nil
}
func handleZINCRBY(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -473,7 +473,7 @@ func handleZINCRBY(ctx context.Context, cmd []string, server utils.Server, conn
}
server.SetValue(ctx, key, NewSortedSet([]MemberParam{{value: member, score: increment}}))
server.KeyUnlock(key)
- return []byte(fmt.Sprintf("+%s\r\n\r\n", strconv.FormatFloat(float64(increment), 'f', -1, 64))), nil
+ return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(float64(increment), 'f', -1, 64))), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
@@ -493,7 +493,7 @@ func handleZINCRBY(ctx context.Context, cmd []string, server utils.Server, conn
"incr"); err != nil {
return nil, err
}
- return []byte(fmt.Sprintf("+%s\r\n\r\n",
+ return []byte(fmt.Sprintf("+%s\r\n",
strconv.FormatFloat(float64(set.Get(member).score), 'f', -1, 64))), nil
}
@@ -522,7 +522,7 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.Server, conn *
for i := 0; i < len(keys); i++ {
if !server.KeyExists(keys[i]) {
// If any of the keys is non-existent, return an empty array as there's no intersect
- return []byte("*0\r\n\r\n"), nil
+ return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, keys[i]); err != nil {
return nil, err
@@ -552,7 +552,7 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.Server, conn *
}
}
- res += "\r\n\r\n"
+ res += "\r\n"
return []byte(res), nil
}
@@ -588,7 +588,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c
for i := 0; i < len(keys); i++ {
if !server.KeyExists(keys[i]) {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, keys[i]); err != nil {
return nil, err
@@ -619,7 +619,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c
server.SetValue(ctx, destination, intersect)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", intersect.Cardinality())), nil
+ return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
}
func handleZMPOP(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -691,13 +691,13 @@ func handleZMPOP(ctx context.Context, cmd []string, server utils.Server, conn *n
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.value), m.value, strconv.FormatFloat(float64(m.score), 'f', -1, 64))
}
- res += "\r\n\r\n"
+ res += "\r\n"
return []byte(res), nil
}
}
- return []byte("*0\r\n\r\n"), nil
+ return []byte("*0\r\n"), nil
}
func handleZPOP(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -723,7 +723,7 @@ func handleZPOP(ctx context.Context, cmd []string, server utils.Server, conn *ne
}
if !server.KeyExists(key) {
- return []byte("*0\r\n\r\n"), nil
+ return []byte("*0\r\n"), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
@@ -746,7 +746,7 @@ func handleZPOP(ctx context.Context, cmd []string, server utils.Server, conn *ne
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.value), m.value, strconv.FormatFloat(float64(m.score), 'f', -1, 64))
}
- res += "\r\n\r\n"
+ res += "\r\n"
return []byte(res), nil
}
@@ -760,7 +760,7 @@ func handleZMSCORE(ctx context.Context, cmd []string, server utils.Server, conn
key := keys[0]
if !server.KeyExists(key) {
- return []byte("*0\r\n\r\n"), nil
+ return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -788,7 +788,7 @@ func handleZMSCORE(ctx context.Context, cmd []string, server utils.Server, conn
}
}
- res += "\r\n\r\n"
+ res += "\r\n"
return []byte(res), nil
}
@@ -819,7 +819,7 @@ func handleZRANDMEMBER(ctx context.Context, cmd []string, server utils.Server, c
}
if !server.KeyExists(key) {
- return []byte("$-1\r\n\r\n"), nil
+ return []byte("$-1\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -843,7 +843,7 @@ func handleZRANDMEMBER(ctx context.Context, cmd []string, server utils.Server, c
}
}
- res += "\r\n\r\n"
+ res += "\r\n"
return []byte(res), nil
}
@@ -863,7 +863,7 @@ func handleZRANK(ctx context.Context, cmd []string, server utils.Server, conn *n
}
if !server.KeyExists(key) {
- return []byte("$-1\r\n\r\n"), nil
+ return []byte("$-1\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -888,14 +888,14 @@ func handleZRANK(ctx context.Context, cmd []string, server utils.Server, conn *n
if members[i].value == Value(member) {
if withscores {
score := strconv.FormatFloat(float64(members[i].score), 'f', -1, 64)
- return []byte(fmt.Sprintf("*2\r\n:%d\r\n$%d\r\n%s\r\n\r\n", i, len(score), score)), nil
+ return []byte(fmt.Sprintf("*2\r\n:%d\r\n$%d\r\n%s\r\n", i, len(score), score)), nil
} else {
- return []byte(fmt.Sprintf("*1\r\n:%d\r\n\r\n", i)), nil
+ return []byte(fmt.Sprintf("*1\r\n:%d\r\n", i)), nil
}
}
}
- return []byte("$-1\r\n\r\n"), nil
+ return []byte("$-1\r\n"), nil
}
func handleZREM(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -907,7 +907,7 @@ func handleZREM(ctx context.Context, cmd []string, server utils.Server, conn *ne
key := keys[0]
if !server.KeyExists(key) {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
@@ -927,7 +927,7 @@ func handleZREM(ctx context.Context, cmd []string, server utils.Server, conn *ne
}
}
- return []byte(fmt.Sprintf(":%d\r\n\r\n", deletedCount)), nil
+ return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil
}
func handleZSCORE(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -939,7 +939,7 @@ func handleZSCORE(ctx context.Context, cmd []string, server utils.Server, conn *
key := keys[0]
if !server.KeyExists(key) {
- return []byte("$-1\r\n\r\n"), nil
+ return []byte("$-1\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err
@@ -951,12 +951,12 @@ func handleZSCORE(ctx context.Context, cmd []string, server utils.Server, conn *
}
member := set.Get(Value(cmd[2]))
if !member.exists {
- return []byte("$-1\r\n\r\n"), nil
+ return []byte("$-1\r\n"), nil
}
score := strconv.FormatFloat(float64(member.score), 'f', -1, 64)
- return []byte(fmt.Sprintf("$%d\r\n%s\r\n\r\n", len(score), score)), nil
+ return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(score), score)), nil
}
func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -980,7 +980,7 @@ func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server utils.Serv
}
if !server.KeyExists(key) {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
@@ -1000,7 +1000,7 @@ func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server utils.Serv
}
}
- return []byte(fmt.Sprintf(":%d\r\n\r\n", deletedCount)), nil
+ return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil
}
func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -1022,7 +1022,7 @@ func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server utils.Serve
}
if !server.KeyExists(key) {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
@@ -1065,7 +1065,7 @@ func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server utils.Serve
}
}
- return []byte(fmt.Sprintf(":%d\r\n\r\n", deletedCount)), nil
+ return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil
}
func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -1079,7 +1079,7 @@ func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server utils.Server
maximum := cmd[3]
if !server.KeyExists(key) {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err = server.KeyLock(ctx, key); err != nil {
@@ -1097,7 +1097,7 @@ func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server utils.Server
// Check if all the members have the same score. If not, return 0
for i := 0; i < len(members)-1; i++ {
if members[i].score != members[i+1].score {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
}
@@ -1112,7 +1112,7 @@ func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server utils.Server
}
}
- return []byte(fmt.Sprintf(":%d\r\n\r\n", deletedCount)), nil
+ return []byte(fmt.Sprintf(":%d\r\n", deletedCount)), nil
}
func handleZRANGE(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -1177,7 +1177,7 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.Server, conn *
}
if !server.KeyExists(key) {
- return []byte("*0\r\n\r\n"), nil
+ return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -1191,7 +1191,7 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.Server, conn *
}
if offset > set.Cardinality() {
- return []byte("*0\r\n\r\n"), nil
+ return []byte("*0\r\n"), nil
}
if count < 0 {
count = set.Cardinality() - offset
@@ -1211,7 +1211,7 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.Server, conn *
// If policy is BYLEX, all the elements must have the same score
for i := 0; i < len(members)-1; i++ {
if members[i].score != members[i+1].score {
- return []byte("*0\r\n\r\n"), nil
+ return []byte("*0\r\n"), nil
}
}
slices.SortFunc(members, func(a, b MemberParam) int {
@@ -1241,9 +1241,7 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.Server, conn *
}
res := fmt.Sprintf("*%d", len(resultMembers))
- if len(resultMembers) == 0 {
- res += "\r\n\r\n"
- }
+
for _, m := range resultMembers {
if withscores {
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.value), m.value, strconv.FormatFloat(float64(m.score), 'f', -1, 64))
@@ -1252,7 +1250,7 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.Server, conn *
}
}
- res += "\r\n\r\n"
+ res += "\r\n"
return []byte(res), nil
}
@@ -1316,7 +1314,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.Server, c
}
if !server.KeyExists(source) {
- return []byte("*0\r\n\r\n"), nil
+ return []byte("*0\r\n"), nil
}
if _, err = server.KeyRLock(ctx, source); err != nil {
@@ -1330,7 +1328,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.Server, c
}
if offset > set.Cardinality() {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if count < 0 {
count = set.Cardinality() - offset
@@ -1350,7 +1348,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.Server, c
// If policy is BYLEX, all the elements must have the same score
for i := 0; i < len(members)-1; i++ {
if members[i].score != members[i+1].score {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
}
slices.SortFunc(members, func(a, b MemberParam) int {
@@ -1394,7 +1392,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.Server, c
server.SetValue(ctx, destination, newSortedSet)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", newSortedSet.Cardinality())), nil
+ return []byte(fmt.Sprintf(":%d\r\n", newSortedSet.Cardinality())), nil
}
func handleZUNION(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -1446,7 +1444,7 @@ func handleZUNION(ctx context.Context, cmd []string, server utils.Server, conn *
}
}
- res += "\r\n\r\n"
+ res += "\r\n"
return []byte(res), nil
}
@@ -1512,7 +1510,7 @@ func handleZUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c
server.SetValue(ctx, destination, union)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", union.Cardinality())), nil
+ return []byte(fmt.Sprintf(":%d\r\n", union.Cardinality())), nil
}
func Commands() []utils.Command {
diff --git a/src/modules/string/commands.go b/src/modules/string/commands.go
index 3ebe060b..296eac3b 100644
--- a/src/modules/string/commands.go
+++ b/src/modules/string/commands.go
@@ -29,7 +29,7 @@ func handleSetRange(ctx context.Context, cmd []string, server utils.Server, conn
}
server.SetValue(ctx, key, newStr)
server.KeyUnlock(key)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", len(newStr))), nil
+ return []byte(fmt.Sprintf(":%d\r\n", len(newStr))), nil
}
if _, err := server.KeyLock(ctx, key); err != nil {
@@ -46,14 +46,14 @@ func handleSetRange(ctx context.Context, cmd []string, server utils.Server, conn
if offset >= len(str) {
newStr = str + newStr
server.SetValue(ctx, key, newStr)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", len(newStr))), nil
+ return []byte(fmt.Sprintf(":%d\r\n", len(newStr))), nil
}
// If the offset is < 0, prepend the new string to the old one.
if offset < 0 {
newStr = newStr + str
server.SetValue(ctx, key, newStr)
- return []byte(fmt.Sprintf(":%d\r\n\r\n", len(newStr))), nil
+ return []byte(fmt.Sprintf(":%d\r\n", len(newStr))), nil
}
strRunes := []rune(str)
@@ -72,7 +72,7 @@ func handleSetRange(ctx context.Context, cmd []string, server utils.Server, conn
server.SetValue(ctx, key, string(strRunes))
- return []byte(fmt.Sprintf(":%d\r\n\r\n", len(strRunes))), nil
+ return []byte(fmt.Sprintf(":%d\r\n", len(strRunes))), nil
}
func handleStrLen(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -84,7 +84,7 @@ func handleStrLen(ctx context.Context, cmd []string, server utils.Server, conn *
key := keys[0]
if !server.KeyExists(key) {
- return []byte(":0\r\n\r\n"), nil
+ return []byte(":0\r\n"), nil
}
if _, err := server.KeyRLock(ctx, key); err != nil {
@@ -98,7 +98,7 @@ func handleStrLen(ctx context.Context, cmd []string, server utils.Server, conn *
return nil, fmt.Errorf("value at key %s is not a string", key)
}
- return []byte(fmt.Sprintf(":%d\r\n\r\n", len(value))), nil
+ return []byte(fmt.Sprintf(":%d\r\n", len(value))), nil
}
func handleSubStr(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
@@ -161,7 +161,7 @@ func handleSubStr(ctx context.Context, cmd []string, server utils.Server, conn *
str = res
}
- return []byte(fmt.Sprintf("$%d\r\n%s\r\n\r\n", len(str), str)), nil
+ return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(str), str)), nil
}
func Commands() []utils.Command {
diff --git a/src/server/aof/log/store.go b/src/server/aof/log/store.go
index 314e094f..a0830e48 100644
--- a/src/server/aof/log/store.go
+++ b/src/server/aof/log/store.go
@@ -15,7 +15,7 @@ import (
)
type AppendReadWriter interface {
- io.ReadWriter
+ io.ReadWriteSeeker
io.Closer
Truncate(size int64) error
Sync() error
@@ -101,7 +101,9 @@ func NewAppendStore(options ...func(store *AppendStore)) *AppendStore {
func (store *AppendStore) Write(command []byte) error {
store.mut.Lock()
defer store.mut.Unlock()
- if _, err := store.rw.Write(command); err != nil {
+ // Add new line before writing to AOF file.
+ out := append(command, []byte("\r\n")...)
+ if _, err := store.rw.Write(out); err != nil {
return err
}
if strings.EqualFold(store.strategy, "always") {
@@ -160,6 +162,10 @@ func (store *AppendStore) Truncate() error {
if err := store.rw.Truncate(0); err != nil {
return err
}
+ // Seek to the beginning of the file after truncating
+ if _, err := store.rw.Seek(0, 0); err != nil {
+ return err
+ }
return nil
}
diff --git a/src/server/aof/preamble/store.go b/src/server/aof/preamble/store.go
index 2a36d778..08f22651 100644
--- a/src/server/aof/preamble/store.go
+++ b/src/server/aof/preamble/store.go
@@ -125,6 +125,10 @@ func (store *PreambleStore) Restore() error {
return err
}
+ if len(b) <= 0 {
+ return nil
+ }
+
state := make(map[string]interface{})
if err = json.Unmarshal(b, &state); err != nil {
diff --git a/src/server/modules.go b/src/server/modules.go
index 55b04bad..0dbb704c 100644
--- a/src/server/modules.go
+++ b/src/server/modules.go
@@ -36,6 +36,8 @@ func (server *Server) handleCommand(ctx context.Context, message []byte, conn *n
return nil, err
}
+ fmt.Println(cmd)
+
command, err := server.getCommand(cmd[0])
if err != nil {
return nil, err
diff --git a/src/server/server.go b/src/server/server.go
index 50023727..cb366473 100644
--- a/src/server/server.go
+++ b/src/server/server.go
@@ -208,6 +208,7 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
if err != nil && errors.Is(err, io.EOF) {
// Connection closed
+ log.Println(err)
break
}
@@ -218,8 +219,12 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
res, err := server.handleCommand(ctx, message, &conn, false)
+ if err != nil && errors.Is(err, io.EOF) {
+ break
+ }
+
if err != nil {
- if _, err = w.Write([]byte(fmt.Sprintf("-Error %s\r\n\r\n", err.Error()))); err != nil {
+ if _, err = w.Write([]byte(fmt.Sprintf("-Error %s\r\n", err.Error()))); err != nil {
log.Println(err)
}
continue
@@ -227,11 +232,13 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
chunkSize := 1024
+ // If the length of the response is 0, return nothing to the client
+ if len(res) == 0 {
+ continue
+ }
+
if len(res) <= chunkSize {
- _, err = w.Write(res)
- if err != nil {
- log.Println(err)
- }
+ _, _ = w.Write(res)
continue
}
@@ -240,7 +247,10 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
for {
// If the current start index is less than chunkSize from length, return the remaining bytes.
if len(res)-1-startIndex < chunkSize {
- _, _ = w.Write(res[startIndex:])
+ _, err = w.Write(res[startIndex:])
+ if err != nil {
+ log.Println(err)
+ }
break
}
n, _ := w.Write(res[startIndex : startIndex+chunkSize])
diff --git a/src/utils/const.go b/src/utils/const.go
index b7418c8e..2c7b5c7c 100644
--- a/src/utils/const.go
+++ b/src/utils/const.go
@@ -25,6 +25,6 @@ const (
)
const (
- OK_RESPONSE = "+OK\r\n\r\n"
+ OK_RESPONSE = "+OK\r\n"
WRONG_ARGS_RESPONSE = "wrong number of arguments"
)
diff --git a/src/utils/utils.go b/src/utils/utils.go
index d005354c..ed32c13a 100644
--- a/src/utils/utils.go
+++ b/src/utils/utils.go
@@ -1,7 +1,9 @@
package utils
import (
+ "bufio"
"bytes"
+ "errors"
"io"
"log"
"math/big"
@@ -32,60 +34,45 @@ func AdaptType(s string) interface{} {
return f
}
-func Filter[T any](arr []T, test func(elem T) bool) (res []T) {
- for _, e := range arr {
- if test(e) {
- res = append(res, e)
- }
- }
- return
-}
-
func Decode(raw []byte) ([]string, error) {
- rd := resp.NewReader(bytes.NewBuffer(raw))
- var res []string
-
- v, _, err := rd.ReadValue()
+ reader := resp.NewReader(bytes.NewReader(raw))
+ value, _, err := reader.ReadValue()
if err != nil {
return nil, err
}
- if slices.Contains([]string{"SimpleString", "Integer", "Error"}, v.Type().String()) {
- return []string{v.String()}, nil
- }
-
- if v.Type().String() == "Array" {
- for _, elem := range v.Array() {
- res = append(res, elem.String())
- }
+ var res []string
+ for i := 0; i < len(value.Array()); i++ {
+ res = append(res, value.Array()[i].String())
}
return res, nil
}
func ReadMessage(r io.Reader) ([]byte, error) {
- delim := []byte{'\r', '\n', '\r', '\n'}
- buffSize := 8
- buff := make([]byte, buffSize)
+ reader := bufio.NewReader(r)
- var n int
- var err error
var res []byte
+ chunk := make([]byte, 8192)
+
for {
- n, err = r.Read(buff)
- res = append(res, buff...)
- if n < buffSize || err != nil {
+ n, err := reader.Read(chunk)
+ if err != nil && errors.Is(err, io.EOF) {
break
}
- if bytes.Equal(buff[len(buff)-4:], delim) {
+ if err != nil {
+ return nil, err
+ }
+ res = append(res, chunk...)
+ if n < len(chunk) {
break
}
- clear(buff)
+ clear(chunk)
}
- return res, err
+ return bytes.Trim(res, "\x00"), nil
}
func RetryBackoff(b retry.Backoff, maxRetries uint64, jitter, cappedDuration, maxDuration time.Duration) retry.Backoff {