diff --git a/Dockerfile.dev b/Dockerfile.dev index 737d3982..ad1db84b 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -8,7 +8,7 @@ COPY . ./ ENV CGO_ENABLED=1 CC=gcc GOOS=linux GOARCH=amd64 ENV DEST=volumes/modules -RUN CGO_ENABLED=$CGO_ENABLED CC=$CC GOOS=$GOOS GOARCH=$GOARCH go build -buildmode=plugin -o $DEST/module_set/module_set.so ./internal/volumes/modules/module_set/module_set.go +RUN CGO_ENABLED=$CGO_ENABLED CC=$CC GOOS=$GOOS GOARCH=$GOARCH go build -buildmode=plugin -o $DEST/module_set/module_set.so ./internal/volumes/modules/module_set/module_set.go RUN CGO_ENABLED=$CGO_ENABLED CC=$CC GOOS=$GOOS GOARCH=$GOARCH go build -buildmode=plugin -o $DEST/module_get/module_get.so ./internal/volumes/modules/module_get/module_get.go ENV DEST=bin diff --git a/echovault/api_hash.go b/echovault/api_hash.go index a45d3e0a..aec63fef 100644 --- a/echovault/api_hash.go +++ b/echovault/api_hash.go @@ -115,6 +115,34 @@ func (server *EchoVault) HGet(key string, fields ...string) ([]string, error) { return internal.ParseStringArrayResponse(b) } +// HMGet retrieves the values corresponding to the provided fields. +// +// Parameters: +// +// `key` - string - the key to the hash map. +// +// `fields` - ...string - the list of fields to fetch. +// +// Returns: A string slice of the values corresponding to the fields in the same order the fields were provided. +// +// Errors: +// +// "value at is not a hash" - when the provided key does not exist or is not a hash. +func (server *EchoVault) HMGet(key string, fields ...string) ([]string, error) { + b, err := server.handleCommand( + server.context, + internal.EncodeCommand(append([]string{"HMGET", key}, fields...)), + nil, + false, + true, + ) + if err != nil { + return nil, err + } + + return internal.ParseStringArrayResponse(b) +} + // HStrLen returns the length of the values held at the specified fields of a hash map. // // Parameters: diff --git a/echovault/api_hash_test.go b/echovault/api_hash_test.go index 7f18e015..f23d0f33 100644 --- a/echovault/api_hash_test.go +++ b/echovault/api_hash_test.go @@ -826,3 +826,59 @@ func TestEchoVault_HGet(t *testing.T) { }) } } + +func TestEchoVault_HMGet(t *testing.T) { + server := createEchoVault() + tests := []struct { + name string + presetValue interface{} + key string + fields []string + want []string + wantErr bool + }{ + { + name: "1. Get values from existing hash.", + key: "HgetKey1", + presetValue: map[string]interface{}{"field1": "value1", "field2": 365, "field3": 3.142}, + fields: []string{"field1", "field2", "field3", "field4"}, + want: []string{"value1", "365", "3.142", ""}, + wantErr: false, + }, + { + name: "2. Return empty slice when attempting to get from non-existed key", + presetValue: nil, + key: "HgetKey2", + fields: []string{"field1"}, + want: []string{}, + wantErr: false, + }, + { + name: "3. Error when trying to get from a value that is not a hash map", + presetValue: "Default Value", + key: "HgetKey3", + fields: []string{"field1"}, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.presetValue != nil { + err := presetValue(server, context.Background(), tt.key, tt.presetValue) + if err != nil { + t.Error(err) + return + } + } + got, err := server.HGet(tt.key, tt.fields...) + if (err != nil) != tt.wantErr { + t.Errorf("HMGet() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("HMGet() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/modules/hash/commands.go b/internal/modules/hash/commands.go index 1c2178bd..f455eb30 100644 --- a/internal/modules/hash/commands.go +++ b/internal/modules/hash/commands.go @@ -138,6 +138,53 @@ func handleHGET(params internal.HandlerFuncParams) ([]byte, error) { return []byte(res), nil } +func handleHMGET(params internal.HandlerFuncParams) ([]byte, error) { + keys, err := hmgetKeyFunc(params.Command) + if err != nil { + return nil, err + } + + key := keys.ReadKeys[0] + keyExists := params.KeysExist(params.Context, keys.ReadKeys)[key] + if !keyExists { + return []byte("*0\r\n"), nil + } + + hash, ok := params.GetValues(params.Context, []string{key})[key].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("value at %s is not a hash", key) + } + + fields := params.Command[2:] + + var value interface{} + res := fmt.Sprintf("*%d\r\n", len(fields)) + for _, field := range fields { + value, ok = hash[field] + if !ok { + res += "$-1\r\n" + continue + } + + if s, ok := value.(string); ok { + res += fmt.Sprintf("$%d\r\n%s\r\n", len(s), s) + continue + } + if d, ok := value.(int); ok { + res += fmt.Sprintf(":%d\r\n", d) + continue + } + if f, ok := value.(float64); ok { + fs := strconv.FormatFloat(f, 'f', -1, 64) + res += fmt.Sprintf("$%d\r\n%s\r\n", len(fs), fs) + continue + } + res += fmt.Sprintf("$-1\r\n") + + } + return []byte(res), nil +} + func handleHSTRLEN(params internal.HandlerFuncParams) ([]byte, error) { keys, err := hstrlenKeyFunc(params.Command) if err != nil { @@ -594,6 +641,16 @@ Retrieve the value of each of the listed fields from the hash.`, KeyExtractionFunc: hgetKeyFunc, HandlerFunc: handleHGET, }, + { + Command: "hmget", + Module: constants.HashModule, + Categories: []string{constants.HashCategory, constants.ReadCategory, constants.FastCategory}, + Description: `(HMGET key field [field ...]) +Retrieve the value of each of the listed fields from the hash.`, + Sync: false, + KeyExtractionFunc: hmgetKeyFunc, + HandlerFunc: handleHMGET, + }, { Command: "hstrlen", Module: constants.HashModule, diff --git a/internal/modules/hash/commands_test.go b/internal/modules/hash/commands_test.go index 0cc27f64..d9238834 100644 --- a/internal/modules/hash/commands_test.go +++ b/internal/modules/hash/commands_test.go @@ -597,6 +597,163 @@ func Test_Hash(t *testing.T) { } }) + t.Run("Test_HandleHMGET", func(t *testing.T) { + t.Parallel() + conn, err := internal.GetConnection("localhost", port) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse []string // Change count + expectedValue map[string]string + expectedError error + }{ + { + name: "1. Get values from existing hash.", + key: "HmgetKey1", + presetValue: map[string]string{"field1": "value1", "field2": "365", "field3": "3.142"}, + command: []string{"HMGET", "HmgetKey1", "field1", "field2", "field3", "field4"}, + expectedResponse: []string{"value1", "365", "3.142", ""}, + expectedValue: map[string]string{"field1": "value1", "field2": "365", "field3": "3.142"}, + expectedError: nil, + }, + { + name: "2. Return nil when attempting to get from non-existed key", + key: "HmgetKey2", + presetValue: nil, + command: []string{"HMGET", "HmgetKey2", "field1"}, + expectedResponse: nil, + expectedValue: nil, + expectedError: nil, + }, + { + name: "3. Error when trying to get from a value that is not a hash map", + key: "HmgetKey3", + presetValue: "Default Value", + command: []string{"HMGET", "HmgetKey3", "field1"}, + expectedResponse: nil, + expectedValue: nil, + expectedError: errors.New("value at HgetKey3 is not a hash"), + }, + { + name: "4. Command too short", + key: "HmgetKey4", + presetValue: nil, + command: []string{"HMGET", "HmgetKey4"}, + expectedResponse: nil, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if test.expectedResponse == nil { + if !res.IsNull() { + t.Errorf("expected nil response, got %+v", res) + } + return + } + + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.String()) { + t.Errorf("unexpected element \"%s\" in response", item.String()) + } + } + + // Check that all the values are what is expected + if err := client.WriteArray([]resp.Value{ + resp.StringValue("HGETALL"), + resp.StringValue(test.key), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + for idx, field := range res.Array() { + if idx%2 == 0 { + if res.Array()[idx+1].String() != test.expectedValue[field.String()] { + t.Errorf( + "expected value \"%+v\" for field \"%s\", got \"%+v\"", + test.expectedValue[field.String()], field.String(), res.Array()[idx+1].String(), + ) + } + } + } + }) + } + }) + t.Run("Test_HandleHSTRLEN", func(t *testing.T) { t.Parallel() conn, err := internal.GetConnection("localhost", port) diff --git a/internal/modules/hash/key_funcs.go b/internal/modules/hash/key_funcs.go index e259c208..6d9f17df 100644 --- a/internal/modules/hash/key_funcs.go +++ b/internal/modules/hash/key_funcs.go @@ -53,6 +53,17 @@ func hgetKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) { }, nil } +func hmgetKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) { + if len(cmd) < 3 { + return internal.KeyExtractionFuncResult{}, errors.New(constants.WrongArgsResponse) + } + return internal.KeyExtractionFuncResult{ + Channels: make([]string, 0), + ReadKeys: cmd[1:2], + WriteKeys: make([]string, 0), + }, nil +} + func hstrlenKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) { if len(cmd) < 3 { return internal.KeyExtractionFuncResult{}, errors.New(constants.WrongArgsResponse)