From bca91d71e519fabe8dd5595700175414a3ea69dc Mon Sep 17 00:00:00 2001 From: Anton Date: Fri, 19 Jan 2024 20:47:09 +0500 Subject: [PATCH] fix: prevent directory traversal attack when writing request files --- internal/engine/docker.go | 13 +++-- internal/engine/docker_test.go | 24 +++++++++ internal/engine/engine.go | 19 +++++++ internal/fileio/fileio.go | 29 +++++++++++ internal/fileio/fileio_test.go | 94 ++++++++++++++++++++++++++++++++++ 5 files changed, 176 insertions(+), 3 deletions(-) diff --git a/internal/engine/docker.go b/internal/engine/docker.go index 216e67b..cde048c 100644 --- a/internal/engine/docker.go +++ b/internal/engine/docker.go @@ -8,7 +8,6 @@ import ( "io" "os" "os/exec" - "path/filepath" "strconv" "strings" "time" @@ -54,7 +53,10 @@ func (e *Docker) Exec(req Request) Execution { if e.cmd.Entry != "" { // write request files to the temp directory err = e.writeFiles(dir, req.Files) - if err != nil { + var argErr ArgumentError + if errors.As(err, &argErr) { + return Fail(req.ID, err) + } else if err != nil { err = NewExecutionError("write files to temp dir", err) return Fail(req.ID, err) } @@ -171,7 +173,12 @@ func (e *Docker) writeFiles(dir string, files Files) error { if name == "" { name = e.cmd.Entry } - path := filepath.Join(dir, name) + var path string + path, err = fileio.JoinDir(dir, name) + if err != nil { + err = NewArgumentError(fmt.Sprintf("files[%s]", name), err) + return false + } err = fileio.WriteFile(path, content, 0444) return err == nil }) diff --git a/internal/engine/docker_test.go b/internal/engine/docker_test.go index 883f25a..d841045 100644 --- a/internal/engine/docker_test.go +++ b/internal/engine/docker_test.go @@ -1,6 +1,7 @@ package engine import ( + "fmt" "strings" "testing" @@ -231,6 +232,29 @@ func TestDockerRun(t *testing.T) { t.Errorf("Stderr: unexpected value: %s", out.Stderr) } }) + + t.Run("directory traversal attack", func(t *testing.T) { + mem.Clear() + const fileName = "../../opt/codapi/codapi" + engine := NewDocker(dockerCfg, "python", "run") + req := Request{ + ID: "http_42", + Sandbox: "python", + Command: "run", + Files: map[string]string{ + "": "print('hello world')", + fileName: "hehe", + }, + } + out := engine.Exec(req) + if out.OK { + t.Error("OK: expected false") + } + want := fmt.Sprintf("files[%s]: invalid name", fileName) + if out.Stderr != want { + t.Errorf("Stderr: unexpected value: %s", out.Stderr) + } + }) } func TestDockerExec(t *testing.T) { diff --git a/internal/engine/engine.go b/internal/engine/engine.go index e40eabb..15d05d2 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -62,6 +62,25 @@ func (err ExecutionError) Unwrap() error { return err.inner } +// An ArgumentError is returned if code execution failed +// due to the invalid value of the request agrument. +type ArgumentError struct { + name string + reason error +} + +func NewArgumentError(name string, reason error) ArgumentError { + return ArgumentError{name: name, reason: reason} +} + +func (err ArgumentError) Error() string { + return err.name + ": " + err.reason.Error() +} + +func (err ArgumentError) Unwrap() error { + return err.reason +} + // Files are a collection of files to be executed by the engine. type Files map[string]string diff --git a/internal/fileio/fileio.go b/internal/fileio/fileio.go index f04b09a..ada82b8 100644 --- a/internal/fileio/fileio.go +++ b/internal/fileio/fileio.go @@ -79,6 +79,35 @@ func WriteFile(path, content string, perm fs.FileMode) (err error) { return os.WriteFile(path, data, perm) } +// JoinDir joins a directory path with a relative file path, +// making sure that the resulting path is still inside the directory. +// Returns an error otherwise. +func JoinDir(dir string, name string) (string, error) { + if dir == "" { + return "", errors.New("invalid dir") + } + + cleanName := filepath.Clean(name) + if cleanName == "" { + return "", errors.New("invalid name") + } + if cleanName == "." || cleanName == "/" || filepath.IsAbs(cleanName) { + return "", errors.New("invalid name") + } + + path := filepath.Join(dir, cleanName) + + dirPrefix := filepath.Clean(dir) + if dirPrefix != "/" { + dirPrefix += string(os.PathSeparator) + } + if !strings.HasPrefix(path, dirPrefix) { + return "", errors.New("invalid name") + } + + return path, nil +} + // MkdirTemp creates a new temporary directory with given permissions // and returns the pathname of the new directory. func MkdirTemp(perm fs.FileMode) (string, error) { diff --git a/internal/fileio/fileio_test.go b/internal/fileio/fileio_test.go index bfef134..4c886de 100644 --- a/internal/fileio/fileio_test.go +++ b/internal/fileio/fileio_test.go @@ -160,6 +160,100 @@ func TestWriteFile(t *testing.T) { }) } +func TestJoinDir(t *testing.T) { + tests := []struct { + name string + dir string + filename string + want string + wantErr bool + }{ + { + name: "regular join", + dir: "/home/user", + filename: "docs/report.txt", + want: "/home/user/docs/report.txt", + wantErr: false, + }, + { + name: "join with dot", + dir: "/home/user", + filename: ".", + want: "", + wantErr: true, + }, + { + name: "join with absolute path", + dir: "/home/user", + filename: "/etc/passwd", + want: "", + wantErr: true, + }, + { + name: "join with parent directory", + dir: "/home/user", + filename: "../user2/docs/report.txt", + want: "", + wantErr: true, + }, + { + name: "empty directory", + dir: "", + filename: "report.txt", + want: "", + wantErr: true, + }, + { + name: "empty filename", + dir: "/home/user", + filename: "", + want: "", + wantErr: true, + }, + { + name: "directory with trailing slash", + dir: "/home/user/", + filename: "docs/report.txt", + want: "/home/user/docs/report.txt", + wantErr: false, + }, + { + name: "filename with leading slash", + dir: "/home/user", + filename: "/docs/report.txt", + want: "", + wantErr: true, + }, + { + name: "root directory", + dir: "/", + filename: "report.txt", + want: "/report.txt", + wantErr: false, + }, + { + name: "dot dot slash filename", + dir: "/home/user", + filename: "..", + want: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := JoinDir(tt.dir, tt.filename) + if (err != nil) != tt.wantErr { + t.Errorf("JoinDir() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("JoinDir() = %v, want %v", got, tt.want) + } + }) + } +} + func TestMkdirTemp(t *testing.T) { t.Run("default permissions", func(t *testing.T) { const perm = 0755