Skip to content

Commit

Permalink
chore: store pending registration requests in db
Browse files Browse the repository at this point in the history
  • Loading branch information
jsiebens committed May 28, 2022
1 parent 2b5439b commit 85656c1
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 62 deletions.
1 change: 1 addition & 0 deletions internal/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ func migrate(db *gorm.DB, repository domain.Repository) error {
&domain.User{},
&domain.AuthKey{},
&domain.Machine{},
&domain.RegistrationRequest{},
)

if err != nil {
Expand Down
96 changes: 96 additions & 0 deletions internal/domain/registration_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package domain

import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"tailscale.com/tailcfg"
"time"
)

type RegistrationRequest struct {
MachineKey string `gorm:"primary_key;autoIncrement:false"`
Key string `gorm:"type:varchar(64);unique_index"`
Data RegistrationRequestData
CreatedAt time.Time
Authenticated bool
Error string
}

func (r *RegistrationRequest) IsFinished() bool {
return r.Authenticated || len(r.Error) != 0
}

type RegistrationRequestData tailcfg.RegisterRequest

func (hi *RegistrationRequestData) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, hi)
default:
return fmt.Errorf("unexpected data type %T", destination)
}
}

func (hi RegistrationRequestData) Value() (driver.Value, error) {
bytes, err := json.Marshal(hi)
return bytes, err
}

// GormDataType gorm common data type
func (RegistrationRequestData) GormDataType() string {
return "json"
}

// GormDBDataType gorm db data type
func (RegistrationRequestData) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() {
case "sqlite":
return "JSON"
}
return ""
}

func (r *repository) SaveRegistrationRequest(ctx context.Context, request *RegistrationRequest) error {
tx := r.withContext(ctx).Save(request)

if tx.Error != nil {
return tx.Error
}

return nil
}

func (r *repository) GetRegistrationRequestByKey(ctx context.Context, key string) (*RegistrationRequest, error) {
var m RegistrationRequest
tx := r.withContext(ctx).First(&m, "key = ?", key)

if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
}

if tx.Error != nil {
return nil, tx.Error
}

return &m, nil
}

func (r *repository) GetRegistrationRequestByMachineKey(ctx context.Context, key string) (*RegistrationRequest, error) {
var m RegistrationRequest
tx := r.withContext(ctx).First(&m, "machine_key = ?", key)

if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
return nil, nil
}

if tx.Error != nil {
return nil, tx.Error
}

return &m, nil
}
4 changes: 4 additions & 0 deletions internal/domain/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ type Repository interface {
SetMachineLastSeen(ctx context.Context, machineID uint64) error
ExpireMachineByAuthMethod(ctx context.Context, authMethodID uint64) (int64, error)

SaveRegistrationRequest(ctx context.Context, request *RegistrationRequest) error
GetRegistrationRequestByKey(ctx context.Context, key string) (*RegistrationRequest, error)
GetRegistrationRequestByMachineKey(ctx context.Context, key string) (*RegistrationRequest, error)

Transaction(func(rp Repository) error) error
}

Expand Down
75 changes: 48 additions & 27 deletions internal/handlers/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/mr-tron/base58"
"net/http"
"strconv"
"tailscale.com/tailcfg"
"time"

