diff --git a/internal/middleware/req_validator.go b/internal/middleware/req_validator.go new file mode 100644 index 0000000..feef71c --- /dev/null +++ b/internal/middleware/req_validator.go @@ -0,0 +1,54 @@ +package middleware + +import ( + "fmt" + "net/http" + "reflect" + + "github.com/go-playground/validator/v10" + "github.com/labstack/echo/v4" + "github.com/rs/zerolog/log" +) + +// Struct solely exists to comply with Echo's interface to add a custom validator... +type RequestBodyValidator struct { + validator *validator.Validate +} + +func (rv *RequestBodyValidator) Validate(i interface{}) error { + log.Info().Msgf("Validating struct: %+v\n", i) + if err := rv.validator.Struct(i); err != nil { + log.Error().Err(err).Msg("Validation error") + return err + } + + return nil +} + +// Creates a new request validator that can be set to an Echo instance +// and used for validating request bodies with c.Validate() +func NewRequestBodyValidator() *RequestBodyValidator { + return &RequestBodyValidator{validator: validator.New()} +} + +// Middleware that validates the incoming request body with the given structType. +func ValidateRequestBody(structType reflect.Type) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + reqStruct := reflect.New(structType) + + if err := c.Bind(reqStruct.Interface()); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid request body: %v", err)) + } + + if err := c.Validate(reqStruct.Interface()); err != nil { + // this will let the global error handler handle + // the ValidationError and get error string for + // the each invalid field. + return err + } + + return next(c) + } + } +} diff --git a/internal/middleware/req_validator_test.go b/internal/middleware/req_validator_test.go new file mode 100644 index 0000000..9f232e2 --- /dev/null +++ b/internal/middleware/req_validator_test.go @@ -0,0 +1,103 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestRequestBodyValidator(t *testing.T) { + type testStruct struct { + TestField bool `json:"test_field" validate:"required"` + } + + e := echo.New() + e.Validator = NewRequestBodyValidator() + e.POST("/", handler, ValidateRequestBody(reflect.TypeOf(testStruct{}))) + + tests := []struct { + name string + payload interface{} + expectedCode int + }{ + { + name: "Valid request body", + payload: testStruct{ + TestField: true, + }, + expectedCode: http.StatusOK, + }, + { + name: "Invalid request body - validation error", + payload: testStruct{ + // will fail required validation + TestField: false, + }, + // expecting 500 since the middleware its expected to return + // the original ValidationErrors from validator pkg + expectedCode: http.StatusInternalServerError, + }, + { + name: "Empty request body", + payload: nil, + // expecting 500 since the middleware its expected to return + // the original ValidationErrors from validator pkg + expectedCode: http.StatusInternalServerError, + }, + { + name: "Invalid JSON format", + payload: `{ + "test_field": invalid + }`, + expectedCode: http.StatusBadRequest, + }, + { + name: "Wrong type in JSON", + payload: map[string]interface{}{ + "test_field": "not a boolean", + }, + expectedCode: http.StatusBadRequest, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var req *http.Request + + if tc.payload != nil { + var payload []byte + var err error + + // handle string payloads (for invalid JSON tests) + if strPayload, ok := tc.payload.(string); ok { + payload = []byte(strPayload) + } else { + payload, err = json.Marshal(tc.payload) + assert.NoError(t, err) + } + + req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(payload)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + } else { + req = httptest.NewRequest(http.MethodPost, "/", nil) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + } + + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedCode, rec.Code) + }) + } +} + +// test handler +func handler(c echo.Context) error { + return c.String(http.StatusOK, "pass") +} diff --git a/internal/server/handler_helpers.go b/internal/server/handler_helpers.go index 066e69e..62c13e4 100644 --- a/internal/server/handler_helpers.go +++ b/internal/server/handler_helpers.go @@ -14,7 +14,10 @@ func validateBody(c echo.Context, requestBodyType interface{}) error { } if err := c.Validate(requestBodyType); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + // this will let the global error handler handle + // the ValidationError and get error string for + // the each invalid field. + return err } return nil diff --git a/internal/server/index.go b/internal/server/index.go index 22e160f..b3a8a50 100644 --- a/internal/server/index.go +++ b/internal/server/index.go @@ -69,7 +69,7 @@ func New(testing bool) (*Server, error) { e.Use(echoMiddleware.Recover()) e.Use(apiLimiter.RateLimit()) // global rate limit - customValidator := NewCustomValidator() + customValidator := middleware.NewRequestBodyValidator() fmt.Printf("Initializing validator: %+v\n", customValidator) e.Validator = customValidator diff --git a/internal/server/types.go b/internal/server/types.go index 7f516ad..fcbed28 100644 --- a/internal/server/types.go +++ b/internal/server/types.go @@ -1,12 +1,7 @@ package server import ( - "fmt" - "net/http" "time" - - "github.com/go-playground/validator/v10" - "github.com/labstack/echo/v4" ) type DatabaseInfo struct { @@ -43,25 +38,6 @@ type CreateResourceRequestRequest struct { Status string `json:"status" validate:"required"` } -type CustomValidator struct { - validator *validator.Validate -} - -func NewCustomValidator() *CustomValidator { - v := validator.New() - return &CustomValidator{validator: v} -} - -func (cv *CustomValidator) Validate(i interface{}) error { - fmt.Printf("Validating struct: %+v\n", i) - if err := cv.validator.Struct(i); err != nil { - fmt.Printf("Validation error: %v\n", err) - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - return nil -} - type SignupRequest struct { Email string `json:"email" validate:"required,email"` Password string `json:"password" validate:"required,min=8"`