-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
363 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
package middleware | ||
|
||
import ( | ||
"fmt" | ||
"mime/multipart" | ||
"net/http" | ||
"strings" | ||
|
||
"github.com/gabriel-vasile/mimetype" | ||
"github.com/labstack/echo/v4" | ||
) | ||
|
||
/* | ||
FileConfig holds configuration for the file validation middleware. | ||
MinSize and MaxSize are in bytes. | ||
Example: | ||
1MB = 1 * 1024 * 1024 bytes | ||
10MB = 10 * 1024 * 1024 bytes | ||
*/ | ||
type FileConfig struct { | ||
MinSize int64 | ||
MaxSize int64 | ||
AllowedTypes []string // ex. ["image/jpeg", "image/png", "application/pdf"] | ||
StrictValidation bool // If true, always verify content type matches header | ||
} | ||
|
||
/* | ||
FileCheck middleware ensures uploaded files meet specified criteria: | ||
- Size limits (via Content-Length header and actual file size) | ||
- MIME type validation | ||
Usage: | ||
e.POST("/upload", handler, middleware.FileCheck(middleware.FileConfig{ | ||
MinSize: 1024, // 1KB minimum | ||
MaxSize: 10485760, // 10MB maximum | ||
AllowedTypes: []string{ | ||
"image/jpeg", | ||
"image/png", | ||
"application/pdf", | ||
}, | ||
})) | ||
*/ | ||
func FileCheck(config FileConfig) echo.MiddlewareFunc { | ||
return func(next echo.HandlerFunc) echo.HandlerFunc { | ||
return func(c echo.Context) error { | ||
// only check on requests that might have file uploads | ||
if c.Request().Method != http.MethodPost && c.Request().Method != http.MethodPut { | ||
return next(c) | ||
} | ||
|
||
// first check content-length as early rejection | ||
contentLength := c.Request().ContentLength | ||
if contentLength == -1 { | ||
return echo.NewHTTPError(http.StatusBadRequest, "content length required") | ||
} | ||
|
||
if contentLength > config.MaxSize { | ||
return echo.NewHTTPError(http.StatusRequestEntityTooLarge, | ||
fmt.Sprintf("file size %d exceeds maximum allowed size of %d", contentLength, config.MaxSize)) | ||
} | ||
|
||
// parse multipart form with max size limit to prevent memory exhaustion | ||
if err := c.Request().ParseMultipartForm(config.MaxSize); err != nil { | ||
return echo.NewHTTPError(http.StatusRequestEntityTooLarge, "file too large") | ||
} | ||
|
||
// check actual file sizes and MIME types | ||
form := c.Request().MultipartForm | ||
if form != nil && form.File != nil { | ||
for _, files := range form.File { | ||
for _, file := range files { | ||
if err := validateFile(file, config); err != nil { | ||
return err | ||
} | ||
} | ||
} | ||
} | ||
|
||
return next(c) | ||
} | ||
} | ||
} | ||
|
||
// validateFile checks both size and MIME type of a single file | ||
func validateFile(file *multipart.FileHeader, config FileConfig) error { | ||
// Check file size | ||
size := file.Size | ||
if size > config.MaxSize { | ||
return echo.NewHTTPError(http.StatusRequestEntityTooLarge, | ||
fmt.Sprintf("file %s size %d exceeds maximum allowed size of %d", file.Filename, size, config.MaxSize)) | ||
} | ||
if size < config.MinSize { | ||
return echo.NewHTTPError(http.StatusBadRequest, | ||
fmt.Sprintf("file %s size %d below minimum required size of %d", file.Filename, size, config.MinSize)) | ||
} | ||
|
||
// Check MIME type if restrictions are specified | ||
if len(config.AllowedTypes) > 0 { | ||
declaredType := file.Header.Get("Content-Type") | ||
declaredType = strings.Split(declaredType, ";")[0] // Remove parameters | ||
|
||
// If no Content-Type header or strict validation is enabled, check actual content | ||
if declaredType == "" || config.StrictValidation { | ||
f, err := file.Open() | ||
if err != nil { | ||
return echo.NewHTTPError(http.StatusBadRequest, "could not read file") | ||
} | ||
defer f.Close() | ||
|
||
mime, err := mimetype.DetectReader(f) | ||
if err != nil { | ||
return echo.NewHTTPError(http.StatusBadRequest, "could not detect file type") | ||
} | ||
|
||
actualType := mime.String() | ||
|
||
// If we have both types, verify they match (when strict validation is enabled) | ||
if declaredType != "" && config.StrictValidation && !strings.EqualFold(declaredType, actualType) { | ||
return echo.NewHTTPError(http.StatusBadRequest, | ||
fmt.Sprintf("declared Content-Type (%s) doesn't match actual content type (%s)", | ||
declaredType, actualType)) | ||
} | ||
|
||
// Use actual type if no declared type, otherwise use declared type | ||
if declaredType == "" { | ||
declaredType = actualType | ||
} | ||
} | ||
|
||
isAllowed := false | ||
for _, allowed := range config.AllowedTypes { | ||
if strings.EqualFold(declaredType, allowed) { | ||
isAllowed = true | ||
break | ||
} | ||
} | ||
|
||
if !isAllowed { | ||
return echo.NewHTTPError(http.StatusBadRequest, | ||
fmt.Sprintf("file type %s not allowed for %s. Allowed types: %v", | ||
declaredType, file.Filename, config.AllowedTypes)) | ||
} | ||
} | ||
|
||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
package middleware | ||
|
||
import ( | ||
"bytes" | ||
"fmt" | ||
"mime/multipart" | ||
"net/http" | ||
"net/http/httptest" | ||
"net/textproto" | ||
"testing" | ||
|
||
"github.com/labstack/echo/v4" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestFileCheck(t *testing.T) { | ||
e := echo.New() | ||
|
||
handler := func(c echo.Context) error { | ||
return c.String(http.StatusOK, "success") | ||
} | ||
|
||
// helper to create multipart request with a file and optional content type | ||
createMultipartRequest := func(filename string, content []byte, contentType string) (*http.Request, error) { | ||
body := new(bytes.Buffer) | ||
writer := multipart.NewWriter(body) | ||
|
||
// Create form file with headers | ||
h := make(textproto.MIMEHeader) | ||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, "file", filename)) | ||
if contentType != "" { | ||
h.Set("Content-Type", contentType) | ||
} | ||
|
||
part, err := writer.CreatePart(h) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
part.Write(content) | ||
writer.Close() | ||
|
||
req := httptest.NewRequest(http.MethodPost, "/", body) | ||
req.Header.Set("Content-Type", writer.FormDataContentType()) | ||
req.ContentLength = int64(body.Len()) | ||
return req, nil | ||
} | ||
|
||
// Sample file contents with proper headers | ||
jpegHeader := []byte{ | ||
0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46, | ||
0x49, 0x46, 0x00, 0x01, | ||
} | ||
pngHeader := []byte{ | ||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, | ||
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, | ||
} | ||
pdfHeader := []byte{ | ||
0x25, 0x50, 0x44, 0x46, 0x2D, 0x31, 0x2E, 0x34, | ||
0x0A, 0x25, 0xC7, 0xEC, 0x8F, 0xA2, 0x0A, | ||
} | ||
|
||
tests := []struct { | ||
name string | ||
config FileConfig | ||
filename string | ||
content []byte | ||
contentType string | ||
expectedStatus int | ||
expectedError string | ||
}{ | ||
{ | ||
name: "valid jpeg with matching content type", | ||
config: FileConfig{ | ||
MinSize: 3, | ||
MaxSize: 1024, | ||
AllowedTypes: []string{"image/jpeg"}, | ||
StrictValidation: true, | ||
}, | ||
filename: "test.jpg", | ||
content: append(jpegHeader, []byte("dummy content")...), | ||
contentType: "image/jpeg", | ||
expectedStatus: http.StatusOK, | ||
}, | ||
{ | ||
name: "valid png without content type header", | ||
config: FileConfig{ | ||
MinSize: 4, | ||
MaxSize: 1024, | ||
AllowedTypes: []string{"image/png"}, | ||
StrictValidation: false, | ||
}, | ||
filename: "test.png", | ||
content: append(pngHeader, []byte("dummy content")...), | ||
expectedStatus: http.StatusOK, | ||
}, | ||
{ | ||
name: "mismatched content type with strict validation", | ||
config: FileConfig{ | ||
MinSize: 5, | ||
MaxSize: 1024, | ||
AllowedTypes: []string{"image/jpeg", "image/png"}, | ||
StrictValidation: true, | ||
}, | ||
filename: "test.jpg", | ||
content: append(pngHeader, []byte("dummy content")...), | ||
contentType: "image/jpeg", | ||
expectedStatus: http.StatusBadRequest, | ||
expectedError: "doesn't match actual content type", | ||
}, | ||
{ | ||
name: "file too large", | ||
config: FileConfig{ | ||
MinSize: 5, | ||
MaxSize: 100, | ||
AllowedTypes: []string{"image/jpeg"}, | ||
}, | ||
filename: "large.jpg", | ||
content: append(jpegHeader, bytes.Repeat([]byte("a"), 150)...), | ||
contentType: "image/jpeg", | ||
expectedStatus: http.StatusRequestEntityTooLarge, | ||
expectedError: "file size", | ||
}, | ||
{ | ||
name: "file too small", | ||
config: FileConfig{ | ||
MinSize: 50, | ||
MaxSize: 1024, | ||
AllowedTypes: []string{"image/jpeg"}, | ||
}, | ||
filename: "small.jpg", | ||
content: append(jpegHeader, []byte("tiny")...), | ||
contentType: "image/jpeg", | ||
expectedStatus: http.StatusBadRequest, | ||
expectedError: "below minimum required size", | ||
}, | ||
{ | ||
name: "wrong mime type", | ||
config: FileConfig{ | ||
MinSize: 5, | ||
MaxSize: 1024, | ||
AllowedTypes: []string{"image/jpeg", "image/png"}, | ||
}, | ||
filename: "document.pdf", | ||
content: append(pdfHeader, []byte("dummy content")...), | ||
contentType: "application/pdf", | ||
expectedStatus: http.StatusBadRequest, | ||
expectedError: "file type", | ||
}, | ||
{ | ||
name: "multiple allowed types", | ||
config: FileConfig{ | ||
MinSize: 5, | ||
MaxSize: 1024, | ||
AllowedTypes: []string{"image/jpeg", "image/png", "application/pdf"}, | ||
}, | ||
filename: "document.pdf", | ||
content: append(pdfHeader, []byte("dummy content")...), | ||
contentType: "application/pdf", | ||
expectedStatus: http.StatusOK, | ||
}, | ||
{ | ||
name: "strict validation success", | ||
config: FileConfig{ | ||
MinSize: 5, | ||
MaxSize: 1024, | ||
AllowedTypes: []string{"application/pdf"}, | ||
StrictValidation: true, | ||
}, | ||
filename: "document.pdf", | ||
content: append(pdfHeader, []byte("dummy content")...), | ||
contentType: "application/pdf", | ||
expectedStatus: http.StatusOK, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
req, err := createMultipartRequest(tt.filename, tt.content, tt.contentType) | ||
assert.NoError(t, err) | ||
|
||
rec := httptest.NewRecorder() | ||
c := e.NewContext(req, rec) | ||
|
||
h := FileCheck(tt.config)(handler) | ||
err = h(c) | ||
|
||
if tt.expectedStatus != http.StatusOK { | ||
he, ok := err.(*echo.HTTPError) | ||
assert.True(t, ok) | ||
assert.Equal(t, tt.expectedStatus, he.Code) | ||
if tt.expectedError != "" { | ||
assert.Contains(t, he.Message, tt.expectedError) | ||
} | ||
} else { | ||
assert.NoError(t, err) | ||
} | ||
}) | ||
} | ||
|
||
// Test GET request (should skip check) | ||
t.Run("skip check for GET request", func(t *testing.T) { | ||
req := httptest.NewRequest(http.MethodGet, "/", nil) | ||
rec := httptest.NewRecorder() | ||
c := e.NewContext(req, rec) | ||
|
||
h := FileCheck(FileConfig{ | ||
MinSize: 5, | ||
MaxSize: 100, | ||
})(handler) | ||
|
||
err := h(c) | ||
assert.NoError(t, err) | ||
assert.Equal(t, http.StatusOK, rec.Code) | ||
}) | ||
} |