From c7f492f83f169a1a75170539010882bb330ee6d8 Mon Sep 17 00:00:00 2001 From: Tejesh Kumar S <81950114+zenc0derr@users.noreply.github.com> Date: Thu, 24 Oct 2024 23:05:19 +0530 Subject: [PATCH] Implementation of Copy command (#141) * Added COPY command - @zenc0derr --------- Co-authored-by: Tejesh Kumar S Co-authored-by: Kelvin Clement Mwinuka --- internal/modules/generic/commands.go | 57 ++++++++ internal/modules/generic/commands_test.go | 164 +++++++++++++++++++++- internal/modules/generic/key_funcs.go | 12 ++ internal/modules/generic/utils.go | 34 +++++ sugardb/api_generic.go | 37 +++++ sugardb/api_generic_test.go | 106 +++++++++++++- sugardb/test_helpers.go | 12 ++ 7 files changed, 420 insertions(+), 2 deletions(-) diff --git a/internal/modules/generic/commands.go b/internal/modules/generic/commands.go index b42f41a4..9f70a6ac 100644 --- a/internal/modules/generic/commands.go +++ b/internal/modules/generic/commands.go @@ -595,6 +595,7 @@ func handleIncrByFloat(params internal.HandlerFuncParams) ([]byte, error) { response := fmt.Sprintf("$%d\r\n%g\r\n", len(fmt.Sprintf("%g", newValue)), newValue) return []byte(response), nil } + func handleDecrBy(params internal.HandlerFuncParams) ([]byte, error) { // Extract key from command keys, err := decrByKeyFunc(params.Command) @@ -870,6 +871,50 @@ func handleObjIdleTime(params internal.HandlerFuncParams) ([]byte, error) { return []byte(fmt.Sprintf("+%v\r\n", idletime)), nil } +func handleCopy(params internal.HandlerFuncParams) ([]byte, error) { + keys, err := copyKeyFunc(params.Command) + if err != nil { + return nil, err + } + + options, err := getCopyCommandOptions(params.Command[3:], CopyOptions{}) + if err != nil { + return nil, err + } + sourceKey := keys.ReadKeys[0] + destinationKey := keys.WriteKeys[0] + sourceKeyExists := params.KeysExist(params.Context, []string{sourceKey})[sourceKey] + + if !sourceKeyExists { + return []byte(":0\r\n"), nil + } + + if !options.replace { + destinationKeyExists := params.KeysExist(params.Context, []string{destinationKey})[destinationKey] + + if destinationKeyExists { + return []byte(":0\r\n"), nil + } + } + + value := params.GetValues(params.Context, []string{sourceKey})[sourceKey] + + ctx := context.WithoutCancel(params.Context) + + if options.database != "" { + database, _ := strconv.Atoi(options.database) + ctx = context.WithValue(ctx, "Database", database) + } + + if err = params.SetValues(ctx, map[string]interface{}{ + destinationKey: value, + }); err != nil { + return nil, err + } + + return []byte(":1\r\n"), nil +} + func handleMove(params internal.HandlerFuncParams) ([]byte, error) { keys, err := moveKeyFunc(params.Command) if err != nil { @@ -1255,6 +1300,18 @@ The command is only available when the maxmemory-policy configuration directive KeyExtractionFunc: objIdleTimeKeyFunc, HandlerFunc: handleObjIdleTime, }, + { + Command: "copy", + Module: constants.GenericModule, + Categories: []string{constants.KeyspaceCategory, constants.WriteCategory, constants.SlowCategory}, + Description: `(COPY source destination [DB destination-db] [REPLACE]) +Copies the value stored at the source key to the destination key. +The command returns zero when the destination key already exists. +The REPLACE option removes the destination key before copying the value to it.`, + Sync: false, + KeyExtractionFunc: copyKeyFunc, + HandlerFunc: handleCopy, + }, { Command: "move", Module: constants.GenericModule, diff --git a/internal/modules/generic/commands_test.go b/internal/modules/generic/commands_test.go index 8106f6f4..059cb0c9 100644 --- a/internal/modules/generic/commands_test.go +++ b/internal/modules/generic/commands_test.go @@ -3333,6 +3333,168 @@ func Test_Generic(t *testing.T) { } }) + t.Run("Test_HandleCOPY", func(t *testing.T) { + t.Parallel() + conn, err := internal.GetConnection("localhost", port) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + sourceKeyPresetValue interface{} + sourcekey string + destKeyPresetValue interface{} + destinationKey string + database string + replace bool + expectedValue string + expectedResponse string + }{ + { + name: "1. Copy Value into non existing key", + sourceKeyPresetValue: "value1", + sourcekey: "skey1", + destKeyPresetValue: nil, + destinationKey: "dkey1", + database: "0", + replace: false, + expectedValue: "value1", + expectedResponse: "1", + }, + { + name: "2. Copy Value into existing key without replace option", + sourceKeyPresetValue: "value2", + sourcekey: "skey2", + destKeyPresetValue: "dValue2", + destinationKey: "dkey2", + database: "0", + replace: false, + expectedValue: "dValue2", + expectedResponse: "0", + }, + { + name: "3. Copy Value into existing key with replace option", + sourceKeyPresetValue: "value3", + sourcekey: "skey3", + destKeyPresetValue: "dValue3", + destinationKey: "dkey3", + database: "0", + replace: true, + expectedValue: "value3", + expectedResponse: "1", + }, + { + name: "4. Copy Value into different database", + sourceKeyPresetValue: "value4", + sourcekey: "skey4", + destKeyPresetValue: nil, + destinationKey: "dkey4", + database: "1", + replace: true, + expectedValue: "value4", + expectedResponse: "1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.sourceKeyPresetValue != nil { + cmd := []resp.Value{resp.StringValue("Set"), resp.StringValue(tt.sourcekey), resp.StringValue(tt.sourceKeyPresetValue.(string))} + + err := client.WriteArray(cmd) + if err != nil { + t.Error(err) + } + + rd, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(rd.String(), "ok") { + t.Errorf("expected preset response to be \"OK\", got %s", rd.String()) + } + } + + if tt.destKeyPresetValue != nil { + cmd := []resp.Value{resp.StringValue("Set"), resp.StringValue(tt.destinationKey), resp.StringValue(tt.destKeyPresetValue.(string))} + + err := client.WriteArray(cmd) + if err != nil { + t.Error(err) + } + + rd, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(rd.String(), "ok") { + t.Errorf("expected preset response to be \"OK\", got %s", rd.String()) + } + } + + command := []resp.Value{resp.StringValue("COPY"), resp.StringValue(tt.sourcekey), resp.StringValue(tt.destinationKey)} + + if tt.database != "0" { + command = append(command, resp.StringValue("DB"), resp.StringValue(tt.database)) + } + + if tt.replace { + command = append(command, resp.StringValue("REPLACE")) + } + + err := client.WriteArray(command) + if err != nil { + t.Error(err) + } + + rd, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + if !strings.EqualFold(rd.String(), tt.expectedResponse) { + t.Errorf("expected response to be %s, but got %s", tt.expectedResponse, rd.String()) + } + + if tt.database != "0" { + selectCommand := []resp.Value{resp.StringValue("SELECT"), resp.StringValue(tt.database)} + + err := client.WriteArray(selectCommand) + if err != nil { + t.Error(err) + } + _, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + } + + getCommand := []resp.Value{resp.StringValue("GET"), resp.StringValue(tt.destinationKey)} + + err = client.WriteArray(getCommand) + if err != nil { + t.Error(err) + } + + rd, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + if !strings.EqualFold(rd.String(), tt.expectedValue) { + t.Errorf("expected value in destinaton key to be %s, but got %s", tt.expectedValue, rd.String()) + } + }) + } + + }) + t.Run("Test_HandleMOVE", func(t *testing.T) { t.Parallel() @@ -3474,7 +3636,7 @@ func Test_Generic(t *testing.T) { } // Certain commands will need to be tested in a server with an eviction policy. -// This is for testing against an LFU evictiona policy. +// This is for testing against an LFU eviction policy. func Test_LFU_Generic(t *testing.T) { // mockClock := clock.NewClock() port, err := internal.GetFreePort() diff --git a/internal/modules/generic/key_funcs.go b/internal/modules/generic/key_funcs.go index 9e53a6f1..ae94a13c 100644 --- a/internal/modules/generic/key_funcs.go +++ b/internal/modules/generic/key_funcs.go @@ -268,6 +268,18 @@ func objIdleTimeKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) }, nil } +func copyKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) { + if len(cmd) < 3 && len(cmd)>6{ + return internal.KeyExtractionFuncResult{}, errors.New(constants.WrongArgsResponse) + } + + return internal.KeyExtractionFuncResult{ + Channels: make([]string, 0), + ReadKeys: cmd[1:2], + WriteKeys: cmd[2:3], + }, nil +} + func moveKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) { if len(cmd) != 3 { return internal.KeyExtractionFuncResult{}, errors.New(constants.WrongArgsResponse) diff --git a/internal/modules/generic/utils.go b/internal/modules/generic/utils.go index 184b36c0..3c0add3c 100644 --- a/internal/modules/generic/utils.go +++ b/internal/modules/generic/utils.go @@ -29,6 +29,11 @@ type SetOptions struct { expireAt interface{} // Exact expireAt time un unix milliseconds } +type CopyOptions struct { + database string + replace bool +} + func getSetCommandOptions(clock clock.Clock, cmd []string, options SetOptions) (SetOptions, error) { if len(cmd) == 0 { return options, nil @@ -116,3 +121,32 @@ func getSetCommandOptions(clock clock.Clock, cmd []string, options SetOptions) ( return SetOptions{}, fmt.Errorf("unknown option %s for set command", strings.ToUpper(cmd[0])) } } + +func getCopyCommandOptions(cmd []string, options CopyOptions) (CopyOptions, error) { + if len(cmd) == 0 { + return options, nil + } + + switch strings.ToLower(cmd[0]){ + case "replace": + options.replace = true + return getCopyCommandOptions(cmd[1:], options) + + case "db": + if len(cmd) < 2 { + return CopyOptions{}, errors.New("syntax error") + } + + _, err := strconv.Atoi(cmd[1]) + if err != nil { + return CopyOptions{}, errors.New("value is not an integer or out of range") + } + + options.database = cmd [1] + return getCopyCommandOptions(cmd[2:], options) + + + default: + return CopyOptions{}, fmt.Errorf("unknown option %s for copy command", strings.ToUpper(cmd[0])) + } +} \ No newline at end of file diff --git a/sugardb/api_generic.go b/sugardb/api_generic.go index 738bb7c7..ccbc38cf 100644 --- a/sugardb/api_generic.go +++ b/sugardb/api_generic.go @@ -137,6 +137,16 @@ type GetExOption interface { func (x GetExOpt) isGetExOpt() GetExOpt { return x } +// COPYOptions is a struct wrapper for all optional parameters of the Copy command. +// +// `Database` - string - Logical database index +// +// `Replace` - bool - Whether to replace the destination key if it exists +type COPYOptions struct { + Database string + Replace bool +} + // Set creates or modifies the value at the given key. // // Parameters: @@ -719,6 +729,33 @@ func (server *SugarDB) Type(key string) (string, error) { return internal.ParseStringResponse(b) } +// Copy copies a value of a source key to destination key. +// +// Parameters: +// +// `source` - string - the source key from which data is to be copied +// +// `destination` - string - the destination key where data should be copied +// +// Returns: 1 if the copy is successful. 0 if the copy is unsuccessful +func (server *SugarDB) Copy(sourceKey, destinationKey string, options COPYOptions) (int, error) { + cmd := []string{"COPY", sourceKey, destinationKey} + + if options.Database != "" { + cmd = append(cmd, "db", options.Database) + } + + if options.Replace { + cmd = append(cmd, "replace") + } + + b, err := server.handleCommand(server.context, internal.EncodeCommand(cmd), nil, false, true) + if err != nil { + return 0, err + } + return internal.ParseIntegerResponse(b) +} + // Move key from currently selected database to specified destination database and return 1. // When key already exists in the destination database, or it does not exist in the source database, it does nothing and returns 0. // diff --git a/sugardb/api_generic_test.go b/sugardb/api_generic_test.go index 1ccd96b4..0c708820 100644 --- a/sugardb/api_generic_test.go +++ b/sugardb/api_generic_test.go @@ -1826,6 +1826,111 @@ func TestSugarDB_TYPE(t *testing.T) { } } +func TestSugarDB_COPY(t *testing.T) { + server := createSugarDB() + + CopyOptions := func(DB string, R bool) COPYOptions { + return COPYOptions{ + Database: DB, + Replace: R, + } + } + + tests := []struct { + name string + sourceKeyPresetValue interface{} + sourcekey string + destKeyPresetValue interface{} + destinationKey string + options COPYOptions + expectedValue string + want int + wantErr bool + }{ + { + name: "Copy Value into non existing key", + sourceKeyPresetValue: "value1", + sourcekey: "skey1", + destKeyPresetValue: nil, + destinationKey: "dkey1", + options: CopyOptions("0", false), + expectedValue: "value1", + want: 1, + wantErr: false, + }, + { + name: "Copy Value into existing key without replace option", + sourceKeyPresetValue: "value2", + sourcekey: "skey2", + destKeyPresetValue: "dValue2", + destinationKey: "dkey2", + options: CopyOptions("0", false), + expectedValue: "dValue2", + want: 0, + wantErr: false, + }, + { + name: "Copy Value into existing key with replace option", + sourceKeyPresetValue: "value3", + sourcekey: "skey3", + destKeyPresetValue: "dValue3", + destinationKey: "dkey3", + options: CopyOptions("0", true), + expectedValue: "value3", + want: 1, + wantErr: false, + }, + { + name: "Copy Value into different database", + sourceKeyPresetValue: "value4", + sourcekey: "skey4", + destKeyPresetValue: nil, + destinationKey: "dkey4", + options: CopyOptions("1", false), + expectedValue: "value4", + want: 1, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.sourceKeyPresetValue != nil { + err := presetValue(server, context.Background(), tt.sourcekey, tt.sourceKeyPresetValue) + if err != nil { + t.Error(err) + return + } + } + if tt.destKeyPresetValue != nil { + err := presetValue(server, context.Background(), tt.destinationKey, tt.destKeyPresetValue) + if err != nil { + t.Error(err) + return + } + } + + got, err := server.Copy(tt.sourcekey, tt.destinationKey, tt.options) + if (err != nil) != tt.wantErr { + t.Errorf("COPY() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("COPY() got = %v, want %v", got, tt.want) + } + + val, err := getValue(server, context.Background(), tt.destinationKey, tt.options.Database) + if err != nil { + t.Error(err) + return + } + + if val != tt.expectedValue { + t.Errorf("COPY() value in destionation key: %v, should be: %v", val, tt.expectedValue) + } + }) + } +} + func TestSugarDB_MOVE(t *testing.T) { server := createSugarDB() @@ -1867,7 +1972,6 @@ func TestSugarDB_MOVE(t *testing.T) { if got != tt.want { t.Errorf("MOVE() got %v, want %v", got, tt.want) } - }) } } diff --git a/sugardb/test_helpers.go b/sugardb/test_helpers.go index 052dbde2..054c41f5 100644 --- a/sugardb/test_helpers.go +++ b/sugardb/test_helpers.go @@ -2,6 +2,8 @@ package sugardb import ( "context" + "strconv" + "github.com/echovault/sugardb/internal" "github.com/echovault/sugardb/internal/config" "github.com/echovault/sugardb/internal/constants" @@ -37,3 +39,13 @@ func presetKeyData(server *SugarDB, ctx context.Context, key string, data intern _ = server.setValues(ctx, map[string]interface{}{key: data.Value}) server.setExpiry(ctx, key, data.ExpireAt, false) } + +func getValue (server *SugarDB, ctx context.Context, key string, database string) (interface{}, error) { + db, err := strconv.Atoi(database) + if err != nil { + return nil, err + } + ctx = context.WithValue(ctx, "Database", db) + + return server.getValues(ctx, []string{key})[key], err +} \ No newline at end of file