Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore/module tests #13

Merged
merged 12 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
bin
volumes
/config/
/coverage/

dist/
src/modules/*/aof
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ run:
make build && docker-compose up --build

test:
go clean -testcache && go test ./src/... -coverprofile coverage/cover.out
go clean -testcache && go test ./src/... -coverprofile coverage/coverage.out
2,807 changes: 2,807 additions & 0 deletions coverage/coverage.out

Large diffs are not rendered by default.

58 changes: 46 additions & 12 deletions src/modules/acl/acl.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,20 @@
"reflect"
"slices"
"strings"
"sync"
"time"
)

type Connection struct {
Authenticated bool
User *User
Authenticated bool // Whether the connection has been authenticated
User *User // The user the connection is associated with
}

type ACL struct {
Users []*User
Connections map[*net.Conn]Connection
Config utils.Config
Users []*User // List of ACL user profiles
UsersMutex sync.RWMutex // RWMutex for concurrency control when accessing ACL profile list
Connections map[*net.Conn]Connection // Connections to the server that are currently registered with the ACL module
Config utils.Config // Server configuration that contains the relevant ACL config options
GlobPatterns map[string]glob.Glob
}

Expand Down Expand Up @@ -93,6 +95,7 @@

acl := ACL{
Users: users,
UsersMutex: sync.RWMutex{},
Connections: make(map[*net.Conn]Connection),
Config: config,
GlobPatterns: make(map[string]glob.Glob),
Expand All @@ -104,6 +107,9 @@
}

func (acl *ACL) RegisterConnection(conn *net.Conn) {
acl.LockUsers()
defer acl.UnlockUsers()

// This is called only when a connection is established.
defaultUserIdx := slices.IndexFunc(acl.Users, func(user *User) bool {
return user.Username == "default"
Expand All @@ -115,7 +121,10 @@
}
}

func (acl *ACL) SetUser(ctx context.Context, cmd []string) error {
func (acl *ACL) SetUser(cmd []string) error {
acl.LockUsers()
defer acl.UnlockUsers()

// Check if user with the given username already exists
// If it does, replace user variable with this user
for _, user := range acl.Users {
Expand Down Expand Up @@ -144,7 +153,10 @@
return nil
}

func (acl *ACL) DeleteUser(ctx context.Context, usernames []string) error {
func (acl *ACL) DeleteUser(_ context.Context, usernames []string) error {
acl.LockUsers()
defer acl.UnlockUsers()

var user *User
for _, username := range usernames {
if username == "default" {
Expand All @@ -164,18 +176,21 @@
// Terminate every connection attached to this user
for connRef, connection := range acl.Connections {
if connection.User.Username == user.Username {
(*connRef).SetReadDeadline(time.Now().Add(-1 * time.Second))
_ = (*connRef).SetReadDeadline(time.Now().Add(-1 * time.Second))

Check warning on line 179 in src/modules/acl/acl.go

View check run for this annotation

Codecov / codecov/patch

src/modules/acl/acl.go#L179

Added line #L179 was not covered by tests
}
}
// Delete the user from the ACL
acl.Users = slices.DeleteFunc(acl.Users, func(u *User) bool {
return u.Username != user.Username
return u.Username == user.Username
})
}
return nil
}

func (acl *ACL) AuthenticateConnection(ctx context.Context, conn *net.Conn, cmd []string) error {
func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []string) error {
acl.RLockUsers()
defer acl.RUnlockUsers()

var passwords []Password
var user *User

Expand All @@ -194,6 +209,7 @@
})
user = acl.Users[idx]
}

if len(cmd) == 3 {
// Process AUTH <username> <password>
h.Write([]byte(cmd[2]))
Expand Down Expand Up @@ -248,6 +264,9 @@
}

func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.Command, subCommand utils.SubCommand) error {
acl.RLockUsers()
defer acl.RUnlockUsers()

// Extract command, categories, and keys
comm := command.Command
categories := command.Categories
Expand Down Expand Up @@ -278,7 +297,6 @@

// If the command is 'auth', then return early and allow it
if strings.EqualFold(comm, "auth") {
// TODO: Add rate limiting to prevent auth spamming
return nil
}

Expand Down Expand Up @@ -339,7 +357,7 @@
return fmt.Errorf("not authorised to run %s command", comm)
}

// 6. PUBSUB authorisation comes first because it has slightly different handling.
// 6. PUBSUB authorisation.
if slices.Contains(categories, utils.PubSubCategory) {
// In PUBSUB, KeyExtractionFunc returns channels so keys[0] is aliased to channel
channel := keys[0]
Expand Down Expand Up @@ -421,3 +439,19 @@
}
}
}

func (acl *ACL) LockUsers() {
acl.UsersMutex.Lock()
}

func (acl *ACL) UnlockUsers() {
acl.UsersMutex.Unlock()
}

func (acl *ACL) RLockUsers() {
acl.UsersMutex.RLock()
}

func (acl *ACL) RUnlockUsers() {
acl.UsersMutex.RUnlock()
}
65 changes: 36 additions & 29 deletions src/modules/acl/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"fmt"
"github.com/echovault/echovault/src/utils"
"gopkg.in/yaml.v3"
"log"
"net"
"os"
"path"
Expand All @@ -28,7 +29,7 @@
return []byte(utils.OkResponse), nil
}

func handleGetUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleGetUser(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) != 3 {
return nil, errors.New(utils.WrongArgsResponse)
}
Expand Down Expand Up @@ -110,22 +111,23 @@

// keys
allKeys := user.IncludedReadKeys
for _, key := range user.IncludedWriteKeys {
for _, key := range append(user.IncludedWriteKeys, user.IncludedReadKeys...) {
if !slices.Contains(allKeys, key) {
allKeys = append(allKeys, key)
}
}
res = res + fmt.Sprintf("\r\n+keys\r\n*%d", len(allKeys))
for _, key := range user.IncludedReadKeys {
if slices.Contains(user.IncludedWriteKeys, key) {
for _, key := range allKeys {
switch {
case slices.Contains(user.IncludedWriteKeys, key) && slices.Contains(user.IncludedReadKeys, key):
// Key is RW
res = res + fmt.Sprintf("\r\n+%s~%s", "%RW", key)
continue
}
res = res + fmt.Sprintf("\r\n+%s~%s", "%R", key)
}
for _, key := range user.IncludedWriteKeys {
if !slices.Contains(user.IncludedReadKeys, key) {
case slices.Contains(user.IncludedWriteKeys, key):
// Keys is W-Only
res = res + fmt.Sprintf("\r\n+%s~%s", "%W", key)
case slices.Contains(user.IncludedReadKeys, key):
// Key is R-Only
res = res + fmt.Sprintf("\r\n+%s~%s", "%R", key)
}
}

Expand All @@ -144,7 +146,7 @@
return []byte(res), nil
}

func handleCat(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleCat(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) > 3 {
return nil, errors.New(utils.WrongArgsResponse)
}
Expand Down Expand Up @@ -201,10 +203,10 @@
}
}

return nil, errors.New("category not found")
return nil, fmt.Errorf("category %s not found", strings.ToUpper(cmd[2]))
}

func handleUsers(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleUsers(_ context.Context, _ []string, server utils.Server, _ *net.Conn) ([]byte, error) {
acl, ok := server.GetACL().(*ACL)
if !ok {
return nil, errors.New("could not load ACL")
Expand All @@ -217,18 +219,18 @@
return []byte(res), nil
}

func handleSetUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleSetUser(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
acl, ok := server.GetACL().(*ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
if err := acl.SetUser(ctx, cmd[2:]); err != nil {
if err := acl.SetUser(cmd[2:]); err != nil {
return nil, err
}
return []byte(utils.OkResponse), nil
}

func handleDelUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleDelUser(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) < 3 {
return nil, errors.New(utils.WrongArgsResponse)
}
Expand All @@ -242,7 +244,7 @@
return []byte(utils.OkResponse), nil
}

func handleWhoAmI(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleWhoAmI(_ context.Context, _ []string, server utils.Server, conn *net.Conn) ([]byte, error) {
acl, ok := server.GetACL().(*ACL)
if !ok {
return nil, errors.New("could not load ACL")
Expand All @@ -251,7 +253,7 @@
return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil
}

func handleList(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleList(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if len(cmd) > 2 {
return nil, errors.New(utils.WrongArgsResponse)
}
Expand Down Expand Up @@ -347,7 +349,7 @@
return []byte(res), nil
}

func handleLoad(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleLoad(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {

Check warning on line 352 in src/modules/acl/commands.go

View check run for this annotation

Codecov / codecov/patch

src/modules/acl/commands.go#L352

Added line #L352 was not covered by tests
if len(cmd) != 3 {
return nil, errors.New(utils.WrongArgsResponse)
}
Expand All @@ -357,15 +359,17 @@
return nil, errors.New("could not load ACL")
}

acl.LockUsers()
defer acl.RUnlockUsers()

Check warning on line 363 in src/modules/acl/commands.go

View check run for this annotation

Codecov / codecov/patch

src/modules/acl/commands.go#L362-L363

Added lines #L362 - L363 were not covered by tests

f, err := os.Open(acl.Config.AclConfig)
if err != nil {
return nil, err
}

defer func() {
if err := f.Close(); err != nil {
// TODO: Log file close error with context
fmt.Println(err)
log.Println(err)

Check warning on line 372 in src/modules/acl/commands.go

View check run for this annotation

Codecov / codecov/patch

src/modules/acl/commands.go#L372

Added line #L372 was not covered by tests
}
}()

Expand Down Expand Up @@ -412,7 +416,7 @@
return []byte(utils.OkResponse), nil
}

func handleSave(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
func handleSave(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {

Check warning on line 419 in src/modules/acl/commands.go

View check run for this annotation

Codecov / codecov/patch

src/modules/acl/commands.go#L419

Added line #L419 was not covered by tests
if len(cmd) > 2 {
return nil, errors.New(utils.WrongArgsResponse)
}
Expand All @@ -422,15 +426,17 @@
return nil, errors.New("could not load ACL")
}

acl.RLockUsers()
acl.RUnlockUsers()

Check warning on line 430 in src/modules/acl/commands.go

View check run for this annotation

Codecov / codecov/patch

src/modules/acl/commands.go#L429-L430

Added lines #L429 - L430 were not covered by tests

f, err := os.OpenFile(acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE, os.ModeAppend)
if err != nil {
return nil, err
}

defer func() {
if err := f.Close(); err != nil {
// TODO: Log file close error with context
fmt.Println(err)
log.Println(err)

Check warning on line 439 in src/modules/acl/commands.go

View check run for this annotation

Codecov / codecov/patch

src/modules/acl/commands.go#L439

Added line #L439 was not covered by tests
}
}()

Expand Down Expand Up @@ -490,10 +496,11 @@
},
SubCommands: []utils.SubCommand{
{
Command: "cat",
Categories: []string{utils.SlowCategory},
Description: "(ACL CAT [category]) List all the categories and commands inside a category.",
Sync: false,
Command: "cat",
Categories: []string{utils.SlowCategory},
Description: `(ACL CAT [category]) List all the categories.
If the optional category is provided, list all the commands in the category`,
Sync: false,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
},
Expand Down Expand Up @@ -565,7 +572,7 @@
Description: `
(ACL LOAD <MERGE | REPLACE>) Reloads the rules from the configured ACL config file.
When 'MERGE' is passed, users from config file who share a username with users in memory will be merged.
When 'REPLACED' is passed, users from config file who share a username with users in memory will replace the user in memory.`,
When 'REPLACE' is passed, users from config file who share a username with users in memory will replace the user in memory.`,
Sync: true,
KeyExtractionFunc: func(cmd []string) ([]string, error) {
return []string{}, nil
Expand Down
Loading
Loading