"github.com/jsiebens/ionscale/internal/config"
Expand All @@ -20,21 +21,18 @@ import (

func NewAuthenticationHandlers(
config *config.Config,
repository domain.Repository,
pendingMachineRegistrationRequests *cache.Cache) *AuthenticationHandlers {
repository domain.Repository) *AuthenticationHandlers {
return &AuthenticationHandlers{
config: config,
repository: repository,
pendingMachineRegistrationRequests: pendingMachineRegistrationRequests,
pendingOAuthUsers: cache.New(5*time.Minute, 10*time.Minute),
config: config,
repository: repository,
pendingOAuthUsers: cache.New(5*time.Minute, 10*time.Minute),
}
}

type AuthenticationHandlers struct {
repository domain.Repository
config *config.Config
pendingMachineRegistrationRequests *cache.Cache
pendingOAuthUsers *cache.Cache
repository domain.Repository
config *config.Config
pendingOAuthUsers *cache.Cache
}

type AuthFormData struct {
Expand All @@ -54,7 +52,7 @@ func (h *AuthenticationHandlers) StartAuth(c echo.Context) error {
ctx := c.Request().Context()
key := c.Param("key")

if _, ok := h.pendingMachineRegistrationRequests.Get(key); !ok {
if req, err := h.repository.GetRegistrationRequestByKey(ctx, key); err != nil || req == nil {
return c.Redirect(http.StatusFound, "/a/error")
}

Expand All @@ -73,12 +71,13 @@ func (h *AuthenticationHandlers) ProcessAuth(c echo.Context) error {
authKey := c.FormValue("ak")
authMethodId := c.FormValue("s")

if _, ok := h.pendingMachineRegistrationRequests.Get(key); !ok {
req, err := h.repository.GetRegistrationRequestByKey(ctx, key)
if err != nil || req == nil {
return c.Redirect(http.StatusFound, "/a/error")
}

if authKey != "" {
return h.endMachineRegistrationFlow(c, &oauthState{Key: key})
return h.endMachineRegistrationFlow(c, req, &oauthState{Key: key})
}

if authMethodId != "" {
Expand Down Expand Up @@ -146,12 +145,19 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error {
}

func (h *AuthenticationHandlers) EndOAuth(c echo.Context) error {
ctx := c.Request().Context()

state, err := h.readState(c.QueryParam("state"))
if err != nil {
return err
return c.Redirect(http.StatusFound, "/a/error")
}

req, err := h.repository.GetRegistrationRequestByKey(ctx, state.Key)
if err != nil || req == nil {
return c.Redirect(http.StatusFound, "/a/error")
}

return h.endMachineRegistrationFlow(c, state)
return h.endMachineRegistrationFlow(c, req, state)
}

func (h *AuthenticationHandlers) Success(c echo.Context) error {
Expand All @@ -169,22 +175,14 @@ func (h *AuthenticationHandlers) Error(c echo.Context) error {
return c.Render(http.StatusOK, "error.html", nil)
}

func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, state *oauthState) error {
func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, registrationRequest *domain.RegistrationRequest, state *oauthState) error {
ctx := c.Request().Context()

defer h.pendingMachineRegistrationRequests.Delete(state.Key)

preqItem, preqOK := h.pendingMachineRegistrationRequests.Get(state.Key)
if !preqOK {
return c.Redirect(http.StatusFound, "/a/error")
}

authKeyParam := c.FormValue("ak")
tailnetIDParam := c.FormValue("s")

preq := preqItem.(*pendingMachineRegistrationRequest)
req := preq.request
machineKey := preq.machineKey
req := tailcfg.RegisterRequest(registrationRequest.Data)
machineKey := registrationRequest.MachineKey
nodeKey := req.NodeKey.String()

var tailnet *domain.Tailnet
Expand All @@ -199,6 +197,14 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, stat
}

if authKey == nil {

registrationRequest.Authenticated = false
registrationRequest.Error = "invalid auth key"

if err := h.repository.SaveRegistrationRequest(ctx, registrationRequest); err != nil {
return c.Redirect(http.StatusFound, "/a/error")
}

return c.Redirect(http.StatusFound, "/a/error?e=iak")
}

Expand Down Expand Up @@ -315,7 +321,22 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, stat
m.ExpiresAt = nil
}

if err := h.repository.SaveMachine(ctx, m); err != nil {
err = h.repository.Transaction(func(rp domain.Repository) error {
registrationRequest.Authenticated = true
registrationRequest.Error = ""

if err := rp.SaveMachine(ctx, m); err != nil {
return err
}

if err := rp.SaveRegistrationRequest(ctx, registrationRequest); err != nil {
return err
}

return nil
})

if err != nil {
return err
}

Expand Down
59 changes: 30 additions & 29 deletions internal/handlers/registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"github.com/jsiebens/ionscale/internal/domain"
"github.com/jsiebens/ionscale/internal/util"
"github.com/labstack/echo/v4"
"github.com/patrickmn/go-cache"
"inet.af/netaddr"
"net/http"
"tailscale.com/tailcfg"
Expand All @@ -21,28 +20,20 @@ func NewRegistrationHandlers(
createBinder bind.Factory,
config *config.Config,
brokers *broker.BrokerPool,
repository domain.Repository,
pendingMachineRegistrationRequests *cache.Cache) *RegistrationHandlers {
repository domain.Repository) *RegistrationHandlers {
return &RegistrationHandlers{
createBinder: createBinder,
brokers: brokers.Get,
repository: repository,
config: config,
pendingMachineRegistrationRequests: pendingMachineRegistrationRequests,
createBinder: createBinder,
brokers: brokers.Get,
repository: repository,
config: config,
}
}

type pendingMachineRegistrationRequest struct {
machineKey string
request *tailcfg.RegisterRequest
}

type RegistrationHandlers struct {
createBinder bind.Factory
repository domain.Repository
brokers func(uint64) broker.Broker
config *config.Config
pendingMachineRegistrationRequests *cache.Cache
createBinder bind.Factory
repository domain.Repository
brokers func(uint64) broker.Broker
config *config.Config
}

func (h *RegistrationHandlers) Register(c echo.Context) error {
Expand Down Expand Up @@ -113,6 +104,8 @@ func (h *RegistrationHandlers) Register(c echo.Context) error {
}

func (h *RegistrationHandlers) authenticateMachine(c echo.Context, binder bind.Binder, machineKey string, req *tailcfg.RegisterRequest) error {
ctx := c.Request().Context()

if req.Followup != "" {
return h.followup(c, binder, req)
}
Expand All @@ -121,10 +114,18 @@ func (h *RegistrationHandlers) authenticateMachine(c echo.Context, binder bind.B
key := util.RandStringBytes(8)
authUrl := h.config.CreateUrl("/a/%s", key)

h.pendingMachineRegistrationRequests.Set(key, &pendingMachineRegistrationRequest{
machineKey: machineKey,
request: req,
}, cache.DefaultExpiration)
request := domain.RegistrationRequest{
MachineKey: machineKey,
Key: key,
CreatedAt: time.Now().UTC(),
Data: domain.RegistrationRequestData(*req),
}

err := h.repository.SaveRegistrationRequest(ctx, &request)
if err != nil {
response := tailcfg.RegisterResponse{MachineAuthorized: false, Error: "something went wrong"}
return binder.WriteResponse(c, http.StatusOK, response)
}

response := tailcfg.RegisterResponse{AuthURL: authUrl}
return binder.WriteResponse(c, http.StatusOK, response)
Expand Down Expand Up @@ -232,24 +233,24 @@ func (h *RegistrationHandlers) followup(c echo.Context, binder bind.Binder, req
// Listen to connection close
ctx := c.Request().Context()
notify := ctx.Done()
tick := time.NewTicker(5 * time.Second)
tick := time.NewTicker(2 * time.Second)

defer func() { tick.Stop() }()

machineKey := binder.Peer().String()
nodeKey := req.NodeKey.String()

for {
select {
case <-tick.C:
m, err := h.repository.GetMachineByKeys(ctx, machineKey, nodeKey)
m, err := h.repository.GetRegistrationRequestByMachineKey(ctx, machineKey)

if err != nil {
return err
if err != nil || m == nil {
response := tailcfg.RegisterResponse{MachineAuthorized: false, Error: "something went wrong"}
return binder.WriteResponse(c, http.StatusOK, response)
}

if m != nil {
response := tailcfg.RegisterResponse{MachineAuthorized: true}
if m != nil && m.IsFinished() {
response := tailcfg.RegisterResponse{MachineAuthorized: len(m.Error) != 0, Error: m.Error}
return binder.WriteResponse(c, http.StatusOK, response)
}
case <-notify:
Expand Down
Loading

0 comments on commit 85656c1

Please sign in to comment.