Skip to content

Commit

Permalink
Add file check middleware (#269)
Browse files Browse the repository at this point in the history
  • Loading branch information
AmirAgassi authored Dec 18, 2024
2 parents 61d9792 + 65b4652 commit 1ef1cbe
Show file tree
Hide file tree
Showing 2 changed files with 363 additions and 0 deletions.
147 changes: 147 additions & 0 deletions backend/internal/middleware/file_check.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package middleware

import (
"fmt"
"mime/multipart"
"net/http"
"strings"

"github.com/gabriel-vasile/mimetype"
"github.com/labstack/echo/v4"
)

/*
FileConfig holds configuration for the file validation middleware.
MinSize and MaxSize are in bytes.
Example:
1MB = 1 * 1024 * 1024 bytes
10MB = 10 * 1024 * 1024 bytes
*/
type FileConfig struct {
MinSize int64
MaxSize int64
AllowedTypes []string // ex. ["image/jpeg", "image/png", "application/pdf"]
StrictValidation bool // If true, always verify content type matches header
}

/*
FileCheck middleware ensures uploaded files meet specified criteria:
- Size limits (via Content-Length header and actual file size)
- MIME type validation
Usage:
e.POST("/upload", handler, middleware.FileCheck(middleware.FileConfig{
MinSize: 1024, // 1KB minimum
MaxSize: 10485760, // 10MB maximum
AllowedTypes: []string{
"image/jpeg",
"image/png",
"application/pdf",
},
}))
*/
func FileCheck(config FileConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// only check on requests that might have file uploads
if c.Request().Method != http.MethodPost && c.Request().Method != http.MethodPut {
return next(c)
}

// first check content-length as early rejection
contentLength := c.Request().ContentLength
if contentLength == -1 {
return echo.NewHTTPError(http.StatusBadRequest, "content length required")
}

if contentLength > config.MaxSize {
return echo.NewHTTPError(http.StatusRequestEntityTooLarge,
fmt.Sprintf("file size %d exceeds maximum allowed size of %d", contentLength, config.MaxSize))
}

// parse multipart form with max size limit to prevent memory exhaustion
if err := c.Request().ParseMultipartForm(config.MaxSize); err != nil {
return echo.NewHTTPError(http.StatusRequestEntityTooLarge, "file too large")
}

// check actual file sizes and MIME types
form := c.Request().MultipartForm
if form != nil && form.File != nil {
for _, files := range form.File {
for _, file := range files {
if err := validateFile(file, config); err != nil {
return err
}
}
}
}

return next(c)
}
}
}

// validateFile checks both size and MIME type of a single file
func validateFile(file *multipart.FileHeader, config FileConfig) error {
// Check file size
size := file.Size
if size > config.MaxSize {
return echo.NewHTTPError(http.StatusRequestEntityTooLarge,
fmt.Sprintf("file %s size %d exceeds maximum allowed size of %d", file.Filename, size, config.MaxSize))
}
if size < config.MinSize {
return echo.NewHTTPError(http.StatusBadRequest,
fmt.Sprintf("file %s size %d below minimum required size of %d", file.Filename, size, config.MinSize))
}

// Check MIME type if restrictions are specified
if len(config.AllowedTypes) > 0 {
declaredType := file.Header.Get("Content-Type")
declaredType = strings.Split(declaredType, ";")[0] // Remove parameters

// If no Content-Type header or strict validation is enabled, check actual content
if declaredType == "" || config.StrictValidation {
f, err := file.Open()
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "could not read file")
}
defer f.Close()

mime, err := mimetype.DetectReader(f)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "could not detect file type")
}

actualType := mime.String()

// If we have both types, verify they match (when strict validation is enabled)
if declaredType != "" && config.StrictValidation && !strings.EqualFold(declaredType, actualType) {
return echo.NewHTTPError(http.StatusBadRequest,
fmt.Sprintf("declared Content-Type (%s) doesn't match actual content type (%s)",
declaredType, actualType))
}

// Use actual type if no declared type, otherwise use declared type
if declaredType == "" {
declaredType = actualType
}
}

isAllowed := false
for _, allowed := range config.AllowedTypes {
if strings.EqualFold(declaredType, allowed) {
isAllowed = true
break
}
}

if !isAllowed {
return echo.NewHTTPError(http.StatusBadRequest,
fmt.Sprintf("file type %s not allowed for %s. Allowed types: %v",
declaredType, file.Filename, config.AllowedTypes))
}
}

return nil
}
216 changes: 216 additions & 0 deletions backend/internal/middleware/file_check_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
package middleware

import (
"bytes"
"fmt"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/textproto"
"testing"

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)

