From 338e434bf37d278ef8b295323ebbdee94be0ba89 Mon Sep 17 00:00:00 2001 From: AmirAgassi <33383085+AmirAgassi@users.noreply.github.com> Date: Thu, 26 Dec 2024 23:39:06 -0500 Subject: [PATCH] Fix param validation support for min/max length and regex patterns --- backend/.sqlc/queries/projects.sql | 7 +- backend/db/projects.sql.go | 21 ++++++ backend/internal/tests/projects_test.go | 75 +++++++++++++++++-- backend/internal/v1/v1_projects/projects.go | 21 ++++++ backend/internal/v1/v1_projects/validation.go | 62 +++++++++++++-- 5 files changed, 173 insertions(+), 13 deletions(-) diff --git a/backend/.sqlc/queries/projects.sql b/backend/.sqlc/queries/projects.sql index 33b160f1..887464bf 100644 --- a/backend/.sqlc/queries/projects.sql +++ b/backend/.sqlc/queries/projects.sql @@ -118,4 +118,9 @@ UPDATE projects SET status = $1, updated_at = extract(epoch from now()) -WHERE id = $2; \ No newline at end of file +WHERE id = $2; + +-- name: GetQuestionByAnswerID :one +SELECT q.* FROM project_questions q +JOIN project_answers a ON a.question_id = q.id +WHERE a.id = $1; \ No newline at end of file diff --git a/backend/db/projects.sql.go b/backend/db/projects.sql.go index bd828107..571cc5b9 100644 --- a/backend/db/projects.sql.go +++ b/backend/db/projects.sql.go @@ -395,6 +395,27 @@ func (q *Queries) GetProjectsByCompanyID(ctx context.Context, companyID string) return items, nil } +const getQuestionByAnswerID = `-- name: GetQuestionByAnswerID :one +SELECT q.id, q.question, q.section, q.required, q.validations, q.created_at, q.updated_at FROM project_questions q +JOIN project_answers a ON a.question_id = q.id +WHERE a.id = $1 +` + +func (q *Queries) GetQuestionByAnswerID(ctx context.Context, id string) (ProjectQuestion, error) { + row := q.db.QueryRow(ctx, getQuestionByAnswerID, id) + var i ProjectQuestion + err := row.Scan( + &i.ID, + &i.Question, + &i.Section, + &i.Required, + &i.Validations, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const listCompanyProjects = `-- name: ListCompanyProjects :many SELECT projects.id, projects.company_id, projects.title, projects.description, projects.status, projects.created_at, projects.updated_at FROM projects WHERE company_id = $1 diff --git a/backend/internal/tests/projects_test.go b/backend/internal/tests/projects_test.go index d2e62851..d9a7b9da 100644 --- a/backend/internal/tests/projects_test.go +++ b/backend/internal/tests/projects_test.go @@ -14,7 +14,6 @@ import ( "KonferCA/SPUR/db" "KonferCA/SPUR/internal/server" - "KonferCA/SPUR/internal/v1/v1_common" "github.com/stretchr/testify/assert" ) @@ -171,9 +170,9 @@ func TestProjectEndpoints(t *testing.T) { case "Company website": answer = "https://example.com" case "What is the core product or service, and what problem does it solve?": - answer = "Our product is a blockchain-based authentication system that solves identity verification issues." + answer = "Our product is a revolutionary blockchain-based authentication system that solves critical identity verification issues in the digital age. We provide a secure, scalable solution that eliminates fraud while maintaining user privacy and compliance with international regulations. Our system uses advanced cryptography and distributed ledger technology to ensure tamper-proof identity verification." case "What is the unique value proposition?": - answer = "We provide secure, decentralized identity verification that's faster and more reliable than traditional methods." + answer = "We provide secure, decentralized identity verification that's faster and more reliable than traditional methods. Our solution reduces verification time by 90% while increasing security and reducing costs for businesses. We are the only solution that combines biometric verification with blockchain immutability at scale." } // Patch the answer @@ -220,6 +219,39 @@ func TestProjectEndpoints(t *testing.T) { // Error cases t.Run("Error Cases", func(t *testing.T) { + // First get the questions/answers to get real IDs + path := fmt.Sprintf("/api/v1/project/%s/answers", projectID) + req := httptest.NewRequest(http.MethodGet, path, nil) + req.Header.Set("Authorization", "Bearer "+accessToken) + rec := httptest.NewRecorder() + s.GetEcho().ServeHTTP(rec, req) + + var answersResp struct { + Answers []struct { + ID string `json:"id"` + QuestionID string `json:"question_id"` + Question string `json:"question"` + } `json:"answers"` + } + err := json.NewDecoder(rec.Body).Decode(&answersResp) + assert.NoError(t, err) + + // Find answer ID for the core product question (which has min length validation) + var coreQuestionAnswerID string + var websiteQuestionAnswerID string + for _, a := range answersResp.Answers { + if strings.Contains(a.Question, "core product") { + coreQuestionAnswerID = a.ID + } + if strings.Contains(a.Question, "website") { + websiteQuestionAnswerID = a.ID + } + } + + // Ensure we found the questions we need + assert.NotEmpty(t, coreQuestionAnswerID, "Should find core product question") + assert.NotEmpty(t, websiteQuestionAnswerID, "Should find website question") + tests := []struct { name string method string @@ -249,6 +281,28 @@ func TestProjectEndpoints(t *testing.T) { expectedCode: http.StatusUnauthorized, expectedError: "missing authorization header", }, + { + name: "Invalid Answer Length", + method: http.MethodPatch, + path: fmt.Sprintf("/api/v1/project/%s/answers", projectID), + body: fmt.Sprintf(`{"content": "too short", "answer_id": "%s"}`, coreQuestionAnswerID), + setupAuth: func(req *http.Request) { + req.Header.Set("Authorization", "Bearer "+accessToken) + }, + expectedCode: http.StatusBadRequest, + expectedError: "Must be at least", + }, + { + name: "Invalid URL Format", + method: http.MethodPatch, + path: fmt.Sprintf("/api/v1/project/%s/answers", projectID), + body: fmt.Sprintf(`{"content": "not-a-url", "answer_id": "%s"}`, websiteQuestionAnswerID), + setupAuth: func(req *http.Request) { + req.Header.Set("Authorization", "Bearer "+accessToken) + }, + expectedCode: http.StatusBadRequest, + expectedError: "Must be a valid URL", + }, } for _, tc := range tests { @@ -269,10 +323,21 @@ func TestProjectEndpoints(t *testing.T) { assert.Equal(t, tc.expectedCode, rec.Code) - var errResp v1_common.APIError + var errResp struct { + Message string `json:"message"` + ValidationErrors []struct { + Question string `json:"question"` + Message string `json:"message"` + } `json:"validation_errors"` + } err := json.NewDecoder(rec.Body).Decode(&errResp) assert.NoError(t, err) - assert.Contains(t, errResp.Message, tc.expectedError) + + if len(errResp.ValidationErrors) > 0 { + assert.Contains(t, errResp.ValidationErrors[0].Message, tc.expectedError) + } else { + assert.Contains(t, errResp.Message, tc.expectedError) + } }) } }) diff --git a/backend/internal/v1/v1_projects/projects.go b/backend/internal/v1/v1_projects/projects.go index d8cfd6fe..8ea2829c 100644 --- a/backend/internal/v1/v1_projects/projects.go +++ b/backend/internal/v1/v1_projects/projects.go @@ -197,6 +197,27 @@ func (h *Handler) handlePatchProjectAnswer(c echo.Context) error { return v1_common.Fail(c, 404, "Company not found", err) } + // Get the question for this answer to check validations + question, err := h.server.GetQueries().GetQuestionByAnswerID(c.Request().Context(), req.AnswerID) + if err != nil { + return v1_common.Fail(c, 404, "Question not found", err) + } + + // Validate the answer if validations exist + if question.Validations != nil && *question.Validations != "" { + if !isValidAnswer(req.Content, *question.Validations) { + return c.JSON(http.StatusBadRequest, map[string]interface{}{ + "message": "Validation failed", + "validation_errors": []ValidationError{ + { + Question: question.Question, + Message: getValidationMessage(*question.Validations), + }, + }, + }) + } + } + // Update the answer _, err = h.server.GetQueries().UpdateProjectAnswer(c.Request().Context(), db.UpdateProjectAnswerParams{ Answer: req.Content, diff --git a/backend/internal/v1/v1_projects/validation.go b/backend/internal/v1/v1_projects/validation.go index e3e80ec0..a0a532c9 100644 --- a/backend/internal/v1/v1_projects/validation.go +++ b/backend/internal/v1/v1_projects/validation.go @@ -3,18 +3,20 @@ package v1_projects import ( "net/url" "strings" + "strconv" + "regexp" ) type validationType struct { Name string - Validate func(string) bool + Validate func(string, string) bool // (answer, param) Message string } var validationTypes = []validationType{ { Name: "url", - Validate: func(answer string) bool { + Validate: func(answer string, _ string) bool { _, err := url.ParseRequestURI(answer) return err == nil }, @@ -22,15 +24,14 @@ var validationTypes = []validationType{ }, { Name: "email", - Validate: func(answer string) bool { + Validate: func(answer string, _ string) bool { return strings.Contains(answer, "@") && strings.Contains(answer, ".") }, Message: "Must be a valid email address", }, { Name: "phone", - Validate: func(answer string) bool { - // Simple check for now - frontend will do proper formatting lol + Validate: func(answer string, _ string) bool { cleaned := strings.Map(func(r rune) rune { if r >= '0' && r <= '9' { return r @@ -41,14 +42,57 @@ var validationTypes = []validationType{ }, Message: "Must be a valid phone number", }, + { + Name: "min", + Validate: func(answer string, param string) bool { + minLen, err := strconv.Atoi(param) + if err != nil { + return false + } + return len(answer) >= minLen + }, + Message: "Must be at least %s characters long", + }, + { + Name: "max", + Validate: func(answer string, param string) bool { + maxLen, err := strconv.Atoi(param) + if err != nil { + return false + } + return len(answer) <= maxLen + }, + Message: "Must be at most %s characters long", + }, + { + Name: "regex", + Validate: func(answer string, pattern string) bool { + re, err := regexp.Compile(pattern) + if err != nil { + return false + } + return re.MatchString(answer) + }, + Message: "Must match the required format", + }, +} + +func parseValidationRule(rule string) (name string, param string) { + parts := strings.SplitN(rule, "=", 2) + name = strings.TrimSpace(parts[0]) + if len(parts) > 1 { + param = strings.TrimSpace(parts[1]) + } + return } func isValidAnswer(answer string, validations string) bool { rules := strings.Split(validations, ",") for _, rule := range rules { + name, param := parseValidationRule(rule) for _, vType := range validationTypes { - if rule == vType.Name && !vType.Validate(answer) { + if name == vType.Name && !vType.Validate(answer, param) { return false } } @@ -61,8 +105,12 @@ func getValidationMessage(validations string) string { rules := strings.Split(validations, ",") for _, rule := range rules { + name, param := parseValidationRule(rule) for _, vType := range validationTypes { - if rule == vType.Name { + if name == vType.Name { + if strings.Contains(vType.Message, "%s") { + return strings.Replace(vType.Message, "%s", param, 1) + } return vType.Message } }