Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add file check middleware #269

Merged
merged 5 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions backend/internal/middleware/file_check.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package middleware

import (
"fmt"
"mime/multipart"
"net/http"
"strings"

"github.com/gabriel-vasile/mimetype"
"github.com/labstack/echo/v4"
)

/*
FileConfig holds configuration for the file validation middleware.
MinSize and MaxSize are in bytes.

Example:
1MB = 1 * 1024 * 1024 bytes
10MB = 10 * 1024 * 1024 bytes
*/
type FileConfig struct {
MinSize int64
MaxSize int64
AllowedTypes []string // ex. ["image/jpeg", "image/png", "application/pdf"]
StrictValidation bool // If true, always verify content type matches header
}

/*
FileCheck middleware ensures uploaded files meet specified criteria:
- Size limits (via Content-Length header and actual file size)
- MIME type validation

Usage:
e.POST("/upload", handler, middleware.FileCheck(middleware.FileConfig{
MinSize: 1024, // 1KB minimum
MaxSize: 10485760, // 10MB maximum
AllowedTypes: []string{
"image/jpeg",
"image/png",
"application/pdf",
},
}))
*/
func FileCheck(config FileConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// only check on requests that might have file uploads
if c.Request().Method != http.MethodPost && c.Request().Method != http.MethodPut {
return next(c)
}

// first check content-length as early rejection
contentLength := c.Request().ContentLength
if contentLength == -1 {
return echo.NewHTTPError(http.StatusBadRequest, "content length required")
}

if contentLength > config.MaxSize {
return echo.NewHTTPError(http.StatusRequestEntityTooLarge,
fmt.Sprintf("file size %d exceeds maximum allowed size of %d", contentLength, config.MaxSize))
}

// parse multipart form with max size limit to prevent memory exhaustion
if err := c.Request().ParseMultipartForm(config.MaxSize); err != nil {
return echo.NewHTTPError(http.StatusRequestEntityTooLarge, "file too large")
}

// check actual file sizes and MIME types
form := c.Request().MultipartForm
if form != nil && form.File != nil {
for _, files := range form.File {
for _, file := range files {
if err := validateFile(file, config); err != nil {
return err
}
}
}
}

return next(c)
}
}
}

// validateFile checks both size and MIME type of a single file
func validateFile(file *multipart.FileHeader, config FileConfig) error {
// Check file size
size := file.Size
if size > config.MaxSize {
return echo.NewHTTPError(http.StatusRequestEntityTooLarge,
fmt.Sprintf("file %s size %d exceeds maximum allowed size of %d", file.Filename, size, config.MaxSize))
}
if size < config.MinSize {
return echo.NewHTTPError(http.StatusBadRequest,
fmt.Sprintf("file %s size %d below minimum required size of %d", file.Filename, size, config.MinSize))
}

// Check MIME type if restrictions are specified
if len(config.AllowedTypes) > 0 {
declaredType := file.Header.Get("Content-Type")
declaredType = strings.Split(declaredType, ";")[0] // Remove parameters

// If no Content-Type header or strict validation is enabled, check actual content
if declaredType == "" || config.StrictValidation {
f, err := file.Open()
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "could not read file")
}
defer f.Close()

mime, err := mimetype.DetectReader(f)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "could not detect file type")
}

actualType := mime.String()

// If we have both types, verify they match (when strict validation is enabled)
if declaredType != "" && config.StrictValidation && !strings.EqualFold(declaredType, actualType) {
return echo.NewHTTPError(http.StatusBadRequest,
fmt.Sprintf("declared Content-Type (%s) doesn't match actual content type (%s)",
declaredType, actualType))
}

// Use actual type if no declared type, otherwise use declared type
if declaredType == "" {
declaredType = actualType
}
}

isAllowed := false
for _, allowed := range config.AllowedTypes {
if strings.EqualFold(declaredType, allowed) {
isAllowed = true
break
}
}

if !isAllowed {
return echo.NewHTTPError(http.StatusBadRequest,
fmt.Sprintf("file type %s not allowed for %s. Allowed types: %v",
declaredType, file.Filename, config.AllowedTypes))
}
}

return nil
}
216 changes: 216 additions & 0 deletions backend/internal/middleware/file_check_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
package middleware

import (
"bytes"
"fmt"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/textproto"
"testing"

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)

func TestFileCheck(t *testing.T) {
e := echo.New()

handler := func(c echo.Context) error {
return c.String(http.StatusOK, "success")
}

// helper to create multipart request with a file and optional content type
createMultipartRequest := func(filename string, content []byte, contentType string) (*http.Request, error) {
body := new(bytes.Buffer)
writer := multipart.NewWriter(body)

// Create form file with headers
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, "file", filename))
if contentType != "" {
h.Set("Content-Type", contentType)
}

