diff --git a/internal/modules/generic/commands.go b/internal/modules/generic/commands.go index 6ad523a9..5c99c652 100644 --- a/internal/modules/generic/commands.go +++ b/internal/modules/generic/commands.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "log" + "reflect" "strconv" "strings" "time" @@ -785,6 +786,46 @@ func handleGetex(params internal.HandlerFuncParams) ([]byte, error) { } +func handleType(params internal.HandlerFuncParams) ([]byte, error) { + keys, err := getKeyFunc(params.Command) + if err != nil { + return nil, err + } + key := keys.ReadKeys[0] + keyExists := params.KeysExist(params.Context, []string{key})[key] + + if !keyExists { + return nil, fmt.Errorf("key %s does not exist", key) + } + + value := params.GetValues(params.Context, []string{key})[key] + t := reflect.TypeOf(value) + type_string := "" + switch t.Kind() { + case reflect.String: + type_string = "string" + case reflect.Int: + type_string = "integer" + case reflect.Float64: + type_string = "float" + case reflect.Slice: + type_string = "list" + case reflect.Map: + type_string = "hash" + case reflect.Pointer: + if t.Elem().Name() == "Set" { + type_string = "set" + } else if t.Elem().Name() == "SortedSet" { + type_string = "zset" + } else { + type_string = t.Elem().Name() + } + default: + type_string = fmt.Sprintf("%T", value) + } + return []byte(fmt.Sprintf("+%v\r\n", type_string)), nil +} + func Commands() []internal.Command { return []internal.Command{ { @@ -1086,5 +1127,14 @@ Delete all the keys in the currently selected database. This command is always s KeyExtractionFunc: getExKeyFunc, HandlerFunc: handleGetex, }, + { + Command: "type", + Module: constants.GenericModule, + Categories: []string{constants.KeyspaceCategory, constants.ReadCategory, constants.FastCategory}, + Description: "(TYPE key) Returns the string representation of the type of the value stored at key. The different types that can be returned are: string, integer, float, list, set, zset, and hash.", + Sync: false, + KeyExtractionFunc: typeKeyFunc, + HandlerFunc: handleType, + }, } } diff --git a/internal/modules/generic/commands_test.go b/internal/modules/generic/commands_test.go index 8fb741d3..c17e1bd9 100644 --- a/internal/modules/generic/commands_test.go +++ b/internal/modules/generic/commands_test.go @@ -27,6 +27,8 @@ import ( "github.com/echovault/echovault/internal/clock" "github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/constants" + "github.com/echovault/echovault/internal/modules/set" + "github.com/echovault/echovault/internal/modules/sorted_set" "github.com/tidwall/resp" ) @@ -3147,4 +3149,186 @@ func Test_Generic(t *testing.T) { } }) + t.Run("Test_HandleType", func(t *testing.T) { + t.Parallel() + conn, err := internal.GetConnection("localhost", port) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + presetCommand string + command []string + expectedResponse string + expectedError error + }{ + { + name: "Test TYPE with preset string value", + key: "TypeTestString", + presetValue: "Hello", + command: []string{"TYPE", "TypeTestString"}, + expectedResponse: "string", + expectedError: nil, + }, + { + name: "Test TYPE with preset integer value", + key: "TypeTestInteger", + presetValue: 1, + command: []string{"TYPE", "TypeTestInteger"}, + expectedResponse: "integer", + expectedError: nil, + }, + { + name: "Test TYPE with preset float value", + key: "TypeTestFloat", + presetValue: 1.12, + command: []string{"TYPE", "TypeTestFloat"}, + expectedResponse: "float", + expectedError: nil, + }, + { + name: "Test TYPE with preset set value", + key: "TypeTestSet", + presetValue: set.NewSet([]string{"one", "two", "three", "four"}), + command: []string{"TYPE", "TypeTestSet"}, + expectedResponse: "set", + expectedError: nil, + }, + { + name: "Test TYPE with preset list value", + key: "TypeTestList", + presetValue: []string{"value1", "value2", "value3", "value4"}, + command: []string{"TYPE", "TypeTestList"}, + expectedResponse: "list", + expectedError: nil, + }, + { + name: "Test TYPE with preset list of integers value", + key: "TypeTestList2", + presetValue: []int{1, 2, 3, 4}, + command: []string{"TYPE", "TypeTestList2"}, + expectedResponse: "list", + expectedError: nil, + }, + { + name: "Test TYPE with preset zset of integers value", + key: "TypeTestZSet", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "member1", Score: sorted_set.Score(5.5)}, + {Value: "member2", Score: sorted_set.Score(67.77)}, + {Value: "member3", Score: sorted_set.Score(10)}, + }), + command: []string{"TYPE", "TypeTestZSet"}, + expectedResponse: "zset", + expectedError: nil, + }, + { + name: "Test TYPE with preset hash of map[string]string", + key: "TypeTestHash", + presetValue: map[string]string{"field1": "value1"}, + command: []string{"TYPE", "TypeTestHash"}, + expectedResponse: "hash", + expectedError: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string, int, float64: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.AnyValue(test.presetValue), + } + expected = "ok" + case *set.Set: + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} + for _, element := range test.presetValue.(*set.Set).GetAll() { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) + case []string: + command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} + for _, element := range test.presetValue.([]string) { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(len(test.presetValue.([]string))) + case []int: + command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} + for _, element := range test.presetValue.([]int) { + command = append(command, []resp.Value{resp.IntegerValue(element)}...) + } + expected = strconv.Itoa(len(test.presetValue.([]int))) + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} + for _, member := range test.presetValue.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(test.presetValue.(*sorted_set.SortedSet).Cardinality()) + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.String() != test.expectedResponse { + t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) + } + }) + } + }) } diff --git a/internal/modules/generic/key_funcs.go b/internal/modules/generic/key_funcs.go index 0cc3893b..57876843 100644 --- a/internal/modules/generic/key_funcs.go +++ b/internal/modules/generic/key_funcs.go @@ -223,3 +223,14 @@ func getExKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) { WriteKeys: cmd[1:2], }, nil } + +func typeKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) { + if len(cmd) != 2 { + return internal.KeyExtractionFuncResult{}, errors.New(constants.WrongArgsResponse) + } + return internal.KeyExtractionFuncResult{ + Channels: make([]string, 0), + ReadKeys: cmd[1:], + WriteKeys: make([]string, 0), + }, nil +}