func TestFileCheck(t *testing.T) {
e := echo.New()

handler := func(c echo.Context) error {
return c.String(http.StatusOK, "success")
}

// helper to create multipart request with a file and optional content type
createMultipartRequest := func(filename string, content []byte, contentType string) (*http.Request, error) {
body := new(bytes.Buffer)
writer := multipart.NewWriter(body)

// Create form file with headers
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, "file", filename))
if contentType != "" {
h.Set("Content-Type", contentType)
}

part, err := writer.CreatePart(h)
if err != nil {
return nil, err
}

part.Write(content)
writer.Close()

req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
req.ContentLength = int64(body.Len())
return req, nil
}

// Sample file contents with proper headers
jpegHeader := []byte{
0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46,
0x49, 0x46, 0x00, 0x01,
}
pngHeader := []byte{
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
}
pdfHeader := []byte{
0x25, 0x50, 0x44, 0x46, 0x2D, 0x31, 0x2E, 0x34,
0x0A, 0x25, 0xC7, 0xEC, 0x8F, 0xA2, 0x0A,
}

tests := []struct {
name string
config FileConfig
filename string
content []byte
contentType string
expectedStatus int
expectedError string
}{
{
name: "valid jpeg with matching content type",
config: FileConfig{
MinSize: 3,
MaxSize: 1024,
AllowedTypes: []string{"image/jpeg"},
StrictValidation: true,
},
filename: "test.jpg",
content: append(jpegHeader, []byte("dummy content")...),
contentType: "image/jpeg",
expectedStatus: http.StatusOK,
},
{
name: "valid png without content type header",
config: FileConfig{
MinSize: 4,
MaxSize: 1024,
AllowedTypes: []string{"image/png"},
StrictValidation: false,
},
filename: "test.png",
content: append(pngHeader, []byte("dummy content")...),
expectedStatus: http.StatusOK,
},
{
name: "mismatched content type with strict validation",
config: FileConfig{
MinSize: 5,
MaxSize: 1024,
AllowedTypes: []string{"image/jpeg", "image/png"},
StrictValidation: true,
},
filename: "test.jpg",
content: append(pngHeader, []byte("dummy content")...),
contentType: "image/jpeg",
expectedStatus: http.StatusBadRequest,
expectedError: "doesn't match actual content type",
},
{
name: "file too large",
config: FileConfig{
MinSize: 5,
MaxSize: 100,
AllowedTypes: []string{"image/jpeg"},
},
filename: "large.jpg",
content: append(jpegHeader, bytes.Repeat([]byte("a"), 150)...),
contentType: "image/jpeg",
expectedStatus: http.StatusRequestEntityTooLarge,
expectedError: "file size",
},
{
name: "file too small",
config: FileConfig{
MinSize: 50,
MaxSize: 1024,
AllowedTypes: []string{"image/jpeg"},
},
filename: "small.jpg",
content: append(jpegHeader, []byte("tiny")...),
contentType: "image/jpeg",
expectedStatus: http.StatusBadRequest,
expectedError: "below minimum required size",
},
{
name: "wrong mime type",
config: FileConfig{
MinSize: 5,
MaxSize: 1024,
AllowedTypes: []string{"image/jpeg", "image/png"},
},
filename: "document.pdf",
content: append(pdfHeader, []byte("dummy content")...),
contentType: "application/pdf",
expectedStatus: http.StatusBadRequest,
expectedError: "file type",
},
{
name: "multiple allowed types",
config: FileConfig{
MinSize: 5,
MaxSize: 1024,
AllowedTypes: []string{"image/jpeg", "image/png", "application/pdf"},
},
filename: "document.pdf",
content: append(pdfHeader, []byte("dummy content")...),
contentType: "application/pdf",
expectedStatus: http.StatusOK,
},
{
name: "strict validation success",
config: FileConfig{
MinSize: 5,
MaxSize: 1024,
AllowedTypes: []string{"application/pdf"},
StrictValidation: true,
},
filename: "document.pdf",
content: append(pdfHeader, []byte("dummy content")...),
contentType: "application/pdf",
expectedStatus: http.StatusOK,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := createMultipartRequest(tt.filename, tt.content, tt.contentType)
assert.NoError(t, err)

rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

h := FileCheck(tt.config)(handler)
err = h(c)

if tt.expectedStatus != http.StatusOK {
he, ok := err.(*echo.HTTPError)
assert.True(t, ok)
assert.Equal(t, tt.expectedStatus, he.Code)
if tt.expectedError != "" {
assert.Contains(t, he.Message, tt.expectedError)
}
} else {
assert.NoError(t, err)
}
})
}

// Test GET request (should skip check)
t.Run("skip check for GET request", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

h := FileCheck(FileConfig{
MinSize: 5,
MaxSize: 100,
})(handler)

err := h(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
})
}

0 comments on commit 1ef1cbe

Please sign in to comment.