From 2787a4d3a8675df81be94c79ce7ef3798e9b3a16 Mon Sep 17 00:00:00 2001 From: Artur Kondas Date: Mon, 14 Aug 2023 21:02:11 +0200 Subject: [PATCH] feat: add demo5, with bad functions --- demo5/data/note_service.go | 48 +++++++++ demo5/data/note_service_test.go | 66 ++++++++++++ demo5/data/user_service.go | 62 +++++++++++ demo5/data/user_service_test.go | 84 +++++++++++++++ demo5/db/db.go | 99 +++++++++++++++++ demo5/handlers/auth.go | 53 ++++++++++ demo5/handlers/login.go | 27 +++++ demo5/handlers/login_test.go | 42 ++++++++ demo5/handlers/note.go | 66 ++++++++++++ demo5/handlers/note_test.go | 87 +++++++++++++++ demo5/handlers/restricted.go | 21 ++++ demo5/handlers/restricted_test.go | 56 ++++++++++ demo5/handlers/sign_up.go | 29 +++++ demo5/handlers/sign_up_test.go | 57 ++++++++++ demo5/mocks/db.go | 38 +++++++ demo5/server.go | 169 ++++++++++++++++++++++++++++++ demo5/server_test.go | 44 ++++++++ 17 files changed, 1048 insertions(+) create mode 100644 demo5/data/note_service.go create mode 100644 demo5/data/note_service_test.go create mode 100644 demo5/data/user_service.go create mode 100644 demo5/data/user_service_test.go create mode 100644 demo5/db/db.go create mode 100644 demo5/handlers/auth.go create mode 100644 demo5/handlers/login.go create mode 100644 demo5/handlers/login_test.go create mode 100644 demo5/handlers/note.go create mode 100644 demo5/handlers/note_test.go create mode 100644 demo5/handlers/restricted.go create mode 100644 demo5/handlers/restricted_test.go create mode 100644 demo5/handlers/sign_up.go create mode 100644 demo5/handlers/sign_up_test.go create mode 100644 demo5/mocks/db.go create mode 100644 demo5/server.go create mode 100644 demo5/server_test.go diff --git a/demo5/data/note_service.go b/demo5/data/note_service.go new file mode 100644 index 0000000..d06ecb4 --- /dev/null +++ b/demo5/data/note_service.go @@ -0,0 +1,48 @@ +package data + +import ( + "github.com/addetz/secure-code-go/demo4/db" + "github.com/google/uuid" +) + +type SecretNote struct { + ID string `json:"id"` + Username string `json:"username"` + Text string `json:"text"` +} + +// SecretNoteService maintains the user notes. +type SecretNoteService struct { + dbService db.DatabaseService +} + +// NewSecretNoteService creates a SecretNoteService that is ready to use. +func NewSecretNoteService(dbService db.DatabaseService) *SecretNoteService { + return &SecretNoteService{ + dbService: dbService, + } +} + +// Add adds a new SecretNote for the given user by using the SecretNoteService. +func (ns *SecretNoteService) Add(user string, n SecretNote) error { + id := uuid.New().String() + return ns.dbService.AddNote(id, user, n.Text) +} + +// Get returns all the SecretNotes of a given user by using the SecretNoteService. +func (ns *SecretNoteService) GetAll(user string) ([]SecretNote, error) { + dbNotes, err := ns.dbService.GetUserNotes(user) + if err != nil { + return nil, err + } + var notes []SecretNote + for _, n := range dbNotes { + notes = append(notes, SecretNote{ + ID: n.ID, + Username: n.Username, + Text: n.Text, + }) + } + + return notes, nil +} diff --git a/demo5/data/note_service_test.go b/demo5/data/note_service_test.go new file mode 100644 index 0000000..43d31eb --- /dev/null +++ b/demo5/data/note_service_test.go @@ -0,0 +1,66 @@ +package data_test + +import ( + "errors" + "testing" + + "github.com/addetz/secure-code-go/demo4/data" + "github.com/addetz/secure-code-go/demo4/db" + "github.com/addetz/secure-code-go/demo4/mocks" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestAddNote(t *testing.T) { + mockDB := new(mocks.DatabaseServiceMock) + notes := data.NewSecretNoteService(mockDB) + user := "user1" + note := data.SecretNote{ + Text: "My Secret Note", + } + mockDB.On("AddNote", mock.AnythingOfType("string"), user, note.Text).Return(nil) + err := notes.Add(user, note) + assert.Nil(t, err) +} + +func TestGetAllNotes(t *testing.T) { + t.Run("no notes found", func(t *testing.T) { + user := "user1" + mockDB := new(mocks.DatabaseServiceMock) + noteService := data.NewSecretNoteService(mockDB) + mockDB.On("GetUserNotes", user).Return(nil, errors.New("no notes found")) + notes, err := noteService.GetAll(user) + assert.Nil(t, notes) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "no notes found") + }) + t.Run("notes found", func(t *testing.T) { + user := "user1" + dbNotes := []db.Note{ + { + ID: uuid.New().String(), + Username: user, + Text: "My first note", + }, + { + ID: uuid.New().String(), + Username: user, + Text: "My second note", + }, + } + mockDB := new(mocks.DatabaseServiceMock) + noteService := data.NewSecretNoteService(mockDB) + mockDB.On("GetUserNotes", user).Return(dbNotes, nil) + notes, err := noteService.GetAll(user) + assert.Nil(t, err) + assert.NotNil(t, notes) + assert.Equal(t, len(dbNotes), len(notes)) + for i, n := range dbNotes { + assert.Equal(t, n.ID, notes[i].ID) + assert.Equal(t, n.Username, notes[i].Username) + assert.Equal(t, n.Text, notes[i].Text) + + } + }) +} diff --git a/demo5/data/user_service.go b/demo5/data/user_service.go new file mode 100644 index 0000000..ec23f4c --- /dev/null +++ b/demo5/data/user_service.go @@ -0,0 +1,62 @@ +package data + +import ( + "github.com/addetz/secure-code-go/demo4/db" + "github.com/pkg/errors" + + passwordvalidator "github.com/wagslane/go-password-validator" + "golang.org/x/crypto/bcrypt" +) + +const minEntropyBits = 60 + +type User struct { + Username, Password string +} + +// UserService holds +type UserService struct { + dbService db.DatabaseService +} + +// NewUserService creates a ready to use user service. +func NewUserService(dbService db.DatabaseService) *UserService { + return &UserService{ + dbService: dbService, + } +} + +// Add validates a user password and creates a new user. +func (us *UserService) Add(name, password string) error { + if _, err := us.dbService.GetUser(name); err == nil { + return errors.New("user exists already, please log in instead") + } + + err := passwordvalidator.Validate(password, minEntropyBits) + if err != nil { + return errors.Wrap(err, "validate new user password") + } + + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return err + } + return us.dbService.AddUser(name, string(hashedPassword)) +} + +// ValidatePassword checks the provided password of an existing user. +func (us *UserService) ValidatePassword(name, providedPwd string) error { + user, err := us.dbService.GetUser(name) + if err != nil { + return errors.Wrap(err, "user does not exist") + } + return bcrypt.CompareHashAndPassword([]byte(user.Pwd), []byte(providedPwd)) +} + +// ValidateUser checks the provided username belongs to an existing user. +func (us *UserService) ValidateUser(name string) error { + if _, err := us.dbService.GetUser(name); err != nil { + return errors.New("user not found") + } + return nil +} diff --git a/demo5/data/user_service_test.go b/demo5/data/user_service_test.go new file mode 100644 index 0000000..4cbd41a --- /dev/null +++ b/demo5/data/user_service_test.go @@ -0,0 +1,84 @@ +package data_test + +import ( + "errors" + "testing" + + "github.com/addetz/secure-code-go/demo4/data" + "github.com/addetz/secure-code-go/demo4/db" + "github.com/addetz/secure-code-go/demo4/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/crypto/bcrypt" +) + +func TestAdd(t *testing.T) { + t.Run("insufficient password", func(t *testing.T) { + name := "user1" + mockDB := new(mocks.DatabaseServiceMock) + us := data.NewUserService(mockDB) + mockDB.On("GetUser", name).Return(nil, errors.New("no user exists")) + err := us.Add(name, "test") + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "validate new user password: insecure password") + }) + t.Run("successful add", func(t *testing.T) { + name := "user1" + password := "test-horse-pen-clam" + mockDB := new(mocks.DatabaseServiceMock) + us := data.NewUserService(mockDB) + mockDB.On("GetUser", name).Return(nil, errors.New("no user exists")) + mockDB.On("AddUser", name, mock.AnythingOfType("string")).Return(nil) + err := us.Add(name, password) + assert.Nil(t, err) + }) + t.Run("duplicate user", func(t *testing.T) { + name := "user1" + password := "test-horse-pen-clam" + mockDB := new(mocks.DatabaseServiceMock) + us := data.NewUserService(mockDB) + mockDB.On("GetUser", name).Return(nil, nil) + err := us.Add(name, password) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "user exists already") + }) +} + +func TestValidate(t *testing.T) { + t.Run("successful validate", func(t *testing.T) { + name := "user1" + password := "test-horse-pen-clam" + expected, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + mockDB := new(mocks.DatabaseServiceMock) + us := data.NewUserService(mockDB) + mockDB.On("GetUser", name).Return(&db.User{ + Username: name, + Pwd: string(expected), + }, nil) + err := us.ValidatePassword(name, password) + assert.Nil(t, err) + }) + t.Run("failed validate", func(t *testing.T) { + name := "user1" + password := "test-horse-pen-clam" + expected, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + mockDB := new(mocks.DatabaseServiceMock) + mockDB.On("GetUser", name).Return(&db.User{ + Username: name, + Pwd: string(expected), + }, nil) + us := data.NewUserService(mockDB) + err := us.ValidatePassword(name, "garbage-password") + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "hashedPassword is not the hash of the given password") + }) + t.Run("inexistent user", func(t *testing.T) { + name := "user1" + mockDB := new(mocks.DatabaseServiceMock) + us := data.NewUserService(mockDB) + mockDB.On("GetUser", name).Return(nil, errors.New("no user exists")) + err := us.ValidatePassword(name, "garbage-password") + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "user does not exist") + }) +} diff --git a/demo5/db/db.go b/demo5/db/db.go new file mode 100644 index 0000000..803ed5b --- /dev/null +++ b/demo5/db/db.go @@ -0,0 +1,99 @@ +package db + +import ( + "database/sql" + "log" +) + +type User struct { + Username, Pwd string +} + +type Note struct { + ID, Username, Text string +} + +type dbService struct { + db *sql.DB +} + +type DatabaseService interface { + AddUser(username, pwd string) error + GetUser(username string) (*User, error) + AddNote(id, username, text string) error + GetUserNotes(username string) ([]Note, error) +} + +// NewDatabaseService initialises a DatabaseService given its dependencies. +func NewDatabaseService(db *sql.DB) *dbService { + return &dbService{ + db: db, + } +} + +// AddUser creates a new user in the DB +func (ds *dbService) AddUser(username, pwd string) error { + stmt, err := ds.db.Prepare("INSERT INTO users (username, pwd) VALUES( $1, $2 )") + if err != nil { + log.Println("error1", err) + return err + } + defer stmt.Close() + if _, err := stmt.Exec(username, pwd); err != nil { + log.Println("error2", err) + return err + } + return nil +} + +// GetUser returns a user from the database or an error if none exists. +func (ds *dbService) GetUser(username string) (*User, error) { + var user User + stmt, err := ds.db.Prepare("SELECT * FROM users WHERE username = $1 ") + if err != nil { + log.Println("error3", err) + return nil, err + } + defer stmt.Close() + if err := stmt.QueryRow(username).Scan(&user.Username, &user.Pwd); err != nil { + log.Println("error4", err) + return nil, err + } + return &user, nil +} + +// AddNote creates a new note in the DB +func (ds *dbService) AddNote(id, username, text string) error { + stmt, err := ds.db.Prepare("INSERT INTO notes(id, username, noteText) VALUES($1, $2, $3)") + if err != nil { + return err + } + defer stmt.Close() + if _, err := stmt.Exec(id, username, text); err != nil { + return err + } + return nil +} + +// GetUserNotes returns all the notes of a given user from the database or an error. +func (ds *dbService) GetUserNotes(username string) ([]Note, error) { + var notes []Note + stmt, err := ds.db.Prepare("SELECT * FROM notes WHERE username = $1") + if err != nil { + return nil, err + } + defer stmt.Close() + rows, err := stmt.Query(username) + if err != nil { + return nil, err + } + defer rows.Close() + for rows.Next() { + n := Note{} + if err := rows.Scan(&n.ID, &n.Username, &n.Text); err != nil { + return nil, err + } + notes = append(notes, n) + } + return notes, nil +} diff --git a/demo5/handlers/auth.go b/demo5/handlers/auth.go new file mode 100644 index 0000000..0518a82 --- /dev/null +++ b/demo5/handlers/auth.go @@ -0,0 +1,53 @@ +package handlers + +import ( + "time" + + "github.com/addetz/secure-code-go/demo4/data" + "github.com/addetz/secure-code-go/demo4/db" + "github.com/golang-jwt/jwt/v5" +) + +// UserRequest represents a login or sign up request +type UserRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} + +// JWTCustomClaims are custom claims extending default ones. +type JWTCustomClaims struct { + Username string `json:"username"` + jwt.RegisteredClaims +} + +type UserAuthService struct { + userService *data.UserService + secretNotesService *data.SecretNoteService + secret string +} + +func NewUserAuthService(secret string, dbService db.DatabaseService) *UserAuthService { + us := data.NewUserService(dbService) + ns := data.NewSecretNoteService(dbService) + return &UserAuthService{ + userService: us, + secretNotesService: ns, + secret: secret, + } +} + +func (us *UserAuthService) EncodeToken(username string) (string, error) { + // Set custom claims + claims := &JWTCustomClaims{ + username, + jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 1)), + }, + } + + // Create token with claims + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + // Generate encoded token and send it as response. + return token.SignedString([]byte(us.secret)) +} diff --git a/demo5/handlers/login.go b/demo5/handlers/login.go new file mode 100644 index 0000000..f9286a1 --- /dev/null +++ b/demo5/handlers/login.go @@ -0,0 +1,27 @@ +package handlers + +import ( + "net/http" + + "github.com/labstack/echo/v4" + "github.com/pkg/errors" +) + +func (authService *UserAuthService) Login(c echo.Context) error { + u := new(UserRequest) + if err := c.Bind(u); err != nil { + return c.String(http.StatusBadRequest, "bad request") + } + if err := authService.userService.ValidatePassword(u.Username, u.Password); err != nil { + return errors.Wrap(err, "login") + } + + t, err := authService.EncodeToken(u.Username) + if err != nil { + return errors.Wrap(err, "login") + } + + return c.JSON(http.StatusOK, echo.Map{ + "token": t, + }) +} diff --git a/demo5/handlers/login_test.go b/demo5/handlers/login_test.go new file mode 100644 index 0000000..0a3a386 --- /dev/null +++ b/demo5/handlers/login_test.go @@ -0,0 +1,42 @@ +package handlers_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/addetz/secure-code-go/demo4/db" + "github.com/addetz/secure-code-go/demo4/handlers" + "github.com/addetz/secure-code-go/demo4/mocks" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/bcrypt" +) + +func TestLogin(t *testing.T) { + password := "potato-cheese-entropy-romania" + username := "user1" + successfulUser := fmt.Sprintf(`{"username":"%s","password":"%s"}`,username, password) + expected, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + + // Setup + e := echo.New() + mockDB := new(mocks.DatabaseServiceMock) + userAuthService := handlers.NewUserAuthService("testing-signing-key", mockDB) + mockDB.On("GetUser", username).Return(&db.User{ + Username: username, + Pwd: string(expected), + }, nil) + + // Login + reqLogin := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(successfulUser)) + reqLogin.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + recLogin := httptest.NewRecorder() + cLogin := e.NewContext(reqLogin, recLogin) + err := userAuthService.Login(cLogin) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, recLogin.Code) + assert.Contains(t, recLogin.Body.String(), "token") +} diff --git a/demo5/handlers/note.go b/demo5/handlers/note.go new file mode 100644 index 0000000..78f209b --- /dev/null +++ b/demo5/handlers/note.go @@ -0,0 +1,66 @@ +package handlers + +import ( + "errors" + "net/http" + + "github.com/addetz/secure-code-go/demo4/data" + "github.com/golang-jwt/jwt/v5" + "github.com/labstack/echo/v4" +) + +// GetUserNotes returns all the notes of a given user. +func (authService *UserAuthService) GetUserNotes(c echo.Context) error { + user := c.Get("user").(*jwt.Token) + claims := user.Claims.(*JWTCustomClaims) + name := claims.Username + if err := authService.userService.ValidateUser(name); err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, err) + } + paramName := c.Param("id") + if name != paramName { + return echo.NewHTTPError(http.StatusUnauthorized, errors.New("not logged in as notes owner")) + } + secretNotes, err := authService.secretNotesService.GetAll(paramName) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, err) + } + + return c.JSON(http.StatusOK, echo.Map{ + "username": name, + "notes": secretNotes, + }) +} + +// AddUserNote adds a note belonging to the given user +func (authService *UserAuthService) AddUserNote(c echo.Context) error { + user := c.Get("user").(*jwt.Token) + claims := user.Claims.(*JWTCustomClaims) + name := claims.Username + if err := authService.userService.ValidateUser(name); err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, err) + } + paramName := c.Param("id") + if name != paramName { + return echo.NewHTTPError(http.StatusUnauthorized, errors.New("not logged in as notes owner")) + } + + newNote := new(data.SecretNote) + if err := c.Bind(newNote); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err) + } + + //add the note + if err := authService.secretNotesService.Add(paramName, *newNote); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err) + } + secretNotes, err := authService.secretNotesService.GetAll(paramName) + if err != nil { + return echo.NewHTTPError(http.StatusNotFound, err) + } + + return c.JSON(http.StatusCreated, echo.Map{ + "username": name, + "notes": secretNotes, + }) +} diff --git a/demo5/handlers/note_test.go b/demo5/handlers/note_test.go new file mode 100644 index 0000000..0712d25 --- /dev/null +++ b/demo5/handlers/note_test.go @@ -0,0 +1,87 @@ +package handlers_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/addetz/secure-code-go/demo4/db" + "github.com/addetz/secure-code-go/demo4/handlers" + "github.com/addetz/secure-code-go/demo4/mocks" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + echojwt "github.com/labstack/echo-jwt/v4" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestNotes(t *testing.T) { + username := "user1" + noteText := "my super duper secret" + mockDB := new(mocks.DatabaseServiceMock) + userAuthService := handlers.NewUserAuthService("testing-signing-key", mockDB) + mockDB.On("GetUser", username).Return(nil, nil) + token, err := userAuthService.EncodeToken(username) + assert.Nil(t, err) + mockDB.On("AddNote", mock.AnythingOfType("string"), username, noteText).Return(nil) + mockDB.On("GetUserNotes", username).Return([]db.Note{ + { + ID: uuid.New().String(), + Username: username, + Text: noteText, + }, + }, nil) + + // set up restricted path middleware + e := echo.New() + e.POST("/secretNotes/:id", func(c echo.Context) error { + return userAuthService.AddUserNote(c) + }) + e.GET("/secretNotes/:id", func(c echo.Context) error { + return userAuthService.GetUserNotes(c) + }) + + e.Use(echojwt.WithConfig(echojwt.Config{ + NewClaimsFunc: func(c echo.Context) jwt.Claims { + return new(handlers.JWTCustomClaims) + }, + SigningKey: []byte("testing-signing-key"), + })) + + t.Run("successful add note", func(t *testing.T) { + newNote := `{"text":"my super duper secret"}` + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/secretNotes/%s", username), strings.NewReader(newNote)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + req.Header.Set(echo.HeaderAuthorization, fmt.Sprintf("Bearer %s", token)) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusCreated, res.Code) + assert.Contains(t, res.Body.String(), username) + assert.Contains(t, res.Body.String(), noteText) + }) + + t.Run("successful get notes", func(t *testing.T) { + reqGet := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/secretNotes/%s", username), nil) + reqGet.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + reqGet.Header.Set(echo.HeaderAuthorization, fmt.Sprintf("Bearer %s", token)) + resGet := httptest.NewRecorder() + e.ServeHTTP(resGet, reqGet) + + assert.Equal(t, http.StatusOK, resGet.Code) + assert.Contains(t, resGet.Body.String(), username) + assert.Contains(t, resGet.Body.String(), noteText) + }) + + t.Run("no token", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/secretNotes/%s", username), nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusUnauthorized, res.Code) + assert.Contains(t, res.Body.String(), "missing or malformed jwt") + }) +} diff --git a/demo5/handlers/restricted.go b/demo5/handlers/restricted.go new file mode 100644 index 0000000..0899d9d --- /dev/null +++ b/demo5/handlers/restricted.go @@ -0,0 +1,21 @@ +package handlers + +import ( + "fmt" + "net/http" + + "github.com/golang-jwt/jwt/v5" + "github.com/labstack/echo/v4" +) + +func (authService *UserAuthService) RestrictedPath(c echo.Context) error { + user := c.Get("user").(*jwt.Token) + claims := user.Claims.(*JWTCustomClaims) + name := claims.Username + if err := authService.userService.ValidateUser(name); err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, err) + } + return c.JSON(http.StatusOK, echo.Map{ + "message": fmt.Sprintf("You're logged in %s!", name), + }) +} diff --git a/demo5/handlers/restricted_test.go b/demo5/handlers/restricted_test.go new file mode 100644 index 0000000..7cee634 --- /dev/null +++ b/demo5/handlers/restricted_test.go @@ -0,0 +1,56 @@ +package handlers_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/addetz/secure-code-go/demo4/handlers" + "github.com/addetz/secure-code-go/demo4/mocks" + "github.com/golang-jwt/jwt/v5" + echojwt "github.com/labstack/echo-jwt/v4" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestRestricted(t *testing.T) { + user := "user1" + mockDB := new(mocks.DatabaseServiceMock) + userAuthService := handlers.NewUserAuthService("testing-signing-key", mockDB) + mockDB.On("GetUser", user).Return(nil, nil) + token, err := userAuthService.EncodeToken(user) + assert.Nil(t, err) + + // set up restricted path middleware + e := echo.New() + e.GET("/restricted", func(c echo.Context) error { + return userAuthService.RestrictedPath(c) + }) + + e.Use(echojwt.WithConfig(echojwt.Config{ + NewClaimsFunc: func(c echo.Context) jwt.Claims { + return new(handlers.JWTCustomClaims) + }, + SigningKey: []byte("testing-signing-key"), + })) + + t.Run("successful restricted", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/restricted", nil) + req.Header.Set(echo.HeaderAuthorization, fmt.Sprintf("Bearer %s", token)) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusOK, res.Code) + assert.Contains(t, res.Body.String(), "You're logged in") + }) + + t.Run("no token", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/restricted", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusUnauthorized, res.Code) + assert.Contains(t, res.Body.String(), "missing or malformed jwt") + }) +} diff --git a/demo5/handlers/sign_up.go b/demo5/handlers/sign_up.go new file mode 100644 index 0000000..c29cb2e --- /dev/null +++ b/demo5/handlers/sign_up.go @@ -0,0 +1,29 @@ +package handlers + +import ( + "net/http" + + "github.com/labstack/echo/v4" + "github.com/pkg/errors" +) + +func (authService *UserAuthService) SignUp(c echo.Context) error { + // Read user request + u := new(UserRequest) + if err := c.Bind(u); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err) + } + // Send user data to the user service + if err := authService.userService.Add(u.Username, u.Password); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, errors.Wrap(err, "sign up")) + } + + t, err := authService.EncodeToken(u.Username) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err) + } + + return c.JSON(http.StatusCreated, echo.Map{ + "token": t, + }) +} diff --git a/demo5/handlers/sign_up_test.go b/demo5/handlers/sign_up_test.go new file mode 100644 index 0000000..3856cf0 --- /dev/null +++ b/demo5/handlers/sign_up_test.go @@ -0,0 +1,57 @@ +package handlers_test + +import ( + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/addetz/secure-code-go/demo4/handlers" + "github.com/addetz/secure-code-go/demo4/mocks" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestSignUp(t *testing.T) { + password := "potato-cheese-entropy-romania" + username := "user1" + successfulUser := fmt.Sprintf(`{"username":"%s","password":"%s"}`, username, password) + + t.Run("successful sign up", func(t *testing.T) { + // Setup + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(successfulUser)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + mockDB := new(mocks.DatabaseServiceMock) + userAuthService := handlers.NewUserAuthService("testing-signing-key", mockDB) + mockDB.On("GetUser", username).Return(nil, errors.New("no user")) + mockDB.On("AddUser", username, mock.AnythingOfType("string")).Return(nil) + + // Assertions + err := userAuthService.SignUp(c) + assert.Nil(t, err) + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Contains(t, rec.Body.String(), "token") + }) + + t.Run("repeated sign up", func(t *testing.T) { + // Setup + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/signup", strings.NewReader(successfulUser)) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + mockDB := new(mocks.DatabaseServiceMock) + userAuthService := handlers.NewUserAuthService("testing-signing-key", mockDB) + mockDB.On("GetUser", username).Return(nil, nil) + + // Assertions + err := userAuthService.SignUp(c) + assert.NotNil(t, err) + }) +} diff --git a/demo5/mocks/db.go b/demo5/mocks/db.go new file mode 100644 index 0000000..1dc0bfd --- /dev/null +++ b/demo5/mocks/db.go @@ -0,0 +1,38 @@ +package mocks + +import ( + "github.com/addetz/secure-code-go/demo4/db" + "github.com/stretchr/testify/mock" +) + +type DatabaseServiceMock struct { + mock.Mock +} + +func (m *DatabaseServiceMock) AddUser(username, pwd string) error { + args := m.Called(username, pwd) + return args.Error(0) +} + +func (m *DatabaseServiceMock) GetUser(username string) (*db.User, error) { + args := m.Called(username) + if args.Get(0) == nil { + return nil, args.Error(1) + } + userArg := args.Get(0).(*db.User) + return userArg, args.Error(1) +} + +func (m *DatabaseServiceMock) AddNote(id, username, text string) error { + args := m.Called(id, username, text) + return args.Error(0) +} + +func (m *DatabaseServiceMock) GetUserNotes(username string) ([]db.Note, error) { + args := m.Called(username) + if args.Get(0) == nil { + return nil, args.Error(1) + } + notesArg := args.Get(0).([]db.Note) + return notesArg, args.Error(1) +} diff --git a/demo5/server.go b/demo5/server.go new file mode 100644 index 0000000..ecdc7d8 --- /dev/null +++ b/demo5/server.go @@ -0,0 +1,169 @@ +package main + +import ( + "database/sql" + "fmt" + "log" + "net/http" + "os" + "time" + + "github.com/addetz/secure-code-go/demo4/db" + "github.com/addetz/secure-code-go/demo4/handlers" + echojwt "github.com/labstack/echo-jwt/v4" + + "github.com/golang-jwt/jwt/v5" + echo "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + _ "github.com/lib/pq" +) + +const TIMEOUT = 3 * time.Second + +type Response struct { + Message string `json:"message"` +} + +func main() { + // Read paths to certificate & private key from environment variables + certFile, ok := os.LookupEnv("SERVER_CERT_FILE") + if !ok { + log.Fatal("SERVER_CERT_FILE variable must be set") + } + keyFile, ok := os.LookupEnv("SERVER_KEY_FILE") + if !ok { + log.Fatal("SERVER_KEY_FILE variable must be set") + } + signingKey, ok := os.LookupEnv("SIGNING_KEY") + if !ok { + log.Fatal("SIGNING_KEY variable must be set") + } + // Read port if one is set + port := readPort() + + // Connect to database + dbConn := connectDatabase() + // Shut down connection when server shuts down + defer func() { + dbConn.Close() + }() + dbService := db.NewDatabaseService(dbConn) + + // Set up internal services + userAuthService := handlers.NewUserAuthService(signingKey, dbService) + + // Initialise echo + e := echo.New() + e.Use(middleware.Logger()) + e.Use(middleware.Recover()) + + // Configure server + s := http.Server{ + Addr: fmt.Sprintf(":%s", port), + Handler: e, + ReadTimeout: TIMEOUT, + ReadHeaderTimeout: TIMEOUT, + WriteTimeout: TIMEOUT, + IdleTimeout: TIMEOUT, + } + + // Set up the root route + e.GET("/", func(c echo.Context) error { + return c.JSON(http.StatusOK, &Response{ + Message: "Hello, Gophers!", + }) + }) + + // Set up authentication routes + e.POST("/signup", func(c echo.Context) error { + return userAuthService.SignUp(c) + }) + e.POST("/login", func(c echo.Context) error { + return userAuthService.Login(c) + }) + + // Restricted route only for logged in users + r := e.Group("/restricted") + config := echojwt.Config{ + NewClaimsFunc: func(c echo.Context) jwt.Claims { + return new(handlers.JWTCustomClaims) + }, + SigningKey: []byte(signingKey), + } + + r.Use(echojwt.WithConfig(config)) + r.GET("", func(c echo.Context) error { + return userAuthService.RestrictedPath(c) + }) + // Get all user's notes + r.GET("/secretNotes/:id", func(c echo.Context) error { + return userAuthService.GetUserNotes(c) + }) + // Add new note for user + r.POST("/secretNotes/:id", func(c echo.Context) error { + return userAuthService.AddUserNote(c) + }) + + // we're not checking an error here, + // so even our IDE is mad + str, _ := stringAndError() + + // this error will never be checked + defer onlyError() + + log.Printf("Listening on :%s...\n", port) + if err := s.ListenAndServeTLS(certFile, keyFile); err != http.ErrServerClosed { + log.Fatal(err) + } +} + +func stringAndError() (string, error) { + // This function will return a string and an error + return "I am not checked", nil +} + +func onlyError() error { + return fmt.Errorf("I am an error!!!") +} + +func readPort() string { + port, ok := os.LookupEnv("SERVER_PORT") + if !ok { + return "1323" + } + return port +} + +func connectDatabase() *sql.DB { + user, ok := os.LookupEnv("POSTGRES_USER") + if !ok { + log.Fatal("POSTGRES_USER variable must be set") + } + pwd, ok := os.LookupEnv("POSTGRES_PWD") + if !ok { + log.Fatal("POSTGRES_PWD variable must be set") + } + db, ok := os.LookupEnv("POSTGRES_DB") + if !ok { + log.Fatal("POSTGRES_DB variable must be set") + } + + connectionStr := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pwd, db) + conn, err := sql.Open("postgres", connectionStr) + if err != nil { + log.Fatal("connection error", err) + } + if err := conn.Ping(); err != nil { + log.Fatal("ping error", err) + } + _, err = conn.Exec("CREATE TABLE IF NOT EXISTS users (username VARCHAR(50) PRIMARY KEY, pwd VARCHAR(100) NOT NULL)") + if err != nil { + log.Fatal("create users", err) + } + _, err = conn.Exec("CREATE TABLE IF NOT EXISTS notes ( id VARCHAR (50) PRIMARY KEY," + + "username VARCHAR(50) REFERENCES users (username), noteText VARCHAR (500) NOT NULL)") + if err != nil { + panic(err) + } + return conn +} diff --git a/demo5/server_test.go b/demo5/server_test.go new file mode 100644 index 0000000..a1ea371 --- /dev/null +++ b/demo5/server_test.go @@ -0,0 +1,44 @@ +package main_test + +import ( + "fmt" + "net/http" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRouterHTTP(t *testing.T) { + if os.Getenv("E2E") == "" { + t.Skip("Skipping TestRouterHTTP in short mode.") + } + // Send an HTTP request + port := readPort() + resp, err := http.Get(fmt.Sprintf("http://localhost:%s", port)) + assert.Nil(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) +} + +func TestRouterHTTPS(t *testing.T) { + if os.Getenv("E2E") == "" { + t.Skip("Skipping TestRouterHTTPS in short mode.") + } + // Send an HTTPS request + port := readPort() + resp, err := http.Get(fmt.Sprintf("https://localhost:%s", port)) + assert.Nil(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func readPort() string { + port, ok := os.LookupEnv("SERVER_PORT") + if !ok { + return "1323" + } + return port +}