From 85656c19a7063fdc688e5c094834d4cc328de2c6 Mon Sep 17 00:00:00 2001 From: Johan Siebens Date: Sat, 28 May 2022 08:43:48 +0200 Subject: [PATCH] chore: store pending registration requests in db --- internal/database/database.go | 1 + internal/domain/registration_request.go | 96 +++++++++++++++++++++++++ internal/domain/repository.go | 4 ++ internal/handlers/authentication.go | 75 ++++++++++++------- internal/handlers/registration.go | 59 +++++++-------- internal/server/server.go | 8 +-- 6 files changed, 181 insertions(+), 62 deletions(-) create mode 100644 internal/domain/registration_request.go diff --git a/internal/database/database.go b/internal/database/database.go index e05cbee1..9ca9afee 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -51,6 +51,7 @@ func migrate(db *gorm.DB, repository domain.Repository) error { &domain.User{}, &domain.AuthKey{}, &domain.Machine{}, + &domain.RegistrationRequest{}, ) if err != nil { diff --git a/internal/domain/registration_request.go b/internal/domain/registration_request.go new file mode 100644 index 00000000..31f35be4 --- /dev/null +++ b/internal/domain/registration_request.go @@ -0,0 +1,96 @@ +package domain + +import ( + "context" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "gorm.io/gorm" + "gorm.io/gorm/schema" + "tailscale.com/tailcfg" + "time" +) + +type RegistrationRequest struct { + MachineKey string `gorm:"primary_key;autoIncrement:false"` + Key string `gorm:"type:varchar(64);unique_index"` + Data RegistrationRequestData + CreatedAt time.Time + Authenticated bool + Error string +} + +func (r *RegistrationRequest) IsFinished() bool { + return r.Authenticated || len(r.Error) != 0 +} + +type RegistrationRequestData tailcfg.RegisterRequest + +func (hi *RegistrationRequestData) Scan(destination interface{}) error { + switch value := destination.(type) { + case []byte: + return json.Unmarshal(value, hi) + default: + return fmt.Errorf("unexpected data type %T", destination) + } +} + +func (hi RegistrationRequestData) Value() (driver.Value, error) { + bytes, err := json.Marshal(hi) + return bytes, err +} + +// GormDataType gorm common data type +func (RegistrationRequestData) GormDataType() string { + return "json" +} + +// GormDBDataType gorm db data type +func (RegistrationRequestData) GormDBDataType(db *gorm.DB, field *schema.Field) string { + switch db.Dialector.Name() { + case "sqlite": + return "JSON" + } + return "" +} + +func (r *repository) SaveRegistrationRequest(ctx context.Context, request *RegistrationRequest) error { + tx := r.withContext(ctx).Save(request) + + if tx.Error != nil { + return tx.Error + } + + return nil +} + +func (r *repository) GetRegistrationRequestByKey(ctx context.Context, key string) (*RegistrationRequest, error) { + var m RegistrationRequest + tx := r.withContext(ctx).First(&m, "key = ?", key) + + if errors.Is(tx.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + + if tx.Error != nil { + return nil, tx.Error + } + + return &m, nil +} + +func (r *repository) GetRegistrationRequestByMachineKey(ctx context.Context, key string) (*RegistrationRequest, error) { + var m RegistrationRequest + tx := r.withContext(ctx).First(&m, "machine_key = ?", key) + + if errors.Is(tx.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + + if tx.Error != nil { + return nil, tx.Error + } + + return &m, nil +} diff --git a/internal/domain/repository.go b/internal/domain/repository.go index c8b034a3..ef2ed52c 100644 --- a/internal/domain/repository.go +++ b/internal/domain/repository.go @@ -67,6 +67,10 @@ type Repository interface { SetMachineLastSeen(ctx context.Context, machineID uint64) error ExpireMachineByAuthMethod(ctx context.Context, authMethodID uint64) (int64, error) + SaveRegistrationRequest(ctx context.Context, request *RegistrationRequest) error + GetRegistrationRequestByKey(ctx context.Context, key string) (*RegistrationRequest, error) + GetRegistrationRequestByMachineKey(ctx context.Context, key string) (*RegistrationRequest, error) + Transaction(func(rp Repository) error) error } diff --git a/internal/handlers/authentication.go b/internal/handlers/authentication.go index ffc0a34b..6fc3e522 100644 --- a/internal/handlers/authentication.go +++ b/internal/handlers/authentication.go @@ -8,6 +8,7 @@ import ( "github.com/mr-tron/base58" "net/http" "strconv" + "tailscale.com/tailcfg" "time" "github.com/jsiebens/ionscale/internal/config" @@ -20,21 +21,18 @@ import ( func NewAuthenticationHandlers( config *config.Config, - repository domain.Repository, - pendingMachineRegistrationRequests *cache.Cache) *AuthenticationHandlers { + repository domain.Repository) *AuthenticationHandlers { return &AuthenticationHandlers{ - config: config, - repository: repository, - pendingMachineRegistrationRequests: pendingMachineRegistrationRequests, - pendingOAuthUsers: cache.New(5*time.Minute, 10*time.Minute), + config: config, + repository: repository, + pendingOAuthUsers: cache.New(5*time.Minute, 10*time.Minute), } } type AuthenticationHandlers struct { - repository domain.Repository - config *config.Config - pendingMachineRegistrationRequests *cache.Cache - pendingOAuthUsers *cache.Cache + repository domain.Repository + config *config.Config + pendingOAuthUsers *cache.Cache } type AuthFormData struct { @@ -54,7 +52,7 @@ func (h *AuthenticationHandlers) StartAuth(c echo.Context) error { ctx := c.Request().Context() key := c.Param("key") - if _, ok := h.pendingMachineRegistrationRequests.Get(key); !ok { + if req, err := h.repository.GetRegistrationRequestByKey(ctx, key); err != nil || req == nil { return c.Redirect(http.StatusFound, "/a/error") } @@ -73,12 +71,13 @@ func (h *AuthenticationHandlers) ProcessAuth(c echo.Context) error { authKey := c.FormValue("ak") authMethodId := c.FormValue("s") - if _, ok := h.pendingMachineRegistrationRequests.Get(key); !ok { + req, err := h.repository.GetRegistrationRequestByKey(ctx, key) + if err != nil || req == nil { return c.Redirect(http.StatusFound, "/a/error") } if authKey != "" { - return h.endMachineRegistrationFlow(c, &oauthState{Key: key}) + return h.endMachineRegistrationFlow(c, req, &oauthState{Key: key}) } if authMethodId != "" { @@ -146,12 +145,19 @@ func (h *AuthenticationHandlers) Callback(c echo.Context) error { } func (h *AuthenticationHandlers) EndOAuth(c echo.Context) error { + ctx := c.Request().Context() + state, err := h.readState(c.QueryParam("state")) if err != nil { - return err + return c.Redirect(http.StatusFound, "/a/error") + } + + req, err := h.repository.GetRegistrationRequestByKey(ctx, state.Key) + if err != nil || req == nil { + return c.Redirect(http.StatusFound, "/a/error") } - return h.endMachineRegistrationFlow(c, state) + return h.endMachineRegistrationFlow(c, req, state) } func (h *AuthenticationHandlers) Success(c echo.Context) error { @@ -169,22 +175,14 @@ func (h *AuthenticationHandlers) Error(c echo.Context) error { return c.Render(http.StatusOK, "error.html", nil) } -func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, state *oauthState) error { +func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, registrationRequest *domain.RegistrationRequest, state *oauthState) error { ctx := c.Request().Context() - defer h.pendingMachineRegistrationRequests.Delete(state.Key) - - preqItem, preqOK := h.pendingMachineRegistrationRequests.Get(state.Key) - if !preqOK { - return c.Redirect(http.StatusFound, "/a/error") - } - authKeyParam := c.FormValue("ak") tailnetIDParam := c.FormValue("s") - preq := preqItem.(*pendingMachineRegistrationRequest) - req := preq.request - machineKey := preq.machineKey + req := tailcfg.RegisterRequest(registrationRequest.Data) + machineKey := registrationRequest.MachineKey nodeKey := req.NodeKey.String() var tailnet *domain.Tailnet @@ -199,6 +197,14 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, stat } if authKey == nil { + + registrationRequest.Authenticated = false + registrationRequest.Error = "invalid auth key" + + if err := h.repository.SaveRegistrationRequest(ctx, registrationRequest); err != nil { + return c.Redirect(http.StatusFound, "/a/error") + } + return c.Redirect(http.StatusFound, "/a/error?e=iak") } @@ -315,7 +321,22 @@ func (h *AuthenticationHandlers) endMachineRegistrationFlow(c echo.Context, stat m.ExpiresAt = nil } - if err := h.repository.SaveMachine(ctx, m); err != nil { + err = h.repository.Transaction(func(rp domain.Repository) error { + registrationRequest.Authenticated = true + registrationRequest.Error = "" + + if err := rp.SaveMachine(ctx, m); err != nil { + return err + } + + if err := rp.SaveRegistrationRequest(ctx, registrationRequest); err != nil { + return err + } + + return nil + }) + + if err != nil { return err } diff --git a/internal/handlers/registration.go b/internal/handlers/registration.go index 27a0fc58..626117c1 100644 --- a/internal/handlers/registration.go +++ b/internal/handlers/registration.go @@ -9,7 +9,6 @@ import ( "github.com/jsiebens/ionscale/internal/domain" "github.com/jsiebens/ionscale/internal/util" "github.com/labstack/echo/v4" - "github.com/patrickmn/go-cache" "inet.af/netaddr" "net/http" "tailscale.com/tailcfg" @@ -21,28 +20,20 @@ func NewRegistrationHandlers( createBinder bind.Factory, config *config.Config, brokers *broker.BrokerPool, - repository domain.Repository, - pendingMachineRegistrationRequests *cache.Cache) *RegistrationHandlers { + repository domain.Repository) *RegistrationHandlers { return &RegistrationHandlers{ - createBinder: createBinder, - brokers: brokers.Get, - repository: repository, - config: config, - pendingMachineRegistrationRequests: pendingMachineRegistrationRequests, + createBinder: createBinder, + brokers: brokers.Get, + repository: repository, + config: config, } } -type pendingMachineRegistrationRequest struct { - machineKey string - request *tailcfg.RegisterRequest -} - type RegistrationHandlers struct { - createBinder bind.Factory - repository domain.Repository - brokers func(uint64) broker.Broker - config *config.Config - pendingMachineRegistrationRequests *cache.Cache + createBinder bind.Factory + repository domain.Repository + brokers func(uint64) broker.Broker + config *config.Config } func (h *RegistrationHandlers) Register(c echo.Context) error { @@ -113,6 +104,8 @@ func (h *RegistrationHandlers) Register(c echo.Context) error { } func (h *RegistrationHandlers) authenticateMachine(c echo.Context, binder bind.Binder, machineKey string, req *tailcfg.RegisterRequest) error { + ctx := c.Request().Context() + if req.Followup != "" { return h.followup(c, binder, req) } @@ -121,10 +114,18 @@ func (h *RegistrationHandlers) authenticateMachine(c echo.Context, binder bind.B key := util.RandStringBytes(8) authUrl := h.config.CreateUrl("/a/%s", key) - h.pendingMachineRegistrationRequests.Set(key, &pendingMachineRegistrationRequest{ - machineKey: machineKey, - request: req, - }, cache.DefaultExpiration) + request := domain.RegistrationRequest{ + MachineKey: machineKey, + Key: key, + CreatedAt: time.Now().UTC(), + Data: domain.RegistrationRequestData(*req), + } + + err := h.repository.SaveRegistrationRequest(ctx, &request) + if err != nil { + response := tailcfg.RegisterResponse{MachineAuthorized: false, Error: "something went wrong"} + return binder.WriteResponse(c, http.StatusOK, response) + } response := tailcfg.RegisterResponse{AuthURL: authUrl} return binder.WriteResponse(c, http.StatusOK, response) @@ -232,24 +233,24 @@ func (h *RegistrationHandlers) followup(c echo.Context, binder bind.Binder, req // Listen to connection close ctx := c.Request().Context() notify := ctx.Done() - tick := time.NewTicker(5 * time.Second) + tick := time.NewTicker(2 * time.Second) defer func() { tick.Stop() }() machineKey := binder.Peer().String() - nodeKey := req.NodeKey.String() for { select { case <-tick.C: - m, err := h.repository.GetMachineByKeys(ctx, machineKey, nodeKey) + m, err := h.repository.GetRegistrationRequestByMachineKey(ctx, machineKey) - if err != nil { - return err + if err != nil || m == nil { + response := tailcfg.RegisterResponse{MachineAuthorized: false, Error: "something went wrong"} + return binder.WriteResponse(c, http.StatusOK, response) } - if m != nil { - response := tailcfg.RegisterResponse{MachineAuthorized: true} + if m != nil && m.IsFinished() { + response := tailcfg.RegisterResponse{MachineAuthorized: len(m.Error) != 0, Error: m.Error} return binder.WriteResponse(c, http.StatusOK, response) } case <-notify: diff --git a/internal/server/server.go b/internal/server/server.go index 8533d3ae..5ea4bd21 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -17,7 +17,6 @@ import ( "github.com/jsiebens/ionscale/pkg/gen/api" echo_prometheus "github.com/labstack/echo-contrib/prometheus" "github.com/labstack/echo/v4" - "github.com/patrickmn/go-cache" "github.com/soheilhy/cmux" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" @@ -28,7 +27,6 @@ import ( "os" "strings" "tailscale.com/types/key" - "time" ) func Start(config *config.Config) error { @@ -54,7 +52,6 @@ func Start(config *config.Config) error { return err } - pendingMachineRegistrationRequests := cache.New(5*time.Minute, 10*time.Minute) brokers := broker.NewBrokerPool() offlineTimers := handlers.NewOfflineTimers(repository, brokers) reaper := handlers.NewReaper(brokers, repository) @@ -81,7 +78,7 @@ func Start(config *config.Config) error { } createPeerHandler := func(p key.MachinePublic) http.Handler { - registrationHandlers := handlers.NewRegistrationHandlers(bind.DefaultBinder(p), config, brokers, repository, pendingMachineRegistrationRequests) + registrationHandlers := handlers.NewRegistrationHandlers(bind.DefaultBinder(p), config, brokers, repository) pollNetMapHandler := handlers.NewPollNetMapHandler(bind.DefaultBinder(p), brokers, repository, offlineTimers) e := echo.New() @@ -94,12 +91,11 @@ func Start(config *config.Config) error { } noiseHandlers := handlers.NewNoiseHandlers(controlKeys.ControlKey, createPeerHandler) - registrationHandlers := handlers.NewRegistrationHandlers(bind.BoxBinder(controlKeys.LegacyControlKey), config, brokers, repository, pendingMachineRegistrationRequests) + registrationHandlers := handlers.NewRegistrationHandlers(bind.BoxBinder(controlKeys.LegacyControlKey), config, brokers, repository) pollNetMapHandler := handlers.NewPollNetMapHandler(bind.BoxBinder(controlKeys.LegacyControlKey), brokers, repository, offlineTimers) authenticationHandlers := handlers.NewAuthenticationHandlers( config, repository, - pendingMachineRegistrationRequests, ) p := echo_prometheus.NewPrometheus("http", nil)