Skip to content

Commit

Permalink
Merge pull request #72 from KonferCA/71-be-patch-directory-traversal-…
Browse files Browse the repository at this point in the history
…in-static-file-serving

Fix/71/prevent-directory-traversal
  • Loading branch information
AmirAgassi authored Nov 18, 2024
2 parents 525efb4 + 5838b2f commit feea0dd
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 1 deletion.
24 changes: 23 additions & 1 deletion internal/server/static.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package server

import (
"mime"
"net/http"
"net/url"
"path/filepath"
"strings"

Expand Down Expand Up @@ -29,7 +31,27 @@ func (s *Server) setupStaticRoutes() {

// serve assets with correct mime types
s.echoInstance.GET("/assets/*", func(c echo.Context) error {
path := filepath.Join(staticDir, "assets", c.Param("*"))
// url decode and clean the path
requestedPath, err := url.QueryUnescape(c.Param("*"))
if err != nil {
return echo.NewHTTPError(http.StatusForbidden, "invalid path")
}

requestedPath = filepath.Clean(requestedPath)
if strings.Contains(requestedPath, "..") {
return echo.NewHTTPError(http.StatusForbidden, "invalid path")
}

// create a safe path within static directory
path := filepath.Join(staticDir, "assets", requestedPath)

// verify the final path is still within the static directory
absStaticDir, _ := filepath.Abs(staticDir)
absPath, _ := filepath.Abs(path)
if !strings.HasPrefix(absPath, absStaticDir) {
return echo.NewHTTPError(http.StatusForbidden, "invalid path")
}

return c.File(path)
})

Expand Down
69 changes: 69 additions & 0 deletions internal/server/static_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package server

import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
)

func TestStaticFileServing(t *testing.T) {
// setup test directory structure
staticDir := "static/dist"
assetsDir := filepath.Join(staticDir, "assets", "img")
err := os.MkdirAll(assetsDir, 0755)
assert.NoError(t, err)
defer os.RemoveAll("static") // cleanup after test

// create a test file
testFile := filepath.Join(assetsDir, "logo.png")
err = os.WriteFile(testFile, []byte("test content"), 0644)
assert.NoError(t, err)

// create index.html
err = os.WriteFile(filepath.Join(staticDir, "index.html"), []byte("<html>test</html>"), 0644)
assert.NoError(t, err)

// setup test server
s, err := New(true)
assert.NoError(t, err)

tests := []struct {
name string
path string
expectedCode int
}{
{
name: "normal asset path",
path: "/assets/img/logo.png",
expectedCode: http.StatusOK,
},
{
name: "path traversal attempt 1",
path: "/assets/../../../etc/passwd",
expectedCode: http.StatusForbidden,
},
{
name: "path traversal attempt 2",
path: "/assets/%2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd",
expectedCode: http.StatusForbidden,
},
{
name: "path traversal attempt 3",
path: "/assets/..%2f..%2f..%2fetc%2fpasswd",
expectedCode: http.StatusForbidden,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, tt.path, nil)
rec := httptest.NewRecorder()
s.echoInstance.ServeHTTP(rec, req)
assert.Equal(t, tt.expectedCode, rec.Code)
})
}
}

0 comments on commit feea0dd

Please sign in to comment.