diff --git a/internal/server/static.go b/internal/server/static.go index 454bfd3..0445b6e 100644 --- a/internal/server/static.go +++ b/internal/server/static.go @@ -2,6 +2,8 @@ package server import ( "mime" + "net/http" + "net/url" "path/filepath" "strings" @@ -29,7 +31,27 @@ func (s *Server) setupStaticRoutes() { // serve assets with correct mime types s.echoInstance.GET("/assets/*", func(c echo.Context) error { - path := filepath.Join(staticDir, "assets", c.Param("*")) + // url decode and clean the path + requestedPath, err := url.QueryUnescape(c.Param("*")) + if err != nil { + return echo.NewHTTPError(http.StatusForbidden, "invalid path") + } + + requestedPath = filepath.Clean(requestedPath) + if strings.Contains(requestedPath, "..") { + return echo.NewHTTPError(http.StatusForbidden, "invalid path") + } + + // create a safe path within static directory + path := filepath.Join(staticDir, "assets", requestedPath) + + // verify the final path is still within the static directory + absStaticDir, _ := filepath.Abs(staticDir) + absPath, _ := filepath.Abs(path) + if !strings.HasPrefix(absPath, absStaticDir) { + return echo.NewHTTPError(http.StatusForbidden, "invalid path") + } + return c.File(path) }) diff --git a/internal/server/static_test.go b/internal/server/static_test.go new file mode 100644 index 0000000..731ef46 --- /dev/null +++ b/internal/server/static_test.go @@ -0,0 +1,69 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStaticFileServing(t *testing.T) { + // setup test directory structure + staticDir := "static/dist" + assetsDir := filepath.Join(staticDir, "assets", "img") + err := os.MkdirAll(assetsDir, 0755) + assert.NoError(t, err) + defer os.RemoveAll("static") // cleanup after test + + // create a test file + testFile := filepath.Join(assetsDir, "logo.png") + err = os.WriteFile(testFile, []byte("test content"), 0644) + assert.NoError(t, err) + + // create index.html + err = os.WriteFile(filepath.Join(staticDir, "index.html"), []byte("test"), 0644) + assert.NoError(t, err) + + // setup test server + s, err := New(true) + assert.NoError(t, err) + + tests := []struct { + name string + path string + expectedCode int + }{ + { + name: "normal asset path", + path: "/assets/img/logo.png", + expectedCode: http.StatusOK, + }, + { + name: "path traversal attempt 1", + path: "/assets/../../../etc/passwd", + expectedCode: http.StatusForbidden, + }, + { + name: "path traversal attempt 2", + path: "/assets/%2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd", + expectedCode: http.StatusForbidden, + }, + { + name: "path traversal attempt 3", + path: "/assets/..%2f..%2f..%2fetc%2fpasswd", + expectedCode: http.StatusForbidden, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tt.path, nil) + rec := httptest.NewRecorder() + s.echoInstance.ServeHTTP(rec, req) + assert.Equal(t, tt.expectedCode, rec.Code) + }) + } +}