Skip to content

Commit

Permalink
Merge pull request #104 from tupyy/agent_auth
Browse files Browse the repository at this point in the history
agent/auth: Implement agent authentication
  • Loading branch information
tupyy authored Dec 16, 2024
2 parents 349a8db + 3b78a2a commit b3dcc8f
Show file tree
Hide file tree
Showing 13 changed files with 105 additions and 23 deletions.
21 changes: 15 additions & 6 deletions cmd/planner-agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

const (
agentFilename = "agent_id"
jwtFilename = "jwt.json"
)

func main() {
Expand Down Expand Up @@ -76,23 +77,31 @@ func (a *agentCmd) Execute() error {
undo := zap.ReplaceGlobals(logger)
defer undo()

agentID, err := a.getAgentID()
agentID, err := a.readFile(agentFilename)
if err != nil {
zap.S().Fatalf("failed to retreive agent_id: %v", err)
}

agentInstance := agent.New(uuid.MustParse(agentID), a.config)
// Try to read jwt from file.
// We're assuming the jwt is valid.
// The agent will not try to validate the jwt. The backend is responsible for validating the token.
jwt, err := a.readFile(jwtFilename)
if err != nil {
zap.S().Errorf("failed to read jwt: %v", err)
}

agentInstance := agent.New(uuid.MustParse(agentID), jwt, a.config)
if err := agentInstance.Run(context.Background()); err != nil {
zap.S().Fatalf("running device agent: %v", err)
}
return nil
}

func (a *agentCmd) getAgentID() (string, error) {
func (a *agentCmd) readFile(filename string) (string, error) {
// look for it in data dir
dataDirPath := path.Join(a.config.DataDir, agentFilename)
if _, err := os.Stat(dataDirPath); err == nil {
content, err := os.ReadFile(dataDirPath)
confDirPath := path.Join(a.config.DataDir, filename)
if _, err := os.Stat(confDirPath); err == nil {
content, err := os.ReadFile(confDirPath)
if err != nil {
return "", err
}
Expand Down
7 changes: 7 additions & 0 deletions data/ignition.template
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ storage:
location = "{{.InsecureRegistry}}"
insecure = true
{{end}}
{{ if .Token }}
- path: /home/core/.migration-planner/config/jwt.json
mode: 0644
contents:
inline: |
{{ .Token }}
{{ end }}
- path: /var/lib/systemd/linger/core
mode: 0644
contents:
Expand Down
12 changes: 10 additions & 2 deletions internal/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/google/uuid"
api "github.com/kubev2v/migration-planner/api/v1alpha1"
"github.com/kubev2v/migration-planner/internal/agent/client"
"github.com/kubev2v/migration-planner/internal/agent/common"
"github.com/kubev2v/migration-planner/internal/agent/config"
"github.com/kubev2v/migration-planner/internal/agent/service"
"github.com/lthibault/jitterbug"
Expand All @@ -31,11 +32,12 @@ const (
var version string

// New creates a new agent.
func New(id uuid.UUID, config *config.Config) *Agent {
func New(id uuid.UUID, jwt string, config *config.Config) *Agent {
return &Agent{
config: config,
healtCheckStopCh: make(chan chan any),
id: id,
jwt: jwt,
}
}

Expand All @@ -45,6 +47,7 @@ type Agent struct {
healtCheckStopCh chan chan any
credUrl string
id uuid.UUID
jwt string
}

func (a *Agent) Run(ctx context.Context) error {
Expand Down Expand Up @@ -96,6 +99,11 @@ func (a *Agent) start(ctx context.Context, plannerClient client.Planner) {
inventoryUpdater := service.NewInventoryUpdater(a.id, plannerClient)
statusUpdater := service.NewStatusUpdater(a.id, version, a.credUrl, a.config, plannerClient)

// insert jwt into the context if any
if a.jwt != "" {
ctx = context.WithValue(ctx, common.JwtKey, a.jwt)
}

// start server
a.server = NewServer(defaultAgentPort, a.config)
go a.server.Start(statusUpdater)
Expand All @@ -114,7 +122,7 @@ func (a *Agent) start(ctx context.Context, plannerClient client.Planner) {
}

// TODO refactor health checker to call it from the main goroutine
healthChecker.Start(a.healtCheckStopCh)
healthChecker.Start(ctx, a.healtCheckStopCh)

collector := service.NewCollector(a.config.DataDir)
collector.Collect(ctx)
Expand Down
3 changes: 2 additions & 1 deletion internal/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ var _ = Describe("Agent", func() {
}
config.PlannerService.Service.Server = testHttpServer.URL

a := agent.New(agentID, &config)
jwt := ""
a := agent.New(agentID, jwt, &config)
ctx, cancel := context.WithTimeout(context.TODO(), 30*time.Second)
go func() {
err := a.Run(ctx)
Expand Down
31 changes: 28 additions & 3 deletions internal/agent/client/planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/google/uuid"
api "github.com/kubev2v/migration-planner/api/v1alpha1/agent"
"github.com/kubev2v/migration-planner/internal/agent/common"
client "github.com/kubev2v/migration-planner/internal/api/client/agent"
)

Expand All @@ -29,7 +30,12 @@ type planner struct {
}

func (p *planner) UpdateSourceStatus(ctx context.Context, id uuid.UUID, params api.SourceStatusUpdate) error {
resp, err := p.client.ReplaceSourceStatusWithResponse(ctx, id, params)
resp, err := p.client.ReplaceSourceStatusWithResponse(ctx, id, params, func(ctx context.Context, req *http.Request) error {
if jwt, found := p.jwtFromContext(ctx); found {
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", jwt))
}
return nil
})
if err != nil {
return err
}
Expand All @@ -44,7 +50,13 @@ func (p *planner) UpdateSourceStatus(ctx context.Context, id uuid.UUID, params a
}

func (p *planner) Health(ctx context.Context) error {
resp, err := p.client.HealthWithResponse(ctx)
resp, err := p.client.HealthWithResponse(ctx, func(ctx context.Context, req *http.Request) error {
if jwt, found := p.jwtFromContext(ctx); found {
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", jwt))
}
return nil
})

if err != nil {
return err
}
Expand All @@ -58,7 +70,12 @@ func (p *planner) Health(ctx context.Context) error {
}

func (p *planner) UpdateAgentStatus(ctx context.Context, id uuid.UUID, params api.AgentStatusUpdate) error {
resp, err := p.client.UpdateAgentStatusWithResponse(ctx, id, params)
resp, err := p.client.UpdateAgentStatusWithResponse(ctx, id, params, func(ctx context.Context, req *http.Request) error {
if jwt, found := p.jwtFromContext(ctx); found {
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", jwt))
}
return nil
})
if err != nil {
return err
}
Expand All @@ -73,3 +90,11 @@ func (p *planner) UpdateAgentStatus(ctx context.Context, id uuid.UUID, params ap
}
return nil
}

func (p *planner) jwtFromContext(ctx context.Context) (string, bool) {
val := ctx.Value(common.JwtKey)
if val == nil {
return "", false
}
return val.(string), true
}
7 changes: 7 additions & 0 deletions internal/agent/common/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package common

type jwtContextKeyType struct{}

var (
JwtKey jwtContextKeyType
)
10 changes: 5 additions & 5 deletions internal/agent/service/health.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ func NewHealthChecker(client client.Planner, logFolder string, checkInterval tim
// initialInterval represents the time after which the check is started.
// checkInterval represents the time to wait between checks.
// closeCh is the channel used to close the goroutine.
func (h *HealthChecker) Start(closeCh chan chan any) {
h.do()
func (h *HealthChecker) Start(ctx context.Context, closeCh chan chan any) {
h.do(ctx)

h.once.Do(func() {
go func() {
Expand All @@ -96,7 +96,7 @@ func (h *HealthChecker) Start(closeCh chan chan any) {
close(c)
return
case <-t.C:
h.do()
h.do(ctx)
}
}
}()
Expand All @@ -109,8 +109,8 @@ func (h *HealthChecker) State() AgentHealthState {
return h.state
}

func (h *HealthChecker) do() {
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout*time.Second)
func (h *HealthChecker) do(ctx context.Context) {
ctx, cancel := context.WithTimeout(ctx, defaultTimeout*time.Second)
defer cancel()

err := h.client.Health(ctx)
Expand Down
10 changes: 5 additions & 5 deletions internal/agent/service/health_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ var _ = Describe("Health checker", func() {

It("should close OK", func() {
closeCh := make(chan chan any)
hc.Start(closeCh)
hc.Start(context.TODO(), closeCh)
<-time.After(2 * time.Second)

c := make(chan any, 1)
Expand All @@ -81,7 +81,7 @@ var _ = Describe("Health checker", func() {

It("should call health endpoint", func() {
closeCh := make(chan chan any)
hc.Start(closeCh)
hc.Start(context.TODO(), closeCh)
<-time.After(5 * time.Second)

c := make(chan any, 1)
Expand Down Expand Up @@ -136,7 +136,7 @@ var _ = Describe("Health checker", func() {
It("should write OK -- only failures", func() {
closeCh := make(chan chan any)
testClient.ShouldReturnError = true
hc.Start(closeCh)
hc.Start(context.TODO(), closeCh)

<-time.After(5 * time.Second)

Expand All @@ -155,7 +155,7 @@ var _ = Describe("Health checker", func() {
It("should write OK -- failures and one OK line", func() {
closeCh := make(chan chan any)
testClient.ShouldReturnError = true
hc.Start(closeCh)
hc.Start(context.TODO(), closeCh)

<-time.After(2 * time.Second)
testClient.ShouldReturnError = false
Expand All @@ -182,7 +182,7 @@ var _ = Describe("Health checker", func() {
It("should write OK -- failures and 2 OK lines", func() {
closeCh := make(chan chan any)
testClient.ShouldReturnError = true
hc.Start(closeCh)
hc.Start(context.TODO(), closeCh)

<-time.After(2 * time.Second)
testClient.ShouldReturnError = false
Expand Down
7 changes: 7 additions & 0 deletions internal/api_server/agentserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/go-chi/chi/v5/middleware"
api "github.com/kubev2v/migration-planner/api/v1alpha1/agent"
server "github.com/kubev2v/migration-planner/internal/api/server/agent"
"github.com/kubev2v/migration-planner/internal/auth"
"github.com/kubev2v/migration-planner/internal/config"
"github.com/kubev2v/migration-planner/internal/events"
service "github.com/kubev2v/migration-planner/internal/service/agent"
Expand Down Expand Up @@ -64,8 +65,14 @@ func (s *AgentServer) Run(ctx context.Context) error {
ErrorHandler: oapiErrorHandler,
}

authenticator, err := auth.NewAuthenticator(s.cfg.Service.Auth)
if err != nil {
return fmt.Errorf("failed to create authenticator: %w", err)
}

router := chi.NewRouter()
router.Use(
authenticator.Authenticator,
middleware.RequestID,
zapchi.Logger(zap.S(), "router_agent"),
middleware.Recoverer,
Expand Down
1 change: 1 addition & 0 deletions internal/auth/rhsso_authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func (rh *RHSSOAuthenticator) parseToken(userToken *jwt.Token) (User, error) {
Username: claims["username"].(string),
Organization: claims["org_id"].(string),
ClientID: claims["client_id"].(string),
Token: userToken,
}, nil
}

Expand Down
7 changes: 6 additions & 1 deletion internal/auth/user.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package auth

import "context"
import (
"context"

"github.com/golang-jwt/jwt/v5"
)

type usernameKeyType struct{}

Expand All @@ -24,4 +28,5 @@ type User struct {
Username string
Organization string
ClientID string
Token *jwt.Token
}
6 changes: 6 additions & 0 deletions internal/image/ova.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/coreos/butane/config"
"github.com/coreos/butane/config/common"
"github.com/golang-jwt/jwt/v5"
"github.com/kubev2v/migration-planner/internal/util"
"github.com/openshift/assisted-image-service/pkg/isoeditor"
"github.com/openshift/assisted-image-service/pkg/overlay"
Expand All @@ -24,6 +25,7 @@ const ResponseWriterKey Key = 0
type Ova struct {
Writer io.Writer
SshKey *string
Jwt *jwt.Token
}

// IgnitionData defines modifiable fields in ignition config
Expand All @@ -33,6 +35,7 @@ type IgnitionData struct {
PlannerService string
MigrationPlannerAgentImage string
InsecureRegistry string
Token string
}

type Image interface {
Expand Down Expand Up @@ -148,6 +151,9 @@ func (o *Ova) generateIgnition() (string, error) {
if o.SshKey != nil {
ignData.SshKey = *o.SshKey
}
if o.Jwt != nil {
ignData.Token = o.Jwt.Raw
}

if insecureRegistry := os.Getenv("INSECURE_REGISTRY"); insecureRegistry != "" {
ignData.InsecureRegistry = insecureRegistry
Expand Down
6 changes: 6 additions & 0 deletions internal/service/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"

"github.com/kubev2v/migration-planner/internal/api/server"
"github.com/kubev2v/migration-planner/internal/auth"
"github.com/kubev2v/migration-planner/internal/image"
)

Expand All @@ -15,7 +16,12 @@ func (h *ServiceHandler) GetImage(ctx context.Context, request server.GetImageRe
if !ok {
return server.GetImage500JSONResponse{Message: "error creating the HTTP stream"}, nil
}

ova := &image.Ova{SshKey: request.Params.SshKey, Writer: writer}
// get token if any
if user, found := auth.UserFromContext(ctx); found {
ova.Jwt = user.Token
}
if err := ova.Generate(); err != nil {
return server.GetImage500JSONResponse{Message: fmt.Sprintf("error generating image %s", err)}, nil
}
Expand Down

0 comments on commit b3dcc8f

Please sign in to comment.