diff --git a/.sqlc/migrations/20241108045258_create_users_table.sql b/.sqlc/migrations/20241108045258_create_users_table.sql index c56009c..beb085f 100644 --- a/.sqlc/migrations/20241108045258_create_users_table.sql +++ b/.sqlc/migrations/20241108045258_create_users_table.sql @@ -18,4 +18,4 @@ CREATE TABLE users ( -- +goose Down -- +goose StatementBegin DROP TABLE users; --- +goose StatementEnd \ No newline at end of file +-- +goose StatementEnd diff --git a/.sqlc/migrations/20241116191633_user_role_enum.sql b/.sqlc/migrations/20241116191633_user_role_enum.sql new file mode 100644 index 0000000..5bd3ddb --- /dev/null +++ b/.sqlc/migrations/20241116191633_user_role_enum.sql @@ -0,0 +1,17 @@ +-- +goose Up +-- +goose StatementBegin +BEGIN; + +CREATE TYPE user_role AS ENUM ('admin', 'startup_owner', 'investor'); + +COMMIT; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +BEGIN; + +DROP TYPE user_role; + +COMMIT; +-- +goose StatementEnd diff --git a/.sqlc/migrations/20241116192218_alter_users_role_col.sql b/.sqlc/migrations/20241116192218_alter_users_role_col.sql new file mode 100644 index 0000000..7934b05 --- /dev/null +++ b/.sqlc/migrations/20241116192218_alter_users_role_col.sql @@ -0,0 +1,29 @@ +-- +goose Up +-- +goose StatementBegin +BEGIN; + +ALTER TABLE users ADD COLUMN role_enum user_role NOT NULL; + +UPDATE users +SET role_enum = role::user_role; + +ALTER TABLE users DROP COLUMN role; +ALTER TABLE users RENAME COLUMN role_enum TO role; + +COMMIT; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +BEGIN; + +ALTER TABLE users ADD COLUMN role_varchar VARCHAR(50) NOT NULL; + +UPDATE users +SET role_varchar = role::text; + +ALTER TABLE users DROP COLUMN role; +ALTER TABLE users RENAME COLUMN role_varchar TO role; + +COMMIT; +-- +goose StatementEnd diff --git a/db/models.go b/db/models.go index 2d1cd0f..2f75579 100644 --- a/db/models.go +++ b/db/models.go @@ -5,11 +5,74 @@ package db import ( + "database/sql/driver" + "fmt" "time" "github.com/jackc/pgx/v5/pgtype" ) +type UserRole string + +const ( + UserRoleAdmin UserRole = "admin" + UserRoleStartupOwner UserRole = "startup_owner" + UserRoleInvestor UserRole = "investor" +) + +func (e *UserRole) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = UserRole(s) + case string: + *e = UserRole(s) + default: + return fmt.Errorf("unsupported scan type for UserRole: %T", src) + } + return nil +} + +type NullUserRole struct { + UserRole UserRole + Valid bool // Valid is true if UserRole is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullUserRole) Scan(value interface{}) error { + if value == nil { + ns.UserRole, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.UserRole.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullUserRole) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.UserRole), nil +} + +func (e UserRole) Valid() bool { + switch e { + case UserRoleAdmin, + UserRoleStartupOwner, + UserRoleInvestor: + return true + } + return false +} + +func AllUserRoleValues() []UserRole { + return []UserRole{ + UserRoleAdmin, + UserRoleStartupOwner, + UserRoleInvestor, + } +} + type Company struct { ID string OwnerUserID string @@ -167,8 +230,8 @@ type User struct { PasswordHash string FirstName *string LastName *string - Role string WalletAddress *string CreatedAt pgtype.Timestamp UpdatedAt pgtype.Timestamp + Role UserRole } diff --git a/db/users.sql.go b/db/users.sql.go index 7265ade..de3ab9d 100644 --- a/db/users.sql.go +++ b/db/users.sql.go @@ -18,7 +18,7 @@ INSERT INTO users ( role ) VALUES ( $1, $2, $3, $4, $5 -) RETURNING id, email, password_hash, first_name, last_name, role, wallet_address, created_at, updated_at +) RETURNING id, email, password_hash, first_name, last_name, wallet_address, created_at, updated_at, role ` type CreateUserParams struct { @@ -26,7 +26,7 @@ type CreateUserParams struct { PasswordHash string FirstName *string LastName *string - Role string + Role UserRole } func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { @@ -44,16 +44,16 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e &i.PasswordHash, &i.FirstName, &i.LastName, - &i.Role, &i.WalletAddress, &i.CreatedAt, &i.UpdatedAt, + &i.Role, ) return i, err } const getUserByEmail = `-- name: GetUserByEmail :one -SELECT id, email, password_hash, first_name, last_name, role, wallet_address, created_at, updated_at FROM users +SELECT id, email, password_hash, first_name, last_name, wallet_address, created_at, updated_at, role FROM users WHERE email = $1 LIMIT 1 ` @@ -66,16 +66,16 @@ func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error &i.PasswordHash, &i.FirstName, &i.LastName, - &i.Role, &i.WalletAddress, &i.CreatedAt, &i.UpdatedAt, + &i.Role, ) return i, err } const getUserByID = `-- name: GetUserByID :one -SELECT id, email, password_hash, first_name, last_name, role, wallet_address, created_at, updated_at FROM users +SELECT id, email, password_hash, first_name, last_name, wallet_address, created_at, updated_at, role FROM users WHERE id = $1 LIMIT 1 ` @@ -88,10 +88,10 @@ func (q *Queries) GetUserByID(ctx context.Context, id string) (User, error) { &i.PasswordHash, &i.FirstName, &i.LastName, - &i.Role, &i.WalletAddress, &i.CreatedAt, &i.UpdatedAt, + &i.Role, ) return i, err } diff --git a/internal/jwt/generate.go b/internal/jwt/generate.go index a2258db..a620d5d 100644 --- a/internal/jwt/generate.go +++ b/internal/jwt/generate.go @@ -4,6 +4,7 @@ import ( "os" "time" + "github.com/KonferCA/NoKap/db" golangJWT "github.com/golang-jwt/jwt/v5" ) @@ -13,7 +14,7 @@ const ( ) // Generates JWT tokens for the given user. Returns the access token, refresh token and error (nil if no error) -func Generate(userID string, role string) (string, string, error) { +func Generate(userID string, role db.UserRole) (string, string, error) { accessToken, err := generateToken(userID, role, ACCESS_TOKEN_TYPE, time.Now().Add(10*time.Minute)) if err != nil { return "", "", err @@ -28,7 +29,7 @@ func Generate(userID string, role string) (string, string, error) { } // Private helper method to generate a token. -func generateToken(userID, role, tokenType string, exp time.Time) (string, error) { +func generateToken(userID string, role db.UserRole, tokenType string, exp time.Time) (string, error) { claims := JWTClaims{ UserID: userID, Role: role, diff --git a/internal/jwt/jwt_test.go b/internal/jwt/jwt_test.go index 0972f98..d39207e 100644 --- a/internal/jwt/jwt_test.go +++ b/internal/jwt/jwt_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/KonferCA/NoKap/db" golangJWT "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" ) @@ -14,7 +15,7 @@ func TestJWT(t *testing.T) { os.Setenv("JWT_SECRET", "secret") userID := "some-user-id" - role := "user" + role := db.UserRole("user") exp := time.Now().Add(5 * time.Minute) t.Run("generate access token", func(t *testing.T) { diff --git a/internal/jwt/types.go b/internal/jwt/types.go index 11ac0ea..a758825 100644 --- a/internal/jwt/types.go +++ b/internal/jwt/types.go @@ -1,10 +1,13 @@ package jwt -import golangJWT "github.com/golang-jwt/jwt/v5" +import ( + "github.com/KonferCA/NoKap/db" + golangJWT "github.com/golang-jwt/jwt/v5" +) type JWTClaims struct { - UserID string `json:"user_id"` - Role string `json:"role"` - TokenType string `json:"token_type"` + UserID string `json:"user_id"` + Role db.UserRole `json:"role"` + TokenType string `json:"token_type"` golangJWT.RegisteredClaims } diff --git a/internal/middleware/jwt_test.go b/internal/middleware/jwt_test.go index 4e1e70a..51437af 100644 --- a/internal/middleware/jwt_test.go +++ b/internal/middleware/jwt_test.go @@ -7,6 +7,7 @@ import ( "os" "testing" + "github.com/KonferCA/NoKap/db" "github.com/KonferCA/NoKap/internal/jwt" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" @@ -24,7 +25,7 @@ func TestProtectAPIForAccessToken(t *testing.T) { // generate valid tokens userID := "user-id" - role := "user-role" + role := db.UserRole("user-role") accessToken, refreshToken, err := jwt.Generate(userID, role) assert.Nil(t, err) @@ -103,7 +104,7 @@ func TestProtectAPIForRefreshToken(t *testing.T) { // generate valid tokens userID := "user-id" - role := "user-role" + role := db.UserRole("user-role") accessToken, refreshToken, err := jwt.Generate(userID, role) assert.Nil(t, err) diff --git a/internal/middleware/req_validator.go b/internal/middleware/req_validator.go index fe16cf2..bcd0b26 100644 --- a/internal/middleware/req_validator.go +++ b/internal/middleware/req_validator.go @@ -5,6 +5,7 @@ import ( "net/http" "reflect" + "github.com/KonferCA/NoKap/db" "github.com/go-playground/validator/v10" "github.com/labstack/echo/v4" "github.com/rs/zerolog/log" @@ -29,7 +30,9 @@ func (rv *RequestBodyValidator) Validate(i interface{}) error { // Creates a new request validator that can be set to an Echo instance // and used for validating request bodies with c.Validate() func NewRequestBodyValidator() *RequestBodyValidator { - return &RequestBodyValidator{validator: validator.New()} + v := validator.New() + v.RegisterValidation("valid_user_role", validateUserRole) + return &RequestBodyValidator{validator: v} } // Middleware that validates the incoming request body with the given structType. @@ -57,3 +60,31 @@ func ValidateRequestBody(structType reflect.Type) echo.MiddlewareFunc { } } } + +// validateUserRole validates the "valid_user_role" tag using the +// the generated valid method from SQLc. +func validateUserRole(fl validator.FieldLevel) bool { + field := fl.Field() + + // handle string type + if field.Kind() == reflect.String { + str := field.String() + ur := db.UserRole(str) + return ur.Valid() + } + + // handle db.UserRole type + if field.Type() == reflect.TypeOf(db.UserRole("")) { + ur := field.Interface().(db.UserRole) + return ur.Valid() + + } + + // handle pointer to db.UserRole + if field.Type() == reflect.TypeOf((*db.UserRole)(nil)) && !field.IsNil() { + ur := field.Interface().(*db.UserRole) + return ur.Valid() + } + + return false +} diff --git a/internal/server/auth.go b/internal/server/auth.go index d59c2fa..f61e8ee 100644 --- a/internal/server/auth.go +++ b/internal/server/auth.go @@ -3,9 +3,11 @@ package server import ( "context" "net/http" + "reflect" "github.com/KonferCA/NoKap/db" "github.com/KonferCA/NoKap/internal/jwt" + mw "github.com/KonferCA/NoKap/internal/middleware" "github.com/jackc/pgx/v5/pgtype" "github.com/labstack/echo/v4" "golang.org/x/crypto/bcrypt" @@ -14,21 +16,19 @@ import ( func (s *Server) setupAuthRoutes() { auth := s.apiV1.Group("/auth") auth.Use(s.authLimiter.RateLimit()) // special rate limit for auth routes - auth.POST("/signup", s.handleSignup) - auth.POST("/signin", s.handleSignin) + auth.POST("/signup", s.handleSignup, mw.ValidateRequestBody(reflect.TypeOf(SignupRequest{}))) + auth.POST("/signin", s.handleSignin, mw.ValidateRequestBody(reflect.TypeOf(SigninRequest{}))) } func (s *Server) handleSignup(c echo.Context) error { - var req SignupRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "invalid request body") + var req *SignupRequest + req, ok := c.Get(mw.REQUEST_BODY_KEY).(*SignupRequest) + if !ok { + // not good... no bueno + return echo.NewHTTPError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) } - if err := c.Validate(&req); err != nil { - return err - } - - ctx := context.Background() + ctx := c.Request().Context() existingUser, err := s.queries.GetUserByEmail(ctx, req.Email) if err == nil && existingUser.ID != "" { return echo.NewHTTPError(http.StatusConflict, "email already registered") @@ -70,13 +70,12 @@ func (s *Server) handleSignup(c echo.Context) error { } func (s *Server) handleSignin(c echo.Context) error { - var req SigninRequest - if err := c.Bind(&req); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "invalid request body") - } - - if err := c.Validate(&req); err != nil { - return err + req, ok := c.Get(mw.REQUEST_BODY_KEY).(*SigninRequest) + if !ok { + // no bueno... + // should never really reach this state since the validator should reject + // the request body if it is not a proper SigninRequest type + return echo.NewHTTPError(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) } ctx := context.Background() diff --git a/internal/server/types.go b/internal/server/types.go index f0c09a7..3cd983a 100644 --- a/internal/server/types.go +++ b/internal/server/types.go @@ -2,6 +2,8 @@ package server import ( "time" + + "github.com/KonferCA/NoKap/db" ) // TODO: Reorder types @@ -42,11 +44,11 @@ type CreateResourceRequestRequest struct { } type SignupRequest struct { - Email string `json:"email" validate:"required,email"` - Password string `json:"password" validate:"required,min=8"` - FirstName string `json:"first_name" validate:"required"` - LastName string `json:"last_name" validate:"required"` - Role string `json:"role" validate:"required,oneof=startup_owner admin investor"` + Email string `json:"email" validate:"required,email"` + Password string `json:"password" validate:"required,min=8"` + FirstName string `json:"first_name" validate:"required"` + LastName string `json:"last_name" validate:"required"` + Role db.UserRole `json:"role" validate:"required,valid_user_role"` } type SigninRequest struct { @@ -61,12 +63,12 @@ type AuthResponse struct { } type User struct { - ID string `json:"id"` - Email string `json:"email"` - FirstName string `json:"first_name"` - LastName string `json:"last_name"` - Role string `json:"role"` - WalletAddress *string `json:"wallet_address,omitempty"` + ID string `json:"id"` + Email string `json:"email"` + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Role db.UserRole `json:"role"` + WalletAddress *string `json:"wallet_address,omitempty"` } type CreateCompanyFinancialsRequest struct { diff --git a/sqlc.yml b/sqlc.yml index cd97b74..67cc367 100644 --- a/sqlc.yml +++ b/sqlc.yml @@ -9,6 +9,8 @@ sql: out: "db" sql_package: "pgx/v5" emit_pointers_for_null_types: true + emit_enum_valid_method: true + emit_all_enum_values: true overrides: - db_type: "uuid" go_type: "string"