From 5e335946f2969946b076dc22371d9711b9d9c572 Mon Sep 17 00:00:00 2001 From: Aidan Traboulay Date: Sat, 9 Nov 2024 22:31:11 -0400 Subject: [PATCH] move repetitive error checking to helper functions --- internal/server/company.go | 62 +++++---------- internal/server/handler_helpers.go | 41 ++++++++++ internal/server/resource_request.go | 112 ++++++++-------------------- 3 files changed, 91 insertions(+), 124 deletions(-) diff --git a/internal/server/company.go b/internal/server/company.go index cb65f62..fdc6440 100644 --- a/internal/server/company.go +++ b/internal/server/company.go @@ -2,7 +2,6 @@ package server import ( "context" - "fmt" "net/http" "github.com/KonferCA/NoKap/db" @@ -12,17 +11,13 @@ import ( func (s *Server) handleCreateCompany(c echo.Context) error { var req CreateCompanyRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Invalid request body :(") + if err := validateBody(c, &req); err != nil { + return err } - if err := c.Validate(req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - var ownerUUID pgtype.UUID - if err := ownerUUID.Scan(req.OwnerUserID); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Invalid owner ID format :(") + ownerUUID, err := validateUUID(req.OwnerUserID, "owner") + if err != nil { + return err } queries := db.New(s.DBPool) @@ -34,31 +29,22 @@ func (s *Server) handleCreateCompany(c echo.Context) error { company, err := queries.CreateCompany(context.Background(), params) if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to create company: %v", err)) + return handleDBError(err, "create", "company") } return c.JSON(http.StatusCreated, company) } func (s *Server) handleGetCompany(c echo.Context) error { - id := c.Param("id") - if id == "" { - return echo.NewHTTPError(http.StatusBadRequest, "Missing company ID :(") - } - - var companyID pgtype.UUID - if err := companyID.Scan(id); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Invalid company ID format :(") + companyID, err := validateUUID(c.Param("id"), "company") + if err != nil { + return err } queries := db.New(s.DBPool) company, err := queries.GetCompanyByID(context.Background(), companyID) if err != nil { - if isNoRowsError(err) { - return echo.NewHTTPError(http.StatusNotFound, "Company not found :(") - } - - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch company :(") + return handleDBError(err, "fetch", "company") } return c.JSON(http.StatusOK, company) @@ -66,43 +52,29 @@ func (s *Server) handleGetCompany(c echo.Context) error { func (s *Server) handleListCompanies(c echo.Context) error { queries := db.New(s.DBPool) - companies, err := queries.ListCompanies(context.Background()) if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch companies :(") + return handleDBError(err, "fetch", "companies") } return c.JSON(http.StatusOK, companies) } func (s *Server) handleDeleteCompany(c echo.Context) error { - id := c.Param("id") - if id == "" { - return echo.NewHTTPError(http.StatusBadRequest, "Missing company ID :(") - } - - var companyID pgtype.UUID - if err := companyID.Scan(id); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Invalid company ID format :(") + companyID, err := validateUUID(c.Param("id"), "company") + if err != nil { + return err } queries := db.New(s.DBPool) - _, err := queries.GetCompanyByID(context.Background(), companyID) + _, err = queries.GetCompanyByID(context.Background(), companyID) if err != nil { - if isNoRowsError(err) { - return echo.NewHTTPError(http.StatusNotFound, "Company not found :(") - } - - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to verify company :(") + return handleDBError(err, "verify", "company") } err = queries.DeleteCompany(context.Background(), companyID) if err != nil { - if isNoRowsError(err) { - return echo.NewHTTPError(http.StatusNotFound, "Company not found :(") - } - - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete company :(") + return handleDBError(err, "delete", "company") } return c.NoContent(http.StatusNoContent) diff --git a/internal/server/handler_helpers.go b/internal/server/handler_helpers.go index dd12c85..066e69e 100644 --- a/internal/server/handler_helpers.go +++ b/internal/server/handler_helpers.go @@ -1,5 +1,46 @@ package server +import ( + "fmt" + "net/http" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/labstack/echo/v4" +) + +func validateBody(c echo.Context, requestBodyType interface{}) error { + if err := c.Bind(requestBodyType); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "Invalid request body :(") + } + + if err := c.Validate(requestBodyType); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + + return nil +} + +func validateUUID(id string, fieldName string) (pgtype.UUID, error) { + var uuid pgtype.UUID + if id == "" { + return uuid, echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Missing %s ID :(", fieldName)) + } + + if err := uuid.Scan(id); err != nil { + return uuid, echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid %s ID format :(", fieldName)) + } + + return uuid, nil +} + +func handleDBError(err error, operation string, resourceType string) error { + if isNoRowsError(err) { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("%s not found :(", resourceType)) + } + + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to %s %s :(", operation, resourceType)) +} + func isNoRowsError(err error) bool { return err != nil && err.Error() == "no rows in dis set" } diff --git a/internal/server/resource_request.go b/internal/server/resource_request.go index 2efd171..859b7d6 100644 --- a/internal/server/resource_request.go +++ b/internal/server/resource_request.go @@ -2,7 +2,6 @@ package server import ( "context" - "fmt" "net/http" "github.com/KonferCA/NoKap/db" @@ -12,27 +11,19 @@ import ( func (s *Server) handleCreateResourceRequest(c echo.Context) error { var req CreateResourceRequestRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Invalid request body :(") + if err := validateBody(c, &req); err != nil { + return err } - if err := c.Validate(req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) - } - - var companyID pgtype.UUID - if err := companyID.Scan(req.CompanyID); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Invalid company ID format :(") + companyID, err := validateUUID(req.CompanyID, "company") + if err != nil { + return err } queries := db.New(s.DBPool) - _, err := queries.GetCompanyByID(context.Background(), companyID) + _, err = queries.GetCompanyByID(context.Background(), companyID) if err != nil { - if isNoRowsError(err) { - return echo.NewHTTPError(http.StatusNotFound, "Company not found :(") - } - - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to verify company :(") + return handleDBError(err, "verify", "company") } params := db.CreateResourceRequestParams{ @@ -44,30 +35,22 @@ func (s *Server) handleCreateResourceRequest(c echo.Context) error { request, err := queries.CreateResourceRequest(context.Background(), params) if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to create resource request: %v", err)) + return handleDBError(err, "create", "resource request") } return c.JSON(http.StatusCreated, request) } func (s *Server) handleGetResourceRequest(c echo.Context) error { - id := c.Param("id") - if id == "" { - return echo.NewHTTPError(http.StatusBadRequest, "Missing resource request ID :(") - } - - var requestID pgtype.UUID - if err := requestID.Scan(id); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Invalid resource request ID format :(") + requestID, err := validateUUID(c.Param("id"), "resource request") + if err != nil { + return err } queries := db.New(s.DBPool) request, err := queries.GetResourceRequestByID(context.Background(), requestID) if err != nil { - if isNoRowsError(err) { - return echo.NewHTTPError(http.StatusNotFound, "Resource request not found :(") - } - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource request :(") + return handleDBError(err, "fetch", "resource request") } return c.JSON(http.StatusOK, request) @@ -78,23 +61,19 @@ func (s *Server) handleListResourceRequests(c echo.Context) error { queries := db.New(s.DBPool) if companyID != "" { - var companyUUID pgtype.UUID - if err := companyUUID.Scan(companyID); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Invalid company ID format :(") + companyUUID, err := validateUUID(companyID, "company") + if err != nil { + return err } - _, err := queries.GetCompanyByID(context.Background(), companyUUID) + _, err = queries.GetCompanyByID(context.Background(), companyUUID) if err != nil { - if isNoRowsError(err) { - return echo.NewHTTPError(http.StatusNotFound, "Company not found :(") - } - - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to verify company :(") + return handleDBError(err, "verify", "company") } requests, err := queries.ListResourceRequestsByCompany(context.Background(), companyUUID) if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource requests :(") + return handleDBError(err, "fetch", "resource requests") } return c.JSON(http.StatusOK, requests) @@ -102,41 +81,29 @@ func (s *Server) handleListResourceRequests(c echo.Context) error { requests, err := queries.ListResourceRequests(context.Background()) if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch resource requests :(") + return handleDBError(err, "fetch", "resource requests") } return c.JSON(http.StatusOK, requests) } func (s *Server) handleUpdateResourceRequestStatus(c echo.Context) error { - id := c.Param("id") - if id == "" { - return echo.NewHTTPError(http.StatusBadRequest, "Missing resource request ID :(") - } - - var requestID pgtype.UUID - if err := requestID.Scan(id); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Invalid resource request ID format :(") + requestID, err := validateUUID(c.Param("id"), "resource request") + if err != nil { + return err } var status struct { Status string `json:"status" validate:"required"` } - if err := c.Bind(&status); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Invalid request body :(") - } - if err := c.Validate(status); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + if err := validateBody(c, &status); err != nil { + return err } queries := db.New(s.DBPool) - _, err := queries.GetResourceRequestByID(context.Background(), requestID) + _, err = queries.GetResourceRequestByID(context.Background(), requestID) if err != nil { - if isNoRowsError(err) { - return echo.NewHTTPError(http.StatusNotFound, "Resource request not found :(") - } - - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to verify resource request :(") + return handleDBError(err, "verify", "resource request") } request, err := queries.UpdateResourceRequestStatus(context.Background(), db.UpdateResourceRequestStatusParams{ @@ -144,40 +111,27 @@ func (s *Server) handleUpdateResourceRequestStatus(c echo.Context) error { Status: status.Status, }) if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to update resource request status :(") + return handleDBError(err, "update", "resource request status") } return c.JSON(http.StatusOK, request) } func (s *Server) handleDeleteResourceRequest(c echo.Context) error { - id := c.Param("id") - if id == "" { - return echo.NewHTTPError(http.StatusBadRequest, "Missing resource request ID :(") - } - - var requestID pgtype.UUID - if err := requestID.Scan(id); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Invalid resource request ID format :(") + requestID, err := validateUUID(c.Param("id"), "resource request") + if err != nil { + return err } queries := db.New(s.DBPool) - _, err := queries.GetResourceRequestByID(context.Background(), requestID) + _, err = queries.GetResourceRequestByID(context.Background(), requestID) if err != nil { - if isNoRowsError(err) { - return echo.NewHTTPError(http.StatusNotFound, "Resource request not found :(") - } - - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to verify resource request :(") + return handleDBError(err, "verify", "resource request") } err = queries.DeleteResourceRequest(context.Background(), requestID) if err != nil { - if isNoRowsError(err) { - return echo.NewHTTPError(http.StatusNotFound, "Resource request not found :(") - } - - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete resource request :(") + handleDBError(err, "delete", "resource request") } return c.NoContent(http.StatusNoContent)