part, err := writer.CreatePart(h)
if err != nil {
return nil, err
}

part.Write(content)
writer.Close()

req := httptest.NewRequest(http.MethodPost, "/", body)
req.Header.Set("Content-Type", writer.FormDataContentType())
req.ContentLength = int64(body.Len())
return req, nil
}

// Sample file contents with proper headers
jpegHeader := []byte{
0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46,
0x49, 0x46, 0x00, 0x01,
}
pngHeader := []byte{
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52,
}
pdfHeader := []byte{
0x25, 0x50, 0x44, 0x46, 0x2D, 0x31, 0x2E, 0x34,
0x0A, 0x25, 0xC7, 0xEC, 0x8F, 0xA2, 0x0A,
}

tests := []struct {
name string
config FileConfig
filename string
content []byte
contentType string
expectedStatus int
expectedError string
}{
{
name: "valid jpeg with matching content type",
config: FileConfig{
MinSize: 3,
MaxSize: 1024,
AllowedTypes: []string{"image/jpeg"},
StrictValidation: true,
},
filename: "test.jpg",
content: append(jpegHeader, []byte("dummy content")...),
contentType: "image/jpeg",
expectedStatus: http.StatusOK,
},
{
name: "valid png without content type header",
config: FileConfig{
MinSize: 4,
MaxSize: 1024,
AllowedTypes: []string{"image/png"},
StrictValidation: false,
},
filename: "test.png",
content: append(pngHeader, []byte("dummy content")...),
expectedStatus: http.StatusOK,
},
{
name: "mismatched content type with strict validation",
config: FileConfig{
MinSize: 5,
MaxSize: 1024,
AllowedTypes: []string{"image/jpeg", "image/png"},
StrictValidation: true,
},
filename: "test.jpg",
content: append(pngHeader, []byte("dummy content")...),
contentType: "image/jpeg",
expectedStatus: http.StatusBadRequest,
expectedError: "doesn't match actual content type",
},
{
name: "file too large",
config: FileConfig{
MinSize: 5,
MaxSize: 100,
AllowedTypes: []string{"image/jpeg"},
},
filename: "large.jpg",
content: append(jpegHeader, bytes.Repeat([]byte("a"), 150)...),
contentType: "image/jpeg",
expectedStatus: http.StatusRequestEntityTooLarge,
expectedError: "file size",
},
{
name: "file too small",
config: FileConfig{
MinSize: 50,
MaxSize: 1024,
AllowedTypes: []string{"image/jpeg"},
},
filename: "small.jpg",
content: append(jpegHeader, []byte("tiny")...),
contentType: "image/jpeg",
expectedStatus: http.StatusBadRequest,
expectedError: "below minimum required size",
},
{
name: "wrong mime type",
config: FileConfig{
MinSize: 5,
MaxSize: 1024,
AllowedTypes: []string{"image/jpeg", "image/png"},
},
filename: "document.pdf",
content: append(pdfHeader, []byte("dummy content")...),
contentType: "application/pdf",
expectedStatus: http.StatusBadRequest,
expectedError: "file type",
},
{
name: "multiple allowed types",
config: FileConfig{
MinSize: 5,
MaxSize: 1024,
AllowedTypes: []string{"image/jpeg", "image/png", "application/pdf"},
},
filename: "document.pdf",
content: append(pdfHeader, []byte("dummy content")...),
contentType: "application/pdf",
expectedStatus: http.StatusOK,
},
{
name: "strict validation success",
config: FileConfig{
MinSize: 5,
MaxSize: 1024,
AllowedTypes: []string{"application/pdf"},
StrictValidation: true,
},
filename: "document.pdf",
content: append(pdfHeader, []byte("dummy content")...),
contentType: "application/pdf",
expectedStatus: http.StatusOK,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := createMultipartRequest(tt.filename, tt.content, tt.contentType)
assert.NoError(t, err)

rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

h := FileCheck(tt.config)(handler)
err = h(c)

if tt.expectedStatus != http.StatusOK {
he, ok := err.(*echo.HTTPError)
assert.True(t, ok)
assert.Equal(t, tt.expectedStatus, he.Code)
if tt.expectedError != "" {
assert.Contains(t, he.Message, tt.expectedError)
}
} else {
assert.NoError(t, err)
}
})
}

// Test GET request (should skip check)
t.Run("skip check for GET request", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

h := FileCheck(FileConfig{
MinSize: 5,
MaxSize: 100,
})(handler)

err := h(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, rec.Code)
})
}
Loading