Skip to content

Commit

Permalink
Fix param validation support for min/max length and regex patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
AmirAgassi committed Dec 27, 2024
1 parent c1671d4 commit 338e434
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 13 deletions.
7 changes: 6 additions & 1 deletion backend/.sqlc/queries/projects.sql
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,9 @@ UPDATE projects
SET
status = $1,
updated_at = extract(epoch from now())
WHERE id = $2;
WHERE id = $2;

-- name: GetQuestionByAnswerID :one
SELECT q.* FROM project_questions q
JOIN project_answers a ON a.question_id = q.id
WHERE a.id = $1;
21 changes: 21 additions & 0 deletions backend/db/projects.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

75 changes: 70 additions & 5 deletions backend/internal/tests/projects_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (

"KonferCA/SPUR/db"
"KonferCA/SPUR/internal/server"
"KonferCA/SPUR/internal/v1/v1_common"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -171,9 +170,9 @@ func TestProjectEndpoints(t *testing.T) {
case "Company website":
answer = "https://example.com"
case "What is the core product or service, and what problem does it solve?":
answer = "Our product is a blockchain-based authentication system that solves identity verification issues."
answer = "Our product is a revolutionary blockchain-based authentication system that solves critical identity verification issues in the digital age. We provide a secure, scalable solution that eliminates fraud while maintaining user privacy and compliance with international regulations. Our system uses advanced cryptography and distributed ledger technology to ensure tamper-proof identity verification."
case "What is the unique value proposition?":
answer = "We provide secure, decentralized identity verification that's faster and more reliable than traditional methods."
answer = "We provide secure, decentralized identity verification that's faster and more reliable than traditional methods. Our solution reduces verification time by 90% while increasing security and reducing costs for businesses. We are the only solution that combines biometric verification with blockchain immutability at scale."
}

// Patch the answer
Expand Down Expand Up @@ -220,6 +219,39 @@ func TestProjectEndpoints(t *testing.T) {

// Error cases
t.Run("Error Cases", func(t *testing.T) {
// First get the questions/answers to get real IDs
path := fmt.Sprintf("/api/v1/project/%s/answers", projectID)
req := httptest.NewRequest(http.MethodGet, path, nil)
req.Header.Set("Authorization", "Bearer "+accessToken)
rec := httptest.NewRecorder()
s.GetEcho().ServeHTTP(rec, req)

var answersResp struct {
Answers []struct {
ID string `json:"id"`
QuestionID string `json:"question_id"`
Question string `json:"question"`
} `json:"answers"`
}
err := json.NewDecoder(rec.Body).Decode(&answersResp)
assert.NoError(t, err)

// Find answer ID for the core product question (which has min length validation)
var coreQuestionAnswerID string
var websiteQuestionAnswerID string
for _, a := range answersResp.Answers {
if strings.Contains(a.Question, "core product") {
coreQuestionAnswerID = a.ID
}
if strings.Contains(a.Question, "website") {
websiteQuestionAnswerID = a.ID
}
}

// Ensure we found the questions we need
assert.NotEmpty(t, coreQuestionAnswerID, "Should find core product question")
assert.NotEmpty(t, websiteQuestionAnswerID, "Should find website question")

tests := []struct {
name string
method string
Expand Down Expand Up @@ -249,6 +281,28 @@ func TestProjectEndpoints(t *testing.T) {
expectedCode: http.StatusUnauthorized,
expectedError: "missing authorization header",
},
{
name: "Invalid Answer Length",
method: http.MethodPatch,
path: fmt.Sprintf("/api/v1/project/%s/answers", projectID),
body: fmt.Sprintf(`{"content": "too short", "answer_id": "%s"}`, coreQuestionAnswerID),
setupAuth: func(req *http.Request) {
req.Header.Set("Authorization", "Bearer "+accessToken)
},
expectedCode: http.StatusBadRequest,
expectedError: "Must be at least",
},
{
name: "Invalid URL Format",
method: http.MethodPatch,
path: fmt.Sprintf("/api/v1/project/%s/answers", projectID),
body: fmt.Sprintf(`{"content": "not-a-url", "answer_id": "%s"}`, websiteQuestionAnswerID),
setupAuth: func(req *http.Request) {
req.Header.Set("Authorization", "Bearer "+accessToken)
},
expectedCode: http.StatusBadRequest,
expectedError: "Must be a valid URL",
},
}

for _, tc := range tests {
Expand All @@ -269,10 +323,21 @@ func TestProjectEndpoints(t *testing.T) {

assert.Equal(t, tc.expectedCode, rec.Code)

var errResp v1_common.APIError
var errResp struct {
Message string `json:"message"`
ValidationErrors []struct {
Question string `json:"question"`
Message string `json:"message"`
} `json:"validation_errors"`
}
err := json.NewDecoder(rec.Body).Decode(&errResp)
assert.NoError(t, err)
assert.Contains(t, errResp.Message, tc.expectedError)

if len(errResp.ValidationErrors) > 0 {
assert.Contains(t, errResp.ValidationErrors[0].Message, tc.expectedError)
} else {
assert.Contains(t, errResp.Message, tc.expectedError)
}
})
}
})
Expand Down
21 changes: 21 additions & 0 deletions backend/internal/v1/v1_projects/projects.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,27 @@ func (h *Handler) handlePatchProjectAnswer(c echo.Context) error {
return v1_common.Fail(c, 404, "Company not found", err)
}

// Get the question for this answer to check validations
question, err := h.server.GetQueries().GetQuestionByAnswerID(c.Request().Context(), req.AnswerID)
if err != nil {
return v1_common.Fail(c, 404, "Question not found", err)
}

// Validate the answer if validations exist
if question.Validations != nil && *question.Validations != "" {
if !isValidAnswer(req.Content, *question.Validations) {
return c.JSON(http.StatusBadRequest, map[string]interface{}{
"message": "Validation failed",
"validation_errors": []ValidationError{
{
Question: question.Question,
Message: getValidationMessage(*question.Validations),
},
},
})
}
}

// Update the answer
_, err = h.server.GetQueries().UpdateProjectAnswer(c.Request().Context(), db.UpdateProjectAnswerParams{
Answer: req.Content,
Expand Down
62 changes: 55 additions & 7 deletions backend/internal/v1/v1_projects/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,35 @@ package v1_projects
import (
"net/url"
"strings"
"strconv"
"regexp"
)

type validationType struct {
Name string
Validate func(string) bool
Validate func(string, string) bool // (answer, param)
Message string
}

var validationTypes = []validationType{
{
Name: "url",
Validate: func(answer string) bool {
Validate: func(answer string, _ string) bool {
_, err := url.ParseRequestURI(answer)
return err == nil
},
Message: "Must be a valid URL",
},
{
Name: "email",
Validate: func(answer string) bool {
Validate: func(answer string, _ string) bool {
return strings.Contains(answer, "@") && strings.Contains(answer, ".")
},
Message: "Must be a valid email address",
},
{
Name: "phone",
Validate: func(answer string) bool {
// Simple check for now - frontend will do proper formatting lol
Validate: func(answer string, _ string) bool {
cleaned := strings.Map(func(r rune) rune {
if r >= '0' && r <= '9' {
return r
Expand All @@ -41,14 +42,57 @@ var validationTypes = []validationType{
},
Message: "Must be a valid phone number",
},
{
Name: "min",
Validate: func(answer string, param string) bool {
minLen, err := strconv.Atoi(param)
if err != nil {
return false
}
return len(answer) >= minLen
},
Message: "Must be at least %s characters long",
},
{
Name: "max",
Validate: func(answer string, param string) bool {
maxLen, err := strconv.Atoi(param)
if err != nil {
return false
}
return len(answer) <= maxLen
},
Message: "Must be at most %s characters long",
},
{
Name: "regex",
Validate: func(answer string, pattern string) bool {
re, err := regexp.Compile(pattern)
if err != nil {
return false
}
return re.MatchString(answer)
},
Message: "Must match the required format",
},
}

func parseValidationRule(rule string) (name string, param string) {
parts := strings.SplitN(rule, "=", 2)
name = strings.TrimSpace(parts[0])
if len(parts) > 1 {
param = strings.TrimSpace(parts[1])
}
return
}

func isValidAnswer(answer string, validations string) bool {
rules := strings.Split(validations, ",")

for _, rule := range rules {
name, param := parseValidationRule(rule)
for _, vType := range validationTypes {
if rule == vType.Name && !vType.Validate(answer) {
if name == vType.Name && !vType.Validate(answer, param) {
return false
}
}
Expand All @@ -61,8 +105,12 @@ func getValidationMessage(validations string) string {
rules := strings.Split(validations, ",")

for _, rule := range rules {
name, param := parseValidationRule(rule)
for _, vType := range validationTypes {
if rule == vType.Name {
if name == vType.Name {
if strings.Contains(vType.Message, "%s") {
return strings.Replace(vType.Message, "%s", param, 1)
}
return vType.Message
}
}
Expand Down

0 comments on commit 338e434

Please sign in to comment.