diff --git a/echovault/api_generic.go b/echovault/api_generic.go index dac2754b..0f280d28 100644 --- a/echovault/api_generic.go +++ b/echovault/api_generic.go @@ -485,6 +485,29 @@ func (server *EchoVault) IncrBy(key string, value string) (int, error) { return internal.ParseIntegerResponse(b) } +// IncrByFloat increments the floating-point value of the specified key by the given increment. +// If the key does not exist, it is created with an initial value of 0 before incrementing. +// If the value stored at the key is not a float, an error is returned. +// +// Parameters: +// +// `key` - string - The key whose value is to be incremented. +// +// `increment` - float64 - The amount by which to increment the key's value. This can be a positive or negative float. +// +// Returns: The new value of the key after the increment operation as a float64. +func (server *EchoVault) IncrByFloat(key string, value string) (float64, error) { + // Construct the command + cmd := []string{"INCRBYFLOAT", key, value} + // Execute the command + b, err := server.handleCommand(server.context, internal.EncodeCommand(cmd), nil, false, true) + if err != nil { + return 0, err + } + // Parse the float response + return internal.ParseFloatResponse(b) +} + // DecrBy decrements the integer value of the specified key by the given increment. // If the key does not exist, it is created with an initial value of 0 before decrementing. // If the value stored at the key is not an integer, an error is returned. diff --git a/echovault/api_generic_test.go b/echovault/api_generic_test.go index c0ffd3b2..ed3f934d 100644 --- a/echovault/api_generic_test.go +++ b/echovault/api_generic_test.go @@ -1117,6 +1117,75 @@ func TestEchoVault_INCRBY(t *testing.T) { } } +func TestEchoVault_INCRBYFLOAT(t *testing.T) { + server := createEchoVault() + + tests := []struct { + name string + key string + increment string + presetValues map[string]internal.KeyData + want float64 + wantErr bool + }{ + { + name: "1. Increment non-existent key by 2.5", + key: "IncrByFloatKey1", + increment: "2.5", + presetValues: nil, + want: 2.5, + wantErr: false, + }, + { + name: "2. Increment existing key with integer value by 1.2", + key: "IncrByFloatKey2", + increment: "1.2", + presetValues: map[string]internal.KeyData{ + "IncrByFloatKey2": {Value: "5"}, + }, + want: 6.2, + wantErr: false, + }, + { + name: "3. Increment existing key with float value by 0.7", + key: "IncrByFloatKey4", + increment: "0.7", + presetValues: map[string]internal.KeyData{ + "IncrByFloatKey4": {Value: "10.0"}, + }, + want: 10.7, + wantErr: false, + }, + { + name: "4. Increment existing key with scientific notation value by 200", + key: "IncrByFloatKey5", + increment: "200", + presetValues: map[string]internal.KeyData{ + "IncrByFloatKey5": {Value: "5.0e3"}, + }, + want: 5200, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.presetValues != nil { + for k, d := range tt.presetValues { + presetKeyData(server, context.Background(), k, d) + } + } + got, err := server.IncrByFloat(tt.key, tt.increment) + if (err != nil) != tt.wantErr { + t.Errorf("IncrByFloat() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err == nil && got != tt.want { + t.Errorf("IncrByFloat() got = %v, want %v", got, tt.want) + } + }) + } +} + func TestEchoVault_DECRBY(t *testing.T) { server := createEchoVault() diff --git a/internal/modules/generic/commands.go b/internal/modules/generic/commands.go index bd0c253b..f28d9f2b 100644 --- a/internal/modules/generic/commands.go +++ b/internal/modules/generic/commands.go @@ -529,6 +529,64 @@ func handleIncrBy(params internal.HandlerFuncParams) ([]byte, error) { return []byte(fmt.Sprintf(":%d\r\n", newValue)), nil } +func handleIncrByFloat(params internal.HandlerFuncParams) ([]byte, error) { + // Extract key from command + keys, err := incrByFloatKeyFunc(params.Command) + if err != nil { + return nil, err + } + + // Parse increment value + incrValue, err := strconv.ParseFloat(params.Command[2], 64) + if err != nil { + return nil, errors.New("increment value is not a float or out of range") + } + + key := keys.WriteKeys[0] + values := params.GetValues(params.Context, []string{key}) // Get the current values for the specified keys + currentValue, ok := values[key] // Check if the key exists + + var newValue float64 + var currentValueFloat float64 + + // Check if the key exists and its current value + if !ok || currentValue == nil { + // If key does not exist, initialize it with the increment value + newValue = incrValue + } else { + // Use type switch to handle different types of currentValue + switch v := currentValue.(type) { + case string: + currentValueFloat, err = strconv.ParseFloat(v, 64) // Parse the string to float64 + if err != nil { + currentValueInt, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return nil, errors.New("value is not a float or integer") + } + currentValueFloat = float64(currentValueInt) + } + case float64: + currentValueFloat = v // Use float64 value directly + case int64: + currentValueFloat = float64(v) // Convert int64 to float64 + case int: + currentValueFloat = float64(v) // Convert int to float64 + default: + fmt.Printf("unexpected type for currentValue: %T\n", currentValue) + return nil, errors.New("unexpected type for currentValue") // Handle unexpected types + } + newValue = currentValueFloat + incrValue // Increment the value by the specified amount + } + + // Set the new incremented value + if err := params.SetValues(params.Context, map[string]interface{}{key: fmt.Sprintf("%g", newValue)}); err != nil { + return nil, err + } + + // Prepare response with the actual new value in bulk string format + response := fmt.Sprintf("$%d\r\n%g\r\n", len(fmt.Sprintf("%g", newValue)), newValue) + return []byte(response), nil +} func handleDecrBy(params internal.HandlerFuncParams) ([]byte, error) { // Extract key from command keys, err := decrByKeyFunc(params.Command) @@ -829,6 +887,17 @@ An error is returned if the key contains a value of the wrong type or contains a KeyExtractionFunc: incrByKeyFunc, HandlerFunc: handleIncrBy, }, + { + Command: "incrbyfloat", + Module: constants.GenericModule, + Categories: []string{constants.WriteCategory, constants.FastCategory}, + Description: `(INCRBYFLOAT key increment) +Increments the number stored at key by increment. If the key does not exist, it is set to 0 before performing the operation. +An error is returned if the key contains a value of the wrong type or contains a string that cannot be represented as float.`, + Sync: true, + KeyExtractionFunc: incrByFloatKeyFunc, + HandlerFunc: handleIncrByFloat, + }, { Command: "decrby", Module: constants.GenericModule, @@ -845,7 +914,7 @@ If the key's value is not of the correct type or cannot be represented as an int Command: "rename", Module: constants.GenericModule, Categories: []string{constants.KeyspaceCategory, constants.WriteCategory, constants.FastCategory}, - Description: `(RENAME key newkey) + Description: `(RENAME key newkey) Renames key to newkey. If newkey already exists, it is overwritten. If key does not exist, an error is returned.`, Sync: true, KeyExtractionFunc: renameKeyFunc, @@ -878,7 +947,7 @@ Renames key to newkey. If newkey already exists, it is overwritten. If key does constants.SlowCategory, constants.DangerousCategory, }, - Description: `(FLUSHDB) + Description: `(FLUSHDB) Delete all the keys in the currently selected database. This command is always synchronous.`, Sync: true, KeyExtractionFunc: func(cmd []string) (internal.KeyExtractionFuncResult, error) { diff --git a/internal/modules/generic/commands_test.go b/internal/modules/generic/commands_test.go index e28fb0c6..aa903a9e 100644 --- a/internal/modules/generic/commands_test.go +++ b/internal/modules/generic/commands_test.go @@ -17,16 +17,17 @@ package generic_test import ( "errors" "fmt" + "strconv" + "strings" + "testing" + "time" + "github.com/echovault/echovault/echovault" "github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal/clock" "github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/constants" "github.com/tidwall/resp" - "strconv" - "strings" - "testing" - "time" ) type KeyData struct { @@ -2268,6 +2269,126 @@ func Test_Generic(t *testing.T) { } }) + t.Run("Test_HandlerINCRBYFLOAT", func(t *testing.T) { + t.Parallel() + conn, err := internal.GetConnection("localhost", port) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + increment string + presetValue interface{} + command []resp.Value + expectedResponse float64 + expectedError error + }{ + { + name: "1. Increment non-existent key by 2.5", + key: "IncrByFloatKey1", + increment: "2.5", + presetValue: nil, + command: []resp.Value{resp.StringValue("INCRBYFLOAT"), resp.StringValue("IncrByFloatKey1"), resp.StringValue("2.5")}, + expectedResponse: 2.5, + expectedError: nil, + }, + { + name: "2. Increment existing key with integer value by 1.2", + key: "IncrByFloatKey2", + increment: "1.2", + presetValue: "5", + command: []resp.Value{resp.StringValue("INCRBYFLOAT"), resp.StringValue("IncrByFloatKey2"), resp.StringValue("1.2")}, + expectedResponse: 6.2, + expectedError: nil, + }, + { + name: "3. Increment existing key with float value by 0.7", + key: "IncrByFloatKey4", + increment: "0.7", + presetValue: "10.0", + command: []resp.Value{resp.StringValue("INCRBYFLOAT"), resp.StringValue("IncrByFloatKey4"), resp.StringValue("0.7")}, + expectedResponse: 10.7, + expectedError: nil, + }, + { + name: "4. Command too short", + key: "IncrByFloatKey5", + increment: "5", + presetValue: nil, + command: []resp.Value{resp.StringValue("INCRBYFLOAT"), resp.StringValue("IncrByFloatKey5")}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. Command too long", + key: "IncrByFloatKey6", + increment: "5", + presetValue: nil, + command: []resp.Value{ + resp.StringValue("INCRBYFLOAT"), + resp.StringValue("IncrByFloatKey6"), + resp.StringValue("5"), + resp.StringValue("extra_arg"), + }, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + command := []resp.Value{resp.StringValue("SET"), resp.StringValue(test.key), resp.StringValue(fmt.Sprintf("%v", test.presetValue))} + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + if !strings.EqualFold(res.String(), "OK") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } + } + + if err = client.WriteArray(test.command); err != nil { + t.Error(err) + } + + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error()) + } + return + } + + if err != nil { + t.Error(err) + } else { + responseFloat, err := strconv.ParseFloat(res.String(), 64) + if err != nil { + t.Errorf("error parsing response to float64: %s", err) + } + if responseFloat != test.expectedResponse { + t.Errorf("expected response %f, got %f", test.expectedResponse, responseFloat) + } + } + }) + } + }) + t.Run("Test_HandlerDECRBY", func(t *testing.T) { t.Parallel() conn, err := internal.GetConnection("localhost", port) diff --git a/internal/modules/generic/key_funcs.go b/internal/modules/generic/key_funcs.go index 150d0ea6..1fea0108 100644 --- a/internal/modules/generic/key_funcs.go +++ b/internal/modules/generic/key_funcs.go @@ -164,6 +164,15 @@ func incrByKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) { }, nil } +func incrByFloatKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) { + if len(cmd) != 3 { + return internal.KeyExtractionFuncResult{}, errors.New(constants.WrongArgsResponse) + } + return internal.KeyExtractionFuncResult{ + WriteKeys: []string{cmd[1]}, + }, nil +} + func decrByKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) { if len(cmd) != 3 { return internal.KeyExtractionFuncResult{}, errors.New(constants.WrongArgsResponse)