Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add file check middleware #269

Merged
merged 5 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions backend/internal/middleware/upload.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package middleware

import (
"fmt"
"net/http"

"github.com/labstack/echo/v4"
)

/*
FileSizeConfig holds configuration for the file size check middleware.
MinSize and MaxSize are in bytes.

Example:
1MB = 1 * 1024 * 1024 bytes
10MB = 10 * 1024 * 1024 bytes
*/
type FileSizeConfig struct {
MinSize int64
MaxSize int64
}

/*
FileSizeCheck middleware ensures uploaded files are within specified size limits.
It checks the Content-Length header and returns 413 if file is too large
or 400 if file is too small.

Usage:
e.POST("/upload", handler, middleware.FileSizeCheck(middleware.FileSizeConfig{
MinSize: 1024, // 1KB minimum
MaxSize: 10485760, // 10MB maximum
}))
*/
func FileSizeCheck(config FileSizeConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// only check on requests that might have file uploads
if c.Request().Method != http.MethodPost && c.Request().Method != http.MethodPut {
return next(c)
}

// first check content-length as early rejection
contentLength := c.Request().ContentLength
if contentLength == -1 {
return echo.NewHTTPError(http.StatusBadRequest, "content length required")
}

if contentLength > config.MaxSize {
return echo.NewHTTPError(http.StatusRequestEntityTooLarge,
fmt.Sprintf("file size %d exceeds maximum allowed size of %d", contentLength, config.MaxSize))
}

// parse multipart form with max size limit to prevent memory exhaustion
if err := c.Request().ParseMultipartForm(config.MaxSize); err != nil {
return echo.NewHTTPError(http.StatusRequestEntityTooLarge, "file too large")
}

// check actual file sizes if the content-length check passed
// (i don't think it would ever happen, but clients can fake a content-length header)
form := c.Request().MultipartForm
if form != nil && form.File != nil {
for _, files := range form.File {
for _, file := range files {
size := file.Size
if size > config.MaxSize {
return echo.NewHTTPError(http.StatusRequestEntityTooLarge,
fmt.Sprintf("file %s size %d exceeds maximum allowed size of %d", file.Filename, size, config.MaxSize))
}
if size < config.MinSize {
return echo.NewHTTPError(http.StatusBadRequest,
fmt.Sprintf("file %s size %d below minimum required size of %d", file.Filename, size, config.MinSize))
}
}
}
}

return next(c)
}
}
}
115 changes: 115 additions & 0 deletions backend/internal/middleware/upload_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package middleware

import (
"bytes"
"mime/multipart"
"net/http"
"net/http/httptest"
"testing"

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)

func TestFileSizeCheck(t *testing.T) {
e := echo.New()

handler := func(c echo.Context) error {
return c.String(http.StatusOK, "success")
}

// helper to create multipart request with a file
createMultipartRequest := func(filename string, content []byte) (*http.Request, *bytes.Buffer, error) {
body := new(bytes.Buffer)
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return nil, nil, err
}
part.Write(content)
writer.Close()

req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
req.ContentLength = int64(body.Len())
return req, body, nil
}

tests := []struct {
name string
config FileSizeConfig
fileSize int
expectedStatus int
}{
{
name: "valid file size",
config: FileSizeConfig{
MinSize: 5,
MaxSize: 1024, // increased to account for form overhead
},
fileSize: 50,
expectedStatus: http.StatusOK,
},
{
name: "file too large",
config: FileSizeConfig{
MinSize: 5,
MaxSize: 100,
},
fileSize: 150,
expectedStatus: http.StatusRequestEntityTooLarge,
},
{
name: "file too small",
config: FileSizeConfig{
MinSize: 50,
MaxSize: 1024,
},
fileSize: 10,
expectedStatus: http.StatusBadRequest,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// create a file with specified size
content := make([]byte, tt.fileSize)
req, body, err := createMultipartRequest("test.txt", content)
assert.NoError(t, err)

// log actual size for debugging
t.Logf("Total request size: %d, File content size: %d", body.Len(), tt.fileSize)

rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

// wrap handler with middleware
h := FileSizeCheck(tt.config)(handler)
err = h(c)

if tt.expectedStatus != http.StatusOK {
he, ok := err.(*echo.HTTPError)
assert.True(t, ok)
assert.Equal(t, tt.expectedStatus, he.Code)
} else {
assert.NoError(t, err)
}
})
}

// Test GET request (should skip check)
t.Run("skip check for GET request", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

h := FileSizeCheck(FileSizeConfig{
MinSize: 5,
MaxSize: 100,
})(handler)

err := h(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
})
}
Loading